Skip to content

Commit

Permalink
cache api on client
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Nov 4, 2021
1 parent e51abd0 commit c8d91fb
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 46 deletions.
49 changes: 45 additions & 4 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
int server_index) {
Expand Down Expand Up @@ -390,8 +390,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
Expand Down Expand Up @@ -435,7 +435,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
Expand Down Expand Up @@ -494,6 +494,47 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure);
return fut;
}

std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
uint32_t table_id, size_t total_size_limit, size_t ttl) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(
request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
size_t size_limit = total_size_limit / server_size +
(total_size_limit % server_size != 0 ? 1 : 0);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<FeatureNode> &res) {
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighboors for each of them
virtual std::future<int32_t> batch_sample_neighboors(
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
int server_index = -1);
Expand All @@ -89,6 +89,9 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
size_t ttl);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list);
virtual int32_t initialize();
Expand Down
40 changes: 28 additions & 12 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;

_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] =
&GraphBrpcService::graph_random_sample_neighboors;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
&GraphBrpcService::graph_random_sample_neighbors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
Expand All @@ -201,8 +201,9 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
&GraphBrpcService::sample_neighboors_across_multi_servers;

&GraphBrpcService::sample_neighbors_across_multi_servers;
_service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
&GraphBrpcService::use_neighbors_sample_cache;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();

Expand Down Expand Up @@ -373,7 +374,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample_neighboors(
int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
Expand All @@ -389,7 +390,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors(
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes);
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes);

cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
Expand Down Expand Up @@ -448,15 +449,15 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,

return 0;
}
int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
// sleep(5);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_random_sample request requires at least 2 arguments");
"graph_random_neighbors_sample request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t),
Expand Down Expand Up @@ -519,7 +520,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
remote_call_num);
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) !=
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
failed[request2server[request_idx]] = true;
Expand Down Expand Up @@ -570,7 +571,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(

for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(request.table_id());
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();
Expand All @@ -590,8 +591,8 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
}
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighboors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
Expand Down Expand Up @@ -636,5 +637,20 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
return 0;
}

int32_t GraphBrpcService::use_neighbors_sample_cache(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(response, -1,
"use_neighbors_sample_cache request requires at least 2 "
"arguments[cache_size, ttl]");
return 0;
}
size_t size_limit = *(size_t *)(request.params(0).c_str());
size_t ttl = *(size_t *)(request.params(1).c_str());
((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl);
return 0;
}
} // namespace distributed
} // namespace paddle
20 changes: 13 additions & 7 deletions paddle/fluid/distributed/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ class GraphBrpcService : public PsBaseService {
int32_t initialize_shard_info();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighboors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
Expand Down Expand Up @@ -116,9 +116,15 @@ class GraphBrpcService : public PsBaseService {
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);

int32_t sample_neighboors_across_multi_servers(
Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);

int32_t use_neighbors_sample_cache(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);

private:
bool _is_initialize_shard_info;
Expand Down
18 changes: 14 additions & 4 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,19 +290,29 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
}
}
std::vector<std::vector<std::pair<uint64_t, float>>>
GraphPyClient::batch_sample_neighboors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v);
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
}

void GraphPyClient::use_neighbors_sample_cache(std::string name,
size_t total_size_limit,
size_t ttl) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->use_neighbors_sample_cache(table_id, total_size_limit, ttl);
status.wait();
}
}
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,15 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighboors(
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighbors(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/service/sendrecv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ enum PsCmdID {
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NEIGHBORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
}

message PsRequestMessage {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
memcpy(pointer, res.data(), actual_size);
return 0;
}
int32_t GraphTable::random_sample_neighboors(
int32_t GraphTable::random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class GraphTable : public SparseTable {
int &actual_size, bool need_feature,
int step);

virtual int32_t random_sample_neighboors(
virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes);
Expand Down Expand Up @@ -427,7 +427,7 @@ class GraphTable : public SparseTable {

size_t get_server_num() { return server_num; }

virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) {
virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
{
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
Expand Down
Loading

1 comment on commit c8d91fb

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on c8d91fb Nov 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #36990 Commit ID: c8d91fb contains failed CI.

🔹 Failed: PR-CI-CINN

Unknown Failed
2021-11-04 18:00:00 7e1b4c4c62f8a6f074c52b5d422549ca747b4fef 更新READMEQQ群号为交流群7群 (#36964)
2021-11-04 18:00:00 9c81a9bb1e601c4ebb38057060f7724e59bfc934 Fix PTen thread safety error (#36960)
2021-11-04 18:00:00 + DCFL=Dockerfile.cuda11_cudnn8_gcc82_ubuntu18_cinn
2021-11-04 18:00:00 + echo '[==========================================]
2021-11-04 18:00:00 DockerFile : Dockerfile.cuda11_cudnn8_gcc82_ubuntu18_cinn
2021-11-04 18:00:00 [==========================================]'
2021-11-04 18:00:00 [==========================================]
2021-11-04 18:00:00 DockerFile : Dockerfile.cuda11_cudnn8_gcc82_ubuntu18_cinn
2021-11-04 18:00:00 [==========================================]
2021-11-04 18:00:00 + cd tools/dockerfile
2021-11-04 18:00:00 + bash ci_dockerfile.sh
2021-11-04 18:00:00 + md5sum Dockerfile.cuda11_cudnn8_gcc82_ubuntu18_cinn
2021-11-04 18:00:00 + cut '-d ' -f1
2021-11-04 18:00:00 + xargs echo md5=
2021-11-04 18:00:00 md5sum: Dockerfile.cuda11_cudnn8_gcc82_ubuntu18_cinn: No such file or directory
2021-11-04 18:00:00 md5=
2021-11-04 18:00:00 PADDLE DOCKER BUILD md5=md5
2021-11-04 18:00:00 check docker md5 fail !
2021-11-04 18:00:00 the build(a85acc16ca35450e944471bd9bb9188b) state is BUILD_CODE_FAIL

🔹 Failed: PR-CI-iScan-C

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Windows-Inference

build_failed
2021-11-04 18:06:47 注锟�: 锟�:      ..\paddle/fluid/platform/event.h
2021-11-04 18:06:47 注锟�: 锟�: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\include\cuda_runtime.h
2021-11-04 18:06:47 注锟�: 锟�: .\paddle/fluid/platform/profiler.pb.h
2021-11-04 18:06:47 注锟�: 锟�: ..\paddle/fluid/inference/api/helper.h
2021-11-04 18:06:47 注锟�: 锟�: .\paddle/fluid/inference/api/paddle_inference_pass.h
2021-11-04 18:06:47 注锟�: 锟�: ..\paddle/fluid/inference/tests/api/config_printer.h
2021-11-04 18:06:47 注锟�: 锟�: ..\paddle/fluid/inference/tests/test_helper.h
2021-11-04 18:06:47 注锟�: 锟�: ..\paddle/fluid/inference/utils/benchmark.h
2021-11-04 18:06:47 ninja: build stopped: subcommand failed.
2021-11-04 18:06:47 7
2021-11-04 18:06:47 Build Paddle failed, will exit
2021-11-04 18:06:49 EXCODE: 7

Please sign in to comment.