Skip to content

Commit

Permalink
[GraphBolt][PyG] Use SampledSubgraph.to_pyg in examples. (#7747)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 26, 2024
1 parent 8eccbfa commit 6bce0cd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 74 deletions.
32 changes: 2 additions & 30 deletions examples/graphbolt/pyg/hetero/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,6 @@ def create_dataloader(
return gb.DataLoader(datapipe, num_workers=args.num_workers)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
#
# We convert the provided sampled edges in CSC format from GraphBolt and
# convert to COO via using gb.expand_indptr.
#####################################################################
h_dst_dict = {}
edge_index_dict = {}
sizes_dict = {}
for etype, sampled_csc in subgraph.sampled_csc.items():
src = sampled_csc.indices
dst = gb.expand_indptr(
sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)
h_dst_dict[dst_ntype] = h[dst_ntype][:dst_size]
edge_index_dict[etype] = edge_index
sizes_dict[etype] = (h[src_ntype].size(0), dst_size)

return (h, h_dst_dict), edge_index_dict, sizes_dict


class RelGraphConvLayer(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -153,7 +125,7 @@ def forward(self, subgraph, x):
# only on the destination nodes' features. By doing so, we ensure the
# feature dimensions match and prevent any misuse of incorrect node
# features.
(h, h_dst), edge_index, size = convert_to_pyg(x, subgraph)
(h, h_dst), edge_index, size = subgraph.to_pyg(x)

h_out = {}
for etype in edge_index:
Expand Down Expand Up @@ -514,7 +486,7 @@ def main():

# Initialize the entity classification model.
model = EntityClassify(
graph, feat_size, hidden_channels, num_classes, 3
graph, feat_size, hidden_channels, num_classes, len(args.fanout)
).to(args.device)

print(
Expand Down
25 changes: 3 additions & 22 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@ def accuracy(out, labels):
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
#
# We convert the provided sampled edges in CSC format from GraphBolt and
# convert to COO via using gb.expand_indptr.
#####################################################################
src = subgraph.sampled_csc.indices
dst = gb.expand_indptr(
subgraph.sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = subgraph.sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
return (h, h[:dst_size]), edge_index, (h.size(0), dst_size)


class GraphSAGE(torch.nn.Module):
def __init__(
self, in_size, hidden_size, out_size, n_layers, dropout, variant
Expand All @@ -75,7 +56,7 @@ def __init__(
def forward(self, subgraphs, x):
h = x
for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):
h, edge_index, size = convert_to_pyg(h, subgraph)
h, edge_index, size = subgraph.to_pyg(h)
h = layer(h, edge_index, size=size)
if self.variant == "custom":
h = self.activation(h)
Expand All @@ -101,8 +82,8 @@ def inference(self, graph, features, dataloader, storage_device):
)
for data in tqdm(dataloader, "Inferencing"):
# len(data.sampled_subgraphs) = 1
h, edge_index, size = convert_to_pyg(
data.node_features["feat"], data.sampled_subgraphs[0]
h, edge_index, size = data.sampled_subgraphs[0].to_pyg(
data.node_features["feat"]
)
hidden_x = layer(h, edge_index, size=size)
if self.variant == "custom":
Expand Down
25 changes: 3 additions & 22 deletions examples/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,6 @@ def accuracy(out, labels):
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
#
# We convert the provided sampled edges in CSC format from GraphBolt and
# convert to COO via using gb.expand_indptr.
#####################################################################
src = subgraph.sampled_csc.indices
dst = gb.expand_indptr(
subgraph.sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = subgraph.sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
return (h, h[:dst_size]), edge_index, (h.size(0), dst_size)


class GraphSAGE(torch.nn.Module):
#####################################################################
# (HIGHLIGHT) Define the GraphSAGE model architecture.
Expand Down Expand Up @@ -123,7 +104,7 @@ def forward(self, subgraphs, x):
# given features to get src and dst features to use the PyG layers
# in the more efficient bipartite mode.
#####################################################################
h, edge_index, size = convert_to_pyg(h, subgraph)
h, edge_index, size = subgraph.to_pyg(h)
h = layer(h, edge_index, size=size)
if i != len(subgraphs) - 1:
h = F.relu(h)
Expand All @@ -146,8 +127,8 @@ def inference(self, graph, features, dataloader, storage_device):
)
for data in tqdm(dataloader, "Inferencing"):
# len(data.sampled_subgraphs) = 1
h, edge_index, size = convert_to_pyg(
data.node_features["feat"], data.sampled_subgraphs[0]
h, edge_index, size = data.sampled_subgraphs[0].to_pyg(
data.node_features["feat"]
)
hidden_x = layer(h, edge_index, size=size)
if not is_last_layer: