Skip to content

Commit

Permalink
compute exact norms
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Sep 12, 2024
1 parent 6f02147 commit a09157e
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 17 deletions.
2 changes: 2 additions & 0 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ struct list {
raft::device_mdarray<index_type, raft::extent_1d<size_type>, raft::row_major> indices;
/** The actual size of the content. */
std::atomic<size_type> size;
/** Dataset norms. */
raft::device_mdarray<float, list_extents, raft::row_major> norms;

/** Allocate a new list capable of holding at least `n_rows` data records and indices. */
list(raft::resources const& res, const spec_type& spec, size_type n_rows);
Expand Down
6 changes: 6 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> centers_rot() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers_rot() const noexcept;

/** Pointers to the inverted lists (clusters) true norms [n_lists] */
std::optional<raft::device_vector_view<float*, uint32_t, raft::row_major>> data_norms_ptrs() noexcept;
std::optional<raft::device_vector_view<const float* const, uint32_t, raft::row_major>> data_norms_ptrs() const noexcept;

/** fetch size of a particular IVF list in bytes using the list extents.
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -479,11 +483,13 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<float, uint32_t, raft::row_major> centers_;
raft::device_matrix<float, uint32_t, raft::row_major> centers_rot_;
raft::device_matrix<float, uint32_t, raft::row_major> rotation_matrix_;
std::optional<raft::device_vector<float, uint32_t>> center_norms_;

// Computed members for accelerating search.
raft::device_vector<uint8_t*, uint32_t, raft::row_major> data_ptrs_;
raft::device_vector<IdxT*, uint32_t, raft::row_major> inds_ptrs_;
raft::host_vector<IdxT, uint32_t, raft::row_major> accum_sorted_sizes_;
std::optional<raft::device_vector<float*, uint32_t, raft::row_major>> dataset_norms_{std::nullopt};

/** Throw an error if the index content is inconsistent. */
void check_consistency();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ void recompute_internal_state(const raft::resources& res, Index& index)
auto& list = index.lists()[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
const auto norms_ptr = list ?
raft::copy(&data_ptrs(label), &data_ptr, 1, stream);
raft::copy(&inds_ptrs(label), &inds_ptr, 1, stream);
}
Expand Down
22 changes: 17 additions & 5 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ void set_centers(raft::resources const& handle, index<IdxT>* index, const float*
raft::linalg::L2Norm,
true,
stream);
if (index->metric == cuvs::distance::DistanceType::CosineExpanded)
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(),
sizeof(float) * index->dim_ext(),
center_norms.data(),
Expand Down Expand Up @@ -1159,6 +1160,7 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne
raft::device_vector_view<uint8_t*, uint32_t, raft::row_major> data_ptrs,
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers,
codebook_gen codebook_kind,
const float* src_norms,
std::optional<raft::device_vector_view<float, uint32_t, raft::row_major>> dataset_norms)
{
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(raft::WarpSize, 1u << PqBits);
Expand All @@ -1181,8 +1183,8 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne
pq_indices[out_ix] = std::get<const IdxT*>(src_offset_or_indices)[row_ix];
}
if (dataset_norms.has_value()) {
auto norms = dataset_norms.value()(cluster_ix);
norms[out_ix] = std::get<IdxT>(src_offset_or_indices) + row_ix;
auto norms = (dataset_norms.value())(cluster_ix);
norms[out_ix] = src_norms[row_ix];
}
}

Expand Down Expand Up @@ -1298,6 +1300,7 @@ void process_and_fill_codes(raft::resources const& handle,
const T* new_vectors,
std::variant<IdxT, const IdxT*> src_offset_or_indices,
const uint32_t* new_labels,
const float* new_norms,
IdxT n_rows,
rmm::device_async_resource_ref mr)
{
Expand Down Expand Up @@ -1335,7 +1338,9 @@ void process_and_fill_codes(raft::resources const& handle,
index.inds_ptrs(),
index.data_ptrs(),
index.pq_centers(),
index.codebook_kind());
index.codebook_kind(),
new_norms,
index.dataset_norms());
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

Expand Down Expand Up @@ -1577,6 +1582,10 @@ void extend(raft::resources const& handle,
copy_stream = raft::resource::get_stream_from_stream_pool(handle);
}
}

if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) {
auto dataset_norms = raft::make_device_vector<float>(handle, n_rows);
}
// Predict the cluster labels for the new data, in batches if necessary
utils::batch_load_iterator<T> vec_batches(
new_vectors, n_rows, index->dim(), max_batch_size, copy_stream, device_memory, enable_prefetch);
Expand Down Expand Up @@ -1607,9 +1616,9 @@ void extend(raft::resources const& handle,
cluster_centers.data(), n_clusters, index->dim());
cuvs::cluster::kmeans::balanced_params kmeans_params;
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>((int)index->metric());
std::optional<raft::device_vector_view<const float, internal_extents_t>> X_norm = std::nullopt;
if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) {
std::optional<raft::device_vector_view<const MathT, IndexT>> X_norm = std::nullopt;
raft::linalg::rowNorm(dataset_norms,
raft::linalg::rowNorm(dataset_norms.data_handle() + batch.offset(),
batch_data_view.data_handle(),
index->dim(),
batch.size(),
Expand Down Expand Up @@ -1668,6 +1677,8 @@ void extend(raft::resources const& handle,
// Fill the extended index with the new data (possibly, in batches)
utils::batch_load_iterator<IdxT> idx_batches(
new_indices, n_rows, 1, max_batch_size, stream, batches_mr);
utils::batch_load_iterator<float> norms_batches(
dataset_norms.data_handle(), n_rows, 1, max_batch_size, stream, batches_mr);
vec_batches.reset();
vec_batches.prefetch_next_batch();
for (const auto& vec_batch : vec_batches) {
Expand All @@ -1679,6 +1690,7 @@ void extend(raft::resources const& handle,
? std::variant<IdxT, const IdxT*>(idx_batch.data())
: std::variant<IdxT, const IdxT*>(IdxT(idx_batch.offset())),
new_data_labels.data() + vec_batch.offset(),
dataset_norms.data() + vec_batch.offset(),
IdxT(vec_batch.size()),
batches_mr);

Expand Down
28 changes: 16 additions & 12 deletions cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
#include "ivf_pq_compute_similarity.cuh"
#include "ivf_pq_fp_8bit.cuh"

#include <cuda_runtime_api.h>
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <cuvs/selection/select_k.hpp>
#include <driver_types.h>
#include <raft/core/cudart_utils.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/logger-ext.hpp>
Expand Down Expand Up @@ -164,25 +166,27 @@ void select_clusters(raft::resources const& handle,
stream);

if (metric == cuvs::distance::DistanceType::CosineExpanded) {
auto centroidsNorm =
// TODO: store dataset norms in a different manner for the cosine metric to avoid the copy here
auto center_norms =
raft::make_device_mdarray<float, uint32_t>(handle, mr, raft::make_extents<uint32_t>(n_lists));
raft::linalg::rowNorm<float, uint32_t>(centroidsNorm.data_handle(),
cluster_centers,
dim,
n_lists,
raft::linalg::L2Norm,
true,
stream);

auto op = [] __device__(float a, float b) { return a / raft::sqrt(b); };

cudaMemcpy2DAsync(center_norms.data_handle(),
sizeof(float),
cluster_centers,
sizeof(float) * dim_ext,
sizeof(float),
n_lists,
cudaMemcpyDefault,
stream);

raft::linalg::matrixVectorOp(qc_distances.data(),
qc_distances.data(),
centroidsNorm.data_handle(),
center_norms.data_handle(),
n_lists,
n_queries,
true,
true,
op,
raft::div_op{},
stream);
}

Expand Down
26 changes: 26 additions & 0 deletions cpp/src/neighbors/ivf_pq_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "cuvs/distance/distance.hpp"
#include <cuvs/neighbors/ivf_pq.hpp>

namespace cuvs::neighbors::ivf_pq {
Expand Down Expand Up @@ -92,6 +93,8 @@ index<IdxT>::index(raft::resources const& handle,
{
check_consistency();
accum_sorted_sizes_(n_lists) = 0;
if (metric == cuvs::distance::DistanceType::CosineExpanded)
dataset_norms_ = raft::make_device_vector<float*, uint32_t>(handle, n_lists);
}

template <typename IdxT>
Expand Down Expand Up @@ -289,6 +292,29 @@ raft::device_matrix_view<const float, uint32_t, raft::row_major> index<IdxT>::ce
return centers_rot_.view();
}

template <typename IdxT>
std::optional<raft::device_vector_view<float*, uint32_t, raft::row_major>>
index<IdxT>::data_norms_ptrs() noexcept
{
std::optional<raft::device_vector_view<float*, uint32_t, raft::row_major>> ret;
if (dataset_norms_.has_value()) ret = dataset_norms_->view();
return ret;
}

template <typename IdxT>
std::optional<raft::device_vector_view<const float* const, uint32_t, raft::row_major>>
index<IdxT>::data_norms_ptrs() const noexcept
{
std::optional<raft::device_vector_view<const float* const, uint32_t, raft::row_major>> ret;
if (dataset_norms_.has_value()) {
auto norms_ptrs_view =
raft::make_mdspan<const float* const, uint32_t, raft::row_major, false, true>(
dataset_norms_->data_handle(), dataset_norms_->extents());
ret = norms_ptrs_view;
}
return ret;
}

template <typename IdxT>
uint32_t index<IdxT>::get_list_size_in_bytes(uint32_t label)
{
Expand Down

0 comments on commit a09157e

Please sign in to comment.