Skip to content

Commit

Permalink
change to index_select_csc_with_indptr.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Feb 4, 2024
1 parent 23ade71 commit 6b21742
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,22 @@ def _fetch_per_layer_impl(self, minibatch, stream):
else:
minibatch._subgraph_seed_nodes = original_positions
index.record_stream(torch.cuda.current_stream())
index_select = partial(
index_select_csc_with_indptr = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)

indptr, indices = index_select(self.graph.indices, index, None)
indptr, indices = index_select_csc_with_indptr(
self.graph.indices, index, None
)
record_stream(indptr)
record_stream(indices)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select(
_, type_per_edge = index_select_csc_with_indptr(
self.graph.type_per_edge, index, output_size
)
record_stream(type_per_edge)
Expand All @@ -80,7 +82,7 @@ def record_stream(tensor):
self.prob_name, None
)
if probs_or_mask is not None:
_, probs_or_mask = index_select(
_, probs_or_mask = index_select_csc_with_indptr(
probs_or_mask, index, output_size
)
record_stream(probs_or_mask)
Expand Down

0 comments on commit 6b21742

Please sign in to comment.