Skip to content

Commit

Permalink
Merge pull request seemingwang#7 from seemingwang/develop
Browse files Browse the repository at this point in the history
Merge
  • Loading branch information
Yelrose committed Mar 17, 2021
2 parents b08a36f + 832cab8 commit 6c38fa0
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 130 deletions.
121 changes: 65 additions & 56 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,90 +35,99 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_per_server;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample(uint32_t table_id,
std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float> > > &res) {

std::future<int32_t> GraphBrpcClient::batch_sample(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if(server2request[server_index] == -1){
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
//res.push_back(std::vector<GraphNode>());
// res.push_back(std::vector<GraphNode>());
res.push_back(std::vector<std::pair<uint64_t, float>>());
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t> > node_id_buckets(request_call_num);
std::vector<std::vector<int> > query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
++fail_num;
} else {
VLOG(0) << "check sample response: "
<< " " << closure->check_response(request_idx, PS_GRAPH_SAMPLE);
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
// char buffer[bytes_size];
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;

int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx){
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back({*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start + GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer =
buffer + sizeof(size_t) + sizeof(int) * node_num;

int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
}
offset += actual_size;
}
}
if (fail_num == request_call_num) {
ret = -1;
}
offset += actual_size;
}
}
if (fail_num == request_call_num){
ret = -1;
}
}
closure->set_promise_value(ret);
});
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
// std::string type_str = GraphNode::node_type_to_string(type);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)node_id_buckets[request_idx].data(), sizeof(uint64_t)*node_num);
closure->request(request_idx)->add_params((char *)&sample_size, sizeof(int));

closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx),
closure);
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}

return fut;
Expand All @@ -133,12 +142,12 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
ret = -1;
} else {
VLOG(0) << "check sample response: "
<< " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
// VLOG(0) << "check sample response: "
// << " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
char buffer[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
Expand Down
28 changes: 17 additions & 11 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,10 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
}
int start = *(int *)(request.params(0).c_str());
int size = *(int *)(request.params(1).c_str());
std::vector<float> res_data;
char *buffer;
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size);
cntl->response_attachment().append(buffer, actual_size);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample(Table *table,
Expand All @@ -287,19 +286,26 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());

std::vector<char*> buffers(node_num, nullptr);
std::vector<std::unique_ptr<char[]>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
table->random_sample(node_data, sample_size, buffers, actual_sizes);

cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*node_num);
for (size_t idx = 0; idx < node_num; ++idx){
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
if (buffers[idx] != nullptr){
delete buffers[idx];
buffers[idx] = nullptr;
}
cntl->response_attachment().append(actual_sizes.data(),
sizeof(int) * node_num);
for (size_t idx = 0; idx < node_num; ++idx) {
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
// if (buffers[idx] != nullptr){
// delete buffers[idx];
// buffers[idx] = nullptr;
// }
}
// =======
// std::unique_ptr<char[]> buffer;
// int actual_size;
// table->random_sample(node_id, sample_size, buffer, actual_size);
// cntl->response_attachment().append(buffer.get(), actual_size);
// >>>>>>> Stashed changes
return 0;
}

Expand Down
96 changes: 47 additions & 49 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ size_t GraphShard::get_size() {
return res;
}

std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id, std::string feature) {
std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id,
std::string feature) {
if (node_location.find(id) != node_location.end())
return node_location.find(id)->second;

Expand All @@ -89,14 +90,13 @@ GraphNode *GraphShard::find_node(uint64_t id) {

int32_t GraphTable::load(const std::string &path, const std::string &param) {
auto cmd = paddle::string::split_string<std::string>(param, "|");
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
bool reverse_edge = cmd_set.count(std::string("reverse"));
bool load_edge = cmd_set.count(std::string("edge"));
if(load_edge) {
return this -> load_edges(path, reverse_edge);
}
else {
return this -> load_nodes(path);
if (load_edge) {
return this->load_edges(path, reverse_edge);
} else {
return this->load_nodes(path);
}
}

Expand All @@ -110,33 +110,28 @@ int32_t GraphTable::load_nodes(const std::string &path) {
if (values.size() < 2) continue;
auto id = std::stoull(values[1]);


size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;

}

std::string node_type = values[0];
std::vector<std::string > feature;
std::vector<std::string> feature;
feature.push_back(node_type);
for(size_t slice = 2; slice < values.size(); slice ++) {
for (size_t slice = 2; slice < values.size(); slice++) {
feature.push_back(values[slice]);
}
auto feat = paddle::string::join_strings(feature, '\t');
size_t index = shard_id - shard_start;
shards[index].add_node(id, feat);

}
}
return 0;
}


int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {

auto paths = paddle::string::split_string<std::string>(path, ";");
int count = 0;

Expand Down Expand Up @@ -173,7 +168,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
VLOG(0) << "Load Finished Total Edge Count " << count;

// Build Sampler j

for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
Expand All @@ -200,46 +195,49 @@ GraphNode *GraphTable::find_node(uint64_t id) {
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num_per_table % task_pool_size_;
}
int GraphTable::random_sample(uint64_t* node_ids, int sample_size,
std::vector<char*>& buffers, std::vector<int> &actual_sizes) {
int GraphTable::random_sample(uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes) {
size_t node_num = buffers.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx){
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
char* & buffer = buffers[idx];
int& actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]
->enqueue([&]() -> int {
GraphNode *node = find_node(node_id);
if (node == NULL) {
actual_size = 0;
std::unique_ptr<char[]> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&]() -> int {
GraphNode *node = find_node(node_id);
if (node == NULL) {
actual_size = 0;
return 0;
}
std::vector<GraphEdge *> res = node->sample_k(sample_size);
actual_size =
res.size() * (GraphNode::id_size + GraphNode::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (auto &x : res) {
id = x->get_id();
weight = x->get_weight();
memcpy(buffer_addr + offset, &id, GraphNode::id_size);
offset += GraphNode::id_size;
memcpy(buffer_addr + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
return 0;
}
return 0;
}
std::vector<GraphEdge *> res = node->sample_k(sample_size);
std::vector<GraphNode> node_list;
actual_size =
res.size() * (GraphNode::id_size + GraphNode::weight_size);
buffer = new char[actual_size];
int offset = 0;
uint64_t id;
float weight;
for (auto &x : res) {
id = x->get_id();
weight = x->get_weight();
memcpy(buffer + offset, &id, GraphNode::id_size);
offset += GraphNode::id_size;
memcpy(buffer + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
}
return 0;
}));
}));
}
for (size_t idx = 0; idx < node_num; ++idx){
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}
int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
int32_t GraphTable::pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
if (start < 0) start = 0;
int size = 0, cur_size;
Expand Down Expand Up @@ -283,11 +281,12 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
size += res.back()[j]->get_size();
}
}
buffer = new char[size];
char *buffer_addr = new char[size];
buffer.reset(buffer_addr);
int index = 0;
for (size_t i = 0; i < res.size(); i++) {
for (size_t j = 0; j < res[i].size(); j++) {
res[i][j]->to_buffer(buffer + index);
res[i][j]->to_buffer(buffer_addr + index);
index += res[i][j]->get_size();
}
}
Expand Down Expand Up @@ -321,4 +320,3 @@ int32_t GraphTable::initialize() {
}
}
};

Loading

0 comments on commit 6c38fa0

Please sign in to comment.