Skip to content

Commit

Permalink
[GraphBolt][CUDA] Fix overlap_graph_fetch edge_ids.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 26, 2024
1 parent ea28196 commit cc04d3a
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 12 deletions.
5 changes: 4 additions & 1 deletion graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,17 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
* @param indices_list Vector of indices tensor with edge information of shape
* (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param with_edge_ids Whether to return edge ids tensor corresponding to
* sliced edges as the last element of the output.
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, std::vector<torch::Tensor>) Output indptr and vector
* of indices tensors of shapes (M + 1,) and ((indptr[nodes + 1] -
* indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatchedImpl(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, torch::optional<int64_t> output_size);
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size);

/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
Expand Down
8 changes: 7 additions & 1 deletion graphbolt/src/cuda/index_select_csc_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(

std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatchedImpl(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, torch::optional<int64_t> output_size) {
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size) {
auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
std::vector<torch::Tensor> results;
results.reserve(indices_list.size());
Expand All @@ -322,6 +323,11 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatchedImpl(
TORCH_CHECK(*output_size == output_indices.size(0));
results.push_back(output_indices);
}
if (with_edge_ids) {
results.push_back(IndptrEdgeIdsImpl(
output_indptr, sliced_indptr.scalar_type(), sliced_indptr,
output_size));
}
return {output_indptr, results};
}

Expand Down
9 changes: 7 additions & 2 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(

std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, torch::optional<int64_t> output_size) {
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size) {
for (auto& indices : indices_list) {
TORCH_CHECK(
indices.sizes().size() == 1,
Expand All @@ -129,11 +130,12 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndexSelectCSCImpl", {
return IndexSelectCSCBatchedImpl(
indptr, indices_list, nodes, output_size);
indptr, indices_list, nodes, with_edge_ids, output_size);
});
}
std::vector<torch::Tensor> results;
torch::Tensor output_indptr;
torch::Tensor edge_ids;
for (auto& indices : indices_list) {
// @todo: The CPU supports only integer dtypes for indices tensor.
TORCH_CHECK(
Expand All @@ -144,7 +146,10 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
const auto res = g.InSubgraph(nodes);
output_indptr = res->indptr;
results.push_back(res->indices);
TORCH_CHECK(res->original_edge_ids.has_value());
edge_ids = *res->original_edge_ids;
}
if (with_edge_ids) results.push_back(edge_ids);
return std::make_tuple(output_indptr, results);
}

Expand Down
6 changes: 5 additions & 1 deletion graphbolt/src/index_select.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,18 @@ c10::intrusive_ptr<Future<torch::Tensor>> ScatterAsync(
* @param indices_list Vector of indices tensor with edge information of shape
* (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param with_edge_ids Whether to return edge ids tensor corresponding to
* sliced edges as the last element of the output.
* @param output_size The total number of edges being copied.
*
* @return (torch::Tensor, std::vector<torch::Tensor>) Output indptr and vector
* of indices tensors of shapes (M + 1,) and ((indptr[nodes + 1] -
* indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, torch::optional<int64_t> output_size);
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size);

} // namespace ops
} // namespace graphbolt
Expand Down
13 changes: 11 additions & 2 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper

from ..base import ORIGINAL_EDGE_ID
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..minibatch_transformer import MiniBatchTransformer

Expand Down Expand Up @@ -78,6 +79,10 @@ def _combine_per_layer(self, minibatch):
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
edge_tensors = edge_tensors[1:]
assert len(edge_tensors) == 0
# TODO @mfbalin: remove these lines after fixing cache for edge ids.
edge_attributes = subgraph.edge_attributes
edge_attributes.pop(ORIGINAL_EDGE_ID)
subgraph.edge_attributes = edge_attributes

return minibatch

Expand Down Expand Up @@ -167,7 +172,7 @@ def record_stream(tensor):
indptr,
sliced_tensors,
) = torch.ops.graphbolt.index_select_csc_batched(
self.graph.csc_indptr, tensors_to_be_sliced, seeds, None
self.graph.csc_indptr, tensors_to_be_sliced, seeds, True, None
)
for tensor in [indptr] + sliced_tensors:
record_stream(tensor)
Expand All @@ -185,6 +190,9 @@ def record_stream(tensor):
if has_probs_or_mask:
probs_or_mask = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]

edge_ids = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
assert len(sliced_tensors) == 0

subgraph = fused_csc_sampling_graph(
Expand All @@ -196,7 +204,8 @@ def record_stream(tensor):
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask}
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)

subgraph._indptr_node_type_offset_list = seed_offsets
minibatch._sliced_sampling_graph = subgraph
Expand Down
10 changes: 8 additions & 2 deletions tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
)
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("is_pinned", [False, True])
@pytest.mark.parametrize("with_edge_ids", [False, True])
@pytest.mark.parametrize("output_size", [None, True])
def test_index_select_csc(
indptr_dtype, indices_dtype, idtype, is_pinned, output_size
indptr_dtype, indices_dtype, idtype, is_pinned, with_edge_ids, output_size
):
"""Original graph in COO:
1 0 1 0 1 0
Expand All @@ -50,6 +51,9 @@ def test_index_select_csc(
indptr = indptr.cuda()
indices = indices.cuda()
index = index.cuda()
edge_ids = torch.tensor(
[0, 1, 2, 12, 13, 7, 8], dtype=indptr_dtype, device=index.device
)

if output_size:
output_size = len(cpu_indices)
Expand All @@ -75,12 +79,14 @@ def test_index_select_csc(
gpu_indptr2,
gpu_indices_list,
) = torch.ops.graphbolt.index_select_csc_batched(
indptr, indices_list, index, output_size_selection
indptr, indices_list, index, with_edge_ids, output_size_selection
)

assert torch.equal(gpu_indptr, gpu_indptr2)
assert torch.equal(gpu_indices_list[0], gpu_indices)
assert torch.equal(gpu_indices_list[1], gpu_indices.int())
if with_edge_ids:
assert torch.equal(gpu_indices_list[2], edge_ids)


def test_InSubgraphSampler_homo():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_NeighborSampler_GraphFetch(
assert len(expected_results) == len(new_results)
for a, b in zip(expected_results, new_results):
# TODO @mfbalin: Fix the edge id bug and enable this test.
assert True or repr(a) == repr(b)
assert num_cached_edges != 0 or repr(a) == repr(b)


@pytest.mark.parametrize("layer_dependency", [False, True])
Expand Down
10 changes: 8 additions & 2 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def test_gpu_sampling_DataLoader(
assert "b" in minibatch.node_features
assert "c" in minibatch.node_features
# TODO @mfbalin: enable this if.
if False and sampler_name == "LayerNeighborSampler":
if (
num_gpu_cached_edges == 0
and sampler_name == "LayerNeighborSampler"
):
assert torch.equal(
minibatch.node_features["a"], minibatch2.node_features["a"]
)
Expand All @@ -174,6 +177,9 @@ def test_gpu_sampling_DataLoader(
edge_feature = minibatch.edge_features[layer_id]["d"]
edge_feature_ref = minibatch2.edge_features[layer_id]["d"]
# TODO @mfbalin: enable this if.
if False and sampler_name == "LayerNeighborSampler":
if (
num_gpu_cached_edges == 0
and sampler_name == "LayerNeighborSampler"
):
assert torch.equal(edge_feature, edge_feature_ref)
assert len(list(dataloader)) == N // B

0 comments on commit cc04d3a

Please sign in to comment.