Skip to content

Commit

Permalink
Merge pull request #20 from Yelrose/rm_node
Browse files Browse the repository at this point in the history
add "remove feature_node pyblind";  add set_features
  • Loading branch information
seemingwang committed Aug 18, 2021
2 parents fc74f83 + ead5c00 commit 703bbb1
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 26 deletions.
96 changes: 96 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure);
return fut;
}

std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) {
++fail_num;
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());

// set features
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());

GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}

return fut;
}

int32_t GraphBrpcClient::initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient {
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);

virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);

virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
Expand Down
42 changes: 42 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h"

#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
Expand Down Expand Up @@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();

Expand Down Expand Up @@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,

return 0;
}

int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_set_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);

std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");

std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));

const char *buffer = request.params(2).c_str();

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}

((GraphTable *)table)->set_node_feat(node_ids, feature_names, features);

return 0;
}

} // namespace distributed
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService {
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);

int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t add_graph_node(Table *table, const PsRequestMessage &request,
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
return v;
}

void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->set_node_feat(table_id, node_ids, feature_names, features);
status.wait();
}
return;
}

std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService {
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/service/sendrecv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum PsCmdID {
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
}

message PsRequestMessage {
Expand Down
79 changes: 53 additions & 26 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include <chrono>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/generator.h"

namespace paddle {
namespace distributed {
Expand Down Expand Up @@ -406,31 +406,30 @@ int32_t GraphTable::random_sample_neighboors(
int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];

tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue(
[&]() -> int {
Node *node = find_node(node_id);
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
Node *node = find_node(node_id);

if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
Expand Down Expand Up @@ -470,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return 0;
}

int32_t GraphTable::set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = shards[index].add_feature_node(node_id);
node->set_feature_size(this->feat_name.size());
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
node->set_feature(feat_id_map[feature_name], res[feat_idx][idx]);
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}

std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GraphShard {
}
return res;
}

GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
Expand Down Expand Up @@ -122,6 +123,11 @@ class GraphTable : public SparseTable {
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);

virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);

protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,17 @@ void RunBrpcPushSparse() {
VLOG(0) << "get_node_feat: " << node_feat[1][0];
VLOG(0) << "get_node_feat: " << node_feat[1][1];

node_feat[1][0] = "helloworld";

client1.set_node_feat(std::string("user"), node_ids, feature_names,
node_feat);

// sleep(5);
node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names);
VLOG(0) << "get_node_feat: " << node_feat[1][0];
ASSERT_TRUE(node_feat[1][0] == "helloworld");

// Test string
node_ids.clear();
node_ids.push_back(37);
Expand Down
Loading

1 comment on commit 703bbb1

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 703bbb1 Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍PR: #20 Commit ID: 703bbb1 contains failed CI.

Please sign in to comment.