diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 9f65a66708def..faef3d7905fd6 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -302,7 +302,7 @@ std::future GraphBrpcClient::remove_graph_node( return fut; } // char* &buffer,int &actual_size -std::future GraphBrpcClient::batch_sample_neighboors( +std::future GraphBrpcClient::batch_sample_neighbors( uint32_t table_id, std::vector node_ids, int sample_size, std::vector>> &res, int server_index) { @@ -390,8 +390,8 @@ std::future GraphBrpcClient::batch_sample_neighboors( size_t fail_num = 0; for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) { - if (closure->check_response(request_idx, - PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { + if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) != + 0) { ++fail_num; } else { auto &res_io_buffer = @@ -435,7 +435,7 @@ std::future GraphBrpcClient::batch_sample_neighboors( for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { int server_index = request2server[request_idx]; - closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS); closure->request(request_idx)->set_table_id(table_id); closure->request(request_idx)->set_client_id(_client_id); size_t node_num = node_id_buckets[request_idx].size(); @@ -494,6 +494,47 @@ std::future GraphBrpcClient::random_sample_nodes( closure); return fut; } + +std::future GraphBrpcClient::use_neighbors_sample_cache( + uint32_t table_id, size_t total_size_limit, size_t ttl) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + server_size, [&, server_size = this->server_size ](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < server_size; ++request_idx) { + if (closure->check_response( + request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) { + ++fail_num; + break; + } + } + ret = fail_num == 0 ? 0 : -1; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + size_t size_limit = total_size_limit / server_size + + (total_size_limit % server_size != 0 ? 1 : 0); + std::future fut = promise->get_future(); + for (size_t i = 0; i < server_size; i++) { + int server_index = i; + closure->request(server_index) + ->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE); + closure->request(server_index)->set_table_id(table_id); + closure->request(server_index)->set_client_id(_client_id); + closure->request(server_index) + ->add_params((char *)&size_limit, sizeof(size_t)); + closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t)); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(server_index), + closure->request(server_index), + closure->response(server_index), closure); + } + return fut; +} std::future GraphBrpcClient::pull_graph_list( uint32_t table_id, int server_index, int start, int size, int step, std::vector &res) { diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 1fbb3fa9b0550..c1083afb71abf 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -61,8 +61,8 @@ class GraphBrpcClient : public BrpcPsClient { public: GraphBrpcClient() {} virtual ~GraphBrpcClient() {} - // given a batch of nodes, sample graph_neighboors for each of them - virtual std::future batch_sample_neighboors( + // given a batch of nodes, sample graph_neighbors for each of them + virtual std::future batch_sample_neighbors( uint32_t table_id, std::vector node_ids, int sample_size, std::vector>>& res, int server_index = -1); @@ -89,6 +89,9 @@ class GraphBrpcClient : public BrpcPsClient { virtual std::future add_graph_node( uint32_t table_id, std::vector& node_id_list, std::vector& is_weighted_list); + virtual std::future use_neighbors_sample_cache(uint32_t table_id, + size_t size_limit, + size_t ttl); virtual std::future remove_graph_node( uint32_t table_id, std::vector& node_id_list); virtual int32_t initialize(); diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 424cf281bf397..0aba2b9f44ae7 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -187,8 +187,8 @@ int32_t GraphBrpcService::initialize() { _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; - _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] = - &GraphBrpcService::graph_random_sample_neighboors; + _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] = + &GraphBrpcService::graph_random_sample_neighbors; _service_handler_map[PS_GRAPH_SAMPLE_NODES] = &GraphBrpcService::graph_random_sample_nodes; _service_handler_map[PS_GRAPH_GET_NODE_FEAT] = @@ -201,8 +201,9 @@ int32_t GraphBrpcService::initialize() { _service_handler_map[PS_GRAPH_SET_NODE_FEAT] = &GraphBrpcService::graph_set_node_feat; _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] = - &GraphBrpcService::sample_neighboors_across_multi_servers; - + &GraphBrpcService::sample_neighbors_across_multi_servers; + _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] = + &GraphBrpcService::use_neighbors_sample_cache; // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -373,7 +374,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, cntl->response_attachment().append(buffer.get(), actual_size); return 0; } -int32_t GraphBrpcService::graph_random_sample_neighboors( +int32_t GraphBrpcService::graph_random_sample_neighbors( Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) @@ -389,7 +390,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( std::vector> buffers(node_num); std::vector actual_sizes(node_num, 0); ((GraphTable *)table) - ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes); + ->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes); cntl->response_attachment().append(&node_num, sizeof(size_t)); cntl->response_attachment().append(actual_sizes.data(), @@ -448,7 +449,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, return 0; } -int32_t GraphBrpcService::sample_neighboors_across_multi_servers( +int32_t GraphBrpcService::sample_neighbors_across_multi_servers( Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { // sleep(5); @@ -456,7 +457,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( if (request.params_size() < 2) { set_response_code( response, -1, - "graph_random_sample request requires at least 2 arguments"); + "graph_random_neighbors_sample request requires at least 2 arguments"); return 0; } size_t node_num = request.params(0).size() / sizeof(uint64_t), @@ -519,7 +520,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( remote_call_num); size_t fail_num = 0; for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) { - if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) != + if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) != 0) { ++fail_num; failed[request2server[request_idx]] = true; @@ -570,7 +571,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) { int server_index = request2server[request_idx]; - closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS); closure->request(request_idx)->set_table_id(request.table_id()); closure->request(request_idx)->set_client_id(rank); size_t node_num = node_id_buckets[request_idx].size(); @@ -590,8 +591,8 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( } if (server2request[rank] != -1) { ((GraphTable *)table) - ->random_sample_neighboors(node_id_buckets.back().data(), sample_size, - local_buffers, local_actual_sizes); + ->random_sample_neighbors(node_id_buckets.back().data(), sample_size, + local_buffers, local_actual_sizes); } local_promise.get()->set_value(0); if (remote_call_num == 0) func(closure); @@ -636,5 +637,20 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, return 0; } +int32_t GraphBrpcService::use_neighbors_sample_cache( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code(response, -1, + "use_neighbors_sample_cache request requires at least 2 " + "arguments[cache_size, ttl]"); + return 0; + } + size_t size_limit = *(size_t *)(request.params(0).c_str()); + size_t ttl = *(size_t *)(request.params(1).c_str()); + ((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl); + return 0; +} } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index 817fe08331165..d1a6aa63604f3 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -78,10 +78,10 @@ class GraphBrpcService : public PsBaseService { int32_t initialize_shard_info(); int32_t pull_graph_list(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t graph_random_sample_neighboors(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl); + int32_t graph_random_sample_neighbors(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); int32_t graph_random_sample_nodes(Table *table, const PsRequestMessage &request, PsResponseMessage &response, @@ -116,9 +116,15 @@ class GraphBrpcService : public PsBaseService { int32_t print_table_stat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t sample_neighboors_across_multi_servers( - Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t sample_neighbors_across_multi_servers(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + + int32_t use_neighbors_sample_cache(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); private: bool _is_initialize_shard_info; diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 498805136417f..78f239f80d445 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -290,19 +290,29 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { } } std::vector>> -GraphPyClient::batch_sample_neighboors(std::string name, - std::vector node_ids, - int sample_size) { +GraphPyClient::batch_sample_neighbors(std::string name, + std::vector node_ids, + int sample_size) { std::vector>> v; if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; auto status = - worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v); + worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v); status.wait(); } return v; } +void GraphPyClient::use_neighbors_sample_cache(std::string name, + size_t total_size_limit, + size_t ttl) { + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + worker_ptr->use_neighbors_sample_cache(table_id, total_size_limit, ttl); + status.wait(); + } +} std::vector GraphPyClient::random_sample_nodes(std::string name, int server_index, int sample_size) { diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 8e03938801ce9..2d36edbf9c17d 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -148,13 +148,15 @@ class GraphPyClient : public GraphPyService { int get_client_id() { return client_id; } void set_client_id(int client_id) { this->client_id = client_id; } void start_client(); - std::vector>> batch_sample_neighboors( + std::vector>> batch_sample_neighbors( std::string name, std::vector node_ids, int sample_size); std::vector random_sample_nodes(std::string name, int server_index, int sample_size); std::vector> get_node_feat( std::string node_type, std::vector node_ids, std::vector feature_names); + void use_neighbors_sample_cache(std::string name, size_t total_size_limit, + size_t ttl); void set_node_feat(std::string node_type, std::vector node_ids, std::vector feature_names, const std::vector> features); diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index 42e25258ec3fe..8ee9b3590721a 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -49,7 +49,7 @@ enum PsCmdID { PS_STOP_PROFILER = 28; PS_PUSH_GLOBAL_STEP = 29; PS_PULL_GRAPH_LIST = 30; - PS_GRAPH_SAMPLE_NEIGHBOORS = 31; + PS_GRAPH_SAMPLE_NEIGHBORS = 31; PS_GRAPH_SAMPLE_NODES = 32; PS_GRAPH_GET_NODE_FEAT = 33; PS_GRAPH_CLEAR = 34; @@ -57,6 +57,7 @@ enum PsCmdID { PS_GRAPH_REMOVE_GRAPH_NODE = 36; PS_GRAPH_SET_NODE_FEAT = 37; PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38; + PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 47b966182e682..96ebf039aae77 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -392,7 +392,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size, memcpy(pointer, res.data(), actual_size); return 0; } -int32_t GraphTable::random_sample_neighboors( +int32_t GraphTable::random_sample_neighbors( uint64_t *node_ids, int sample_size, std::vector> &buffers, std::vector &actual_sizes) { diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index e09ca71925c03..91f2b1c029d80 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -367,7 +367,7 @@ class GraphTable : public SparseTable { int &actual_size, bool need_feature, int step); - virtual int32_t random_sample_neighboors( + virtual int32_t random_sample_neighbors( uint64_t *node_ids, int sample_size, std::vector> &buffers, std::vector &actual_sizes); @@ -427,7 +427,7 @@ class GraphTable : public SparseTable { size_t get_server_num() { return server_num; } - virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) { + virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) { { std::unique_lock lock(mutex_); if (use_cache == false) { diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 0674ef13d7c45..c061fe0bb909d 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -111,7 +111,7 @@ void testFeatureNodeSerializeFloat64() { void testSingleSampleNeighboor( std::shared_ptr& worker_ptr_) { std::vector>> vs; - auto pull_status = worker_ptr_->batch_sample_neighboors( + auto pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 37), 4, vs); pull_status.wait(); @@ -127,7 +127,7 @@ void testSingleSampleNeighboor( s.clear(); s1.clear(); vs.clear(); - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 96), 4, vs); pull_status.wait(); s1 = {111, 48, 247}; @@ -139,7 +139,7 @@ void testSingleSampleNeighboor( ASSERT_EQ(true, s1.find(g) != s1.end()); } vs.clear(); - pull_status = worker_ptr_->batch_sample_neighboors(0, {96, 37}, 4, vs, 0); + pull_status = worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, 0); pull_status.wait(); ASSERT_EQ(vs.size(), 2); } @@ -199,7 +199,7 @@ void testBatchSampleNeighboor( std::shared_ptr& worker_ptr_) { std::vector>> vs; std::vector v = {37, 96}; - auto pull_status = worker_ptr_->batch_sample_neighboors(0, v, 4, vs); + auto pull_status = worker_ptr_->batch_sample_neighbors(0, v, 4, vs); pull_status.wait(); std::unordered_set s; std::unordered_set s1 = {112, 45, 145}; @@ -435,24 +435,24 @@ void RunBrpcPushSparse() { sleep(5); testSingleSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_); - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 10240001024), 4, vs); pull_status.wait(); ASSERT_EQ(0, vs[0].size()); paddle::distributed::GraphTable* g = (paddle::distributed::GraphTable*)pserver_ptr_->table(0); size_t ttl = 6; - g->make_neigh_sample_cache(4, ttl); + g->make_neighbor_sample_cache(4, ttl); int round = 5; while (round--) { vs.clear(); - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 37), 1, vs); pull_status.wait(); for (int i = 0; i < ttl; i++) { std::vector>> vs1; - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 37), 1, vs1); pull_status.wait(); ASSERT_EQ(vs[0].size(), vs1[0].size()); @@ -559,13 +559,13 @@ void RunBrpcPushSparse() { ASSERT_EQ(count_item_nodes.size(), 12); } - vs = client1.batch_sample_neighboors(std::string("user2item"), - std::vector(1, 96), 4); + vs = client1.batch_sample_neighbors(std::string("user2item"), + std::vector(1, 96), 4); ASSERT_EQ(vs[0].size(), 3); std::vector node_ids; node_ids.push_back(96); node_ids.push_back(37); - vs = client1.batch_sample_neighboors(std::string("user2item"), node_ids, 4); + vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4); ASSERT_EQ(vs.size(), 2); std::vector nodes_ids = client2.random_sample_nodes("user", 0, 6); diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 0a39f529387a2..e6b8238010a35 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -205,7 +205,8 @@ void BindGraphPyClient(py::module* m) { .def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf) .def("pull_graph_list", &GraphPyClient::pull_graph_list) .def("start_client", &GraphPyClient::start_client) - .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors) + .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors) + .def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors) .def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("stop_server", &GraphPyClient::stop_server)