Skip to content

Commit

Permalink
Add Rust bindings for CAGRA-Q
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Apr 5, 2024
1 parent 3b65faf commit a4e3bc8
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 7 deletions.
18 changes: 15 additions & 3 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ mod tests {
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;

#[test]
fn test_cagra_index() {
fn test_cagra(build_params: IndexParams) {
let res = Resources::new().unwrap();

// Create a new random dataset to index
Expand All @@ -117,7 +116,6 @@ mod tests {
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));

// build the cagra index
let build_params = IndexParams::new().unwrap();
let index =
Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");

Expand Down Expand Up @@ -159,4 +157,18 @@ mod tests {
assert_eq!(neighbors_host[[2, 0]], 2);
assert_eq!(neighbors_host[[3, 0]], 3);
}

#[test]
fn test_cagra_index() {
let build_params = IndexParams::new().unwrap();
test_cagra(build_params);
}

#[test]
fn test_cagra_compression() {
use crate::cagra::CompressionParams;
let build_params = IndexParams::new().unwrap()
.set_compression(CompressionParams::new().unwrap());
test_cagra(build_params);
}
}
115 changes: 112 additions & 3 deletions rust/cuvs/src/cagra/index_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,87 @@ use std::io::{stderr, Write};
pub type BuildAlgo = ffi::cuvsCagraGraphBuildAlgo;

/// Supplemental parameters to build CAGRA Index
pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t);
pub struct CompressionParams(pub ffi::cuvsCagraCompressionParams_t);

impl CompressionParams {
/// Returns a new CompressionParams
pub fn new() -> Result<CompressionParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraCompressionParams_t>::uninit();
check_cuvs(ffi::cuvsCagraCompressionParamsCreate(params.as_mut_ptr()))?;
Ok(CompressionParams(params.assume_init()))
}
}

/// The bit length of the vector element after compression by PQ.
pub fn set_pq_bits(self, pq_bits: u32) -> CompressionParams {
unsafe {
(*self.0).pq_bits = pq_bits;
}
self
}

/// The dimensionality of the vector after compression by PQ. When zero,
/// an optimal value is selected using a heuristic.
pub fn set_pq_dim(self, pq_dim: u32) -> CompressionParams {
unsafe {
(*self.0).pq_dim = pq_dim;
}
self
}

/// Vector Quantization (VQ) codebook size - number of "coarse cluster
/// centers". When zero, an optimal value is selected using a heuristic.
pub fn set_vq_n_centers(self, vq_n_centers: u32) -> CompressionParams {
unsafe {
(*self.0).vq_n_centers = vq_n_centers;
}
self
}

/// The number of iterations searching for kmeans centers (both VQ & PQ
/// phases).
pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> CompressionParams {
unsafe {
(*self.0).kmeans_n_iters = kmeans_n_iters;
}
self
}

/// The fraction of data to use during iterative kmeans building (VQ
/// phase). When zero, an optimal value is selected using a heuristic.
pub fn set_vq_kmeans_trainset_fraction(
self,
vq_kmeans_trainset_fraction: f64,
) -> CompressionParams {
unsafe {
(*self.0).vq_kmeans_trainset_fraction = vq_kmeans_trainset_fraction;
}
self
}

/// The fraction of data to use during iterative kmeans building (PQ
/// phase). When zero, an optimal value is selected using a heuristic.
pub fn set_pq_kmeans_trainset_fraction(
self,
pq_kmeans_trainset_fraction: f64,
) -> CompressionParams {
unsafe {
(*self.0).pq_kmeans_trainset_fraction = pq_kmeans_trainset_fraction;
}
self
}
}

pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t, Option<CompressionParams>);

impl IndexParams {
/// Returns a new IndexParams
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams(params.assume_init()))
Ok(IndexParams(params.assume_init(), None))
}
}

Expand Down Expand Up @@ -64,6 +136,16 @@ impl IndexParams {
}
self
}

pub fn set_compression(mut self, compression: CompressionParams) -> IndexParams {
unsafe {
(*self.0).compression = compression.0;
}
// Note: we're moving the ownership of compression here to avoid having it cleaned up
// and leaving a dangling pointer
self.1 = Some(compression);
self
}
}

impl fmt::Debug for IndexParams {
Expand All @@ -74,6 +156,12 @@ impl fmt::Debug for IndexParams {
}
}

impl fmt::Debug for CompressionParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "CompressionParams({:?})", unsafe { *self.0 })
}
}

impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) {
Expand All @@ -87,6 +175,19 @@ impl Drop for IndexParams {
}
}

impl Drop for CompressionParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraCompressionParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsCagraCompressionParamsDestroy {:?}",
e
)
.expect("failed to write to stderr");
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -98,14 +199,22 @@ mod tests {
.set_intermediate_graph_degree(128)
.set_graph_degree(16)
.set_build_algo(BuildAlgo::NN_DESCENT)
.set_nn_descent_niter(10);
.set_nn_descent_niter(10)
.set_compression(
CompressionParams::new()
.unwrap()
.set_pq_bits(4)
.set_pq_dim(8),
);

// make sure the setters actually updated internal representation on the c-struct
unsafe {
assert_eq!((*params.0).graph_degree, 16);
assert_eq!((*params.0).intermediate_graph_degree, 128);
assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT);
assert_eq!((*params.0).nn_descent_niter, 10);
assert_eq!((*(*params.0).compression).pq_dim, 8);
assert_eq!((*(*params.0).compression).pq_bits, 4);
}
}
}
2 changes: 1 addition & 1 deletion rust/cuvs/src/cagra/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ mod index_params;
mod search_params;

pub use index::Index;
pub use index_params::{BuildAlgo, IndexParams};
pub use index_params::{BuildAlgo, CompressionParams, IndexParams};
pub use search_params::{HashMode, SearchAlgo, SearchParams};

0 comments on commit a4e3bc8

Please sign in to comment.