Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpugraph #45

Merged
merged 4 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 249 additions & 100 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ class GraphTable : public Table {
const FsClientParameter &fs_config);
virtual int32_t Initialize(const GraphParameter &config);
int32_t Load(const std::string &path, const std::string &param);

int32_t load_node_and_edge_file(std::string etype, std::string ntype, std::string epath,
std::string npath, int part_num, bool reverse);

std::string get_inverse_etype(std::string &etype);

int32_t load_edges(const std::string &path, bool reverse,
const std::string &edge_type);
Expand All @@ -506,11 +511,10 @@ class GraphTable : public Table {
int slice_num);
int get_all_feature_ids(int type, int idx,
int slice_num, std::vector<std::vector<uint64_t>>* output);
int32_t load_nodes(const std::string &path, std::string node_type);
int32_t parse_edge_file(const std::string &path, int idx, bool reverse,
uint64_t &count, uint64_t &valid_count);
int32_t parse_node_file(const std::string &path, const std::string &node_type,
int idx, uint64_t &count, uint64_t &valid_count);
int32_t load_nodes(const std::string &path, std::string node_type = std::string());
std::pair<uint64_t, uint64_t> parse_edge_file(const std::string &path, int idx, bool reverse);
std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path, const std::string &node_type, int idx);
std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path);
int32_t add_graph_node(int idx, std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list);

Expand Down Expand Up @@ -620,7 +624,8 @@ class GraphTable : public Table {
std::vector<std::vector<GraphShard *>> edge_shards, feature_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
int task_pool_size_ = 24;
int load_thread_num = 150;
int load_thread_num = 160;

const int random_sample_nodes_ranges = 3;

std::vector<std::vector<std::unordered_map<uint64_t, double>>> node_weight;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License. */
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
#endif

DECLARE_int32(record_pool_max_size);
Expand Down Expand Up @@ -422,7 +423,6 @@ struct UsedSlotGpuType {
};

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
#define CUDA_CHECK(val) CHECK(val == gpuSuccess)
template <typename T>
struct CudaBuffer {
T* cu_buffer;
Expand Down
196 changes: 178 additions & 18 deletions paddle/fluid/framework/device_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ limitations under the License. */

#include "paddle/fluid/framework/device_worker.h"

#include <chrono>
#include "paddle/fluid/framework/convert_utils.h"

namespace phi {
class DenseTensor;
} // namespace phi
Expand Down Expand Up @@ -52,7 +52,55 @@ std::string PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end,
}
return os.str();
}
template <typename T>
void PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end,
std::string& out_val, char separator = ':',
bool need_leading_separator = true) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
out_val += "access violation";
return;
}
if (start >= end) return;
if (!need_leading_separator) {
out_val += std::to_string(tensor->data<T>()[start]);
// os << tensor->data<T>()[start];
start++;
}
for (int64_t i = start; i < end; i++) {
// os << ":" << tensor->data<T>()[i];
// os << separator << tensor->data<T>()[i];
out_val += separator;
out_val += std::to_string(tensor->data<T>()[i]);
}
}

#define FLOAT_EPS 1e-8
#define MAX_FLOAT_BUFF_SIZE 40
template <>
void PrintLodTensorType<float>(Tensor* tensor, int64_t start, int64_t end,
std::string& out_val, char separator,
bool need_leading_separator) {
char buf[MAX_FLOAT_BUFF_SIZE];
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
out_val += "access violation";
return;
}
if (start >= end) return;
for (int64_t i = start; i < end; i++) {
if (i != start || need_leading_separator) out_val += separator;
if (tensor->data<float>()[i] > -FLOAT_EPS &&
tensor->data<float>()[i] < FLOAT_EPS)
out_val += "0";
else {
sprintf(buf, "%.9f", tensor->data<float>()[i]);
out_val += buf;
}
}
}
std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end,
char separator = ':',
bool need_leading_separator = true) {
Expand All @@ -74,6 +122,31 @@ std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end,
return os.str();
}

void PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end,
std::string& out_val, char separator = ':',
bool need_leading_separator = true) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
out_val += "access violation";
return;
}
if (start >= end) return;
if (!need_leading_separator) {
out_val +=
std::to_string(static_cast<uint64_t>(tensor->data<int64_t>()[start]));
start++;
}
for (int64_t i = start; i < end; i++) {
// os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
// os << separator << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
out_val += separator;
out_val +=
std::to_string(static_cast<uint64_t>(tensor->data<int64_t>()[i]));
}
// return os.str();
}

std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end,
char separator, bool need_leading_separator) {
std::string out_val;
Expand All @@ -94,6 +167,25 @@ std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end,
return out_val;
}

void PrintLodTensor(Tensor* tensor, int64_t start, int64_t end,
std::string& out_val, char separator,
bool need_leading_separator) {
if (framework::TransToProtoVarType(tensor->dtype()) == proto::VarType::FP32) {
PrintLodTensorType<float>(tensor, start, end, out_val, separator,
need_leading_separator);
} else if (framework::TransToProtoVarType(tensor->dtype()) ==
proto::VarType::INT64) {
PrintLodTensorIntType(tensor, start, end, out_val, separator,
need_leading_separator);
} else if (framework::TransToProtoVarType(tensor->dtype()) ==
proto::VarType::FP64) {
PrintLodTensorType<double>(tensor, start, end, out_val, separator,
need_leading_separator);
} else {
out_val += "unsupported type";
}
}

std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
auto& dims = tensor->dims();
if (tensor->lod().size() != 0) {
Expand Down Expand Up @@ -164,7 +256,9 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode,
int dump_interval) { // dump_mode: 0: no random,
// 1: random with insid hash,
// 2: random with random
// 3: simple mode
// 3: simple mode using multi-threads, for gpugraphps-mode
auto start1 = std::chrono::steady_clock::now();

size_t batch_size = device_reader_->GetCurBatchSize();
auto& ins_id_vec = device_reader_->GetInsIdVec();
auto& ins_content_vec = device_reader_->GetInsContentVec();
Expand All @@ -173,7 +267,81 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode,
}
std::vector<std::string> ars(batch_size);
std::vector<bool> hit(batch_size, false);
if (dump_mode_ == 3) {
if (dump_fields_ == NULL || (*dump_fields_).size() == 0) {
return;
}
auto set_output_str = [&, this](size_t begin, size_t end,
LoDTensor* tensor) {
for (size_t i = begin; i < end; ++i) {
auto bound = GetTensorBound(tensor, i);
if (ars[i].size() > 0) ars[i] += "\t";
// ars[i] += '[';
PrintLodTensor(tensor, bound.first, bound.second, ars[i], ' ', false);
// ars[i] += ']';
// ars[i] += "<" + PrintLodTensor(tensor, bound.first, bound.second, '
// ', false) + ">";
}
};
std::thread threads[tensor_iterator_thread_num];
for (auto& field : *dump_fields_) {
Variable* var = scope.FindVar(field);
if (var == nullptr) {
VLOG(0) << "Note: field[" << field
<< "] cannot be find in scope, so it was skipped.";
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (!tensor->IsInitialized()) {
VLOG(0) << "Note: field[" << field
<< "] is not initialized, so it was skipped.";
continue;
}
framework::LoDTensor cpu_tensor;
if (platform::is_gpu_place(tensor->place())) {
TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
cpu_tensor.set_lod(tensor->lod());
tensor = &cpu_tensor;
}
if (!CheckValidOutput(tensor, batch_size)) {
VLOG(0) << "Note: field[" << field << "] cannot pass check, so it was "
"skipped. Maybe the dimension is "
"wrong ";
continue;
}
size_t acutal_thread_num =
std::min((size_t)batch_size, tensor_iterator_thread_num);
for (size_t i = 0; i < acutal_thread_num; i++) {
size_t average_size = batch_size / acutal_thread_num;
size_t begin =
average_size * i + std::min(batch_size % acutal_thread_num, i);
size_t end =
begin + average_size + (i < batch_size % acutal_thread_num ? 1 : 0);
threads[i] = std::thread(set_output_str, begin, end, tensor);
}
for (size_t i = 0; i < acutal_thread_num; i++) threads[i].join();
}
auto end1 = std::chrono::steady_clock::now();
auto tt =
std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1);
VLOG(1) << "writing a batch takes " << tt.count() << " us";

size_t acutal_thread_num =
std::min((size_t)batch_size, tensor_iterator_thread_num);
for (size_t i = 0; i < acutal_thread_num; i++) {
size_t average_size = batch_size / acutal_thread_num;
size_t begin =
average_size * i + std::min(batch_size % acutal_thread_num, i);
size_t end =
begin + average_size + (i < batch_size % acutal_thread_num ? 1 : 0);
for (size_t j = begin + 1; j < end; j++) {
if (ars[begin].size() > 0 && ars[j].size() > 0) ars[begin] += "\n";
ars[begin] += ars[j];
}
if (ars[begin].size() > 0) writer_ << ars[begin];
}
return;
}
std::default_random_engine engine(0);
std::uniform_int_distribution<size_t> dist(0U, INT_MAX);
for (size_t i = 0; i < batch_size; i++) {
Expand All @@ -189,14 +357,12 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode,
hit[i] = true;
}

if (dump_mode_ != 3) {
for (size_t i = 0; i < ins_id_vec.size(); i++) {
if (!hit[i]) {
continue;
}
ars[i] += ins_id_vec[i];
ars[i] = ars[i] + "\t" + ins_content_vec[i];
for (size_t i = 0; i < ins_id_vec.size(); i++) {
if (!hit[i]) {
continue;
}
ars[i] += ins_id_vec[i];
ars[i] += "\t" + ins_content_vec[i];
}
for (auto& field : *dump_fields_) {
Variable* var = scope.FindVar(field);
Expand All @@ -223,22 +389,16 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode,
"wrong ";
continue;
}

for (size_t i = 0; i < batch_size; ++i) {
if (!hit[i]) {
continue;
}
auto bound = GetTensorBound(tensor, i);
if (dump_mode_ == 3) {
if (ars[i].size() > 0) ars[i] += "\t";
ars[i] += PrintLodTensor(tensor, bound.first, bound.second, ' ', false);
} else {
ars[i] = ars[i] + "\t" + field + ":" +
std::to_string(bound.second - bound.first);
ars[i] += PrintLodTensor(tensor, bound.first, bound.second);
}
ars[i] += "\t" + field + ":" + std::to_string(bound.second - bound.first);
ars[i] += PrintLodTensor(tensor, bound.first, bound.second);
}
}

// #pragma omp parallel for
for (size_t i = 0; i < ars.size(); i++) {
if (ars[i].length() == 0) {
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#endif

#include <map>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_util.h"
Expand Down Expand Up @@ -62,6 +63,9 @@ namespace framework {
std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end,
char separator = ',',
bool need_leading_separator = false);
void PrintLodTensor(Tensor* tensor, int64_t start, int64_t end,
std::string& output_str, char separator = ',',
bool need_leading_separator = false);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);

Expand Down Expand Up @@ -231,6 +235,7 @@ class DeviceWorker {
int dump_mode_ = 0;
int dump_interval_ = 10000;
ChannelWriter<std::string> writer_;
const size_t tensor_iterator_thread_num = 16;
platform::DeviceContext* dev_ctx_ = nullptr;
};

Expand Down Expand Up @@ -770,7 +775,6 @@ class HeterSectionWorker : public DeviceWorker {
static uint64_t batch_id_;
uint64_t total_ins_num_ = 0;
platform::DeviceContext* dev_ctx_ = nullptr;

bool debug_ = false;
std::vector<double> op_total_time_;
std::vector<std::string> op_name_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,14 @@ x.second );

__host__ void print_collision(int id) {
if (m_enable_collision_stat) {
printf("collision stat for hbm table %d, insert(%lu:%lu), query(%lu:%lu)\n",
printf("collision stat for hbm table %d, insert(%lu:%lu:%.2f), query(%lu:%lu:%.2f)\n",
id,
m_insert_times,
m_insert_collisions,
m_insert_collisions / (double)m_insert_times,
m_query_times,
m_query_collisions);
m_query_collisions,
m_query_collisions / (double)m_query_times);
}
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct GpuAccessorInfo {
size_t dim;
// value各个维度的size
size_t size;
// embedx维度
size_t embedx_dim;
// push value维度
size_t update_dim;
// push value各个维度的size
Expand Down Expand Up @@ -192,6 +194,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
? 8
: int(_config["embedx_dim"]);
// VLOG(0) << "feature value InitAccessorInfo embedx_dim:" << embedx_dim;
_accessor_info.embedx_dim = embedx_dim;
_accessor_info.update_dim = 5 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
Expand Down
Loading