Skip to content

Commit

Permalink
distinction between homo layer input and hetero layer input.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 26, 2024
1 parent c046a47 commit f37a455
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,25 @@


class PyGLayerData(NamedTuple):
"""A named tuple class to represent the inputs to a PyG model layer.
The fields are x (input features), edge_index and size (source and destination sizes.)
"""A named tuple class to represent homogenous inputs to a PyG model layer.
The fields are x (input features), edge_index and size
(source and destination sizes).
"""

x: Union[torch.Tensor, Dict[str, torch.Tensor]]
edge_index: Union[torch.Tensor, Dict[str, torch.Tensor]]
size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]]
x: torch.Tensor
edge_index: torch.Tensor
size: Tuple[int, int]


class PyGLayerHeteroData(NamedTuple):
"""A named tuple class to represent heterogenous inputs to a PyG model
layer. The fields are x (input features), edge_index and size
(source and destination sizes), and all fields are dictionaries.
"""

x: Dict[str, torch.Tensor]
edge_index: Dict[str, torch.Tensor]
size: Dict[str, Tuple[int, int]]


class SampledSubgraph:
Expand Down Expand Up @@ -292,7 +304,9 @@ def to_pyg(self, x: Union[torch.Tensor, Dict[str, torch.Tensor]]):
edge_index_dict[etype] = edge_index
sizes_dict[etype] = (x[src_ntype].size(0), dst_size)

return PyGLayerData((x, x_dst_dict), edge_index_dict, sizes_dict)
return PyGLayerHeteroData(
(x, x_dst_dict), edge_index_dict, sizes_dict
)

def to(
self, device: torch.device, non_blocking=False
Expand Down

0 comments on commit f37a455

Please sign in to comment.