Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] sampled_edge_ids to fetch indices later. #7664

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,23 +573,28 @@ def _convert_to_sampled_subgraph(
indices = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids
sampled_edge_ids = C_sampled_subgraph.original_edge_ids
Copy link
Collaborator

@frozenbugs frozenbugs Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also rename the original_edge_ids in C_sampled_subgraph?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#7703 I updated the doc string to clarify. It is indeed hard to rename.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampled_edge_ids -> edge_ids_in_fused_csc_sampling_graph.

etype_offsets = C_sampled_subgraph.etype_offsets
if etype_offsets is not None:
etype_offsets = etype_offsets.tolist()

has_original_eids = (
original_edge_ids is not None
and self.edge_attributes is not None
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], original_edge_ids
original_edge_ids = (
torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], sampled_edge_ids
)
if has_original_eids
else sampled_edge_ids
)
if type_per_edge is None and etype_offsets is None:
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
if indices is not None:
# Only needed to fetch indices.
sampled_edge_ids = None
else:
offset = self._node_type_offset_list

Expand Down Expand Up @@ -626,11 +631,12 @@ def _convert_to_sampled_subgraph(
sub_indptr[etype] = torch.cat(
(torch.tensor([0], device=indptr.device), cum_edges)
)
if original_edge_ids is not None:
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
sampled_hetero_edge_ids = None
else:
sampled_hetero_edge_ids = {}
edge_offsets = [0]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
Expand All @@ -647,18 +653,28 @@ def _convert_to_sampled_subgraph(
sub_indptr[etype] = indptr[
edge_offsets[etype_id] : edge_offsets[etype_id + 1]
]
sub_indices[etype] = indices[
sub_indices[etype] = (
None
if indices is None
else indices[
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]
)
original_hetero_edge_ids[etype] = original_edge_ids[
etype_offsets[etype_id] : etype_offsets[etype_id + 1]
]
if original_edge_ids is not None:
original_hetero_edge_ids[etype] = original_edge_ids[
if indices is None:
# Only needed to fetch indices.
sampled_hetero_edge_ids[etype] = sampled_edge_ids[
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]

if original_edge_ids is not None:
original_edge_ids = original_hetero_edge_ids
original_edge_ids = original_hetero_edge_ids
sampled_edge_ids = sampled_hetero_edge_ids
sampled_csc = {
etype: CSCFormatBase(
indptr=sub_indptr[etype],
Expand All @@ -669,6 +685,7 @@ def _convert_to_sampled_subgraph(
return SampledSubgraphImpl(
sampled_csc=sampled_csc,
original_edge_ids=original_edge_ids,
_sampled_edge_ids=sampled_edge_ids,
)

def sample_neighbors(
Expand Down
44 changes: 30 additions & 14 deletions python/dgl/graphbolt/impl/sampled_subgraph_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class SampledSubgraphImpl(SampledSubgraph):
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
# Used to fetch sampled_csc.indices if it is missing.
_sampled_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: _edge_ids_in_fused_csc_sampling_graph

The name has to be clear since it is indeed non-intuitive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def __post_init__(self):
if isinstance(self.sampled_csc, dict):
Expand All @@ -53,22 +55,34 @@ def __post_init__(self):
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
pair.indptr is not None and pair.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(pair.indptr, torch.Tensor) and isinstance(
pair.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
assert pair.indptr is not None and isinstance(
pair.indptr, torch.Tensor
), "Node pair should be have indptr of type torch.Tensor."
# For CUDA, indices may be None because it will be fetched later.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering how much performance we can gain by asynchronously fetch indices, esp, we fetch original_edge_ids in the sync way.

Please make balance on performance gain and code structure complexity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

25%

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

25% is a lot so we needed to do it.

if not pair.indptr.is_cuda or pair.indices is not None:
assert isinstance(
pair.indices, torch.Tensor
), "Node pair should be have indices of type torch.Tensor."
else:
assert isinstance(
self._sampled_edge_ids.get(etype, None), torch.Tensor
), "When indices is missing, sampled edge ids needs to be provided."
else:
assert (
self.sampled_csc.indptr is not None
and self.sampled_csc.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(
assert self.sampled_csc.indptr is not None and isinstance(
self.sampled_csc.indptr, torch.Tensor
) and isinstance(
self.sampled_csc.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
), "Node pair should be have torch.Tensor indptr."
# For CUDA, indices may be None because it will be fetched later.
if (
not self.sampled_csc.indptr.is_cuda
or self.sampled_csc.indices is not None
):
assert isinstance(
self.sampled_csc.indices, torch.Tensor
), "Node pair should have a torch.Tensor indices."
else:
assert isinstance(
self._sampled_edge_ids, torch.Tensor
), "When indices is missing, sampled edge ids needs to be provided."

def __repr__(self) -> str:
return _sampled_subgraph_str(self, "SampledSubgraphImpl")
Expand All @@ -81,6 +95,8 @@ def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str:
attributes.reverse()

for name in attributes:
if name in "_sampled_edge_ids":
continue
val = getattr(sampled_subgraph, name)

def _add_indent(_str, indent):
Expand Down
Loading