Skip to content

Commit

Permalink
simplify.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 18, 2024
1 parent 18f9aee commit 0a52ad2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 41 deletions.
44 changes: 18 additions & 26 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,7 @@ def sample_neighbors(
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
returning_indices_is_optional: bool = False,
fetching_original_edge_ids_is_optional: bool = False,
returning_indices_and_original_edge_ids_are_optional: bool = False,
async_op: bool = False,
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
Expand Down Expand Up @@ -795,15 +794,12 @@ def sample_neighbors(
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
returning_indices_is_optional: bool
returning_indices_and_original_edge_ids_are_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
fetching_original_edge_ids_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the original edge ids tensor uninitialized. In this case,
it is the user's responsibility to gather it using
_edge_ids_in_fused_csc_sampling_graph.
to leave the indices and the original edge ids tensors

Check warning on line 799 in python/dgl/graphbolt/impl/fused_csc_sampling_graph.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
uninitialized. In this case, it is the user's responsibility to
gather them using _edge_ids_in_fused_csc_sampling_graph if either is
missing.
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the
result can be obtained by calling wait on the returned future.
Expand Down Expand Up @@ -850,21 +846,21 @@ def sample_neighbors(
fanouts,
replace=replace,
probs_or_mask=probs_or_mask,
returning_indices_is_optional=returning_indices_is_optional,
returning_indices_is_optional=returning_indices_and_original_edge_ids_are_optional,
async_op=async_op,
)
if async_op:
return _SampleNeighborsWaiter(
self._convert_to_sampled_subgraph,
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
returning_indices_and_original_edge_ids_are_optional,
)
else:
return self._convert_to_sampled_subgraph(
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
returning_indices_and_original_edge_ids_are_optional,
)

def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask):
Expand Down Expand Up @@ -991,8 +987,7 @@ def sample_layer_neighbors(
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
returning_indices_is_optional: bool = False,
fetching_original_edge_ids_is_optional: bool = False,
returning_indices_and_original_edge_ids_are_optional: bool = False,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
async_op: bool = False,
Expand Down Expand Up @@ -1037,15 +1032,12 @@ def sample_layer_neighbors(
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
returning_indices_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
fetching_original_edge_ids_is_optional: bool
returning_indices_and_original_edge_ids_are_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the original edge ids tensor uninitialized. In this case,
it is the user's responsibility to gather it using
_edge_ids_in_fused_csc_sampling_graph.
to leave the indices and the original edge ids tensors
uninitialized. In this case, it is the user's responsibility to
gather them using _edge_ids_in_fused_csc_sampling_graph if either is
missing.
random_seed: torch.Tensor, optional
An int64 tensor with one or two elements.
Expand Down Expand Up @@ -1133,7 +1125,7 @@ def sample_layer_neighbors(
fanouts.tolist(),
replace,
True, # is_labor
returning_indices_is_optional,
returning_indices_and_original_edge_ids_are_optional,
probs_or_mask,
random_seed,
seed2_contribution,
Expand All @@ -1143,13 +1135,13 @@ def sample_layer_neighbors(
self._convert_to_sampled_subgraph,
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
returning_indices_and_original_edge_ids_are_optional,
)
else:
return self._convert_to_sampled_subgraph(
C_sampled_subgraph,
seed_offsets,
fetching_original_edge_ids_is_optional,
returning_indices_and_original_edge_ids_are_optional,
)

def temporal_sample_neighbors(
Expand Down
24 changes: 9 additions & 15 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,32 +263,27 @@ def __init__(
asynchronous=False,
):
graph = sampler.__self__
self.returning_indices_is_optional = False
self.returning_indices_and_original_edge_ids_are_optional = False
original_edge_ids = (
None
if graph.edge_attributes is None
else graph.edge_attributes.get(ORIGINAL_EDGE_ID, None)
)
self.fetching_original_edge_ids_is_optional = (
overlap_fetch
and original_edge_ids is not None
and original_edge_ids.is_pinned()
)
fetch_indices_and_original_edge_ids_fn = partial(
self._fetch_indices_and_original_edge_ids,
graph.indices,
original_edge_ids,
)
if (
overlap_fetch
and sampler.__name__ == "sample_neighbors"

Check warning on line 274 in python/dgl/graphbolt/impl/neighbor_sampler.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
and graph.indices.is_pinned()
and (graph.indices.is_pinned() or (original_edge_ids is not None and original_edge_ids.is_pinned()))
and graph._gpu_graph_cache is None
):
datapipe = datapipe.transform(self._sample_per_layer)
if asynchronous:
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._wait_subgraph_future)
fetch_indices_and_original_edge_ids_fn = partial(
self._fetch_indices_and_original_edge_ids,
graph.indices,
original_edge_ids,
)
datapipe = (
datapipe.transform(fetch_indices_and_original_edge_ids_fn)
.buffer()
Expand All @@ -303,7 +298,7 @@ def __init__(
graph.node_type_to_id,
)
)
self.returning_indices_is_optional = True
self.returning_indices_and_original_edge_ids_are_optional = True
elif overlap_fetch:
datapipe = datapipe.fetch_insubgraph_data(graph, prob_name)
datapipe = datapipe.transform(
Expand Down Expand Up @@ -336,8 +331,7 @@ def _sample_per_layer(self, minibatch):
self.fanout,
self.replace,
self.prob_name,
self.returning_indices_is_optional,
self.fetching_original_edge_ids_is_optional,
self.returning_indices_and_original_edge_ids_are_optional,
async_op=self.asynchronous,
**kwargs,
)
Expand Down

0 comments on commit 0a52ad2

Please sign in to comment.