Skip to content

Commit

Permalink
Merge pull request #12 from WeiyueSu/FeatureNode
Browse files Browse the repository at this point in the history
Feature node
  • Loading branch information
seemingwang committed Mar 24, 2021
2 parents 4a32d64 + 959198a commit bb48ece
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 30 deletions.
88 changes: 88 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,94 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
: shard_num / server_size + 1;
return id % shard_num / shard_per_server;
}

std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string> > &res) {

std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < query_idx_buckets.at(request_idx).size(); ++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feature = std::string(buffer, feat_len);
res[feat_idx][query_idx] = feature;
buffer += feat_len;
}
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
std::string joint_feature_name = paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());

PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}

return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class GraphBrpcClient : public BrpcPsClient {
int server_index,
int sample_size,
std::vector<uint64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id,
const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string> > &res);
virtual int32_t initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
Expand Down
37 changes: 37 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::graph_random_sample_neighboors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
&GraphBrpcService::graph_get_node_feat;

// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
Expand Down Expand Up @@ -314,5 +316,40 @@ int32_t GraphBrpcService::graph_random_sample_nodes(

return 0;
}


int32_t GraphBrpcService::graph_get_node_feat(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_get_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data+node_num);

std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");


std::vector<std::vector<std::string>> feature
(feature_names.size(), std::vector<std::string>(node_num));

table->get_node_feat(node_ids, feature_names, feature);

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = feature[feat_idx][node_idx].size();
cntl->response_attachment().append(&feat_len, sizeof(size_t));
cntl->response_attachment().append(feature[feat_idx][node_idx].data(), feat_len);
}
}

return 0;
}
} // namespace distributed
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class GraphBrpcService : public PsBaseService {
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_get_node_feat(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,24 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
}
return v;
}

// (name, dtype, ndarray)
std::vector<std::vector<std::string > >
GraphPyClient::get_node_feat(std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names){

std::vector<std::vector<std::string> > v(feature_names.size(),
std::vector<std::string>(node_ids.size()));
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
status.wait();
}
return v;
}

std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ class GraphPyClient : public GraphPyService {
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string> >
get_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ class PSClient {
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id,
const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string> > &res){
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/service/sendrecv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum PsCmdID {
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
}

message PsRequestMessage {
Expand Down Expand Up @@ -114,4 +115,4 @@ message MultiVariableMessage {
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
};
};
47 changes: 40 additions & 7 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,17 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {

auto node = shards[index].add_feature_node(id);

auto mutable_feature = node->get_mutable_feature();
//auto mutable_feature = node->get_mutable_feature();

mutable_feature.clear();
mutable_feature.resize(this->feat_name.size());
//mutable_feature.clear();
//mutable_feature.resize(this->feat_name.size());
node->set_feature_size(feat_name.size());

for (size_t slice = 2; slice < values.size(); slice++) {
auto feat = this->parse_feature(values[slice]);
if(feat.first > 0) {
mutable_feature[feat.first] = feat.second;
//mutable_feature[feat.first] = feat.second;
node->set_feature(feat.first, feat.second);
}
}
}
Expand Down Expand Up @@ -212,7 +214,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
Node *GraphTable::find_node(uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return NULL;
return nullptr;
}
size_t index = shard_id - shard_start;
Node *node = shards[index].find_node(id);
Expand Down Expand Up @@ -287,7 +289,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
memcpy(pointer, res.data(), actual_size);
return 0;
}
int GraphTable::random_sample_neighboors(
int32_t GraphTable::random_sample_neighboors(
uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes) {
Expand All @@ -301,7 +303,7 @@ int GraphTable::random_sample_neighboors(
[&]() -> int {
Node *node = find_node(node_id);

if (node == NULL) {
if (node == nullptr) {
actual_size = 0;
return 0;
}
Expand Down Expand Up @@ -330,6 +332,37 @@ int GraphTable::random_sample_neighboors(
return 0;
}

int32_t GraphTable::get_node_feat(
const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string> > &res){
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
Node *node = find_node(node_id);

if (node == nullptr) {
return 0;
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx){
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()){
//res[feat_idx][idx] = node->get_feature(feat_id_map[feature_name]);
auto feat = node->get_feature(feat_id_map[feature_name]);
res[feat_idx][idx] = feat;
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}

std::pair<int32_t, std::string> GraphTable::parse_feature(std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1, "")
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ class GraphTable : public SparseTable {
virtual uint32_t get_thread_pool_index(uint64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);

virtual int32_t get_node_feat(
const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string> > &res);

protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
Expand Down
23 changes: 20 additions & 3 deletions paddle/fluid/distributed/table/graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ class Node {
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual void add_feature(std::string feature) { }
virtual std::string get_feature(int idx) { return std::string(""); }
virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {}
virtual int get_feature_size() {return 0;}

protected:
uint64_t id;
Expand Down Expand Up @@ -77,8 +79,23 @@ class FeatureNode: public Node{
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) { return this->feature[idx]; }
virtual std::vector<std::string> & get_mutable_feature() { return this->feature; }
virtual std::string get_feature(int idx) {
if (idx < (int)this->feature.size()){
return this->feature[idx];
}
else{
return std::string("");
}
}

virtual void set_feature(int idx, std::string str) {
if (idx >= (int)this->feature.size()){
this->feature.resize(idx+1);
}
this->feature[idx] = str;
}
virtual void set_feature_size(int size) {this->feature.resize(size);}
virtual int get_feature_size() {return this->feature.size();}

template <typename T>
static std::string parse_value_to_bytes(std::vector<std::string> feat_str) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/table/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class Table {
int &actual_sizes) {
return 0;
}
virtual int32_t get_node_feat(
const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string> > &res) {
return 0;
}
virtual int32_t pour() { return 0; }

virtual void clear() = 0;
Expand Down
Loading

0 comments on commit bb48ece

Please sign in to comment.