Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] reformat and update performance of grace #3009

Merged
merged 4 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions examples/pytorch/grace/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,47 @@ This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang9
--dataname str The graph dataset name. Default is 'cora'.
--gpu int GPU index. Default is 0.
--split int Dataset spliting method. Default is 'random'.
--epochs int Number of training periods. Default is 500.
--lr float Learning rate. Default is 0.001.
--wd float Weight decay. Default is 1e-5.
--temp float Temperature. Default is 1.0.
--act_fn str Activation function. Default is relu.001
mufeili marked this conversation as resolved.
Show resolved Hide resolved
--hid_dim int Hidden dimension. Default is 256.
--out_dim int Output dimension. Default is 256.
--num_layers int Number of GNN layers. Default is 2.
--der1 float Drop edge ratio 1. Default is 0.2.
--der2 float Drop edge ratio 2. Default is 0.2.
--dfr1 float Drop feature ratio 1. Default is 0.2.
--dfr2 float Drop feature ratio 2. Default is 0.2.
```

## How to run examples

In the paper(as well as authors' repo), the training set and testing set are split randomly with 1:9 ratio. In order to fairly compare it with other models with the split they used, termed public split, in this repo we also provide experiment results using public split. To run the examples,
In the paper(as well as authors' repo), the training set and testing set are split randomly with 1:9 ratio. In order to fairly compare it with other methods with public split (20 training nodes each class), in this repo we also provide its results using public split (with fine-tuned hyper-parameters). To run the examples, follow the following instructions.
mufeili marked this conversation as resolved.
Show resolved Hide resolved
mufeili marked this conversation as resolved.
Show resolved Hide resolved

```python
# Cora with random split
python main.py --dataname cora
python main.py --dataname cora --epochs 200 --lr 5e-4 --wd 1e-5 --hid_dim 128 --out_dim 128 --act_fn relu --der1 0.2 --der2 0.4 --dfr1 0.3 --dfr2 0.4 --temp 0.4

# Cora with public split
python main.py --dataname cora --split public
```
python main.py --dataname cora --split public --epochs 400 --lr 5e-4 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn relu --der1 0.3 --der2 0.4 --dfr1 0.3 --dfr2 0.4 --temp 0.4

# Citeseer with random split
python main.py --dataname citeseer --epochs 200 --lr 1e-3 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn prelu --der1 0.2 --der2 0.0 --dfr1 0.3 --dfr2 0.2 --temp 0.9

# Citeseer with public split
python main.py --dataname citeseer --split public --epochs 100 --lr 1e-3 --wd 1e-5 --hid_dim 512 --out_dim 512 --act_fn prelu --der1 0.3 --der2 0.3 --dfr1 0.3 --dfr2 0.3 --temp 0.4

replace 'cora' with 'citeseer' or 'pubmed' if you would like to run this example for other datasets.
# Pubmed with random split
python main.py --dataname pubmed --epochs 1500 --lr 1e-3 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn relu --der1 0.4 --der2 0.1 --dfr1 0.0 --dfr2 0.2 --temp 0.7

# Pubmed with public split
python main.py --dataname pubmed --split public --epochs 1500 --lr 1e-3 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn relu --der1 0.4 --der2 0.1 --dfr1 0.0 --dfr2 0.2 --temp 0.7
```

## Performance

We use the same hyperparameter settings as provided by the author, you can check config.yaml for detailed hyper-parameters for each dataset.
For random split, we use the hyper-parameters as stated in the paper. For public split, we find the given hyper-parameters lead to poor performance, so we select the hyperparameters via a small grid search.
mufeili marked this conversation as resolved.
Show resolved Hide resolved

Random split (Train/Test = 1:9)

Expand All @@ -64,6 +86,6 @@ Public split

| Dataset | Cora | Citeseer | Pubmed |
| :-----------: | :--: | :------: | :----: |
| Author's Code | 79.9 | 68.6 | 81.3 |
| DGL | 80.1 | 68.9 | 81.2 |
| Author's Code | 81.9 | 71.2 | 80.6 |
| DGL | 82.2 | 71.4 | 80.2 |

39 changes: 0 additions & 39 deletions examples/pytorch/grace/config.yaml

This file was deleted.

69 changes: 43 additions & 26 deletions examples/pytorch/grace/main.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,81 @@
import argparse

from model import Grace
from aug import aug
from dataset import load

import numpy as np
import torch as th
import torch.nn as nn

import yaml
from yaml import SafeLoader

from eval import label_classification
import warnings

warnings.filterwarnings('ignore')


def count_parameters(model):
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])


parser = argparse.ArgumentParser()
parser.add_argument('--dataname', type=str, default='cora', choices = ['cora', 'citeseer', 'pubmed'])
parser.add_argument('--dataname', type=str, default='cora')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--split', type=str, default='random', choices = ['random', 'public'])
parser.add_argument('--split', type=str, default='random')

parser.add_argument('--epochs', type=int, default=500, help='Number of training periods.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--wd', type=float, default=1e-5, help='Weight decay.')
parser.add_argument('--temp', type=float, default=1.0, help='Temperature.')

parser.add_argument('--act_fn', type=str, default='relu')

parser.add_argument("--hid_dim", type=int, default=256, help='Hidden layer dim.')
parser.add_argument("--out_dim", type=int, default=256, help='Output layer dim.')

parser.add_argument("--num_layers", type=int, default=2, help='Number of GNN layers.')
parser.add_argument('--der1', type=float, default=0.2, help='Drop edge ratio of the 1st augmentation.')
parser.add_argument('--der2', type=float, default=0.2, help='Drop edge ratio of the 2nd augmentation.')
parser.add_argument('--dfr1', type=float, default=0.2, help='Drop feature ratio of the 1st augmentation.')
parser.add_argument('--dfr2', type=float, default=0.2, help='Drop feature ratio of the 2nd augmentation.')

args = parser.parse_args()

if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu)
else:
args.device = 'cpu'


if __name__ == '__main__':

# Step 1: Load hyperparameters =================================================================== #
config = 'config.yaml'
config = yaml.load(open(config), Loader=SafeLoader)[args.dataname]
lr = config['learning_rate']
hid_dim = config['num_hidden']
out_dim = config['num_proj_hidden']
lr = args.lr
hid_dim = args.hid_dim
out_dim = args.out_dim

num_layers = config['num_layers']
act_fn = ({'relu': nn.ReLU(), 'prelu': nn.PReLU()})[config['activation']]
num_layers = args.num_layers
act_fn = ({'relu': nn.ReLU(), 'prelu': nn.PReLU()})[args.act_fn]

drop_edge_rate_1 = config['drop_edge_rate_1']
drop_edge_rate_2 = config['drop_edge_rate_2']
drop_feature_rate_1 = config['drop_feature_rate_1']
drop_feature_rate_2 = config['drop_feature_rate_2']
drop_edge_rate_1 = args.der1
drop_edge_rate_2 = args.der2
drop_feature_rate_1 = args.dfr1
drop_feature_rate_2 = args.dfr2

temp = config['tau']
epochs = config['num_epochs']
wd = config['weight_decay']
temp = args.temp
epochs = args.epochs
wd = args.wd

# Step 2: Prepare data =================================================================== #
# Step 2: Prepare data =================================================================== =#
mufeili marked this conversation as resolved.
Show resolved Hide resolved
graph, feat, labels, train_mask, test_mask = load(args.dataname)

in_dim = feat.shape[1]

# Step 3: Create model =================================================================== #
model = Grace(in_dim, hid_dim, out_dim, num_layers, act_fn, temp)
model = model.to(args.device)
print(f'# params: {count_parameters(model)}')

optimizer = th.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

# Step 4: Training ======================================================================= #
# Step 4: Training =======================================================================
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
Expand All @@ -79,11 +95,12 @@
print(f'Epoch={epoch:03d}, loss={loss.item():.4f}')

# Step 5: Linear evaluation ============================================================== #
print("=== Final Evaluation ===")
print("=== Final ===")

graph = graph.add_self_loop()
graph = graph.to(args.device)
feat = feat.to(args.device)
embeds = model.get_embedding(graph, feat)

'''Evaluation Embeddings '''
label_classification(embeds, labels, train_mask, test_mask, split=args.split, ratio=0.1)
label_classification(embeds, labels, train_mask, test_mask, split=args.split)
2 changes: 1 addition & 1 deletion examples/pytorch/grace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def sim(self, z1, z2):

def get_loss(self, z1, z2):
# calculate SimCLR loss

mufeili marked this conversation as resolved.
Show resolved Hide resolved
f = lambda x: th.exp(x / self.temp)

refl_sim = f(self.sim(z1, z1)) # intra-view pairs
Expand Down