Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/graph_engine1.0' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Mar 28, 2021
2 parents bdd9404 + 2cb2eaf commit 88cd27a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <memory>
#include <string>
#include <vector>
#include <ThreadPool.h>

#include <utility>
#include "ThreadPool.h"
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
params += ">";
}
if (this->table_id_map.count(name)) {
VLOG(0) << "loadding data with type " << name << " from " << filepath;
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,13 @@ int32_t GraphTable::get_nodes_ids_by_ranges(

int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
int64_t valid_count = 0;
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
count ++;
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
auto id = std::stoull(values[1]);
Expand All @@ -133,6 +136,10 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
continue;
}

if (count % 1000000 == 0) {
VLOG(0) << count << " nodes are loaded from filepath";
}

std::string nt = values[0];
if (nt != node_type) {
continue;
Expand All @@ -153,8 +160,12 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
<< " not in feature_map.";
}
}
valid_count ++;
}
}

VLOG(0) << valid_count << "/" << count << " nodes in type " <<
node_type << " are loaded successfully in " << path;
return 0;
}

Expand All @@ -163,6 +174,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
int count = 0;
std::string sample_type = "random";
bool is_weighted = false;
int valid_count = 0;

for (auto path : paths) {
std::ifstream file(path);
Expand Down Expand Up @@ -190,13 +202,17 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
<< ", please check id distribution";
continue;
}
if (count % 1000000 == 0) {
VLOG(0) << count << " edges are loaded from filepath";
}

size_t index = src_shard_id - shard_start;
shards[index].add_graph_node(src_id)->build_edges(is_weighted);
shards[index].add_neighboor(src_id, dst_id, weight);
valid_count ++;
}
}
VLOG(0) << "Load Finished Total Edge Count " << count;
VLOG(0) << valid_count << "/" << count << " edges are loaded successfully in " << path;

// Build Sampler j

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ void BindGraphNode(py::module* m) {
.def("get_feature", &GraphNode::get_feature);
}
void BindGraphPyFeatureNode(py::module* m) {
py::class_<FeatureNode>(*m, "FeatureNode").def(py::init<>());
py::class_<FeatureNode>(*m, "FeatureNode")
.def(py::init<>())
.def("get_id", &GraphNode::get_id)
.def("get_feature", &GraphNode::get_feature);
}

void BindGraphPyService(py::module* m) {
Expand Down

0 comments on commit 88cd27a

Please sign in to comment.