Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Mar 26, 2021
2 parents f95b4d2 + 74d2167 commit af03ec9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
34 changes: 27 additions & 7 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,15 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {

auto node = shards[index].add_feature_node(id);

// auto mutable_feature = node->get_mutable_feature();

// mutable_feature.clear();
// mutable_feature.resize(this->feat_name.size());
node->set_feature_size(feat_name.size());

for (size_t slice = 2; slice < values.size(); slice++) {
auto feat = this->parse_feature(values[slice]);
if (feat.first > 0) {
// mutable_feature[feat.first] = feat.second;
if (feat.first >= 0) {
node->set_feature(feat.first, feat.second);
} else {
VLOG(4) << "Node feature: " << values[slice]
<< " not in feature_map.";
}
}
}
Expand Down Expand Up @@ -363,6 +361,27 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return 0;
}

if (node == nullptr) {
return 0;
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
// res[feat_idx][idx] =
// node->get_feature(feat_id_map[feature_name]);
auto feat = node->get_feature(feat_id_map[feature_name]);
res[feat_idx][idx] = feat;
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}

std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
Expand Down Expand Up @@ -488,4 +507,5 @@ int32_t GraphTable::initialize() {
return 0;
}
}
};
}
;
17 changes: 17 additions & 0 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,23 @@ void RunBrpcPushSparse() {
std::cout << "get_node_feat: " << node_feat[1][0] << std::endl;
std::cout << "get_node_feat: " << node_feat[1][1] << std::endl;

// Test string
node_ids.clear();
node_ids.push_back(37);
node_ids.push_back(96);
// std::vector<std::string> feature_names;
feature_names.clear();
feature_names.push_back(std::string("a"));
feature_names.push_back(std::string("b"));
node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names);
ASSERT_EQ(node_feat.size(), 2);
ASSERT_EQ(node_feat[0].size(), 2);
std::cout << "get_node_feat: " << node_feat[0][0].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[0][1].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[1][0].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[1][1].size() << std::endl;

std::remove(edge_file_name);
std::remove(node_file_name);
LOG(INFO) << "Run stop_server";
Expand Down
15 changes: 14 additions & 1 deletion paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,20 @@ void BindGraphPyClient(py::module* m) {
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("get_node_feat", &GraphPyClient::get_node_feat)
.def("get_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names) {
auto feats =
self.get_node_feat(node_type, node_ids, feature_names);
std::vector<std::vector<py::bytes>> bytes_feats(feats.size());
for (int i = 0; i < feats.size(); ++i) {
for (int j = 0; j < feats[i].size(); ++j) {
bytes_feats[i].push_back(py::bytes(feats[i][j]));
}
}
return bytes_feats;
})
.def("bind_local_server", &GraphPyClient::bind_local_server);
}

Expand Down

0 comments on commit af03ec9

Please sign in to comment.