Skip to content

Commit

Permalink
[GraphBolt][CUDA] Specialize non-weighted neighbor sampling implement…
Browse files Browse the repository at this point in the history
…ation.
  • Loading branch information
mfbalin committed Mar 14, 2024
1 parent a272efe commit cd4b0e1
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 112 deletions.
2 changes: 2 additions & 0 deletions graphbolt/include/graphbolt/continuous_seed.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class continuous_seed {
c[1] = std::sin(pi * r / 2);
}

uint64_t get_seed(int i) const { return s[i != 0]; }

#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t t) const {
const uint64_t kCurandSeed = 999961; // Could be any random number.
Expand Down
329 changes: 217 additions & 112 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <algorithm>
#include <array>
#include <cub/cub.cuh>
#include <cuda/atomic>
#include <limits>
#include <numeric>
#include <type_traits>
Expand All @@ -30,6 +31,43 @@ namespace ops {

constexpr int BLOCK_SIZE = 128;

/**
* @brief Performs neighbor sampling and fills the edge_ids array with
* original edge ids if sliced_indptr is valid. If not, then it fills the edge
* ids array with numbers upto the node degree.
*/
template <typename indptr_t, typename indices_t>
__global__ void _ComputeRandomsNS(
const int64_t num_edges, const indptr_t* const sliced_indptr,
const indptr_t* const sub_indptr, const indptr_t* const output_indptr,
const indices_t* const csr_rows, const uint64_t random_seed,
indptr_t* edge_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;

curandStatePhilox4_32_10_t rng;
curand_init(random_seed, i, 0, &rng);

while (i < num_edges) {
const auto row_position = csr_rows[i];
const auto row_offset = i - sub_indptr[row_position];
const auto output_offset = output_indptr[row_position];
const auto fanout = output_indptr[row_position + 1] - output_offset;
const auto rnd =
row_offset < fanout ? row_offset : curand(&rng) % (row_offset + 1);
if (rnd < fanout) {
::cuda::atomic_ref<indptr_t, ::cuda::thread_scope_device> a(
edge_ids[output_offset + rnd]);
const auto edge_id =
row_offset + (sliced_indptr ? sliced_indptr[row_position] : 0);
a.fetch_max(
static_cast<indptr_t>(edge_id), ::cuda::std::memory_order_relaxed);
}

i += stride;
}
}

/**
* @brief Fills the random_arr with random numbers and the edge_ids array with
* original edge ids. When random_arr is sorted along with edge_ids, the first
Expand Down Expand Up @@ -251,119 +289,186 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(

// Find the smallest integer type to store the edge id offsets. We synch
// the CUDAEvent so that the access is safe.
max_in_degree_event.synchronize();
const int num_bits =
cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
std::array<int, 4> type_bits = {8, 16, 32, 64};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
torch::kLong};
auto edge_id_dtype = types[type_index];
AT_DISPATCH_INTEGRAL_TYPES(
edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
using edge_id_t = std::make_unsigned_t<scalar_t>;
TORCH_CHECK(
num_bits <= sizeof(edge_id_t) * 8,
"Selected edge_id_t must be capable of storing edge_ids.");
// Using bfloat16 for random numbers works just as reliably as
// float32 and provides around %30 percent speedup.
using rnd_t = nv_bfloat16;
auto randoms =
allocator.AllocateStorage<rnd_t>(num_edges.value());
auto randoms_sorted =
allocator.AllocateStorage<rnd_t>(num_edges.value());
auto edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges.value());
auto sorted_edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges.value());
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = index_t;
auto probs_or_mask_scalar_type = torch::kFloat32;
if (probs_or_mask.has_value()) {
probs_or_mask_scalar_type =
probs_or_mask.value().scalar_type();
}
GRAPHBOLT_DISPATCH_ALL_TYPES(
probs_or_mask_scalar_type, "SampleNeighborsProbs",
([&] {
using probs_t = scalar_t;
probs_t* sliced_probs_ptr = nullptr;
if (sliced_probs_or_mask.has_value()) {
sliced_probs_ptr = sliced_probs_or_mask.value()
.data_ptr<probs_t>();
}
const indices_t* indices_ptr =
layer ? indices.data_ptr<indices_t>() : nullptr;
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(num_edges.value() + BLOCK_SIZE - 1) /
BLOCK_SIZE);
// Compute row and random number pairs.
CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0,
num_edges.value(),
sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
indices_ptr, random_seed, randoms.get(),
edge_id_segments.get());
}));
}));

// Sort the random numbers along with edge ids, after
// sorting the first fanout elements of each row will
// give us the sampled edges.
CUB_CALL(
DeviceSegmentedSort::SortPairs, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges.value(), num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1);

picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges),
sub_indptr.options());

// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (type_per_edge && fanouts.size() == 1) {
// Ensuring sort result still ends up in sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree});
CUB_CALL(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it);
}

auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();

// Copy the sampled edge ids into picked_eids tensor.
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
auto compute_num_bits = [&] {
max_in_degree_event.synchronize();
return cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
};
if (layer || probs_or_mask.has_value()) {
const int num_bits = compute_num_bits();
std::array<int, 4> type_bits = {8, 16, 32, 64};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
torch::kLong};
auto edge_id_dtype = types[type_index];
AT_DISPATCH_INTEGRAL_TYPES(
edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
using edge_id_t = std::make_unsigned_t<scalar_t>;
TORCH_CHECK(
num_bits <= sizeof(edge_id_t) * 8,
"Selected edge_id_t must be capable of storing edge_ids.");
// Using bfloat16 for random numbers works just as reliably as
// float32 and provides around %30 percent speedup.
using rnd_t = nv_bfloat16;
auto randoms =
allocator.AllocateStorage<rnd_t>(num_edges.value());
auto randoms_sorted =
allocator.AllocateStorage<rnd_t>(num_edges.value());
auto edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges.value());
auto sorted_edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges.value());
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = index_t;
auto probs_or_mask_scalar_type = torch::kFloat32;
if (probs_or_mask.has_value()) {
probs_or_mask_scalar_type =
probs_or_mask.value().scalar_type();
}
GRAPHBOLT_DISPATCH_ALL_TYPES(
probs_or_mask_scalar_type, "SampleNeighborsProbs",
([&] {
using probs_t = scalar_t;
probs_t* sliced_probs_ptr = nullptr;
if (sliced_probs_or_mask.has_value()) {
sliced_probs_ptr = sliced_probs_or_mask.value()
.data_ptr<probs_t>();
}
const indices_t* indices_ptr =
layer ? indices.data_ptr<indices_t>() : nullptr;
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(num_edges.value() + BLOCK_SIZE - 1) /
BLOCK_SIZE);
// Compute row and random number pairs.
CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0,
num_edges.value(),
sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(),
sliced_probs_ptr, indices_ptr, random_seed,
randoms.get(), edge_id_segments.get());
}));
}));

// Sort the random numbers along with edge ids, after
// sorting the first fanout elements of each row will
// give us the sampled edges.
CUB_CALL(
DeviceCopy::Batched, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once));
}
}));
DeviceSegmentedSort::SortPairs, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges.value(), num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1);

picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges),
sub_indptr.options());

// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (type_per_edge && fanouts.size() == 1) {
// Ensuring sort result still ends up in
// sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota,
SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree});
CUB_CALL(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it);
}

auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();

// Copy the sampled edge ids into picked_eids tensor.
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
CUB_CALL(
DeviceCopy::Batched, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once));
}
}));
} else { // Non-weighted neighbor sampling.
picked_eids = torch::zeros(num_edges.value(), sub_indptr.options());
const auto sort_needed = type_per_edge && fanouts.size() == 1;
const auto sliced_indptr_ptr =
sort_needed ? nullptr : sliced_indptr.data_ptr<indptr_t>();

const dim3 block(BLOCK_SIZE);
const dim3 grid(
(std::min(num_edges.value(), static_cast<int64_t>(1 << 20)) +
BLOCK_SIZE - 1) /
BLOCK_SIZE);
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = index_t;
// Compute row and random number pairs.
CUDA_KERNEL_CALL(
_ComputeRandomsNS, grid, block, 0, num_edges.value(),
sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),
output_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), random_seed.get_seed(0),
picked_eids.data_ptr<indptr_t>());
}));

picked_eids =
picked_eids.slice(0, 0, static_cast<indptr_t>(num_sampled_edges));

// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if (sort_needed) {
const int num_bits = compute_num_bits();
std::array<int, 4> type_bits = {8, 15, 31, 63};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
torch::kLong};
auto edge_id_dtype = types[type_index];
AT_DISPATCH_INTEGRAL_TYPES(
edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
using edge_id_t = scalar_t;
TORCH_CHECK(
num_bits <= sizeof(edge_id_t) * 8,
"Selected edge_id_t must be capable of storing "
"edge_ids.");
auto picked_offsets = picked_eids.to(edge_id_dtype);
auto sorted_offsets = torch::empty_like(picked_offsets);
CUB_CALL(
DeviceSegmentedSort::SortKeys,
picked_offsets.data_ptr<edge_id_t>(),
sorted_offsets.data_ptr<edge_id_t>(), picked_eids.size(0),
num_rows, output_indptr.data_ptr<indptr_t>(),
output_indptr.data_ptr<indptr_t>() + 1);
auto edge_id_offsets = ExpandIndptrImpl(
output_indptr, picked_eids.scalar_type(), sliced_indptr,
picked_eids.size(0));
picked_eids = sorted_offsets.to(picked_eids.scalar_type()) +
edge_id_offsets;
}));
}
}

output_indices = torch::empty(
picked_eids.size(0),
Expand Down

0 comments on commit cd4b0e1

Please sign in to comment.