From 3db9f6862c94d60a1c7adf83ebbeb57e59d4587e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 10 Sep 2024 17:56:09 -0700 Subject: [PATCH] add dset norms to lists --- cpp/src/neighbors/ivf_common.cuh | 14 +++--- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 45 ++++++++++++++++++- .../ivf_pq/ivf_pq_compute_similarity_impl.cuh | 20 +-------- cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh | 13 ++++++ cpp/test/neighbors/ann_ivf_pq.cuh | 4 ++ 5 files changed, 72 insertions(+), 24 deletions(-) diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index c8f5c3354..20a90c105 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -21,6 +21,7 @@ #include // matrix::detail::select::warpsort::warp_sort_distributed #include +#include namespace cuvs::neighbors::ivf::detail { @@ -219,6 +220,9 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] bool account_for_max_close, rmm::cuda_stream_view stream) { + std::cout << "inside ivf_common post_process_distances" << std::endl; + std::cout << "scaling factor" << scaling_factor << std::endl; + raft::print_device_vector("input_distances", in, 100, std::cout); constexpr bool needs_cast = !std::is_same::value; const bool needs_copy = ((void*)in) != ((void*)out); size_t len = size_t(n_queries) * size_t(topk); @@ -254,7 +258,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); } } break; - case distance::DistanceType::CosineExpanded: + // case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor; if (factor != 1.0) { @@ -269,16 +273,16 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] } } break; case distance::DistanceType::CosineExpanded: { - if (needs_cast) { + // if (needs_cast) { raft::linalg::unaryOp( out, in, len, raft::compose_op(raft::add_const_op{1.0}, raft::cast_op{}), stream); - } else { - raft::linalg::unaryOp(out, in, len, raft::add_const_op{1.0}, stream); - } + // } else { + // raft::linalg::unaryOp(out, in, len, raft::add_const_op{1.0}, stream); + // } } break; default: RAFT_FAIL("Unexpected metric."); } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index d5fc735e9..4f8840d84 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -28,6 +28,7 @@ // TODO (cjnolet): This should be using an exposed API instead of circumventing the public APIs. #include "../../cluster/kmeans_balanced.cuh" +#include "raft/linalg/norm_types.hpp" #include #include @@ -1062,6 +1063,7 @@ struct encode_vectors { uint32_t cluster_ix; raft::device_mdspan, raft::row_major> pq_centers; raft::device_mdspan, raft::row_major> in_vectors; + raft::device_vector_view cluster_center; __device__ inline encode_vectors( raft::device_mdspan, raft::row_major> pq_centers, @@ -1122,6 +1124,29 @@ struct encode_vectors { } return code; } + + __device__ inline auto operator()(uint8_t code, uint32_t j) -> float + { + float res = 0; + const uint32_t pq_book_size = pq_centers.extent(2); + const uint32_t pq_len = pq_centers.extent(1); + uint32_t partition_ix; + switch (codebook_kind) { + case codebook_gen::PER_CLUSTER: { + partition_ix = cluster_ix; + } break; + case codebook_gen::PER_SUBSPACE: { + partition_ix = j; + } break; + default: __builtin_unreachable(); + } + for (uint32_t k = 0; k < pq_len; k++) { + // NB: the L2 quantifiers on residuals are always trained on L2 metric. + res = res + cluster_center[threadIdx.x * pq_len + k] + + pq_centers(partition_ix, k, uint32_t(code)); + } + return res; + } }; template @@ -1133,7 +1158,8 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne raft::device_vector_view inds_ptrs, raft::device_vector_view data_ptrs, raft::device_mdspan, raft::row_major> pq_centers, - codebook_gen codebook_kind) + codebook_gen codebook_kind, + std::optional> dataset_norms) { constexpr uint32_t kSubWarpSize = std::min(raft::WarpSize, 1u << PqBits); using subwarp_align = raft::Pow2; @@ -1154,6 +1180,10 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne } else { pq_indices[out_ix] = std::get(src_offset_or_indices)[row_ix]; } + if (dataset_norms.has_value()) { + auto norms = dataset_norms.value()(cluster_ix); + norms[out_ix] = std::get(src_offset_or_indices) + row_ix; + } } // write the codes (one record per subwarp): @@ -1577,6 +1607,16 @@ void extend(raft::resources const& handle, cluster_centers.data(), n_clusters, index->dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = static_cast((int)index->metric()); + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + std::optional> X_norm = std::nullopt; + raft::linalg::rowNorm(dataset_norms, + batch_data_view.data_handle(), + index->dim(), + batch.size(), + raft::linalg::NormType::L2Norm, + true, + stream); + } cuvs::cluster::kmeans_balanced::predict(handle, kmeans_params, batch_data_view, @@ -1641,6 +1681,7 @@ void extend(raft::resources const& handle, new_data_labels.data() + vec_batch.offset(), IdxT(vec_batch.size()), batches_mr); + vec_batches.prefetch_next_batch(); // User needs to make sure kernel finishes its work before we overwrite batch in the next // iteration if different streams are used for kernel and copy. @@ -1804,6 +1845,8 @@ auto build(raft::resources const& handle, if (params.add_data_on_build) { detail::extend(handle, &index, dataset.data_handle(), nullptr, n_rows); } + + if (index.metric() == distance::DistanceType::CosineExpanded) { compute_c_pq_c_norms(); } return index; } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh index f5ea9d906..cbffe99ac 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh @@ -285,6 +285,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, const uint8_t* const* pq_dataset, const uint32_t* cluster_labels, const uint32_t* _chunk_indices, + const float* dataset_norms, const float* queries, const uint32_t* index_list, float* query_kths, @@ -420,6 +421,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, diff -= pq_c; score += diff * diff; } break; + case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { // NB: we negate the scores as we hardcoded select-topk to always compute the minimum float q; @@ -433,28 +435,10 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, } score -= q * pq_c; } break; - case distance::DistanceType::CosineExpanded: { - // NB: we negate the scores as we hardcoded select-topk to always compute the minimum - float q, c, quantized; - if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(lut_end)[j]; - q = pvals.x; - c = pvals.y; - } else { - q = query[j]; - c = cluster_center[j]; - } - quantized = c + pq_c; - score -= q * quantized; - q_norm += q * q; - c_norm += quantized * quantized; - } break; default: __builtin_unreachable(); } } while (++j < j_end); lut_scores[i] = LutT(score); - if (metric == distance::DistanceType::CosineExpanded) - lut_scores[i] = LutT(score / sqrt(q_norm * c_norm)); } } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index c95179a7c..287a89a7c 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -466,6 +467,7 @@ void ivfpq_search_worker(raft::resources const& handle, distances_buf.data(), neighbors_ptr); + raft::print_device_vector("distances_buf", distances_buf.data(), 100, std::cout); // Select topk vectors for each query rmm::device_uvector topk_dists(n_queries * topK, stream, mr); @@ -486,7 +488,10 @@ void ivfpq_search_worker(raft::resources const& handle, cuvs::selection::SelectAlgo::kAuto, num_samples_vector); + raft::print_device_vector("topk_dists.data()", topk_dists.data(), 100, std::cout); + // Postprocessing + std::cout << "index.metric" << index.metric() << std::endl; ivf::detail::postprocess_distances( distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream); ivf::detail::postprocess_neighbors(neighbors, @@ -730,6 +735,14 @@ inline void search(raft::resources const& handle, index.rot_dim(), stream); + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::row_normalize( + handle, + raft::make_device_matrix_view(rot_queries.data(), n_queries, index.rot_dim()), + raft::make_device_matrix_view(rot_queries.data(), n_queries, index.rot_dim()), + raft::linalg::NormType::L2Norm); + } + for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) { uint32_t batch_size = min(max_batch_size, queries_batch - offset_b); /* The distance calculation is done in the rotated/transformed space; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index df53dd514..7dfcbc14b 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -642,6 +643,9 @@ class ivf_pq_filter_test : public ::testing::TestWithParam { raft::update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); + raft::print_host_vector("indices_ivf_pq.data()", indices_ivf_pq.data(), 100, std::cout); + raft::print_host_vector("distances_ivf_pq.data()", distances_ivf_pq.data(), 100, std::cout); + // A very conservative lower bound on recall double min_recall = static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists);