diff --git a/graphbolt/src/cuda/index_select_csc_impl.cu b/graphbolt/src/cuda/index_select_csc_impl.cu index d1a6a89af18f..700babe1e53f 100644 --- a/graphbolt/src/cuda/index_select_csc_impl.cu +++ b/graphbolt/src/cuda/index_select_csc_impl.cu @@ -14,6 +14,7 @@ #include #include "./common.h" +#include "./max_uva_threads.h" #include "./utils.h" namespace graphbolt { @@ -130,7 +131,10 @@ std::tuple UVAIndexSelectCSCCopyIndices( torch::Tensor output_indices = torch::empty(output_size.value(), options.dtype(indices.scalar_type())); const dim3 block(BLOCK_SIZE); - const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE); + const dim3 grid( + (std::min(edge_count_aligned, cuda::max_uva_threads.value_or(1 << 20)) + + BLOCK_SIZE - 1) / + BLOCK_SIZE); // Find the smallest integer type to store the coo_aligned_rows tensor. const int num_bits = cuda::NumberOfBits(num_nodes); diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index 8e1f308c4ac4..95f1a11ff414 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -95,9 +95,23 @@ def __iter__(self): yield data +class FutureWaiter(dp.iter.IterDataPipe): + """Calls the result function of all items and returns their results.""" + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + for data in self.datapipe: + yield data.result() + + class FetcherAndSampler(dp.iter.IterDataPipe): - def __init__(self, datapipe, sampler): - datapipe = datapipe.fetch_insubgraph_data(sampler) + def __init__(self, datapipe, sampler, stream): + datapipe = datapipe.fetch_insubgraph_data(sampler, stream) + datapipe = Bufferer(datapipe, 1) + datapipe = FutureWaiter(datapipe) + datapipe = Awaiter(datapipe) self.datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler) def __iter__(self): @@ -260,7 +274,9 @@ def __init__( datapipe_graph = dp_utils.replace_dp( datapipe_graph, sampler, - FetcherAndSampler(sampler.datapipe, sampler), + FetcherAndSampler( + sampler.datapipe, sampler, _get_uva_stream() + ), ) # (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 3eed3723cf43..54ac12a59ce0 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -1,5 +1,6 @@ """Neighbor subgraph samplers for GraphBolt.""" +from concurrent.futures import ThreadPoolExecutor from functools import partial import torch @@ -25,22 +26,34 @@ class FetchInsubgraphData(Mapper): """""" - def __init__(self, datapipe, sample_per_layer_obj): + def __init__(self, datapipe, sample_per_layer_obj, stream=None): super().__init__(datapipe, self._fetch_per_layer) self.graph = sample_per_layer_obj.sampler.__self__ self.prob_name = sample_per_layer_obj.prob_name + self.stream = stream + self.executor = ThreadPoolExecutor(max_workers=1) - def _fetch_per_layer(self, minibatch): + def _fetch_per_layer_helper(self, minibatch, stream): index = minibatch.input_nodes + if index.is_cuda: + index.record_stream(torch.cuda.current_stream()) index_select = partial( torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr ) + + def record_stream(tensor): + if stream is not None and tensor.is_cuda: + tensor.record_stream(stream) + indptr, indices = index_select(self.graph.indices, index, None) + record_stream(indptr) + record_stream(indices) output_size = len(indices) if self.graph.type_per_edge is not None: _, type_per_edge = index_select( self.graph.type_per_edge, index, output_size ) + record_stream(type_per_edge) else: type_per_edge = None if self.graph.edge_attributes is not None: @@ -49,6 +62,7 @@ def _fetch_per_layer(self, minibatch): _, probs_or_mask = index_select( probs_or_mask, index, output_size ) + record_stream(probs_or_mask) else: probs_or_mask = None subgraph = fused_csc_sampling_graph( @@ -59,7 +73,26 @@ def _fetch_per_layer(self, minibatch): if self.prob_name is not None and probs_or_mask is not None: subgraph.edge_attributes = {self.prob_name: probs_or_mask} - return subgraph, minibatch + if self.stream is not None: + event = torch.cuda.current_stream().record_event() + + class WaitableTuple(tuple): + def wait(self): + event.wait() + + return WaitableTuple((subgraph, minibatch)) + else: + return subgraph, minibatch + + def _fetch_per_layer(self, minibatch): + current_stream = None + if self.stream is not None: + current_stream = torch.cuda.current_stream() + self.stream.wait_stream(current_stream) + with torch.cuda.stream(self.stream): + return self.executor.submit( + self._fetch_per_layer_helper, minibatch, current_stream + ) @functional_datapipe("sample_per_layer_from_fetched_subgraph")