Skip to content

Commit

Permalink
graph-engine data transfer optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Nov 18, 2021
1 parent 2cc00bd commit 6877d97
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 62 deletions.
43 changes: 31 additions & 12 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,15 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
// std::vector<std::vector<std::pair<uint64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) {
if (server_index != -1) {
res.resize(node_ids.size());
if (need_weight) {
res_weight.resize(node_ids.size());
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
Expand All @@ -331,11 +336,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[node_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
res[node_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[node_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
Expand All @@ -352,6 +360,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
Expand All @@ -364,13 +373,18 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
res_weight.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
res.push_back(std::vector<std::pair<uint64_t, float>>());
// res.push_back(std::vector<std::pair<uint64_t, float>>());
res.push_back({});
if (need_weight) {
res_weight.push_back({});
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
Expand Down Expand Up @@ -413,11 +427,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
res[query_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[query_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
Expand Down Expand Up @@ -445,6 +462,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
std::vector<std::vector<uint64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1);

virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
Expand Down
22 changes: 14 additions & 8 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,21 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_random_sample request requires at least 2 arguments");
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes);
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes,
need_weight);

cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
Expand Down Expand Up @@ -454,16 +456,17 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
brpc::Controller *cntl) {
// sleep(5);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_random_neighbors_sample request requires at least 2 arguments");
if (request.params_size() < 3) {
set_response_code(response, -1,
"sample_neighbors_across_multi_servers request requires "
"at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t),
size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server;
Expand Down Expand Up @@ -581,6 +584,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
// GraphPsService_Stub rpc_stub =
Expand All @@ -592,7 +597,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
local_buffers, local_actual_sizes,
need_weight);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,13 @@ GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
// std::vector<std::vector<std::pair<uint64_t, float>>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
auto status = worker_ptr->batch_sample_neighbors(
table_id, node_ids, sample_size, v, v1, return_weight);
status.wait();
}

Expand All @@ -313,9 +315,10 @@ GraphPyClient::batch_sample_neighbors(std::string name,
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
res.first[0].push_back(v[i][j].first);
// res.first[0].push_back(v[i][j].first);
res.first[0].push_back(v[i][j]);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v[i][j].second);
if (return_weight) res.second.push_back(v1[i][j]);
}
if (i == v.size() - 1) break;

Expand Down
20 changes: 12 additions & 8 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
}
int32_t GraphTable::random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes) {
std::vector<std::shared_ptr<char>> &buffers, std::vector<int> &actual_sizes,
bool need_weight) {
size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks;
Expand All @@ -407,7 +407,7 @@ int32_t GraphTable::random_sample_neighbors(
for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]);
seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size);
id_list[index].emplace_back(node_ids[idx], sample_size, need_weight);
}
for (int i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
Expand Down Expand Up @@ -442,25 +442,29 @@ int32_t GraphTable::random_sample_neighbors(
}
std::shared_ptr<char> &buffer = buffers[idx];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
actual_size =
res.size() * (need_weight ? (Node::id_size + Node::weight_size)
: Node::id_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(node_id, sample_size);
sample_keys.emplace_back(node_id, sample_size, need_weight);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
buffer.reset(buffer_addr, char_del);
}
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
if (need_weight) {
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey {
uint64_t node_key;
size_t sample_size;
SampleKey(uint64_t _node_key, size_t _sample_size)
: node_key(_node_key), sample_size(_sample_size) {}
bool is_weighted;
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key),
sample_size(_sample_size),
is_weighted(_is_weighted) {}
bool operator==(const SampleKey &s) const {
return node_key == s.node_key && sample_size == s.sample_size;
return node_key == s.node_key && sample_size == s.sample_size &&
is_weighted == s.is_weighted;
}
};

Expand Down Expand Up @@ -360,7 +364,7 @@ class GraphTable : public SparseTable {
virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes);
std::vector<int> &actual_sizes, bool need_weight);

int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
int &actual_sizes);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/table/graph/graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Node {

protected:
uint64_t id;
bool is_weighted;
};

class GraphNode : public Node {
Expand Down
Loading

1 comment on commit 6877d97

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.