Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Refactor S3-FIFO and add SIEVE, LRU and CLOCK. #7508

Merged
merged 12 commits into from
Jul 9, 2024
174 changes: 128 additions & 46 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,79 +22,161 @@
namespace graphbolt {
namespace storage {

S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
: small_queue_(capacity / 10),
main_queue_(capacity - capacity / 10),
ghost_queue_time_(0),
capacity_(capacity),
cache_usage_(0) {}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
S3FifoCachePolicy::Query(torch::Tensor keys) {
auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
auto missing_keys = torch::empty_like(keys);
template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto indices = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto filtered_keys =
torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned()));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "S3FifoCachePolicy::Query::DispatchForKeys", ([&] {
keys.scalar_type(), "BaseCachePolicy::Query::DispatchForKeys", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
cache_key.Increment();
positions_ptr[found_cnt] = cache_key.getPos();
auto pos = policy.Read(key);
if (pos.has_value()) {
positions_ptr[found_cnt] = *pos;
filtered_keys_ptr[found_cnt] = key;
indices_ptr[found_cnt++] = i;
} else {
indices_ptr[--missing_cnt] = i;
missing_keys_ptr[missing_cnt] = key;
filtered_keys_ptr[missing_cnt] = key;
}
}
}));
return {
positions.slice(0, 0, found_cnt), indices,
missing_keys.slice(0, found_cnt)};
filtered_keys.slice(0, found_cnt), filtered_keys.slice(0, 0, found_cnt)};
}

torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
auto positions = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
template <typename CachePolicy>
torch::Tensor BaseCachePolicy::ReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] {
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
phmap::flat_hash_set<int64_t> position_set;
position_set.reserve(keys.size(0));
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto it = key_to_cache_key_.find(key);
if (it !=
key_to_cache_key_.end()) { // Already in the cache, inc freq.
auto& cache_key = *it->second;
cache_key.Increment();
positions_ptr[i] = cache_key.getPos();
} else {
const auto in_ghost_queue = InGhostQueue(key);
auto& queue = in_ghost_queue ? main_queue_ : small_queue_;
int64_t pos;
if (queue.IsFull()) {
// When the queue is full, we need to make a space by evicting.
// Inside ghost queue means insertion into M, otherwise S.
pos = (in_ghost_queue ? EvictMainQueue() : EvictSmallQueue());
} else { // If the cache is not full yet, get an unused empty slot.
pos = cache_usage_++;
}
TORCH_CHECK(0 <= pos && pos < capacity_, "Position out of bounds!");
key_to_cache_key_[key] = queue.Push(CacheKey(key, pos));
positions_ptr[i] = pos;
}
const auto pos_optional = policy.Read(key);
const auto pos = pos_optional ? *pos_optional : policy.Insert(key);
positions_ptr[i] = pos;
TORCH_CHECK(
std::get<1>(position_set.insert(pos)),
"Can't insert all, larger cache capacity is needed.");
}
}));
TrimGhostQueue();
return positions;
}

template <typename CachePolicy>
void BaseCachePolicy::ReadingCompletedImpl(
CachePolicy& policy, torch::Tensor keys) {
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
policy.Unmark(keys_ptr[i]);
}
}));
}

S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
: small_queue_(capacity),
main_queue_(capacity),
ghost_queue_(capacity - capacity / 10),
capacity_(capacity),
cache_usage_(0),
small_queue_size_target_(capacity / 10) {
TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough.");
ghost_set_.reserve(ghost_queue_.Capacity());
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
S3FifoCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

SieveCachePolicy::SieveCachePolicy(int64_t capacity)
: hand_(queue_.end()), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
SieveCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

LruCachePolicy::LruCachePolicy(int64_t capacity)
: capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
LruCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

ClockCachePolicy::ClockCachePolicy(int64_t capacity)
: queue_(capacity), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
ClockCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

} // namespace storage
} // namespace graphbolt
Loading
Loading