Skip to content

Commit

Permalink
[GraphBolt][CUDA] sampled_edge_ids to fetch indices later.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 7, 2024
1 parent d6f0771 commit 590b8e2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
47 changes: 32 additions & 15 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,23 +573,28 @@ def _convert_to_sampled_subgraph(
indices = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids
sampled_edge_ids = C_sampled_subgraph.original_edge_ids
etype_offsets = C_sampled_subgraph.etype_offsets
if etype_offsets is not None:
etype_offsets = etype_offsets.tolist()

has_original_eids = (
original_edge_ids is not None
and self.edge_attributes is not None
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], original_edge_ids
original_edge_ids = (
torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], sampled_edge_ids
)
if has_original_eids
else sampled_edge_ids
)
if type_per_edge is None and etype_offsets is None:
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
if indices is not None:
# Only needed to fetch indices.
sampled_edge_ids = None
else:
offset = self._node_type_offset_list

Expand Down Expand Up @@ -626,11 +631,12 @@ def _convert_to_sampled_subgraph(
sub_indptr[etype] = torch.cat(
(torch.tensor([0], device=indptr.device), cum_edges)
)
if original_edge_ids is not None:
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
sampled_hetero_edge_ids = None
else:
sampled_hetero_edge_ids = {}
edge_offsets = [0]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
Expand All @@ -647,18 +653,28 @@ def _convert_to_sampled_subgraph(
sub_indptr[etype] = indptr[
edge_offsets[etype_id] : edge_offsets[etype_id + 1]
]
sub_indices[etype] = indices[
sub_indices[etype] = (
None
if indices is None
else indices[
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]
)
original_hetero_edge_ids[etype] = original_edge_ids[
etype_offsets[etype_id] : etype_offsets[etype_id + 1]
]
if original_edge_ids is not None:
original_hetero_edge_ids[etype] = original_edge_ids[
if indices is None:
# Only needed to fetch indices.
sampled_hetero_edge_ids[etype] = sampled_edge_ids[
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]

if original_edge_ids is not None:
original_edge_ids = original_hetero_edge_ids
original_edge_ids = original_hetero_edge_ids
sampled_edge_ids = sampled_hetero_edge_ids
sampled_csc = {
etype: CSCFormatBase(
indptr=sub_indptr[etype],
Expand All @@ -669,6 +685,7 @@ def _convert_to_sampled_subgraph(
return SampledSubgraphImpl(
sampled_csc=sampled_csc,
original_edge_ids=original_edge_ids,
_sampled_edge_ids=sampled_edge_ids,
)

def sample_neighbors(
Expand Down
44 changes: 30 additions & 14 deletions python/dgl/graphbolt/impl/sampled_subgraph_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class SampledSubgraphImpl(SampledSubgraph):
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
# Used to fetch sampled_csc.indices if it is missing.
_sampled_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None

def __post_init__(self):
if isinstance(self.sampled_csc, dict):
Expand All @@ -53,22 +55,34 @@ def __post_init__(self):
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
pair.indptr is not None and pair.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(pair.indptr, torch.Tensor) and isinstance(
pair.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
assert pair.indptr is not None and isinstance(
pair.indptr, torch.Tensor
), "Node pair should be have indptr of type torch.Tensor."
# For CUDA, indices may be None because it will be fetched later.
if not pair.indptr.is_cuda or pair.indices is not None:
assert isinstance(
pair.indices, torch.Tensor
), "Node pair should be have indices of type torch.Tensor."
else:
assert isinstance(
self._sampled_edge_ids.get(etype, None), torch.Tensor
), "When indices is missing, sampled edge ids needs to be provided."
else:
assert (
self.sampled_csc.indptr is not None
and self.sampled_csc.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(
assert self.sampled_csc.indptr is not None and isinstance(
self.sampled_csc.indptr, torch.Tensor
) and isinstance(
self.sampled_csc.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
), "Node pair should be have torch.Tensor indptr."
# For CUDA, indices may be None because it will be fetched later.
if (
not self.sampled_csc.indptr.is_cuda
or self.sampled_csc.indices is not None
):
assert isinstance(
self.sampled_csc.indices, torch.Tensor
), "Node pair should have a torch.Tensor indices."
else:
assert isinstance(
self._sampled_edge_ids, torch.Tensor
), "When indices is missing, sampled edge ids needs to be provided."

def __repr__(self) -> str:
return _sampled_subgraph_str(self, "SampledSubgraphImpl")
Expand All @@ -81,6 +95,8 @@ def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str:
attributes.reverse()

for name in attributes:
if name in "_sampled_edge_ids":
continue
val = getattr(sampled_subgraph, name)

def _add_indent(_str, indent):
Expand Down

0 comments on commit 590b8e2

Please sign in to comment.