Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] torch.compile() support for gb.expand_indptr. #7188

Merged
merged 11 commits into from
Mar 5, 2024
15 changes: 15 additions & 0 deletions graphbolt/src/expand_indptr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* @brief ExpandIndptr operators.
*/
#include <graphbolt/cuda_ops.h>
#include <torch/autograd.h>

#include "./macro.h"
#include "./utils.h"
Expand All @@ -29,5 +30,19 @@ torch::Tensor ExpandIndptr(
indptr.diff(), 0, output_size);
}

TORCH_LIBRARY_IMPL(graphbolt, CPU, m) {
m.impl("expand_indptr", &ExpandIndptr);
}

#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {
m.impl("expand_indptr", &ExpandIndptrImpl);
}
#endif

TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback());
}

} // namespace ops
} // namespace graphbolt
12 changes: 11 additions & 1 deletion graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,21 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect);
m.def("index_select_csc", &ops::IndexSelectCSC);
m.def("expand_indptr", &ops::ExpandIndptr);
m.def("set_seed", &RandomEngine::SetManualSeed);
#ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
#endif
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base");
#endif
m.def(
"expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, "
"SymInt? output_size) -> Tensor"
#ifdef HAS_PT2_COMPLIANT_TAG
,
{at::Tag::pt2_compliant_tag}
#endif
);
}

} // namespace sampling
Expand Down
40 changes: 21 additions & 19 deletions python/dgl/graphbolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,6 @@
import torch

from .._ffi import libinfo
from .base import *
from .minibatch import *
from .dataloader import *
from .dataset import *
from .feature_fetcher import *
from .feature_store import *
from .impl import *
from .itemset import *
from .item_sampler import *
from .minibatch_transformer import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
from .internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges


def load_graphbolt():
Expand Down Expand Up @@ -53,3 +34,24 @@ def load_graphbolt():


load_graphbolt()

# pylint: disable=wrong-import-position
from .base import *
from .minibatch import *
from .dataloader import *
from .dataset import *
from .feature_fetcher import *
from .feature_store import *
from .impl import *
from .itemset import *
from .item_sampler import *
from .minibatch_transformer import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
from .internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges
13 changes: 13 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass

import torch
from torch.torch_version import TorchVersion
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

Expand Down Expand Up @@ -63,6 +64,18 @@ def isin(elements, test_elements):
return torch.ops.graphbolt.isin(elements, test_elements)


if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):

@torch.library.impl_abstract("graphbolt::expand_indptr")
def expand_indptr_abstract(indptr, dtype, node_ids, output_size):
"""Abstract implementation of expand_indptr for torch.compile() support."""
if output_size is None:
output_size = torch.library.get_ctx().new_dynamic_size()
if dtype is None:
dtype = node_ids.dtype
return indptr.new_empty(output_size, dtype=dtype)


def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
"""Converts a given indptr offset tensor to a COO format tensor. If
node_ids is not given, it is assumed to be equal to
Expand Down
27 changes: 27 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dgl.graphbolt as gb
import pytest
import torch
from torch.torch_version import TorchVersion

from . import gb_test_utils

Expand Down Expand Up @@ -296,6 +297,32 @@ def test_expand_indptr(nodes, dtype):
gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item())
assert torch.equal(torch_result, gb_result)

if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):
import torch._dynamo as dynamo
from torch.testing._internal.optests import opcheck

# Tests torch.compile compatibility
for output_size in [None, indptr[-1].item()]:
kwargs = {"node_ids": nodes, "output_size": output_size}
opcheck(
torch.ops.graphbolt.expand_indptr,
(indptr, dtype),
kwargs,
test_utils=[
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
],
raise_exception=True,
)

explanation = dynamo.explain(gb.expand_indptr)(
indptr, dtype, nodes, output_size
)
expected_breaks = -1 if output_size is None else 0
assert explanation.graph_break_count == expected_breaks


def test_csc_format_base_representation():
csc_format_base = gb.CSCFormatBase(
Expand Down
Loading