Skip to content

Commit

Permalink
more progress
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 28, 2024
1 parent 3547c00 commit 1dce39e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
6 changes: 5 additions & 1 deletion graphbolt/src/cuda/index_select_csc_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <numeric>

#include "./common.h"
#include "./max_uva_threads.h"
#include "./utils.h"

namespace graphbolt {
Expand Down Expand Up @@ -130,7 +131,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::Tensor output_indices =
torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
const dim3 block(BLOCK_SIZE);
const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE);
const dim3 grid(
(std::min(edge_count_aligned, cuda::max_uva_threads.value_or(1 << 20)) +
BLOCK_SIZE - 1) /
BLOCK_SIZE);

// Find the smallest integer type to store the coo_aligned_rows tensor.
const int num_bits = cuda::NumberOfBits(num_nodes);
Expand Down
22 changes: 19 additions & 3 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,23 @@ def __iter__(self):
yield data


class FutureWaiter(dp.iter.IterDataPipe):
"""Calls the result function of all items and returns their results."""

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for data in self.datapipe:
yield data.result()


class FetcherAndSampler(dp.iter.IterDataPipe):
def __init__(self, datapipe, sampler):
datapipe = datapipe.fetch_insubgraph_data(sampler)
def __init__(self, datapipe, sampler, stream):
datapipe = datapipe.fetch_insubgraph_data(sampler, stream)
datapipe = Bufferer(datapipe, 1)
datapipe = FutureWaiter(datapipe)
datapipe = Awaiter(datapipe)
self.datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler)

def __iter__(self):
Expand Down Expand Up @@ -260,7 +274,9 @@ def __init__(
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
sampler,
FetcherAndSampler(sampler.datapipe, sampler),
FetcherAndSampler(
sampler.datapipe, sampler, _get_uva_stream()
),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
Expand Down
39 changes: 36 additions & 3 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Neighbor subgraph samplers for GraphBolt."""

from concurrent.futures import ThreadPoolExecutor
from functools import partial

import torch
Expand All @@ -25,22 +26,34 @@
class FetchInsubgraphData(Mapper):
""""""

def __init__(self, datapipe, sample_per_layer_obj):
def __init__(self, datapipe, sample_per_layer_obj, stream=None):
super().__init__(datapipe, self._fetch_per_layer)
self.graph = sample_per_layer_obj.sampler.__self__
self.prob_name = sample_per_layer_obj.prob_name
self.stream = stream
self.executor = ThreadPoolExecutor(max_workers=1)

def _fetch_per_layer(self, minibatch):
def _fetch_per_layer_helper(self, minibatch, stream):
index = minibatch.input_nodes
if index.is_cuda:
index.record_stream(torch.cuda.current_stream())
index_select = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)

indptr, indices = index_select(self.graph.indices, index, None)
record_stream(indptr)
record_stream(indices)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select(
self.graph.type_per_edge, index, output_size
)
record_stream(type_per_edge)
else:
type_per_edge = None
if self.graph.edge_attributes is not None:
Expand All @@ -49,6 +62,7 @@ def _fetch_per_layer(self, minibatch):
_, probs_or_mask = index_select(
probs_or_mask, index, output_size
)
record_stream(probs_or_mask)
else:
probs_or_mask = None
subgraph = fused_csc_sampling_graph(
Expand All @@ -59,7 +73,26 @@ def _fetch_per_layer(self, minibatch):
if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask}

return subgraph, minibatch
if self.stream is not None:
event = torch.cuda.current_stream().record_event()

class WaitableTuple(tuple):
def wait(self):
event.wait()

return WaitableTuple((subgraph, minibatch))
else:
return subgraph, minibatch

def _fetch_per_layer(self, minibatch):
current_stream = None
if self.stream is not None:
current_stream = torch.cuda.current_stream()
self.stream.wait_stream(current_stream)
with torch.cuda.stream(self.stream):
return self.executor.submit(
self._fetch_per_layer_helper, minibatch, current_stream
)


@functional_datapipe("sample_per_layer_from_fetched_subgraph")
Expand Down

0 comments on commit 1dce39e

Please sign in to comment.