Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Labor dependent template specialization. #7220

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions graphbolt/include/graphbolt/continuous_seed.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,30 @@ class continuous_seed {
#endif // __CUDA_ARCH__
};

class single_seed {
uint64_t s;
mfbalin marked this conversation as resolved.
Show resolved Hide resolved

public:
/* implicit */ single_seed(const int64_t seed) : s(seed) {} // NOLINT

single_seed(torch::Tensor seed_arr) : s(seed_arr.data_ptr<int64_t>()[0]) {}

#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t t) const {
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
const uint64_t kCurandSeed = 999961; // Could be any random number.
curandStatePhilox4_32_10_t rng;
curand_init(kCurandSeed, s, t, &rng);
return curand_uniform(&rng);
}
#else
inline float uniform(const uint64_t t) const {
pcg32 ng0(s, t);
std::uniform_real_distribution<float> uni;
return uni(ng0);
}
#endif // __CUDA_ARCH__
};

} // namespace graphbolt

#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
31 changes: 21 additions & 10 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
namespace graphbolt {
namespace sampling {

enum SamplerType { NEIGHBOR, LABOR };
enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };

constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
}

template <SamplerType S>
struct SamplerArgs;
Expand All @@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};

template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
single_seed random_seed;
int64_t num_nodes;
};

template <>
struct SamplerArgs<SamplerType::LABOR_DEPENDENT> {
const torch::Tensor& indices;
continuous_seed random_seed;
int64_t num_nodes;
Expand Down Expand Up @@ -555,12 +566,12 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);

template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

template <typename PickedType>
int64_t TemporalPick(
Expand Down Expand Up @@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
PickedType* picked_data_ptr);

template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024>
int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize = 1024>
std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

} // namespace sampling
} // namespace graphbolt
Expand Down
80 changes: 46 additions & 34 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <limits>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <vector>

#include "./macro.h"
Expand Down Expand Up @@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}

if (layer) {
SamplerArgs<SamplerType::LABOR> args = [&] {
if (random_seed.has_value()) {
return SamplerArgs<SamplerType::LABOR>{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
if (random_seed.has_value() && random_seed->numel() >= 2) {
SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
} else {
auto args = [&] {
if (random_seed.has_value() && random_seed->numel() == 1) {
return SamplerArgs<SamplerType::LABOR>{
indices_, random_seed.value(), NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
Expand Down Expand Up @@ -1297,7 +1309,7 @@ int64_t TemporalPick(
}
return picked_indices.numel();
}
if constexpr (S == SamplerType::LABOR) {
if constexpr (is_labor(S)) {
return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr);
Expand Down Expand Up @@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
return pick_offset;
}

template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
if (fanout == 0) return 0;
if (probs_or_mask.has_value()) {
if (fanout < 0) {
Expand Down Expand Up @@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
return rem * (one - std::pow(one - u, one / n));
}

template <typename T>
template <typename T, typename seed_t>
inline T jth_sorted_uniform_random(
continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const T u = seed.uniform(t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
Expand Down Expand Up @@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
* should be put. Enough memory space should be allocated in advance.
*/
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize>
inline int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize>
inline std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
Expand All @@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
}
AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const index_t* local_indices_data =
args.indices.data_ptr<index_t>() + offset;
const auto local_indices_data =
reinterpret_cast<index_t*>(args.indices.data_ptr()) + offset;
if constexpr (Replace) {
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
Expand Down
Loading