Skip to content

Commit

Permalink
Further reduce the size size of the dataset descriptor and add explic…
Browse files Browse the repository at this point in the history
…it loading from shmem for more of its members
  • Loading branch information
achirkin committed Sep 2, 2024
1 parent dc75f7a commit 6630a99
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 56 deletions.
64 changes: 54 additions & 10 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// TODO: This shouldn't be invoking spatial/knn
#include "../ann_utils.cuh"

#include <raft/util/device_loads_stores.cuh>
#include <raft/util/vectorized.cuh>

#include <type_traits>
Expand Down Expand Up @@ -67,6 +68,30 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
}
};

struct smem_and_team_size_t {
uint32_t value;
RAFT_INLINE_FUNCTION constexpr smem_and_team_size_t(uint32_t smem_size_bytes,
uint32_t team_size_bitshift)
: value{(team_size_bitshift << 24) | smem_size_bytes}
{
}
/** Total dynamic shared memory required by the descriptor. */
RAFT_INLINE_FUNCTION constexpr auto smem_ws_size_in_bytes() const noexcept -> uint32_t
{
return value & 0xffffffu;
}
RAFT_INLINE_FUNCTION constexpr auto team_size_bitshift() const noexcept -> uint32_t
{
return (value >> 24) & 0xffu;
}
/** How many threads are involved in computing a single distance. */
RAFT_INLINE_FUNCTION constexpr auto team_size() const noexcept -> uint32_t
{
return 1u << team_size_bitshift();
}
};
static_assert(sizeof(smem_and_team_size_t) == sizeof(uint32_t));

using setup_workspace_type = const base_type*(const base_type*, void*, const DATA_T*, uint32_t);
using compute_distance_type = DISTANCE_T(const args_t, const INDEX_T);

Expand All @@ -79,10 +104,7 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
* given by the dataset_index. */
compute_distance_type* compute_distance_impl;
void* extra_ptr3;
/** How many threads are involved in computing a single distance. */
uint32_t team_size;
/** Total dynamic shared memory required by the descriptor. */
uint32_t smem_ws_size_in_bytes;
smem_and_team_size_t smem_and_team_size;

/** Number of records in the database. */
INDEX_T size;
Expand All @@ -91,17 +113,39 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
compute_distance_type* compute_distance_impl,
INDEX_T size,
uint32_t dim,
uint32_t team_size,
uint32_t team_size_bitshift,
uint32_t smem_ws_size_in_bytes)
: setup_workspace_impl(setup_workspace_impl),
compute_distance_impl(compute_distance_impl),
size(size),
team_size(team_size),
smem_ws_size_in_bytes(smem_ws_size_in_bytes),
smem_and_team_size(smem_ws_size_in_bytes, team_size_bitshift),
args{nullptr, nullptr, 0, dim, 0, 0}
{
}

/** Total dynamic shared memory required by the descriptor. */
RAFT_INLINE_FUNCTION constexpr auto smem_ws_size_in_bytes() const noexcept -> uint32_t
{
return smem_and_team_size.smem_ws_size_in_bytes();
}
RAFT_INLINE_FUNCTION constexpr auto team_size_bitshift() const noexcept -> uint32_t
{
return smem_and_team_size.team_size_bitshift();
}
RAFT_DEVICE_INLINE_FUNCTION constexpr auto team_size_bitshift_from_smem() const noexcept
-> uint32_t
{
uint32_t sts;
raft::lds(sts, reinterpret_cast<const uint32_t*>(&smem_and_team_size));
return reinterpret_cast<smem_and_team_size_t&>(sts).team_size_bitshift();
}

/** How many threads are involved in computing a single distance. */
RAFT_INLINE_FUNCTION constexpr auto team_size() const noexcept -> uint32_t
{
return smem_and_team_size.team_size();
}

RAFT_DEVICE_INLINE_FUNCTION auto setup_workspace(void* smem_ptr,
const DATA_T* queries_ptr,
uint32_t query_id) const -> const base_type*
Expand All @@ -113,7 +157,7 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
-> DISTANCE_T
{
auto per_thread_distances = valid ? compute_distance_impl(args.load(), dataset_index) : 0;
return device::team_sum(per_thread_distances, this->team_size);
return device::team_sum(per_thread_distances, team_size_bitshift_from_smem());
}
};

Expand All @@ -130,8 +174,8 @@ struct dataset_descriptor_host {
rmm::cuda_stream_view stream,
uint32_t dataset_block_dim)
: stream_{stream},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes},
team_size{dd_host.team_size},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()},
team_size{dd_host.team_size()},
dataset_block_dim{dataset_block_dim}
{
RAFT_CUDA_TRY(cudaMallocAsync(&dev_ptr, sizeof(DescriptorImpl), stream_));
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/cagra/compute_distance_standard.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <raft/core/logger-macros.hpp>
#include <raft/core/operators.hpp>
#include <raft/util/device_loads_stores.cuh>

#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <type_traits>
Expand Down Expand Up @@ -113,7 +113,7 @@ struct standard_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, I
compute_distance_impl,
size,
dim,
TeamSize,
raft::Pow2<TeamSize>::Log2,
get_smem_ws_size_in_bytes(dim))
{
standard_dataset_descriptor_t::ptr(args) = ptr;
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cuvs/distance/distance.hpp>
#include <raft/util/device_loads_stores.cuh>
#include <raft/util/integer_utils.hpp>
#include <raft/util/pow2_utils.cuh>

namespace cuvs::neighbors::cagra::detail {

Expand Down Expand Up @@ -124,7 +125,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
compute_distance_impl,
size,
dim,
TeamSize,
raft::Pow2<TeamSize>::Log2,
get_smem_ws_size_in_bytes(dim))
{
cagra_q_dataset_descriptor_t::encoded_dataset_ptr(args) = encoded_dataset_ptr;
Expand Down
81 changes: 47 additions & 34 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ RAFT_DEVICE_INLINE_FUNCTION constexpr T swizzling(T x)
}
}

template <uint32_t TeamSize, typename T>
RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x) -> T
{
#pragma unroll
for (uint32_t stride = TeamSize >> 1; stride > 0; stride >>= 1) {
x += raft::shfl_xor(x, stride, TeamSize);
}
return x;
}

template <typename T>
RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size_bitshift) -> T
{
switch (team_size_bitshift) {
case 5: x += raft::shfl_xor(x, 16);
case 4: x += raft::shfl_xor(x, 8);
case 3: x += raft::shfl_xor(x, 4);
case 2: x += raft::shfl_xor(x, 2);
case 1: x += raft::shfl_xor(x, 1);
default: return x;
}
}

template <typename IndexT,
typename DistanceT,
typename DATASET_DESCRIPTOR_T>
Expand All @@ -88,10 +111,12 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
const auto team_size = dataset_desc.team_size;
const auto max_i = raft::round_up_safe<uint32_t>(num_pickup, warp_size / team_size);
const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem();
const auto max_i = raft::round_up_safe<uint32_t>(num_pickup, warp_size >> team_size_bits);
const auto compute_distance = dataset_desc.compute_distance_impl;
const auto args = dataset_desc.args.load();

for (uint32_t i = threadIdx.x / team_size; i < max_i; i += blockDim.x / team_size) {
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) {
const bool valid_i = (i < num_pickup);

IndexT best_index_team_local;
Expand All @@ -109,15 +134,19 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
}
}

auto norm2 = dataset_desc.compute_distance(seed_index, valid_i);
// This is the `dataset_desc.compute_distance` manually inlined to move the fetching of
// dataset_desc from smem out of the loop.
// const auto norm2 = dataset_desc.compute_distance(seed_index, valid_i);
const auto norm2 =
device::team_sum(valid_i ? compute_distance(args, seed_index) : 0, team_size_bits);

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
best_index_team_local = seed_index;
}
}

const unsigned lane_id = threadIdx.x % team_size;
const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
Expand Down Expand Up @@ -168,18 +197,25 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
__syncthreads();

// Compute the distance to child nodes
const auto team_size = dataset_desc.team_size;
const auto max_i = raft::round_up_safe(knn_k * search_width, warp_size / team_size);
for (uint32_t tid = threadIdx.x; tid < max_i * team_size; tid += blockDim.x) {
const auto i = tid / team_size;
const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem();
const auto max_i = raft::round_up_safe(knn_k * search_width, warp_size >> team_size_bits)
<< team_size_bits;
const auto compute_distance = dataset_desc.compute_distance_impl;
const auto args = dataset_desc.args.load();
for (uint32_t tid = threadIdx.x; tid < max_i; tid += blockDim.x) {
const auto i = tid >> team_size_bits;
const bool valid_i = (i < (knn_k * search_width));
IndexT child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

auto norm2 = dataset_desc.compute_distance(child_id, child_id != invalid_index);
// This is the `dataset_desc.compute_distance` manually inlined to move the fetching of
// dataset_desc from smem out of the loop.
// const auto norm2 = dataset_desc.compute_distance(child_id, child_id != invalid_index);
const auto norm2 = device::team_sum(
(child_id != invalid_index) ? compute_distance(args, child_id) : 0, team_size_bits);

// Store the distance
const unsigned lane_id = threadIdx.x % team_size;
const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (child_id != invalid_index) {
result_child_distances_ptr[i] = norm2;
Expand All @@ -190,29 +226,6 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
}
}

template <uint32_t TeamSize, typename T>
RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x) -> T
{
#pragma unroll
for (uint32_t stride = TeamSize >> 1; stride > 0; stride >>= 1) {
x += raft::shfl_xor(x, stride, TeamSize);
}
return x;
}

template <typename T>
RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size) -> T
{
switch (team_size) {
case 32: x += raft::shfl_xor(x, 16);
case 16: x += raft::shfl_xor(x, 8);
case 8: x += raft::shfl_xor(x, 4);
case 4: x += raft::shfl_xor(x, 2);
case 2: x += raft::shfl_xor(x, 1);
default: return x;
}
}

RAFT_DEVICE_INLINE_FUNCTION void lds(float& x, uint32_t addr)
{
asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x) : "r"(addr));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ RAFT_KERNEL search_kernel(
dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);

auto result_indices_buffer =
reinterpret_cast<INDEX_T*>(smem + dataset_desc->smem_ws_size_in_bytes);
reinterpret_cast<INDEX_T*>(smem + dataset_desc->smem_ws_size_in_bytes());
auto result_distances_buffer =
reinterpret_cast<DISTANCE_T*>(result_indices_buffer + result_buffer_size_32);
auto parent_indices_buffer =
Expand Down
15 changes: 8 additions & 7 deletions cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ RAFT_KERNEL random_pickup_kernel(
using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T;
using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T;

const auto team_size = dataset_desc->team_size;
const auto team_size_bits = dataset_desc->team_size_bitshift();
const auto ldb = hashmap::get_size(hash_bitlen);
const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) / team_size;
const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) >> team_size_bits;
const uint32_t query_id = blockIdx.y;
if (global_team_index >= num_pickup) { return; }
extern __shared__ uint8_t smem[];
Expand All @@ -140,7 +140,7 @@ RAFT_KERNEL random_pickup_kernel(
}

const auto store_gmem_index = global_team_index + (ldr * query_id);
if (threadIdx.x % team_size == 0) {
if ((threadIdx.x & ((1u << team_size_bits) - 1u)) == 0) {
if (hashmap::insert(
visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) {
result_distances_ptr[store_gmem_index] = best_norm2_team_local;
Expand Down Expand Up @@ -316,10 +316,11 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel(
using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T;
using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T;

const auto team_size = dataset_desc->team_size;
const auto team_size_bits = dataset_desc->team_size_bitshift();
const auto team_size = 1u << team_size_bits;
const uint32_t ldb = hashmap::get_size(hash_bitlen);
const auto tid = threadIdx.x + blockDim.x * blockIdx.x;
const auto global_team_id = tid / team_size;
const auto global_team_id = tid >> team_size_bits;
const auto query_id = blockIdx.y;

extern __shared__ uint8_t smem[];
Expand Down Expand Up @@ -353,12 +354,12 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel(
DISTANCE_T norm2 = dataset_desc->compute_distance(child_id, compute_distance_flag);

if (compute_distance_flag) {
if (threadIdx.x % team_size == 0) {
if ((threadIdx.x & (team_size - 1)) == 0) {
result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id;
result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2;
}
} else {
if (threadIdx.x % team_size == 0) {
if ((threadIdx.x & (team_size - 1)) == 0) {
result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value<DISTANCE_T>();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ RAFT_KERNEL search_kernel(
dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);

auto result_indices_buffer =
reinterpret_cast<INDEX_T*>(smem + dataset_desc->smem_ws_size_in_bytes);
reinterpret_cast<INDEX_T*>(smem + dataset_desc->smem_ws_size_in_bytes());
auto result_distances_buffer =
reinterpret_cast<DISTANCE_T*>(result_indices_buffer + result_buffer_size_32);
auto visited_hash_buffer =
Expand Down

0 comments on commit 6630a99

Please sign in to comment.