Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Update GCN on ogbn-arxiv dataset #2153

Merged
merged 14 commits into from
Sep 5, 2020
11 changes: 6 additions & 5 deletions examples/pytorch/ogb/ogbn-arxiv/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Requires DGL 0.5 or later versions.

Run `gcn.py` with `--use-linear` and `use-labels` enabled and you should directly see the result.
Run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result.

```bash
python3 gcn.py --use-linear --use-labels
Expand All @@ -17,12 +17,12 @@ usage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs

optional arguments:
-h, --help show this help message and exit
--cpu CPU mode. This option overrides --gpu.
--gpu GPU GPU device ID.
--cpu CPU mode. This option overrides --gpu. (default: False)
--gpu GPU GPU device ID. (default: 0)
--n-runs N_RUNS
--n-epochs N_EPOCHS
--use-labels Use labels in the training set as input features.
--use-linear Use linear layers.
--use-labels Use labels in the training set as input features. (default: False)
--use-linear Use linear layer. (default: False)
--lr LR
--n-layers N_LAYERS
--n-hidden N_HIDDEN
Expand All @@ -41,3 +41,4 @@ Here are the results over 10 runs.
| Val acc | 0.7361 ± 0.0009 | 0.7397 ± 0.0010 | 0.7399 ± 0.0008 | 0.7442 ± 0.0012 |
| Test acc | 0.7246 ± 0.0021 | 0.7270 ± 0.0016 | 0.7259 ± 0.0006 | 0.7306 ± 0.0024 |
| Parameters | 109608 | 218152 | 119848 | 238632 |

78 changes: 39 additions & 39 deletions examples/pytorch/ogb/ogbn-arxiv/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,46 @@
import torch.optim as optim
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator

from models import GCN

device = None
in_feats, n_classes = None, None


def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def gen_model(args):
if args.use_labels:
model = GCN(
in_feats + n_classes, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear
)
else:
model = GCN(in_feats, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear)
return model


def cross_entropy(x, labels):
y = F.cross_entropy(x, labels, reduction="none")
y = F.cross_entropy(x, labels[:, 0], reduction="none")
y = th.log(0.5 + y) - math.log(0.5)
return th.mean(y)


def compute_acc(pred, labels, evaluator):
return evaluator.eval({"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels})["acc"]


def add_labels(feat, labels, idx):
onehot = th.zeros([feat.shape[0], n_classes]).to(device)
onehot[idx, labels[idx]] = 1
onehot[idx, labels[idx, 0]] = 1
return th.cat([feat, onehot], dim=-1)


def adjust_learning_rate(optimizer, lr, epoch):
if epoch <= 50:
for param_group in optimizer.param_groups:
param_group["lr"] = lr * epoch / 50


def train(model, graph, labels, train_idx, optimizer, use_labels):
model.train()

Expand All @@ -54,6 +67,7 @@ def train(model, graph, labels, train_idx, optimizer, use_labels):
else:
mask_rate = 0.5
mask = th.rand(train_idx.shape) < mask_rate

train_pred_idx = train_idx[mask]

optimizer.zero_grad()
Expand All @@ -66,7 +80,7 @@ def train(model, graph, labels, train_idx, optimizer, use_labels):


@th.no_grad()
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels):
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator):
model.eval()

feat = graph.ndata["feat"]
Expand All @@ -80,32 +94,16 @@ def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels):
test_loss = cross_entropy(pred[test_idx], labels[test_idx])

return (
compute_acc(pred[train_idx], labels[train_idx]),
compute_acc(pred[val_idx], labels[val_idx]),
compute_acc(pred[test_idx], labels[test_idx]),
compute_acc(pred[train_idx], labels[train_idx], evaluator),
compute_acc(pred[val_idx], labels[val_idx], evaluator),
compute_acc(pred[test_idx], labels[test_idx], evaluator),
train_loss,
val_loss,
test_loss,
)


def adjust_learning_rate(optimizer, lr, epoch):
if epoch <= 50:
for param_group in optimizer.param_groups:
param_group["lr"] = lr * epoch / 50


def gen_model(args):
if args.use_labels:
model = GCN(
in_feats + n_classes, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear
)
else:
model = GCN(in_feats, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear)
return model


def run(args, graph, labels, train_idx, val_idx, test_idx, n_running):
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
# define model and optimizer
model = gen_model(args)
model = model.to(device)
Expand All @@ -128,27 +126,28 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, n_running):
adjust_learning_rate(optimizer, args.lr, epoch)

loss, pred = train(model, graph, labels, train_idx, optimizer, args.use_labels)
acc = compute_acc(pred[train_idx], labels[train_idx])
acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)

train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate(
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels, evaluator
)

lr_scheduler.step(loss)

toc = time.time()
total_time += toc - tic
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate(
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels
)
Copy link
Member

@mufeili mufeili Sep 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be a duplicate by accident. @espylapiza


# if val_acc > best_val_acc:
if val_loss < best_val_loss:
best_val_loss = val_loss.item()
best_val_acc = val_acc.item()
best_test_acc = test_acc.item()
best_val_loss = val_loss
best_val_acc = val_acc
best_test_acc = test_acc

if epoch % args.log_every == 0:
print(f"Epoch: {epoch}/{args.n_epochs}")
print(
f"Loss: {loss.item():.4f}, Acc: {acc.item():.4f}\n"
f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}"
)
Expand Down Expand Up @@ -234,10 +233,11 @@ def main():

# load data
data = DglNodePropPredDataset(name="ogbn-arxiv")
evaluator = Evaluator(name="ogbn-arxiv")

splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
graph, labels = data[0]
labels = labels[:, 0]

# add reverse edges
srcs, dsts = graph.all_edges()
Expand All @@ -263,7 +263,7 @@ def main():
test_accs = []

for i in range(args.n_runs):
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, i)
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i)
val_accs.append(val_acc)
test_accs.append(test_acc)

Expand Down