Skip to content

Commit

Permalink
[GraphBolt][CUDA] Make fetching indices optional for NS.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 6, 2024
1 parent 4978313 commit 9c09380
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
5 changes: 4 additions & 1 deletion graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace ops {
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param indices_return_is_optional Boolean indicating whether returning the
* output indices is optional, typically used if the fetch of indices will
* happen later.
* @param type_per_edge A tensor representing the type of each edge, if present.
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
Expand All @@ -78,7 +81,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
bool return_eids, bool indices_return_is_optional,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> node_type_offset = torch::nullopt,
Expand Down
12 changes: 9 additions & 3 deletions graphbolt/include/graphbolt/fused_sampled_subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* edges that are sorted w.r.t. edge types.
*/
FusedSampledSubgraph(
torch::Tensor indptr, torch::Tensor indices,
torch::Tensor indptr, torch::optional<torch::Tensor> indices,
torch::optional<torch::Tensor> original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
Expand All @@ -67,7 +67,10 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
original_row_node_ids(original_row_node_ids),
original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge),
etype_offsets(etype_offsets) {}
etype_offsets(etype_offsets) {
// If indices will be fetched later, we need original_edge_ids.
TORCH_CHECK(indices.has_value() || original_edge_ids.has_value());
}

FusedSampledSubgraph() = default;

Expand All @@ -84,8 +87,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* original ids. If compacted, the original ids are stored in the
* `original_row_node_ids` field. The indices are sorted w.r.t. their edge
* types for the heterogenous case.
*
* @note This is optional if its fetch operation will be performed later using
* the original_edge_ids tensor.
*/
torch::Tensor indices;
torch::optional<torch::Tensor> indices;

/**
* @brief Column's reverse node ids in the original graph. A graph structure
Expand Down
29 changes: 17 additions & 12 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> type_per_edge,
bool return_eids, bool indices_return_is_optional,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> node_type_offset,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,
Expand Down Expand Up @@ -296,7 +297,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}();
auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids;
torch::Tensor output_indices;
torch::optional<torch::Tensor> output_indices;

AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
Expand Down Expand Up @@ -519,7 +520,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}

output_indices = Gather(indices, picked_eids);
if (!indices_return_is_optional || utils::is_on_gpu(indices)) {
output_indices = Gather(indices, picked_eids);
}
}));

torch::optional<torch::Tensor> output_type_per_edge;
Expand Down Expand Up @@ -565,7 +568,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
auto indices_offsets_device = torch::empty(
etype_id_to_src_ntype_id.size(0),
output_indices.options().dtype(torch::kLong));
picked_eids.options().dtype(torch::kLong));
AT_DISPATCH_INDEX_TYPES(
node_type_offset->scalar_type(), "SampleNeighborsNodeTypeOffset", ([&] {
THRUST_CALL(
Expand Down Expand Up @@ -628,14 +631,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
edge_offsets_pinned_device_pair);
}));
edge_offsets_event.record();
auto indices_offset_subtract = ExpandIndptrImpl(
edge_offsets_device, indices.scalar_type(), indices_offsets_device,
output_indices.size(0));
// The output_indices is permuted here.
std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, output_indices, permutation,
num_rows - 1, output_indices.size(0));
output_indices -= indices_offset_subtract;
if (output_indices.has_value()) {
auto indices_offset_subtract = ExpandIndptrImpl(
edge_offsets_device, indices.scalar_type(), indices_offsets_device,
output_indices->size(0));
// The output_indices is permuted here.
std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, *output_indices, permutation,
num_rows - 1, output_indices->size(0));
*output_indices -= indices_offset_subtract;
}
auto output_indptr_offsets = torch::empty(
num_etypes * 2,
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
Expand Down
6 changes: 3 additions & 3 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
return_eids, type_per_edge_, probs_or_mask, node_type_offset_,
node_type_to_id_, edge_type_to_id_, random_seed,
seed2_contribution);
return_eids, false, type_per_edge_, probs_or_mask,
node_type_offset_, node_type_to_id_, edge_type_to_id_,
random_seed, seed2_contribution);
});
}
TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU.");
Expand Down
7 changes: 3 additions & 4 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
}
sampling::FusedCSCSamplingGraph g(indptr, indices);
const auto res = g.InSubgraph(nodes);
return std::make_tuple(res->indptr, res->indices);
return std::make_tuple(res->indptr, res->indices.value());
}

std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
Expand All @@ -136,9 +136,8 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
sampling::FusedCSCSamplingGraph g(indptr, indices);
const auto res = g.InSubgraph(nodes);
output_indptr = res->indptr;
results.push_back(res->indices);
TORCH_CHECK(res->original_edge_ids.has_value());
edge_ids = *res->original_edge_ids;
results.push_back(res->indices.value());
edge_ids = res->original_edge_ids.value();
}
if (with_edge_ids) results.push_back(edge_ids);
return std::make_tuple(output_indptr, results);
Expand Down

0 comments on commit 9c09380

Please sign in to comment.