Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Label Propagation and Correct&Smooth #2852

Merged
merged 16 commits into from
May 12, 2021
15 changes: 12 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,17 @@ The folder contains example implementations of selected research papers related
| [DeeperGCN: All You Need to Train Deeper GCNs](#deepergcn) | | | :heavy_check_mark: | | :heavy_check_mark: |
| [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting](#dcrnn) | | | :heavy_check_mark: | | |
| [GaAN: Gated Attention Networks for Learning on large and Spatiotemporal Graphs](#gaan) | | | :heavy_check_mark: | | |
| [Combining Label Propagation and Simple Models Out-performs Graph Neural Networks](#correct_and_smooth) | :heavy_check_mark: | | | | :heavy_check_mark: |
| [Learning from Labeled and Unlabeled Data with Label Propagation](#label_propagation) | :heavy_check_mark: | | | | |

## 2021

- <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL).
- Example code: [PyTorch](../examples/pytorch/bgnn)
- Tags: semi-supervised node classification, tabular data, GBDT
- <a name="correct_and_smooth"></a> Huang et al. Combining Label Propagation and Simple Models Out-performs Graph Neural Networks. [Paper link](https://arxiv.org/abs/2010.13993).
- Example code: [PyTorch](../examples/pytorch/correct_and_smooth)
- Tags: efficiency, node classification, label propagation

## 2020

Expand Down Expand Up @@ -142,7 +147,7 @@ The folder contains example implementations of selected research papers related
- Tags: molecules, molecular property prediction, quantum chemistry
- <a name="tgn"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637).
- Example code: [Pytorch](../examples/pytorch/tgn)
- Tags: over-smoothing, node classification
- Tags: temporal, node classification
- <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082).
- Example code: [PyTorch](../examples/pytorch/compGCN)
- Tags: multi-relational graphs, graph neural network
Expand All @@ -152,7 +157,6 @@ The folder contains example implementations of selected research papers related

## 2019


- <a name="infograph"></a> Sun et al. InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. [Paper link](https://arxiv.org/abs/1908.01000).
- Example code: [PyTorch](../examples/pytorch/infograph)
- Tags: semi-supervised graph regression, unsupervised graph classification
Expand Down Expand Up @@ -229,7 +233,6 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](../examples/pytorch/gnn_explainer)
- Tags: Graph Neural Network, Explainability


## 2018

- <a name="dgmg"></a> Li et al. Learning Deep Generative Models of Graphs. [Paper link](https://arxiv.org/abs/1803.03324).
Expand Down Expand Up @@ -419,6 +422,12 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](../examples/pytorch/graph_matching)
- Tags: graph edit distance, graph matching

## 2002

- <a name="label_propagation"></a> Zhu & Ghahramani. Learning from Labeled and Unlabeled Data with Label Propagation. [Paper link](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf).
- Example code: [PyTorch](../examples/pytorch/label_propagation)
- Tags: node classification, label propagation

## 1998

- <a name="pagerank"></a> Page et al. The PageRank Citation Ranking: Bringing Order to the Web. [Paper link](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.38.5427).
Expand Down
75 changes: 75 additions & 0 deletions examples/pytorch/correct_and_smooth/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# DGL Implementation of CorrectAndSmooth

This DGL example implements the GNN model proposed in the paper [Combining Label Propagation and Simple Models Out-performs Graph Neural Networks](https://arxiv.org/abs/2010.13993). For the original implementation, see [here](https://github.com/CUAI/CorrectAndSmooth).

Contributor: [xnuohz](https://github.com/xnuohz)

### Requirements
The codebase is implemented in Python 3.7. For version requirement of packages, see below.

```
dgl 0.6.0.post1
torch 1.7.0
ogb 1.3.0
```

### The graph datasets used in this example

Open Graph Benchmark(OGB). Dataset summary:

| Dataset | #Nodes | #Edges | #Node Feats | Metric |
| :-----------: | :-------: | :--------: | :---------: | :------: |
| ogbn-arxiv | 169,343 | 1,166,243 | 128 | Accuracy |
| ogbn-products | 2,449,029 | 61,859,140 | 100 | Accuracy |

### Usage

Training a **Base predictor** and using **Correct&Smooth** which follows the original hyperparameters on different datasets.

##### ogbn-arxiv

* **MLP + C&S**

```bash
python main.py --dropout 0.5
python main.py --pretrain --correction-adj DA --smoothing-adj AD
```

* **Linear + C&S**

```bash
python main.py --model linear --dropout 0.5 --epochs 1000
python main.py --model linear --pretrain --correction-alpha 0.8 --smoothing-alpha 0.6 --correction-adj AD
```

##### ogbn-products

* **Linear + C&S**

```bash
python main.py --dataset ogbn-products --model linear --dropout 0.5 --epochs 1000 --lr 0.1
python main.py --dataset ogbn-products --model linear --pretrain --correction-alpha 0.6 --smoothing-alpha 0.9
```

### Performance

#### ogbn-arxiv

| | MLP | MLP + C&S | Linear | Linear + C&S |
| :-------------: | :---: | :-------: | :----: | :----------: |
| Results(Author) | 55.58 | 68.72 | 51.06 | 70.24 |
| Results(DGL) | 56.12 | 68.63 | 52.49 | 71.69 |

#### ogbn-products

| | Linear | Linear + C&S |
| :-------------: | :----: | :----------: |
| Results(Author) | 47.67 | 82.34 |
| Results(DGL) | 47.71 | 79.57 |

### Speed

| ogb-arxiv | Time | GPU Memory | Params |
| :------------------: | :-----------: | :--------: | :-----: |
| Author, Linear + C&S | 6.3 * 10 ^ -3 | 1,248M | 5,160 |
| DGL, Linear + C&S | 5.6 * 10 ^ -3 | 1,252M | 5,160 |
170 changes: 170 additions & 0 deletions examples/pytorch/correct_and_smooth/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import argparse
import copy
import os
import torch
import torch.nn.functional as F
import torch.optim as optim
import dgl
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from model import MLP, MLPLinear, CorrectAndSmooth


def evaluate(y_pred, y_true, idx, evaluator):
return evaluator.eval({
'y_true': y_true[idx],
'y_pred': y_pred[idx]
})['acc']


def main():
# check cuda
device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu'
# load data
dataset = DglNodePropPredDataset(name=args.dataset)
evaluator = Evaluator(name=args.dataset)

split_idx = dataset.get_idx_split()
g, labels = dataset[0] # graph: DGLGraph object, label: torch tensor of shape (num_nodes, num_tasks)

if args.dataset == 'ogbn-arxiv':
g = dgl.to_bidirected(g, copy_ndata=True)

feat = g.ndata['feat']
feat = (feat - feat.mean(0)) / feat.std(0)
g.ndata['feat'] = feat

g = g.to(device)
feats = g.ndata['feat']
labels = labels.to(device)

# load masks for train / validation / test
train_idx = split_idx["train"].to(device)
valid_idx = split_idx["valid"].to(device)
test_idx = split_idx["test"].to(device)

n_features = feats.size()[-1]
n_classes = dataset.num_classes

mufeili marked this conversation as resolved.
Show resolved Hide resolved
# load model
if args.model == 'mlp':
model = MLP(n_features, args.hid_dim, n_classes, args.num_layers, args.dropout)
elif args.model == 'linear':
model = MLPLinear(n_features, n_classes)
else:
raise NotImplementedError(f'Model {args.model} is not supported.')

model = model.to(device)
print(f'Model parameters: {sum(p.numel() for p in model.parameters())}')

if args.pretrain:
print('---------- Before ----------')
model.load_state_dict(torch.load(f'base/{args.dataset}-{args.model}.pt'))
model.eval()

y_soft = model(feats).exp()

y_pred = y_soft.argmax(dim=-1, keepdim=True)
valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
test_acc = evaluate(y_pred, labels, test_idx, evaluator)
print(f'Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}')

print('---------- Correct & Smoothing ----------')
cs = CorrectAndSmooth(num_correction_layers=args.num_correction_layers,
correction_alpha=args.correction_alpha,
correction_adj=args.correction_adj,
num_smoothing_layers=args.num_smoothing_layers,
smoothing_alpha=args.smoothing_alpha,
smoothing_adj=args.smoothing_adj,
scale=args.scale)

mask_idx = torch.cat([train_idx, valid_idx])
y_soft = cs.correct(g, y_soft, labels[mask_idx], mask_idx)
y_soft = cs.smooth(g, y_soft, labels[mask_idx], mask_idx)
y_pred = y_soft.argmax(dim=-1, keepdim=True)
valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
test_acc = evaluate(y_pred, labels, test_idx, evaluator)
print(f'Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}')
else:
opt = optim.Adam(model.parameters(), lr=args.lr)

best_acc = 0
best_model = copy.deepcopy(model)

# training
print('---------- Training ----------')
for i in range(args.epochs):
mufeili marked this conversation as resolved.
Show resolved Hide resolved

model.train()
opt.zero_grad()

logits = model(feats)

train_loss = F.nll_loss(logits[train_idx], labels.squeeze(1)[train_idx])
mufeili marked this conversation as resolved.
Show resolved Hide resolved
train_loss.backward()

opt.step()

model.eval()
with torch.no_grad():
logits = model(feats)

y_pred = logits.argmax(dim=-1, keepdim=True)
mufeili marked this conversation as resolved.
Show resolved Hide resolved

train_acc = evaluate(y_pred, labels, train_idx, evaluator)
valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)

print(f'Epoch {i} | Train loss: {train_loss.item():.4f} | Train acc: {train_acc:.4f} | Valid acc {valid_acc:.4f}')

if valid_acc > best_acc:
best_acc = valid_acc
best_model = copy.deepcopy(model)
mufeili marked this conversation as resolved.
Show resolved Hide resolved

# testing & saving model
print('---------- Testing ----------')
best_model.eval()

logits = best_model(feats)

y_pred = logits.argmax(dim=-1, keepdim=True)
test_acc = evaluate(y_pred, labels, test_idx, evaluator)
print(f'Test acc: {test_acc:.4f}')

if not os.path.exists('base'):
os.makedirs('base')

torch.save(best_model.state_dict(), f'base/{args.dataset}-{args.model}.pt')


if __name__ == '__main__':
"""
Correct & Smoothing Hyperparameters
"""
parser = argparse.ArgumentParser(description='Base predictor(C&S)')

# Dataset
parser.add_argument('--gpu', type=int, default=0, help='-1 for cpu')
parser.add_argument('--dataset', type=str, default='ogbn-arxiv', choices=['ogbn-arxiv', 'ogbn-products'])
# Base predictor
parser.add_argument('--model', type=str, default='mlp', choices=['mlp', 'linear'])
parser.add_argument('--num-layers', type=int, default=3)
parser.add_argument('--hid-dim', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.4)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=300)
# extra options for gat
parser.add_argument('--n-heads', type=int, default=3)
parser.add_argument('--attn_drop', type=float, default=0.05)
# C & S
parser.add_argument('--pretrain', action='store_true', help='Whether to perform C & S')
parser.add_argument('--num-correction-layers', type=int, default=50)
parser.add_argument('--correction-alpha', type=float, default=0.979)
parser.add_argument('--correction-adj', type=str, default='DAD')
parser.add_argument('--num-smoothing-layers', type=int, default=50)
parser.add_argument('--smoothing-alpha', type=float, default=0.756)
parser.add_argument('--smoothing-adj', type=str, default='DAD')
parser.add_argument('--scale', type=float, default=20.)

args = parser.parse_args()
print(args)

main()
Loading