Skip to content

Commit

Permalink
[GraphBolt][CUDA] Overlap original edge ids fetch.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 17, 2024
1 parent 2521081 commit 0f3604b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 16 deletions.
70 changes: 57 additions & 13 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,32 @@


class _SampleNeighborsWaiter:
def __init__(self, fn, future, seed_offsets):
def __init__(
self, fn, future, seed_offsets, fetching_original_edge_ids_is_optional
):
self.fn = fn
self.future = future
self.seed_offsets = seed_offsets
self.fetching_original_edge_ids_is_optional = (
fetching_original_edge_ids_is_optional
)

def wait(self):
"""Returns the stored value when invoked."""
fn = self.fn
C_sampled_subgraph = self.future.wait()
seed_offsets = self.seed_offsets
fetching_original_edge_ids_is_optional = (
self.fetching_original_edge_ids_is_optional
)
# Ensure there is no memory leak.
self.fn = self.future = self.seed_offsets = None
return fn(C_sampled_subgraph, seed_offsets)
self.fetching_original_edge_ids_is_optional = None
return fn(
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
)


class FusedCSCSamplingGraph(SamplingGraph):
Expand Down Expand Up @@ -592,6 +605,7 @@ def _convert_to_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
seed_offsets: Optional[list] = None,
fetching_original_edge_ids_is_optional: bool = False,
) -> SampledSubgraphImpl:
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
Expand All @@ -611,18 +625,24 @@ def _convert_to_sampled_subgraph(
and ORIGINAL_EDGE_ID in self.edge_attributes
)
original_edge_ids = (
torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID],
edge_ids_in_fused_csc_sampling_graph,
(
torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID],
edge_ids_in_fused_csc_sampling_graph,
)
if not fetching_original_edge_ids_is_optional
or not edge_ids_in_fused_csc_sampling_graph.is_cuda
or not self.edge_attributes[ORIGINAL_EDGE_ID].is_pinned()
else None
)
if has_original_eids
else edge_ids_in_fused_csc_sampling_graph
)
if type_per_edge is None and etype_offsets is None:
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
if indices is not None:
# Only needed to fetch indices.
if indices is not None and original_edge_ids is not None:
# Only needed to fetch indices or original_edge_ids.
edge_ids_in_fused_csc_sampling_graph = None
else:
offset = self._node_type_offset_list
Expand Down Expand Up @@ -691,10 +711,16 @@ def _convert_to_sampled_subgraph(
]
]
)
original_hetero_edge_ids[etype] = original_edge_ids[
etype_offsets[etype_id] : etype_offsets[etype_id + 1]
]
if indices is None:
original_hetero_edge_ids[etype] = (
None
if original_edge_ids is None
else original_edge_ids[
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]
)
if indices is None or original_edge_ids is None:
# Only needed to fetch indices.
sampled_hetero_edge_ids_in_fused_csc_sampling_graph[
etype
Expand Down Expand Up @@ -728,6 +754,7 @@ def sample_neighbors(
replace: bool = False,
probs_name: Optional[str] = None,
returning_indices_is_optional: bool = False,
fetching_original_edge_ids_is_optional: bool = False,
async_op: bool = False,
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
Expand Down Expand Up @@ -772,6 +799,11 @@ def sample_neighbors(
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
fetching_original_edge_ids_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the original edge ids tensor uninitialized. In this case,
it is the user's responsibility to gather it using
_edge_ids_in_fused_csc_sampling_graph.
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the
result can be obtained by calling wait on the returned future.
Expand Down Expand Up @@ -826,10 +858,13 @@ def sample_neighbors(
self._convert_to_sampled_subgraph,
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
)
else:
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
)

def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask):
Expand Down Expand Up @@ -957,6 +992,7 @@ def sample_layer_neighbors(
replace: bool = False,
probs_name: Optional[str] = None,
returning_indices_is_optional: bool = False,
fetching_original_edge_ids_is_optional: bool = False,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
async_op: bool = False,
Expand Down Expand Up @@ -1005,6 +1041,11 @@ def sample_layer_neighbors(
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
fetching_original_edge_ids_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the original edge ids tensor uninitialized. In this case,
it is the user's responsibility to gather it using
_edge_ids_in_fused_csc_sampling_graph.
random_seed: torch.Tensor, optional
An int64 tensor with one or two elements.
Expand Down Expand Up @@ -1102,10 +1143,13 @@ def sample_layer_neighbors(
self._convert_to_sampled_subgraph,
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
)
else:
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
)

def temporal_sample_neighbors(
Expand Down
48 changes: 45 additions & 3 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,21 @@ def __init__(
):
graph = sampler.__self__
self.returning_indices_is_optional = False
original_edge_ids = (
None
if graph.edge_attributes is None
else graph.edge_attributes.get(ORIGINAL_EDGE_ID, None)
)
self.fetching_original_edge_ids_is_optional = (
overlap_fetch
and original_edge_ids is not None
and original_edge_ids.is_pinned()
)
fetch_indices_and_original_edge_ids_fn = partial(
self._fetch_indices_and_original_edge_ids,
graph.indices,
original_edge_ids,
)
if (
overlap_fetch
and sampler.__name__ == "sample_neighbors"
Expand All @@ -263,7 +278,7 @@ def __init__(
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._wait_subgraph_future)
datapipe = (
datapipe.transform(partial(self._fetch_indices, graph.indices))
datapipe.transform(fetch_indices_and_original_edge_ids_fn)
.buffer()
.wait()
)
Expand All @@ -285,6 +300,12 @@ def __init__(
if asynchronous:
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._wait_subgraph_future)
if self.fetching_original_edge_ids_is_optional:
datapipe = (
datapipe.transform(fetch_indices_and_original_edge_ids_fn)
.buffer()
.wait()
)
else:
datapipe = datapipe.transform(self._sample_per_layer)
if asynchronous:
Expand All @@ -310,6 +331,7 @@ def _sample_per_layer(self, minibatch):
self.replace,
self.prob_name,
self.returning_indices_is_optional,
self.fetching_original_edge_ids_is_optional,
async_op=self.asynchronous,
**kwargs,
)
Expand All @@ -329,6 +351,8 @@ def _sample_per_layer_from_fetched_subgraph(self, minibatch):
self.fanout,
self.replace,
self.prob_name,
False,
self.fetching_original_edge_ids_is_optional,
async_op=self.asynchronous,
**kwargs,
)
Expand All @@ -341,7 +365,7 @@ def _wait_subgraph_future(minibatch):
return minibatch

@staticmethod
def _fetch_indices(indices, minibatch):
def _fetch_indices_and_original_edge_ids(indices, orig_edge_ids, minibatch):
stream = torch.cuda.current_stream()
host_to_device_stream = get_host_to_device_uva_stream()
host_to_device_stream.wait_stream(stream)
Expand All @@ -366,6 +390,13 @@ def record_stream(tensor):
index_select(indices, edge_ids)
)
minibatch._indices_needs_offset_subtraction = True
if (
orig_edge_ids is not None
and subgraph.original_edge_ids[etype] is None
):
subgraph.original_edge_ids[etype] = record_stream(
index_select(orig_edge_ids, edge_ids)
)
elif subgraph.sampled_csc.indices is None:
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
torch.cuda.current_stream()
Expand All @@ -375,7 +406,18 @@ def record_stream(tensor):
indices, subgraph._edge_ids_in_fused_csc_sampling_graph
)
)
minibatch._indices_needs_offset_subtraction = True
# homo case does not need subtraction of offsets from indices.
minibatch._indices_needs_offset_subtraction = False
if (
orig_edge_ids is not None
and subgraph.original_edge_ids is None
):
subgraph.original_edge_ids = record_stream(
index_select(
orig_edge_ids,
subgraph._edge_ids_in_fused_csc_sampling_graph,
)
)
subgraph._edge_ids_in_fused_csc_sampling_graph = None
minibatch.wait = torch.cuda.current_stream().record_event().wait

Expand Down

0 comments on commit 0f3604b

Please sign in to comment.