Skip to content

Commit

Permalink
[Model] Update cugraph-ops models for 23.04 release (#5540)
Browse files Browse the repository at this point in the history
Co-authored-by: schmidt-ju <79865721+schmidt-ju@users.noreply.github.com>
Co-authored-by: Quan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: Mufei Li <mufeili1996@gmail.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
Co-authored-by: peizhou001 <110809584+peizhou001@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
  • Loading branch information
7 people authored and Rhett-Ying committed Jun 23, 2023
1 parent 9892abd commit 05aebd8
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 87 deletions.
57 changes: 57 additions & 0 deletions python/dgl/nn/pytorch/conv/cugraph_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""An abstract base class for cugraph-ops nn module."""
import torch
from torch import nn


class CuGraphBaseConv(nn.Module):
r"""An abstract base class for cugraph-ops nn module."""

def __init__(self):
super().__init__()
self._cached_offsets_fg = None

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
raise NotImplementedError

def forward(self, *args):
r"""Runs the forward pass of the module."""
raise NotImplementedError

def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
r"""Pad zero-in-degree nodes to the end of offsets to reach size.
cugraph-ops often provides two variants of aggregation functions for a
specific model: one intended for sampled-graph use cases, one for
full-graph ones. The former is in general more performant, however, it
only works when the sample size (the max of in-degrees) is small (<200),
due to the limit of GPU shared memory. For graphs with a larger max
in-degree, we need to fall back to the full-graph option, which requires
to convert a DGL block to a full graph. With the csc-representation,
this is equivalent to pad zero-in-degree nodes to the end of the offsets
array (also called indptr or colptr).
Parameters
----------
offsets :
The (monotonically increasing) index pointer array in a CSC-format
graph.
size : int
The length of offsets after padding.
Returns
-------
torch.Tensor
The augmented offsets array.
"""
if self._cached_offsets_fg is None:
self._cached_offsets_fg = torch.empty(
size, dtype=offsets.dtype, device=offsets.device
)
elif self._cached_offsets_fg.numel() < size:
self._cached_offsets_fg.resize_(size)

self._cached_offsets_fg[: offsets.numel()] = offsets
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]

return self._cached_offsets_fg[:size]
43 changes: 19 additions & 24 deletions python/dgl/nn/pytorch/conv/cugraph_gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@
import torch
from torch import nn

from .cugraph_base import CuGraphBaseConv

try:
from pylibcugraphops import make_fg_csr, make_mfg_csr
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
from pylibcugraphops.pytorch import SampledCSC, StaticCSC
from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg

HAS_PYLIBCUGRAPHOPS = True
except ImportError:
has_pylibcugraphops = False
else:
has_pylibcugraphops = True
HAS_PYLIBCUGRAPHOPS = False


class CuGraphGATConv(nn.Module):
class CuGraphGATConv(CuGraphBaseConv):
r"""Graph attention layer from `Graph Attention Networks
<https://arxiv.org/pdf/1710.10903.pdf>`__, with the sparse aggregation
accelerated by cugraph-ops.
See :class:`dgl.nn.pytorch.conv.GATConv` for mathematical model.
This module depends on :code:`pylibcugraphops` package, which can be
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
:code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
.. note::
This is an **experimental** feature.
Expand Down Expand Up @@ -78,7 +81,7 @@ class CuGraphGATConv(nn.Module):
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>)
"""
MAX_IN_DEGREE_MFG = 500
MAX_IN_DEGREE_MFG = 200

def __init__(
self,
Expand All @@ -91,10 +94,11 @@ def __init__(
activation=None,
bias=True,
):
if has_pylibcugraphops is False:
if HAS_PYLIBCUGRAPHOPS is False:
raise ModuleNotFoundError(
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
f"{self.__class__.__name__} requires pylibcugraphops=23.04. "
f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
f"pylibcugraphops requires Python 3.8 or 3.10."
)
super().__init__()
self.in_feats = in_feats
Expand Down Expand Up @@ -170,25 +174,17 @@ def forward(self, g, feat, max_in_degree=None):
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = make_mfg_csr(
g.dstnodes(),
_graph = SampledCSC(
offsets,
indices,
max_in_degree,
g.num_src_nodes(),
)
else:
offsets_fg = torch.empty(
g.num_src_nodes() + 1,
dtype=offsets.dtype,
device=offsets.device,
)
offsets_fg[: offsets.numel()] = offsets
offsets_fg[offsets.numel() :] = offsets[-1]

_graph = make_fg_csr(offsets_fg, indices)
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = StaticCSC(offsets_fg, indices)
else:
_graph = make_fg_csr(offsets, indices)
_graph = StaticCSC(offsets, indices)

feat = self.feat_drop(feat)
feat_transformed = self.fc(feat)
Expand All @@ -199,7 +195,6 @@ def forward(self, g, feat, max_in_degree=None):
self.num_heads,
"LeakyReLU",
self.negative_slope,
add_own_node=False,
concat_heads=True,
)[: g.num_dst_nodes()].view(-1, self.num_heads, self.out_feats)

Expand Down
61 changes: 24 additions & 37 deletions python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
import torch
from torch import nn

from .cugraph_base import CuGraphBaseConv

try:
from pylibcugraphops import make_fg_csr_hg, make_mfg_csr_hg
from pylibcugraphops.torch.autograd import (
from pylibcugraphops.pytorch import SampledHeteroCSC, StaticHeteroCSC
from pylibcugraphops.pytorch.operators import (
agg_hg_basis_n2n_post as RelGraphConvAgg,
)

HAS_PYLIBCUGRAPHOPS = True
except ImportError:
has_pylibcugraphops = False
else:
has_pylibcugraphops = True
HAS_PYLIBCUGRAPHOPS = False


class CuGraphRelGraphConv(nn.Module):
class CuGraphRelGraphConv(CuGraphBaseConv):
r"""An accelerated relational graph convolution layer from `Modeling
Relational Data with Graph Convolutional Networks
<https://arxiv.org/abs/1703.06103>`__ that leverages the highly-optimized
Expand All @@ -26,7 +28,8 @@ class CuGraphRelGraphConv(nn.Module):
See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model.
This module depends on :code:`pylibcugraphops` package, which can be
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
:code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
.. note::
This is an **experimental** feature.
Expand Down Expand Up @@ -92,10 +95,11 @@ def __init__(
dropout=0.0,
apply_norm=False,
):
if has_pylibcugraphops is False:
if HAS_PYLIBCUGRAPHOPS is False:
raise ModuleNotFoundError(
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02 "
f"to be installed."
f"{self.__class__.__name__} requires pylibcugraphops=23.04. "
f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
f"pylibcugraphops requires Python 3.8 or 3.10."
)
super().__init__()
self.in_feat = in_feat
Expand Down Expand Up @@ -176,53 +180,36 @@ def forward(self, g, feat, etypes, max_in_degree=None):
torch.Tensor
New node features. Shape: :math:`(|V|, D_{out})`.
"""
# Create csc-representation and cast etypes to int32.
offsets, indices, edge_ids = g.adj_tensors("csc")
edge_types_perm = etypes[edge_ids.long()].int()

# Create cugraph-ops graph.
if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = make_mfg_csr_hg(
g.dstnodes(),
_graph = SampledHeteroCSC(
offsets,
indices,
edge_types_perm,
max_in_degree,
g.num_src_nodes(),
n_node_types=0,
n_edge_types=self.num_rels,
out_node_types=None,
in_node_types=None,
edge_types=edge_types_perm,
self.num_rels,
)
else:
offsets_fg = torch.empty(
g.num_src_nodes() + 1,
dtype=offsets.dtype,
device=offsets.device,
)
offsets_fg[: offsets.numel()] = offsets
offsets_fg[offsets.numel() :] = offsets[-1]

_graph = make_fg_csr_hg(
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = StaticHeteroCSC(
offsets_fg,
indices,
n_node_types=0,
n_edge_types=self.num_rels,
node_types=None,
edge_types=edge_types_perm,
edge_types_perm,
self.num_rels,
)
else:
_graph = make_fg_csr_hg(
_graph = StaticHeteroCSC(
offsets,
indices,
n_node_types=0,
n_edge_types=self.num_rels,
node_types=None,
edge_types=edge_types_perm,
edge_types_perm,
self.num_rels,
)

h = RelGraphConvAgg(
Expand Down
41 changes: 18 additions & 23 deletions python/dgl/nn/pytorch/conv/cugraph_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
cugraph-ops"""
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments

import torch
from torch import nn

from .cugraph_base import CuGraphBaseConv

try:
from pylibcugraphops import make_fg_csr, make_mfg_csr
from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg
from pylibcugraphops.pytorch import SampledCSC, StaticCSC
from pylibcugraphops.pytorch.operators import agg_concat_n2n as SAGEConvAgg

HAS_PYLIBCUGRAPHOPS = True
except ImportError:
has_pylibcugraphops = False
else:
has_pylibcugraphops = True
HAS_PYLIBCUGRAPHOPS = False


class CuGraphSAGEConv(nn.Module):
class CuGraphSAGEConv(CuGraphBaseConv):
r"""An accelerated GraphSAGE layer from `Inductive Representation Learning
on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ that leverages the
highly-optimized aggregation primitives in cugraph-ops:
Expand All @@ -27,7 +28,8 @@ class CuGraphSAGEConv(nn.Module):
(h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)})
This module depends on :code:`pylibcugraphops` package, which can be
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
:code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
.. note::
This is an **experimental** feature.
Expand Down Expand Up @@ -74,10 +76,11 @@ def __init__(
feat_drop=0.0,
bias=True,
):
if has_pylibcugraphops is False:
if HAS_PYLIBCUGRAPHOPS is False:
raise ModuleNotFoundError(
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
f"{self.__class__.__name__} requires pylibcugraphops=23.04. "
f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
f"pylibcugraphops requires Python 3.8 or 3.10."
)

valid_aggr_types = {"max", "min", "mean", "sum"}
Expand Down Expand Up @@ -126,25 +129,17 @@ def forward(self, g, feat, max_in_degree=None):
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = make_mfg_csr(
g.dstnodes(),
_graph = SampledCSC(
offsets,
indices,
max_in_degree,
g.num_src_nodes(),
)
else:
offsets_fg = torch.empty(
g.num_src_nodes() + 1,
dtype=offsets.dtype,
device=offsets.device,
)
offsets_fg[: offsets.numel()] = offsets
offsets_fg[offsets.numel() :] = offsets[-1]

_graph = make_fg_csr(offsets_fg, indices)
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = StaticCSC(offsets_fg, indices)
else:
_graph = make_fg_csr(offsets, indices)
_graph = StaticCSC(offsets, indices)

feat = self.feat_drop(feat)
h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()]
Expand Down
1 change: 0 additions & 1 deletion tests/cugraph/cugraph-ops/test_cugraph_gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def generate_graph():
return g


@pytest.mark.skip()
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_gatconv_equality(idtype_int, max_in_degree, num_heads, to_block):
device = "cuda:0"
Expand Down
1 change: 0 additions & 1 deletion tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def generate_graph():
return g


@pytest.mark.skip()
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_relgraphconv_equality(
idtype_int, max_in_degree, num_bases, regularizer, self_loop, to_block
Expand Down
1 change: 0 additions & 1 deletion tests/cugraph/cugraph-ops/test_cugraph_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def generate_graph():
return g


@pytest.mark.skip()
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
device = "cuda:0"
Expand Down

0 comments on commit 05aebd8

Please sign in to comment.