Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Use async for GPUGraphCache. #7707

Merged
merged 8 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions graphbolt/src/cuda/extension/gpu_graph_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cub/cub.cuh>
#include <cuco/static_map.cuh>
#include <cuda/std/atomic>
#include <limits>
#include <numeric>
#include <type_traits>

Expand Down Expand Up @@ -168,6 +169,7 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
seeds.device().index() == device_id_,
"Seeds should be on the correct CUDA device.");
TORCH_CHECK(seeds.sizes().size() == 1, "Keys should be a 1D tensor.");
std::lock_guard lock(mtx_);
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
const dim3 block(kIntBlockSize);
Expand Down Expand Up @@ -237,6 +239,12 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
}));
}

c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>
GpuGraphCache::QueryAsync(torch::Tensor seeds) {
return async([=] { return Query(seeds); });
}

std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
Expand All @@ -250,6 +258,7 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
TORCH_CHECK(
indptr.size(0) == num_nodes - num_hit + 1,
"(indptr.size(0) == seeds.size(0) - num_hit + 1) failed.");
std::lock_guard lock(mtx_);
const int64_t num_buffers = num_nodes * num_tensors;
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
Expand Down Expand Up @@ -490,5 +499,18 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
}));
}

c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
GpuGraphCache::ReplaceAsync(
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
std::vector<torch::Tensor> edge_tensors) {
return async([=] {
return Replace(
seeds, indices, positions, num_hit, num_threshold, indptr,
edge_tensors);
});
}

} // namespace cuda
} // namespace graphbolt
16 changes: 14 additions & 2 deletions graphbolt/src/cuda/extension/gpu_graph_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#ifndef GRAPHBOLT_GPU_GRAPH_CACHE_H_
#define GRAPHBOLT_GPU_GRAPH_CACHE_H_

#include <graphbolt/async.h>
#include <torch/custom_class.h>
#include <torch/torch.h>

#include <limits>
#include <type_traits>
#include <mutex>

namespace graphbolt {
namespace cuda {
Expand Down Expand Up @@ -69,6 +69,10 @@ class GpuGraphCache : public torch::CustomClassHolder {
std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> Query(
torch::Tensor seeds);

c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>
QueryAsync(torch::Tensor seeds);

/**
* @brief After the graph structure for the missing node ids are fetched, it
* inserts the node ids which passes the threshold and returns the final
Expand Down Expand Up @@ -96,6 +100,13 @@ class GpuGraphCache : public torch::CustomClassHolder {
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
std::vector<torch::Tensor> edge_tensors);

c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
ReplaceAsync(
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
std::vector<torch::Tensor> edge_tensors);

static c10::intrusive_ptr<GpuGraphCache> Create(
const int64_t num_edges, const int64_t threshold,
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes);
Expand All @@ -111,6 +122,7 @@ class GpuGraphCache : public torch::CustomClassHolder {
torch::Tensor offset_; // The original graph's sliced_indptr tensor.
std::vector<torch::Tensor> cached_edge_tensors_; // The cached graph
// structure edge tensors.
std::mutex mtx_; // Protects the data structure and makes it threadsafe.
};

} // namespace cuda
Expand Down
15 changes: 14 additions & 1 deletion graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ TORCH_LIBRARY(graphbolt, m) {
"wait",
&Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
"GpuGraphCacheQueryFuture")
.def(
"wait",
&Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>::
Wait);
m.class_<Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>(
"GpuGraphCacheReplaceFuture")
.def(
"wait",
&Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>::Wait);
m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray")
.def("index_select", &storage::OnDiskNpyArray::IndexSelect);
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
Expand Down Expand Up @@ -114,7 +125,9 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("gpu_cache", &cuda::GpuCache::Create);
m.class_<cuda::GpuGraphCache>("GpuGraphCache")
.def("query", &cuda::GpuGraphCache::Query)
.def("replace", &cuda::GpuGraphCache::Replace);
.def("query_async", &cuda::GpuGraphCache::QueryAsync)
.def("replace", &cuda::GpuGraphCache::Replace)
.def("replace_async", &cuda::GpuGraphCache::ReplaceAsync);
m.def("gpu_graph_cache", &cuda::GpuGraphCache::Create);
#endif
m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);
Expand Down
39 changes: 39 additions & 0 deletions python/dgl/graphbolt/impl/gpu_graph_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ def replace_functional(missing_indptr, missing_edge_tensors):

return keys[index[num_hit:]], replace_functional

def query_async(self, keys):
"""Queries the GPU cache asynchronously.

Parameters
----------
keys : Tensor
The keys to query the GPU graph cache with.

Returns
-------
A generator object.
The returned generator object returns the missing keys on the second
invocation and expects the fetched indptr and edge tensors on the
next invocation. The third and last invocation returns a future
object and the return result can be accessed by calling `.wait()`
on the returned future object. It is undefined behavior to call
`.wait()` more than once.
"""
future = self._cache.query_async(keys)

yield

index, position, num_hit, num_threshold = future.wait()

self.total_queries += keys.shape[0]
self.total_miss += keys.shape[0] - num_hit

missing_indptr, missing_edge_tensors = yield keys[index[num_hit:]]

yield self._cache.replace_async(
keys,
index,
position,
num_hit,
num_threshold,
missing_indptr,
missing_edge_tensors,
)

@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""
Expand Down
49 changes: 32 additions & 17 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,25 @@

@functional_datapipe("fetch_cached_insubgraph_data")
class FetchCachedInsubgraphData(Mapper):
"""Queries the GPUGraphCache and returns the missing seeds and a lambda
function that can be called with the fetched graph structure.
"""Queries the GPUGraphCache and returns the missing seeds and a generator
handle that can be called with the fetched graph structure.
"""

def __init__(self, datapipe, gpu_graph_cache):
super().__init__(datapipe, self._fetch_per_layer)
datapipe = datapipe.transform(self._fetch_per_layer).buffer()
super().__init__(datapipe, self._wait_query_future)
self.cache = gpu_graph_cache

def _fetch_per_layer(self, minibatch):
minibatch._seeds, minibatch._replace = self.cache.query(
minibatch._seeds
)
minibatch._async_handle = self.cache.query_async(minibatch._seeds)
# Start first stage
next(minibatch._async_handle)

return minibatch

@staticmethod
def _wait_query_future(minibatch):
minibatch._seeds = next(minibatch._async_handle)

return minibatch

Expand All @@ -55,7 +62,8 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
"""

def __init__(self, datapipe, prob_name):
super().__init__(datapipe, self._combine_per_layer)
datapipe = datapipe.transform(self._combine_per_layer).buffer()
super().__init__(datapipe, self._wait_replace_future)
self.prob_name = prob_name

def _combine_per_layer(self, minibatch):
Expand All @@ -69,16 +77,24 @@ def _combine_per_layer(self, minibatch):
edge_tensors.append(probs_or_mask)
edge_tensors.append(subgraph.edge_attribute(ORIGINAL_EDGE_ID))

subgraph.csc_indptr, edge_tensors = minibatch._replace(
subgraph.csc_indptr, edge_tensors
minibatch._future = minibatch._async_handle.send(
(subgraph.csc_indptr, edge_tensors)
)
delattr(minibatch, "_replace")
delattr(minibatch, "_async_handle")

return minibatch

def _wait_replace_future(self, minibatch):
subgraph = minibatch._sliced_sampling_graph
subgraph.csc_indptr, edge_tensors = minibatch._future.wait()
delattr(minibatch, "_future")

subgraph.indices = edge_tensors[0]
edge_tensors = edge_tensors[1:]
if subgraph.type_per_edge is not None:
subgraph.type_per_edge = edge_tensors[0]
edge_tensors = edge_tensors[1:]
probs_or_mask = subgraph.edge_attribute(self.prob_name)
if probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
edge_tensors = edge_tensors[1:]
Expand Down Expand Up @@ -113,7 +129,7 @@ def _concat(self, minibatch):


@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(Mapper):
class FetchInsubgraphData(MiniBatchTransformer):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
the provided sample_per_layer_obj has a valid prob_name, then it reads the
probabilies of all the fetched edges. Furthermore, if type_per_array tensor
Expand All @@ -131,9 +147,13 @@ def __init__(
datapipe = datapipe.fetch_cached_insubgraph_data(
graph._gpu_graph_cache
)
datapipe = datapipe.transform(self._fetch_per_layer)
datapipe = datapipe.buffer().wait()
if graph._gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)
super().__init__(datapipe)
self.graph = graph
self.prob_name = prob_name
super().__init__(datapipe, self._fetch_per_layer)

def _fetch_per_layer(self, minibatch):
stream = torch.cuda.current_stream()
Expand Down Expand Up @@ -260,11 +280,6 @@ def __init__(
self.returning_indices_is_optional = True
elif overlap_fetch:
datapipe = datapipe.fetch_insubgraph_data(graph, prob_name)
datapipe = datapipe.buffer().wait()
if graph._gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(
prob_name
)
datapipe = datapipe.transform(
self._sample_per_layer_from_fetched_subgraph
)
Expand Down
4 changes: 4 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def test_gpu_sampling_DataLoader(
awaiter_cnt += num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
if overlap_graph_fetch:
bufferer_cnt += 0 * num_layers
if num_gpu_cached_edges > 0:
bufferer_cnt += 2 * num_layers
datapipe = dataloader.dataset
datapipe_graph = traverse_dps(datapipe)
awaiters = find_dps(
Expand Down
Loading