Skip to content

Commit

Permalink
Reduce the register usage in compute_distance_standard further
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Sep 3, 2024
1 parent 790e79c commit 7599331
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions cpp/src/neighbors/detail/cagra/compute_distance_standard.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,45 +175,42 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_standard(
}

template <typename DescriptorT>
_RAFT_DEVICE __noinline__ auto compute_distance_standard(
const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) ->
typename DescriptorT::DISTANCE_T
RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker(
const typename DescriptorT::DATA_T* __restrict__ dataset_ptr,
uint32_t dim,
uint32_t query_smem_ptr) -> typename DescriptorT::DISTANCE_T
{
using DATA_T = typename DescriptorT::DATA_T;
using DISTANCE_T = typename DescriptorT::DISTANCE_T;
using LOAD_T = typename DescriptorT::LOAD_T;
using QUERY_T = typename DescriptorT::QUERY_T;
constexpr auto kTeamSize = DescriptorT::kTeamSize;
constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim;

// const auto* __restrict__ query_ptr = reinterpret_cast<const QUERY_T*>(args.smem_ws_ptr);
const auto* __restrict__ dataset_ptr =
DescriptorT::ptr(args) + (static_cast<std::uint64_t>(DescriptorT::ld(args)) * dataset_index);
const auto lane_id = threadIdx.x % kTeamSize;
constexpr auto vlen = device::get_vlen<LOAD_T, DATA_T>();
constexpr auto reg_nelem = raft::ceildiv<uint32_t>(kDatasetBlockDim, kTeamSize * vlen);

DISTANCE_T r = 0;
for (uint32_t elem_offset = 0; elem_offset < args.dim; elem_offset += kDatasetBlockDim) {
constexpr unsigned vlen = device::get_vlen<LOAD_T, DATA_T>();
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(kDatasetBlockDim, kTeamSize * vlen);
for (uint32_t elem_offset = (threadIdx.x % kTeamSize) * vlen; elem_offset < dim;
elem_offset += kDatasetBlockDim) {
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (kTeamSize * e)) * vlen + elem_offset;
if (k >= args.dim) break;
const uint32_t k = e * (kTeamSize * vlen) + elem_offset;
if (k >= dim) break;
dl_buff[e].load(dataset_ptr, k);
}
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (kTeamSize * e)) * vlen + elem_offset;
if (k >= args.dim) break;
const uint32_t k = e * (kTeamSize * vlen) + elem_offset;
if (k >= dim) break;
#pragma unroll
for (uint32_t v = 0; v < vlen; v++) {
// Note this loop can go above the dataset_dim for padded arrays. This is not a problem
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T d;
device::lds(d, args.smem_ws_ptr + sizeof(QUERY_T) * device::swizzling(k + v));
device::lds(d, query_smem_ptr + sizeof(QUERY_T) * device::swizzling(k + v));
r += dist_op<DISTANCE_T, DescriptorT::kMetric>(
d, cuvs::spatial::knn::detail::utils::mapping<DISTANCE_T>{}(dl_buff[e].val.data[v]));
}
Expand All @@ -222,6 +219,17 @@ _RAFT_DEVICE __noinline__ auto compute_distance_standard(
return r;
}

template <typename DescriptorT>
_RAFT_DEVICE __noinline__ auto compute_distance_standard(
const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) ->
typename DescriptorT::DISTANCE_T
{
return compute_distance_standard_worker<DescriptorT>(
DescriptorT::ptr(args) + (static_cast<std::uint64_t>(DescriptorT::ld(args)) * dataset_index),
args.dim,
args.smem_ws_ptr);
}

template <cuvs::distance::DistanceType Metric,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
Expand Down

0 comments on commit 7599331

Please sign in to comment.