Skip to content

Commit

Permalink
Add the python files.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 26, 2024
1 parent 8c9bd68 commit d905c10
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
52 changes: 52 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"isin",
"index_select",
"expand_indptr",
"indptr_edge_ids",
"CSCFormatBase",
"seed",
"seed_type_str_to_ntypes",
Expand Down Expand Up @@ -158,6 +159,57 @@ def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
)


if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):

torch_fake_decorator = (
torch.library.impl_abstract
if TorchVersion(torch.__version__) < TorchVersion("2.4.0a0")
else torch.library.register_fake
)

@torch_fake_decorator("graphbolt::indptr_edge_ids")
def indptr_edge_ids_fake(indptr, dtype, offset, output_size):
"""Fake implementation of indptr_edge_ids for torch.compile() support."""
if output_size is None:
output_size = torch.library.get_ctx().new_dynamic_size()
if dtype is None:
dtype = offset.dtype
return indptr.new_empty(output_size, dtype=dtype)


def indptr_edge_ids(indptr, dtype=None, offset=None, output_size=None):
"""Converts a given indptr offset tensor to a COO format tensor for the edge
ids. For a given indptr [0, 2, 5, 7] and offset tensor [0, 100, 200], the
output will be [0, 1, 100, 101, 102, 201, 202]. If offset was not provided,
the output would be [0, 1, 0, 1, 2, 0, 1].
Parameters
----------
indptr : torch.Tensor
A 1D tensor represents the csc_indptr tensor.
dtype : Optional[torch.dtype]
The dtype of the returned output tensor.
offset : Optional[torch.Tensor]
A 1D tensor represents the offsets that the returned tensor will be
populated with.
output_size : Optional[int]
The size of the output tensor. Should be equal to indptr[-1]. Using this
argument avoids a stream synchronization to calculate the output shape.
Returns
-------
torch.Tensor
The converted COO edge ids tensor.
"""
assert indptr.dim() == 1, "Indptr should be 1D tensor."
assert offset is None or offset.dim() == 1, "Offset should be 1D tensor."
if dtype is None:
dtype = offset.dtype
return torch.ops.graphbolt.indptr_edge_ids(
indptr, dtype, offset, output_size
)


def index_select(tensor, index):
"""Returns a new tensor which indexes the input tensor along dimension dim
using the entries in index.
Expand Down
48 changes: 48 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,54 @@ def test_expand_indptr(nodes, dtype):
assert explanation.graph_break_count == expected_breaks


@unittest.skipIf(
F._default_context_str != "gpu", "Only GPU implementation is available."
)
@pytest.mark.parametrize("offset", [None, True])
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_indptr_edge_ids(offset, dtype):
indptr = torch.tensor([0, 2, 2, 7, 10, 12], device=F.ctx())
if offset:
offset = indptr[:-1]
ref_result = torch.arange(
0, indptr[-1].item(), dtype=dtype, device=F.ctx()
)
else:
ref_result = torch.tensor(
[0, 1, 0, 1, 2, 3, 4, 0, 1, 2, 0, 1], dtype=dtype, device=F.ctx()
)
gb_result = gb.indptr_edge_ids(indptr, dtype, offset)
assert torch.equal(ref_result, gb_result)
gb_result = gb.indptr_edge_ids(indptr, dtype, offset, indptr[-1].item())
assert torch.equal(ref_result, gb_result)

if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):
import torch._dynamo as dynamo
from torch.testing._internal.optests import opcheck

# Tests torch.compile compatibility
for output_size in [None, indptr[-1].item()]:
kwargs = {"offset": offset, "output_size": output_size}
opcheck(
torch.ops.graphbolt.indptr_edge_ids,
(indptr, dtype),
kwargs,
test_utils=[
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
],
raise_exception=True,
)

explanation = dynamo.explain(gb.indptr_edge_ids)(
indptr, dtype, offset, output_size
)
expected_breaks = -1 if output_size is None else 0
assert explanation.graph_break_count == expected_breaks


def test_csc_format_base_representation():
csc_format_base = gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
Expand Down

0 comments on commit d905c10

Please sign in to comment.