Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add "remove feature_node pyblind"; add set_features #20

Merged
merged 5 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure);
return fut;
}

std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) {
++fail_num;
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());

// set features
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());

GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}

return fut;
}

int32_t GraphBrpcClient::initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient {
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);

virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);

virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
Expand Down
42 changes: 42 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h"

#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
Expand Down Expand Up @@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();

Expand Down Expand Up @@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,

return 0;
}

int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_set_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);

std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");

std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));

const char *buffer = request.params(2).c_str();

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}

((GraphTable *)table)->set_node_feat(node_ids, feature_names, features);

return 0;
}

} // namespace distributed
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService {
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);

int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t add_graph_node(Table *table, const PsRequestMessage &request,
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
return v;
}

void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->set_node_feat(table_id, node_ids, feature_names, features);
status.wait();
}
return;
}

std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService {
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/service/sendrecv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum PsCmdID {
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
}

message PsRequestMessage {
Expand Down
79 changes: 53 additions & 26 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include <chrono>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/generator.h"

namespace paddle {
namespace distributed {
Expand Down Expand Up @@ -406,31 +406,30 @@ int32_t GraphTable::random_sample_neighboors(
int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];

tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue(
[&]() -> int {
Node *node = find_node(node_id);
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
Node *node = find_node(node_id);

if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
Expand Down Expand Up @@ -470,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return 0;
}

int32_t GraphTable::set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = shards[index].add_feature_node(node_id);
node->set_feature_size(this->feat_name.size());
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
node->set_feature(feat_id_map[feature_name], res[feat_idx][idx]);
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}

std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GraphShard {
}
return res;
}

GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
Expand Down Expand Up @@ -122,6 +123,11 @@ class GraphTable : public SparseTable {
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);

virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);

protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,17 @@ void RunBrpcPushSparse() {
VLOG(0) << "get_node_feat: " << node_feat[1][0];
VLOG(0) << "get_node_feat: " << node_feat[1][1];

node_feat[1][0] = "helloworld";

client1.set_node_feat(std::string("user"), node_ids, feature_names,
node_feat);

// sleep(5);
node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names);
VLOG(0) << "get_node_feat: " << node_feat[1][0];
ASSERT_TRUE(node_feat[1][0] == "helloworld");

// Test string
node_ids.clear();
node_ids.push_back(37);
Expand Down
Loading