diff --git a/docs/source/api/python/nn-pytorch.rst b/docs/source/api/python/nn-pytorch.rst index 7416624066e7..ed9a41bd1aa7 100644 --- a/docs/source/api/python/nn-pytorch.rst +++ b/docs/source/api/python/nn-pytorch.rst @@ -14,7 +14,6 @@ Conv Layers ~dgl.nn.pytorch.conv.GraphConv ~dgl.nn.pytorch.conv.EdgeWeightNorm ~dgl.nn.pytorch.conv.RelGraphConv - ~dgl.nn.pytorch.conv.CuGraphRelGraphConv ~dgl.nn.pytorch.conv.TAGConv ~dgl.nn.pytorch.conv.GATConv ~dgl.nn.pytorch.conv.GATv2Conv @@ -42,6 +41,17 @@ Conv Layers ~dgl.nn.pytorch.conv.PNAConv ~dgl.nn.pytorch.conv.DGNConv +CuGraph Conv Layers +---------------------------------------- + +.. autosummary:: + :toctree: ../../generated/ + :nosignatures: + :template: classtemplate.rst + + ~dgl.nn.pytorch.conv.CuGraphRelGraphConv + ~dgl.nn.pytorch.conv.CuGraphSAGEConv + Dense Conv Layers ---------------------------------------- diff --git a/examples/advanced/cugraph/graphsage.py b/examples/advanced/cugraph/graphsage.py new file mode 100644 index 000000000000..bb9e12af7173 --- /dev/null +++ b/examples/advanced/cugraph/graphsage.py @@ -0,0 +1,200 @@ +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import tqdm +from dgl.data import AsNodePredDataset +from dgl.dataloading import ( + DataLoader, + MultiLayerFullNeighborSampler, + NeighborSampler, +) +from dgl.nn import CuGraphSAGEConv +from ogb.nodeproppred import DglNodePropPredDataset + + +class SAGE(nn.Module): + def __init__(self, in_size, hid_size, out_size): + super().__init__() + self.layers = nn.ModuleList() + # three-layer GraphSAGE-mean + self.layers.append(CuGraphSAGEConv(in_size, hid_size, "mean")) + self.layers.append(CuGraphSAGEConv(hid_size, hid_size, "mean")) + self.layers.append(CuGraphSAGEConv(hid_size, out_size, "mean")) + self.dropout = nn.Dropout(0.5) + self.hid_size = hid_size + self.out_size = out_size + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + return h + + def inference(self, g, device, batch_size): + """Conduct layer-wise inference to get all the node embeddings.""" + feat = g.ndata["feat"] + sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"]) + dataloader = DataLoader( + g, + torch.arange(g.num_nodes()).to(g.device), + sampler, + device=device, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=0, + ) + buffer_device = torch.device("cpu") + pin_memory = buffer_device != device + + for l, layer in enumerate(self.layers): + y = torch.empty( + g.num_nodes(), + self.hid_size if l != len(self.layers) - 1 else self.out_size, + device=buffer_device, + pin_memory=pin_memory, + ) + feat = feat.to(device) + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = feat[input_nodes] + h = layer(blocks[0], x) # len(blocks) = 1 + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + # by design, our output nodes are contiguous + y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device) + feat = y + return y + + +def evaluate(model, graph, dataloader): + model.eval() + ys = [] + y_hats = [] + for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + with torch.no_grad(): + x = blocks[0].srcdata["feat"] + ys.append(blocks[-1].dstdata["label"]) + y_hats.append(model(blocks, x)) + num_classes = y_hats[0].shape[1] + return MF.accuracy( + torch.cat(y_hats), + torch.cat(ys), + task="multiclass", + num_classes=num_classes, + ) + + +def layerwise_infer(device, graph, nid, model, batch_size): + model.eval() + with torch.no_grad(): + pred = model.inference( + graph, device, batch_size + ) # pred in buffer_device + pred = pred[nid] + label = graph.ndata["label"][nid].to(pred.device) + num_classes = pred.shape[1] + return MF.accuracy( + pred, label, task="multiclass", num_classes=num_classes + ) + + +def train(args, device, g, dataset, model): + # create sampler & dataloader + train_idx = dataset.train_idx.to(device) + val_idx = dataset.val_idx.to(device) + sampler = NeighborSampler( + [10, 10, 10], # fanout for [layer-0, layer-1, layer-2] + prefetch_node_feats=["feat"], + prefetch_labels=["label"], + ) + use_uva = args.mode == "mixed" + train_dataloader = DataLoader( + g, + train_idx, + sampler, + device=device, + batch_size=1024, + shuffle=True, + drop_last=False, + num_workers=0, + use_uva=use_uva, + ) + + val_dataloader = DataLoader( + g, + val_idx, + sampler, + device=device, + batch_size=1024, + shuffle=True, + drop_last=False, + num_workers=0, + use_uva=use_uva, + ) + + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) + + for epoch in range(10): + model.train() + total_loss = 0 + for it, (input_nodes, output_nodes, blocks) in enumerate( + train_dataloader + ): + x = blocks[0].srcdata["feat"] + y = blocks[-1].dstdata["label"] + y_hat = model(blocks, x) + loss = F.cross_entropy(y_hat, y) + opt.zero_grad() + loss.backward() + opt.step() + + total_loss += loss.item() + acc = evaluate(model, g, val_dataloader) + print( + "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format( + epoch, total_loss / (it + 1), acc.item() + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + default="mixed", + choices=["mixed", "puregpu"], + help="Training mode. 'mixed' for CPU-GPU mixed training, " + "'puregpu' for pure-GPU training.", + ) + args = parser.parse_args() + if not torch.cuda.is_available(): + args.mode = "cpu" + print(f"Training in {args.mode} mode.") + + # load and preprocess dataset + print("Loading data") + dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products")) + g = dataset[0] + g = g.to("cuda" if args.mode == "puregpu" else "cpu") + device = torch.device("cpu" if args.mode == "cpu" else "cuda") + + # create GraphSAGE model + in_size = g.ndata["feat"].shape[1] + out_size = dataset.num_classes + model = SAGE(in_size, 256, out_size).to(device) + + # model training + print("Training...") + train(args, device, g, dataset, model) + + # test the model + print("Testing...") + acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096) + print("Test Accuracy {:.4f}".format(acc.item())) diff --git a/python/dgl/nn/pytorch/conv/__init__.py b/python/dgl/nn/pytorch/conv/__init__.py index 938fd5589373..4cccd8bcee9b 100644 --- a/python/dgl/nn/pytorch/conv/__init__.py +++ b/python/dgl/nn/pytorch/conv/__init__.py @@ -7,6 +7,7 @@ from .cfconv import CFConv from .chebconv import ChebConv from .cugraph_relgraphconv import CuGraphRelGraphConv +from .cugraph_sageconv import CuGraphSAGEConv from .densechebconv import DenseChebConv from .densegraphconv import DenseGraphConv from .densesageconv import DenseSAGEConv @@ -67,4 +68,5 @@ "PNAConv", "DGNConv", "CuGraphRelGraphConv", + "CuGraphSAGEConv", ] diff --git a/python/dgl/nn/pytorch/conv/cugraph_sageconv.py b/python/dgl/nn/pytorch/conv/cugraph_sageconv.py new file mode 100644 index 000000000000..b15aca8a9ce6 --- /dev/null +++ b/python/dgl/nn/pytorch/conv/cugraph_sageconv.py @@ -0,0 +1,153 @@ +"""Torch Module for GraphSAGE layer using the aggregation primitives in +cugraph-ops""" +# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments + +import torch +from torch import nn + +try: + from pylibcugraphops import make_fg_csr, make_mfg_csr + from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg +except ImportError: + has_pylibcugraphops = False +else: + has_pylibcugraphops = True + + +class CuGraphSAGEConv(nn.Module): + r"""An accelerated GraphSAGE layer from `Inductive Representation Learning + on Large Graphs `__ that leverages the + highly-optimized aggregation primitives in cugraph-ops: + + .. math:: + h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} + \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) + + h_{i}^{(l+1)} &= W \cdot \mathrm{concat} + (h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)}) + + This module depends on :code:`pylibcugraphops` package, which can be + installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. + + .. note:: + This is an **experimental** feature. + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + aggregator_type : str + Aggregator type to use (``mean``, ``sum``, ``min``, ``max``). + feat_drop : float + Dropout rate on features, default: ``0``. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + + Examples + -------- + >>> import dgl + >>> import torch + >>> from dgl.nn import CuGraphSAGEConv + >>> device = 'cuda' + >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device) + >>> g = dgl.add_self_loop(g) + >>> feat = torch.ones(6, 10).to(device) + >>> conv = CuGraphSAGEConv(10, 2, 'mean').to(device) + >>> res = conv(g, feat) + >>> res + tensor([[-1.1690, 0.1952], + [-1.1690, 0.1952], + [-1.1690, 0.1952], + [-1.1690, 0.1952], + [-1.1690, 0.1952], + [-1.1690, 0.1952]], device='cuda:0', grad_fn=) + """ + MAX_IN_DEGREE_MFG = 500 + + def __init__( + self, + in_feats, + out_feats, + aggregator_type="mean", + feat_drop=0.0, + bias=True, + ): + if has_pylibcugraphops is False: + raise ModuleNotFoundError( + f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. " + f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`." + ) + + valid_aggr_types = {"max", "min", "mean", "sum"} + if aggregator_type not in valid_aggr_types: + raise ValueError( + f"Invalid aggregator_type. Must be one of {valid_aggr_types}. " + f"But got '{aggregator_type}' instead." + ) + + super().__init__() + self.in_feats = in_feats + self.out_feats = out_feats + self.aggr = aggregator_type + self.feat_drop = nn.Dropout(feat_drop) + self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias) + + def reset_parameters(self): + r"""Reinitialize learnable parameters.""" + self.linear.reset_parameters() + + def forward(self, g, feat, max_in_degree=None): + r"""Forward computation. + + Parameters + ---------- + g : DGLGraph + The graph. + feat : torch.Tensor + Node features. Shape: :math:`(N, D_{in})`. + max_in_degree : int + Maximum in-degree of destination nodes. It is only effective when + :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When + :attr:`g` is generated from a neighbor sampler, the value should be + set to the corresponding :attr:`fanout`. If not given, + :attr:`max_in_degree` will be calculated on-the-fly. + + Returns + ------- + torch.Tensor + Output node features. Shape: :math:`(N, D_{out})`. + """ + offsets, indices, _ = g.adj_sparse("csc") + + if g.is_block: + if max_in_degree is None: + max_in_degree = g.in_degrees().max().item() + + if max_in_degree < self.MAX_IN_DEGREE_MFG: + _graph = make_mfg_csr( + g.dstnodes(), + offsets, + indices, + max_in_degree, + g.num_src_nodes(), + ) + else: + offsets_fg = torch.empty( + g.num_src_nodes() + 1, + dtype=offsets.dtype, + device=offsets.device, + ) + offsets_fg[: offsets.numel()] = offsets + offsets_fg[offsets.numel() :] = offsets[-1] + + _graph = make_fg_csr(offsets_fg, indices) + else: + _graph = make_fg_csr(offsets, indices) + + feat = self.feat_drop(feat) + h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()] + h = self.linear(h) + + return h diff --git a/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py b/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py new file mode 100644 index 000000000000..5ebad58f9825 --- /dev/null +++ b/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py @@ -0,0 +1,69 @@ +# pylint: disable=too-many-arguments, too-many-locals +from collections import OrderedDict +from itertools import product + +import dgl +import pytest +import torch +from dgl.nn import CuGraphSAGEConv, SAGEConv + +options = OrderedDict( + { + "idtype_int": [False, True], + "max_in_degree": [None, 8], + "to_block": [False, True], + } +) + + +def generate_graph(): + u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9]) + v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0]) + g = dgl.graph((u, v)) + return g + + +@pytest.mark.skip() +@pytest.mark.parametrize(",".join(options.keys()), product(*options.values())) +def test_SAGEConv_equality(idtype_int, max_in_degree, to_block): + device = "cuda:0" + in_feat, out_feat = 5, 2 + kwargs = {"aggregator_type": "mean"} + g = generate_graph().to(device) + if idtype_int: + g = g.int() + if to_block: + g = dgl.to_block(g) + feat = torch.rand(g.num_src_nodes(), in_feat).to(device) + + torch.manual_seed(0) + conv1 = SAGEConv(in_feat, out_feat, **kwargs).to(device) + + torch.manual_seed(0) + conv2 = CuGraphSAGEConv(in_feat, out_feat, **kwargs).to(device) + + with torch.no_grad(): + conv2.linear.weight.data[:, :in_feat] = conv1.fc_neigh.weight.data + conv2.linear.weight.data[:, in_feat:] = conv1.fc_self.weight.data + conv2.linear.bias.data[:] = conv1.fc_self.bias.data + + out1 = conv1(g, feat) + out2 = conv2(g, feat, max_in_degree=max_in_degree) + assert torch.allclose(out1, out2, atol=1e-06) + + grad_out = torch.rand_like(out1) + out1.backward(grad_out) + out2.backward(grad_out) + assert torch.allclose( + conv1.fc_neigh.weight.grad, + conv2.linear.weight.grad[:, :in_feat], + atol=1e-6, + ) + assert torch.allclose( + conv1.fc_self.weight.grad, + conv2.linear.weight.grad[:, in_feat:], + atol=1e-6, + ) + assert torch.allclose( + conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6 + )