Skip to content

Commit

Permalink
[GraphBolt][CUDA] Add FeatureCache::IndexSelect. (#7526)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 16, 2024
1 parent 50a0ae8 commit 0984ad9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
17 changes: 13 additions & 4 deletions graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
*/
#include "./feature_cache.h"

#include "./index_select.h"

namespace graphbolt {
namespace storage {

constexpr int kIntGrainSize = 64;

FeatureCache::FeatureCache(
const std::vector<int64_t>& shape, torch::ScalarType dtype)
: tensor_(torch::empty(shape, c10::TensorOptions().dtype(dtype))) {}
const std::vector<int64_t>& shape, torch::ScalarType dtype, bool pin_memory)
: tensor_(torch::empty(
shape, c10::TensorOptions().dtype(dtype).pinned_memory(pin_memory))) {
}

torch::Tensor FeatureCache::Query(
torch::Tensor positions, torch::Tensor indices, int64_t size) {
Expand All @@ -52,6 +56,10 @@ torch::Tensor FeatureCache::Query(
return values;
}

torch::Tensor FeatureCache::IndexSelect(torch::Tensor positions) {
return ops::IndexSelect(tensor_, positions);
}

void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size();
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
Expand All @@ -68,8 +76,9 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
}

c10::intrusive_ptr<FeatureCache> FeatureCache::Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype) {
return c10::make_intrusive<FeatureCache>(shape, dtype);
const std::vector<int64_t>& shape, torch::ScalarType dtype,
bool pin_memory) {
return c10::make_intrusive<FeatureCache>(shape, dtype, pin_memory);
}

} // namespace storage
Expand Down
20 changes: 16 additions & 4 deletions graphbolt/src/feature_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,40 @@ struct FeatureCache : public torch::CustomClassHolder {
*
* @param shape The shape of the cache.
* @param dtype The dtype of elements stored in the cache.
* @param pin_memory Whether to pin the memory of the cache storage tensor.
*/
FeatureCache(const std::vector<int64_t>& shape, torch::ScalarType dtype);
FeatureCache(
const std::vector<int64_t>& shape, torch::ScalarType dtype,
bool pin_memory);

/**
* @brief The cache query function. Allocates an empty tensor `values` with
* size as the first dimension and runs
* values[indices[:positions.size(0)]] = cache_tensor[positions] before
* returning it.
*
* @param positions The positions of the queries items.
* @param positions The positions of the queried items.
* @param indices The indices of the queried items among the original keys.
* Only the first portion corresponding to the provided positions tensor is
* used, e.g. indices[:positions.size(0)].
* @param size The size of the original keys, hence the first dimension of
* the output shape.
* @param pin_memory Whether to pin the memory of the output values tensor.
*
* @return The values tensor is returned. Its memory is pinned if pin_memory
* is true.
*/
torch::Tensor Query(
torch::Tensor positions, torch::Tensor indices, int64_t size);

/**
* @brief The cache tensor index_select returns cache_tensor[positions].
*
* @param positions The positions of the queried items.
*
* @return The values tensor is returned on the same device as positions.
*/
torch::Tensor IndexSelect(torch::Tensor positions);

/**
* @brief The cache replace function.
*
Expand All @@ -66,7 +77,8 @@ struct FeatureCache : public torch::CustomClassHolder {
void Replace(torch::Tensor positions, torch::Tensor values);

static c10::intrusive_ptr<FeatureCache> Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype);
const std::vector<int64_t>& shape, torch::ScalarType dtype,
bool pin_memory);

private:
torch::Tensor tensor_;
Expand Down
1 change: 1 addition & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ TORCH_LIBRARY(graphbolt, m) {
"clock_cache_policy",
&storage::PartitionedCachePolicy::Create<storage::ClockCachePolicy>);
m.class_<storage::FeatureCache>("FeatureCache")
.def("index_select", &storage::FeatureCache::IndexSelect)
.def("query", &storage::FeatureCache::Query)
.def("replace", &storage::FeatureCache::Replace);
m.def("feature_cache", &storage::FeatureCache::Create);
Expand Down
10 changes: 8 additions & 2 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,20 @@ class FeatureCache(object):
policy: str, optional
The cache policy. Default is "sieve". "s3-fifo", "lru" and "clock" are
also available.
pin_memory: bool, optional
Whether the cache storage should be pinned.
"""

def __init__(self, cache_shape, dtype, num_parts=1, policy="sieve"):
def __init__(
self, cache_shape, dtype, num_parts=1, policy="sieve", pin_memory=False
):
assert (
policy in caching_policies
), f"{list(caching_policies.keys())} are the available caching policies."
self._policy = caching_policies[policy](cache_shape[0], num_parts)
self._cache = torch.ops.graphbolt.feature_cache(cache_shape, dtype)
self._cache = torch.ops.graphbolt.feature_cache(
cache_shape, dtype, pin_memory
)
self.total_miss = 0
self.total_queries = 0

Expand Down

0 comments on commit 0984ad9

Please sign in to comment.