Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#26 from DesmonDay/gpu_graph_engine2
Browse files Browse the repository at this point in the history
Add actual neighbor sample result
  • Loading branch information
seemingwang committed May 1, 2022
2 parents acb8ac0 + 7762561 commit 7cfe661
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
19 changes: 17 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ struct NeighborSampleQuery {
};
struct NeighborSampleResult {
int64_t *val;
int64_t *actual_val;
int *actual_sample_size, sample_size, key_size;
std::shared_ptr<memory::Allocation> val_mem, actual_sample_size_mem;
std::shared_ptr<memory::Allocation> actual_val_mem;
int64_t *get_val() { return val; }
int64_t *get_actual_val() { return actual_val; }
int *get_actual_sample_size() { return actual_sample_size; }
int get_sample_size() { return sample_size; }
int get_key_size() { return key_size; }
Expand All @@ -165,18 +168,30 @@ struct NeighborSampleResult {
int *ac_size = new int[key_size];
cudaMemcpy(ac_size, actual_sample_size, key_size * sizeof(int),
cudaMemcpyDeviceToHost); // 3, 1, 3
int total_sample_size = 0;
for (int i = 0; i < key_size; i++) {
total_sample_size += ac_size[i];
}
int64_t *res2 = new int64_t[total_sample_size];
cudaMemcpy(res2, actual_val, total_sample_size * sizeof(int64_t),
cudaMemcpyDeviceToHost);

int start = 0;
for (int i = 0; i < key_size; i++) {
VLOG(0) << "actual sample size for " << i << "th key is " << ac_size[i];
VLOG(0) << "sampled neighbors are ";
std::string neighbor;
std::string neighbor, neighbor2;
for (int j = 0; j < ac_size[i]; j++) {
if (neighbor.size() > 0) neighbor += ";";
if (neighbor2.size() > 0) neighbor2 += ";";
neighbor += std::to_string(res[i * sample_size + j]);
neighbor2 += std::to_string(res2[start + j]);
}
VLOG(0) << neighbor;
VLOG(0) << neighbor << " " << neighbor2;
start += ac_size[i];
}
delete[] res;
delete[] res2;
delete[] ac_size;
VLOG(0) << " ------------------";
}
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <functional>
#pragma once
#ifdef PADDLE_WITH_HETERPS
Expand Down Expand Up @@ -374,6 +376,18 @@ __global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals,
}
}

__global__ void fill_actual_vals(int64_t* vals, int64_t* actual_vals,
int* actual_sample_size,
int* cumsum_actual_sample_size,
int sample_size, int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
for (int j = 0; j < actual_sample_size[i]; j++) {
actual_vals[cumsum_actual_sample_size[i] + j] = vals[sample_size * i + j];
}
}
}

__global__ void node_query_example(GpuPsCommGraph graph, int start, int size,
int64_t* res) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -846,6 +860,22 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
d_idx_ptr, sample_size, len);

thrust::device_ptr<int> t_actual_sample_size(actual_sample_size);
int total_sample_size =
thrust::reduce(t_actual_sample_size, t_actual_sample_size + len);
result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(int64_t));
result.actual_val = (int64_t*)(result.actual_val_mem)->ptr();

thrust::device_vector<int> cumsum_actual_sample_size(len);
thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len,
cumsum_actual_sample_size.begin(), 0);
fill_actual_vals<<<grid_size, block_size_, 0, stream>>>(
val, result.actual_val, actual_sample_size,
thrust::raw_pointer_cast(cumsum_actual_sample_size.data()), sample_size,
len);

for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
Expand Down

0 comments on commit 7cfe661

Please sign in to comment.