Skip to content

Commit

Permalink
Remove unnecessary argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 7, 2024
1 parent 9c09380 commit 93a519e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
5 changes: 1 addition & 4 deletions graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ namespace ops {
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param indices_return_is_optional Boolean indicating whether returning the
* output indices is optional, typically used if the fetch of indices will
* happen later.
* @param type_per_edge A tensor representing the type of each edge, if present.
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
Expand All @@ -81,7 +78,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, bool indices_return_is_optional,
bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> node_type_offset = torch::nullopt,
Expand Down
7 changes: 4 additions & 3 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, bool indices_return_is_optional,
torch::optional<torch::Tensor> type_per_edge,
bool return_eids, torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> node_type_offset,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,
Expand Down Expand Up @@ -520,7 +519,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}

if (!indices_return_is_optional || utils::is_on_gpu(indices)) {
// TODO @mfbalin: remove true from here once fetching indices later is
// setup.
if (true || layer || utils::is_on_gpu(indices)) {
output_indices = Gather(indices, picked_eids);
}
}));
Expand Down
6 changes: 3 additions & 3 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
return_eids, false, type_per_edge_, probs_or_mask,
node_type_offset_, node_type_to_id_, edge_type_to_id_,
random_seed, seed2_contribution);
return_eids, type_per_edge_, probs_or_mask, node_type_offset_,
node_type_to_id_, edge_type_to_id_, random_seed,
seed2_contribution);
});
}
TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU.");
Expand Down

0 comments on commit 93a519e

Please sign in to comment.