Skip to content

Commit

Permalink
take back cache changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 16, 2024
1 parent 80fa1de commit 2e416a1
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 50 deletions.
10 changes: 0 additions & 10 deletions graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ torch::Tensor FeatureCache::Query(
return values;
}

c10::intrusive_ptr<Future<torch::Tensor>> FeatureCache::QueryAsync(
torch::Tensor positions, torch::Tensor indices, int64_t size) {
return async([=] { return Query(positions, indices, size); });
}

void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size();
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
Expand All @@ -72,11 +67,6 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
});
}

c10::intrusive_ptr<Future<void>> FeatureCache::ReplaceAsync(
torch::Tensor positions, torch::Tensor values) {
return async([=] { return Replace(positions, values); });
}

c10::intrusive_ptr<FeatureCache> FeatureCache::Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype) {
return c10::make_intrusive<FeatureCache>(shape, dtype);
Expand Down
6 changes: 0 additions & 6 deletions graphbolt/src/feature_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ struct FeatureCache : public torch::CustomClassHolder {
torch::Tensor Query(
torch::Tensor positions, torch::Tensor indices, int64_t size);

c10::intrusive_ptr<Future<torch::Tensor>> QueryAsync(
torch::Tensor positions, torch::Tensor indices, int64_t size);

/**
* @brief The cache replace function.
*
Expand All @@ -69,9 +66,6 @@ struct FeatureCache : public torch::CustomClassHolder {
*/
void Replace(torch::Tensor positions, torch::Tensor values);

c10::intrusive_ptr<Future<void>> ReplaceAsync(
torch::Tensor positions, torch::Tensor values);

static c10::intrusive_ptr<FeatureCache> Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype);

Expand Down
18 changes: 0 additions & 18 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,6 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
return std::make_tuple(positions, output_indices, missing_keys, found_keys);
}

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
return async([=] {
auto [positions, output_indices, missing_keys, found_keys] = Query(keys);
return std::vector{positions, output_indices, missing_keys, found_keys};
});
}

torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) {
if (policies_.size() == 1) return policies_[0]->Replace(keys);
torch::Tensor offsets, indices, permuted_keys;
Expand All @@ -221,11 +213,6 @@ torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) {
return output_positions;
}

c10::intrusive_ptr<Future<torch::Tensor>> PartitionedCachePolicy::ReplaceAsync(
torch::Tensor keys) {
return async([=] { return Replace(keys); });
}

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
if (policies_.size() == 1) {
policies_[0]->ReadingCompleted(keys);
Expand All @@ -243,11 +230,6 @@ void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
});
}

c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::ReadingCompletedAsync(
torch::Tensor keys) {
return async([=] { return ReadingCompleted(keys); });
}

template <typename CachePolicy>
c10::intrusive_ptr<PartitionedCachePolicy> PartitionedCachePolicy::Create(
int64_t capacity, int64_t num_partitions) {
Expand Down
7 changes: 0 additions & 7 deletions graphbolt/src/partitioned_cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ class PartitionedCachePolicy : public BaseCachePolicy,
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(
torch::Tensor keys);

/**
* @brief The policy replace function.
* @param keys The keys to query the cache.
Expand All @@ -78,16 +75,12 @@ class PartitionedCachePolicy : public BaseCachePolicy,
*/
torch::Tensor Replace(torch::Tensor keys);

c10::intrusive_ptr<Future<torch::Tensor>> ReplaceAsync(torch::Tensor keys);

/**
* @brief A reader has finished reading these keys, so they can be evicted.
* @param keys The keys to unmark.
*/
void ReadingCompleted(torch::Tensor keys);

c10::intrusive_ptr<Future<void>> ReadingCompletedAsync(torch::Tensor keys);

template <typename CachePolicy>
static c10::intrusive_ptr<PartitionedCachePolicy> Create(
int64_t capacity, int64_t num_partitions);
Expand Down
11 changes: 2 additions & 9 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,10 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);
m.class_<storage::PartitionedCachePolicy>("PartitionedCachePolicy")
.def("query", &storage::PartitionedCachePolicy::Query)
.def("query_async", &storage::PartitionedCachePolicy::QueryAsync)
.def("replace", &storage::PartitionedCachePolicy::Replace)
.def("replace_async", &storage::PartitionedCachePolicy::ReplaceAsync)
.def(
"reading_completed",
&storage::PartitionedCachePolicy::ReadingCompleted)
.def(
"reading_completed_async",
&storage::PartitionedCachePolicy::ReadingCompletedAsync);
&storage::PartitionedCachePolicy::ReadingCompleted);
m.def(
"s3_fifo_cache_policy",
&storage::PartitionedCachePolicy::Create<storage::S3FifoCachePolicy>);
Expand All @@ -128,9 +123,7 @@ TORCH_LIBRARY(graphbolt, m) {
&storage::PartitionedCachePolicy::Create<storage::ClockCachePolicy>);
m.class_<storage::FeatureCache>("FeatureCache")
.def("query", &storage::FeatureCache::Query)
.def("query_async", &storage::FeatureCache::QueryAsync)
.def("replace", &storage::FeatureCache::Replace)
.def("replace_async", &storage::FeatureCache::ReplaceAsync);
.def("replace", &storage::FeatureCache::Replace);
m.def("feature_cache", &storage::FeatureCache::Create);
m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
Expand Down

0 comments on commit 2e416a1

Please sign in to comment.