Skip to content

Commit

Permalink
change shards to pointer vector
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Dec 1, 2021
1 parent c972210 commit afd080a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
64 changes: 38 additions & 26 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int32_t GraphTable::add_graph_node(std::vector<uint64_t> &id_list,
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p.first % this->shard_num - this->shard_start;
this->shards[index].add_graph_node(p.first)->build_edges(p.second);
this->shards[index]->add_graph_node(p.first)->build_edges(p.second);
}
return 0;
}));
Expand All @@ -79,7 +79,7 @@ int32_t GraphTable::remove_graph_node(std::vector<uint64_t> &id_list) {
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p % this->shard_num - this->shard_start;
this->shards[index].delete_node(p);
this->shards[index]->delete_node(p);
}
return 0;
}));
Expand Down Expand Up @@ -178,7 +178,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res.clear();
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i].get_size();
end = total_size + shards[i]->get_size();
start = total_size;
while (start < end && index < ranges.size()) {
if (ranges[index].second <= start)
Expand All @@ -193,11 +193,11 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
second -= total_size;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, first, second, i]() -> std::vector<uint64_t> {
return shards[i].get_ids_by_range(first, second);
return shards[i]->get_ids_by_range(first, second);
}));
}
}
total_size += shards[i].get_size();
total_size += shards[i]->get_size();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto vec = tasks[i].get();
Expand Down Expand Up @@ -241,7 +241,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {

size_t index = shard_id - shard_start;

auto node = shards[index].add_feature_node(id);
auto node = shards[index]->add_feature_node(id);

node->set_feature_size(feat_name.size());

Expand Down Expand Up @@ -307,8 +307,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
extra_alloc_index %= task_pool_size_;
extra_nodes_to_thread_index[src_id] = index;
}
extra_shards[index].add_graph_node(src_id)->build_edges(is_weighted);
extra_shards[index].add_neighbor(src_id, dst_id, weight);
extra_shards[index]->add_graph_node(src_id)->build_edges(is_weighted);
extra_shards[index]->add_neighbor(src_id, dst_id, weight);
valid_count++;
continue;
}
Expand All @@ -318,8 +318,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
}

size_t index = src_shard_id - shard_start;
shards[index].add_graph_node(src_id)->build_edges(is_weighted);
shards[index].add_neighbor(src_id, dst_id, weight);
shards[index]->add_graph_node(src_id)->build_edges(is_weighted);
shards[index]->add_neighbor(src_id, dst_id, weight);
valid_count++;
}
}
Expand All @@ -330,7 +330,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
// Build Sampler j

for (auto &shard : shards) {
auto bucket = shard.get_bucket();
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
used[get_thread_pool_index(bucket[i]->get_id())]++;
Expand All @@ -340,7 +340,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
relocate the duplicate nodes to make them distributed evenly among threads.
*/
for (auto &shard : extra_shards) {
auto bucket = shard.get_bucket();
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
Expand Down Expand Up @@ -393,17 +393,25 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
if (has_alloc[index[right]] - alloc[index[right]] == 0) right--;
if (alloc[index[left]] - has_alloc[index[left]] == 0) left++;
}
std::vector<GraphShard> extra_shards_copy(task_pool_size_, GraphShard());
std::vector<GraphShard *> extra_shards_copy;
for (int i = 0; i < task_pool_size_; ++i) {
extra_shards_copy.push_back(new GraphShard());
}
for (auto &shard : extra_shards) {
auto bucket = shard.get_bucket();
auto bucket = shard->get_bucket();
auto node_location = shard->get_node_location();
while (bucket.size()) {
Node *temp = bucket.back();
bucket.pop_back();
node_location.erase(temp->get_id());
extra_shards_copy[extra_nodes_to_thread_index[temp->get_id()]]
.add_graph_node(temp);
->add_graph_node(temp);
}
}
extra_shards = extra_shards_copy;
for (int i = 0; i < task_pool_size_; ++i) {
delete extra_shards[i];
extra_shards[i] = extra_shards_copy[i];
}
return 0;
}

Expand All @@ -416,11 +424,11 @@ Node *GraphTable::find_node(uint64_t id) {
if (iter == extra_nodes_to_thread_index.end())
return nullptr;
else {
return extra_shards[iter->second].find_node(id);
return extra_shards[iter->second]->find_node(id);
}
}
size_t index = shard_id - shard_start;
Node *node = shards[index].find_node(id);
Node *node = shards[index]->find_node(id);
return node;
}
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
Expand All @@ -446,13 +454,13 @@ int32_t GraphTable::clear_nodes() {
for (size_t i = 0; i < shards.size(); i++) {
tasks.push_back(
_shards_task_pool[i % task_pool_size_]->enqueue([this, i]() -> int {
this->shards[i].clear();
this->shards[i]->clear();
return 0;
}));
}
for (size_t i = 0; i < extra_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i]->enqueue([this, i]() -> int {
this->extra_shards[i].clear();
this->extra_shards[i]->clear();
return 0;
}));
}
Expand All @@ -465,7 +473,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
int &actual_size) {
int total_size = 0;
for (int i = 0; i < shards.size(); i++) {
total_size += shards[i].get_size();
total_size += shards[i]->get_size();
}
if (sample_size > total_size) sample_size = total_size;
int range_num = random_sample_nodes_ranges;
Expand Down Expand Up @@ -655,7 +663,7 @@ int32_t GraphTable::set_node_feat(
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = shards[index].add_feature_node(node_id);
auto node = shards[index]->add_feature_node(node_id);
node->set_feature_size(this->feat_name.size());
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
Expand Down Expand Up @@ -712,7 +720,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int size = 0, cur_size;
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < shards.size() && total_size > 0; i++) {
cur_size = shards[i].get_size();
cur_size = shards[i]->get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
Expand All @@ -721,7 +729,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<Node *> {
return this->shards[i].get_batch(start - size, end - size, step);
return this->shards[i]->get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
Expand Down Expand Up @@ -796,9 +804,13 @@ int32_t GraphTable::initialize() {
shard_end = shard_start + shard_num_per_server;
VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start "
<< shard_start << " shard_end " << shard_end;
shards = std::vector<GraphShard>(shard_num_per_server, GraphShard());
for (int i = 0; i < shard_num_per_server; i++) {
shards.push_back(new GraphShard());
}
use_duplicate_nodes = false;
extra_shards = std::vector<GraphShard>(task_pool_size_, GraphShard());
for (int i = 0; i < task_pool_size_; i++) {
extra_shards.push_back(new GraphShard());
}

return 0;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class GraphShard {
void delete_node(uint64_t id);
void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> get_node_location() {
std::unordered_map<uint64_t, int> &get_node_location() {
return node_location;
}

Expand Down Expand Up @@ -434,7 +434,7 @@ class GraphTable : public SparseTable {
}

protected:
std::vector<GraphShard> shards, extra_shards;
std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3;
Expand Down

0 comments on commit afd080a

Please sign in to comment.