From 8ee52d94059ae93aebcfdc27c6ddb59a0072dfc8 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 10 May 2023 12:26:17 +0800 Subject: [PATCH] move function --- paddle/fluid/framework/data_feed.cu | 1467 +++++++++++++-------------- 1 file changed, 707 insertions(+), 760 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index f353423ea9320..a4bf2faf38afb 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -2541,614 +2541,205 @@ uint64_t CopyUniqueNodes( return 0; } -int FillWalkBuf( - const std::vector &h_device_keys_len, - const std::vector> &d_device_keys, - const std::vector> &meta_path, - const GraphDataGeneratorConfig &conf, - bool *epoch_finish_ptr, - uint64_t *copy_unique_len_ptr, - const paddle::platform::Place &place, - const std::vector &first_node_type, - std::unordered_map *node_type_start_ptr, - std::set *finish_node_type_ptr, - uint64_t *d_walk, // output - uint8_t *d_walk_ntype, - std::shared_ptr *d_uniq_node_num, - int *d_random_row_ptr, - int *d_random_row_col_shift_ptr, - phi::DenseTensor *multi_node_sync_stat_ptr, - std::vector *host_vec_ptr, - int *total_row_ptr, - size_t *jump_rows_ptr, - int *shuffle_seed_ptr, - HashTable *table, - BufState *buf_state, - cudaStream_t stream); - -int FillWalkBufMultiPath( - const std::vector &h_device_keys_len, - const std::vector> &d_device_keys_ptr, - const std::vector> &meta_path, - const GraphDataGeneratorConfig &conf, - bool *epoch_finish_ptr, - uint64_t *copy_unique_len_ptr, - const paddle::platform::Place &place, - const std::vector &first_node_type, - std::unordered_map *node_type_start_ptr, - uint64_t *d_walk, // output - uint8_t *d_walk_ntype, - std::shared_ptr *d_uniq_node_num, - int *d_random_row_ptr, - int *d_random_row_col_shift_ptr, - std::vector *host_vec_ptr, - int *total_row_ptr, - size_t *jump_rows_ptr, - int *shuffle_seed_ptr, - uint64_t *d_train_metapath_keys, - uint64_t *h_train_metapath_keys_len_ptr, - HashTable *table, - BufState *buf_state, - cudaStream_t stream); - int multi_node_sync_sample(int flag, const ncclRedOp_t &op, const paddle::platform::Place &place, - phi::DenseTensor *multi_node_sync_stat_ptr); + phi::DenseTensor *multi_node_sync_stat_ptr) { + if (flag < 0 && flag > 2) { + VLOG(0) << "invalid flag! " << flag; + assert(false); + return -1; + } -void GraphDataGenerator::DoWalkandSage() { - int device_id = place_.GetDeviceId(); - debug_gpu_memory_info(device_id, "DoWalkandSage start"); - platform::CUDADeviceGuard guard(conf_.gpuid); - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - if (conf_.gpu_graph_training) { - // train - bool train_flag; - uint8_t *walk_ntype = NULL; - if (conf_.need_walk_ntype) { - walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); - } + int ret = 0; +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH) + int *stat_ptr = multi_node_sync_stat_ptr->data(); + auto comm = platform::NCCLCommContext::Instance().Get(0, place.GetDeviceId()); + auto stream = comm->stream(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + &stat_ptr[flag], &stat_ptr[3], 1, ncclInt, op, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret, // output + &stat_ptr[3], + sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +#endif + return ret; +} - if (FLAGS_graph_metapath_split_opt) { - train_flag = FillWalkBufMultiPath( - h_device_keys_len_, - d_device_keys_, - gpu_graph_ptr->meta_path_, - conf_, - &epoch_finish_, - ©_unique_len_, - place_, - gpu_graph_ptr->first_node_type_, - &(gpu_graph_ptr->node_type_start_[conf_.gpuid]), - reinterpret_cast(d_walk_->ptr()), - walk_ntype, - &d_uniq_node_num_, - reinterpret_cast(d_random_row_->ptr()), - reinterpret_cast(d_random_row_col_shift_->ptr()), - &host_vec_, - &total_row_, - &jump_rows_, - &shuffle_seed_, - reinterpret_cast(d_train_metapath_keys_->ptr()), - &h_train_metapath_keys_len_, - table_, - &buf_state_, - sample_stream_); - } else { - train_flag = - FillWalkBuf(h_device_keys_len_, - d_device_keys_, - gpu_graph_ptr->meta_path_, - conf_, - &epoch_finish_, - ©_unique_len_, - place_, - gpu_graph_ptr->first_node_type_, - &(gpu_graph_ptr->node_type_start_[conf_.gpuid]), - &(gpu_graph_ptr->finish_node_type_[conf_.gpuid]), - reinterpret_cast(d_walk_->ptr()), - walk_ntype, - &d_uniq_node_num_, - reinterpret_cast(d_random_row_->ptr()), - reinterpret_cast(d_random_row_col_shift_->ptr()), - &multi_node_sync_stat_, - &host_vec_, - &total_row_, - &jump_rows_, - &shuffle_seed_, - table_, - &buf_state_, - sample_stream_); - } +int FillWalkBuf(const std::vector &h_device_keys_len, + const std::vector> + &d_device_keys, // input + const std::vector> &meta_path, // input + const GraphDataGeneratorConfig &conf, + bool *epoch_finish_ptr, + uint64_t *copy_unique_len_ptr, + const paddle::platform::Place &place, + const std::vector &first_node_type, + std::unordered_map *node_type_start_ptr, + std::set *finish_node_type_ptr, + uint64_t *walk, // output + uint8_t *walk_ntype, + std::shared_ptr *d_uniq_node_num, + int *d_random_row, + int *d_random_row_col_shift, + phi::DenseTensor *multi_node_sync_stat_ptr, + std::vector *host_vec_ptr, + int *total_row_ptr, + size_t *jump_rows_ptr, + int *shuffle_seed_ptr, + HashTable *table, + BufState *buf_state, + cudaStream_t stream) { + platform::CUDADeviceGuard guard(conf.gpuid); - if (conf_.sage_mode) { - sage_batch_num_ = 0; - if (train_flag) { - int total_instance = 0, uniq_instance = 0; - bool ins_pair_flag = true; - int sage_pass_end = 0; - uint64_t *ins_buf, *ins_cursor; - while (ins_pair_flag) { - int res = 0; - while (ins_buf_pair_len_ < conf_.batch_size) { - int32_t *pair_label_buf = NULL; - if (d_pair_label_buf_ != NULL) { - pair_label_buf = - reinterpret_cast(d_pair_label_buf_->ptr()); - } - res = FillInsBuf(d_walk_, - d_walk_ntype_, - conf_, - d_random_row_, - d_random_row_col_shift_, - &buf_state_, - reinterpret_cast(d_ins_buf_->ptr()), - pair_label_buf, - reinterpret_cast(d_pair_num_->ptr()), - &ins_buf_pair_len_, - sample_stream_); - if (res == -1) { - if (ins_buf_pair_len_ == 0) { - if (conf_.is_multi_node) { - sage_pass_end = 1; - if (total_row_ != 0) { - buf_state_.Reset(total_row_); - VLOG(1) << "reset buf state to make batch num equal in " - "multi node"; - } - } else { - ins_pair_flag = false; - break; - } - } else { - break; - } - } - } + //////// + uint64_t *h_walk; + if (conf.debug_mode) { + h_walk = new uint64_t[conf.buf_size]; + } + /////// + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + cudaMemsetAsync(walk, 0, conf.buf_size * sizeof(uint64_t), stream); + if (conf.need_walk_ntype) { + cudaMemsetAsync(walk_ntype, 0, conf.buf_size * sizeof(uint8_t), stream); + } + int sample_times = 0; + int i = 0; + *total_row_ptr = 0; - // check whether reach sage pass end - if (conf_.is_multi_node) { - int res = multi_node_sync_sample( - sage_pass_end, ncclProd, place_, &multi_node_sync_stat_); - if (res) { - ins_pair_flag = false; - } - } + std::vector> d_sampleidx2rows; + d_sampleidx2rows.push_back(memory::AllocShared( + place, + conf.once_max_sample_keynum * sizeof(int), + phi::Stream(reinterpret_cast(stream)))); + d_sampleidx2rows.push_back(memory::AllocShared( + place, + conf.once_max_sample_keynum * sizeof(int), + phi::Stream(reinterpret_cast(stream)))); + int cur_sampleidx2row = 0; - if (!ins_pair_flag) { - break; - } + // 获取全局采样状态 + auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index(); + auto &cursor = gpu_graph_ptr->cursor_[conf.thread_id]; + size_t node_type_len = first_node_type.size(); + int remain_size = conf.buf_size - conf.walk_degree * + conf.once_sample_startid_len * + conf.walk_len; + int total_samples = 0; - total_instance = ins_buf_pair_len_ < conf_.batch_size - ? ins_buf_pair_len_ - : conf_.batch_size; - total_instance *= 2; + // Definition of variables related to multi machine sampling + int switch_flag = EVENT_NOT_SWTICH; // Mark whether the local machine needs + // to switch metapath + int switch_command = EVENT_NOT_SWTICH; // Mark whether to switch metapath, + // after multi node sync + int sample_flag = EVENT_CONTINUE_SAMPLE; // Mark whether the local machine + // needs to continue sampling + int sample_command = + EVENT_CONTINUE_SAMPLE; // Mark whether to continue sampling, after multi + // node sync - ins_buf = reinterpret_cast(d_ins_buf_->ptr()); - ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; - auto inverse = memory::AllocShared( - place_, - total_instance * sizeof(int), - phi::Stream(reinterpret_cast(sample_stream_))); - int *inverse_ptr = reinterpret_cast(inverse->ptr()); - auto final_sage_nodes = GenerateSampleGraph(ins_cursor, - total_instance, - &uniq_instance, - inverse_ptr, - conf_, - &graph_edges_vec_, - &edges_split_num_vec_, - &edge_type_graph_, - place_, - sample_stream_); - uint64_t *final_sage_nodes_ptr = - reinterpret_cast(final_sage_nodes->ptr()); - if (conf_.get_degree) { - auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, - uniq_instance, - conf_, - place_, - sample_stream_); - node_degree_vec_.emplace_back(node_degrees); - } + // In the case of a single machine, for scenarios where the d_walk buffer is + // full, epoch sampling ends, and metapath switching occurs, direct decisions + // are made to end the current card sampling or perform metapath switching. + // However, in the case of multiple machines, further decisions can only be + // made after waiting for the multiple machines to synchronize and exchange + // information. + while (1) { + if (i > remain_size) { + // scenarios 1: d_walk is full + if (FLAGS_enable_graph_multi_node_sampling) { + sample_flag = EVENT_WALKBUF_FULL; + } else { + break; + } + } - if (conf_.enable_pair_label) { - auto pair_label = memory::AllocShared( - place_, - total_instance / 2 * sizeof(int), - phi::Stream(reinterpret_cast(sample_stream_))); - int32_t *pair_label_buf = - reinterpret_cast(d_pair_label_buf_->ptr()); - int32_t *pair_label_cursor = - pair_label_buf + ins_buf_pair_len_ - total_instance / 2; - cudaMemcpyAsync(pair_label->ptr(), - pair_label_cursor, - sizeof(int32_t) * total_instance / 2, - cudaMemcpyDeviceToDevice, - sample_stream_); - pair_label_vec_.emplace_back(pair_label); - } + int cur_node_idx = cursor % node_type_len; + int node_type = first_node_type[cur_node_idx]; + auto &path = meta_path[cur_node_idx]; + size_t start = (*node_type_start_ptr)[node_type]; + int type_index = type_to_index[node_type]; + size_t device_key_size = h_device_keys_len[type_index]; + uint64_t *d_type_keys = + reinterpret_cast(d_device_keys[type_index]->ptr()); + int tmp_len = start + conf.once_sample_startid_len > device_key_size + ? device_key_size - start + : conf.once_sample_startid_len; + VLOG(2) << "choose node_type: " << node_type + << " cur_node_idx: " << cur_node_idx + << " meta_path.size: " << meta_path.size() + << " key_size: " << device_key_size << " start: " << start + << " tmp_len: " << tmp_len; + if (tmp_len == 0) { + finish_node_type_ptr->insert(node_type); + if (finish_node_type_ptr->size() == node_type_start_ptr->size()) { + // scenarios 2: epoch finish + if (FLAGS_enable_graph_multi_node_sampling) { + sample_flag = EVENT_FINISH_EPOCH; + } else { + cursor = 0; + *epoch_finish_ptr = true; + break; + } + } - cudaStreamSynchronize(sample_stream_); - if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { - uint64_t *final_sage_nodes_ptr = - reinterpret_cast(final_sage_nodes->ptr()); - InsertTable(final_sage_nodes_ptr, - uniq_instance, - &d_uniq_node_num_, - conf_, - ©_unique_len_, - place_, - table_, - &host_vec_, - sample_stream_); - } - final_sage_nodes_vec_.emplace_back(final_sage_nodes); - inverse_vec_.emplace_back(inverse); - uniq_instance_vec_.emplace_back(uniq_instance); - total_instance_vec_.emplace_back(total_instance); - ins_buf_pair_len_ -= total_instance / 2; - sage_batch_num_ += 1; + // scenarios 3: switch metapath + if (FLAGS_enable_graph_multi_node_sampling) { + if (sample_flag == EVENT_CONTINUE_SAMPLE) { + // Switching only occurs when multi machine sampling continues + switch_flag = EVENT_SWTICH_METAPATH; } - uint64_t h_uniq_node_num = CopyUniqueNodes(table_, - copy_unique_len_, - place_, - d_uniq_node_num_, - &host_vec_, - sample_stream_); - VLOG(1) << "train sage_batch_num: " << sage_batch_num_; + } else { + cursor += 1; + continue; } } - } else { - // infer - bool infer_flag = FillInferBuf(); - if (conf_.sage_mode) { - sage_batch_num_ = 0; - if (infer_flag) { - // Set new batch size for multi_node - if (conf_.is_multi_node) { - int new_batch_size = dynamic_adjust_batch_num_for_sage(); - conf_.batch_size = new_batch_size; - } - - int total_instance = 0, uniq_instance = 0; - total_instance = - (infer_node_start_ + conf_.batch_size <= infer_node_end_) - ? conf_.batch_size - : infer_node_end_ - infer_node_start_; - total_instance *= 2; - while (total_instance != 0) { - uint64_t *d_type_keys = reinterpret_cast( - d_device_keys_[infer_cursor_]->ptr()); - d_type_keys += infer_node_start_; - infer_node_start_ += total_instance / 2; - auto node_buf = memory::AllocShared( - place_, - total_instance * sizeof(uint64_t), - phi::Stream(reinterpret_cast(sample_stream_))); - int64_t *node_buf_ptr = reinterpret_cast(node_buf->ptr()); - CopyDuplicateKeys<<>>( - node_buf_ptr, d_type_keys, total_instance / 2); - uint64_t *node_buf_ptr_ = - reinterpret_cast(node_buf->ptr()); - auto inverse = memory::AllocShared( - place_, - total_instance * sizeof(int), - phi::Stream(reinterpret_cast(sample_stream_))); - int *inverse_ptr = reinterpret_cast(inverse->ptr()); - auto final_sage_nodes = GenerateSampleGraph(node_buf_ptr_, - total_instance, - &uniq_instance, - inverse_ptr, - conf_, - &graph_edges_vec_, - &edges_split_num_vec_, - &edge_type_graph_, - place_, - sample_stream_); - uint64_t *final_sage_nodes_ptr = - reinterpret_cast(final_sage_nodes->ptr()); - if (conf_.get_degree) { - auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, - uniq_instance, - conf_, - place_, - sample_stream_); - node_degree_vec_.emplace_back(node_degrees); - } - cudaStreamSynchronize(sample_stream_); - if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { - uint64_t *final_sage_nodes_ptr = - reinterpret_cast(final_sage_nodes->ptr()); - InsertTable(final_sage_nodes_ptr, - uniq_instance, - &d_uniq_node_num_, - conf_, - ©_unique_len_, - place_, - table_, - &host_vec_, - sample_stream_); - } - final_sage_nodes_vec_.emplace_back(final_sage_nodes); - inverse_vec_.emplace_back(inverse); - uniq_instance_vec_.emplace_back(uniq_instance); - total_instance_vec_.emplace_back(total_instance); - sage_batch_num_ += 1; - total_instance = - (infer_node_start_ + conf_.batch_size <= infer_node_end_) - ? conf_.batch_size - : infer_node_end_ - infer_node_start_; - total_instance *= 2; - } + // Perform synchronous information exchange between multiple machines + // to decide whether to continue sampling + if (FLAGS_enable_graph_multi_node_sampling) { + switch_command = multi_node_sync_sample( + switch_flag, ncclProd, place, multi_node_sync_stat_ptr); + VLOG(2) << "gpuid:" << conf.gpuid << " multi node sample sync" + << " switch_flag:" << switch_flag << "," << switch_command; + if (switch_command) { + cursor += 1; + switch_flag = EVENT_NOT_SWTICH; + continue; + } - uint64_t h_uniq_node_num = CopyUniqueNodes(table_, - copy_unique_len_, - place_, - d_uniq_node_num_, - &host_vec_, - sample_stream_); - VLOG(1) << "infer sage_batch_num: " << sage_batch_num_; + sample_command = multi_node_sync_sample( + sample_flag, ncclMax, place, multi_node_sync_stat_ptr); + VLOG(2) << "gpuid:" << conf.gpuid << " multi node sample sync" + << " sample_flag:" << sample_flag << "," << sample_command; + if (sample_command == EVENT_FINISH_EPOCH) { + // end sampling current epoch + cursor = 0; + *epoch_finish_ptr = true; + VLOG(0) << "sample epoch finish!"; + break; + } else if (sample_command == EVENT_WALKBUF_FULL) { + // end sampling current pass + VLOG(0) << "sample pass finish!"; + break; + } else if (sample_command == EVENT_CONTINUE_SAMPLE) { + // continue sampling + } else { + // shouldn't come here + VLOG(0) << "should not come here, sample_command:" << sample_command; + assert(false); } } - } - debug_gpu_memory_info(device_id, "DoWalkandSage end"); -} -void GraphDataGenerator::clear_gpu_mem() { - platform::CUDADeviceGuard guard(conf_.gpuid); - delete table_; -} - -int GraphDataGenerator::FillInferBuf() { - platform::CUDADeviceGuard guard(conf_.gpuid); - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - auto &global_infer_node_type_start = - gpu_graph_ptr->global_infer_node_type_start_[conf_.gpuid]; - auto &infer_cursor = gpu_graph_ptr->infer_cursor_[conf_.thread_id]; - total_row_ = 0; - if (infer_cursor < h_device_keys_len_.size()) { - while (global_infer_node_type_start[infer_cursor] >= - h_device_keys_len_[infer_cursor]) { - infer_cursor++; - if (infer_cursor >= h_device_keys_len_.size()) { - return 0; - } - } - if (!infer_node_type_index_set_.empty()) { - while (infer_cursor < h_device_keys_len_.size()) { - if (infer_node_type_index_set_.find(infer_cursor) == - infer_node_type_index_set_.end()) { - VLOG(2) << "Skip cursor[" << infer_cursor << "]"; - infer_cursor++; - continue; - } else { - VLOG(2) << "Not skip cursor[" << infer_cursor << "]"; - break; - } - } - if (infer_cursor >= h_device_keys_len_.size()) { - return 0; - } - } - - size_t device_key_size = h_device_keys_len_[infer_cursor]; - total_row_ = - (global_infer_node_type_start[infer_cursor] + conf_.buf_size <= - device_key_size) - ? conf_.buf_size - : device_key_size - global_infer_node_type_start[infer_cursor]; - - uint64_t *d_type_keys = - reinterpret_cast(d_device_keys_[infer_cursor]->ptr()); - if (!conf_.sage_mode) { - host_vec_.resize(total_row_); - cudaMemcpyAsync(host_vec_.data(), - d_type_keys + global_infer_node_type_start[infer_cursor], - sizeof(uint64_t) * total_row_, - cudaMemcpyDeviceToHost, - sample_stream_); - cudaStreamSynchronize(sample_stream_); - } - VLOG(1) << "cursor: " << infer_cursor - << " start: " << global_infer_node_type_start[infer_cursor] - << " num: " << total_row_; - infer_node_start_ = global_infer_node_type_start[infer_cursor]; - global_infer_node_type_start[infer_cursor] += total_row_; - infer_node_end_ = global_infer_node_type_start[infer_cursor]; - infer_cursor_ = infer_cursor; - return 1; - } - return 0; -} - -void GraphDataGenerator::ClearSampleState() { - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - auto &finish_node_type = gpu_graph_ptr->finish_node_type_[conf_.gpuid]; - auto &node_type_start = gpu_graph_ptr->node_type_start_[conf_.gpuid]; - finish_node_type.clear(); - for (auto iter = node_type_start.begin(); iter != node_type_start.end(); - iter++) { - iter->second = 0; - } -} - -int FillWalkBuf(const std::vector &h_device_keys_len, - const std::vector> - &d_device_keys, // input - const std::vector> &meta_path, // input - const GraphDataGeneratorConfig &conf, - bool *epoch_finish_ptr, - uint64_t *copy_unique_len_ptr, - const paddle::platform::Place &place, - const std::vector &first_node_type, - std::unordered_map *node_type_start_ptr, - std::set *finish_node_type_ptr, - uint64_t *walk, // output - uint8_t *walk_ntype, - std::shared_ptr *d_uniq_node_num, - int *d_random_row, - int *d_random_row_col_shift, - phi::DenseTensor *multi_node_sync_stat_ptr, - std::vector *host_vec_ptr, - int *total_row_ptr, - size_t *jump_rows_ptr, - int *shuffle_seed_ptr, - HashTable *table, - BufState *buf_state, - cudaStream_t stream) { - platform::CUDADeviceGuard guard(conf.gpuid); - - //////// - uint64_t *h_walk; - if (conf.debug_mode) { - h_walk = new uint64_t[conf.buf_size]; - } - /////// - auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - cudaMemsetAsync(walk, 0, conf.buf_size * sizeof(uint64_t), stream); - if (conf.need_walk_ntype) { - cudaMemsetAsync(walk_ntype, 0, conf.buf_size * sizeof(uint8_t), stream); - } - int sample_times = 0; - int i = 0; - *total_row_ptr = 0; - - std::vector> d_sampleidx2rows; - d_sampleidx2rows.push_back(memory::AllocShared( - place, - conf.once_max_sample_keynum * sizeof(int), - phi::Stream(reinterpret_cast(stream)))); - d_sampleidx2rows.push_back(memory::AllocShared( - place, - conf.once_max_sample_keynum * sizeof(int), - phi::Stream(reinterpret_cast(stream)))); - int cur_sampleidx2row = 0; - - // 获取全局采样状态 - auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index(); - auto &cursor = gpu_graph_ptr->cursor_[conf.thread_id]; - size_t node_type_len = first_node_type.size(); - int remain_size = conf.buf_size - conf.walk_degree * - conf.once_sample_startid_len * - conf.walk_len; - int total_samples = 0; - - // Definition of variables related to multi machine sampling - int switch_flag = EVENT_NOT_SWTICH; // Mark whether the local machine needs - // to switch metapath - int switch_command = EVENT_NOT_SWTICH; // Mark whether to switch metapath, - // after multi node sync - int sample_flag = EVENT_CONTINUE_SAMPLE; // Mark whether the local machine - // needs to continue sampling - int sample_command = - EVENT_CONTINUE_SAMPLE; // Mark whether to continue sampling, after multi - // node sync - - // In the case of a single machine, for scenarios where the d_walk buffer is - // full, epoch sampling ends, and metapath switching occurs, direct decisions - // are made to end the current card sampling or perform metapath switching. - // However, in the case of multiple machines, further decisions can only be - // made after waiting for the multiple machines to synchronize and exchange - // information. - while (1) { - if (i > remain_size) { - // scenarios 1: d_walk is full - if (FLAGS_enable_graph_multi_node_sampling) { - sample_flag = EVENT_WALKBUF_FULL; - } else { - break; - } - } - - int cur_node_idx = cursor % node_type_len; - int node_type = first_node_type[cur_node_idx]; - auto &path = meta_path[cur_node_idx]; - size_t start = (*node_type_start_ptr)[node_type]; - int type_index = type_to_index[node_type]; - size_t device_key_size = h_device_keys_len[type_index]; - uint64_t *d_type_keys = - reinterpret_cast(d_device_keys[type_index]->ptr()); - int tmp_len = start + conf.once_sample_startid_len > device_key_size - ? device_key_size - start - : conf.once_sample_startid_len; - VLOG(2) << "choose node_type: " << node_type - << " cur_node_idx: " << cur_node_idx - << " meta_path.size: " << meta_path.size() - << " key_size: " << device_key_size << " start: " << start - << " tmp_len: " << tmp_len; - if (tmp_len == 0) { - finish_node_type_ptr->insert(node_type); - if (finish_node_type_ptr->size() == node_type_start_ptr->size()) { - // scenarios 2: epoch finish - if (FLAGS_enable_graph_multi_node_sampling) { - sample_flag = EVENT_FINISH_EPOCH; - } else { - cursor = 0; - *epoch_finish_ptr = true; - break; - } - } - - // scenarios 3: switch metapath - if (FLAGS_enable_graph_multi_node_sampling) { - if (sample_flag == EVENT_CONTINUE_SAMPLE) { - // Switching only occurs when multi machine sampling continues - switch_flag = EVENT_SWTICH_METAPATH; - } - } else { - cursor += 1; - continue; - } - } - - // Perform synchronous information exchange between multiple machines - // to decide whether to continue sampling - if (FLAGS_enable_graph_multi_node_sampling) { - switch_command = multi_node_sync_sample( - switch_flag, ncclProd, place, multi_node_sync_stat_ptr); - VLOG(2) << "gpuid:" << conf.gpuid << " multi node sample sync" - << " switch_flag:" << switch_flag << "," << switch_command; - if (switch_command) { - cursor += 1; - switch_flag = EVENT_NOT_SWTICH; - continue; - } - - sample_command = multi_node_sync_sample( - sample_flag, ncclMax, place, multi_node_sync_stat_ptr); - VLOG(2) << "gpuid:" << conf.gpuid << " multi node sample sync" - << " sample_flag:" << sample_flag << "," << sample_command; - if (sample_command == EVENT_FINISH_EPOCH) { - // end sampling current epoch - cursor = 0; - *epoch_finish_ptr = true; - VLOG(0) << "sample epoch finish!"; - break; - } else if (sample_command == EVENT_WALKBUF_FULL) { - // end sampling current pass - VLOG(0) << "sample pass finish!"; - break; - } else if (sample_command == EVENT_CONTINUE_SAMPLE) { - // continue sampling - } else { - // shouldn't come here - VLOG(0) << "should not come here, sample_command:" << sample_command; - assert(false); - } - } - - int step = 1; - bool update = true; - uint64_t *cur_walk = walk + i; - uint8_t *cur_walk_ntype = NULL; - if (conf.need_walk_ntype) { - cur_walk_ntype = walk_ntype + i; - } + int step = 1; + bool update = true; + uint64_t *cur_walk = walk + i; + uint8_t *cur_walk_ntype = NULL; + if (conf.need_walk_ntype) { + cur_walk_ntype = walk_ntype + i; + } NeighborSampleQuery q; q.initialize(conf.gpuid, @@ -3393,6 +2984,7 @@ int FillWalkBuf(const std::vector &h_device_keys_len, return *total_row_ptr != 0; } + int FillWalkBufMultiPath( const std::vector &h_device_keys_len, const std::vector> &d_device_keys_ptr, @@ -3533,172 +3125,554 @@ int FillWalkBufMultiPath( update = false; break; } - } - } + } + } + + FillOneStep(d_type_keys + start, + path[0], + cur_walk, + cur_walk_ntype, + tmp_len, + &sample_res, + conf.walk_degree, + step, + conf, + &d_sampleidx2rows, + &cur_sampleidx2row, + place, + stream); + ///////// + if (conf.debug_mode) { + cudaMemcpy(h_walk, + walk, + conf.buf_size * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < conf.buf_size; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + } + } + + VLOG(2) << "sample, step=" << step << " sample_keys=" << tmp_len + << " sample_res_len=" << sample_res.total_sample_size; + + ///////// + step++; + size_t path_len = path.size(); + for (; step < conf.walk_len; step++) { + if (sample_res.total_sample_size == 0) { + VLOG(2) << "sample finish, step=" << step; + break; + } + auto sample_key_mem = sample_res.actual_val_mem; + uint64_t *sample_keys_ptr = + reinterpret_cast(sample_key_mem->ptr()); + int edge_type_id = path[(step - 1) % path_len]; + VLOG(2) << "sample edge type: " << edge_type_id << " step: " << step; + q.initialize(conf.gpuid, + edge_type_id, + (uint64_t)sample_keys_ptr, + 1, + sample_res.total_sample_size, + step); + int sample_key_len = sample_res.total_sample_size; + sample_res = gpu_graph_ptr->graph_neighbor_sample_v3( + q, false, true, conf.weighted_sample); + total_samples += sample_res.total_sample_size; + if (!conf.sage_mode) { + if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { + if (InsertTable(sample_res.actual_val, + sample_res.total_sample_size, + d_uniq_node_num, + conf, + copy_unique_len_ptr, + place, + table, + host_vec_ptr, + stream) != 0) { + VLOG(2) << "in step: " << step << ", table is full"; + update = false; + break; + } + } + } + FillOneStep(d_type_keys + start, + edge_type_id, + cur_walk, + cur_walk_ntype, + sample_key_len, + &sample_res, + 1, + step, + conf, + &d_sampleidx2rows, + &cur_sampleidx2row, + place, + stream); + if (conf.debug_mode) { + cudaMemcpy(h_walk, + walk, + conf.buf_size * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < conf.buf_size; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + } + } + + VLOG(2) << "sample, step=" << step << " sample_keys=" << sample_key_len + << " sample_res_len=" << sample_res.total_sample_size; + } + // 此时更新全局采样状态 + if (update == true) { + cur_metapath_start = tmp_len + start; + i += *jump_rows_ptr * conf.walk_len; + *total_row_ptr += *jump_rows_ptr; + sample_times++; + } else { + VLOG(2) << "table is full, not update stat!"; + break; + } + } + buf_state->Reset(*total_row_ptr); + + paddle::memory::ThrustAllocator allocator(place, stream); + thrust::random::default_random_engine engine(*shuffle_seed_ptr); + const auto &exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::counting_iterator cnt_iter(0); + thrust::shuffle_copy(exec_policy, + cnt_iter, + cnt_iter + *total_row_ptr, + thrust::device_pointer_cast(d_random_row), + engine); + + thrust::transform(exec_policy, + cnt_iter, + cnt_iter + *total_row_ptr, + thrust::device_pointer_cast(d_random_row_col_shift), + RandInt(0, conf.walk_len)); + + cudaStreamSynchronize(stream); + *shuffle_seed_ptr = engine(); + + if (conf.debug_mode) { + int *h_random_row = new int[*total_row_ptr + 10]; + cudaMemcpy(h_random_row, + d_random_row, + *total_row_ptr * sizeof(int), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < *total_row_ptr; xx++) { + VLOG(2) << "h_random_row[" << xx << "]: " << h_random_row[xx]; + } + delete[] h_random_row; + delete[] h_walk; + } + + if (!conf.sage_mode) { + uint64_t h_uniq_node_num = CopyUniqueNodes(table, + *copy_unique_len_ptr, + place, + *d_uniq_node_num, + host_vec_ptr, + stream); + VLOG(1) << "sample_times:" << sample_times + << ", d_walk_size:" << conf.buf_size << ", d_walk_offset:" << i + << ", total_rows:" << *total_row_ptr + << ", h_uniq_node_num:" << h_uniq_node_num + << ", total_samples:" << total_samples; + } else { + VLOG(1) << "sample_times:" << sample_times + << ", d_walk_size:" << conf.buf_size << ", d_walk_offset:" << i + << ", total_rows:" << *total_row_ptr + << ", total_samples:" << total_samples; + } + + return *total_row_ptr != 0; +} + +void GraphDataGenerator::DoWalkandSage() { + int device_id = place_.GetDeviceId(); + debug_gpu_memory_info(device_id, "DoWalkandSage start"); + platform::CUDADeviceGuard guard(conf_.gpuid); + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + if (conf_.gpu_graph_training) { + // train + bool train_flag; + uint8_t *walk_ntype = NULL; + if (conf_.need_walk_ntype) { + walk_ntype = reinterpret_cast(d_walk_ntype_->ptr()); + } + + if (FLAGS_graph_metapath_split_opt) { + train_flag = FillWalkBufMultiPath( + h_device_keys_len_, + d_device_keys_, + gpu_graph_ptr->meta_path_, + conf_, + &epoch_finish_, + ©_unique_len_, + place_, + gpu_graph_ptr->first_node_type_, + &(gpu_graph_ptr->node_type_start_[conf_.gpuid]), + reinterpret_cast(d_walk_->ptr()), + walk_ntype, + &d_uniq_node_num_, + reinterpret_cast(d_random_row_->ptr()), + reinterpret_cast(d_random_row_col_shift_->ptr()), + &host_vec_, + &total_row_, + &jump_rows_, + &shuffle_seed_, + reinterpret_cast(d_train_metapath_keys_->ptr()), + &h_train_metapath_keys_len_, + table_, + &buf_state_, + sample_stream_); + } else { + train_flag = + FillWalkBuf(h_device_keys_len_, + d_device_keys_, + gpu_graph_ptr->meta_path_, + conf_, + &epoch_finish_, + ©_unique_len_, + place_, + gpu_graph_ptr->first_node_type_, + &(gpu_graph_ptr->node_type_start_[conf_.gpuid]), + &(gpu_graph_ptr->finish_node_type_[conf_.gpuid]), + reinterpret_cast(d_walk_->ptr()), + walk_ntype, + &d_uniq_node_num_, + reinterpret_cast(d_random_row_->ptr()), + reinterpret_cast(d_random_row_col_shift_->ptr()), + &multi_node_sync_stat_, + &host_vec_, + &total_row_, + &jump_rows_, + &shuffle_seed_, + table_, + &buf_state_, + sample_stream_); + } + + if (conf_.sage_mode) { + sage_batch_num_ = 0; + if (train_flag) { + int total_instance = 0, uniq_instance = 0; + bool ins_pair_flag = true; + int sage_pass_end = 0; + uint64_t *ins_buf, *ins_cursor; + while (ins_pair_flag) { + int res = 0; + while (ins_buf_pair_len_ < conf_.batch_size) { + int32_t *pair_label_buf = NULL; + if (d_pair_label_buf_ != NULL) { + pair_label_buf = + reinterpret_cast(d_pair_label_buf_->ptr()); + } + res = FillInsBuf(d_walk_, + d_walk_ntype_, + conf_, + d_random_row_, + d_random_row_col_shift_, + &buf_state_, + reinterpret_cast(d_ins_buf_->ptr()), + pair_label_buf, + reinterpret_cast(d_pair_num_->ptr()), + &ins_buf_pair_len_, + sample_stream_); + if (res == -1) { + if (ins_buf_pair_len_ == 0) { + if (conf_.is_multi_node) { + sage_pass_end = 1; + if (total_row_ != 0) { + buf_state_.Reset(total_row_); + VLOG(1) << "reset buf state to make batch num equal in " + "multi node"; + } + } else { + ins_pair_flag = false; + break; + } + } else { + break; + } + } + } + + // check whether reach sage pass end + if (conf_.is_multi_node) { + int res = multi_node_sync_sample( + sage_pass_end, ncclProd, place_, &multi_node_sync_stat_); + if (res) { + ins_pair_flag = false; + } + } + + if (!ins_pair_flag) { + break; + } + + total_instance = ins_buf_pair_len_ < conf_.batch_size + ? ins_buf_pair_len_ + : conf_.batch_size; + total_instance *= 2; + + ins_buf = reinterpret_cast(d_ins_buf_->ptr()); + ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; + auto inverse = memory::AllocShared( + place_, + total_instance * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); + int *inverse_ptr = reinterpret_cast(inverse->ptr()); + auto final_sage_nodes = GenerateSampleGraph(ins_cursor, + total_instance, + &uniq_instance, + inverse_ptr, + conf_, + &graph_edges_vec_, + &edges_split_num_vec_, + &edge_type_graph_, + place_, + sample_stream_); + uint64_t *final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + if (conf_.get_degree) { + auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, + uniq_instance, + conf_, + place_, + sample_stream_); + node_degree_vec_.emplace_back(node_degrees); + } + + if (conf_.enable_pair_label) { + auto pair_label = memory::AllocShared( + place_, + total_instance / 2 * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); + int32_t *pair_label_buf = + reinterpret_cast(d_pair_label_buf_->ptr()); + int32_t *pair_label_cursor = + pair_label_buf + ins_buf_pair_len_ - total_instance / 2; + cudaMemcpyAsync(pair_label->ptr(), + pair_label_cursor, + sizeof(int32_t) * total_instance / 2, + cudaMemcpyDeviceToDevice, + sample_stream_); + pair_label_vec_.emplace_back(pair_label); + } + + cudaStreamSynchronize(sample_stream_); + if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { + uint64_t *final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + InsertTable(final_sage_nodes_ptr, + uniq_instance, + &d_uniq_node_num_, + conf_, + ©_unique_len_, + place_, + table_, + &host_vec_, + sample_stream_); + } + final_sage_nodes_vec_.emplace_back(final_sage_nodes); + inverse_vec_.emplace_back(inverse); + uniq_instance_vec_.emplace_back(uniq_instance); + total_instance_vec_.emplace_back(total_instance); + ins_buf_pair_len_ -= total_instance / 2; + sage_batch_num_ += 1; + } + uint64_t h_uniq_node_num = CopyUniqueNodes(table_, + copy_unique_len_, + place_, + d_uniq_node_num_, + &host_vec_, + sample_stream_); + VLOG(1) << "train sage_batch_num: " << sage_batch_num_; + } + } + } else { + // infer + bool infer_flag = FillInferBuf(); + if (conf_.sage_mode) { + sage_batch_num_ = 0; + if (infer_flag) { + // Set new batch size for multi_node + if (conf_.is_multi_node) { + int new_batch_size = dynamic_adjust_batch_num_for_sage(); + conf_.batch_size = new_batch_size; + } + + int total_instance = 0, uniq_instance = 0; + total_instance = + (infer_node_start_ + conf_.batch_size <= infer_node_end_) + ? conf_.batch_size + : infer_node_end_ - infer_node_start_; + total_instance *= 2; + while (total_instance != 0) { + uint64_t *d_type_keys = reinterpret_cast( + d_device_keys_[infer_cursor_]->ptr()); + d_type_keys += infer_node_start_; + infer_node_start_ += total_instance / 2; + auto node_buf = memory::AllocShared( + place_, + total_instance * sizeof(uint64_t), + phi::Stream(reinterpret_cast(sample_stream_))); + int64_t *node_buf_ptr = reinterpret_cast(node_buf->ptr()); + CopyDuplicateKeys<<>>( + node_buf_ptr, d_type_keys, total_instance / 2); + uint64_t *node_buf_ptr_ = + reinterpret_cast(node_buf->ptr()); + auto inverse = memory::AllocShared( + place_, + total_instance * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); + int *inverse_ptr = reinterpret_cast(inverse->ptr()); + auto final_sage_nodes = GenerateSampleGraph(node_buf_ptr_, + total_instance, + &uniq_instance, + inverse_ptr, + conf_, + &graph_edges_vec_, + &edges_split_num_vec_, + &edge_type_graph_, + place_, + sample_stream_); + uint64_t *final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + if (conf_.get_degree) { + auto node_degrees = GetNodeDegree(final_sage_nodes_ptr, + uniq_instance, + conf_, + place_, + sample_stream_); + node_degree_vec_.emplace_back(node_degrees); + } + cudaStreamSynchronize(sample_stream_); + if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { + uint64_t *final_sage_nodes_ptr = + reinterpret_cast(final_sage_nodes->ptr()); + InsertTable(final_sage_nodes_ptr, + uniq_instance, + &d_uniq_node_num_, + conf_, + ©_unique_len_, + place_, + table_, + &host_vec_, + sample_stream_); + } + final_sage_nodes_vec_.emplace_back(final_sage_nodes); + inverse_vec_.emplace_back(inverse); + uniq_instance_vec_.emplace_back(uniq_instance); + total_instance_vec_.emplace_back(total_instance); + sage_batch_num_ += 1; + + total_instance = + (infer_node_start_ + conf_.batch_size <= infer_node_end_) + ? conf_.batch_size + : infer_node_end_ - infer_node_start_; + total_instance *= 2; + } - FillOneStep(d_type_keys + start, - path[0], - cur_walk, - cur_walk_ntype, - tmp_len, - &sample_res, - conf.walk_degree, - step, - conf, - &d_sampleidx2rows, - &cur_sampleidx2row, - place, - stream); - ///////// - if (conf.debug_mode) { - cudaMemcpy(h_walk, - walk, - conf.buf_size * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < conf.buf_size; xx++) { - VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + uint64_t h_uniq_node_num = CopyUniqueNodes(table_, + copy_unique_len_, + place_, + d_uniq_node_num_, + &host_vec_, + sample_stream_); + VLOG(1) << "infer sage_batch_num: " << sage_batch_num_; } } + } + debug_gpu_memory_info(device_id, "DoWalkandSage end"); +} - VLOG(2) << "sample, step=" << step << " sample_keys=" << tmp_len - << " sample_res_len=" << sample_res.total_sample_size; +void GraphDataGenerator::clear_gpu_mem() { + platform::CUDADeviceGuard guard(conf_.gpuid); + delete table_; +} - ///////// - step++; - size_t path_len = path.size(); - for (; step < conf.walk_len; step++) { - if (sample_res.total_sample_size == 0) { - VLOG(2) << "sample finish, step=" << step; - break; +int GraphDataGenerator::FillInferBuf() { + platform::CUDADeviceGuard guard(conf_.gpuid); + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + auto &global_infer_node_type_start = + gpu_graph_ptr->global_infer_node_type_start_[conf_.gpuid]; + auto &infer_cursor = gpu_graph_ptr->infer_cursor_[conf_.thread_id]; + total_row_ = 0; + if (infer_cursor < h_device_keys_len_.size()) { + while (global_infer_node_type_start[infer_cursor] >= + h_device_keys_len_[infer_cursor]) { + infer_cursor++; + if (infer_cursor >= h_device_keys_len_.size()) { + return 0; } - auto sample_key_mem = sample_res.actual_val_mem; - uint64_t *sample_keys_ptr = - reinterpret_cast(sample_key_mem->ptr()); - int edge_type_id = path[(step - 1) % path_len]; - VLOG(2) << "sample edge type: " << edge_type_id << " step: " << step; - q.initialize(conf.gpuid, - edge_type_id, - (uint64_t)sample_keys_ptr, - 1, - sample_res.total_sample_size, - step); - int sample_key_len = sample_res.total_sample_size; - sample_res = gpu_graph_ptr->graph_neighbor_sample_v3( - q, false, true, conf.weighted_sample); - total_samples += sample_res.total_sample_size; - if (!conf.sage_mode) { - if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { - if (InsertTable(sample_res.actual_val, - sample_res.total_sample_size, - d_uniq_node_num, - conf, - copy_unique_len_ptr, - place, - table, - host_vec_ptr, - stream) != 0) { - VLOG(2) << "in step: " << step << ", table is full"; - update = false; - break; - } + } + if (!infer_node_type_index_set_.empty()) { + while (infer_cursor < h_device_keys_len_.size()) { + if (infer_node_type_index_set_.find(infer_cursor) == + infer_node_type_index_set_.end()) { + VLOG(2) << "Skip cursor[" << infer_cursor << "]"; + infer_cursor++; + continue; + } else { + VLOG(2) << "Not skip cursor[" << infer_cursor << "]"; + break; } } - FillOneStep(d_type_keys + start, - edge_type_id, - cur_walk, - cur_walk_ntype, - sample_key_len, - &sample_res, - 1, - step, - conf, - &d_sampleidx2rows, - &cur_sampleidx2row, - place, - stream); - if (conf.debug_mode) { - cudaMemcpy(h_walk, - walk, - conf.buf_size * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < conf.buf_size; xx++) { - VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; - } + if (infer_cursor >= h_device_keys_len_.size()) { + return 0; } - - VLOG(2) << "sample, step=" << step << " sample_keys=" << sample_key_len - << " sample_res_len=" << sample_res.total_sample_size; - } - // 此时更新全局采样状态 - if (update == true) { - cur_metapath_start = tmp_len + start; - i += *jump_rows_ptr * conf.walk_len; - *total_row_ptr += *jump_rows_ptr; - sample_times++; - } else { - VLOG(2) << "table is full, not update stat!"; - break; } - } - buf_state->Reset(*total_row_ptr); - - paddle::memory::ThrustAllocator allocator(place, stream); - thrust::random::default_random_engine engine(*shuffle_seed_ptr); - const auto &exec_policy = thrust::cuda::par(allocator).on(stream); - thrust::counting_iterator cnt_iter(0); - thrust::shuffle_copy(exec_policy, - cnt_iter, - cnt_iter + *total_row_ptr, - thrust::device_pointer_cast(d_random_row), - engine); - - thrust::transform(exec_policy, - cnt_iter, - cnt_iter + *total_row_ptr, - thrust::device_pointer_cast(d_random_row_col_shift), - RandInt(0, conf.walk_len)); - cudaStreamSynchronize(stream); - *shuffle_seed_ptr = engine(); + size_t device_key_size = h_device_keys_len_[infer_cursor]; + total_row_ = + (global_infer_node_type_start[infer_cursor] + conf_.buf_size <= + device_key_size) + ? conf_.buf_size + : device_key_size - global_infer_node_type_start[infer_cursor]; - if (conf.debug_mode) { - int *h_random_row = new int[*total_row_ptr + 10]; - cudaMemcpy(h_random_row, - d_random_row, - *total_row_ptr * sizeof(int), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < *total_row_ptr; xx++) { - VLOG(2) << "h_random_row[" << xx << "]: " << h_random_row[xx]; + uint64_t *d_type_keys = + reinterpret_cast(d_device_keys_[infer_cursor]->ptr()); + if (!conf_.sage_mode) { + host_vec_.resize(total_row_); + cudaMemcpyAsync(host_vec_.data(), + d_type_keys + global_infer_node_type_start[infer_cursor], + sizeof(uint64_t) * total_row_, + cudaMemcpyDeviceToHost, + sample_stream_); + cudaStreamSynchronize(sample_stream_); } - delete[] h_random_row; - delete[] h_walk; + VLOG(1) << "cursor: " << infer_cursor + << " start: " << global_infer_node_type_start[infer_cursor] + << " num: " << total_row_; + infer_node_start_ = global_infer_node_type_start[infer_cursor]; + global_infer_node_type_start[infer_cursor] += total_row_; + infer_node_end_ = global_infer_node_type_start[infer_cursor]; + infer_cursor_ = infer_cursor; + return 1; } + return 0; +} - if (!conf.sage_mode) { - uint64_t h_uniq_node_num = CopyUniqueNodes(table, - *copy_unique_len_ptr, - place, - *d_uniq_node_num, - host_vec_ptr, - stream); - VLOG(1) << "sample_times:" << sample_times - << ", d_walk_size:" << conf.buf_size << ", d_walk_offset:" << i - << ", total_rows:" << *total_row_ptr - << ", h_uniq_node_num:" << h_uniq_node_num - << ", total_samples:" << total_samples; - } else { - VLOG(1) << "sample_times:" << sample_times - << ", d_walk_size:" << conf.buf_size << ", d_walk_offset:" << i - << ", total_rows:" << *total_row_ptr - << ", total_samples:" << total_samples; +void GraphDataGenerator::ClearSampleState() { + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + auto &finish_node_type = gpu_graph_ptr->finish_node_type_[conf_.gpuid]; + auto &node_type_start = gpu_graph_ptr->node_type_start_[conf_.gpuid]; + finish_node_type.clear(); + for (auto iter = node_type_start.begin(); iter != node_type_start.end(); + iter++) { + iter->second = 0; } - - return *total_row_ptr != 0; } void GraphDataGenerator::SetFeedVec(std::vector feed_vec) { feed_vec_ = feed_vec; } + void GraphDataGenerator::SetFeedInfo(std::vector* feed_info) { feed_info_ = feed_info; for (int i = 0; i < conf_.slot_num; i++) { @@ -4056,33 +4030,6 @@ void GraphDataGenerator::DumpWalkPath(std::string dump_path, size_t dump_rate) { #endif } -int multi_node_sync_sample(int flag, - const ncclRedOp_t &op, - const paddle::platform::Place &place, - phi::DenseTensor *multi_node_sync_stat_ptr) { - if (flag < 0 && flag > 2) { - VLOG(0) << "invalid flag! " << flag; - assert(false); - return -1; - } - - int ret = 0; -#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH) - int *stat_ptr = multi_node_sync_stat_ptr->data(); - auto comm = platform::NCCLCommContext::Instance().Get(0, place.GetDeviceId()); - auto stream = comm->stream(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - &stat_ptr[flag], &stat_ptr[3], 1, ncclInt, op, comm->comm(), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret, // output - &stat_ptr[3], - sizeof(int), - cudaMemcpyDeviceToHost, - stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); -#endif - return ret; -} - int GraphDataGenerator::dynamic_adjust_batch_num_for_sage() { int batch_num = (total_row_ + conf_.batch_size - 1) / conf_.batch_size; auto send_buff = memory::Alloc(