Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Make fetching indices optional for NS. #7662

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions graphbolt/include/graphbolt/fused_sampled_subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* edges that are sorted w.r.t. edge types.
*/
FusedSampledSubgraph(
torch::Tensor indptr, torch::Tensor indices,
torch::Tensor indptr, torch::optional<torch::Tensor> indices,
torch::optional<torch::Tensor> original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
Expand All @@ -67,7 +67,10 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
original_row_node_ids(original_row_node_ids),
original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge),
etype_offsets(etype_offsets) {}
etype_offsets(etype_offsets) {
// If indices will be fetched later, we need original_edge_ids.
TORCH_CHECK(indices.has_value() || original_edge_ids.has_value());
}

FusedSampledSubgraph() = default;

Expand All @@ -84,8 +87,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* original ids. If compacted, the original ids are stored in the
* `original_row_node_ids` field. The indices are sorted w.r.t. their edge
* types for the heterogenous case.
*
* @note This is optional if its fetch operation will be performed later using
* the original_edge_ids tensor.
*/
torch::Tensor indices;
torch::optional<torch::Tensor> indices;

/**
* @brief Column's reverse node ids in the original graph. A graph structure
Expand Down
28 changes: 17 additions & 11 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}();
auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids;
torch::Tensor output_indices;
torch::optional<torch::Tensor> output_indices;

AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
Expand Down Expand Up @@ -519,7 +519,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}

output_indices = Gather(indices, picked_eids);
// TODO @mfbalin: remove true from here once fetching indices later is
// setup.
if (true || layer || utils::is_on_gpu(indices)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't catch it in previous review, but layer is not a good name for boolean.

output_indices = Gather(indices, picked_eids);
}
}));

torch::optional<torch::Tensor> output_type_per_edge;
Expand Down Expand Up @@ -565,7 +569,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
auto indices_offsets_device = torch::empty(
etype_id_to_src_ntype_id.size(0),
output_indices.options().dtype(torch::kLong));
picked_eids.options().dtype(torch::kLong));
AT_DISPATCH_INDEX_TYPES(
node_type_offset->scalar_type(), "SampleNeighborsNodeTypeOffset", ([&] {
THRUST_CALL(
Expand Down Expand Up @@ -628,14 +632,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
edge_offsets_pinned_device_pair);
}));
edge_offsets_event.record();
auto indices_offset_subtract = ExpandIndptrImpl(
edge_offsets_device, indices.scalar_type(), indices_offsets_device,
output_indices.size(0));
// The output_indices is permuted here.
std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, output_indices, permutation,
num_rows - 1, output_indices.size(0));
output_indices -= indices_offset_subtract;
if (output_indices.has_value()) {
auto indices_offset_subtract = ExpandIndptrImpl(
edge_offsets_device, indices.scalar_type(), indices_offsets_device,
output_indices->size(0));
// The output_indices is permuted here.
std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, *output_indices, permutation,
num_rows - 1, output_indices->size(0));
*output_indices -= indices_offset_subtract;
}
auto output_indptr_offsets = torch::empty(
num_etypes * 2,
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
Expand Down
7 changes: 3 additions & 4 deletions graphbolt/src/index_select.cc