Skip to content

Commit

Permalink
Mark items currently being read as unevictable.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 9, 2024
1 parent 1f1332c commit c312794
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 102 deletions.
57 changes: 43 additions & 14 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,40 @@ namespace graphbolt {
namespace storage {

template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto indices = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto missing_keys =
auto filtered_keys =
torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned()));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "S3FifoCachePolicy::Query::DispatchForKeys", ([&] {
keys.scalar_type(), "BaseCachePolicy::Query::DispatchForKeys", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto pos = policy.Read(key);
if (pos.has_value()) {
positions_ptr[found_cnt] = *pos;
filtered_keys_ptr[found_cnt] = key;
indices_ptr[found_cnt++] = i;
} else {
indices_ptr[--missing_cnt] = i;
missing_keys_ptr[missing_cnt] = key;
filtered_keys_ptr[missing_cnt] = key;
}
}
}));
return {
positions.slice(0, 0, found_cnt), indices,
missing_keys.slice(0, found_cnt)};
filtered_keys.slice(0, found_cnt), filtered_keys.slice(0, 0, found_cnt)};
}

template <typename CachePolicy>
Expand All @@ -65,7 +66,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "S3FifoCachePolicy::Replace", ([&] {
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
phmap::flat_hash_set<int64_t> position_set;
Expand All @@ -83,6 +84,18 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
return positions;
}

template <typename CachePolicy>
void BaseCachePolicy::ReadingCompletedImpl(
CachePolicy& policy, torch::Tensor keys) {
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
policy.Unmark(keys_ptr[i]);
}
}));
}

S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
: small_queue_(capacity),
main_queue_(capacity),
Expand All @@ -95,7 +108,7 @@ S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
S3FifoCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}
Expand All @@ -104,50 +117,66 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

SieveCachePolicy::SieveCachePolicy(int64_t capacity)
: hand_(queue_.end()), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SieveCachePolicy::Query(
torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
SieveCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

LruCachePolicy::LruCachePolicy(int64_t capacity)
: capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> LruCachePolicy::Query(
torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
LruCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

ClockCachePolicy::ClockCachePolicy(int64_t capacity)
: queue_(capacity), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> ClockCachePolicy::Query(
torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
ClockCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
}

} // namespace storage
} // namespace graphbolt
Loading

0 comments on commit c312794

Please sign in to comment.