diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs index 6a5149f07..959959f60 100644 --- a/rust/cuvs/src/cagra/index.rs +++ b/rust/cuvs/src/cagra/index.rs @@ -106,8 +106,7 @@ mod tests { use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; - #[test] - fn test_cagra_index() { + fn test_cagra(build_params: IndexParams) { let res = Resources::new().unwrap(); // Create a new random dataset to index @@ -117,7 +116,6 @@ mod tests { ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); // build the cagra index - let build_params = IndexParams::new().unwrap(); let index = Index::build(&res, &build_params, &dataset).expect("failed to create cagra index"); @@ -159,4 +157,18 @@ mod tests { assert_eq!(neighbors_host[[2, 0]], 2); assert_eq!(neighbors_host[[3, 0]], 3); } + + #[test] + fn test_cagra_index() { + let build_params = IndexParams::new().unwrap(); + test_cagra(build_params); + } + + #[test] + fn test_cagra_compression() { + use crate::cagra::CompressionParams; + let build_params = IndexParams::new().unwrap() + .set_compression(CompressionParams::new().unwrap()); + test_cagra(build_params); + } } diff --git a/rust/cuvs/src/cagra/index_params.rs b/rust/cuvs/src/cagra/index_params.rs index 2e3367e06..9a481ef9e 100644 --- a/rust/cuvs/src/cagra/index_params.rs +++ b/rust/cuvs/src/cagra/index_params.rs @@ -21,7 +21,79 @@ use std::io::{stderr, Write}; pub type BuildAlgo = ffi::cuvsCagraGraphBuildAlgo; /// Supplemental parameters to build CAGRA Index -pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t); +pub struct CompressionParams(pub ffi::cuvsCagraCompressionParams_t); + +impl CompressionParams { + /// Returns a new CompressionParams + pub fn new() -> Result { + unsafe { + let mut params = std::mem::MaybeUninit::::uninit(); + check_cuvs(ffi::cuvsCagraCompressionParamsCreate(params.as_mut_ptr()))?; + Ok(CompressionParams(params.assume_init())) + } + } + + /// The bit length of the vector element after compression by PQ. + pub fn set_pq_bits(self, pq_bits: u32) -> CompressionParams { + unsafe { + (*self.0).pq_bits = pq_bits; + } + self + } + + /// The dimensionality of the vector after compression by PQ. When zero, + /// an optimal value is selected using a heuristic. + pub fn set_pq_dim(self, pq_dim: u32) -> CompressionParams { + unsafe { + (*self.0).pq_dim = pq_dim; + } + self + } + + /// Vector Quantization (VQ) codebook size - number of "coarse cluster + /// centers". When zero, an optimal value is selected using a heuristic. + pub fn set_vq_n_centers(self, vq_n_centers: u32) -> CompressionParams { + unsafe { + (*self.0).vq_n_centers = vq_n_centers; + } + self + } + + /// The number of iterations searching for kmeans centers (both VQ & PQ + /// phases). + pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> CompressionParams { + unsafe { + (*self.0).kmeans_n_iters = kmeans_n_iters; + } + self + } + + /// The fraction of data to use during iterative kmeans building (VQ + /// phase). When zero, an optimal value is selected using a heuristic. + pub fn set_vq_kmeans_trainset_fraction( + self, + vq_kmeans_trainset_fraction: f64, + ) -> CompressionParams { + unsafe { + (*self.0).vq_kmeans_trainset_fraction = vq_kmeans_trainset_fraction; + } + self + } + + /// The fraction of data to use during iterative kmeans building (PQ + /// phase). When zero, an optimal value is selected using a heuristic. + pub fn set_pq_kmeans_trainset_fraction( + self, + pq_kmeans_trainset_fraction: f64, + ) -> CompressionParams { + unsafe { + (*self.0).pq_kmeans_trainset_fraction = pq_kmeans_trainset_fraction; + } + self + } +} + +pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t, Option); impl IndexParams { /// Returns a new IndexParams @@ -29,7 +101,7 @@ impl IndexParams { unsafe { let mut params = std::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?; - Ok(IndexParams(params.assume_init())) + Ok(IndexParams(params.assume_init(), None)) } } @@ -64,6 +136,16 @@ impl IndexParams { } self } + + pub fn set_compression(mut self, compression: CompressionParams) -> IndexParams { + unsafe { + (*self.0).compression = compression.0; + } + // Note: we're moving the ownership of compression here to avoid having it cleaned up + // and leaving a dangling pointer + self.1 = Some(compression); + self + } } impl fmt::Debug for IndexParams { @@ -74,6 +156,12 @@ impl fmt::Debug for IndexParams { } } +impl fmt::Debug for CompressionParams { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CompressionParams({:?})", unsafe { *self.0 }) + } +} + impl Drop for IndexParams { fn drop(&mut self) { if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) { @@ -87,6 +175,19 @@ impl Drop for IndexParams { } } +impl Drop for CompressionParams { + fn drop(&mut self) { + if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraCompressionParamsDestroy(self.0) }) { + write!( + stderr(), + "failed to call cuvsCagraCompressionParamsDestroy {:?}", + e + ) + .expect("failed to write to stderr"); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -98,7 +199,13 @@ mod tests { .set_intermediate_graph_degree(128) .set_graph_degree(16) .set_build_algo(BuildAlgo::NN_DESCENT) - .set_nn_descent_niter(10); + .set_nn_descent_niter(10) + .set_compression( + CompressionParams::new() + .unwrap() + .set_pq_bits(4) + .set_pq_dim(8), + ); // make sure the setters actually updated internal representation on the c-struct unsafe { @@ -106,6 +213,8 @@ mod tests { assert_eq!((*params.0).intermediate_graph_degree, 128); assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT); assert_eq!((*params.0).nn_descent_niter, 10); + assert_eq!((*(*params.0).compression).pq_dim, 8); + assert_eq!((*(*params.0).compression).pq_bits, 4); } } } diff --git a/rust/cuvs/src/cagra/mod.rs b/rust/cuvs/src/cagra/mod.rs index 417ed9b0d..c7db85842 100644 --- a/rust/cuvs/src/cagra/mod.rs +++ b/rust/cuvs/src/cagra/mod.rs @@ -82,5 +82,5 @@ mod index_params; mod search_params; pub use index::Index; -pub use index_params::{BuildAlgo, IndexParams}; +pub use index_params::{BuildAlgo, CompressionParams, IndexParams}; pub use search_params::{HashMode, SearchAlgo, SearchParams};