forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Sparse] A hetero-relational GCN example (dmlc#6157)
Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
- Loading branch information
1 parent
deb6902
commit cdce548
Showing
1 changed file
with
284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
""" | ||
Modeling Relational Data with Graph Convolutional Networks | ||
Paper: https://arxiv.org/abs/1703.06103 | ||
Reference Code: https://github.com/tkipf/relational-gcn | ||
This script trains and tests a Hetero Relational Graph Convolutional Networks | ||
(Hetero-RGCN) model based on the information of a full graph. | ||
This flowchart describes the main functional sequence of the provided example. | ||
main | ||
│ | ||
├───> Load and preprocess full dataset | ||
│ | ||
├───> Instantiate Hetero-RGCN model | ||
│ | ||
├───> train | ||
│ │ | ||
│ └───> Training loop | ||
│ │ | ||
│ └───> Hetero-RGCN.forward | ||
└───> test | ||
│ | ||
└───> Evaluate the model | ||
""" | ||
import argparse | ||
import time | ||
|
||
import dgl | ||
import dgl.sparse as dglsp | ||
|
||
import numpy as np | ||
|
||
import torch as th | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset | ||
|
||
|
||
class RelGraphEmbed(nn.Module): | ||
r"""Embedding layer for featureless heterograph.""" | ||
|
||
def __init__( | ||
self, | ||
ntype_num, | ||
embed_size, | ||
): | ||
super(RelGraphEmbed, self).__init__() | ||
self.embed_size = embed_size | ||
self.dropout = nn.Dropout(0.0) | ||
|
||
# Create weight embeddings for each node for each relation. | ||
self.embeds = nn.ParameterDict() | ||
for ntype, num_nodes in ntype_num.items(): | ||
embed = nn.Parameter(th.Tensor(num_nodes, self.embed_size)) | ||
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain("relu")) | ||
self.embeds[ntype] = embed | ||
|
||
def forward(self): | ||
return self.embeds | ||
|
||
|
||
class HeteroRelationalGraphConv(nn.Module): | ||
r"""HeteroRelational graph convolution layer. | ||
Parameters | ||
---------- | ||
in_size : int | ||
Input feature size. | ||
out_size : int | ||
Output feature size. | ||
relation_names : list[str] | ||
Relation names. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_size, | ||
out_size, | ||
relation_names, | ||
activation=None, | ||
): | ||
super(HeteroRelationalGraphConv, self).__init__() | ||
self.in_size = in_size | ||
self.out_size = out_size | ||
self.relation_names = relation_names | ||
self.activation = activation | ||
|
||
######################################################################## | ||
# (HIGHLIGHT) HeteroGraphConv is a graph convolution operator over | ||
# heterogeneous graphs. A dictionary is passed where the key is the | ||
# relation name and the value is the insatnce of conv layer. | ||
######################################################################## | ||
self.W = nn.ModuleDict( | ||
{str(rel): nn.Linear(in_size, out_size) for rel in relation_names} | ||
) | ||
|
||
self.dropout = nn.Dropout(0.0) | ||
|
||
def forward(self, A, inputs): | ||
"""Forward computation | ||
Parameters | ||
---------- | ||
A : Hetero Sparse Matrix | ||
Input graph. | ||
inputs : dict[str, torch.Tensor] | ||
Node feature for each node type. | ||
Returns | ||
------- | ||
dict[str, torch.Tensor] | ||
New node features for each node type. | ||
""" | ||
hs = {} | ||
for rel in A: | ||
src_type, edge_type, dst_type = rel | ||
if dst_type not in hs: | ||
hs[dst_type] = th.zeros( | ||
inputs[dst_type].shape[0], self.out_size | ||
) | ||
#################################################################### | ||
# (HIGHLIGHT) Sparse library use hetero sparse matrix to present | ||
# heterogeneous graphs. A dictionary is passed where the key is | ||
# the tuple of (source node type, edge type, destination node type) | ||
# and the value is the sparse matrix contructed from the key on | ||
# global graph. The convolution operation is the multiplication of | ||
# sparse matrix and convolutional layer. | ||
#################################################################### | ||
hs[dst_type] = hs[dst_type] + ( | ||
A[rel].T @ self.W[str(edge_type)](inputs[src_type]) | ||
) | ||
if self.activation: | ||
hs[dst_type] = self.activation(hs[dst_type]) | ||
hs[dst_type] = self.dropout(hs[dst_type]) | ||
|
||
return hs | ||
|
||
|
||
class EntityClassify(nn.Module): | ||
def __init__( | ||
self, | ||
in_size, | ||
out_size, | ||
relation_names, | ||
embed_layer, | ||
): | ||
super(EntityClassify, self).__init__() | ||
self.in_size = in_size | ||
self.out_size = out_size | ||
self.relation_names = relation_names | ||
self.relation_names.sort() | ||
self.embed_layer = embed_layer | ||
|
||
self.layers = nn.ModuleList() | ||
# Input to hidden. | ||
self.layers.append( | ||
HeteroRelationalGraphConv( | ||
self.in_size, | ||
self.in_size, | ||
self.relation_names, | ||
activation=F.relu, | ||
) | ||
) | ||
# Hidden to output. | ||
self.layers.append( | ||
HeteroRelationalGraphConv( | ||
self.in_size, | ||
self.out_size, | ||
self.relation_names, | ||
) | ||
) | ||
|
||
def forward(self, A): | ||
h = self.embed_layer() | ||
for layer in self.layers: | ||
h = layer(A, h) | ||
return h | ||
|
||
|
||
def main(args): | ||
# Load graph data. | ||
if args.dataset == "aifb": | ||
dataset = AIFBDataset() | ||
elif args.dataset == "mutag": | ||
dataset = MUTAGDataset() | ||
elif args.dataset == "bgs": | ||
dataset = BGSDataset() | ||
elif args.dataset == "am": | ||
dataset = AMDataset() | ||
else: | ||
raise ValueError() | ||
|
||
g = dataset[0] | ||
category = dataset.predict_category | ||
num_classes = dataset.num_classes | ||
train_mask = g.nodes[category].data.pop("train_mask") | ||
test_mask = g.nodes[category].data.pop("test_mask") | ||
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() | ||
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() | ||
labels = g.nodes[category].data.pop("labels") | ||
|
||
# Split dataset into train, validate, test. | ||
val_idx = train_idx[: len(train_idx) // 5] | ||
train_idx = train_idx[len(train_idx) // 5 :] | ||
|
||
embed_layer = RelGraphEmbed( | ||
{ntype: g.num_nodes(ntype) for ntype in g.ntypes}, 16 | ||
) | ||
|
||
# Create model. | ||
model = EntityClassify( | ||
16, | ||
num_classes, | ||
list(set(g.etypes)), | ||
embed_layer, | ||
) | ||
|
||
# Optimizer. | ||
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0) | ||
|
||
# Construct hetero sparse matrix. | ||
A = {} | ||
for stype, etype, dtype in g.canonical_etypes: | ||
eg = g[stype, etype, dtype] | ||
indices = th.stack(eg.edges("uv")) | ||
A[(stype, etype, dtype)] = dglsp.spmatrix( | ||
indices, shape=(g.num_nodes(stype), g.num_nodes(dtype)) | ||
) | ||
########################################################### | ||
# (HIGHLIGHT) Compute the normalized adjacency matrix with | ||
# Sparse Matrix API | ||
########################################################### | ||
D1_hat = dglsp.diag(A[(stype, etype, dtype)].sum(1)) ** -0.5 | ||
D2_hat = dglsp.diag(A[(stype, etype, dtype)].sum(0)) ** -0.5 | ||
A[(stype, etype, dtype)] = D1_hat @ A[(stype, etype, dtype)] @ D2_hat | ||
|
||
# Training loop. | ||
print("start training...") | ||
model.train() | ||
for epoch in range(20): | ||
optimizer.zero_grad() | ||
logits = model(A)[category] | ||
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_acc = th.sum( | ||
logits[train_idx].argmax(dim=1) == labels[train_idx] | ||
).item() / len(train_idx) | ||
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx]) | ||
val_acc = th.sum( | ||
logits[val_idx].argmax(dim=1) == labels[val_idx] | ||
).item() / len(val_idx) | ||
print( | ||
f"Epoch {epoch:05d} | Train Acc: {train_acc:.4f} | " | ||
f"Train Loss: {loss.item():.4f} | Valid Acc: {val_acc:.4f} | " | ||
f"Valid loss: {val_loss.item():.4f} " | ||
) | ||
print() | ||
|
||
model.eval() | ||
logits = model.forward(A)[category] | ||
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) | ||
test_acc = th.sum( | ||
logits[test_idx].argmax(dim=1) == labels[test_idx] | ||
).item() / len(test_idx) | ||
print( | ||
"Test Acc: {:.4f} | Test loss: {:.4f}".format( | ||
test_acc, test_loss.item() | ||
) | ||
) | ||
print() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="RGCN") | ||
parser.add_argument( | ||
"-d", "--dataset", type=str, required=True, help="dataset to use" | ||
) | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
main(args) |