Skip to content

Commit

Permalink
[GraphBolt][PyG] Add to_pyg for layer input conversion. (#7745)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 26, 2024
1 parent 0331009 commit 8eccbfa
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 1 deletion.
80 changes: 79 additions & 1 deletion python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Graphbolt sampled subgraph."""

# pylint: disable= invalid-name
from typing import Dict, Tuple, Union
from typing import Dict, NamedTuple, Tuple, Union

import torch

Expand All @@ -20,6 +20,28 @@
__all__ = ["SampledSubgraph"]


class PyGLayerData(NamedTuple):
"""A named tuple class to represent homogenous inputs to a PyG model layer.
The fields are x (input features), edge_index and size
(source and destination sizes).
"""

x: torch.Tensor
edge_index: torch.Tensor
size: Tuple[int, int]


class PyGLayerHeteroData(NamedTuple):
"""A named tuple class to represent heterogenous inputs to a PyG model
layer. The fields are x (input features), edge_index and size
(source and destination sizes), and all fields are dictionaries.
"""

x: Dict[str, torch.Tensor]
edge_index: Dict[str, torch.Tensor]
size: Dict[str, Tuple[int, int]]


class SampledSubgraph:
r"""An abstract class for sampled subgraph. In the context of a
heterogeneous graph, each field should be of `Dict` type. Otherwise,
Expand Down Expand Up @@ -233,6 +255,62 @@ def exclude_edges(
)
return calling_class(*_slice_subgraph(self, index))

def to_pyg(
self, x: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> Union[PyGLayerData, PyGLayerHeteroData]:
"""
Process layer inputs so that they can be consumed by a PyG model layer.
Parameters
----------
x : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input node features to the GNN layer.
Returns
-------
Union[PyGLayerData, PyGLayerHeteroData]
A named tuple class with `x`, `edge_index` and `size` fields.
Typically, a PyG GNN layer's forward method will accept these as
arguments.
"""
if isinstance(x, torch.Tensor):
# Homogenous
src = self.sampled_csc.indices
dst = expand_indptr(
self.sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = self.sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
return PyGLayerData(
(x, x[:dst_size]), edge_index, (x.size(0), dst_size)
)
else:
# Heterogenous
x_dst_dict = {}
edge_index_dict = {}
sizes_dict = {}
for etype, sampled_csc in self.sampled_csc.items():
src = sampled_csc.indices
dst = expand_indptr(
sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
x_dst_dict[dst_ntype] = x[dst_ntype][:dst_size]
edge_index_dict[etype] = edge_index
sizes_dict[etype] = (x[src_ntype].size(0), dst_size)

return PyGLayerHeteroData(
(x, x_dst_dict), edge_index_dict, sizes_dict
)

def to(
self, device: torch.device, non_blocking=False
) -> None: # pylint: disable=invalid-name
Expand Down
100 changes: 100 additions & 0 deletions tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import backend as F

import dgl
import dgl.graphbolt as gb
import pytest
import torch
Expand Down Expand Up @@ -505,6 +506,105 @@ def test_exclude_edges_hetero_duplicated_tensor(reverse_row, reverse_column):
_assert_container_equal(result.original_edge_ids, expected_edge_ids)


def test_to_pyg_homo():
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
names = "seeds"
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]
sampler = gb.NeighborSampler
datapipe = sampler(
datapipe,
graph,
fanouts,
deduplicate=True,
)
for minibatch in datapipe:
x = torch.randn((minibatch.node_ids().size(0), 2), dtype=torch.float32)
for subgraph in minibatch.sampled_subgraphs:
(x_src, x_dst), edge_index, sizes = subgraph.to_pyg(x)
assert torch.equal(x_src, x)
dst_size = subgraph.original_column_node_ids.size(0)
assert torch.equal(x_dst, x[:dst_size])
src_size = subgraph.original_row_node_ids.size(0)
assert dst_size == sizes[1]
assert src_size == sizes[0]
assert torch.equal(edge_index[0], subgraph.sampled_csc.indices)
assert torch.equal(
edge_index[1],
gb.expand_indptr(
subgraph.sampled_csc.indptr,
subgraph.sampled_csc.indices.dtype,
),
)
x = x_dst


def test_to_pyg_hetero():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
).to(F.ctx())
itemset = gb.HeteroItemSet(
{"n1:e1:n2": gb.ItemSet(torch.tensor([[0, 1]]), names="seeds")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
for minibatch in datapipe:
x = {}
for key, ids in minibatch.node_ids().items():
x[key] = torch.randn((ids.size(0), 2), dtype=torch.float32)
for subgraph in minibatch.sampled_subgraphs:
(x_src, x_dst), edge_index, sizes = subgraph.to_pyg(x)
assert x_src == x
for ntype in x:
dst_size = subgraph.original_column_node_ids[ntype].size(0)
assert torch.equal(x_dst[ntype], x[ntype][:dst_size])
for etype in subgraph.sampled_csc:
src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)
src_size = subgraph.original_row_node_ids[src_ntype].size(0)
dst_size = subgraph.original_column_node_ids[dst_ntype].size(0)
assert dst_size == sizes[etype][1]
assert src_size == sizes[etype][0]
assert torch.equal(
edge_index[etype][0], subgraph.sampled_csc[etype].indices
)
assert torch.equal(
edge_index[etype][1],
gb.expand_indptr(
subgraph.sampled_csc[etype].indptr,
subgraph.sampled_csc[etype].indices.dtype,
),
)
x = x_dst


@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
Expand Down

0 comments on commit 8eccbfa

Please sign in to comment.