Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Change hash table for performance. #7631

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@
[submodule "third_party/taskflow"]
path = third_party/taskflow
url = https://github.com/taskflow/taskflow.git
[submodule "third_party/tsl_robin_map"]
path = third_party/tsl_robin_map
url = https://github.com/Tessil/robin-map.git
2 changes: 1 addition & 1 deletion graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ include_directories(BEFORE ${BOLT_DIR}
# `std::atomic_ref`, `std::counting_semaphore`
"../third_party/cccl/libcudacxx/include"
"../third_party/pcg/include"
"../third_party/phmap")
"../third_party/tsl_robin_map/include")
target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}")
if(BUILD_WITH_TASKFLOW)
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE "../third_party/taskflow")
Expand Down
23 changes: 14 additions & 9 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
#ifndef GRAPHBOLT_CACHE_POLICY_H_
#define GRAPHBOLT_CACHE_POLICY_H_

#include <parallel_hashmap/phmap.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>

#include <limits>
#include <mutex>
Expand Down Expand Up @@ -178,9 +179,13 @@ class BaseCachePolicy {

protected:
template <typename K, typename V>
using map_t = phmap::flat_hash_map<K, V>;
using map_t = tsl::robin_map<K, V>;
template <typename K>
using set_t = phmap::flat_hash_set<K>;
using set_t = tsl::robin_set<K>;
template <typename iterator>
static auto& mutable_value_ref(iterator it) {
return it.value();
}
static constexpr int kCapacityFactor = 2;

template <typename CachePolicy>
Expand Down Expand Up @@ -298,7 +303,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
const auto in_ghost_queue = ghost_set_.erase(key);
auto& queue = in_ghost_queue ? main_queue_ : small_queue_;
auto cache_key_ptr = queue.Push(CacheKey(key));
it->second = cache_key_ptr;
mutable_value_ref(it) = cache_key_ptr;
return &cache_key_ptr->setPos(Evict());
}

Expand All @@ -318,7 +323,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
auto it = key_to_cache_key_.find(evicted.getKey());
if (evicted.getFreq() > 0 || evicted.InUse()) {
evicted.Decrement();
it->second = main_queue_.Push(evicted);
mutable_value_ref(it) = main_queue_.Push(evicted);
} else {
key_to_cache_key_.erase(it);
return evicted.getPos();
Expand All @@ -332,7 +337,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
auto evicted = small_queue_.Pop();
auto it = key_to_cache_key_.find(evicted.getKey());
if (evicted.getFreq() > 0 || evicted.InUse()) {
it->second = main_queue_.Push(evicted.ResetFreq());
mutable_value_ref(it) = main_queue_.Push(evicted.ResetFreq());
} else {
key_to_cache_key_.erase(it);
const auto evicted_key = evicted.getKey();
Expand Down Expand Up @@ -449,7 +454,7 @@ class SieveCachePolicy : public BaseCachePolicy {
const auto key = it->first;
queue_.push_front(CacheKey(key));
auto cache_key_ptr = &queue_.front();
it->second = cache_key_ptr;
mutable_value_ref(it) = cache_key_ptr;
return &cache_key_ptr->setPos(Evict());
}

Expand Down Expand Up @@ -589,7 +594,7 @@ class LruCachePolicy : public BaseCachePolicy {
CacheKey* Insert(map_iterator it) {
const auto key = it->first;
queue_.push_front(CacheKey(key));
it->second = queue_.begin();
mutable_value_ref(it) = queue_.begin();
auto cache_key_ptr = &*queue_.begin();
return &cache_key_ptr->setPos(Evict());
}
Expand Down Expand Up @@ -718,7 +723,7 @@ class ClockCachePolicy : public BaseCachePolicy {
CacheKey* Insert(map_iterator it) {
const auto key = it->first;
auto cache_key_ptr = queue_.Push(CacheKey(key));
it->second = cache_key_ptr;
mutable_value_ref(it) = cache_key_ptr;
return &cache_key_ptr->setPos(Evict());
}

Expand Down
1 change: 1 addition & 0 deletions third_party/tsl_robin_map
Submodule tsl_robin_map added at 1115da
Loading