Skip to content

Commit

Permalink
get the implementation working
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 28, 2024
1 parent 1dce39e commit 0ad178a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 50 deletions.
4 changes: 2 additions & 2 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class DataLoader(torch.utils.data.DataLoader):
of the computations can run simultaneously with it. Setting it to a too
high value will limit the amount of overlap while setting it too low may
cause the PCI-e bandwidth to not get fully utilized. Manually tuned
default is 6144, meaning around 3-4 Streaming Multiprocessors.
default is 12288, meaning around 6-8 Streaming Multiprocessors.
"""

def __init__(
Expand All @@ -194,7 +194,7 @@ def __init__(
persistent_workers=True,
overlap_feature_fetch=True,
overlap_graph_fetch=True,
max_uva_threads=6144,
max_uva_threads=12288,
):
# Multiprocessing requires two modifications to the datapipe:
#
Expand Down
98 changes: 50 additions & 48 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,65 +34,67 @@ def __init__(self, datapipe, sample_per_layer_obj, stream=None):
self.executor = ThreadPoolExecutor(max_workers=1)

def _fetch_per_layer_helper(self, minibatch, stream):
index = minibatch.input_nodes
if index.is_cuda:
index.record_stream(torch.cuda.current_stream())
index_select = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)

indptr, indices = index_select(self.graph.indices, index, None)
record_stream(indptr)
record_stream(indices)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select(
self.graph.type_per_edge, index, output_size
with torch.cuda.stream(self.stream):
index = minibatch.input_nodes
if index.is_cuda:
index.record_stream(torch.cuda.current_stream())
index_select = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)
record_stream(type_per_edge)
else:
type_per_edge = None
if self.graph.edge_attributes is not None:
probs_or_mask = self.graph.edge_attributes.get(self.prob_name, None)
if probs_or_mask is not None:
_, probs_or_mask = index_select(
probs_or_mask, index, output_size

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)

indptr, indices = index_select(self.graph.indices, index, None)
record_stream(indptr)
record_stream(indices)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select(
self.graph.type_per_edge, index, output_size
)
record_stream(probs_or_mask)
else:
probs_or_mask = None
subgraph = fused_csc_sampling_graph(
indptr,
indices,
type_per_edge=type_per_edge,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask}
record_stream(type_per_edge)
else:
type_per_edge = None
if self.graph.edge_attributes is not None:
probs_or_mask = self.graph.edge_attributes.get(
self.prob_name, None
)
if probs_or_mask is not None:
_, probs_or_mask = index_select(
probs_or_mask, index, output_size
)
record_stream(probs_or_mask)
else:
probs_or_mask = None
subgraph = fused_csc_sampling_graph(
indptr,
indices,
type_per_edge=type_per_edge,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask}

if self.stream is not None:
event = torch.cuda.current_stream().record_event()
if self.stream is not None:
event = torch.cuda.current_stream().record_event()

class WaitableTuple(tuple):
def wait(self):
event.wait()
class WaitableTuple(tuple):
def wait(self):
event.wait()

return WaitableTuple((subgraph, minibatch))
else:
return subgraph, minibatch
return WaitableTuple((subgraph, minibatch))
else:
return subgraph, minibatch

def _fetch_per_layer(self, minibatch):
current_stream = None
if self.stream is not None:
current_stream = torch.cuda.current_stream()
self.stream.wait_stream(current_stream)
with torch.cuda.stream(self.stream):
return self.executor.submit(
self._fetch_per_layer_helper, minibatch, current_stream
)
return self.executor.submit(
self._fetch_per_layer_helper, minibatch, current_stream
)


@functional_datapipe("sample_per_layer_from_fetched_subgraph")
Expand Down

0 comments on commit 0ad178a

Please sign in to comment.