Skip to content

Commit

Permalink
fix bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 18, 2024
1 parent 7831056 commit 425b7bb
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,17 +401,19 @@ def record_stream(tensor):
subgraph.original_edge_ids[etype] = record_stream(
index_select(orig_edge_ids, edge_ids)
)
elif subgraph.sampled_csc.indices is None:
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
torch.cuda.current_stream()
)
subgraph.sampled_csc.indices = record_stream(
index_select(
indices, subgraph._edge_ids_in_fused_csc_sampling_graph
else:
if subgraph.sampled_csc.indices is None:
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
torch.cuda.current_stream()
)
)
# homo case does not need subtraction of offsets from indices.
minibatch._indices_needs_offset_subtraction = False
subgraph.sampled_csc.indices = record_stream(
index_select(
indices,
subgraph._edge_ids_in_fused_csc_sampling_graph,
)
)
# homo case does not need subtraction of offsets from indices.
minibatch._indices_needs_offset_subtraction = False
if (
orig_edge_ids is not None
and subgraph.original_edge_ids is None
Expand Down

0 comments on commit 425b7bb

Please sign in to comment.