Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] add ogc method #6437

Merged
merged 13 commits into from
Nov 14, 2023
Merged
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ The folder contains example implementations of selected research papers related
* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)

To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).

## 2023

- <a name="labor"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/pytorch/ogc)

- Tags: semi-supervised node classification

## 2022
- <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/labor/train_lightning.py)
Expand Down
37 changes: 37 additions & 0 deletions examples/pytorch/ogc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Optimized Graph Convolution (OGC)

This DGL example implements the OGC method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599).
With only one trainable layer, OGC is a very simple but powerful graph convolution method.


## Example Implementor
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved

This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU.



## Dataset

The DGL's built-in Cora, Pubmed and Citeseer datasets, as follows:

| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |
| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |
| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |


## Usage

```bash
python main.py --dataset cora
python main.py --dataset citeseer
python main.py --dataset pubmed
```

## Performance

| Dataset | Cora | Citeseer | Pubmed |
| :-: | :-: | :-: | :-: |
| OGC (DGL) | **86.9(±0.2)** | **77.4(±0.1)** | **83.6(±0.1)** |
| OGC (Reported) | **86.9(±0.0)** | **77.4(±0.0)** | 83.4(±0.0) |
103 changes: 103 additions & 0 deletions examples/pytorch/ogc/ogc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import time
import argparse
import scipy.sparse as sp

import torch
import torch.nn.functional as F

from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from utils import sparse_mx_to_torch_sparse_tensor, symmetric_normalize_adjacency, LinearNeuralNetwork


# Training settings
decline = 0.9 # the dcline rate
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
eta_sup = 0.001 # the learning rate for supervised loss
eta_W = 0.5 # the learning rate for updating W
beta = 0.1 # in [0,1], the moving probability that a node moves to its neighbors
max_similar_tol = 0.995 # the max_tol test set label prediction similarity between two iterations
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
SinuoXu marked this conversation as resolved.
Show resolved Hide resolved
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
max_patience = 2 # the tolreance for consecutively getting very similar test prediction


def update_U(U, Y, predY, W):
global eta_sup
# ------ update the smoothness loss via LGC ------
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
U = torch.spmm(lazy_adj.to(device), U)

# ------ update the supervised loss via SEB ------
dU_sup = 2*torch.mm(torch.sparse.mm(S, -Y + predY), W)
U = U - eta_sup * dU_sup

eta_sup = eta_sup * decline
return U


def OGC(linear_clf, U, g):
patience = 0
_, _, last_acc, last_outp = linear_clf.test(U, g)
for i in range(64):
# updating W by training a simple linear supervised model Y=W*X
predY, W = linear_clf.update_W(U, g, eta_W)

# updating U by LGC and SEB jointly
U = update_U(U, F.one_hot(g.ndata["label"]).float(), predY, W)

loss_tv, acc_tv, acc_test, pred = linear_clf.test(U, g)
print('epoch {} loss_tv {:.4f} acc_train_val {:.4f} acc_test {:.4f}'.format(
i + 1, loss_tv, acc_tv, acc_test))

sim_rate = float(int((pred == last_outp).sum()) / int(pred.shape[0]))
if (sim_rate > max_similar_tol):
patience += 1
if (patience > max_patience):
break

last_acc = acc_test
last_outp = pred
return last_acc


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
type=str,
default="citeseer",
choices=["cora", "citeseer", "pubmed"],
help='Dataset to use.')
args, _ = parser.parse_known_args()

# load and preprocess dataset
transform = (AddSelfLoop())
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))

g = data[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device)
features = g.ndata["feat"]

adj = symmetric_normalize_adjacency(g)
I_N = sp.eye(features.shape[0])
# lazy random walk (also known as lazy graph convolution)
lazy_adj = (1 - beta) * I_N + beta * adj
lazy_adj = sparse_mx_to_torch_sparse_tensor(lazy_adj)
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
# LIM track, else use both train and validation set to construct S
S = torch.diag(g.ndata["train_mask"]).float().to_sparse()

linear_clf = LinearNeuralNetwork(nfeat=g.ndata["feat"].size(1),
nclass=g.ndata["label"].max().item()+1,
bias=False).to(device)

start_time = time.time()
res = OGC(linear_clf, features, g)
time_tot = time.time() - start_time

print(f'Test Acc:{res:.4f}')
print(f'Total Time:{time_tot:.4f}')
66 changes: 66 additions & 0 deletions examples/pytorch/ogc/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)


def symmetric_normalize_adjacency(graph):
""" Symmetric normalize graph adjacency matrix. """
adj = graph.adjacency_matrix()
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
in_degs = graph.in_degrees().float()
in_norm = torch.pow(in_degs, -0.5).unsqueeze(-1)
degi = torch.diag(torch.squeeze(torch.t(in_norm)))
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
degi = sp.coo_matrix(degi.cpu()).tocsr()
adj = sp.csr_matrix((adj.val.cpu(), (adj.row.cpu(), adj.col.cpu())), shape=adj.shape)
adj = degi.dot(adj.dot(degi))
return adj


class LinearNeuralNetwork(nn.Module):
def __init__(self, nfeat, nclass, bias=True):
super(LinearNeuralNetwork, self).__init__()
self.W = nn.Linear(nfeat, nclass, bias=bias)

def forward(self, x):
return self.W(x)

def test(self, U, g):
self.eval()
with torch.no_grad():
output = self(U)
pred = output.argmax(dim=-1)
labels = g.ndata["label"]
test_mask = g.ndata["test_mask"]
tv_mask = g.ndata["train_mask"] + g.ndata["val_mask"]
loss_tv = F.mse_loss(output[tv_mask],
F.one_hot(labels).float()[tv_mask])
accs = []
for mask in [tv_mask, test_mask]:
accs.append(
float((pred[mask] == labels[mask]).sum()/mask.sum()))
return loss_tv.item(), accs[0], accs[1], pred

def update_W(self, U, g, eta_W):
optimizer = optim.SGD(self.parameters(), lr=eta_W)
self.train()
optimizer.zero_grad()
output = self(U)
labels = g.ndata["label"]
tv_mask = g.ndata["train_mask"] + g.ndata["val_mask"]
loss_tv = F.mse_loss(output[tv_mask],
F.one_hot(labels).float()[tv_mask],
reduction='sum')
loss_tv.backward()
optimizer.step()
return self(U).data, self.W.weight.data