Skip to content

Commit

Permalink
add dset norms to lists
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Sep 11, 2024
1 parent 259708a commit 3db9f68
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 24 deletions.
14 changes: 9 additions & 5 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/matrix/detail/select_warpsort.cuh> // matrix::detail::select::warpsort::warp_sort_distributed

#include <cub/cub.cuh>
#include <raft/util/cudart_utils.hpp>

namespace cuvs::neighbors::ivf::detail {

Expand Down Expand Up @@ -219,6 +220,9 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
bool account_for_max_close,
rmm::cuda_stream_view stream)
{
std::cout << "inside ivf_common post_process_distances" << std::endl;
std::cout << "scaling factor" << scaling_factor << std::endl;
raft::print_device_vector("input_distances", in, 100, std::cout);
constexpr bool needs_cast = !std::is_same<ScoreInT, ScoreOutT>::value;
const bool needs_copy = ((void*)in) != ((void*)out);
size_t len = size_t(n_queries) * size_t(topk);
Expand Down Expand Up @@ -254,7 +258,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream);
}
} break;
case distance::DistanceType::CosineExpanded:
// case distance::DistanceType::CosineExpanded:
case distance::DistanceType::InnerProduct: {
float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor;
if (factor != 1.0) {
Expand All @@ -269,16 +273,16 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
}
} break;
case distance::DistanceType::CosineExpanded: {
if (needs_cast) {
// if (needs_cast) {
raft::linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::add_const_op<ScoreOutT>{1.0}, raft::cast_op<ScoreOutT>{}),
stream);
} else {
raft::linalg::unaryOp(out, in, len, raft::add_const_op<ScoreInT>{1.0}, stream);
}
// } else {
// raft::linalg::unaryOp(out, in, len, raft::add_const_op<ScoreInT>{1.0}, stream);
// }
} break;
default: RAFT_FAIL("Unexpected metric.");
}
Expand Down
45 changes: 44 additions & 1 deletion cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

// TODO (cjnolet): This should be using an exposed API instead of circumventing the public APIs.
#include "../../cluster/kmeans_balanced.cuh"
#include "raft/linalg/norm_types.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/logger-ext.hpp>
Expand Down Expand Up @@ -1062,6 +1063,7 @@ struct encode_vectors {
uint32_t cluster_ix;
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers;
raft::device_mdspan<const float, raft::extent_3d<IdxT>, raft::row_major> in_vectors;
raft::device_vector_view<const float, uint32_t> cluster_center;

__device__ inline encode_vectors(
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers,
Expand Down Expand Up @@ -1122,6 +1124,29 @@ struct encode_vectors {
}
return code;
}

__device__ inline auto operator()(uint8_t code, uint32_t j) -> float
{
float res = 0;
const uint32_t pq_book_size = pq_centers.extent(2);
const uint32_t pq_len = pq_centers.extent(1);
uint32_t partition_ix;
switch (codebook_kind) {
case codebook_gen::PER_CLUSTER: {
partition_ix = cluster_ix;
} break;
case codebook_gen::PER_SUBSPACE: {
partition_ix = j;
} break;
default: __builtin_unreachable();
}
for (uint32_t k = 0; k < pq_len; k++) {
// NB: the L2 quantifiers on residuals are always trained on L2 metric.
res = res + cluster_center[threadIdx.x * pq_len + k] +
pq_centers(partition_ix, k, uint32_t(code));
}
return res;
}
};

template <uint32_t BlockSize, uint32_t PqBits, typename IdxT>
Expand All @@ -1133,7 +1158,8 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne
raft::device_vector_view<IdxT*, uint32_t, raft::row_major> inds_ptrs,
raft::device_vector_view<uint8_t*, uint32_t, raft::row_major> data_ptrs,
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers,
codebook_gen codebook_kind)
codebook_gen codebook_kind,
std::optional<raft::device_vector_view<float, uint32_t, raft::row_major>> dataset_norms)
{
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(raft::WarpSize, 1u << PqBits);
using subwarp_align = raft::Pow2<kSubWarpSize>;
Expand All @@ -1154,6 +1180,10 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne
} else {
pq_indices[out_ix] = std::get<const IdxT*>(src_offset_or_indices)[row_ix];
}
if (dataset_norms.has_value()) {
auto norms = dataset_norms.value()(cluster_ix);
norms[out_ix] = std::get<IdxT>(src_offset_or_indices) + row_ix;
}
}

// write the codes (one record per subwarp):
Expand Down Expand Up @@ -1577,6 +1607,16 @@ void extend(raft::resources const& handle,
cluster_centers.data(), n_clusters, index->dim());
cuvs::cluster::kmeans::balanced_params kmeans_params;
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>((int)index->metric());
if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) {
std::optional<raft::device_vector_view<const MathT, IndexT>> X_norm = std::nullopt;
raft::linalg::rowNorm(dataset_norms,
batch_data_view.data_handle(),
index->dim(),
batch.size(),
raft::linalg::NormType::L2Norm,
true,
stream);
}
cuvs::cluster::kmeans_balanced::predict(handle,
kmeans_params,
batch_data_view,
Expand Down Expand Up @@ -1641,6 +1681,7 @@ void extend(raft::resources const& handle,
new_data_labels.data() + vec_batch.offset(),
IdxT(vec_batch.size()),
batches_mr);

vec_batches.prefetch_next_batch();
// User needs to make sure kernel finishes its work before we overwrite batch in the next
// iteration if different streams are used for kernel and copy.
Expand Down Expand Up @@ -1804,6 +1845,8 @@ auto build(raft::resources const& handle,
if (params.add_data_on_build) {
detail::extend<T, IdxT>(handle, &index, dataset.data_handle(), nullptr, n_rows);
}

if (index.metric() == distance::DistanceType::CosineExpanded) { compute_c_pq_c_norms(); }
return index;
}

Expand Down
20 changes: 2 additions & 18 deletions cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim,
const uint8_t* const* pq_dataset,
const uint32_t* cluster_labels,
const uint32_t* _chunk_indices,
const float* dataset_norms,
const float* queries,
const uint32_t* index_list,
float* query_kths,
Expand Down Expand Up @@ -420,6 +421,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim,
diff -= pq_c;
score += diff * diff;
} break;
case distance::DistanceType::CosineExpanded:
case distance::DistanceType::InnerProduct: {
// NB: we negate the scores as we hardcoded select-topk to always compute the minimum
float q;
Expand All @@ -433,28 +435,10 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim,
}
score -= q * pq_c;
} break;
case distance::DistanceType::CosineExpanded: {
// NB: we negate the scores as we hardcoded select-topk to always compute the minimum
float q, c, quantized;
if constexpr (PrecompBaseDiff) {
float2 pvals = reinterpret_cast<float2*>(lut_end)[j];
q = pvals.x;
c = pvals.y;
} else {
q = query[j];
c = cluster_center[j];
}
quantized = c + pq_c;
score -= q * quantized;
q_norm += q * q;
c_norm += quantized * quantized;
} break;
default: __builtin_unreachable();
}
} while (++j < j_end);
lut_scores[i] = LutT(score);
if (metric == distance::DistanceType::CosineExpanded)
lut_scores[i] = LutT(score / sqrt(q_norm * c_norm));
}
}

Expand Down
13 changes: 13 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/norm_types.hpp>
#include <raft/linalg/normalize.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/util/cache.hpp>
Expand Down Expand Up @@ -466,6 +467,7 @@ void ivfpq_search_worker(raft::resources const& handle,
distances_buf.data(),
neighbors_ptr);

raft::print_device_vector("distances_buf", distances_buf.data(), 100, std::cout);
// Select topk vectors for each query
rmm::device_uvector<ScoreT> topk_dists(n_queries * topK, stream, mr);

Expand All @@ -486,7 +488,10 @@ void ivfpq_search_worker(raft::resources const& handle,
cuvs::selection::SelectAlgo::kAuto,
num_samples_vector);

raft::print_device_vector("topk_dists.data()", topk_dists.data(), 100, std::cout);

// Postprocessing
std::cout << "index.metric" << index.metric() << std::endl;
ivf::detail::postprocess_distances(
distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream);
ivf::detail::postprocess_neighbors(neighbors,
Expand Down Expand Up @@ -730,6 +735,14 @@ inline void search(raft::resources const& handle,
index.rot_dim(),
stream);

if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::row_normalize(
handle,
raft::make_device_matrix_view<const float>(rot_queries.data(), n_queries, index.rot_dim()),
raft::make_device_matrix_view<float>(rot_queries.data(), n_queries, index.rot_dim()),
raft::linalg::NormType::L2Norm);
}

for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) {
uint32_t batch_size = min(max_batch_size, queries_batch - offset_b);
/* The distance calculation is done in the rotated/transformed space;
Expand Down
4 changes: 4 additions & 0 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/linalg/add.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <thrust/sequence.h>
Expand Down Expand Up @@ -642,6 +643,9 @@ class ivf_pq_filter_test : public ::testing::TestWithParam<ivf_pq_inputs> {
raft::update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_);
raft::resource::sync_stream(handle_);

raft::print_host_vector("indices_ivf_pq.data()", indices_ivf_pq.data(), 100, std::cout);
raft::print_host_vector("distances_ivf_pq.data()", distances_ivf_pq.data(), 100, std::cout);

// A very conservative lower bound on recall
double min_recall =
static_cast<double>(ps.search_params.n_probes) / static_cast<double>(ps.index_params.n_lists);
Expand Down

0 comments on commit 3db9f68

Please sign in to comment.