diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 430d8ef3ed06a..8527c5031e70d 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -54,28 +54,36 @@ int32_t GraphTable::Load_to_ssd(const std::string &path, paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( std::vector &node_ids, int slot_num) { std::vector> bags(task_pool_size_); + for (int i = 0; i < task_pool_size_; i++) { + auto predsize = node_ids.size() / task_pool_size_; + bags[i].reserve(predsize * 1.2); + } + for (auto x : node_ids) { int location = x % shard_num % task_pool_size_; bags[location].push_back(x); } + std::vector> tasks; std::vector feature_array[task_pool_size_]; std::vector slot_id_array[task_pool_size_]; - std::vector - node_fea_array[task_pool_size_]; + std::vector node_id_array[task_pool_size_]; + std::vector + node_fea_info_array[task_pool_size_]; for (size_t i = 0; i < bags.size(); i++) { if (bags[i].size() > 0) { tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int { - paddle::framework::GpuPsGraphFeaNode x; + uint64_t node_id; + paddle::framework::GpuPsFeaInfo x; std::vector feature_ids; for (size_t j = 0; j < bags[i].size(); j++) { // TODO use FEATURE_TABLE instead Node *v = find_node(1, bags[i][j]); - x.node_id = bags[i][j]; + node_id = bags[i][j]; if (v == NULL) { x.feature_size = 0; x.feature_offset = 0; - node_fea_array[i].push_back(x); + node_fea_info_array[i].push_back(x); } else { // x <- v x.feature_offset = feature_array[i].size(); @@ -91,8 +99,9 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( } } x.feature_size = total_feature_size; - node_fea_array[i].push_back(x); + node_fea_info_array[i].push_back(x); } + node_id_array[i].push_back(node_id); } return 0; })); @@ -109,9 +118,10 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num); unsigned int offset = 0, ind = 0; for (int i = 0; i < task_pool_size_; i++) { - for (int j = 0; j < (int)node_fea_array[i].size(); j++) { - res.node_list[ind] = node_fea_array[i][j]; - res.node_list[ind++].feature_offset += offset; + for (int j = 0; j < (int)node_id_array[i].size(); j++) { + res.node_list[ind] = node_id_array[i][j]; + res.fea_info_list[ind] = node_fea_info_array[i][j]; + res.fea_info_list[ind++].feature_offset += offset; } for (size_t j = 0; j < feature_array[i].size(); j++) { res.feature_list[offset + j] = feature_array[i][j]; @@ -125,49 +135,62 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( int idx, std::vector ids) { std::vector> bags(task_pool_size_); + for (int i = 0; i < task_pool_size_; i++) { + auto predsize = ids.size() / task_pool_size_; + bags[i].reserve(predsize * 1.2); + } for (auto x : ids) { int location = x % shard_num % task_pool_size_; bags[location].push_back(x); } + std::vector> tasks; - std::vector edge_array[task_pool_size_]; - std::vector node_array[task_pool_size_]; + std::vector node_array[task_pool_size_]; // node id list + std::vector info_array[task_pool_size_]; + std::vector edge_array[task_pool_size_]; // edge id list + for (size_t i = 0; i < bags.size(); i++) { if (bags[i].size() > 0) { tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int { - paddle::framework::GpuPsGraphNode x; + node_array[i].resize(bags[i].size()); + info_array[i].resize(bags[i].size()); + edge_array[i].reserve(bags[i].size()); + for (size_t j = 0; j < bags[i].size(); j++) { - Node *v = find_node(0, idx, bags[i][j]); - x.node_id = bags[i][j]; - if (v == NULL) { - x.neighbor_size = 0; - x.neighbor_offset = 0; - node_array[i].push_back(x); - } else { - x.neighbor_size = v->get_neighbor_size(); - x.neighbor_offset = edge_array[i].size(); - node_array[i].push_back(x); - for (size_t k = 0; k < (size_t)x.neighbor_size; k++) { - edge_array[i].push_back(v->get_neighbor_id(k)); + auto node_id = bags[i][j]; + node_array[i][j] = node_id; + Node *v = find_node(0, idx, node_id); + if (v != nullptr) { + info_array[i][j].neighbor_offset = edge_array[i].size(); + info_array[i][j].neighbor_size = v->get_neighbor_size(); + for (size_t k = 0; k < v->get_neighbor_size(); k++) { + edge_array[i].push_back(v->get_neighbor_id(k)); + } + } + else { + info_array[i][j].neighbor_offset = 0; + info_array[i][j].neighbor_size = 0; } - } } return 0; })); } } for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); - paddle::framework::GpuPsCommGraph res; + int64_t tot_len = 0; for (int i = 0; i < task_pool_size_; i++) { tot_len += edge_array[i].size(); } + + paddle::framework::GpuPsCommGraph res; res.init_on_cpu(tot_len, ids.size()); int64_t offset = 0, ind = 0; for (int i = 0; i < task_pool_size_; i++) { for (int j = 0; j < (int)node_array[i].size(); j++) { res.node_list[ind] = node_array[i][j]; - res.node_list[ind++].neighbor_offset += offset; + res.node_info_list[ind] = info_array[i][j]; + res.node_info_list[ind++].neighbor_offset += offset; } for (size_t j = 0; j < edge_array[i].size(); j++) { res.neighbor_list[offset + j] = edge_array[i][j]; @@ -275,7 +298,7 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx, std::string str; if (_db->get(i, ch, sizeof(int) * 2 + sizeof(uint64_t), str) == 0) { count[i] += (int64_t)str.size(); - for (int j = 0; j < (int)str.size(); j += sizeof(uint64_t)) { + for (size_t j = 0; j < (int)str.size(); j += sizeof(uint64_t)) { uint64_t id = *(uint64_t *)(str.c_str() + j); add_comm_edge(idx, v, id); } @@ -345,7 +368,7 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) { score[i] = 0; } } - for (int j = 0; j < (int)value.size(); j += sizeof(uint64_t)) { + for (size_t j = 0; j < (int)value.size(); j += sizeof(uint64_t)) { uint64_t v = *((uint64_t *)(value.c_str() + j)); int index = -1; if (id_map.find(v) != id_map.end()) { @@ -438,7 +461,7 @@ void GraphTable::clear_graph(int idx) { } } int32_t GraphTable::load_next_partition(int idx) { - if (next_partition >= partitions[idx].size()) { + if (next_partition >= (int)partitions[idx].size()) { VLOG(0) << "partition iteration is done"; return -1; } @@ -500,7 +523,7 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) { std::vector &v = shards[i]->get_bucket(); for (size_t j = 0; j < v.size(); j++) { std::vector s; - for (int k = 0; k < (int)v[j]->get_neighbor_size(); k++) { + for (size_t k = 0; k < (int)v[j]->get_neighbor_size(); k++) { s.push_back(v[j]->get_neighbor_id(k)); } cost += v[j]->get_neighbor_size() * sizeof(uint64_t); @@ -1794,7 +1817,7 @@ std::vector> GraphTable::get_all_id(int type_id, int idx, auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; std::vector>> tasks; VLOG(0) << "begin task, task_pool_size_[" << task_pool_size_ << "]"; - for (int i = 0; i < search_shards.size(); i++) { + for (size_t i = 0; i < search_shards.size(); i++) { tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( [&search_shards, i]() -> std::vector { return search_shards[i]->get_all_id(); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 40e72aff434a5..7833e9760c476 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -45,6 +45,7 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #endif DECLARE_int32(record_pool_max_size); @@ -422,7 +423,6 @@ struct UsedSlotGpuType { }; #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) -#define CUDA_CHECK(val) CHECK(val == gpuSuccess) template struct CudaBuffer { T* cu_buffer; diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index dcdca8944b142..1b996a9b9359b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -23,52 +23,63 @@ #include "paddle/phi/core/enforce.h" namespace paddle { namespace framework { -struct GpuPsGraphNode { - uint64_t node_id; - int64_t neighbor_size, neighbor_offset; +struct GpuPsNodeInfo { + uint32_t neighbor_size, neighbor_offset; + GpuPsNodeInfo() : neighbor_size(0), neighbor_offset(0) {} // this node's neighbor is stored on [neighbor_offset,neighbor_offset + // neighbor_size) of int64_t *neighbor_list; }; struct GpuPsCommGraph { - uint64_t *neighbor_list; - GpuPsGraphNode *node_list; - int64_t neighbor_size, node_size; - // the size of neighbor array and graph_node_list array + uint64_t *node_list; //locate on both side + int64_t node_size; // the size of node_list + GpuPsNodeInfo *node_info_list; // only locate on host side + uint64_t *neighbor_list; //locate on both side + int64_t neighbor_size; //the size of neighbor_list GpuPsCommGraph() - : neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {} - GpuPsCommGraph(uint64_t *neighbor_list_, GpuPsGraphNode *node_list_, - int64_t neighbor_size_, int64_t node_size_) - : neighbor_list(neighbor_list_), - node_list(node_list_), - neighbor_size(neighbor_size_), - node_size(node_size_) {} - void init_on_cpu(int64_t neighbor_size, int64_t node_size) { - this->neighbor_size = neighbor_size; - this->node_size = node_size; - this->neighbor_list = new uint64_t[neighbor_size]; - this->node_list = new paddle::framework::GpuPsGraphNode[node_size]; + : node_list(nullptr), node_size(0), node_info_list(nullptr), + neighbor_list(nullptr), neighbor_size(0) {} + GpuPsCommGraph(uint64_t *node_list_, int64_t node_size_, + GpuPsNodeInfo *node_info_list_, + uint64_t *neighbor_list_, int64_t neighbor_size_) + : node_list(node_list_), node_size(node_size_), + node_info_list(node_info_list_), + neighbor_list(neighbor_list_), + neighbor_size(neighbor_size_) {} + void init_on_cpu(int64_t neighbor_size_, int64_t node_size_) { + if (node_size_ > 0) { + this->node_size = node_size_; + this->node_list = new uint64_t[node_size_]; + this->node_info_list = new paddle::framework::GpuPsNodeInfo[node_size_]; + } + if (neighbor_size_) { + this->neighbor_size = neighbor_size_; + this->neighbor_list = new uint64_t[neighbor_size_]; + } } void release_on_cpu() { - delete[] neighbor_list; - delete[] node_list; + +#define DEL_PTR_ARRAY(p) \ + if (p != nullptr) { \ + delete [] p; \ + p = nullptr; \ + } + DEL_PTR_ARRAY(node_list); + DEL_PTR_ARRAY(neighbor_list); + DEL_PTR_ARRAY(node_info_list); + node_size = 0; + neighbor_size = 0; } - void display_on_cpu() { + void display_on_cpu() const { VLOG(0) << "neighbor_size = " << neighbor_size; VLOG(0) << "node_size = " << node_size; for (int64_t i = 0; i < neighbor_size; i++) { VLOG(0) << "neighbor " << i << " " << neighbor_list[i]; } for (int64_t i = 0; i < node_size; i++) { - VLOG(0) << "node i " << node_list[i].node_id - << " neighbor_size = " << node_list[i].neighbor_size; - std::string str; - int offset = node_list[i].neighbor_offset; - for (int64_t j = 0; j < node_list[i].neighbor_size; j++) { - if (j > 0) str += ","; - str += std::to_string(neighbor_list[j + offset]); - } - VLOG(0) << str; + auto id = node_list[i]; + auto val = node_info_list[i]; + VLOG(0) << "node id " << id << "," << val.neighbor_offset << ":" << val.neighbor_size; } } }; @@ -108,19 +119,11 @@ node 9:[14,14] node 17:[15,15] ... by the above information, -we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph -of size 9, -where node_list[i].id = u_id[i] -then we have: -node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0 -node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2 -node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4 -node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5 -node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6 -node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9 -node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13 -node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14 -node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15 +we generate a node_list and node_info_list in GpuPsCommGraph, +node_list: [0,5,1,2,7,3,8,9,17] +node_info_list: [(2,0),(2,2),(1,4),(1,5),(3,6),(4,9),(1,13),(1,14),(1,15)] +Here, we design the data in this format to better +adapt to gpu and avoid to convert again. */ struct NeighborSampleQuery { int gpu_id; @@ -263,8 +266,6 @@ struct NodeQueryResult { platform::CUDAPlace place = platform::CUDAPlace(dev_id); val_mem = memory::AllocShared(place, query_size * sizeof(uint64_t)); val = (uint64_t *)val_mem->ptr(); - - // cudaMalloc((void **)&val, query_size * sizeof(int64_t)); actual_sample_size = 0; } void display() { @@ -290,31 +291,33 @@ struct NodeQueryResult { ~NodeQueryResult() {} }; // end of struct NodeQueryResult -struct GpuPsGraphFeaNode { - uint64_t node_id; - uint64_t feature_size, feature_offset; +struct GpuPsFeaInfo { + uint32_t feature_size, feature_offset; // this node's feature is stored on [feature_offset,feature_offset + // feature_size) of int64_t *feature_list; }; struct GpuPsCommGraphFea { - uint64_t *feature_list; - uint8_t *slot_id_list; - GpuPsGraphFeaNode *node_list; + uint64_t *node_list; // only locate on host side, the list of node id + uint64_t *feature_list; //locate on both side + uint8_t *slot_id_list; //locate on both side + GpuPsFeaInfo *fea_info_list;// only locate on host side, the list of fea_info uint64_t feature_size, node_size; // the size of feature array and graph_node_list array GpuPsCommGraphFea() - : feature_list(NULL), + : node_list(NULL), + feature_list(NULL), slot_id_list(NULL), - node_list(NULL), + fea_info_list(NULL), feature_size(0), node_size(0) {} - GpuPsCommGraphFea(uint64_t *feature_list_, uint8_t *slot_id_list_, - GpuPsGraphFeaNode *node_list_, uint64_t feature_size_, + GpuPsCommGraphFea(uint64_t *node_list_, uint64_t *feature_list_, uint8_t *slot_id_list_, + GpuPsFeaInfo *fea_info_list_, uint64_t feature_size_, uint64_t node_size_) - : feature_list(feature_list_), + : node_list(node_list_), + feature_list(feature_list_), slot_id_list(slot_id_list_), - node_list(node_list_), + fea_info_list(fea_info_list_), feature_size(feature_size_), node_size(node_size_) {} void init_on_cpu(uint64_t feature_size, uint64_t node_size, @@ -322,27 +325,34 @@ struct GpuPsCommGraphFea { PADDLE_ENFORCE_LE(slot_num, 255); this->feature_size = feature_size; this->node_size = node_size; + this->node_list = new uint64_t[node_size]; this->feature_list = new uint64_t[feature_size]; this->slot_id_list = new uint8_t[feature_size]; - this->node_list = new GpuPsGraphFeaNode[node_size]; + this->fea_info_list = new GpuPsFeaInfo[node_size]; } void release_on_cpu() { - delete[] feature_list; - delete[] slot_id_list; - delete[] node_list; +#define DEL_PTR_ARRAY(p) \ + if (p != nullptr) { \ + delete [] p; \ + p = nullptr; \ + } + DEL_PTR_ARRAY(node_list); + DEL_PTR_ARRAY(feature_list); + DEL_PTR_ARRAY(slot_id_list); + DEL_PTR_ARRAY(fea_info_list); } - void display_on_cpu() { + void display_on_cpu() const { VLOG(1) << "feature_size = " << feature_size; VLOG(1) << "node_size = " << node_size; for (uint64_t i = 0; i < feature_size; i++) { VLOG(1) << "feature_list[" << i << "] = " << feature_list[i]; } for (uint64_t i = 0; i < node_size; i++) { - VLOG(1) << "node_id[" << node_list[i].node_id - << "] feature_size = " << node_list[i].feature_size; + VLOG(1) << "node_id[" << node_list[i] + << "] feature_size = " << fea_info_list[i].feature_size; std::string str; - int offset = node_list[i].feature_offset; - for (uint64_t j = 0; j < node_list[i].feature_size; j++) { + uint32_t offset = fea_info_list[i].feature_offset; + for (uint64_t j = 0; j < fea_info_list[i].feature_size; j++) { if (j > 0) str += ","; str += std::to_string(slot_id_list[j + offset]); str += ":"; diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h index e63043e414bbe..7e899eef1b67a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h @@ -12,16 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -namespace paddle { -namespace framework { - +#pragma once #include #include #include #include #include "paddle/fluid/platform/enforce.h" +namespace paddle { +namespace framework { + +#define CUDA_CHECK(cmd) do { \ + cudaError_t e = cmd; \ + CHECK(e == cudaSuccess) \ + << "Cuda failure " \ + << __FILE__ << ":" \ + << __LINE__ << " " \ + << cudaGetErrorString(e) \ + << std::endl; \ +} while (0) + +class CudaDeviceRestorer { +public: + CudaDeviceRestorer() { + cudaGetDevice(&dev_); + } + ~CudaDeviceRestorer() { + cudaSetDevice(dev_); + } +private: + int dev_; +}; + +inline void debug_gpu_memory_info(int gpu_id, const char* desc) { + CudaDeviceRestorer r; + + size_t avail{0}; + size_t total{0}; + cudaSetDevice(gpu_id); + auto err = cudaMemGetInfo(&avail, &total); + PADDLE_ENFORCE_EQ(err, cudaSuccess, + platform::errors::InvalidArgument("cudaMemGetInfo failed!")); + VLOG(0) << "updatex gpu memory on device " << gpu_id << ", " + << "avail=" << avail/1024.0/1024.0/1024.0 << "g, " + << "total=" << total/1024.0/1024.0/1024.0 << "g, " + << "use_rate=" << (total-avail)/double(total) << "%, " + << "desc=" << desc; +} + inline void debug_gpu_memory_info(const char* desc) { + CudaDeviceRestorer r; + int device_num = 0; auto err = cudaGetDeviceCount(&device_num); PADDLE_ENFORCE_EQ(err, cudaSuccess, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 9ec85bab5975d..11a52d631729c 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -27,7 +27,7 @@ DECLARE_double(gpugraph_hbm_table_load_factor); namespace paddle { namespace framework { enum GraphTableType { EDGE_TABLE, FEATURE_TABLE }; -class GpuPsGraphTable : public HeterComm { +class GpuPsGraphTable : public HeterComm { public: int get_table_offset(int gpu_id, GraphTableType type, int idx) const { int type_id = type; @@ -36,9 +36,10 @@ class GpuPsGraphTable : public HeterComm { } GpuPsGraphTable(std::shared_ptr resource, int topo_aware, int graph_table_num) - : HeterComm(1, resource) { + : HeterComm(1, resource) { load_factor_ = FLAGS_gpugraph_hbm_table_load_factor; VLOG(0) << "load_factor = " << load_factor_; + rw_lock.reset(new pthread_rwlock_t()); this->graph_table_num_ = graph_table_num; this->feature_table_num_ = 1; @@ -109,15 +110,15 @@ class GpuPsGraphTable : public HeterComm { } ~GpuPsGraphTable() { } - void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id, int idx); - void build_graph_fea_on_single_gpu(GpuPsCommGraphFea &g, int gpu_id); + void build_graph_on_single_gpu(const GpuPsCommGraph &g, int gpu_id, int idx); + void build_graph_fea_on_single_gpu(const GpuPsCommGraphFea &g, int gpu_id); void clear_graph_info(int gpu_id, int index); void clear_graph_info(int index); void clear_feature_info(int gpu_id, int index); void clear_feature_info(int index); - void build_graph_from_cpu(std::vector &cpu_node_list, + void build_graph_from_cpu(const std::vector &cpu_node_list, int idx); - void build_graph_fea_from_cpu(std::vector &cpu_node_list, + void build_graph_fea_from_cpu(const std::vector &cpu_node_list, int idx); NodeQueryResult graph_node_sample(int gpu_id, int sample_size); NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 9e8e12f2e2775..fbc2abf02e91e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -18,6 +18,7 @@ #include #pragma once #ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" namespace paddle { namespace framework { @@ -68,13 +69,13 @@ __global__ void copy_buffer_ac_to_final_place( } } -__global__ void get_features_kernel(GpuPsCommGraphFea graph, uint64_t* node_offset_array, +__global__ void get_features_kernel(GpuPsCommGraphFea graph, GpuPsFeaInfo* fea_info_array, int* actual_size, uint64_t* feature, int slot_num, int n) { int idx = blockIdx.x * blockDim.y + threadIdx.y; if (idx < n) { - int node_offset = node_offset_array[idx]; + int node_offset = fea_info_array[idx].feature_offset; int offset = idx * slot_num; - if (node_offset == -1) { + if (node_offset == 0) { for (int k = 0; k < slot_num; ++ k) { feature[offset + k] = 0; } @@ -82,12 +83,11 @@ __global__ void get_features_kernel(GpuPsCommGraphFea graph, uint64_t* node_offs return; } - GpuPsGraphFeaNode* node = &(graph.node_list[node_offset]); - uint64_t* feature_start = &(graph.feature_list[node->feature_offset]); - uint8_t* slot_id_start = &(graph.slot_id_list[node->feature_offset]); + uint64_t* feature_start = &(graph.feature_list[fea_info_array[idx].feature_offset]); + uint8_t* slot_id_start = &(graph.slot_id_list[fea_info_array[idx].feature_offset]); int m = 0; for (int k = 0; k < slot_num; ++k) { - if (m >= node->feature_size || k < slot_id_start[m]) { + if (m >= fea_info_array[idx].feature_size || k < slot_id_start[m]) { feature[offset + k] = 0; } else if (k == slot_id_start[m]) { feature[offset + k] = feature_start[m]; @@ -102,7 +102,8 @@ __global__ void get_features_kernel(GpuPsCommGraphFea graph, uint64_t* node_offs template __global__ void neighbor_sample_kernel(GpuPsCommGraph graph, - int64_t* node_index, int* actual_size, + GpuPsNodeInfo* node_info_list, + int* actual_size, uint64_t* res, int sample_len, int n, int default_value) { assert(blockDim.x == WARP_SIZE); @@ -112,15 +113,14 @@ __global__ void neighbor_sample_kernel(GpuPsCommGraph graph, const int last_idx = min(static_cast(blockIdx.x + 1) * TILE_SIZE, n); curandState rng; curand_init(blockIdx.x, threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); - while (i < last_idx) { - if (node_index[i] == -1) { + if (node_info_list[i].neighbor_size == 0) { actual_size[i] = default_value; i += BLOCK_WARPS; continue; } - int neighbor_len = (int)graph.node_list[node_index[i]].neighbor_size; - int64_t data_offset = graph.node_list[node_index[i]].neighbor_offset; + int neighbor_len = (int)node_info_list[i].neighbor_size; + uint32_t data_offset = node_info_list[i].neighbor_offset; int offset = i * sample_len; uint64_t* data = graph.neighbor_list; if (neighbor_len <= sample_len) { @@ -225,30 +225,30 @@ void GpuPsGraphTable::move_result_to_source_gpu( shard_len[i] = h_right[i] - h_left[i] + 1; int cur_step = (int)path_[start_index][i].nodes_.size() - 1; for (int j = cur_step; j > 0; j--) { - cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, + CUDA_CHECK(cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, path_[start_index][i].nodes_[j].val_storage, path_[start_index][i].nodes_[j - 1].val_bytes_len, cudaMemcpyDefault, - path_[start_index][i].nodes_[j - 1].out_stream); + path_[start_index][i].nodes_[j - 1].out_stream)); } auto& node = path_[start_index][i].nodes_.front(); - cudaMemcpyAsync( + CUDA_CHECK(cudaMemcpyAsync( reinterpret_cast(src_sample_res + h_left[i] * sample_size), node.val_storage + sizeof(int64_t) * shard_len[i] + sizeof(int) * (shard_len[i] + shard_len[i] % 2), sizeof(uint64_t) * shard_len[i] * sample_size, cudaMemcpyDefault, - node.out_stream); - cudaMemcpyAsync(reinterpret_cast(actual_sample_size + h_left[i]), + node.out_stream)); + CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast(actual_sample_size + h_left[i]), node.val_storage + sizeof(int64_t) * shard_len[i], sizeof(int) * shard_len[i], cudaMemcpyDefault, - node.out_stream); + node.out_stream)); } for (int i = 0; i < gpu_num; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } auto& node = path_[start_index][i].nodes_.front(); - cudaStreamSynchronize(node.out_stream); + CUDA_CHECK(cudaStreamSynchronize(node.out_stream)); // cudaStreamSynchronize(resource_->remote_stream(i, start_index)); } } @@ -297,7 +297,7 @@ __global__ void node_query_example(GpuPsCommGraph graph, int start, int size, uint64_t* res) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < size) { - res[i] = graph.node_list[start + i].node_id; + res[i] = graph.node_list[start + i]; } } @@ -317,14 +317,12 @@ void GpuPsGraphTable::clear_feature_info(int gpu_id) { auto& graph = gpu_graph_fea_list_[graph_fea_idx]; if (graph.feature_list != NULL) { cudaFree(graph.feature_list); + graph.feature_list = NULL; } if (graph.slot_id_list != NULL) { cudaFree(graph.slot_id_list); - } - - if (graph.node_list != NULL) { - cudaFree(graph.node_list); + graph.slot_id_list = NULL; } } @@ -338,9 +336,11 @@ void GpuPsGraphTable::clear_graph_info(int gpu_id, int idx) { auto& graph = gpu_graph_list_[gpu_id * graph_table_num_ + idx]; if (graph.neighbor_list != NULL) { cudaFree(graph.neighbor_list); + graph.neighbor_list = nullptr; } if (graph.node_list != NULL) { cudaFree(graph.node_list); + graph.node_list = nullptr; } } void GpuPsGraphTable::clear_graph_info(int idx) { @@ -354,11 +354,11 @@ for the ith GpuPsCommGraph, any the node's key satisfies that key % gpu_number In this function, memory is allocated on each gpu to save the graphs, gpu i saves the ith graph from cpu_graph_list */ -void GpuPsGraphTable::build_graph_fea_on_single_gpu(GpuPsCommGraphFea& g, +void GpuPsGraphTable::build_graph_fea_on_single_gpu(const GpuPsCommGraphFea& g, int gpu_id) { clear_feature_info(gpu_id); int ntype_id = 0; - + platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); int offset = gpu_id * feature_table_num_ + ntype_id; @@ -370,19 +370,9 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(GpuPsCommGraphFea& g, size_t capacity = std::max((uint64_t)1, g.node_size) / load_factor_; tables_[table_offset] = new Table(capacity); if (g.node_size > 0) { - std::vector keys; - std::vector offsets; - // TODO - cudaMalloc((void**)&gpu_graph_fea_list_[offset].node_list, - g.node_size * sizeof(GpuPsGraphFeaNode)); - cudaMemcpy(gpu_graph_fea_list_[offset].node_list, g.node_list, - g.node_size * sizeof(GpuPsGraphFeaNode), cudaMemcpyHostToDevice); - for (int64_t j = 0; j < g.node_size; j++) { - keys.push_back(g.node_list[j].node_id); - offsets.push_back(j); - } - build_ps(gpu_id, keys.data(), offsets.data(), keys.size(), 1024, 8, + build_ps(gpu_id, g.node_list, (uint64_t*)g.fea_info_list, g.node_size, 1024, 8, table_offset); + gpu_graph_fea_list_[offset].node_list = NULL; gpu_graph_fea_list_[offset].node_size = g.node_size; } else { build_ps(gpu_id, NULL, NULL, 0, 1024, 8, table_offset); @@ -401,8 +391,8 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(GpuPsCommGraphFea& g, VLOG(0) << "sucessfully allocate " << g.feature_size * sizeof(uint64_t) << " bytes of memory for graph-feature on gpu " << resource_->dev_id(gpu_id); - cudaMemcpy(gpu_graph_fea_list_[offset].feature_list, g.feature_list, - g.feature_size * sizeof(uint64_t), cudaMemcpyHostToDevice); + CUDA_CHECK(cudaMemcpy(gpu_graph_fea_list_[offset].feature_list, g.feature_list, + g.feature_size * sizeof(uint64_t), cudaMemcpyHostToDevice)); // TODO cudaStatus = cudaMalloc((void**)&gpu_graph_fea_list_[offset].slot_id_list, @@ -423,6 +413,8 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(GpuPsCommGraphFea& g, gpu_graph_fea_list_[offset].slot_id_list = NULL; gpu_graph_fea_list_[offset].feature_size = 0; } + VLOG(0) << "gpu node_feature info card :" << gpu_id << " ,node_size is " << gpu_graph_fea_list_[offset].node_size + << ", feature_size is " << gpu_graph_fea_list_[offset].feature_size; } /* @@ -433,7 +425,7 @@ for the ith GpuPsCommGraph, any the node's key satisfies that key % gpu_number In this function, memory is allocated on each gpu to save the graphs, gpu i saves the ith graph from cpu_graph_list */ -void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i, +void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g, int i, int idx) { clear_graph_info(i, idx); platform::CUDADeviceGuard guard(resource_->dev_id(i)); @@ -443,18 +435,13 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i, size_t capacity = std::max((uint64_t)1, (uint64_t)g.node_size) / load_factor_; tables_[table_offset] = new Table(capacity); if (g.node_size > 0) { - std::vector keys; - std::vector offsets; - cudaMalloc((void**)&gpu_graph_list_[offset].node_list, - g.node_size * sizeof(GpuPsGraphNode)); - cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list, - g.node_size * sizeof(GpuPsGraphNode), cudaMemcpyHostToDevice); - for (int64_t j = 0; j < g.node_size; j++) { - keys.push_back(g.node_list[j].node_id); - offsets.push_back(j); - } - build_ps(i, (uint64_t*)keys.data(), offsets.data(), keys.size(), 1024, 8, - table_offset); + CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list, + g.node_size * sizeof(uint64_t))); + CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list, + g.node_size * sizeof(uint64_t), cudaMemcpyHostToDevice)); + + build_ps(i, g.node_list, (uint64_t*)(g.node_info_list), + g.node_size, 1024, 8, table_offset); gpu_graph_list_[offset].node_size = g.node_size; } else { build_ps(i, NULL, NULL, 0, 1024, 8, table_offset); @@ -471,17 +458,20 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i, VLOG(0) << "sucessfully allocate " << g.neighbor_size * sizeof(uint64_t) << " bytes of memory for graph-edges on gpu " << resource_->dev_id(i); - cudaMemcpy(gpu_graph_list_[offset].neighbor_list, g.neighbor_list, - g.neighbor_size * sizeof(uint64_t), cudaMemcpyHostToDevice); + CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].neighbor_list, g.neighbor_list, + g.neighbor_size * sizeof(uint64_t), cudaMemcpyHostToDevice)); gpu_graph_list_[offset].neighbor_size = g.neighbor_size; } else { gpu_graph_list_[offset].neighbor_list = NULL; gpu_graph_list_[offset].neighbor_size = 0; } + VLOG(0) << " gpu node_neighbor info card: " << i << " ,node_size is " << gpu_graph_list_[offset].node_size + << ", neighbor_size is " << gpu_graph_list_[offset].neighbor_size; + } void GpuPsGraphTable::build_graph_fea_from_cpu( - std::vector& cpu_graph_fea_list, int ntype_id) { + const std::vector& cpu_graph_fea_list, int ntype_id) { PADDLE_ENFORCE_EQ( cpu_graph_fea_list.size(), resource_->total_device(), platform::errors::InvalidArgument("the cpu node list size doesn't match " @@ -497,21 +487,8 @@ void GpuPsGraphTable::build_graph_fea_from_cpu( std::max((uint64_t)1, (uint64_t)cpu_graph_fea_list[i].node_size) / load_factor_); if (cpu_graph_fea_list[i].node_size > 0) { - std::vector keys; - std::vector offsets; - // TODO - cudaMalloc((void**)&gpu_graph_fea_list_[offset].node_list, - cpu_graph_fea_list[i].node_size * sizeof(GpuPsGraphNode)); - cudaMemcpy(gpu_graph_fea_list_[offset].node_list, - cpu_graph_fea_list[i].node_list, - cpu_graph_fea_list[i].node_size * sizeof(GpuPsGraphNode), - cudaMemcpyHostToDevice); - for (int64_t j = 0; j < cpu_graph_fea_list[i].node_size; j++) { - keys.push_back(cpu_graph_fea_list[i].node_list[j].node_id); - offsets.push_back(j); - } - build_ps(i, (uint64_t*)(keys.data()), offsets.data(), keys.size(), 1024, - 8, table_offset); + build_ps(i, cpu_graph_fea_list[i].node_list, (uint64_t*)cpu_graph_fea_list[i].fea_info_list, + cpu_graph_fea_list[i].node_size, 1024, 8, table_offset); gpu_graph_fea_list_[offset].node_size = cpu_graph_fea_list[i].node_size; } else { build_ps(i, NULL, NULL, 0, 1024, 8, table_offset); @@ -520,22 +497,22 @@ void GpuPsGraphTable::build_graph_fea_from_cpu( } if (cpu_graph_fea_list[i].feature_size) { // TODO - cudaMalloc((void**)&gpu_graph_fea_list_[offset].feature_list, - cpu_graph_fea_list[i].feature_size * sizeof(uint64_t)); + CUDA_CHECK(cudaMalloc((void**)&gpu_graph_fea_list_[offset].feature_list, + cpu_graph_fea_list[i].feature_size * sizeof(uint64_t))); - cudaMemcpy(gpu_graph_fea_list_[offset].feature_list, + CUDA_CHECK(cudaMemcpy(gpu_graph_fea_list_[offset].feature_list, cpu_graph_fea_list[i].feature_list, cpu_graph_fea_list[i].feature_size * sizeof(uint64_t), - cudaMemcpyHostToDevice); + cudaMemcpyHostToDevice)); // TODO - cudaMalloc((void**)&gpu_graph_fea_list_[offset].slot_id_list, - cpu_graph_fea_list[i].feature_size * sizeof(uint8_t)); + CUDA_CHECK(cudaMalloc((void**)&gpu_graph_fea_list_[offset].slot_id_list, + cpu_graph_fea_list[i].feature_size * sizeof(uint8_t))); - cudaMemcpy(gpu_graph_fea_list_[offset].slot_id_list, + CUDA_CHECK(cudaMemcpy(gpu_graph_fea_list_[offset].slot_id_list, cpu_graph_fea_list[i].slot_id_list, cpu_graph_fea_list[i].feature_size * sizeof(uint8_t), - cudaMemcpyHostToDevice); + cudaMemcpyHostToDevice)); gpu_graph_fea_list_[offset].feature_size = cpu_graph_fea_list[i].feature_size; @@ -549,7 +526,7 @@ void GpuPsGraphTable::build_graph_fea_from_cpu( } void GpuPsGraphTable::build_graph_from_cpu( - std::vector& cpu_graph_list, int idx) { + const std::vector& cpu_graph_list, int idx) { VLOG(0) << "in build_graph_from_cpu cpu_graph_list size = " << cpu_graph_list.size(); PADDLE_ENFORCE_EQ( @@ -566,19 +543,16 @@ void GpuPsGraphTable::build_graph_from_cpu( new Table(std::max((uint64_t)1, (uint64_t)cpu_graph_list[i].node_size) / load_factor_); if (cpu_graph_list[i].node_size > 0) { - std::vector keys; - std::vector offsets; - cudaMalloc((void**)&gpu_graph_list_[offset].node_list, - cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode)); - cudaMemcpy(gpu_graph_list_[offset].node_list, cpu_graph_list[i].node_list, - cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode), - cudaMemcpyHostToDevice); - for (int64_t j = 0; j < cpu_graph_list[i].node_size; j++) { - keys.push_back(cpu_graph_list[i].node_list[j].node_id); - offsets.push_back(j); - } - build_ps(i, (uint64_t*)(keys.data()), offsets.data(), keys.size(), 1024, - 8, table_offset); + CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list, + cpu_graph_list[i].node_size * sizeof(uint64_t))); + CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, + cpu_graph_list[i].node_list, + cpu_graph_list[i].node_size * sizeof(uint64_t), + cudaMemcpyHostToDevice)); + build_ps(i, cpu_graph_list[i].node_list, + (uint64_t*)(cpu_graph_list[i].node_info_list), + cpu_graph_list[i].node_size, + 1024, 8, table_offset); gpu_graph_list_[offset].node_size = cpu_graph_list[i].node_size; } else { build_ps(i, NULL, NULL, 0, 1024, 8, table_offset); @@ -586,20 +560,20 @@ void GpuPsGraphTable::build_graph_from_cpu( gpu_graph_list_[offset].node_size = 0; } if (cpu_graph_list[i].neighbor_size) { - cudaMalloc((void**)&gpu_graph_list_[offset].neighbor_list, - cpu_graph_list[i].neighbor_size * sizeof(uint64_t)); + CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].neighbor_list, + cpu_graph_list[i].neighbor_size * sizeof(uint64_t))); - cudaMemcpy(gpu_graph_list_[offset].neighbor_list, + CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].neighbor_list, cpu_graph_list[i].neighbor_list, cpu_graph_list[i].neighbor_size * sizeof(uint64_t), - cudaMemcpyHostToDevice); + cudaMemcpyHostToDevice)); gpu_graph_list_[offset].neighbor_size = cpu_graph_list[i].neighbor_size; } else { gpu_graph_list_[offset].neighbor_list = NULL; gpu_graph_list_[offset].neighbor_size = 0; } } - cudaDeviceSynchronize(); + CUDA_CHECK(cudaDeviceSynchronize()); } NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3( @@ -628,6 +602,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); + int* actual_sample_size = result.actual_sample_size; uint64_t* val = result.val; int total_gpu = resource_->total_device(); @@ -647,8 +622,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( default_value = -1; } - cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); - cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); + CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream)); // auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); @@ -668,12 +643,12 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, stream); - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); - cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost); - cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost)); for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; if (shard_len == 0) { @@ -681,7 +656,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( } create_storage(gpu_id, i, shard_len * sizeof(uint64_t), shard_len * sample_size * sizeof(uint64_t) + - shard_len * sizeof(int64_t) + + shard_len * sizeof(uint64_t) + sizeof(int) * (shard_len + shard_len % 2)); } walk_to_dest(gpu_id, total_gpu, h_left, h_right, @@ -693,21 +668,22 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( } int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; auto& node = path_[gpu_id][i].nodes_.back(); - cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int64_t), - node.in_stream); - cudaStreamSynchronize(node.in_stream); + + CUDA_CHECK(cudaMemsetAsync(node.val_storage, 0, shard_len * sizeof(int64_t), + node.in_stream)); + CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); platform::CUDADeviceGuard guard(resource_->dev_id(i)); // If not found, val is -1. int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx); int offset = i * graph_table_num_ + idx; tables_[table_offset]->get(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, + reinterpret_cast(node.val_storage), + (size_t)(h_right[i] - h_left[i] + 1), resource_->remote_stream(i, gpu_id)); - - auto graph = gpu_graph_list_[offset]; - int64_t* id_array = reinterpret_cast(node.val_storage); - int* actual_size_array = (int*)(id_array + shard_len); + + auto graph = gpu_graph_list_[offset]; + GpuPsNodeInfo* node_info_list = reinterpret_cast(node.val_storage); + int* actual_size_array = (int*)(node_info_list + shard_len); uint64_t* sample_array = (uint64_t*)(actual_size_array + shard_len + shard_len % 2); constexpr int WARP_SIZE = 32; @@ -715,10 +691,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( constexpr int TILE_SIZE = BLOCK_WARPS * 16; const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 grid((shard_len + TILE_SIZE - 1) / TILE_SIZE); + neighbor_sample_kernel< WARP_SIZE, BLOCK_WARPS, TILE_SIZE><<remote_stream(i, gpu_id)>>>( - graph, id_array, actual_size_array, sample_array, sample_size, + graph, node_info_list, actual_size_array, sample_array, sample_size, shard_len, default_value); } @@ -726,7 +703,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( if (h_left[i] == -1) { continue; } - cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); + CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(i, gpu_id))); } move_result_to_source_gpu(gpu_id, total_gpu, sample_size, h_left, h_right, d_shard_vals_ptr, @@ -735,7 +712,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, sample_size, len); - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); if (cpu_query_switch) { // Get cpu keys and corresponding position. @@ -746,15 +723,15 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( thrust::raw_pointer_cast(t_index.data()), thrust::raw_pointer_cast(t_index.data()) + 1, len); - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); int number_on_cpu = 0; - cudaMemcpy(&number_on_cpu, thrust::raw_pointer_cast(t_index.data()), - sizeof(int), cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpy(&number_on_cpu, thrust::raw_pointer_cast(t_index.data()), + sizeof(int), cudaMemcpyDeviceToHost)); if (number_on_cpu > 0) { uint64_t* cpu_keys = new uint64_t[number_on_cpu]; - cudaMemcpy(cpu_keys, thrust::raw_pointer_cast(t_cpu_keys.data()), - number_on_cpu * sizeof(uint64_t), cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpy(cpu_keys, thrust::raw_pointer_cast(t_cpu_keys.data()), + number_on_cpu * sizeof(uint64_t), cudaMemcpyDeviceToHost)); std::vector> buffers(number_on_cpu); std::vector ac(number_on_cpu); @@ -778,11 +755,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( thrust::device_vector gpu_ac(number_on_cpu); uint64_t* gpu_buffers_ptr = thrust::raw_pointer_cast(gpu_buffers.data()); int* gpu_ac_ptr = thrust::raw_pointer_cast(gpu_ac.data()); - cudaMemcpyAsync(gpu_buffers_ptr, merge_buffers, + CUDA_CHECK(cudaMemcpyAsync(gpu_buffers_ptr, merge_buffers, total_cpu_sample_size * sizeof(uint64_t), - cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(gpu_ac_ptr, ac.data(), number_on_cpu * sizeof(int), - cudaMemcpyHostToDevice, stream); + cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(gpu_ac_ptr, ac.data(), number_on_cpu * sizeof(int), + cudaMemcpyHostToDevice, stream)); // Copy gpu_buffers and gpu_ac using kernel. // Kernel divide for gpu_ac_ptr. @@ -790,7 +767,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( get_actual_gpu_ac<<>>(gpu_ac_ptr, number_on_cpu); - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); thrust::device_vector cumsum_gpu_ac(number_on_cpu); thrust::exclusive_scan(gpu_ac.begin(), gpu_ac.end(), @@ -814,11 +791,12 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( } { - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); - + thrust::device_vector t_actual_sample_size(len); + thrust::copy(actual_sample_size, actual_sample_size + len, t_actual_sample_size.begin()); int total_sample_size = thrust::reduce(t_actual_sample_size.begin(), @@ -829,7 +807,6 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( result.actual_val = (uint64_t*)(result.actual_val_mem)->ptr(); result.set_total_sample_size(total_sample_size); - thrust::device_vector cumsum_actual_sample_size(len); thrust::exclusive_scan(t_actual_sample_size.begin(), t_actual_sample_size.end(), @@ -846,7 +823,6 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( } destroy_storage(gpu_id, i); } - cudaStreamSynchronize(stream); return result; } @@ -859,13 +835,9 @@ NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id, NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int idx, int start, int query_size) { NodeQueryResult result; + result.actual_sample_size = 0; if (query_size <= 0) return result; - int& actual_size = result.actual_sample_size; - actual_size = 0; - // int dev_id = resource_->dev_id(gpu_id); - // platform::CUDADeviceGuard guard(dev_id); std::vector gpu_begin_pos, local_begin_pos; - int sample_size; std::function range_check = []( int x, int y, int x1, int y1, int& x2, int& y2) { if (y <= x1 || x >= y1) return 0; @@ -873,7 +845,9 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int idx, int start, x2 = max(x1, x); return y2 - x2; }; - auto graph = gpu_graph_list_[gpu_id]; + + int offset = gpu_id * graph_table_num_ + idx; + const auto& graph = gpu_graph_list_[offset]; if (graph.node_size == 0) { return result; } @@ -883,19 +857,18 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int idx, int start, if (len == 0) { return result; } - uint64_t* val; - sample_size = len; + result.initialize(len, resource_->dev_id(gpu_id)); - actual_size = len; - val = result.val; + result.actual_sample_size = len; + uint64_t* val = result.val; + int dev_id_i = resource_->dev_id(gpu_id); platform::CUDADeviceGuard guard(dev_id_i); int grid_size = (len - 1) / block_size_ + 1; - int offset = gpu_id * graph_table_num_ + idx; node_query_example<<remote_stream(gpu_id, gpu_id)>>>( - gpu_graph_list_[offset], x2, len, (uint64_t*)val); - cudaStreamSynchronize(resource_->remote_stream(gpu_id, gpu_id)); + graph, x2, len, (uint64_t*)val); + CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(gpu_id, gpu_id))); return result; } @@ -915,8 +888,8 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes, int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); - cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); - cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); + CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream)); // auto d_idx = memory::Alloc(place, node_num * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); @@ -931,12 +904,12 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes, split_input_to_shard(d_nodes, d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id); heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_nodes, d_idx_ptr, node_num, stream); - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); int h_left[total_gpu]; // NOLINT - cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost)); int h_right[total_gpu]; // NOLINT - cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost)); for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; if (shard_len == 0) { @@ -955,19 +928,21 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes, } int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; auto& node = path_[gpu_id][i].nodes_.back(); - cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(uint64_t), node.in_stream); - cudaStreamSynchronize(node.in_stream); + + CUDA_CHECK(cudaMemsetAsync(node.val_storage, 0, shard_len * sizeof(uint64_t), node.in_stream)); + CUDA_CHECK(cudaStreamSynchronize(node.in_stream)); platform::CUDADeviceGuard guard(resource_->dev_id(i)); // If not found, val is -1. int table_offset = get_table_offset(i, GraphTableType::FEATURE_TABLE, 0); tables_[table_offset]->get(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, + reinterpret_cast(node.val_storage), + (size_t)(h_right[i] - h_left[i] + 1), resource_->remote_stream(i, gpu_id)); int offset = i * feature_table_num_; auto graph = gpu_graph_fea_list_[offset]; - uint64_t* val_array = reinterpret_cast(node.val_storage); + + GpuPsFeaInfo* val_array = reinterpret_cast(node.val_storage); int* actual_size_array = (int*)(val_array + shard_len); uint64_t* feature_array = (uint64_t*)(actual_size_array + shard_len + shard_len % 2); dim3 grid((shard_len - 1) / dim_y + 1); @@ -980,7 +955,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes, if (h_left[i] == -1) { continue; } - cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); + CUDA_CHECK(cudaStreamSynchronize(resource_->remote_stream(i, gpu_id))); } move_result_to_source_gpu(gpu_id, total_gpu, slot_num, h_left, h_right, @@ -998,7 +973,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes, destroy_storage(gpu_id, i); } - cudaStreamSynchronize(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); return 0; } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h b/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h index 335508217fb04..261d3bca7e83b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h @@ -100,7 +100,8 @@ class AllInGpuGraphSampler : public GraphSampler { protected: paddle::distributed::GraphTable *graph_table; GpuPsGraphTable *gpu_table; - std::vector> sample_nodes; + std::vector> sample_node_ids; + std::vector> sample_node_infos; std::vector> sample_neighbors; std::vector sample_res; // std::shared_ptr random; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_sampler_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_sampler_inl.h index ae05398c14844..2a75386634782 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_sampler_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_sampler_inl.h @@ -78,18 +78,23 @@ void CommonGraphSampler::init(GpuPsGraphTable *g, int AllInGpuGraphSampler::run_graph_sampling() { return 0; } int AllInGpuGraphSampler::load_from_ssd(std::string path) { graph_table->load_edges(path, false); - sample_nodes.clear(); + sample_node_ids.clear() + sample_node_infos.clear() sample_neighbors.clear(); sample_res.clear(); - sample_nodes.resize(gpu_num); + sample_node_ids.resize(gpu_num); + sample_node_infos.resize(gpu_num); sample_neighbors.resize(gpu_num); sample_res.resize(gpu_num); - std::vector>> - sample_nodes_ex(graph_table->task_pool_size_); + std::vector>> + sample_node_ids_ex(graph_table->task_pool_size_); + std::vector>> + sample_node_infos_ex(graph_table->task_pool_size_); std::vector>> sample_neighbors_ex( graph_table->task_pool_size_); for (int i = 0; i < graph_table->task_pool_size_; i++) { - sample_nodes_ex[i].resize(gpu_num); + sample_node_ids_ex[i].resize(gpu_num); + sample_node_infos_ex[i].resize(gpu_num); sample_neighbors_ex[i].resize(gpu_num); } std::vector> tasks; @@ -98,17 +103,15 @@ int AllInGpuGraphSampler::load_from_ssd(std::string path) { graph_table->_shards_task_pool[i % graph_table->task_pool_size_] ->enqueue([&, i, this]() -> int { if (this->status == GraphSamplerStatus::terminating) return 0; - paddle::framework::GpuPsGraphNode node; + paddle::framework::GpuPsNodeInfo info; std::vector &v = this->graph_table->shards[i]->get_bucket(); size_t ind = i % this->graph_table->task_pool_size_; for (size_t j = 0; j < v.size(); j++) { - size_t location = v[j]->get_id() % this->gpu_num; - node.node_id = v[j]->get_id(); - node.neighbor_size = v[j]->get_neighbor_size(); - node.neighbor_offset = - (int)sample_neighbors_ex[ind][location].size(); - sample_nodes_ex[ind][location].emplace_back(node); + info.neighbor_size = v[j]->get_neighbor_size(); + info.neighbor_offset = sample_neighbors_ex[ind][location].size(); + sample_node_infos_ex[ind][location].emplace_back(info); + sample_node_ids_ex[ind][location].emplace_back(v[j]->get_id()); for (int k = 0; k < node.neighbor_size; k++) sample_neighbors_ex[ind][location].push_back( v[j]->get_neighbor_id(k)); @@ -126,9 +129,10 @@ int AllInGpuGraphSampler::load_from_ssd(std::string path) { int total_offset = 0; size_t ind = i; for (int j = 0; j < this->graph_table->task_pool_size_; j++) { - for (size_t k = 0; k < sample_nodes_ex[j][ind].size(); k++) { - sample_nodes[ind].push_back(sample_nodes_ex[j][ind][k]); - sample_nodes[ind].back().neighbor_offset += total_offset; + for (size_t k = 0; k < sample_node_ids_ex[j][ind].size(); k++) { + sample_node_ids[ind].push_back(sample_node_ids_ex[j][ind][k]); + sample_node_infos[ind].push_back(sample_node_infos_ex[j][ind][k]); + sample_node_infos[ind].back().neighbor_offset += total_offset; } size_t neighbor_size = sample_neighbors_ex[j][ind].size(); total_offset += neighbor_size; @@ -142,9 +146,10 @@ int AllInGpuGraphSampler::load_from_ssd(std::string path) { } for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); for (size_t i = 0; i < gpu_num; i++) { - sample_res[i].node_list = sample_nodes[i].data(); + sample_res[i].node_list = sample_node_ids[i].data(); + sample_res[i].node_info_list = sample_node_infos[i].data(); sample_res[i].neighbor_list = sample_neighbors[i].data(); - sample_res[i].node_size = sample_nodes[i].size(); + sample_res[i].node_size = sample_node_ids[i].size(); sample_res[i].neighbor_size = sample_neighbors[i].size(); } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index ffaddf9965336..f1b332428b6c6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -341,6 +341,7 @@ template class HashTable; template class HashTable; template class HashTable; template class HashTable; +template class HashTable; template class HashTable; template class HashTable; template class HashTable; @@ -362,6 +363,8 @@ template void HashTable::get(const long* d_keys, template void HashTable::get( const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream); +template void HashTable::get( + const unsigned long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( @@ -410,6 +413,10 @@ template void HashTable::insert( template void HashTable::insert( const long* d_keys, const unsigned int* d_vals, size_t len, cudaStream_t stream); + +template void HashTable::insert( + const unsigned long* d_keys, const unsigned long* d_vals, size_t len, + cudaStream_t stream); template void HashTable:: dump_to_cpu(int devid, cudaStream_t stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 9531d74adbaa4..396bea4d55503 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #ifdef PADDLE_WITH_XPU_KP #include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif @@ -153,9 +154,9 @@ void HeterComm::memory_copy( DstPlace dst_place, void* dst, SrcPlace src_place, const void* src, size_t count, StreamType stream) { #if defined(PADDLE_WITH_CUDA) - cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream); + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream)); if (stream == 0) { - cudaStreamSynchronize(0); + CUDA_CHECK(cudaStreamSynchronize(0)); } #elif defined(PADDLE_WITH_XPU_KP) memory::Copy(dst_place, dst, src_place, src, count); @@ -304,36 +305,36 @@ void HeterComm::walk_to_dest( auto& node = path_[start_index][i].nodes_[0]; CopyTask t(&path_[start_index][i], 0); que.push(t); - cudaMemcpyAsync(node.key_storage, + CUDA_CHECK(cudaMemcpyAsync(node.key_storage, reinterpret_cast(src_key + h_left[i]), - node.key_bytes_len, cudaMemcpyDefault, node.in_stream); + node.key_bytes_len, cudaMemcpyDefault, node.in_stream)); if (need_copy_val) { - cudaMemcpyAsync(node.val_storage, + CUDA_CHECK(cudaMemcpyAsync(node.val_storage, src_val + uint64_t(h_left[i]) * uint64_t(val_size), - node.val_bytes_len, cudaMemcpyDefault, node.in_stream); + node.val_bytes_len, cudaMemcpyDefault, node.in_stream)); } } while (!que.empty()) { CopyTask& cur_task = que.front(); que.pop(); if (cur_task.path->nodes_[cur_task.step].sync) { - cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream); + CUDA_CHECK(cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream)); } if (cur_task.step != cur_task.path->nodes_.size() - 1) { int cur_step = cur_task.step; CopyTask c(cur_task.path, cur_step + 1); que.push(c); - cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, + CUDA_CHECK(cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, cur_task.path->nodes_[cur_step].key_storage, cur_task.path->nodes_[cur_step + 1].key_bytes_len, cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream); + cur_task.path->nodes_[cur_step + 1].in_stream)); if (need_copy_val) { - cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, + CUDA_CHECK(cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, cur_task.path->nodes_[cur_step].val_storage, cur_task.path->nodes_[cur_step + 1].val_bytes_len, cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream); + cur_task.path->nodes_[cur_step + 1].in_stream)); } } } @@ -351,17 +352,17 @@ void HeterComm::walk_to_src( int cur_step = path_[start_index][i].nodes_.size() - 1; auto& node = path_[start_index][i].nodes_[cur_step]; if (cur_step == 0) { - cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size, + CUDA_CHECK(cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size, node.val_storage, node.val_bytes_len, cudaMemcpyDefault, - node.out_stream); + node.out_stream)); } else { CopyTask t(&path_[start_index][i], cur_step - 1); que.push(t); - cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage, + CUDA_CHECK(cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage, node.val_storage, path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, cudaMemcpyDefault, - path_[start_index][i].nodes_[cur_step - 1].out_stream); + path_[start_index][i].nodes_[cur_step - 1].out_stream)); } } while (!que.empty()) { @@ -374,18 +375,18 @@ void HeterComm::walk_to_src( if (cur_step > 0) { CopyTask c(cur_task.path, cur_step - 1); que.push(c); - cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, + CUDA_CHECK(cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, cur_task.path->nodes_[cur_step].val_storage, cur_task.path->nodes_[cur_step - 1].val_bytes_len, cudaMemcpyDefault, - cur_task.path->nodes_[cur_step - 1].out_stream); + cur_task.path->nodes_[cur_step - 1].out_stream)); } else if (cur_step == 0) { int end_index = cur_task.path->nodes_.back().dev_num; - cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size, + CUDA_CHECK(cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size, cur_task.path->nodes_[cur_step].val_storage, cur_task.path->nodes_[cur_step].val_bytes_len, cudaMemcpyDefault, - cur_task.path->nodes_[cur_step].out_stream); + cur_task.path->nodes_[cur_step].out_stream)); } } } @@ -511,7 +512,7 @@ void HeterComm::build_ps( if (offset == -1) offset = dev_num; tables_[offset]->insert( reinterpret_cast(d_key_bufs[cur_stream]->ptr()), - reinterpret_cast(d_val_bufs[cur_stream]->ptr()), tmp_len, + reinterpret_cast(d_val_bufs[cur_stream]->ptr()), (size_t)tmp_len, cur_use_stream); cur_stream += 1; diff --git a/paddle/fluid/framework/fleet/heter_ps/test_graph.cu b/paddle/fluid/framework/fleet/heter_ps/test_graph.cu index 06c7026eb51ca..cda7202192b33 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_graph.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_graph.cu @@ -48,15 +48,16 @@ TEST(TEST_FLEET, graph_comm) { } std::vector neighbor_offset(gpu_count, 0), node_index(gpu_count, 0); for (int i = 0; i < graph_list.size(); i++) { - graph_list[i].node_list = new GpuPsGraphNode[graph_list[i].node_size]; + graph_list[i].node_list = new uint64_t[graph_list[i].node_size]; + graph_list[i].node_info_list = new GpuPsNodeInfo[graph_list[i].node_size]; graph_list[i].neighbor_list = new int64_t[graph_list[i].neighbor_size]; } for (int i = 0; i < node_count; i++) { ind = i % gpu_count; - graph_list[ind].node_list[node_index[ind]].node_id = i; - graph_list[ind].node_list[node_index[ind]].neighbor_offset = + graph_list[ind].node_list[node_index[ind]] = i; + graph_list[ind].node_info_list[node_index[ind]].neighbor_offset = neighbor_offset[ind]; - graph_list[ind].node_list[node_index[ind]].neighbor_size = + graph_list[ind].node_info_list[node_index[ind]].neighbor_size = neighbors[i].size(); for (auto x : neighbors[i]) { graph_list[ind].neighbor_list[neighbor_offset[ind]++] = x;