Skip to content

Commit

Permalink
cache optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Nov 13, 2021
1 parent 1325315 commit 2e48442
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 123 deletions.
201 changes: 104 additions & 97 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ class LRUNode {
LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
next = pre = NULL;
}
std::chrono::milliseconds ms;
// the last hit time
K key;
V data;
size_t ttl;
Expand All @@ -119,12 +117,13 @@ class ScaledLRU;
template <typename K, typename V>
class RandomSampleLRU {
public:
RandomSampleLRU(ScaledLRU<K, V> *_father) : father(_father) {
RandomSampleLRU(ScaledLRU<K, V> *_father) {
father = _father;
remove_count = 0;
node_size = 0;
node_head = node_end = NULL;
global_ttl = father->ttl;
extra_penalty = 0;
size_limit = (father->size_limit / father->shard_num + 1);
total_diff = 0;
}

~RandomSampleLRU() {
Expand All @@ -138,63 +137,71 @@ class RandomSampleLRU {
LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
int init_node_size = node_size;
try {
// pthread_rwlock_rdlock(&father->rwlock);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second);
} else {
move_to_tail(iter->second);
}
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);

for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second);
if (remove_count != 0) remove_count--;
} else {
move_to_tail(iter->second);
}
}
} catch (...) {
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::err;
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::ok;
}
LRUResponse insert(K *keys, V *data, size_t length) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
int init_node_size = node_size;
try {
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
move_to_tail(iter->second);
iter->second->ttl = global_ttl;
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_new(temp);
}
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
move_to_tail(iter->second);
iter->second->ttl = global_ttl;
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_new(temp);
}
} catch (...) {
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::err;
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}

pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::ok;
}
void remove(LRUNode<K, V> *node) {
fetch(node);
node_size--;
key_map.erase(node->key);
delete node;
if (node_size >= size_limit) {
extra_penalty -= 1.0;
}

void process_redundant(int process_size) {
size_t length = std::min(remove_count, process_size);
while (length--) {
remove(node_head);
remove_count--;
}
// std::cerr<<"after remove_count = "<<remove_count<<std::endl;
}

void move_to_tail(LRUNode<K, V> *node) {
Expand All @@ -207,12 +214,6 @@ class RandomSampleLRU {
place_at_tail(node);
node_size++;
key_map[node->key] = node;
if (node_size > size_limit) {
extra_penalty += penalty_inc;
if (extra_penalty >= 1.0) {
remove(node_head);
}
}
}
void place_at_tail(LRUNode<K, V> *node) {
if (node_end == NULL) {
Expand All @@ -224,8 +225,6 @@ class RandomSampleLRU {
node->next = NULL;
node_end = node;
}
node->ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch());
}

void fetch(LRUNode<K, V> *node) {
Expand All @@ -245,11 +244,10 @@ class RandomSampleLRU {
std::unordered_map<K, LRUNode<K, V> *> key_map;
ScaledLRU<K, V> *father;
size_t global_ttl, size_limit;
int node_size;
int node_size, total_diff;
LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V>;
float extra_penalty;
const float penalty_inc = 0.75;
int remove_count;
};

template <typename K, typename V>
Expand All @@ -268,7 +266,7 @@ class ScaledLRU {
while (true) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait_for(lock, std::chrono::milliseconds(20000));
cv_.wait_for(lock, std::chrono::milliseconds(3000));
if (stop) {
return;
}
Expand All @@ -295,52 +293,33 @@ class ScaledLRU {
int shrink() {
int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size;
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
}

if (node_size <= 1.2 * size_limit) return 0;
if (node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
try {
global_count = 0;
std::priority_queue<RemovedNode, std::vector<RemovedNode>,
std::greater<RemovedNode>>
q;
for (size_t i = 0; i < lru_pool.size(); i++) {
if (lru_pool[i].node_size > 0) {
global_count += lru_pool[i].node_size;
q.push({lru_pool[i].node_head, &lru_pool[i]});
}
}
if (global_count > size_limit) {
// VLOG(0)<<"before shrinking cache, cached nodes count =
// "<<global_count<<std::endl;
size_t remove = global_count - size_limit;
while (remove--) {
RemovedNode remove_node = q.top();
q.pop();
auto next = remove_node.node->next;
if (next) {
q.push({next, remove_node.lru_pointer});
}
global_count--;
remove_node.lru_pointer->remove(remove_node.node);
}
for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].size_limit = lru_pool[i].node_size;
lru_pool[i].extra_penalty = 0;
}
// VLOG(0)<<"after shrinking cache, cached nodes count =
// // "<<global_count<<std::endl;
// std::cerr<<"in shrink\n";
global_count = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
}
// std::cerr<<"global_count "<<global_count<<"\n";
if (global_count > size_limit) {
size_t remove = global_count - size_limit;
for (int i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0;
lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
global_count * remove;
// std::cerr<<i<<" "<<lru_pool[i].remove_count<<std::endl;
}
} catch (...) {
pthread_rwlock_unlock(&rwlock);
return -1;
}
pthread_rwlock_unlock(&rwlock);
return 0;
}
return 0;
}

void handle_size_diff(int diff) {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
Expand All @@ -358,18 +337,13 @@ class ScaledLRU {
pthread_rwlock_t rwlock;
size_t shard_num;
int global_count;
size_t size_limit;
size_t size_limit, total, hit;
size_t ttl;
bool stop;
std::thread shrink_job;
std::vector<RandomSampleLRU<K, V>> lru_pool;
mutable std::mutex mutex_;
std::condition_variable cv_;
struct RemovedNode {
LRUNode<K, V> *node;
RandomSampleLRU<K, V> *lru_pointer;
bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; }
};
std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V>;
};
Expand Down Expand Up @@ -448,13 +422,46 @@ class GraphTable : public SparseTable {
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
shard_end - shard_start, size_limit, ttl));
task_pool_size_, size_limit, ttl));
use_cache = true;
}
}
return 0;
}

virtual int32_t test_sample_with_cache(int size, int batch_size,
int sample_size) {
std::vector<int> actual_sizes1, actual_sizes2;
std::vector<std::shared_ptr<char>> buffers1, buffers2;
std::vector<uint64_t> node_ids1(batch_size), node_ids2;
for (int i = 0; i <= size - batch_size; i += batch_size) {
for (int j = 0; j < batch_size; j++) {
node_ids1[j] = i + j;
}
actual_sizes1.resize(batch_size);
buffers1.resize(batch_size);
random_sample_neighbors(node_ids1.data(), sample_size, buffers1,
actual_sizes1);
node_ids2.clear();
for (int j = 0; j < batch_size; j++) {
if (actual_sizes1[j] != 0) {
int offset = 0;
char *p = buffers1[j].get();
while (offset < actual_sizes1[j]) {
node_ids2.push_back(*(uint64_t *)(p + offset));
offset += Node::id_size + Node::weight_size;
}
}
}
buffers2.resize(node_ids2.size());
actual_sizes2.resize(node_ids2.size());
random_sample_neighbors(node_ids2.data(), sample_size, buffers2,
actual_sizes2);
}

return 0;
}

protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
Expand Down
27 changes: 1 addition & 26 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -681,28 +678,6 @@ void testCache() {
}
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
::paddle::distributed::SampleResult>
cache1(2, 1, 4);
str = new char[18];
strcpy(str, "3433776521");
result = new ::paddle::distributed::SampleResult(strlen(str), str);
cache1.insert(1, &skey, result, 1);
::paddle::distributed::SampleKey skey1 = {8, 1};
char* str1 = new char[18];
strcpy(str1, "3xcf2eersfd");
usleep(3000); // sleep 3ms to guaruntee that skey1's time stamp is larger
// than skey;
auto result1 = new ::paddle::distributed::SampleResult(strlen(str1), str1);
cache1.insert(0, &skey1, result1, 1);
sleep(1); // sleep 1 s to guarantee that shrinking work is done
cache1.query(1, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
cache1.query(0, &skey1, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p1 = (char*)r[0].second.buffer.get();
for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p1[j], str1[j]);
r.clear();
}
void testGraphToBuffer() {
::paddle::distributed::GraphNode s, s1;
Expand All @@ -718,4 +693,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0);
}

TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }

1 comment on commit 2e48442

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 2e48442 Nov 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #37168 Commit ID: 2e48442 contains failed CI.

🔹 Failed: PR-CI-iScan-Python

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Windows-Inference

Unknown Failed
2021-11-13 17:10:20 Updating files:  98% (6451/6582)
2021-11-13 17:10:20 Updating files: 99% (6517/6582)
2021-11-13 17:10:20 Updating files: 100% (6582/6582)
2021-11-13 17:10:20 Updating files: 100% (6582/6582), done.
2021-11-13 17:10:20 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>if exist Paddle.git\index.lock del Paddle.git\index.lock
2021-11-13 17:10:20 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>git config --global user.name "PaddleCI"
2021-11-13 17:10:20 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>git config --global user.email "paddle_ci@example.com"
2021-11-13 17:10:20 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>git fetch origin pull/37168/head
2021-11-13 17:11:33 From https://github.com/PaddlePaddle/Paddle
2021-11-13 17:11:33 * branch refs/pull/37168/head -> FETCH_HEAD
2021-11-13 17:11:33 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>if 0 NEQ 0 exit /b 1
2021-11-13 17:11:33 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>git checkout -b test_pr FETCH_HEAD
2021-11-13 17:11:33 Switched to a new branch 'test_pr'
2021-11-13 17:11:33 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>if 0 NEQ 0 exit /b 1
2021-11-13 17:11:33 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>git merge --no-edit develop
2021-11-13 17:11:33 Auto-merging paddle/fluid/distributed/table/common_graph_table.h
2021-11-13 17:11:33 CONFLICT (content): Merge conflict in paddle/fluid/distributed/table/common_graph_table.h
2021-11-13 17:11:33 Automatic merge failed; fix conflicts and then commit the result.
2021-11-13 17:11:33 C:\Users\Administrator\Downloads\workspace\67d7c9af-274d-4ca6-bb50-23fac0451939\Paddle>if 1 NEQ 0 exit /b 1

Please sign in to comment.