Skip to content

Commit

Permalink
reduce sample threads when cache is not used
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Nov 10, 2021
1 parent 0c0b63e commit 24e1df2
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 117 deletions.
166 changes: 65 additions & 101 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ FeatureNode *GraphShard::add_feature_node(uint64_t id) {
return (FeatureNode *)bucket[node_location[id]];
}

void GraphShard::add_neighboor(uint64_t id, uint64_t dst_id, float weight) {
void GraphShard::add_neighbor(uint64_t id, uint64_t dst_id, float weight) {
find_node(id)->add_edge(dst_id, weight);
}

Expand Down Expand Up @@ -277,7 +277,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {

size_t index = src_shard_id - shard_start;
shards[index].add_graph_node(src_id)->build_edges(is_weighted);
shards[index].add_neighboor(src_id, dst_id, weight);
shards[index].add_neighbor(src_id, dst_id, weight);
valid_count++;
}
}
Expand Down Expand Up @@ -399,114 +399,78 @@ int32_t GraphTable::random_sample_neighbors(
size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks;
if (use_cache) {
std::vector<std::vector<uint32_t>> seq_id(shard_end - shard_start);
std::vector<std::vector<SampleKey>> id_list(shard_end - shard_start);
size_t index;
for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]);
seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size);
}
for (int i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
uint64_t node_id;
std::vector<std::pair<SampleKey, SampleResult>> r;
auto response =
std::vector<std::vector<uint32_t>> seq_id(shard_end - shard_start);
std::vector<std::vector<SampleKey>> id_list(shard_end - shard_start);
size_t index;
for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]);
seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size);
}
for (int i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
uint64_t node_id;
std::vector<std::pair<SampleKey, SampleResult>> r;
LRUResponse response = LRUResponse::blocked;
if (use_cache) {
response =
scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r);
int index = 0;
uint32_t idx;
std::vector<SampleResult> sample_res;
std::vector<SampleKey> sample_keys;
auto &rng = _shards_task_rng_pool[i];
for (size_t k = 0; k < id_list[i].size(); k++) {
if (index < r.size() &&
r[index].first.node_key == id_list[i][k].node_key) {
idx = seq_id[i][k];
actual_sizes[idx] = r[index].second.actual_size;
buffers[idx] = r[index].second.buffer;
index++;
}
int index = 0;
uint32_t idx;
std::vector<SampleResult> sample_res;
std::vector<SampleKey> sample_keys;
auto &rng = _shards_task_rng_pool[i];
for (size_t k = 0; k < id_list[i].size(); k++) {
if (index < r.size() &&
r[index].first.node_key == id_list[i][k].node_key) {
idx = seq_id[i][k];
actual_sizes[idx] = r[index].second.actual_size;
buffers[idx] = r[index].second.buffer;
index++;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(node_id);
idx = seq_id[i][k];
int &actual_size = actual_sizes[idx];
if (node == nullptr) {
actual_size = 0;
continue;
}
std::shared_ptr<char> &buffer = buffers[idx];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(node_id, sample_size);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(node_id);
idx = seq_id[i][k];
int &actual_size = actual_sizes[idx];
if (node == nullptr) {
actual_size = 0;
continue;
}
std::shared_ptr<char> &buffer = buffers[idx];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(node_id, sample_size);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
buffer.reset(buffer_addr, char_del);
}
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
buffer.reset(buffer_addr, char_del);
}
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
}
if (sample_res.size()) {
scaled_lru->insert(i, sample_keys.data(), sample_res.data(),
sample_keys.size());
}
return 0;
}));
}
for (auto &t : tasks) {
t.get();
}
return 0;
}
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t &node_id = node_ids[idx];
std::shared_ptr<char> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx];

int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];

tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
Node *node = find_node(node_id);

if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr, char_del);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
if (sample_res.size()) {
scaled_lru->insert(i, sample_keys.data(), sample_res.data(),
sample_keys.size());
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
for (auto &t : tasks) {
t.get();
}
return 0;
}
Expand Down
23 changes: 7 additions & 16 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class GraphShard {
Node *find_node(uint64_t id);
void delete_node(uint64_t id);
void clear();
void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> get_node_location() {
return node_location;
}
Expand All @@ -81,9 +81,7 @@ struct SampleKey {
uint64_t node_key;
size_t sample_size;
SampleKey(uint64_t _node_key, size_t _sample_size)
: node_key(_node_key), sample_size(_sample_size) {
// std::cerr<<"in constructor of samplekey\n";
}
: node_key(_node_key), sample_size(_sample_size) {}
bool operator==(const SampleKey &s) const {
return node_key == s.node_key && sample_size == s.sample_size;
}
Expand Down Expand Up @@ -143,7 +141,7 @@ class RandomSampleLRU {
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.push_back({keys[i], iter->second->data});
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second, true);
Expand Down Expand Up @@ -252,8 +250,6 @@ class ScaledLRU {
}
}

// shrink();
// std::cerr<<"shrink job in queue\n";
auto status =
thread_pool->enqueue([this]() -> int { return shrink(); });
status.wait();
Expand All @@ -263,10 +259,8 @@ class ScaledLRU {
}
~ScaledLRU() {
std::unique_lock<std::mutex> lock(mutex_);
// std::cerr<<"cancel shrink job\n";
stop = true;
cv_.notify_one();
// pthread_cancel(shrink_job.native_handle());
}
LRUResponse query(size_t index, K *keys, size_t length,
std::vector<std::pair<K, V>> &res) {
Expand All @@ -280,10 +274,7 @@ class ScaledLRU {
std::string t = "";
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size;
// t += std::to_string(i) + "->" + std::to_string(lru_pool[i].node_size) +
// " ";
}
// std::cout<<t<<std::endl;

if (node_size <= size_limit) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
Expand All @@ -299,7 +290,7 @@ class ScaledLRU {
}
}
if (global_count > size_limit) {
// std::cout<<"before shrinking cache, cached nodes count =
// VLOG(0)<<"before shrinking cache, cached nodes count =
// "<<global_count<<std::endl;
size_t remove = global_count - size_limit;
while (remove--) {
Expand All @@ -313,11 +304,11 @@ class ScaledLRU {
remove_node.lru_pointer->key_map.erase(remove_node.node->key);
remove_node.lru_pointer->remove(remove_node.node, true);
}
// std::cout<<"after shrinking cache, cached nodes count =
// VLOG(0)<<"after shrinking cache, cached nodes count =
// "<<global_count<<std::endl;
}
} catch (...) {
// std::cout << "shrink cache failed"<<std::endl;
// VLOG(0) << "shrink cache failed"<<std::endl;
pthread_rwlock_unlock(&rwlock);
return -1;
}
Expand All @@ -330,7 +321,7 @@ class ScaledLRU {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.5 * size_limit)) {
// std::cout<<"global_count too large "<<global_count<<" enter start
// VLOG(0)<<"global_count too large "<<global_count<<" enter start
// shrink task\n";
thread_pool->enqueue([this]() -> int { return shrink(); });
}
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,28 @@ void testCache() {
}
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
::paddle::distributed::SampleResult>
cache1(2, 1, 4);
str = new char[18];
strcpy(str, "3433776521");
result = new ::paddle::distributed::SampleResult(strlen(str), str);
cache1.insert(1, &skey, result, 1);
::paddle::distributed::SampleKey skey1 = {8, 1};
char* str1 = new char[18];
strcpy(str1, "3xcf2eersfd");
usleep(3000); // sleep 3ms to guaruntee that skey1's time stamp is larger
// than skey;
auto result1 = new ::paddle::distributed::SampleResult(strlen(str1), str1);
cache1.insert(0, &skey1, result1, 1);
sleep(1); // sleep 1 s to guarantee that shrinking work is done
cache1.query(1, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
cache1.query(0, &skey1, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p1 = (char*)r[0].second.buffer.get();
for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p1[j], str1[j]);
r.clear();
}
void testGraphToBuffer() {
::paddle::distributed::GraphNode s, s1;
Expand Down

0 comments on commit 24e1df2

Please sign in to comment.