Skip to content

Commit

Permalink
Merge pull request #18 from Liwb5/engine2.0
Browse files Browse the repository at this point in the history
speed up graph neighbors sampling
  • Loading branch information
seemingwang committed Jul 11, 2021
2 parents 5b07c6c + 2037c51 commit fc74f83
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 30 deletions.
12 changes: 10 additions & 2 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/string/printf.h"
#include <chrono>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/generator.h"

namespace paddle {
namespace distributed {

Expand Down Expand Up @@ -399,15 +402,19 @@ int32_t GraphTable::random_sample_neighboors(
uint64_t &node_id = node_ids[idx];
std::unique_ptr<char[]> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(

int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];

tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue(
[&]() -> int {
Node *node = find_node(node_id);

if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size);
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
Expand Down Expand Up @@ -546,6 +553,7 @@ int32_t GraphTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
}
server_num = _shard_num;
// VLOG(0) << "in init graph table server num = " << server_num;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class GraphTable : public SparseTable {
std::string table_type;

std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
};
} // namespace distributed

Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/table/graph/graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <sstream>
#include <vector>
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
#include <memory>
namespace paddle {
namespace distributed {

Expand All @@ -33,7 +34,7 @@ class Node {
virtual void build_edges(bool is_weighted) {}
virtual void build_sampler(std::string sample_type) {}
virtual void add_edge(uint64_t id, float weight) {}
virtual std::vector<int> sample_k(int k) { return std::vector<int>(); }
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) { return std::vector<int>(); }
virtual uint64_t get_neighbor_id(int idx) { return 0; }
virtual float get_neighbor_weight(int idx) { return 1.; }

Expand All @@ -59,7 +60,7 @@ class GraphNode : public Node {
virtual void add_edge(uint64_t id, float weight) {
edges->add_edge(id, weight);
}
virtual std::vector<int> sample_k(int k) { return sampler->sample_k(k); }
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) { return sampler->sample_k(k, rng); }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }

Expand Down
68 changes: 45 additions & 23 deletions paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,58 @@
// limitations under the License.

#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
#include "paddle/fluid/framework/generator.h"
#include <iostream>
#include <unordered_map>
#include <memory>
namespace paddle {
namespace distributed {

void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }

std::vector<int> RandomSampler::sample_k(int k) {
std::vector<int> RandomSampler::sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) {
int n = edges->size();
if (k > n) {
k = n;
}
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<int> sample_result;
for(int i = 0;i < k;i ++ ) {
sample_result.push_back(i);
}
if (k == n) {
return sample_result;
}

std::uniform_int_distribution<int> distrib(0, n - 1);
std::unordered_map<int, int> replace_map;
while (k--) {
int rand_int = rand() % n;
auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) {
sample_result.push_back(rand_int);
} else {
sample_result.push_back(iter->second);
}

iter = replace_map.find(n - 1);
if (iter == replace_map.end()) {
replace_map[rand_int] = n - 1;
for(int i = 0; i < k; i ++) {
int j = distrib(*rng);
if (j >= i) {
// buff_nid[offset + i] = nid[j] if m.find(j) == m.end() else nid[m[j]]
auto iter_j = replace_map.find(j);
if(iter_j == replace_map.end()) {
sample_result[i] = j;
} else {
sample_result[i] = iter_j -> second;
}
// m[j] = i if m.find(i) == m.end() else m[i]
auto iter_i = replace_map.find(i);
if(iter_i == replace_map.end()) {
replace_map[j] = i;
} else {
replace_map[j] = (iter_i -> second);
}
} else {
replace_map[rand_int] = iter->second;
sample_result[i] = sample_result[j];
// buff_nid[offset + j] = nid[i] if m.find(i) == m.end() else nid[m[i]]
auto iter_i = replace_map.find(i);
if(iter_i == replace_map.end()) {
sample_result[j] = i;
} else {
sample_result[j] = (iter_i -> second);
}
}
--n;
}
return sample_result;
}
Expand Down Expand Up @@ -98,19 +117,22 @@ void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start,
count = left->count + right->count;
}
}
std::vector<int> WeightedSampler::sample_k(int k) {
if (k > count) {
std::vector<int> WeightedSampler::sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) {
if (k >= count) {
k = count;
std::vector<int> sample_result;
for (int i = 0; i < k; i++) {
sample_result.push_back(i);
}
return sample_result;
}
std::vector<int> sample_result;
float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::uniform_real_distribution<float> distrib(0, 1.0);
while (k--) {
float query_weight = rand() % 100000 / 100000.0;
float query_weight = distrib(*rng);
query_weight *= weight - subtract_weight_map[this];
sample_result.push_back(sample(query_weight, subtract_weight_map,
subtract_count_map, subtract));
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/table/graph/graph_weighted_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/table/graph/graph_edge.h"
#include <random>
#include <memory>
namespace paddle {
namespace distributed {

class Sampler {
public:
virtual ~Sampler() {}
virtual void build(GraphEdgeBlob *edges) = 0;
virtual std::vector<int> sample_k(int k) = 0;
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) = 0;
};

class RandomSampler : public Sampler {
public:
virtual ~RandomSampler() {}
virtual void build(GraphEdgeBlob *edges);
virtual std::vector<int> sample_k(int k);
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng);
GraphEdgeBlob *edges;
};

Expand All @@ -46,7 +48,7 @@ class WeightedSampler : public Sampler {
GraphEdgeBlob *edges;
virtual void build(GraphEdgeBlob *edges);
virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end);
virtual std::vector<int> sample_k(int k);
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng);

private:
int sample(float query_weight,
Expand Down

1 comment on commit fc74f83

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on fc74f83 Jul 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍PR: #18 Commit ID: fc74f83 contains failed CI.

Please sign in to comment.