diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 92f8304a8bf62..997c395f5147d 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -20,7 +20,10 @@ #include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/string/printf.h" +#include #include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/framework/generator.h" + namespace paddle { namespace distributed { @@ -399,7 +402,11 @@ int32_t GraphTable::random_sample_neighboors( uint64_t &node_id = node_ids[idx]; std::unique_ptr &buffer = buffers[idx]; int &actual_size = actual_sizes[idx]; - tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( + + int thread_pool_index = get_thread_pool_index(node_id); + auto rng = _shards_task_rng_pool[thread_pool_index]; + + tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue( [&]() -> int { Node *node = find_node(node_id); @@ -407,7 +414,7 @@ int32_t GraphTable::random_sample_neighboors( actual_size = 0; return 0; } - std::vector res = node->sample_k(sample_size); + std::vector res = node->sample_k(sample_size, rng); actual_size = res.size() * (Node::id_size + Node::weight_size); int offset = 0; uint64_t id; @@ -546,6 +553,7 @@ int32_t GraphTable::initialize() { _shards_task_pool.resize(task_pool_size_); for (size_t i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); + _shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0)); } server_num = _shard_num; // VLOG(0) << "in init graph table server num = " << server_num; diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 5eeb3915f5b1f..6ccce44c7ead6 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -136,6 +136,7 @@ class GraphTable : public SparseTable { std::string table_type; std::vector> _shards_task_pool; + std::vector> _shards_task_rng_pool; }; } // namespace distributed diff --git a/paddle/fluid/distributed/table/graph/graph_node.h b/paddle/fluid/distributed/table/graph/graph_node.h index 8ad795ac97b54..940896188bf60 100644 --- a/paddle/fluid/distributed/table/graph/graph_node.h +++ b/paddle/fluid/distributed/table/graph/graph_node.h @@ -18,6 +18,7 @@ #include #include #include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" +#include namespace paddle { namespace distributed { @@ -33,7 +34,7 @@ class Node { virtual void build_edges(bool is_weighted) {} virtual void build_sampler(std::string sample_type) {} virtual void add_edge(uint64_t id, float weight) {} - virtual std::vector sample_k(int k) { return std::vector(); } + virtual std::vector sample_k(int k, const std::shared_ptr rng) { return std::vector(); } virtual uint64_t get_neighbor_id(int idx) { return 0; } virtual float get_neighbor_weight(int idx) { return 1.; } @@ -59,7 +60,7 @@ class GraphNode : public Node { virtual void add_edge(uint64_t id, float weight) { edges->add_edge(id, weight); } - virtual std::vector sample_k(int k) { return sampler->sample_k(k); } + virtual std::vector sample_k(int k, const std::shared_ptr rng) { return sampler->sample_k(k, rng); } virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc index 3a680875e3df4..4520513588267 100644 --- a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc @@ -13,39 +13,58 @@ // limitations under the License. #include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" +#include "paddle/fluid/framework/generator.h" #include #include +#include namespace paddle { namespace distributed { void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; } -std::vector RandomSampler::sample_k(int k) { +std::vector RandomSampler::sample_k(int k, const std::shared_ptr rng) { int n = edges->size(); if (k > n) { k = n; } - struct timespec tn; - clock_gettime(CLOCK_REALTIME, &tn); - srand(tn.tv_nsec); std::vector sample_result; + for(int i = 0;i < k;i ++ ) { + sample_result.push_back(i); + } + if (k == n) { + return sample_result; + } + + std::uniform_int_distribution distrib(0, n - 1); std::unordered_map replace_map; - while (k--) { - int rand_int = rand() % n; - auto iter = replace_map.find(rand_int); - if (iter == replace_map.end()) { - sample_result.push_back(rand_int); - } else { - sample_result.push_back(iter->second); - } - iter = replace_map.find(n - 1); - if (iter == replace_map.end()) { - replace_map[rand_int] = n - 1; + for(int i = 0; i < k; i ++) { + int j = distrib(*rng); + if (j >= i) { + // buff_nid[offset + i] = nid[j] if m.find(j) == m.end() else nid[m[j]] + auto iter_j = replace_map.find(j); + if(iter_j == replace_map.end()) { + sample_result[i] = j; + } else { + sample_result[i] = iter_j -> second; + } + // m[j] = i if m.find(i) == m.end() else m[i] + auto iter_i = replace_map.find(i); + if(iter_i == replace_map.end()) { + replace_map[j] = i; + } else { + replace_map[j] = (iter_i -> second); + } } else { - replace_map[rand_int] = iter->second; + sample_result[i] = sample_result[j]; + // buff_nid[offset + j] = nid[i] if m.find(i) == m.end() else nid[m[i]] + auto iter_i = replace_map.find(i); + if(iter_i == replace_map.end()) { + sample_result[j] = i; + } else { + sample_result[j] = (iter_i -> second); + } } - --n; } return sample_result; } @@ -98,19 +117,22 @@ void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start, count = left->count + right->count; } } -std::vector WeightedSampler::sample_k(int k) { - if (k > count) { +std::vector WeightedSampler::sample_k(int k, const std::shared_ptr rng) { + if (k >= count) { k = count; + std::vector sample_result; + for (int i = 0; i < k; i++) { + sample_result.push_back(i); + } + return sample_result; } std::vector sample_result; float subtract; std::unordered_map subtract_weight_map; std::unordered_map subtract_count_map; - struct timespec tn; - clock_gettime(CLOCK_REALTIME, &tn); - srand(tn.tv_nsec); + std::uniform_real_distribution distrib(0, 1.0); while (k--) { - float query_weight = rand() % 100000 / 100000.0; + float query_weight = distrib(*rng); query_weight *= weight - subtract_weight_map[this]; sample_result.push_back(sample(query_weight, subtract_weight_map, subtract_count_map, subtract)); diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h index 1787ab23b0431..a23207cc31eab 100644 --- a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h @@ -17,6 +17,8 @@ #include #include #include "paddle/fluid/distributed/table/graph/graph_edge.h" +#include +#include namespace paddle { namespace distributed { @@ -24,14 +26,14 @@ class Sampler { public: virtual ~Sampler() {} virtual void build(GraphEdgeBlob *edges) = 0; - virtual std::vector sample_k(int k) = 0; + virtual std::vector sample_k(int k, const std::shared_ptr rng) = 0; }; class RandomSampler : public Sampler { public: virtual ~RandomSampler() {} virtual void build(GraphEdgeBlob *edges); - virtual std::vector sample_k(int k); + virtual std::vector sample_k(int k, const std::shared_ptr rng); GraphEdgeBlob *edges; }; @@ -46,7 +48,7 @@ class WeightedSampler : public Sampler { GraphEdgeBlob *edges; virtual void build(GraphEdgeBlob *edges); virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end); - virtual std::vector sample_k(int k); + virtual std::vector sample_k(int k, const std::shared_ptr rng); private: int sample(float query_weight,