Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add dgl.nn.CuGraphSAGEConv model #5137

Merged
merged 14 commits into from
Feb 22, 2023
Merged
6 changes: 5 additions & 1 deletion docs/source/api/python/nn-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Conv Layers
~dgl.nn.pytorch.conv.GraphConv
~dgl.nn.pytorch.conv.EdgeWeightNorm
~dgl.nn.pytorch.conv.RelGraphConv
~dgl.nn.pytorch.conv.CuGraphRelGraphConv
~dgl.nn.pytorch.conv.TAGConv
~dgl.nn.pytorch.conv.GATConv
~dgl.nn.pytorch.conv.GATv2Conv
Expand Down Expand Up @@ -42,6 +41,11 @@ Conv Layers
~dgl.nn.pytorch.conv.PNAConv
~dgl.nn.pytorch.conv.DGNConv

CuGraph Conv Layers
----------------------------------------
tingyu66 marked this conversation as resolved.
Show resolved Hide resolved
~dgl.nn.pytorch.conv.CuGraphRelGraphConv
~dgl.nn.pytorch.conv.CuGraphSAGEConv

Dense Conv Layers
----------------------------------------

Expand Down
200 changes: 200 additions & 0 deletions examples/advanced/cugraph/graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run this script? If so, what performance number did you obtain?

Copy link
Contributor Author

@tingyu66 tingyu66 Feb 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in terms of pure training time (not including dataloading), SAGEConv takes 2.5s per epoch, while CuGraphSAGEConv takes 2.0s, despite the overhead of coo-to-csc conversion. Test accuracy is also the same.

Edit: add timings for both mode in the example

mode mixed (uva) pure gpu
CuGraphSAGEConv 2.0 s 1.2 s
SAGEConv 2.5 s 1.7 s


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from dgl.data import AsNodePredDataset
from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
NeighborSampler,
)
from dgl.nn import CuGraphSAGEConv
from ogb.nodeproppred import DglNodePropPredDataset


class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
self.layers.append(CuGraphSAGEConv(in_size, hid_size, "mean"))
self.layers.append(CuGraphSAGEConv(hid_size, hid_size, "mean"))
self.layers.append(CuGraphSAGEConv(hid_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size

def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h, max_in_degree=10)
Copy link
Member

@mufeili mufeili Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit ugly. Perhaps it's better to pass the argument to SAGE.__init__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what needs to be done here? Are you suggesting to unpack to loop like this?

h = F.relu(self.conv1(g[0], x))
h = F.relu(self.conv2(g[1], h))
...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the specification of max_in_degree.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see and I do agree that it is not an ideal interface. We did not make max_in_degree an attribute of CuGraphSAGEConv since it is a property of the graph (i.e., block), rather than the model. I have removed it from the example as this flag is optional.
In the meantime, we are improving our aggregation primitives to be more flexible to eventually ditch this option.

if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h

def inference(self, g, device, batch_size):
"""Conduct layer-wise inference to get all the node embeddings."""
feat = g.ndata["feat"]
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
dataloader = DataLoader(
g,
torch.arange(g.num_nodes()).to(g.device),
sampler,
device=device,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
)
buffer_device = torch.device("cpu")
pin_memory = buffer_device != device

for l, layer in enumerate(self.layers):
y = torch.empty(
g.num_nodes(),
self.hid_size if l != len(self.layers) - 1 else self.out_size,
device=buffer_device,
pin_memory=pin_memory,
)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = feat[input_nodes]
h = layer(blocks[0], x) # len(blocks) = 1
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
# by design, our output nodes are contiguous
y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
feat = y
return y


def evaluate(model, graph, dataloader):
model.eval()
ys = []
y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
with torch.no_grad():
x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x))
num_classes = y_hats[0].shape[1]
return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)


def layerwise_infer(device, graph, nid, model, batch_size):
model.eval()
with torch.no_grad():
pred = model.inference(
graph, device, batch_size
) # pred in buffer_device
pred = pred[nid]
label = graph.ndata["label"][nid].to(pred.device)
num_classes = pred.shape[1]
return MF.accuracy(
pred, label, task="multiclass", num_classes=num_classes
)


def train(args, device, g, dataset, model):
# create sampler & dataloader
train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device)
sampler = NeighborSampler(
[10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats=["feat"],
prefetch_labels=["label"],
)
use_uva = args.mode == "mixed"
train_dataloader = DataLoader(
g,
train_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)

val_dataloader = DataLoader(
g,
val_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)

opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

for epoch in range(10):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(
train_dataloader
):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()

total_loss += loss.item()
acc = evaluate(model, g, val_dataloader)
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc.item()
)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
default="mixed",
choices=["cpu", "mixed", "puregpu"],
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is automatically formatted by lintrunner. I removed the cpu mode, as it is not supported by the model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes pushed.

)
args = parser.parse_args()
if not torch.cuda.is_available():
args.mode = "cpu"
print(f"Training in {args.mode} mode.")

# load and preprocess dataset
print("Loading data")
dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
g = dataset[0]
g = g.to("cuda" if args.mode == "puregpu" else "cpu")
device = torch.device("cpu" if args.mode == "cpu" else "cuda")

# create GraphSAGE model
in_size = g.ndata["feat"].shape[1]
out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size).to(device)

# model training
print("Training...")
train(args, device, g, dataset, model)

# test the model
print("Testing...")
acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)
print("Test Accuracy {:.4f}".format(acc.item()))
2 changes: 2 additions & 0 deletions python/dgl/nn/pytorch/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .cfconv import CFConv
from .chebconv import ChebConv
from .cugraph_relgraphconv import CuGraphRelGraphConv
from .cugraph_sageconv import CuGraphSAGEConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
Expand Down Expand Up @@ -67,4 +68,5 @@
"PNAConv",
"DGNConv",
"CuGraphRelGraphConv",
"CuGraphSAGEConv",
]
153 changes: 153 additions & 0 deletions python/dgl/nn/pytorch/conv/cugraph_sageconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Torch Module for GraphSAGE layer using the aggregation primitives in
cugraph-ops"""
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments

import torch
from torch import nn

try:
from pylibcugraphops import make_fg_csr, make_mfg_csr
from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg
except ImportError:
has_pylibcugraphops = False
else:
has_pylibcugraphops = True


class CuGraphSAGEConv(nn.Module):
r"""An accelerated GraphSAGE layer from `Inductive Representation Learning
on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ that leverages the
highly-optimized aggregation primitives in cugraph-ops:

.. math::
h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

h_{i}^{(l+1)} &= W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)})

This module depends on :code:`pylibcugraphops` package, which can be
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.

.. note::
This is an **experimental** feature.

Parameters
----------
in_feats : int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SageConv supports source and destination nodes with different feature size. I assume this is not the case for this implementation.

Input feature size.
out_feats : int
Output feature size.
aggregator_type : str
Aggregator type to use (``mean``, ``sum``, ``min``, ``max``).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The options seem to be different from the ones for GraphConv, which are mean, gcn, pool, lstm.

feat_drop : float
Dropout rate on features, default: ``0``.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
tingyu66 marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import CuGraphSAGEConv
>>> device = 'cuda'
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)
>>> g = dgl.add_self_loop(g)
>>> feat = torch.ones(6, 10).to(device)
>>> conv = CuGraphSAGEConv(10, 2, 'mean').to(device)
>>> res = conv(g, feat)
>>> res
tensor([[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952]], device='cuda:0', grad_fn=<AddmmBackward0>)
"""
MAX_IN_DEGREE_MFG = 500

def __init__(
self,
in_feats,
out_feats,
aggregator_type="mean",
feat_drop=0.0,
bias=True,
):
if has_pylibcugraphops is False:
raise ModuleNotFoundError(
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
)

valid_aggr_types = {"max", "min", "mean", "sum"}
if aggregator_type not in valid_aggr_types:
raise ValueError(
f"Invalid aggregator_type. Must be one of {valid_aggr_types}. "
f"But got '{aggregator_type}' instead."
)

super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.aggr = aggregator_type
self.feat_drop = nn.Dropout(feat_drop)
self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias)

def reset_parameters(self):
r"""Reinitialize learnable parameters."""
self.linear.reset_parameters()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously SageConv considers Xavier uniform while nn.Linear.reset_parameters considers Kaiming uniform. I'm not sure about the effects of this difference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Kaiming is more suitable here as ReLU is often the choice for the nonlinearity in GNN; Xavier was designed for sigmoid function.


def forward(self, g, feat, max_in_degree=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another difference, lack of support for edge_weight

r"""Forward computation.

Parameters
----------
g : DGLGraph
The graph.
feat : torch.Tensor
Node features. Shape: :math:`(N, D_{in})`.
max_in_degree : int
Maximum in-degree of destination nodes. It is only effective when
:attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When
:attr:`g` is generated from a neighbor sampler, the value should be
set to the corresponding :attr:`fanout`. If not given,
:attr:`max_in_degree` will be calculated on-the-fly.

Returns
-------
torch.Tensor
Output node features. Shape: :math:`(N, D_{out})`.
"""
offsets, indices, _ = g.adj_sparse("csc")

if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = make_mfg_csr(
g.dstnodes(),
offsets,
indices,
max_in_degree,
g.num_src_nodes(),
)
else:
offsets_fg = torch.empty(
g.num_src_nodes() + 1,
dtype=offsets.dtype,
device=offsets.device,
)
offsets_fg[: offsets.numel()] = offsets
offsets_fg[offsets.numel() :] = offsets[-1]

_graph = make_fg_csr(offsets_fg, indices)
else:
_graph = make_fg_csr(offsets, indices)

feat = self.feat_drop(feat)
h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()]
tingyu66 marked this conversation as resolved.
Show resolved Hide resolved
h = self.linear(h)

return h
Loading