From 05aebd80017a412138c782ee4096210087f72550 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 4 May 2023 05:11:22 -0400 Subject: [PATCH] [Model] Update cugraph-ops models for 23.04 release (#5540) Co-authored-by: schmidt-ju <79865721+schmidt-ju@users.noreply.github.com> Co-authored-by: Quan (Andy) Gan Co-authored-by: Mufei Li Co-authored-by: Ubuntu Co-authored-by: peizhou001 <110809584+peizhou001@users.noreply.github.com> Co-authored-by: Ubuntu --- python/dgl/nn/pytorch/conv/cugraph_base.py | 57 +++++++++++++++++ python/dgl/nn/pytorch/conv/cugraph_gatconv.py | 43 ++++++------- .../nn/pytorch/conv/cugraph_relgraphconv.py | 61 ++++++++----------- .../dgl/nn/pytorch/conv/cugraph_sageconv.py | 41 ++++++------- .../cugraph-ops/test_cugraph_gatconv.py | 1 - .../cugraph-ops/test_cugraph_relgraphconv.py | 1 - .../cugraph-ops/test_cugraph_sageconv.py | 1 - 7 files changed, 118 insertions(+), 87 deletions(-) create mode 100644 python/dgl/nn/pytorch/conv/cugraph_base.py diff --git a/python/dgl/nn/pytorch/conv/cugraph_base.py b/python/dgl/nn/pytorch/conv/cugraph_base.py new file mode 100644 index 000000000000..8fab0dbb2fcc --- /dev/null +++ b/python/dgl/nn/pytorch/conv/cugraph_base.py @@ -0,0 +1,57 @@ +"""An abstract base class for cugraph-ops nn module.""" +import torch +from torch import nn + + +class CuGraphBaseConv(nn.Module): + r"""An abstract base class for cugraph-ops nn module.""" + + def __init__(self): + super().__init__() + self._cached_offsets_fg = None + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + raise NotImplementedError + + def forward(self, *args): + r"""Runs the forward pass of the module.""" + raise NotImplementedError + + def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor: + r"""Pad zero-in-degree nodes to the end of offsets to reach size. + + cugraph-ops often provides two variants of aggregation functions for a + specific model: one intended for sampled-graph use cases, one for + full-graph ones. The former is in general more performant, however, it + only works when the sample size (the max of in-degrees) is small (<200), + due to the limit of GPU shared memory. For graphs with a larger max + in-degree, we need to fall back to the full-graph option, which requires + to convert a DGL block to a full graph. With the csc-representation, + this is equivalent to pad zero-in-degree nodes to the end of the offsets + array (also called indptr or colptr). + + Parameters + ---------- + offsets : + The (monotonically increasing) index pointer array in a CSC-format + graph. + size : int + The length of offsets after padding. + + Returns + ------- + torch.Tensor + The augmented offsets array. + """ + if self._cached_offsets_fg is None: + self._cached_offsets_fg = torch.empty( + size, dtype=offsets.dtype, device=offsets.device + ) + elif self._cached_offsets_fg.numel() < size: + self._cached_offsets_fg.resize_(size) + + self._cached_offsets_fg[: offsets.numel()] = offsets + self._cached_offsets_fg[offsets.numel() : size] = offsets[-1] + + return self._cached_offsets_fg[:size] diff --git a/python/dgl/nn/pytorch/conv/cugraph_gatconv.py b/python/dgl/nn/pytorch/conv/cugraph_gatconv.py index 20e47ba9c7d7..6bd428f5339f 100644 --- a/python/dgl/nn/pytorch/conv/cugraph_gatconv.py +++ b/python/dgl/nn/pytorch/conv/cugraph_gatconv.py @@ -5,16 +5,18 @@ import torch from torch import nn +from .cugraph_base import CuGraphBaseConv + try: - from pylibcugraphops import make_fg_csr, make_mfg_csr - from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg + from pylibcugraphops.pytorch import SampledCSC, StaticCSC + from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg + + HAS_PYLIBCUGRAPHOPS = True except ImportError: - has_pylibcugraphops = False -else: - has_pylibcugraphops = True + HAS_PYLIBCUGRAPHOPS = False -class CuGraphGATConv(nn.Module): +class CuGraphGATConv(CuGraphBaseConv): r"""Graph attention layer from `Graph Attention Networks `__, with the sparse aggregation accelerated by cugraph-ops. @@ -22,7 +24,8 @@ class CuGraphGATConv(nn.Module): See :class:`dgl.nn.pytorch.conv.GATConv` for mathematical model. This module depends on :code:`pylibcugraphops` package, which can be - installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. + installed via :code:`conda install -c nvidia pylibcugraphops=23.04`. + :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x. .. note:: This is an **experimental** feature. @@ -78,7 +81,7 @@ class CuGraphGATConv(nn.Module): [ 1.6477, -1.9986], [ 1.1138, -1.9302]]], device='cuda:0', grad_fn=) """ - MAX_IN_DEGREE_MFG = 500 + MAX_IN_DEGREE_MFG = 200 def __init__( self, @@ -91,10 +94,11 @@ def __init__( activation=None, bias=True, ): - if has_pylibcugraphops is False: + if HAS_PYLIBCUGRAPHOPS is False: raise ModuleNotFoundError( - f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. " - f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`." + f"{self.__class__.__name__} requires pylibcugraphops=23.04. " + f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`." + f"pylibcugraphops requires Python 3.8 or 3.10." ) super().__init__() self.in_feats = in_feats @@ -170,25 +174,17 @@ def forward(self, g, feat, max_in_degree=None): max_in_degree = g.in_degrees().max().item() if max_in_degree < self.MAX_IN_DEGREE_MFG: - _graph = make_mfg_csr( - g.dstnodes(), + _graph = SampledCSC( offsets, indices, max_in_degree, g.num_src_nodes(), ) else: - offsets_fg = torch.empty( - g.num_src_nodes() + 1, - dtype=offsets.dtype, - device=offsets.device, - ) - offsets_fg[: offsets.numel()] = offsets - offsets_fg[offsets.numel() :] = offsets[-1] - - _graph = make_fg_csr(offsets_fg, indices) + offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1) + _graph = StaticCSC(offsets_fg, indices) else: - _graph = make_fg_csr(offsets, indices) + _graph = StaticCSC(offsets, indices) feat = self.feat_drop(feat) feat_transformed = self.fc(feat) @@ -199,7 +195,6 @@ def forward(self, g, feat, max_in_degree=None): self.num_heads, "LeakyReLU", self.negative_slope, - add_own_node=False, concat_heads=True, )[: g.num_dst_nodes()].view(-1, self.num_heads, self.out_feats) diff --git a/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py b/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py index 251aba0103f1..3b1adc73090c 100644 --- a/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py +++ b/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py @@ -6,18 +6,20 @@ import torch from torch import nn +from .cugraph_base import CuGraphBaseConv + try: - from pylibcugraphops import make_fg_csr_hg, make_mfg_csr_hg - from pylibcugraphops.torch.autograd import ( + from pylibcugraphops.pytorch import SampledHeteroCSC, StaticHeteroCSC + from pylibcugraphops.pytorch.operators import ( agg_hg_basis_n2n_post as RelGraphConvAgg, ) + + HAS_PYLIBCUGRAPHOPS = True except ImportError: - has_pylibcugraphops = False -else: - has_pylibcugraphops = True + HAS_PYLIBCUGRAPHOPS = False -class CuGraphRelGraphConv(nn.Module): +class CuGraphRelGraphConv(CuGraphBaseConv): r"""An accelerated relational graph convolution layer from `Modeling Relational Data with Graph Convolutional Networks `__ that leverages the highly-optimized @@ -26,7 +28,8 @@ class CuGraphRelGraphConv(nn.Module): See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model. This module depends on :code:`pylibcugraphops` package, which can be - installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. + installed via :code:`conda install -c nvidia pylibcugraphops=23.04`. + :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x. .. note:: This is an **experimental** feature. @@ -92,10 +95,11 @@ def __init__( dropout=0.0, apply_norm=False, ): - if has_pylibcugraphops is False: + if HAS_PYLIBCUGRAPHOPS is False: raise ModuleNotFoundError( - f"{self.__class__.__name__} requires pylibcugraphops >= 23.02 " - f"to be installed." + f"{self.__class__.__name__} requires pylibcugraphops=23.04. " + f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`." + f"pylibcugraphops requires Python 3.8 or 3.10." ) super().__init__() self.in_feat = in_feat @@ -176,53 +180,36 @@ def forward(self, g, feat, etypes, max_in_degree=None): torch.Tensor New node features. Shape: :math:`(|V|, D_{out})`. """ - # Create csc-representation and cast etypes to int32. offsets, indices, edge_ids = g.adj_tensors("csc") edge_types_perm = etypes[edge_ids.long()].int() - # Create cugraph-ops graph. if g.is_block: if max_in_degree is None: max_in_degree = g.in_degrees().max().item() if max_in_degree < self.MAX_IN_DEGREE_MFG: - _graph = make_mfg_csr_hg( - g.dstnodes(), + _graph = SampledHeteroCSC( offsets, indices, + edge_types_perm, max_in_degree, g.num_src_nodes(), - n_node_types=0, - n_edge_types=self.num_rels, - out_node_types=None, - in_node_types=None, - edge_types=edge_types_perm, + self.num_rels, ) else: - offsets_fg = torch.empty( - g.num_src_nodes() + 1, - dtype=offsets.dtype, - device=offsets.device, - ) - offsets_fg[: offsets.numel()] = offsets - offsets_fg[offsets.numel() :] = offsets[-1] - - _graph = make_fg_csr_hg( + offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1) + _graph = StaticHeteroCSC( offsets_fg, indices, - n_node_types=0, - n_edge_types=self.num_rels, - node_types=None, - edge_types=edge_types_perm, + edge_types_perm, + self.num_rels, ) else: - _graph = make_fg_csr_hg( + _graph = StaticHeteroCSC( offsets, indices, - n_node_types=0, - n_edge_types=self.num_rels, - node_types=None, - edge_types=edge_types_perm, + edge_types_perm, + self.num_rels, ) h = RelGraphConvAgg( diff --git a/python/dgl/nn/pytorch/conv/cugraph_sageconv.py b/python/dgl/nn/pytorch/conv/cugraph_sageconv.py index 86ce1b7d13ef..ffcab632e0dc 100644 --- a/python/dgl/nn/pytorch/conv/cugraph_sageconv.py +++ b/python/dgl/nn/pytorch/conv/cugraph_sageconv.py @@ -2,19 +2,20 @@ cugraph-ops""" # pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments -import torch from torch import nn +from .cugraph_base import CuGraphBaseConv + try: - from pylibcugraphops import make_fg_csr, make_mfg_csr - from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg + from pylibcugraphops.pytorch import SampledCSC, StaticCSC + from pylibcugraphops.pytorch.operators import agg_concat_n2n as SAGEConvAgg + + HAS_PYLIBCUGRAPHOPS = True except ImportError: - has_pylibcugraphops = False -else: - has_pylibcugraphops = True + HAS_PYLIBCUGRAPHOPS = False -class CuGraphSAGEConv(nn.Module): +class CuGraphSAGEConv(CuGraphBaseConv): r"""An accelerated GraphSAGE layer from `Inductive Representation Learning on Large Graphs `__ that leverages the highly-optimized aggregation primitives in cugraph-ops: @@ -27,7 +28,8 @@ class CuGraphSAGEConv(nn.Module): (h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)}) This module depends on :code:`pylibcugraphops` package, which can be - installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. + installed via :code:`conda install -c nvidia pylibcugraphops=23.04`. + :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x. .. note:: This is an **experimental** feature. @@ -74,10 +76,11 @@ def __init__( feat_drop=0.0, bias=True, ): - if has_pylibcugraphops is False: + if HAS_PYLIBCUGRAPHOPS is False: raise ModuleNotFoundError( - f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. " - f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`." + f"{self.__class__.__name__} requires pylibcugraphops=23.04. " + f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`." + f"pylibcugraphops requires Python 3.8 or 3.10." ) valid_aggr_types = {"max", "min", "mean", "sum"} @@ -126,25 +129,17 @@ def forward(self, g, feat, max_in_degree=None): max_in_degree = g.in_degrees().max().item() if max_in_degree < self.MAX_IN_DEGREE_MFG: - _graph = make_mfg_csr( - g.dstnodes(), + _graph = SampledCSC( offsets, indices, max_in_degree, g.num_src_nodes(), ) else: - offsets_fg = torch.empty( - g.num_src_nodes() + 1, - dtype=offsets.dtype, - device=offsets.device, - ) - offsets_fg[: offsets.numel()] = offsets - offsets_fg[offsets.numel() :] = offsets[-1] - - _graph = make_fg_csr(offsets_fg, indices) + offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1) + _graph = StaticCSC(offsets_fg, indices) else: - _graph = make_fg_csr(offsets, indices) + _graph = StaticCSC(offsets, indices) feat = self.feat_drop(feat) h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()] diff --git a/tests/cugraph/cugraph-ops/test_cugraph_gatconv.py b/tests/cugraph/cugraph-ops/test_cugraph_gatconv.py index 28d79e76bc24..a1b5524638ee 100644 --- a/tests/cugraph/cugraph-ops/test_cugraph_gatconv.py +++ b/tests/cugraph/cugraph-ops/test_cugraph_gatconv.py @@ -24,7 +24,6 @@ def generate_graph(): return g -@pytest.mark.skip() @pytest.mark.parametrize(",".join(options.keys()), product(*options.values())) def test_gatconv_equality(idtype_int, max_in_degree, num_heads, to_block): device = "cuda:0" diff --git a/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py b/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py index ddacf8d27484..2dd1b9260228 100644 --- a/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py +++ b/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py @@ -27,7 +27,6 @@ def generate_graph(): return g -@pytest.mark.skip() @pytest.mark.parametrize(",".join(options.keys()), product(*options.values())) def test_relgraphconv_equality( idtype_int, max_in_degree, num_bases, regularizer, self_loop, to_block diff --git a/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py b/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py index 5ebad58f9825..cec5d3fd68ef 100644 --- a/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py +++ b/tests/cugraph/cugraph-ops/test_cugraph_sageconv.py @@ -23,7 +23,6 @@ def generate_graph(): return g -@pytest.mark.skip() @pytest.mark.parametrize(",".join(options.keys()), product(*options.values())) def test_SAGEConv_equality(idtype_int, max_in_degree, to_block): device = "cuda:0"