Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_sampled_edge_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 7, 2024
2 parents 590b8e2 + ae45716 commit 5d77d1f
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 65 deletions.
3 changes: 0 additions & 3 deletions graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ namespace ops {
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param type_per_edge A tensor representing the type of each edge, if present.
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
Expand All @@ -78,7 +76,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> node_type_offset = torch::nullopt,
Expand Down
11 changes: 3 additions & 8 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_or_mask An optional edge attribute tensor for probablities
* or masks. This attribute tensor should contain (unnormalized)
* probabilities corresponding to each neighboring edge of a node. It must be
Expand All @@ -357,7 +355,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;

Expand All @@ -378,8 +376,6 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param input_nodes_pre_time_window The time window of the nodes represents
* a period of time before `input_nodes_timestamp`. If provided, only
* neighbors and related edges whose timestamps fall within
Expand All @@ -400,7 +396,6 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::optional<torch::Tensor> input_nodes_pre_time_window,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<std::string> node_timestamp_attr_name,
Expand Down Expand Up @@ -445,8 +440,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const;
const std::vector<int64_t>& fanouts, NumPickFn num_pick_fn,
PickFn pick_fn) const;

/** @brief CSC format index pointer array. */
torch::Tensor indptr_;
Expand Down
25 changes: 11 additions & 14 deletions graphbolt/include/graphbolt/fused_sampled_subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,18 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
*/
FusedSampledSubgraph(
torch::Tensor indptr, torch::optional<torch::Tensor> indices,
torch::Tensor original_edge_ids,
torch::optional<torch::Tensor> original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> etype_offsets = torch::nullopt)
: indptr(indptr),
indices(indices),
original_edge_ids(original_edge_ids),
original_column_node_ids(original_column_node_ids),
original_row_node_ids(original_row_node_ids),
original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge),
etype_offsets(etype_offsets) {
// If indices will be fetched later, we need original_edge_ids.
TORCH_CHECK(indices.has_value() || original_edge_ids.has_value());
}
etype_offsets(etype_offsets) {}

FusedSampledSubgraph() = default;

Expand All @@ -93,6 +90,14 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
*/
torch::optional<torch::Tensor> indices;

/**
* @brief Reverse edge ids in the original graph, the edge with id
* `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed. The edges are
* sorted w.r.t. their edge types for the heterogenous case.
*/
torch::Tensor original_edge_ids;

/**
* @brief Column's reverse node ids in the original graph. A graph structure
* can be treated as a coordinated row and column pair, and this is the the
Expand All @@ -114,14 +119,6 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
*/
torch::optional<torch::Tensor> original_row_node_ids;

/**
* @brief Reverse edge ids in the original graph, the edge with id
* `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed. The edges are
* sorted w.r.t. their edge types for the heterogenous case.
*/
torch::optional<torch::Tensor> original_edge_ids;

/**
* @brief Type id of each edge, where type id is the corresponding index of
* edge types. The length of it is equal to the number of edges in the
Expand Down
2 changes: 1 addition & 1 deletion graphbolt/src/cuda/insubgraph.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
output_indptr, sliced_indptr.scalar_type(), sliced_indptr, num_edges);

return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt, edge_ids,
output_indptr, output_indices, edge_ids, nodes, torch::nullopt,
output_type_per_edge);
}

Expand Down
9 changes: 3 additions & 6 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> node_type_offset,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,
Expand Down Expand Up @@ -697,12 +697,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
output_type_per_edge = Gather(*type_per_edge, picked_eids);
}

torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);

return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, seeds, torch::nullopt,
subgraph_reverse_edge_ids, output_type_per_edge, edge_offsets);
output_indptr, output_indices, picked_eids, seeds, torch::nullopt,
output_type_per_edge, edge_offsets);
}

} // namespace ops
Expand Down
37 changes: 14 additions & 23 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const auto num_seeds = nodes.size(0);
torch::Tensor indptr = torch::empty({num_seeds + 1}, indptr_.dtype());
std::vector<torch::Tensor> indices_arr(num_seeds);
torch::Tensor original_column_node_ids =
torch::empty({num_seeds}, nodes.dtype());
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);

Expand All @@ -311,8 +309,6 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "InSubgraph::nodes", ([&] {
const auto nodes_data = nodes.data_ptr<index_t>();
auto column_ids_data =
original_column_node_ids.data_ptr<index_t>();
torch::parallel_for(
0, num_seeds, kDefaultGrainSize,
[&](size_t start, size_t end) {
Expand All @@ -321,7 +317,6 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const auto start_idx = indptr_data[node_id];
const auto end_idx = indptr_data[node_id + 1];
out_indptr_data[i + 1] = end_idx - start_idx;
column_ids_data[i] = node_id;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(
start_idx, end_idx, indptr_.scalar_type());
Expand All @@ -335,8 +330,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
}));

return c10::make_intrusive<FusedSampledSubgraph>(
indptr.cumsum(0), torch::cat(indices_arr), original_column_node_ids,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
indptr.cumsum(0), torch::cat(indices_arr), torch::cat(edge_ids_arr),
nodes, torch::arange(0, NumNodes()),
type_per_edge_
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
: torch::nullopt);
Expand Down Expand Up @@ -513,8 +508,8 @@ c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const {
const std::vector<int64_t>& fanouts, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_seeds = seeds.size(0);
const auto indptr_options = indptr_.options();

Expand Down Expand Up @@ -796,23 +791,20 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}));
}));

torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);

if (subgraph_indptr_substract.has_value()) {
subgraph_indptr -= subgraph_indptr_substract.value();
}

return c10::make_intrusive<FusedSampledSubgraph>(
subgraph_indptr, subgraph_indices, seeds, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets);
subgraph_indptr, subgraph_indices, picked_eids, seeds, torch::nullopt,
subgraph_type_per_edge, edge_offsets);
}

c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
// If seeds does not have a value, then we expect all arguments to be resident
Expand All @@ -836,7 +828,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
return_eids, type_per_edge_, probs_or_mask, node_type_offset_,
type_per_edge_, probs_or_mask, node_type_offset_,
node_type_to_id_, edge_type_to_id_, random_seed,
seed2_contribution);
});
Expand All @@ -862,7 +854,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids,
seeds.value(), seed_offsets, fanouts,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask,
with_seed_offsets),
Expand All @@ -883,7 +875,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}();
return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids,
seeds.value(), seed_offsets, fanouts,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask,
with_seed_offsets),
Expand All @@ -894,7 +886,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids,
seeds.value(), seed_offsets, fanouts,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets),
GetPickFn(
Expand All @@ -908,7 +900,6 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::optional<torch::Tensor> input_nodes_pre_time_window,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<std::string> node_timestamp_attr_name,
Expand Down Expand Up @@ -938,7 +929,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
input_nodes, seed_offsets, fanouts,
GetTemporalNumPickFn(
input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_,
input_nodes_pre_time_window, probs_or_mask, node_timestamp,
Expand All @@ -961,7 +952,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
}
}();
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
input_nodes, seed_offsets, fanouts,
GetTemporalNumPickFn(
input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_,
input_nodes_pre_time_window, probs_or_mask, node_timestamp,
Expand All @@ -974,7 +965,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
input_nodes, seed_offsets, fanouts,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, input_nodes_pre_time_window, probs_or_mask,
Expand Down
2 changes: 1 addition & 1 deletion graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
const auto res = g.InSubgraph(nodes);
output_indptr = res->indptr;
results.push_back(res->indices.value());
edge_ids = res->original_edge_ids.value();
edge_ids = res->original_edge_ids;
}
if (with_edge_ids) results.push_back(edge_ids);
return std::make_tuple(output_indptr, results);
Expand Down
9 changes: 0 additions & 9 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,6 @@ def sample_neighbors(
fanouts,
replace=replace,
probs_or_mask=probs_or_mask,
return_eids=True,
)
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
Expand Down Expand Up @@ -828,7 +827,6 @@ def _sample_neighbors(
fanouts: torch.Tensor,
replace: bool = False,
probs_or_mask: Optional[torch.Tensor] = None,
return_eids: bool = False,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
Expand Down Expand Up @@ -867,9 +865,6 @@ def _sample_neighbors(
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
return_eids: bool, optional
Boolean indicating whether to return the original edge IDs of the
sampled edges.
Returns
-------
Expand All @@ -884,7 +879,6 @@ def _sample_neighbors(
fanouts.tolist(),
replace,
False, # is_labor
return_eids,
probs_or_mask,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
Expand Down Expand Up @@ -1018,7 +1012,6 @@ def sample_layer_neighbors(
fanouts.tolist(),
replace,
True, # is_labor
True, # return_eids
probs_or_mask,
random_seed,
seed2_contribution,
Expand Down Expand Up @@ -1114,7 +1107,6 @@ def temporal_sample_neighbors(
fanouts.tolist(),
replace,
False, # is_labor
True, # return_eids
input_nodes_pre_time_window,
probs_or_mask,
node_timestamp_attr_name,
Expand Down Expand Up @@ -1241,7 +1233,6 @@ def temporal_sample_layer_neighbors(
fanouts.tolist(),
replace,
True, # is_labor
True, # return_eids
input_nodes_pre_time_window,
probs_or_mask,
node_timestamp_attr_name,
Expand Down

0 comments on commit 5d77d1f

Please sign in to comment.