Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] add ogc method #6437

Merged
merged 13 commits into from
Nov 14, 2023
Merged
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ The folder contains example implementations of selected research papers related
* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)

To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).

## 2023

- <a name="labor"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/pytorch/ogc)

- Tags: semi-supervised node classification

## 2022
- <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/labor/train_lightning.py)
Expand Down
44 changes: 44 additions & 0 deletions examples/pytorch/ogc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Optimized Graph Convolution (OGC)

This DGL example implements the OGC method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599).
With only one trainable layer, OGC is a very simple but powerful graph convolution method.


## Example Implementor
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved

This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU.


## Dependencies

Python 3.11.5
PyTorch 2.0.1
DGL 1.1.2
scikit-learn 1.3.1


## Dataset

The DGL's built-in Cora, Pubmed and Citeseer datasets, as follows:

| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |
| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |
| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |


## Usage

```bash
python main.py --dataset cora
python main.py --dataset citeseer
python main.py --dataset pubmed
```

## Performance

| Dataset | Cora | Citeseer | Pubmed |
| :-: | :-: | :-: | :-: |
| OGC (DGL) | **86.9(±0.2)** | **77.4(±0.1)** | **83.6(±0.1)** |
| OGC (Reported) | **86.9(±0.0)** | **77.4(±0.0)** | 83.4(±0.0) |
44 changes: 44 additions & 0 deletions examples/pytorch/ogc/ogc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import dgl.sparse as dglsp
import torch.nn as nn
import torch.nn.functional as F

from utils import LinearNeuralNetwork


class OGC(nn.Module):
def __init__(self, graph):
super(OGC, self).__init__()
self.linear_clf = LinearNeuralNetwork(
nfeat=graph.ndata["feat"].shape[1],
nclass=graph.ndata["label"].max().item() + 1,
bias=False,
)

self.label = graph.ndata["label"]
self.label_one_hot = F.one_hot(graph.ndata["label"]).float()
# LIM trick, else use both train and val set to construct this matrix.
self.label_idx_mat = dglsp.diag(graph.ndata["train_mask"]).float()

self.test_mask = graph.ndata["test_mask"]
self.tv_mask = graph.ndata["train_mask"] + graph.ndata["val_mask"]

def forward(self, x):
return self.linear_clf(x)

def update_embeds(self, embeds, lazy_adj, args):
"""Update classifier's weight by training a linear supervised model."""
pred_label = self(embeds).data
clf_weight = self.linear_clf.W.weight.data

# Update the smoothness loss via LGC.
embeds = dglsp.spmm(lazy_adj, embeds)

# Update the supervised loss via SEB.
deriv_sup = 2 * dglsp.matmul(
dglsp.spmm(self.label_idx_mat, -self.label_one_hot + pred_label),
clf_weight,
)
embeds = embeds - args.lr_sup * deriv_sup

args.lr_sup = args.lr_sup * args.decline
return embeds
126 changes: 126 additions & 0 deletions examples/pytorch/ogc/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse
import time

import dgl.sparse as dglsp

import torch.nn.functional as F
import torch.optim as optim
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

from ogc import OGC
from utils import model_test, symmetric_normalize_adjacency


def train(model, embeds, lazy_adj, args):
patience = 0
_, _, last_acc, last_output = model_test(model, embeds)

tv_mask = model.tv_mask
optimizer = optim.SGD(model.parameters(), lr=args.lr_clf)

for i in range(64):
model.train()
output = model(embeds)
loss_tv = F.mse_loss(
output[tv_mask], model.label_one_hot[tv_mask], reduction="sum"
)
optimizer.zero_grad()
loss_tv.backward()
optimizer.step()

# Updating node embeds by LGC and SEB jointly.
embeds = model.update_embeds(embeds, lazy_adj, args)

loss_tv, acc_tv, acc_test, pred = model_test(model, embeds)
print(
"epoch {} loss_tv {:.4f} acc_tv {:.4f} acc_test {:.4f}".format(
i + 1, loss_tv, acc_tv, acc_test
)
)

sim_rate = float(int((pred == last_output).sum()) / int(pred.shape[0]))
if sim_rate > args.max_sim_rate:
patience += 1
if patience > args.max_patience:
break
last_acc = acc_test
last_output = pred
return last_acc


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="citeseer",
choices=["cora", "citeseer", "pubmed"],
help="dataset to use",
)
parser.add_argument(
"--decline", type=float, default=0.9, help="decline rate"
)
parser.add_argument(
"--lr_sup",
type=float,
default=0.001,
help="learning rate for supervised loss",
)
parser.add_argument(
"--lr_clf",
type=float,
default=0.5,
help="learning rate for the used linear classifier",
)
parser.add_argument(
"--beta",
type=float,
default=0.1,
help="moving probability that a node moves to its neighbors",
)
parser.add_argument(
"--max_sim_rate",
type=float,
default=0.995,
help="max label prediction similarity between iterations",
)
parser.add_argument(
"--max_patience",
type=int,
default=2,
help="tolerance for consecutively similar test predictions",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="device to use",
)
args, _ = parser.parse_known_args()

# Load and preprocess dataset.
transform = AddSelfLoop()
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))
graph = data[0].to(args.device)
features = graph.ndata["feat"]
adj = symmetric_normalize_adjacency(graph)
I_N = dglsp.identity((features.shape[0], features.shape[0]))
# Lazy random walk (also known as lazy graph convolution).
lazy_adj = dglsp.add((1 - args.beta) * I_N, args.beta * adj).to(args.device)

model = OGC(graph).to(args.device)
start_time = time.time()
res = train(model, features, lazy_adj, args)
time_tot = time.time() - start_time

print(f"Test Acc:{res:.4f}")
print(f"Total Time:{time_tot:.4f}")
35 changes: 35 additions & 0 deletions examples/pytorch/ogc/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearNeuralNetwork(nn.Module):
def __init__(self, nfeat, nclass, bias=True):
super(LinearNeuralNetwork, self).__init__()
self.W = nn.Linear(nfeat, nclass, bias=bias)

def forward(self, x):
return self.W(x)


def symmetric_normalize_adjacency(graph):
"""Symmetric normalize graph adjacency matrix."""
indices = torch.stack(graph.edges())
n = graph.num_nodes()
adj = dglsp.spmatrix(indices, shape=(n, n))
deg_invsqrt = dglsp.diag(adj.sum(0)) ** -0.5
return deg_invsqrt @ adj @ deg_invsqrt


def model_test(model, embeds):
model.eval()
with torch.no_grad():
output = model(embeds)
pred = output.argmax(dim=-1)
test_mask, tv_mask = model.test_mask, model.tv_mask
loss_tv = F.mse_loss(output[tv_mask], model.label_one_hot[tv_mask])
accs = []
for mask in [tv_mask, test_mask]:
accs.append(float((pred[mask] == model.label[mask]).sum() / mask.sum()))
return loss_tv.item(), accs[0], accs[1], pred