Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][PyG] Use SampledSubgraph.to_pyg in examples. #7747

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 2 additions & 30 deletions examples/graphbolt/pyg/hetero/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,6 @@ def create_dataloader(
return gb.DataLoader(datapipe, num_workers=args.num_workers)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
#
# We convert the provided sampled edges in CSC format from GraphBolt and
# convert to COO via using gb.expand_indptr.
#####################################################################
h_dst_dict = {}
edge_index_dict = {}
sizes_dict = {}
for etype, sampled_csc in subgraph.sampled_csc.items():
src = sampled_csc.indices
dst = gb.expand_indptr(
sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)
h_dst_dict[dst_ntype] = h[dst_ntype][:dst_size]
edge_index_dict[etype] = edge_index
sizes_dict[etype] = (h[src_ntype].size(0), dst_size)

return (h, h_dst_dict), edge_index_dict, sizes_dict


class RelGraphConvLayer(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -153,7 +125,7 @@ def forward(self, subgraph, x):
# only on the destination nodes' features. By doing so, we ensure the
# feature dimensions match and prevent any misuse of incorrect node
# features.
(h, h_dst), edge_index, size = convert_to_pyg(x, subgraph)
(h, h_dst), edge_index, size = subgraph.to_pyg(x)

h_out = {}
for etype in edge_index:
Expand Down Expand Up @@ -514,7 +486,7 @@ def main():

# Initialize the entity classification model.
model = EntityClassify(
graph, feat_size, hidden_channels, num_classes, 3
graph, feat_size, hidden_channels, num_classes, len(args.fanout)
).to(args.device)

print(
Expand Down
25 changes: 3 additions & 22 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@ def accuracy(out, labels):
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
#
# We convert the provided sampled edges in CSC format from GraphBolt and
# convert to COO via using gb.expand_indptr.
#####################################################################
src = subgraph.sampled_csc.indices
dst = gb.expand_indptr(
subgraph.sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = subgraph.sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
return (h, h[:dst_size]), edge_index, (h.size(0), dst_size)


class GraphSAGE(torch.nn.Module):
def __init__(
self, in_size, hidden_size, out_size, n_layers, dropout, variant
Expand All @@ -75,7 +56,7 @@ def __init__(
def forward(self, subgraphs, x):
h = x
for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):
h, edge_index, size = convert_to_pyg(h, subgraph)
h, edge_index, size = subgraph.to_pyg(h)
h = layer(h, edge_index, size=size)
if self.variant == "custom":
h = self.activation(h)
Expand All @@ -101,8 +82,8 @@ def inference(self, graph, features, dataloader, storage_device):
)
for data in tqdm(dataloader, "Inferencing"):
# len(data.sampled_subgraphs) = 1
h, edge_index, size = convert_to_pyg(
data.node_features["feat"], data.sampled_subgraphs[0]
h, edge_index, size = data.sampled_subgraphs[0].to_pyg(
data.node_features["feat"]
)
hidden_x = layer(h, edge_index, size=size)
if self.variant == "custom":
Expand Down
25 changes: 3 additions & 22 deletions examples/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,6 @@ def accuracy(out, labels):
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)


def convert_to_pyg(h, subgraph):
#####################################################################
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
#
# We convert the provided sampled edges in CSC format from GraphBolt and
# convert to COO via using gb.expand_indptr.
#####################################################################
src = subgraph.sampled_csc.indices
dst = gb.expand_indptr(
subgraph.sampled_csc.indptr,
dtype=src.dtype,
output_size=src.size(0),
)
edge_index = torch.stack([src, dst], dim=0).long()
dst_size = subgraph.sampled_csc.indptr.size(0) - 1
# h and h[:dst_size] correspond to source and destination features resp.
return (h, h[:dst_size]), edge_index, (h.size(0), dst_size)


class GraphSAGE(torch.nn.Module):
#####################################################################
# (HIGHLIGHT) Define the GraphSAGE model architecture.
Expand Down Expand Up @@ -123,7 +104,7 @@ def forward(self, subgraphs, x):
# given features to get src and dst features to use the PyG layers
# in the more efficient bipartite mode.
#####################################################################
h, edge_index, size = convert_to_pyg(h, subgraph)
h, edge_index, size = subgraph.to_pyg(h)
h = layer(h, edge_index, size=size)
if i != len(subgraphs) - 1:
h = F.relu(h)
Expand All @@ -146,8 +127,8 @@ def inference(self, graph, features, dataloader, storage_device):
)
for data in tqdm(dataloader, "Inferencing"):
# len(data.sampled_subgraphs) = 1
h, edge_index, size = convert_to_pyg(
data.node_features["feat"], data.sampled_subgraphs[0]
h, edge_index, size = data.sampled_subgraphs[0].to_pyg(
data.node_features["feat"]
)
hidden_x = layer(h, edge_index, size=size)
if not is_last_layer:
Expand Down
Loading