Skip to content

Commit

Permalink
passing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Sep 11, 2024
1 parent 53bcf5d commit 0c2d082
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
16 changes: 5 additions & 11 deletions cpp/src/neighbors/hnsw_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <cuvs/neighbors/hnsw.hpp>

namespace {
template <typename T, typename QueriesT>
template <typename T>
void _search(cuvsResources_t res,
cuvsHnswSearchParams params,
cuvsHnswIndex index,
Expand All @@ -46,7 +46,7 @@ void _search(cuvsResources_t res,
search_params.ef = params.ef;
search_params.num_threads = params.numThreads;

using queries_mdspan_type = raft::host_matrix_view<QueriesT const, int64_t, raft::row_major>;
using queries_mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
using neighbors_mdspan_type = raft::host_matrix_view<uint64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
Expand Down Expand Up @@ -127,16 +127,13 @@ extern "C" cuvsError_t cuvsHnswSearch(cuvsResources_t res,

auto index = *index_c_ptr;
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");
RAFT_EXPECTS(queries.dtype.bits == 32, "number of bits in queries dtype should be 32");

if (index.dtype.code == kDLFloat) {
_search<float, float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (index.dtype.code == kDLUInt) {
_search<uint8_t, int>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (index.dtype.code == kDLInt) {
_search<int8_t, int>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", queries.dtype.code, queries.dtype.bits);
}
Expand All @@ -152,13 +149,10 @@ extern "C" cuvsError_t cuvsHnswDeserialize(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename, dim, metric));
index->dtype.code = kDLFloat;
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t>(res, filename, dim, metric));
index->dtype.code = kDLInt;
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t>(res, filename, dim, metric));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dtype in file %s", filename);
}
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from cuvs.distance_type cimport cuvsDistanceType
cdef extern from "cuvs/neighbors/hnsw.h" nogil:
ctypedef struct cuvsHnswSearchParams:
int32_t ef
int32_t num_threads
int32_t numThreads

ctypedef cuvsHnswSearchParams* cuvsHnswSearchParams_t

Expand Down
17 changes: 12 additions & 5 deletions python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ cdef class SearchParams:
ef=200,
num_threads=0):
self.params.ef = ef
self.params.num_threads = num_threads
self.params.numThreads = num_threads

def __repr__(self):
attr_str = [attr + "=" + str(getattr(self, attr))
Expand All @@ -72,7 +72,7 @@ cdef class SearchParams:

@property
def num_threads(self):
return self.params.num_threads
return self.params.numThreads


cdef class Index:
Expand Down Expand Up @@ -106,6 +106,9 @@ cdef class Index:
def save(filename, cagra.Index index, resources=None):
"""
Saves the CAGRA index to a file as an hnswlib index.
The saved index is immutable and can only be searched by the hnswlib
wrapper in cuVS, as the format is not compatible with the original
hnswlib.
Saving / loading the index is experimental. The serialization format is
subject to change.
Expand Down Expand Up @@ -142,11 +145,13 @@ def save(filename, cagra.Index index, resources=None):
def load(filename, dim, dtype, metric="sqeuclidean", resources=None):
"""
Loads base-layer-only hnswlib index from file, which was originally
saved as a built CAGRA index.
saved as a built CAGRA index. The loaded index is immutable and can only
be searched by the hnswlib wrapper in cuVS, as the format is not
compatible with the original hnswlib.
Saving / loading the index is experimental. The serialization format is
subject to change, therefore loading an index saved with a previous
version of raft is not guaranteed to work.
version of cuVS is not guaranteed to work.
Parameters
----------
Expand Down Expand Up @@ -224,7 +229,9 @@ def from_cagra(cagra.Index index, resources=None):
NOTE: This method uses the filesystem to write the CAGRA index in
`/tmp/<random_number>.bin` before reading it as an hnswlib index,
then deleting the temporary file.
then deleting the temporary file. The returned index is immutable
and can only be searched by the hnswlib wrapper in cuVS, as the
format is not compatible with the original hnswlib.
Saving / loading the index is experimental. The serialization format is
subject to change.
Expand Down

0 comments on commit 0c2d082

Please sign in to comment.