Skip to content

Commit

Permalink
[GraphBolt] Change hash table for performance. (#7631)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 2, 2024
1 parent 1a8cf7e commit f724ec0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@
[submodule "third_party/taskflow"]
path = third_party/taskflow
url = https://github.com/taskflow/taskflow.git
[submodule "third_party/tsl_robin_map"]
path = third_party/tsl_robin_map
url = https://github.com/Tessil/robin-map.git
2 changes: 1 addition & 1 deletion graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ include_directories(BEFORE ${BOLT_DIR}
# `std::atomic_ref`, `std::counting_semaphore`
"../third_party/cccl/libcudacxx/include"
"../third_party/pcg/include"
"../third_party/phmap")
"../third_party/tsl_robin_map/include")
target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}")
if(BUILD_WITH_TASKFLOW)
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE "../third_party/taskflow")
Expand Down
23 changes: 14 additions & 9 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
#ifndef GRAPHBOLT_CACHE_POLICY_H_
#define GRAPHBOLT_CACHE_POLICY_H_

#include <parallel_hashmap/phmap.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>

#include <limits>
#include <mutex>
Expand Down Expand Up @@ -178,9 +179,13 @@ class BaseCachePolicy {

protected:
template <typename K, typename V>
using map_t = phmap::flat_hash_map<K, V>;
using map_t = tsl::robin_map<K, V>;
template <typename K>
using set_t = phmap::flat_hash_set<K>;
using set_t = tsl::robin_set<K>;
template <typename iterator>
static auto& mutable_value_ref(iterator it) {
return it.value();
}
static constexpr int kCapacityFactor = 2;

template <typename CachePolicy>
Expand Down Expand Up @@ -298,7 +303,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
const auto in_ghost_queue = ghost_set_.erase(key);
auto& queue = in_ghost_queue ? main_queue_ : small_queue_;
auto cache_key_ptr = queue.Push(CacheKey(key));
it->second = cache_key_ptr;
mutable_value_ref(it) = cache_key_ptr;
return &cache_key_ptr->setPos(Evict());
}

Expand All @@ -318,7 +323,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
auto it = key_to_cache_key_.find(evicted.getKey());
if (evicted.getFreq() > 0 || evicted.InUse()) {
evicted.Decrement();
it->second = main_queue_.Push(evicted);
mutable_value_ref(it) = main_queue_.Push(evicted);
} else {
key_to_cache_key_.erase(it);
return evicted.getPos();
Expand All @@ -332,7 +337,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
auto evicted = small_queue_.Pop();
auto it = key_to_cache_key_.find(evicted.getKey());
if (evicted.getFreq() > 0 || evicted.InUse()) {
it->second = main_queue_.Push(evicted.ResetFreq());
mutable_value_ref(it) = main_queue_.Push(evicted.ResetFreq());
} else {
key_to_cache_key_.erase(it);
const auto evicted_key = evicted.getKey();
Expand Down Expand Up @@ -449,7 +454,7 @@ class SieveCachePolicy : public BaseCachePolicy {
const auto key = it->first;
queue_.push_front(CacheKey(key));
auto cache_key_ptr = &queue_.front();
it->second = cache_key_ptr;
mutable_value_ref(it) = cache_key_ptr;
return &cache_key_ptr->setPos(Evict());
}

Expand Down Expand Up @@ -589,7 +594,7 @@ class LruCachePolicy : public BaseCachePolicy {
CacheKey* Insert(map_iterator it) {
const auto key = it->first;
queue_.push_front(CacheKey(key));
it->second = queue_.begin();
mutable_value_ref(it) = queue_.begin();
auto cache_key_ptr = &*queue_.begin();
return &cache_key_ptr->setPos(Evict());
}
Expand Down Expand Up @@ -718,7 +723,7 @@ class ClockCachePolicy : public BaseCachePolicy {
CacheKey* Insert(map_iterator it) {
const auto key = it->first;
auto cache_key_ptr = queue_.Push(CacheKey(key));
it->second = cache_key_ptr;
mutable_value_ref(it) = cache_key_ptr;
return &cache_key_ptr->setPos(Evict());
}

Expand Down
1 change: 1 addition & 0 deletions third_party/tsl_robin_map
Submodule tsl_robin_map added at 1115da

0 comments on commit f724ec0

Please sign in to comment.