diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a38f4e2d5..ae418b09b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ We try to indicate most contributions here with the contributor names who are no the Facebook Faiss team. Feel free to add entries here if you submit a PR. ## [Unreleased] +- Support for range search in HNSW and Fast scan IVF. ## [1.7.4] - 2023-04-12 ### Added - Added big batch IVF search for conducting efficient search with big batches of queries diff --git a/benchs/link_and_code/README.md b/benchs/link_and_code/README.md index bbf034bc60..697c7bdfc6 100644 --- a/benchs/link_and_code/README.md +++ b/benchs/link_and_code/README.md @@ -39,6 +39,8 @@ The code runs on top of Faiss. The HNSW index can be extended with a `ReconstructFromNeighbors` C++ object that refines the distances. The training is implemented in Python. +Update: 2023-12-28: the current Faiss dropped support for reconstruction with +this method. Reproducing Table 2 in the paper -------------------------------- diff --git a/contrib/evaluation.py b/contrib/evaluation.py index 50e8a93319..435c390594 100644 --- a/contrib/evaluation.py +++ b/contrib/evaluation.py @@ -261,6 +261,7 @@ def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5): mask = DrefC == dis testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) + def check_ref_range_results(Lref, Dref, Iref, Lnew, Dnew, Inew): """ compare range search results wrt. a reference result, diff --git a/faiss/IndexAdditiveQuantizer.cpp b/faiss/IndexAdditiveQuantizer.cpp index 5bf06c4a4a..719dcafbc9 100644 --- a/faiss/IndexAdditiveQuantizer.cpp +++ b/faiss/IndexAdditiveQuantizer.cpp @@ -114,18 +114,19 @@ struct AQDistanceComputerLUT : FlatCodesDistanceComputer { * scanning implementation for search ************************************************************/ -template +template void search_with_decompress( const IndexAdditiveQuantizer& ir, const float* xq, VectorDistance& vd, - ResultHandler& res) { + BlockResultHandler& res) { const uint8_t* codes = ir.codes.data(); size_t ntotal = ir.ntotal; size_t code_size = ir.code_size; const AdditiveQuantizer* aq = ir.aq; - using SingleResultHandler = typename ResultHandler::SingleResultHandler; + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; #pragma omp parallel for if (res.nq > 100) for (int64_t q = 0; q < res.nq; q++) { @@ -142,11 +143,14 @@ void search_with_decompress( } } -template +template < + bool is_IP, + AdditiveQuantizer::Search_type_t st, + class BlockResultHandler> void search_with_LUT( const IndexAdditiveQuantizer& ir, const float* xq, - ResultHandler& res) { + BlockResultHandler& res) { const AdditiveQuantizer& aq = *ir.aq; const uint8_t* codes = ir.codes.data(); size_t ntotal = ir.ntotal; @@ -154,7 +158,8 @@ void search_with_LUT( size_t nq = res.nq; size_t d = ir.d; - using SingleResultHandler = typename ResultHandler::SingleResultHandler; + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; std::unique_ptr LUT(new float[nq * aq.total_codebook_size]); aq.compute_LUT(nq, xq, LUT.get()); @@ -241,21 +246,23 @@ void IndexAdditiveQuantizer::search( if (metric_type == METRIC_L2) { using VD = VectorDistance; VD vd = {size_t(d), metric_arg}; - HeapResultHandler rh(n, distances, labels, k); + HeapBlockResultHandler rh(n, distances, labels, k); search_with_decompress(*this, x, vd, rh); } else if (metric_type == METRIC_INNER_PRODUCT) { using VD = VectorDistance; VD vd = {size_t(d), metric_arg}; - HeapResultHandler rh(n, distances, labels, k); + HeapBlockResultHandler rh(n, distances, labels, k); search_with_decompress(*this, x, vd, rh); } } else { if (metric_type == METRIC_INNER_PRODUCT) { - HeapResultHandler> rh(n, distances, labels, k); + HeapBlockResultHandler> rh( + n, distances, labels, k); search_with_LUT( *this, x, rh); } else { - HeapResultHandler> rh(n, distances, labels, k); + HeapBlockResultHandler> rh( + n, distances, labels, k); switch (aq->search_type) { #define DISPATCH(st) \ case AdditiveQuantizer::st: \ diff --git a/faiss/IndexAdditiveQuantizerFastScan.cpp b/faiss/IndexAdditiveQuantizerFastScan.cpp index 709ccc87e2..85a78647f3 100644 --- a/faiss/IndexAdditiveQuantizerFastScan.cpp +++ b/faiss/IndexAdditiveQuantizerFastScan.cpp @@ -203,9 +203,9 @@ void IndexAdditiveQuantizerFastScan::search( NormTableScaler scaler(norm_scale); if (metric_type == METRIC_L2) { - search_dispatch_implem(n, x, k, distances, labels, scaler); + search_dispatch_implem(n, x, k, distances, labels, &scaler); } else { - search_dispatch_implem(n, x, k, distances, labels, scaler); + search_dispatch_implem(n, x, k, distances, labels, &scaler); } } diff --git a/faiss/IndexBinaryHNSW.cpp b/faiss/IndexBinaryHNSW.cpp index e6fda8e4bf..f1bda08fbc 100644 --- a/faiss/IndexBinaryHNSW.cpp +++ b/faiss/IndexBinaryHNSW.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -28,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -201,27 +200,31 @@ void IndexBinaryHNSW::search( !params, "search params not supported for this index"); FAISS_THROW_IF_NOT(k > 0); + // we use the buffer for distances as float but convert them back + // to int in the end + float* distances_f = (float*)distances; + + using RH = HeapBlockResultHandler; + RH bres(n, distances_f, labels, k); + #pragma omp parallel { VisitedTable vt(ntotal); std::unique_ptr dis(get_distance_computer()); + RH::SingleResultHandler res(bres); #pragma omp for for (idx_t i = 0; i < n; i++) { - idx_t* idxi = labels + i * k; - float* simi = (float*)(distances + i * k); - + res.begin(i); dis->set_query((float*)(x + i * code_size)); - - maxheap_heapify(k, simi, idxi); - hnsw.search(*dis, k, idxi, simi, vt); - maxheap_reorder(k, simi, idxi); + hnsw.search(*dis, res, vt); + res.end(); } } #pragma omp parallel for for (int i = 0; i < n * k; ++i) { - distances[i] = std::round(((float*)distances)[i]); + distances[i] = std::round(distances_f[i]); } } diff --git a/faiss/IndexFastScan.cpp b/faiss/IndexFastScan.cpp index 02840767d1..ca02af4168 100644 --- a/faiss/IndexFastScan.cpp +++ b/faiss/IndexFastScan.cpp @@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) { namespace { -template +template void estimators_from_tables_generic( const IndexFastScan& index, const uint8_t* codes, @@ -167,25 +167,28 @@ void estimators_from_tables_generic( size_t k, typename C::T* heap_dis, int64_t* heap_ids, - const Scaler& scaler) { + const NormTableScaler* scaler) { using accu_t = typename C::T; for (size_t j = 0; j < ncodes; ++j) { BitstringReader bsr(codes + j * index.code_size, index.code_size); accu_t dis = 0; const dis_t* dt = dis_table; - for (size_t m = 0; m < index.M - scaler.nscale; m++) { + int nscale = scaler ? scaler->nscale : 0; + + for (size_t m = 0; m < index.M - nscale; m++) { uint64_t c = bsr.read(index.nbits); dis += dt[c]; dt += index.ksub; } - for (size_t m = 0; m < scaler.nscale; m++) { - uint64_t c = bsr.read(index.nbits); - dis += scaler.scale_one(dt[c]); - dt += index.ksub; + if (nscale) { + for (size_t m = 0; m < nscale; m++) { + uint64_t c = bsr.read(index.nbits); + dis += scaler->scale_one(dt[c]); + dt += index.ksub; + } } - if (C::cmp(heap_dis[0], dis)) { heap_pop(k, heap_dis, heap_ids); heap_push(k, heap_dis, heap_ids, dis, j); @@ -193,6 +196,27 @@ void estimators_from_tables_generic( } } +template +ResultHandlerCompare* make_knn_handler( + int impl, + idx_t n, + idx_t k, + size_t ntotal, + float* distances, + idx_t* labels) { + using HeapHC = HeapHandler; + using ReservoirHC = ReservoirHandler; + using SingleResultHC = SingleResultHandler; + + if (k == 1) { + return new SingleResultHC(n, ntotal, distances, labels); + } else if (impl % 2 == 0) { + return new HeapHC(n, ntotal, k, distances, labels); + } else /* if (impl % 2 == 1) */ { + return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels); + } +} + } // anonymous namespace using namespace quantize_lut; @@ -241,22 +265,21 @@ void IndexFastScan::search( !params, "search params not supported for this index"); FAISS_THROW_IF_NOT(k > 0); - DummyScaler scaler; if (metric_type == METRIC_L2) { - search_dispatch_implem(n, x, k, distances, labels, scaler); + search_dispatch_implem(n, x, k, distances, labels, nullptr); } else { - search_dispatch_implem(n, x, k, distances, labels, scaler); + search_dispatch_implem(n, x, k, distances, labels, nullptr); } } -template +template void IndexFastScan::search_dispatch_implem( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const Scaler& scaler) const { + const NormTableScaler* scaler) const { using Cfloat = typename std::conditional< is_max, CMax, @@ -319,14 +342,14 @@ void IndexFastScan::search_dispatch_implem( } } -template +template void IndexFastScan::search_implem_234( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const Scaler& scaler) const { + const NormTableScaler* scaler) const { FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4); const size_t dim12 = ksub * M; @@ -378,7 +401,7 @@ void IndexFastScan::search_implem_234( } } -template +template void IndexFastScan::search_implem_12( idx_t n, const float* x, @@ -386,7 +409,8 @@ void IndexFastScan::search_implem_12( float* distances, idx_t* labels, int impl, - const Scaler& scaler) const { + const NormTableScaler* scaler) const { + using RH = ResultHandlerCompare; FAISS_THROW_IF_NOT(bbs == 32); // handle qbs2 blocking by recursive call @@ -432,63 +456,31 @@ void IndexFastScan::search_implem_12( pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get()); FAISS_THROW_IF_NOT(LUT_nq == n); - if (k == 1) { - SingleResultHandler handler(n, ntotal); - if (skip & 4) { - // pass - } else { - handler.disable = bool(skip & 2); - pq4_accumulate_loop_qbs( - qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler); - } + std::unique_ptr handler( + make_knn_handler(impl, n, k, ntotal, distances, labels)); + handler->disable = bool(skip & 2); + handler->normalizers = normalizers.get(); - handler.to_flat_arrays(distances, labels, normalizers.get()); - - } else if (impl == 12) { - std::vector tmp_dis(n * k); - std::vector tmp_ids(n * k); - - if (skip & 4) { - // skip - } else { - HeapHandler handler( - n, tmp_dis.data(), tmp_ids.data(), k, ntotal); - handler.disable = bool(skip & 2); - - pq4_accumulate_loop_qbs( - qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler); - - if (!(skip & 8)) { - handler.to_flat_arrays(distances, labels, normalizers.get()); - } - } - - } else { // impl == 13 - - ReservoirHandler handler(n, ntotal, k, 2 * k); - handler.disable = bool(skip & 2); - - if (skip & 4) { - // skip - } else { - pq4_accumulate_loop_qbs( - qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler); - } - - if (!(skip & 8)) { - handler.to_flat_arrays(distances, labels, normalizers.get()); - } - - FastScan_stats.t0 += handler.times[0]; - FastScan_stats.t1 += handler.times[1]; - FastScan_stats.t2 += handler.times[2]; - FastScan_stats.t3 += handler.times[3]; + if (skip & 4) { + // pass + } else { + pq4_accumulate_loop_qbs( + qbs, + ntotal2, + M2, + codes.get(), + LUT.get(), + *handler.get(), + scaler); + } + if (!(skip & 8)) { + handler->end(); } } FastScanStats FastScan_stats; -template +template void IndexFastScan::search_implem_14( idx_t n, const float* x, @@ -496,7 +488,8 @@ void IndexFastScan::search_implem_14( float* distances, idx_t* labels, int impl, - const Scaler& scaler) const { + const NormTableScaler* scaler) const { + using RH = ResultHandlerCompare; FAISS_THROW_IF_NOT(bbs % 32 == 0); int qbs2 = qbs == 0 ? 4 : qbs; @@ -531,91 +524,29 @@ void IndexFastScan::search_implem_14( AlignedTable LUT(n * dim12); pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get()); - if (k == 1) { - SingleResultHandler handler(n, ntotal); - if (skip & 4) { - // pass - } else { - handler.disable = bool(skip & 2); - pq4_accumulate_loop( - n, - ntotal2, - bbs, - M2, - codes.get(), - LUT.get(), - handler, - scaler); - } - handler.to_flat_arrays(distances, labels, normalizers.get()); - - } else if (impl == 14) { - std::vector tmp_dis(n * k); - std::vector tmp_ids(n * k); - - if (skip & 4) { - // skip - } else if (k > 1) { - HeapHandler handler( - n, tmp_dis.data(), tmp_ids.data(), k, ntotal); - handler.disable = bool(skip & 2); - - pq4_accumulate_loop( - n, - ntotal2, - bbs, - M2, - codes.get(), - LUT.get(), - handler, - scaler); - - if (!(skip & 8)) { - handler.to_flat_arrays(distances, labels, normalizers.get()); - } - } - - } else { // impl == 15 - - ReservoirHandler handler(n, ntotal, k, 2 * k); - handler.disable = bool(skip & 2); - - if (skip & 4) { - // skip - } else { - pq4_accumulate_loop( - n, - ntotal2, - bbs, - M2, - codes.get(), - LUT.get(), - handler, - scaler); - } + std::unique_ptr handler( + make_knn_handler(impl, n, k, ntotal, distances, labels)); + handler->disable = bool(skip & 2); + handler->normalizers = normalizers.get(); - if (!(skip & 8)) { - handler.to_flat_arrays(distances, labels, normalizers.get()); - } + if (skip & 4) { + // pass + } else { + pq4_accumulate_loop( + n, + ntotal2, + bbs, + M2, + codes.get(), + LUT.get(), + *handler.get(), + scaler); + } + if (!(skip & 8)) { + handler->end(); } } -template void IndexFastScan::search_dispatch_implem( - idx_t n, - const float* x, - idx_t k, - float* distances, - idx_t* labels, - const NormTableScaler& scaler) const; - -template void IndexFastScan::search_dispatch_implem( - idx_t n, - const float* x, - idx_t k, - float* distances, - idx_t* labels, - const NormTableScaler& scaler) const; - void IndexFastScan::reconstruct(idx_t key, float* recons) const { std::vector code(code_size, 0); BitstringWriter bsw(code.data(), code_size); diff --git a/faiss/IndexFastScan.h b/faiss/IndexFastScan.h index 19aad2a8ee..3c89dcf928 100644 --- a/faiss/IndexFastScan.h +++ b/faiss/IndexFastScan.h @@ -13,6 +13,7 @@ namespace faiss { struct CodePacker; +struct NormTableScaler; /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now. * @@ -87,25 +88,25 @@ struct IndexFastScan : Index { uint8_t* lut, float* normalizers) const; - template + template void search_dispatch_implem( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; - template + template void search_implem_234( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; - template + template void search_implem_12( idx_t n, const float* x, @@ -113,9 +114,9 @@ struct IndexFastScan : Index { float* distances, idx_t* labels, int impl, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; - template + template void search_implem_14( idx_t n, const float* x, @@ -123,7 +124,7 @@ struct IndexFastScan : Index { float* distances, idx_t* labels, int impl, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; void reconstruct(idx_t key, float* recons) const override; size_t remove_ids(const IDSelector& sel) override; diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 8c0e0afde8..c40136749b 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include #include @@ -286,18 +286,20 @@ void IndexHNSW::train(idx_t n, const float* x) { is_trained = true; } -void IndexHNSW::search( +namespace { + +template +void hnsw_search( + const IndexHNSW* index, idx_t n, const float* x, - idx_t k, - float* distances, - idx_t* labels, - const SearchParameters* params_in) const { - FAISS_THROW_IF_NOT(k > 0); + BlockResultHandler& bres, + const SearchParameters* params_in) { FAISS_THROW_IF_NOT_MSG( - storage, + index->storage, "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly"); const SearchParametersHNSW* params = nullptr; + const HNSW& hnsw = index->hnsw; int efSearch = hnsw.efSearch; if (params_in) { @@ -307,61 +309,81 @@ void IndexHNSW::search( } size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0; - idx_t check_period = - InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch); + idx_t check_period = InterruptCallback::get_period_hint( + hnsw.max_level * index->d * efSearch); for (idx_t i0 = 0; i0 < n; i0 += check_period) { idx_t i1 = std::min(i0 + check_period, n); #pragma omp parallel { - VisitedTable vt(ntotal); + VisitedTable vt(index->ntotal); + typename BlockResultHandler::SingleResultHandler res(bres); std::unique_ptr dis( - storage_distance_computer(storage)); + storage_distance_computer(index->storage)); #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided) for (idx_t i = i0; i < i1; i++) { - idx_t* idxi = labels + i * k; - float* simi = distances + i * k; - dis->set_query(x + i * d); + res.begin(i); + dis->set_query(x + i * index->d); - maxheap_heapify(k, simi, idxi); - HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params); + HNSWStats stats = hnsw.search(*dis, res, vt, params); n1 += stats.n1; n2 += stats.n2; n3 += stats.n3; ndis += stats.ndis; nreorder += stats.nreorder; - maxheap_reorder(k, simi, idxi); - - if (reconstruct_from_neighbors && - reconstruct_from_neighbors->k_reorder != 0) { - int k_reorder = reconstruct_from_neighbors->k_reorder; - if (k_reorder == -1 || k_reorder > k) - k_reorder = k; - - nreorder += reconstruct_from_neighbors->compute_distances( - k_reorder, idxi, x + i * d, simi); - - // sort top k_reorder - maxheap_heapify( - k_reorder, simi, idxi, simi, idxi, k_reorder); - maxheap_reorder(k_reorder, simi, idxi); - } + res.end(); } } InterruptCallback::check(); } - if (is_similarity_metric(metric_type)) { + hnsw_stats.combine({n1, n2, n3, ndis, nreorder}); +} + +} // anonymous namespace + +void IndexHNSW::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params_in) const { + FAISS_THROW_IF_NOT(k > 0); + + using RH = HeapBlockResultHandler; + RH bres(n, distances, labels, k); + + hnsw_search(this, n, x, bres, params_in); + + if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances for (size_t i = 0; i < k * n; i++) { distances[i] = -distances[i]; } } +} - hnsw_stats.combine({n1, n2, n3, ndis, nreorder}); +void IndexHNSW::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const { + using RH = RangeSearchBlockResultHandler; + RH bres(result, radius); + + hnsw_search(this, n, x, bres, params); + + if (is_similarity_metric(this->metric_type)) { + // we need to revert the negated distances + for (size_t i = 0; i < result->lims[result->nq]; i++) { + result->distances[i] = -result->distances[i]; + } + } } void IndexHNSW::add(idx_t n, const float* x) { @@ -437,35 +459,33 @@ void IndexHNSW::search_level_0( storage_idx_t ntotal = hnsw.levels.size(); + using RH = HeapBlockResultHandler; + RH bres(n, distances, labels, k); + #pragma omp parallel { std::unique_ptr qdis( storage_distance_computer(storage)); HNSWStats search_stats; VisitedTable vt(ntotal); + RH::SingleResultHandler res(bres); #pragma omp for for (idx_t i = 0; i < n; i++) { - idx_t* idxi = labels + i * k; - float* simi = distances + i * k; - + res.begin(i); qdis->set_query(x + i * d); - maxheap_heapify(k, simi, idxi); hnsw.search_level_0( *qdis.get(), - k, - idxi, - simi, + res, nprobe, nearest + i * nprobe, nearest_d + i * nprobe, search_type, search_stats, vt); - + res.end(); vt.advance(); - maxheap_reorder(k, simi, idxi); } #pragma omp critical { hnsw_stats.combine(search_stats); } @@ -630,246 +650,6 @@ void IndexHNSW::permute_entries(const idx_t* perm) { hnsw.permute_entries(perm); } -/************************************************************** - * ReconstructFromNeighbors implementation - **************************************************************/ - -ReconstructFromNeighbors::ReconstructFromNeighbors( - const IndexHNSW& index, - size_t k, - size_t nsq) - : index(index), k(k), nsq(nsq) { - M = index.hnsw.nb_neighbors(0); - FAISS_ASSERT(k <= 256); - code_size = k == 1 ? 0 : nsq; - ntotal = 0; - d = index.d; - FAISS_ASSERT(d % nsq == 0); - dsub = d / nsq; - k_reorder = -1; -} - -void ReconstructFromNeighbors::reconstruct( - storage_idx_t i, - float* x, - float* tmp) const { - const HNSW& hnsw = index.hnsw; - size_t begin, end; - hnsw.neighbor_range(i, 0, &begin, &end); - - if (k == 1 || nsq == 1) { - const float* beta; - if (k == 1) { - beta = codebook.data(); - } else { - int idx = codes[i]; - beta = codebook.data() + idx * (M + 1); - } - - float w0 = beta[0]; // weight of image itself - index.storage->reconstruct(i, tmp); - - for (int l = 0; l < d; l++) - x[l] = w0 * tmp[l]; - - for (size_t j = begin; j < end; j++) { - storage_idx_t ji = hnsw.neighbors[j]; - if (ji < 0) - ji = i; - float w = beta[j - begin + 1]; - index.storage->reconstruct(ji, tmp); - for (int l = 0; l < d; l++) - x[l] += w * tmp[l]; - } - } else if (nsq == 2) { - int idx0 = codes[2 * i]; - int idx1 = codes[2 * i + 1]; - - const float* beta0 = codebook.data() + idx0 * (M + 1); - const float* beta1 = codebook.data() + (idx1 + k) * (M + 1); - - index.storage->reconstruct(i, tmp); - - float w0; - - w0 = beta0[0]; - for (int l = 0; l < dsub; l++) - x[l] = w0 * tmp[l]; - - w0 = beta1[0]; - for (int l = dsub; l < d; l++) - x[l] = w0 * tmp[l]; - - for (size_t j = begin; j < end; j++) { - storage_idx_t ji = hnsw.neighbors[j]; - if (ji < 0) - ji = i; - index.storage->reconstruct(ji, tmp); - float w; - w = beta0[j - begin + 1]; - for (int l = 0; l < dsub; l++) - x[l] += w * tmp[l]; - - w = beta1[j - begin + 1]; - for (int l = dsub; l < d; l++) - x[l] += w * tmp[l]; - } - } else { - std::vector betas(nsq); - { - const float* b = codebook.data(); - const uint8_t* c = &codes[i * code_size]; - for (int sq = 0; sq < nsq; sq++) { - betas[sq] = b + (*c++) * (M + 1); - b += (M + 1) * k; - } - } - - index.storage->reconstruct(i, tmp); - { - int d0 = 0; - for (int sq = 0; sq < nsq; sq++) { - float w = *(betas[sq]++); - int d1 = d0 + dsub; - for (int l = d0; l < d1; l++) { - x[l] = w * tmp[l]; - } - d0 = d1; - } - } - - for (size_t j = begin; j < end; j++) { - storage_idx_t ji = hnsw.neighbors[j]; - if (ji < 0) - ji = i; - - index.storage->reconstruct(ji, tmp); - int d0 = 0; - for (int sq = 0; sq < nsq; sq++) { - float w = *(betas[sq]++); - int d1 = d0 + dsub; - for (int l = d0; l < d1; l++) { - x[l] += w * tmp[l]; - } - d0 = d1; - } - } - } -} - -void ReconstructFromNeighbors::reconstruct_n( - storage_idx_t n0, - storage_idx_t ni, - float* x) const { -#pragma omp parallel - { - std::vector tmp(index.d); -#pragma omp for - for (storage_idx_t i = 0; i < ni; i++) { - reconstruct(n0 + i, x + i * index.d, tmp.data()); - } - } -} - -size_t ReconstructFromNeighbors::compute_distances( - size_t n, - const idx_t* shortlist, - const float* query, - float* distances) const { - std::vector tmp(2 * index.d); - size_t ncomp = 0; - for (int i = 0; i < n; i++) { - if (shortlist[i] < 0) - break; - reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d); - distances[i] = fvec_L2sqr(query, tmp.data(), index.d); - ncomp++; - } - return ncomp; -} - -void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float* tmp1) - const { - const HNSW& hnsw = index.hnsw; - size_t begin, end; - hnsw.neighbor_range(i, 0, &begin, &end); - size_t d = index.d; - - index.storage->reconstruct(i, tmp1); - - for (size_t j = begin; j < end; j++) { - storage_idx_t ji = hnsw.neighbors[j]; - if (ji < 0) - ji = i; - index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d); - } -} - -/// called by add_codes -void ReconstructFromNeighbors::estimate_code( - const float* x, - storage_idx_t i, - uint8_t* code) const { - // fill in tmp table with the neighbor values - std::unique_ptr tmp1(new float[d * (M + 1) + (d * k)]); - float* tmp2 = tmp1.get() + d * (M + 1); - - // collect coordinates of base - get_neighbor_table(i, tmp1.get()); - - for (size_t sq = 0; sq < nsq; sq++) { - int d0 = sq * dsub; - - { - FINTEGER ki = k, di = d, m1 = M + 1; - FINTEGER dsubi = dsub; - float zero = 0, one = 1; - - sgemm_("N", - "N", - &dsubi, - &ki, - &m1, - &one, - tmp1.get() + d0, - &di, - codebook.data() + sq * (m1 * k), - &m1, - &zero, - tmp2, - &dsubi); - } - - float min = HUGE_VAL; - int argmin = -1; - for (size_t j = 0; j < k; j++) { - float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub); - if (dis < min) { - min = dis; - argmin = j; - } - } - code[sq] = argmin; - } -} - -void ReconstructFromNeighbors::add_codes(size_t n, const float* x) { - if (k == 1) { // nothing to encode - ntotal += n; - return; - } - codes.resize(codes.size() + code_size * n); -#pragma omp parallel for - for (int i = 0; i < n; i++) { - estimate_code( - x + i * index.d, - ntotal + i, - codes.data() + (ntotal + i) * code_size); - } - ntotal += n; - FAISS_ASSERT(codes.size() == ntotal * code_size); -} - /************************************************************** * IndexHNSWFlat implementation **************************************************************/ diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index 13855d3037..e0b65fca9d 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -21,49 +21,6 @@ namespace faiss { struct IndexHNSW; -struct ReconstructFromNeighbors { - typedef HNSW::storage_idx_t storage_idx_t; - - const IndexHNSW& index; - size_t M; // number of neighbors - size_t k; // number of codebook entries - size_t nsq; // number of subvectors - size_t code_size; - int k_reorder; // nb to reorder. -1 = all - - std::vector codebook; // size nsq * k * (M + 1) - - std::vector codes; // size ntotal * code_size - size_t ntotal; - size_t d, dsub; // derived values - - explicit ReconstructFromNeighbors( - const IndexHNSW& index, - size_t k = 256, - size_t nsq = 1); - - /// codes must be added in the correct order and the IndexHNSW - /// must be populated and sorted - void add_codes(size_t n, const float* x); - - size_t compute_distances( - size_t n, - const idx_t* shortlist, - const float* query, - float* distances) const; - - /// called by add_codes - void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const; - - /// called by compute_distances - void reconstruct(storage_idx_t i, float* x, float* tmp) const; - - void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const; - - /// get the M+1 -by-d table for neighbor coordinates for vector i - void get_neighbor_table(storage_idx_t i, float* out) const; -}; - /** The HNSW index is a normal random-access index with a HNSW * link structure built on top */ @@ -77,8 +34,6 @@ struct IndexHNSW : Index { bool own_fields = false; Index* storage = nullptr; - ReconstructFromNeighbors* reconstruct_from_neighbors = nullptr; - explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2); explicit IndexHNSW(Index* storage, int M = 32); @@ -98,6 +53,13 @@ struct IndexHNSW : Index { idx_t* labels, const SearchParameters* params = nullptr) const override; + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + void reconstruct(idx_t key, float* recons) const override; void reset() override; diff --git a/faiss/IndexIVFAdditiveQuantizerFastScan.cpp b/faiss/IndexIVFAdditiveQuantizerFastScan.cpp index 25c3aa2b06..23a2de554d 100644 --- a/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +++ b/faiss/IndexIVFAdditiveQuantizerFastScan.cpp @@ -211,7 +211,8 @@ void IndexIVFAdditiveQuantizerFastScan::estimate_norm_scale( size_t index_nprobe = nprobe; nprobe = 1; - compute_LUT(n, x, coarse_ids.data(), coarse_dis.data(), dis_tables, biases); + CoarseQuantized cq{index_nprobe, coarse_dis.data(), coarse_ids.data()}; + compute_LUT(n, x, cq, dis_tables, biases); nprobe = index_nprobe; float scale = 0; @@ -313,13 +314,8 @@ void IndexIVFAdditiveQuantizerFastScan::search( } NormTableScaler scaler(norm_scale); - if (metric_type == METRIC_L2) { - search_dispatch_implem( - n, x, k, distances, labels, nullptr, nullptr, scaler); - } else { - search_dispatch_implem( - n, x, k, distances, labels, nullptr, nullptr, scaler); - } + IndexIVFFastScan::CoarseQuantized cq{nprobe}; + search_dispatch_implem(n, x, k, distances, labels, cq, &scaler); } /********************************************************* @@ -385,12 +381,12 @@ bool IndexIVFAdditiveQuantizerFastScan::lookup_table_is_3d() const { void IndexIVFAdditiveQuantizerFastScan::compute_LUT( size_t n, const float* x, - const idx_t* coarse_ids, - const float*, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const { const size_t dim12 = ksub * M; const size_t ip_dim12 = aq->M * ksub; + const size_t nprobe = cq.nprobe; dis_tables.resize(n * dim12); @@ -411,7 +407,7 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT( #pragma omp for for (idx_t ij = 0; ij < n * nprobe; ij++) { int i = ij / nprobe; - quantizer->reconstruct(coarse_ids[ij], c); + quantizer->reconstruct(cq.ids[ij], c); biases[ij] = coef * fvec_inner_product(c, x + i * d, d); } } diff --git a/faiss/IndexIVFAdditiveQuantizerFastScan.h b/faiss/IndexIVFAdditiveQuantizerFastScan.h index 24ce7287ec..643628dec1 100644 --- a/faiss/IndexIVFAdditiveQuantizerFastScan.h +++ b/faiss/IndexIVFAdditiveQuantizerFastScan.h @@ -93,8 +93,7 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan { void compute_LUT( size_t n, const float* x, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const override; diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index 0b9c4e0992..45dae01f27 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -198,7 +198,7 @@ CodePacker* IndexIVFFastScan::get_CodePacker() const { namespace { -template +template void estimators_from_tables_generic( const IndexIVFFastScan& index, const uint8_t* codes, @@ -209,22 +209,26 @@ void estimators_from_tables_generic( size_t k, typename C::T* heap_dis, int64_t* heap_ids, - const Scaler& scaler) { + const NormTableScaler* scaler) { using accu_t = typename C::T; + int nscale = scaler ? scaler->nscale : 0; for (size_t j = 0; j < ncodes; ++j) { BitstringReader bsr(codes + j * index.code_size, index.code_size); accu_t dis = bias; const dis_t* __restrict dt = dis_table; - for (size_t m = 0; m < index.M - scaler.nscale; m++) { + + for (size_t m = 0; m < index.M - nscale; m++) { uint64_t c = bsr.read(index.nbits); dis += dt[c]; dt += index.ksub; } - for (size_t m = 0; m < scaler.nscale; m++) { - uint64_t c = bsr.read(index.nbits); - dis += scaler.scale_one(dt[c]); - dt += index.ksub; + if (scaler) { + for (size_t m = 0; m < nscale; m++) { + uint64_t c = bsr.read(index.nbits); + dis += scaler->scale_one(dt[c]); + dt += index.ksub; + } } if (C::cmp(heap_dis[0], dis)) { @@ -245,18 +249,15 @@ using namespace quantize_lut; void IndexIVFFastScan::compute_LUT_uint8( size_t n, const float* x, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases, float* normalizers) const { AlignedTable dis_tables_float; AlignedTable biases_float; - uint64_t t0 = get_cy(); - compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float); - IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0; - + compute_LUT(n, x, cq, dis_tables_float, biases_float); + size_t nprobe = cq.nprobe; bool lut_is_3d = lookup_table_is_3d(); size_t dim123 = ksub * M; size_t dim123_2 = ksub * M2; @@ -268,7 +269,6 @@ void IndexIVFFastScan::compute_LUT_uint8( if (biases_float.get()) { biases.resize(n * nprobe); } - uint64_t t1 = get_cy(); #pragma omp parallel for if (n > 100) for (int64_t i = 0; i < n; i++) { @@ -294,7 +294,6 @@ void IndexIVFFastScan::compute_LUT_uint8( normalizers + 2 * i, normalizers + 2 * i + 1); } - IVFFastScan_stats.t_round += get_cy() - t1; } /********************************************************* @@ -308,18 +307,10 @@ void IndexIVFFastScan::search( float* distances, idx_t* labels, const SearchParameters* params) const { - FAISS_THROW_IF_NOT_MSG( - !params, "search params not supported for this index"); - FAISS_THROW_IF_NOT(k > 0); - - DummyScaler scaler; - if (metric_type == METRIC_L2) { - search_dispatch_implem( - n, x, k, distances, labels, nullptr, nullptr, scaler); - } else { - search_dispatch_implem( - n, x, k, distances, labels, nullptr, nullptr, scaler); - } + auto paramsi = dynamic_cast(params); + FAISS_THROW_IF_NOT_MSG(!params || paramsi, "need IVFSearchParameters"); + search_preassigned( + n, x, k, nullptr, nullptr, distances, labels, false, paramsi); } void IndexIVFFastScan::search_preassigned( @@ -333,51 +324,144 @@ void IndexIVFFastScan::search_preassigned( bool store_pairs, const IVFSearchParameters* params, IndexIVFStats* stats) const { - FAISS_THROW_IF_NOT_MSG( - !params, "search params not supported for this index"); + size_t nprobe = this->nprobe; + if (params) { + FAISS_THROW_IF_NOT_MSG( + !params->quantizer_params, "quantizer params not supported"); + FAISS_THROW_IF_NOT(params->max_codes == 0); + nprobe = params->nprobe; + } FAISS_THROW_IF_NOT_MSG( !store_pairs, "store_pairs not supported for this index"); FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index"); FAISS_THROW_IF_NOT(k > 0); - DummyScaler scaler; - if (metric_type == METRIC_L2) { - search_dispatch_implem( - n, x, k, distances, labels, assign, centroid_dis, scaler); + const CoarseQuantized cq = {nprobe, centroid_dis, assign}; + search_dispatch_implem(n, x, k, distances, labels, cq, nullptr); +} + +void IndexIVFFastScan::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const { + FAISS_THROW_IF_NOT(!params); + const CoarseQuantized cq = {nprobe, nullptr, nullptr}; + range_search_dispatch_implem(n, x, radius, *result, cq, nullptr); +} + +namespace { + +template +ResultHandlerCompare* make_knn_handler_fixC( + int impl, + idx_t n, + idx_t k, + float* distances, + idx_t* labels) { + using HeapHC = HeapHandler; + using ReservoirHC = ReservoirHandler; + using SingleResultHC = SingleResultHandler; + + if (k == 1) { + return new SingleResultHC(n, 0, distances, labels); + } else if (impl % 2 == 0) { + return new HeapHC(n, 0, k, distances, labels); + } else /* if (impl % 2 == 1) */ { + return new ReservoirHC(n, 0, k, 2 * k, distances, labels); + } +} + +SIMDResultHandlerToFloat* make_knn_handler( + bool is_max, + int impl, + idx_t n, + idx_t k, + float* distances, + idx_t* labels) { + if (is_max) { + return make_knn_handler_fixC>( + impl, n, k, distances, labels); } else { - search_dispatch_implem( - n, x, k, distances, labels, assign, centroid_dis, scaler); + return make_knn_handler_fixC>( + impl, n, k, distances, labels); } } -void IndexIVFFastScan::range_search( - idx_t, - const float*, - float, - RangeSearchResult*, - const SearchParameters*) const { - FAISS_THROW_MSG("not implemented"); +using CoarseQuantized = IndexIVFFastScan::CoarseQuantized; + +struct CoarseQuantizedWithBuffer : CoarseQuantized { + explicit CoarseQuantizedWithBuffer(const CoarseQuantized& cq) + : CoarseQuantized(cq) {} + + bool done() const { + return ids != nullptr; + } + + std::vector ids_buffer; + std::vector dis_buffer; + + void quantize(const Index* quantizer, idx_t n, const float* x) { + dis_buffer.resize(nprobe * n); + ids_buffer.resize(nprobe * n); + quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data()); + dis = dis_buffer.data(); + ids = ids_buffer.data(); + } +}; + +struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer { + size_t i0, i1; + CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1) + : CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) { + if (done()) { + dis += nprobe * i0; + ids += nprobe * i0; + } + } + + void quantize_slice(const Index* quantizer, const float* x) { + quantize(quantizer, i1 - i0, x + quantizer->d * i0); + } +}; + +int compute_search_nslice( + const IndexIVFFastScan* index, + size_t n, + size_t nprobe) { + int nslice; + if (n <= omp_get_max_threads()) { + nslice = n; + } else if (index->lookup_table_is_3d()) { + // make sure we don't make too big LUT tables + size_t lut_size_per_query = index->M * index->ksub * nprobe * + (sizeof(float) + sizeof(uint8_t)); + + size_t max_lut_size = precomputed_table_max_bytes; + // how many queries we can handle within mem budget + size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1)); + nslice = roundup( + std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads()); + } else { + // LUTs unlikely to be a limiting factor + nslice = omp_get_max_threads(); + } + return nslice; } -template +}; // namespace + void IndexIVFFastScan::search_dispatch_implem( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const Scaler& scaler) const { - using Cfloat = typename std::conditional< - is_max, - CMax, - CMin>::type; - - using C = typename std::conditional< - is_max, - CMax, - CMin>::type; + const CoarseQuantized& cq_in, + const NormTableScaler* scaler) const { + bool is_max = !is_similarity_metric(metric_type); + using RH = SIMDResultHandlerToFloat; if (n == 0) { return; @@ -392,94 +476,74 @@ void IndexIVFFastScan::search_dispatch_implem( } else { impl = 10; } - if (k > 20) { + if (k > 20) { // use reservoir rather than heap impl++; } } + bool multiple_threads = + n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1; + if (impl >= 100) { + multiple_threads = false; + impl -= 100; + } + + CoarseQuantizedWithBuffer cq(cq_in); + + if (!cq.done() && !multiple_threads) { + // we do the coarse quantization here execpt when search is + // sliced over threads (then it is more efficient to have each thread do + // its own coarse quantization) + cq.quantize(quantizer, n, x); + } + if (impl == 1) { - search_implem_1( - n, x, k, distances, labels, coarse_ids, coarse_dis, scaler); + if (is_max) { + search_implem_1>( + n, x, k, distances, labels, cq, scaler); + } else { + search_implem_1>( + n, x, k, distances, labels, cq, scaler); + } } else if (impl == 2) { - search_implem_2( - n, x, k, distances, labels, coarse_ids, coarse_dis, scaler); + if (is_max) { + search_implem_2>( + n, x, k, distances, labels, cq, scaler); + } else { + search_implem_2>( + n, x, k, distances, labels, cq, scaler); + } } else if (impl >= 10 && impl <= 15) { size_t ndis = 0, nlist_visited = 0; - if (n < 2) { + if (!multiple_threads) { + // clang-format off if (impl == 12 || impl == 13) { - search_implem_12( - n, - x, - k, - distances, - labels, - coarse_ids, - coarse_dis, - impl, - &ndis, - &nlist_visited, - scaler); + std::unique_ptr handler(make_knn_handler(is_max, impl, n, k, distances, labels)); + search_implem_12( + n, x, *handler.get(), + cq, &ndis, &nlist_visited, scaler); + } else if (impl == 14 || impl == 15) { - search_implem_14( - n, - x, - k, - distances, - labels, - coarse_ids, - coarse_dis, - impl, - scaler); + + search_implem_14( + n, x, k, distances, labels, + cq, impl, scaler); } else { - search_implem_10( - n, - x, - k, - distances, - labels, - coarse_ids, - coarse_dis, - impl, - &ndis, - &nlist_visited, - scaler); + std::unique_ptr handler(make_knn_handler(is_max, impl, n, k, distances, labels)); + search_implem_10( + n, x, *handler.get(), cq, + &ndis, &nlist_visited, scaler); } + // clang-format on } else { // explicitly slice over threads - int nslice; - if (n <= omp_get_max_threads()) { - nslice = n; - } else if (lookup_table_is_3d()) { - // make sure we don't make too big LUT tables - size_t lut_size_per_query = - M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t)); - - size_t max_lut_size = precomputed_table_max_bytes; - // how many queries we can handle within mem budget - size_t nq_ok = - std::max(max_lut_size / lut_size_per_query, size_t(1)); - nslice = - roundup(std::max(size_t(n / nq_ok), size_t(1)), - omp_get_max_threads()); - } else { - // LUTs unlikely to be a limiting factor - nslice = omp_get_max_threads(); - } - if (impl == 14 || - impl == 15) { // this might require slicing if there are too - // many queries (for now we keep this simple) - search_implem_14( - n, - x, - k, - distances, - labels, - coarse_ids, - coarse_dis, - impl, - scaler); + int nslice = compute_search_nslice(this, n, cq.nprobe); + if (impl == 14 || impl == 15) { + // this might require slicing if there are too + // many queries (for now we keep this simple) + search_implem_14(n, x, k, distances, labels, cq, impl, scaler); } else { #pragma omp parallel for reduction(+ : ndis, nlist_visited) for (int slice = 0; slice < nslice; slice++) { @@ -487,39 +551,23 @@ void IndexIVFFastScan::search_dispatch_implem( idx_t i1 = n * (slice + 1) / nslice; float* dis_i = distances + i0 * k; idx_t* lab_i = labels + i0 * k; - const idx_t* coarse_ids_i = coarse_ids != nullptr - ? coarse_ids + i0 * nprobe - : nullptr; - const float* coarse_dis_i = coarse_dis != nullptr - ? coarse_dis + i0 * nprobe - : nullptr; + CoarseQuantizedSlice cq_i(cq, i0, i1); + if (!cq_i.done()) { + cq_i.quantize_slice(quantizer, x); + } + std::unique_ptr handler(make_knn_handler( + is_max, impl, i1 - i0, k, dis_i, lab_i)); + // clang-format off if (impl == 12 || impl == 13) { - search_implem_12( - i1 - i0, - x + i0 * d, - k, - dis_i, - lab_i, - coarse_ids_i, - coarse_dis_i, - impl, - &ndis, - &nlist_visited, - scaler); + search_implem_12( + i1 - i0, x + i0 * d, *handler.get(), + cq_i, &ndis, &nlist_visited, scaler); } else { - search_implem_10( - i1 - i0, - x + i0 * d, - k, - dis_i, - lab_i, - coarse_ids_i, - coarse_dis_i, - impl, - &ndis, - &nlist_visited, - scaler); + search_implem_10( + i1 - i0, x + i0 * d, *handler.get(), + cq_i, &ndis, &nlist_visited, scaler); } + // clang-format on } } } @@ -531,46 +579,139 @@ void IndexIVFFastScan::search_dispatch_implem( } } -#define COARSE_QUANTIZE \ - std::unique_ptr coarse_ids_buffer; \ - std::unique_ptr coarse_dis_buffer; \ - if (coarse_ids == nullptr || coarse_dis == nullptr) { \ - coarse_ids_buffer.reset(new idx_t[n * nprobe]); \ - coarse_dis_buffer.reset(new float[n * nprobe]); \ - quantizer->search( \ - n, \ - x, \ - nprobe, \ - coarse_dis_buffer.get(), \ - coarse_ids_buffer.get()); \ - coarse_ids = coarse_ids_buffer.get(); \ - coarse_dis = coarse_dis_buffer.get(); \ +void IndexIVFFastScan::range_search_dispatch_implem( + idx_t n, + const float* x, + float radius, + RangeSearchResult& rres, + const CoarseQuantized& cq_in, + const NormTableScaler* scaler) const { + bool is_max = !is_similarity_metric(metric_type); + + if (n == 0) { + return; + } + + // actual implementation used + int impl = implem; + + if (impl == 0) { + if (bbs == 32) { + impl = 12; + } else { + impl = 10; + } + } + + CoarseQuantizedWithBuffer cq(cq_in); + + bool multiple_threads = + n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1; + if (impl >= 100) { + multiple_threads = false; + impl -= 100; + } + + if (!multiple_threads && !cq.done()) { + cq.quantize(quantizer, n, x); } -template + size_t ndis = 0, nlist_visited = 0; + + if (!multiple_threads) { // single thread + std::unique_ptr handler; + if (is_max) { + handler.reset(new RangeHandler, true>( + rres, radius, 0)); + } else { + handler.reset(new RangeHandler, true>( + rres, radius, 0)); + } + if (impl == 12) { + search_implem_12( + n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler); + } else if (impl == 10) { + search_implem_10( + n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler); + } else { + FAISS_THROW_FMT("Range search implem %d not impemented", impl); + } + } else { + // explicitly slice over threads + int nslice = compute_search_nslice(this, n, cq.nprobe); +#pragma omp parallel + { + RangeSearchPartialResult pres(&rres); + +#pragma omp parallel for reduction(+ : ndis, nlist_visited) + for (int slice = 0; slice < nslice; slice++) { + idx_t i0 = n * slice / nslice; + idx_t i1 = n * (slice + 1) / nslice; + CoarseQuantizedSlice cq_i(cq, i0, i1); + if (!cq_i.done()) { + cq_i.quantize_slice(quantizer, x); + } + std::unique_ptr handler; + if (is_max) { + handler.reset(new PartialRangeHandler< + CMax, + true>(pres, radius, 0, i0, i1)); + } else { + handler.reset(new PartialRangeHandler< + CMin, + true>(pres, radius, 0, i0, i1)); + } + + if (impl == 12 || impl == 13) { + search_implem_12( + i1 - i0, + x + i0 * d, + *handler.get(), + cq_i, + &ndis, + &nlist_visited, + scaler); + } else { + search_implem_10( + i1 - i0, + x + i0 * d, + *handler.get(), + cq_i, + &ndis, + &nlist_visited, + scaler); + } + } + pres.finalize(); + } + } + + indexIVF_stats.nq += n; + indexIVF_stats.ndis += ndis; + indexIVF_stats.nlist += nlist_visited; +} + +template void IndexIVFFastScan::search_implem_1( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const Scaler& scaler) const { + const CoarseQuantized& cq, + const NormTableScaler* scaler) const { FAISS_THROW_IF_NOT(orig_invlists); - COARSE_QUANTIZE; - size_t dim12 = ksub * M; AlignedTable dis_tables; AlignedTable biases; - compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables, biases); + compute_LUT(n, x, cq, dis_tables, biases); bool single_LUT = !lookup_table_is_3d(); size_t ndis = 0, nlist_visited = 0; - + size_t nprobe = cq.nprobe; #pragma omp parallel for reduction(+ : ndis, nlist_visited) for (idx_t i = 0; i < n; i++) { int64_t* heap_ids = labels + i * k; @@ -585,7 +726,7 @@ void IndexIVFFastScan::search_implem_1( if (!single_LUT) { LUT = dis_tables.get() + (i * nprobe + j) * dim12; } - idx_t list_no = coarse_ids[i * nprobe + j]; + idx_t list_no = cq.ids[i * nprobe + j]; if (list_no < 0) continue; size_t ls = orig_invlists->list_size(list_no); @@ -617,36 +758,28 @@ void IndexIVFFastScan::search_implem_1( indexIVF_stats.nlist += nlist_visited; } -template +template void IndexIVFFastScan::search_implem_2( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const Scaler& scaler) const { + const CoarseQuantized& cq, + const NormTableScaler* scaler) const { FAISS_THROW_IF_NOT(orig_invlists); - COARSE_QUANTIZE; size_t dim12 = ksub * M2; AlignedTable dis_tables; AlignedTable biases; std::unique_ptr normalizers(new float[2 * n]); - compute_LUT_uint8( - n, - x, - coarse_ids, - coarse_dis, - dis_tables, - biases, - normalizers.get()); + compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get()); bool single_LUT = !lookup_table_is_3d(); size_t ndis = 0, nlist_visited = 0; + size_t nprobe = cq.nprobe; #pragma omp parallel for reduction(+ : ndis, nlist_visited) for (idx_t i = 0; i < n; i++) { @@ -663,7 +796,7 @@ void IndexIVFFastScan::search_implem_2( if (!single_LUT) { LUT = dis_tables.get() + (i * nprobe + j) * dim12; } - idx_t list_no = coarse_ids[i * nprobe + j]; + idx_t list_no = cq.ids[i * nprobe + j]; if (list_no < 0) continue; size_t ls = orig_invlists->list_size(list_no); @@ -708,169 +841,100 @@ void IndexIVFFastScan::search_implem_2( indexIVF_stats.nlist += nlist_visited; } -template void IndexIVFFastScan::search_implem_10( idx_t n, const float* x, - idx_t k, - float* distances, - idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - int impl, + SIMDResultHandlerToFloat& handler, + const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const Scaler& scaler) const { - memset(distances, -1, sizeof(float) * k * n); - memset(labels, -1, sizeof(idx_t) * k * n); - - using HeapHC = HeapHandler; - using ReservoirHC = ReservoirHandler; - using SingleResultHC = SingleResultHandler; - - uint64_t times[10]; - memset(times, 0, sizeof(times)); - int ti = 0; -#define TIC times[ti++] = get_cy() - TIC; - - COARSE_QUANTIZE; - - TIC; - + const NormTableScaler* scaler) const { size_t dim12 = ksub * M2; AlignedTable dis_tables; AlignedTable biases; std::unique_ptr normalizers(new float[2 * n]); - compute_LUT_uint8( - n, - x, - coarse_ids, - coarse_dis, - dis_tables, - biases, - normalizers.get()); - - TIC; + compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get()); bool single_LUT = !lookup_table_is_3d(); - TIC; size_t ndis = 0, nlist_visited = 0; + int qmap1[1]; - { - AlignedTable tmp_distances(k); - for (idx_t i = 0; i < n; i++) { - const uint8_t* LUT = nullptr; - int qmap1[1] = {0}; - std::unique_ptr> handler; - - if (k == 1) { - handler.reset(new SingleResultHC(1, 0)); - } else if (impl == 10) { - handler.reset(new HeapHC( - 1, tmp_distances.get(), labels + i * k, k, 0)); - } else if (impl == 11) { - handler.reset(new ReservoirHC(1, 0, k, 2 * k)); - } else { - FAISS_THROW_MSG("invalid"); - } + handler.q_map = qmap1; + handler.begin(skip & 16 ? nullptr : normalizers.get()); + size_t nprobe = cq.nprobe; - handler->q_map = qmap1; + for (idx_t i = 0; i < n; i++) { + const uint8_t* LUT = nullptr; + qmap1[0] = i; - if (single_LUT) { - LUT = dis_tables.get() + i * dim12; + if (single_LUT) { + LUT = dis_tables.get() + i * dim12; + } + for (idx_t j = 0; j < nprobe; j++) { + size_t ij = i * nprobe + j; + if (!single_LUT) { + LUT = dis_tables.get() + ij * dim12; + } + if (biases.get()) { + handler.dbias = biases.get() + ij; } - for (idx_t j = 0; j < nprobe; j++) { - size_t ij = i * nprobe + j; - if (!single_LUT) { - LUT = dis_tables.get() + ij * dim12; - } - if (biases.get()) { - handler->dbias = biases.get() + ij; - } - - idx_t list_no = coarse_ids[ij]; - if (list_no < 0) - continue; - size_t ls = invlists->list_size(list_no); - if (ls == 0) - continue; - InvertedLists::ScopedCodes codes(invlists, list_no); - InvertedLists::ScopedIds ids(invlists, list_no); + idx_t list_no = cq.ids[ij]; + if (list_no < 0) { + continue; + } + size_t ls = invlists->list_size(list_no); + if (ls == 0) { + continue; + } - handler->ntotal = ls; - handler->id_map = ids.get(); + InvertedLists::ScopedCodes codes(invlists, list_no); + InvertedLists::ScopedIds ids(invlists, list_no); -#define DISPATCH(classHC) \ - if (dynamic_cast(handler.get())) { \ - auto* res = static_cast(handler.get()); \ - pq4_accumulate_loop( \ - 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \ - } - DISPATCH(HeapHC) - else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC) -#undef DISPATCH + handler.ntotal = ls; + handler.id_map = ids.get(); - nlist_visited++; - ndis++; - } + pq4_accumulate_loop( + 1, + roundup(ls, bbs), + bbs, + M2, + codes.get(), + LUT, + handler, + scaler); - handler->to_flat_arrays( - distances + i * k, - labels + i * k, - skip & 16 ? nullptr : normalizers.get() + i * 2); + nlist_visited++; + ndis++; } } + handler.end(); *ndis_out = ndis; *nlist_out = nlist; } -template void IndexIVFFastScan::search_implem_12( idx_t n, const float* x, - idx_t k, - float* distances, - idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - int impl, + SIMDResultHandlerToFloat& handler, + const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const Scaler& scaler) const { + const NormTableScaler* scaler) const { if (n == 0) { // does not work well with reservoir return; } FAISS_THROW_IF_NOT(bbs == 32); - uint64_t times[10]; - memset(times, 0, sizeof(times)); - int ti = 0; -#define TIC times[ti++] = get_cy() - TIC; - - COARSE_QUANTIZE; - - TIC; - size_t dim12 = ksub * M2; AlignedTable dis_tables; AlignedTable biases; std::unique_ptr normalizers(new float[2 * n]); - compute_LUT_uint8( - n, - x, - coarse_ids, - coarse_dis, - dis_tables, - biases, - normalizers.get()); - - TIC; + compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get()); + handler.begin(skip & 16 ? nullptr : normalizers.get()); struct QC { int qno; // sequence number of the query @@ -878,14 +942,15 @@ void IndexIVFFastScan::search_implem_12( int rank; // this is the rank'th result of the coarse quantizer }; bool single_LUT = !lookup_table_is_3d(); + size_t nprobe = cq.nprobe; std::vector qcs; { int ij = 0; for (int i = 0; i < n; i++) { for (int j = 0; j < nprobe; j++) { - if (coarse_ids[ij] >= 0) { - qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)}); + if (cq.ids[ij] >= 0) { + qcs.push_back(QC{i, int(cq.ids[ij]), int(j)}); } ij++; } @@ -894,42 +959,21 @@ void IndexIVFFastScan::search_implem_12( return a.list_no < b.list_no; }); } - TIC; - // prepare the result handlers - std::unique_ptr> handler; - AlignedTable tmp_distances; - - using HeapHC = HeapHandler; - using ReservoirHC = ReservoirHandler; - using SingleResultHC = SingleResultHandler; - - if (k == 1) { - handler.reset(new SingleResultHC(n, 0)); - } else if (impl == 12) { - tmp_distances.resize(n * k); - handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0)); - } else if (impl == 13) { - handler.reset(new ReservoirHC(n, 0, k, 2 * k)); - } - int qbs2 = this->qbs2 ? this->qbs2 : 11; std::vector tmp_bias; if (biases.get()) { tmp_bias.resize(qbs2); - handler->dbias = tmp_bias.data(); + handler.dbias = tmp_bias.data(); } - TIC; size_t ndis = 0; size_t i0 = 0; uint64_t t_copy_pack = 0, t_scan = 0; while (i0 < qcs.size()) { - uint64_t tt0 = get_cy(); - // find all queries that access this inverted list int list_no = qcs[i0].list_no; size_t i1 = i0 + 1; @@ -977,92 +1021,47 @@ void IndexIVFFastScan::search_implem_12( // prepare the handler - handler->ntotal = list_size; - handler->q_map = q_map.data(); - handler->id_map = ids.get(); - uint64_t tt1 = get_cy(); - -#define DISPATCH(classHC) \ - if (dynamic_cast(handler.get())) { \ - auto* res = static_cast(handler.get()); \ - pq4_accumulate_loop_qbs( \ - qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \ - } - DISPATCH(HeapHC) - else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC) - - // prepare for next loop - i0 = i1; + handler.ntotal = list_size; + handler.q_map = q_map.data(); + handler.id_map = ids.get(); - uint64_t tt2 = get_cy(); - t_copy_pack += tt1 - tt0; - t_scan += tt2 - tt1; + pq4_accumulate_loop_qbs( + qbs, list_size, M2, codes.get(), LUT.get(), handler, scaler); + // prepare for next loop + i0 = i1; } - TIC; - - // labels is in-place for HeapHC - handler->to_flat_arrays( - distances, labels, skip & 16 ? nullptr : normalizers.get()); - TIC; + handler.end(); // these stats are not thread-safe - for (int i = 1; i < ti; i++) { - IVFFastScan_stats.times[i] += times[i] - times[i - 1]; - } IVFFastScan_stats.t_copy_pack += t_copy_pack; IVFFastScan_stats.t_scan += t_scan; - if (auto* rh = dynamic_cast(handler.get())) { - for (int i = 0; i < 4; i++) { - IVFFastScan_stats.reservoir_times[i] += rh->times[i]; - } - } - *ndis_out = ndis; *nlist_out = nlist; } -template void IndexIVFFastScan::search_implem_14( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, int impl, - const Scaler& scaler) const { + const NormTableScaler* scaler) const { if (n == 0) { // does not work well with reservoir return; } FAISS_THROW_IF_NOT(bbs == 32); - uint64_t ttg0 = get_cy(); - - COARSE_QUANTIZE; - - uint64_t ttg1 = get_cy(); - uint64_t coarse_search_tt = ttg1 - ttg0; - size_t dim12 = ksub * M2; AlignedTable dis_tables; AlignedTable biases; std::unique_ptr normalizers(new float[2 * n]); - compute_LUT_uint8( - n, - x, - coarse_ids, - coarse_dis, - dis_tables, - biases, - normalizers.get()); - - uint64_t ttg2 = get_cy(); - uint64_t lut_compute_tt = ttg2 - ttg1; + compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get()); struct QC { int qno; // sequence number of the query @@ -1070,14 +1069,15 @@ void IndexIVFFastScan::search_implem_14( int rank; // this is the rank'th result of the coarse quantizer }; bool single_LUT = !lookup_table_is_3d(); + size_t nprobe = cq.nprobe; std::vector qcs; { int ij = 0; for (int i = 0; i < n; i++) { for (int j = 0; j < nprobe; j++) { - if (coarse_ids[ij] >= 0) { - qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)}); + if (cq.ids[ij] >= 0) { + qcs.push_back(QC{i, int(cq.ids[ij]), int(j)}); } ij++; } @@ -1115,14 +1115,13 @@ void IndexIVFFastScan::search_implem_14( ses.push_back(SE{i0_l, i1, list_size}); i0_l = i1; } - uint64_t ttg3 = get_cy(); - uint64_t compute_clusters_tt = ttg3 - ttg2; // function to handle the global heap + bool is_max = !is_similarity_metric(metric_type); using HeapForIP = CMin; using HeapForL2 = CMax; auto init_result = [&](float* simi, idx_t* idxi) { - if (metric_type == METRIC_INNER_PRODUCT) { + if (!is_max) { heap_heapify(k, simi, idxi); } else { heap_heapify(k, simi, idxi); @@ -1133,7 +1132,7 @@ void IndexIVFFastScan::search_implem_14( const idx_t* local_idx, float* simi, idx_t* idxi) { - if (metric_type == METRIC_INNER_PRODUCT) { + if (!is_max) { heap_addn(k, simi, idxi, local_dis, local_idx, k); } else { heap_addn(k, simi, idxi, local_dis, local_idx, k); @@ -1141,14 +1140,12 @@ void IndexIVFFastScan::search_implem_14( }; auto reorder_result = [&](float* simi, idx_t* idxi) { - if (metric_type == METRIC_INNER_PRODUCT) { + if (!is_max) { heap_reorder(k, simi, idxi); } else { heap_reorder(k, simi, idxi); } }; - uint64_t ttg4 = get_cy(); - uint64_t fn_tt = ttg4 - ttg3; size_t ndis = 0; size_t nlist_visited = 0; @@ -1160,22 +1157,9 @@ void IndexIVFFastScan::search_implem_14( std::vector local_dis(k * n); // prepare the result handlers - std::unique_ptr> handler; - AlignedTable tmp_distances; - - using HeapHC = HeapHandler; - using ReservoirHC = ReservoirHandler; - using SingleResultHC = SingleResultHandler; - - if (k == 1) { - handler.reset(new SingleResultHC(n, 0)); - } else if (impl == 14) { - tmp_distances.resize(n * k); - handler.reset( - new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0)); - } else if (impl == 15) { - handler.reset(new ReservoirHC(n, 0, k, 2 * k)); - } + std::unique_ptr handler(make_knn_handler( + is_max, impl, n, k, local_dis.data(), local_idx.data())); + handler->begin(normalizers.get()); int qbs2 = this->qbs2 ? this->qbs2 : 11; @@ -1184,15 +1168,10 @@ void IndexIVFFastScan::search_implem_14( tmp_bias.resize(qbs2); handler->dbias = tmp_bias.data(); } - - uint64_t ttg5 = get_cy(); - uint64_t handler_tt = ttg5 - ttg4; - std::set q_set; uint64_t t_copy_pack = 0, t_scan = 0; #pragma omp for schedule(dynamic) for (idx_t cluster = 0; cluster < ses.size(); cluster++) { - uint64_t tt0 = get_cy(); size_t i0 = ses[cluster].start; size_t i1 = ses[cluster].end; size_t list_size = ses[cluster].list_size; @@ -1232,28 +1211,21 @@ void IndexIVFFastScan::search_implem_14( handler->ntotal = list_size; handler->q_map = q_map.data(); handler->id_map = ids.get(); - uint64_t tt1 = get_cy(); -#define DISPATCH(classHC) \ - if (dynamic_cast(handler.get())) { \ - auto* res = static_cast(handler.get()); \ - pq4_accumulate_loop_qbs( \ - qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \ - } - DISPATCH(HeapHC) - else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC) - - uint64_t tt2 = get_cy(); - t_copy_pack += tt1 - tt0; - t_scan += tt2 - tt1; + pq4_accumulate_loop_qbs( + qbs, + list_size, + M2, + codes.get(), + LUT.get(), + *handler.get(), + scaler); } // labels is in-place for HeapHC - handler->to_flat_arrays( - local_dis.data(), - local_idx.data(), - skip & 16 ? nullptr : normalizers.get()); + handler->end(); + // merge per-thread results #pragma omp single { // we init the results as a heap @@ -1276,12 +1248,6 @@ void IndexIVFFastScan::search_implem_14( IVFFastScan_stats.t_copy_pack += t_copy_pack; IVFFastScan_stats.t_scan += t_scan; - - if (auto* rh = dynamic_cast(handler.get())) { - for (int i = 0; i < 4; i++) { - IVFFastScan_stats.reservoir_times[i] += rh->times[i]; - } - } } #pragma omp barrier #pragma omp single @@ -1351,24 +1317,4 @@ void IndexIVFFastScan::reconstruct_orig_invlists() { IVFFastScanStats IVFFastScan_stats; -template void IndexIVFFastScan::search_dispatch_implem( - idx_t n, - const float* x, - idx_t k, - float* distances, - idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const NormTableScaler& scaler) const; - -template void IndexIVFFastScan::search_dispatch_implem( - idx_t n, - const float* x, - idx_t k, - float* distances, - idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const NormTableScaler& scaler) const; - } // namespace faiss diff --git a/faiss/IndexIVFFastScan.h b/faiss/IndexIVFFastScan.h index 824e63ed28..159a3a7098 100644 --- a/faiss/IndexIVFFastScan.h +++ b/faiss/IndexIVFFastScan.h @@ -14,6 +14,9 @@ namespace faiss { +struct NormTableScaler; +struct SIMDResultHandlerToFloat; + /** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now. * * The codes in the inverted lists are not stored sequentially but @@ -28,6 +31,12 @@ namespace faiss { * 11: idem, collect results in reservoir * 12: optimizer int16 search, collect results in heap, uses qbs * 13: idem, collect results in reservoir + * 14: internally multithreaded implem over nq * nprobe + * 15: same with reservoir + * + * For range search, only 10 and 12 are supported. + * add 100 to the implem to force single-thread scanning (the coarse quantizer + * may still use multiple threads). */ struct IndexIVFFastScan : IndexIVF { @@ -80,19 +89,24 @@ struct IndexIVFFastScan : IndexIVF { virtual bool lookup_table_is_3d() const = 0; + // compact way of conveying coarse quantization results + struct CoarseQuantized { + size_t nprobe; + const float* dis = nullptr; + const idx_t* ids = nullptr; + }; + virtual void compute_LUT( size_t n, const float* x, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const = 0; void compute_LUT_uint8( size_t n, const float* x, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases, float* normalizers) const; @@ -117,7 +131,6 @@ struct IndexIVFFastScan : IndexIVF { const IVFSearchParameters* params = nullptr, IndexIVFStats* stats = nullptr) const override; - /// will just fail void range_search( idx_t n, const float* x, @@ -127,81 +140,75 @@ struct IndexIVFFastScan : IndexIVF { // internal search funcs - template + // dispatch to implementations and parallelize void search_dispatch_implem( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const Scaler& scaler) const; + const CoarseQuantized& cq, + const NormTableScaler* scaler) const; + + void range_search_dispatch_implem( + idx_t n, + const float* x, + float radius, + RangeSearchResult& rres, + const CoarseQuantized& cq_in, + const NormTableScaler* scaler) const; - template + // impl 1 and 2 are just for verification + template void search_implem_1( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const Scaler& scaler) const; + const CoarseQuantized& cq, + const NormTableScaler* scaler) const; - template + template void search_implem_2( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - const Scaler& scaler) const; + const CoarseQuantized& cq, + const NormTableScaler* scaler) const; // implem 10 and 12 are not multithreaded internally, so // export search stats - template void search_implem_10( idx_t n, const float* x, - idx_t k, - float* distances, - idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - int impl, + SIMDResultHandlerToFloat& handler, + const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; - template void search_implem_12( idx_t n, const float* x, - idx_t k, - float* distances, - idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, - int impl, + SIMDResultHandlerToFloat& handler, + const CoarseQuantized& cq, size_t* ndis_out, size_t* nlist_out, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; // implem 14 is multithreaded internally across nprobes and queries - template void search_implem_14( idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, int impl, - const Scaler& scaler) const; + const NormTableScaler* scaler) const; // reconstruct vectors from packed invlists void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons) diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index c433991c9b..6de78b9539 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -749,7 +749,7 @@ struct QueryTables { } }; -// This way of handling the sleector is not optimal since all distances +// This way of handling the selector is not optimal since all distances // are computed even if the id would filter it out. template struct KnnSearchResults { diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp index b44b71ec67..d069db1354 100644 --- a/faiss/IndexIVFPQFastScan.cpp +++ b/faiss/IndexIVFPQFastScan.cpp @@ -171,7 +171,7 @@ void IndexIVFPQFastScan::encode_vectors( * Look-Up Table functions *********************************************************/ -void fvec_madd_avx( +void fvec_madd_simd( size_t n, const float* a, float bf, @@ -202,12 +202,12 @@ bool IndexIVFPQFastScan::lookup_table_is_3d() const { void IndexIVFPQFastScan::compute_LUT( size_t n, const float* x, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const { size_t dim12 = pq.ksub * pq.M; size_t d = pq.d; + size_t nprobe = this->nprobe; if (by_residual) { if (metric_type == METRIC_L2) { @@ -215,7 +215,7 @@ void IndexIVFPQFastScan::compute_LUT( if (use_precomputed_table == 1) { biases.resize(n * nprobe); - memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe); + memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe); AlignedTable ip_table(n * dim12); pq.compute_inner_prod_tables(n, x, ip_table.get()); @@ -224,10 +224,10 @@ void IndexIVFPQFastScan::compute_LUT( for (idx_t ij = 0; ij < n * nprobe; ij++) { idx_t i = ij / nprobe; float* tab = dis_tables.get() + ij * dim12; - idx_t cij = coarse_ids[ij]; + idx_t cij = cq.ids[ij]; if (cij >= 0) { - fvec_madd_avx( + fvec_madd_simd( dim12, precomputed_table.get() + cij * dim12, -2, @@ -249,7 +249,7 @@ void IndexIVFPQFastScan::compute_LUT( for (idx_t ij = 0; ij < n * nprobe; ij++) { idx_t i = ij / nprobe; float* xij = &xrel[ij * d]; - idx_t cij = coarse_ids[ij]; + idx_t cij = cq.ids[ij]; if (cij >= 0) { quantizer->compute_residual(x + i * d, xij, cij); @@ -269,7 +269,7 @@ void IndexIVFPQFastScan::compute_LUT( // compute_inner_prod_tables(pq, n, x, dis_tables.get()); biases.resize(n * nprobe); - memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe); + memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe); } else { FAISS_THROW_FMT("metric %d not supported", metric_type); } diff --git a/faiss/IndexIVFPQFastScan.h b/faiss/IndexIVFPQFastScan.h index 9a79833591..00dd2f11dd 100644 --- a/faiss/IndexIVFPQFastScan.h +++ b/faiss/IndexIVFPQFastScan.h @@ -77,8 +77,7 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { void compute_LUT( size_t n, const float* x, - const idx_t* coarse_ids, - const float* coarse_dis, + const CoarseQuantized& cq, AlignedTable& dis_tables, AlignedTable& biases) const override; diff --git a/faiss/impl/AuxIndexStructures.h b/faiss/impl/AuxIndexStructures.h index 344a708b78..f8b5cca842 100644 --- a/faiss/impl/AuxIndexStructures.h +++ b/faiss/impl/AuxIndexStructures.h @@ -41,7 +41,6 @@ struct RangeSearchResult { /// called when lims contains the nb of elements result entries /// for each query - virtual void do_allocation(); virtual ~RangeSearchResult(); diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea39..8c4c5be87f 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -14,6 +12,7 @@ #include #include #include +#include #include #include @@ -513,17 +512,15 @@ void HNSW::add_with_locks( **************************************************************/ namespace { - using MinimaxHeap = HNSW::MinimaxHeap; using Node = HNSW::Node; +using C = HNSW::C; /** Do a BFS on the candidates list */ int search_from_candidates( const HNSW& hnsw, DistanceComputer& qdis, - int k, - idx_t* I, - float* D, + ResultHandler& res, MinimaxHeap& candidates, VisitedTable& vt, HNSWStats& stats, @@ -539,15 +536,16 @@ int search_from_candidates( int efSearch = params ? params->efSearch : hnsw.efSearch; const IDSelector* sel = params ? params->sel : nullptr; + C::T threshold = res.threshold; for (int i = 0; i < candidates.size(); i++) { idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); + if (d < threshold) { + if (res.add_result(d, v1)) { + threshold = res.threshold; + } } } vt.set(v1); @@ -609,13 +607,14 @@ int search_from_candidates( size_t saved_j[4]; ndis += jmax - begin; + threshold = res.threshold; auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); + if (dis < threshold) { + if (res.add_result(dis, idx)) { + threshold = res.threshold; + } } } candidates.push(idx, dis); @@ -799,19 +798,28 @@ std::priority_queue search_from_candidate_unbounded( return top_candidates; } +// just used as a lower bound for the minmaxheap, but it is set for heap search +int extract_k_from_ResultHandler(ResultHandler& res) { + using RH = HeapBlockResultHandler; + if (auto hres = dynamic_cast(&res)) { + return hres->k; + } + return 1; +} + } // anonymous namespace HNSWStats HNSW::search( DistanceComputer& qdis, - int k, - idx_t* I, - float* D, + ResultHandler& res, VisitedTable& vt, const SearchParametersHNSW* params) const { HNSWStats stats; if (entry_point == -1) { return stats; } + int k = extract_k_from_ResultHandler(res); + if (upper_beam == 1) { // greedy search on upper levels storage_idx_t nearest = entry_point; @@ -828,7 +836,7 @@ HNSWStats HNSW::search( candidates.push(nearest, d_nearest); search_from_candidates( - *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params); + *this, qdis, res, candidates, vt, stats, 0, 0, params); } else { std::priority_queue top_candidates = search_from_candidate_unbounded( @@ -848,7 +856,8 @@ HNSWStats HNSW::search( float d; storage_idx_t label; std::tie(d, label) = top_candidates.top(); - faiss::maxheap_push(++nres, D, I, d, label); + res.add_result(d, label); + nres++; top_candidates.pop(); } } @@ -862,6 +871,10 @@ HNSWStats HNSW::search( std::vector I_to_next(candidates_size); std::vector D_to_next(candidates_size); + HeapBlockResultHandler block_resh( + 1, D_to_next.data(), I_to_next.data(), candidates_size); + HeapBlockResultHandler::SingleResultHandler resh(block_resh); + int nres = 1; I_to_next[0] = entry_point; D_to_next[0] = qdis(entry_point); @@ -877,18 +890,12 @@ HNSWStats HNSW::search( if (level == 0) { nres = search_from_candidates( - *this, qdis, k, I, D, candidates, vt, stats, 0); + *this, qdis, res, candidates, vt, stats, 0); } else { + resh.begin(0); nres = search_from_candidates( - *this, - qdis, - candidates_size, - I_to_next.data(), - D_to_next.data(), - candidates, - vt, - stats, - level); + *this, qdis, resh, candidates, vt, stats, level); + resh.end(); } vt.advance(); } @@ -899,9 +906,7 @@ HNSWStats HNSW::search( void HNSW::search_level_0( DistanceComputer& qdis, - int k, - idx_t* idxi, - float* simi, + ResultHandler& res, idx_t nprobe, const storage_idx_t* nearest_i, const float* nearest_d, @@ -909,7 +914,7 @@ void HNSW::search_level_0( HNSWStats& search_stats, VisitedTable& vt) const { const HNSW& hnsw = *this; - + int k = extract_k_from_ResultHandler(res); if (search_type == 1) { int nres = 0; @@ -922,22 +927,13 @@ void HNSW::search_level_0( if (vt.get(cj)) continue; - int candidates_size = std::max(hnsw.efSearch, int(k)); + int candidates_size = std::max(hnsw.efSearch, k); MinimaxHeap candidates(candidates_size); candidates.push(cj, nearest_d[j]); nres = search_from_candidates( - hnsw, - qdis, - k, - idxi, - simi, - candidates, - vt, - search_stats, - 0, - nres); + hnsw, qdis, res, candidates, vt, search_stats, 0, nres); } } else if (search_type == 2) { int candidates_size = std::max(hnsw.efSearch, int(k)); @@ -953,7 +949,7 @@ void HNSW::search_level_0( } search_from_candidates( - hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0); + hnsw, qdis, res, candidates, vt, search_stats, 0); } } diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index c923e0a6ae..cb6b422c3d 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #pragma once #include @@ -42,6 +40,8 @@ namespace faiss { struct VisitedTable; struct DistanceComputer; // from AuxIndexStructures struct HNSWStats; +template +struct ResultHandler; struct SearchParametersHNSW : SearchParameters { int efSearch = 16; @@ -54,6 +54,9 @@ struct HNSW { /// internal storage of vectors (32 bits: this is expensive) using storage_idx_t = int32_t; + // for now we do only these distances + using C = CMax; + typedef std::pair Node; /** Heap structure that allows fast @@ -195,18 +198,14 @@ struct HNSW { /// search interface for 1 point, single thread HNSWStats search( DistanceComputer& qdis, - int k, - idx_t* I, - float* D, + ResultHandler& res, VisitedTable& vt, const SearchParametersHNSW* params = nullptr) const; /// search only in level 0 from a given vertex void search_level_0( DistanceComputer& qdis, - int k, - idx_t* idxi, - float* simi, + ResultHandler& res, idx_t nprobe, const storage_idx_t* nearest_i, const float* nearest_d, diff --git a/faiss/impl/Quantizer.h b/faiss/impl/Quantizer.h index 34673211d7..9171448ef5 100644 --- a/faiss/impl/Quantizer.h +++ b/faiss/impl/Quantizer.h @@ -11,7 +11,7 @@ namespace faiss { -/** Product Quantizer. Implemented only for METRIC_L2 */ +/** General interface for quantizer objects */ struct Quantizer { size_t d; ///< size of the input vectors size_t code_size; ///< bytes per indexed vector diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index c6f731de91..a532cf2a95 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -17,23 +17,170 @@ namespace faiss { +/***************************************************************** + * The classes below are intended to be used as template arguments + * they handle results for batches of queries (size nq). + * They can be called in two ways: + * - by instanciating a SingleResultHandler that tracks results for a single + * query + * - with begin_multiple/add_results/end_multiple calls where a whole block of + * resutls is submitted + * All classes are templated on C which to define wheter the min or the max of + * results is to be kept. + *****************************************************************/ + +template +struct BlockResultHandler { + size_t nq; // number of queries for which we search + + explicit BlockResultHandler(size_t nq) : nq(nq) {} + + // currently handled query range + size_t i0 = 0, i1 = 0; + + // start collecting results for queries [i0, i1) + virtual void begin_multiple(size_t i0, size_t i1) { + this->i0 = i0; + this->i1 = i1; + } + + // add results for queries [i0, i1) and database [j0, j1) + virtual void add_results(size_t, size_t, const typename C::T*) {} + + // series of results for queries i0..i1 is done + virtual void end_multiple() {} + + virtual ~BlockResultHandler() {} +}; + +// handler for a single query +template +struct ResultHandler { + // if not better than threshold, then not necessary to call add_result + typename C::T threshold = 0; + + // return whether threshold was updated + virtual bool add_result(typename C::T dis, typename C::TI idx) = 0; + + virtual ~ResultHandler() {} +}; + +/***************************************************************** + * Single best result handler. + * Tracks the only best result, thus avoiding storing + * some temporary data in memory. + *****************************************************************/ + +template +struct Top1BlockResultHandler : BlockResultHandler { + using T = typename C::T; + using TI = typename C::TI; + using BlockResultHandler::i0; + using BlockResultHandler::i1; + + // contains exactly nq elements + T* dis_tab; + // contains exactly nq elements + TI* ids_tab; + + Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab) + : BlockResultHandler(nq), dis_tab(dis_tab), ids_tab(ids_tab) {} + + struct SingleResultHandler : ResultHandler { + Top1BlockResultHandler& hr; + using ResultHandler::threshold; + + TI min_idx; + size_t current_idx = 0; + + explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {} + + /// begin results for query # i + void begin(const size_t current_idx) { + this->current_idx = current_idx; + threshold = C::neutral(); + min_idx = -1; + } + + /// add one result for query i + bool add_result(T dis, TI idx) final { + if (C::cmp(this->threshold, dis)) { + threshold = dis; + min_idx = idx; + return true; + } + return false; + } + + /// series of results for query i is done + void end() { + hr.dis_tab[current_idx] = threshold; + hr.ids_tab[current_idx] = min_idx; + } + }; + + /// begin + void begin_multiple(size_t i0, size_t i1) final { + this->i0 = i0; + this->i1 = i1; + + for (size_t i = i0; i < i1; i++) { + this->dis_tab[i] = C::neutral(); + } + } + + /// add results for query i0..i1 and j0..j1 + void add_results(size_t j0, size_t j1, const T* dis_tab) final { + for (int64_t i = i0; i < i1; i++) { + const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; + + auto& min_distance = this->dis_tab[i]; + auto& min_index = this->ids_tab[i]; + + for (size_t j = j0; j < j1; j++) { + const T distance = dis_tab_i[j]; + + if (C::cmp(min_distance, distance)) { + min_distance = distance; + min_index = j; + } + } + } + } + + void add_result(const size_t i, const T dis, const TI idx) { + auto& min_distance = this->dis_tab[i]; + auto& min_index = this->ids_tab[i]; + + if (C::cmp(min_distance, dis)) { + min_distance = dis; + min_index = idx; + } + } +}; + /***************************************************************** * Heap based result handler *****************************************************************/ template -struct HeapResultHandler { +struct HeapBlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; + using BlockResultHandler::i0; + using BlockResultHandler::i1; - int nq; T* heap_dis_tab; TI* heap_ids_tab; int64_t k; // number of results to keep - HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k) - : nq(nq), + HeapBlockResultHandler( + size_t nq, + T* heap_dis_tab, + TI* heap_ids_tab, + size_t k) + : BlockResultHandler(nq), heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k) {} @@ -43,30 +190,33 @@ struct HeapResultHandler { * called from 1 thread) */ - struct SingleResultHandler { - HeapResultHandler& hr; + struct SingleResultHandler : ResultHandler { + HeapBlockResultHandler& hr; + using ResultHandler::threshold; size_t k; T* heap_dis; TI* heap_ids; - T thresh; - SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {} + explicit SingleResultHandler(HeapBlockResultHandler& hr) + : hr(hr), k(hr.k) {} /// begin results for query # i void begin(size_t i) { heap_dis = hr.heap_dis_tab + i * k; heap_ids = hr.heap_ids_tab + i * k; heap_heapify(k, heap_dis, heap_ids); - thresh = heap_dis[0]; + threshold = heap_dis[0]; } /// add one result for query i - void add_result(T dis, TI idx) { - if (C::cmp(heap_dis[0], dis)) { + bool add_result(T dis, TI idx) final { + if (C::cmp(threshold, dis)) { heap_replace_top(k, heap_dis, heap_ids, dis, idx); - thresh = heap_dis[0]; + threshold = heap_dis[0]; + return true; } + return false; } /// series of results for query i is done @@ -79,19 +229,17 @@ struct HeapResultHandler { * API for multiple results (called from 1 thread) */ - size_t i0, i1; - /// begin - void begin_multiple(size_t i0_2, size_t i1_2) { + void begin_multiple(size_t i0_2, size_t i1_2) final { this->i0 = i0_2; this->i1 = i1_2; - for (size_t i = i0_2; i < i1_2; i++) { + for (size_t i = i0; i < i1; i++) { heap_heapify(k, heap_dis_tab + i * k, heap_ids_tab + i * k); } } /// add results for query i0..i1 and j0..j1 - void add_results(size_t j0, size_t j1, const T* dis_tab) { + void add_results(size_t j0, size_t j1, const T* dis_tab) final { #pragma omp parallel for for (int64_t i = i0; i < i1; i++) { T* heap_dis = heap_dis_tab + i * k; @@ -109,7 +257,7 @@ struct HeapResultHandler { } /// series of results for queries i0..i1 is done - void end_multiple() { + void end_multiple() final { // maybe parallel for for (size_t i = i0; i < i1; i++) { heap_reorder(k, heap_dis_tab + i * k, heap_ids_tab + i * k); @@ -128,9 +276,10 @@ struct HeapResultHandler { /// Reservoir for a single query template -struct ReservoirTopN { +struct ReservoirTopN : ResultHandler { using T = typename C::T; using TI = typename C::TI; + using ResultHandler::threshold; T* vals; TI* ids; @@ -139,8 +288,6 @@ struct ReservoirTopN { size_t n; // number of requested elements size_t capacity; // size of storage - T threshold; // current threshold - ReservoirTopN() {} ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids) @@ -149,15 +296,22 @@ struct ReservoirTopN { threshold = C::neutral(); } - void add(T val, TI id) { + bool add_result(T val, TI id) final { + bool updated_threshold = false; if (C::cmp(threshold, val)) { if (i == capacity) { shrink_fuzzy(); + updated_threshold = true; } vals[i] = val; ids[i] = id; i++; } + return updated_threshold; + } + + void add(T val, TI id) { + add_result(val, id); } // reduce storage from capacity to anything @@ -169,6 +323,11 @@ struct ReservoirTopN { vals, ids, capacity, n, (capacity + n) / 2, &i); } + void shrink() { + threshold = partition(vals, ids, i, n); + i = n; + } + void to_result(T* heap_dis, TI* heap_ids) const { for (int j = 0; j < std::min(i, n); j++) { heap_push(j + 1, heap_dis, heap_ids, vals[j], ids[j]); @@ -187,23 +346,24 @@ struct ReservoirTopN { }; template -struct ReservoirResultHandler { +struct ReservoirBlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; + using BlockResultHandler::i0; + using BlockResultHandler::i1; - int nq; T* heap_dis_tab; TI* heap_ids_tab; int64_t k; // number of results to keep size_t capacity; // capacity of the reservoirs - ReservoirResultHandler( + ReservoirBlockResultHandler( size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k) - : nq(nq), + : BlockResultHandler(nq), heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k) { @@ -216,40 +376,34 @@ struct ReservoirResultHandler { * called from 1 thread) */ - struct SingleResultHandler { - ReservoirResultHandler& hr; + struct SingleResultHandler : ReservoirTopN { + ReservoirBlockResultHandler& hr; std::vector reservoir_dis; std::vector reservoir_ids; - ReservoirTopN res1; - SingleResultHandler(ReservoirResultHandler& hr) - : hr(hr), - reservoir_dis(hr.capacity), - reservoir_ids(hr.capacity) {} + explicit SingleResultHandler(ReservoirBlockResultHandler& hr) + : ReservoirTopN(hr.k, hr.capacity, nullptr, nullptr), + hr(hr) {} - size_t i; + size_t qno; /// begin results for query # i - void begin(size_t i_2) { - res1 = ReservoirTopN( - hr.k, - hr.capacity, - reservoir_dis.data(), - reservoir_ids.data()); - this->i = i_2; - } - - /// add one result for query i - void add_result(T dis, TI idx) { - res1.add(dis, idx); + void begin(size_t qno) { + reservoir_dis.resize(hr.capacity); + reservoir_ids.resize(hr.capacity); + this->vals = reservoir_dis.data(); + this->ids = reservoir_ids.data(); + this->i = 0; // size of reservoir + this->threshold = C::neutral(); + this->qno = qno; } - /// series of results for query i is done + /// series of results for query qno is done void end() { - T* heap_dis = hr.heap_dis_tab + i * hr.k; - TI* heap_ids = hr.heap_ids_tab + i * hr.k; - res1.to_result(heap_dis, heap_ids); + T* heap_dis = hr.heap_dis_tab + qno * hr.k; + TI* heap_ids = hr.heap_ids_tab + qno * hr.k; + this->to_result(heap_dis, heap_ids); } }; @@ -257,8 +411,6 @@ struct ReservoirResultHandler { * API for multiple results (called from 1 thread) */ - size_t i0, i1; - std::vector reservoir_dis; std::vector reservoir_ids; std::vector> reservoirs; @@ -267,8 +419,8 @@ struct ReservoirResultHandler { void begin_multiple(size_t i0_2, size_t i1_2) { this->i0 = i0_2; this->i1 = i1_2; - reservoir_dis.resize((i1_2 - i0_2) * capacity); - reservoir_ids.resize((i1_2 - i0_2) * capacity); + reservoir_dis.resize((i1 - i0) * capacity); + reservoir_ids.resize((i1 - i0) * capacity); reservoirs.clear(); for (size_t i = i0_2; i < i1_2; i++) { reservoirs.emplace_back( @@ -281,20 +433,19 @@ struct ReservoirResultHandler { /// add results for query i0..i1 and j0..j1 void add_results(size_t j0, size_t j1, const T* dis_tab) { - // maybe parallel for #pragma omp parallel for for (int64_t i = i0; i < i1; i++) { ReservoirTopN& reservoir = reservoirs[i - i0]; const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; for (size_t j = j0; j < j1; j++) { T dis = dis_tab_i[j]; - reservoir.add(dis, j); + reservoir.add_result(dis, j); } } } /// series of results for queries i0..i1 is done - void end_multiple() { + void end_multiple() final { // maybe parallel for for (size_t i = i0; i < i1; i++) { reservoirs[i - i0].to_result( @@ -308,29 +459,33 @@ struct ReservoirResultHandler { *****************************************************************/ template -struct RangeSearchResultHandler { +struct RangeSearchBlockResultHandler : BlockResultHandler { using T = typename C::T; using TI = typename C::TI; + using BlockResultHandler::i0; + using BlockResultHandler::i1; RangeSearchResult* res; - float radius; + T radius; - RangeSearchResultHandler(RangeSearchResult* res, float radius) - : res(res), radius(radius) {} + RangeSearchBlockResultHandler(RangeSearchResult* res, float radius) + : BlockResultHandler(res->nq), res(res), radius(radius) {} /****************************************************** * API for 1 result at a time (each SingleResultHandler is * called from 1 thread) ******************************************************/ - struct SingleResultHandler { + struct SingleResultHandler : ResultHandler { // almost the same interface as RangeSearchResultHandler + using ResultHandler::threshold; RangeSearchPartialResult pres; - float radius; RangeQueryResult* qr = nullptr; - SingleResultHandler(RangeSearchResultHandler& rh) - : pres(rh.res), radius(rh.radius) {} + explicit SingleResultHandler(RangeSearchBlockResultHandler& rh) + : pres(rh.res) { + threshold = rh.radius; + } /// begin results for query # i void begin(size_t i) { @@ -338,10 +493,11 @@ struct RangeSearchResultHandler { } /// add one result for query i - void add_result(T dis, TI idx) { - if (C::cmp(radius, dis)) { + bool add_result(T dis, TI idx) final { + if (C::cmp(threshold, dis)) { qr->add(dis, idx); } + return false; } /// series of results for query i is done @@ -356,8 +512,6 @@ struct RangeSearchResultHandler { * API for multiple results (called from 1 thread) ******************************************************/ - size_t i0, i1; - std::vector partial_results; std::vector j0s; int pr = 0; @@ -404,109 +558,11 @@ struct RangeSearchResultHandler { } } - void end_multiple() {} - - ~RangeSearchResultHandler() { + ~RangeSearchBlockResultHandler() { if (partial_results.size() > 0) { RangeSearchPartialResult::merge(partial_results); } } }; -/***************************************************************** - * Single best result handler. - * Tracks the only best result, thus avoiding storing - * some temporary data in memory. - *****************************************************************/ - -template -struct SingleBestResultHandler { - using T = typename C::T; - using TI = typename C::TI; - - int nq; - // contains exactly nq elements - T* dis_tab; - // contains exactly nq elements - TI* ids_tab; - - SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab) - : nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {} - - struct SingleResultHandler { - SingleBestResultHandler& hr; - - T min_dis; - TI min_idx; - size_t current_idx = 0; - - SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {} - - /// begin results for query # i - void begin(const size_t current_idx_2) { - this->current_idx = current_idx_2; - min_dis = HUGE_VALF; - min_idx = -1; - } - - /// add one result for query i - void add_result(T dis, TI idx) { - if (C::cmp(min_dis, dis)) { - min_dis = dis; - min_idx = idx; - } - } - - /// series of results for query i is done - void end() { - hr.dis_tab[current_idx] = min_dis; - hr.ids_tab[current_idx] = min_idx; - } - }; - - size_t i0, i1; - - /// begin - void begin_multiple(size_t i0_2, size_t i1_2) { - this->i0 = i0_2; - this->i1 = i1_2; - - for (size_t i = i0_2; i < i1_2; i++) { - this->dis_tab[i] = HUGE_VALF; - } - } - - /// add results for query i0..i1 and j0..j1 - void add_results(size_t j0, size_t j1, const T* dis_tab_2) { - for (int64_t i = i0; i < i1; i++) { - const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0; - - auto& min_distance = this->dis_tab[i]; - auto& min_index = this->ids_tab[i]; - - for (size_t j = j0; j < j1; j++) { - const T distance = dis_tab_i[j]; - - if (C::cmp(min_distance, distance)) { - min_distance = distance; - min_index = j; - } - } - } - } - - void add_result(const size_t i, const T dis, const TI idx) { - auto& min_distance = this->dis_tab[i]; - auto& min_index = this->ids_tab[i]; - - if (C::cmp(min_distance, dis)) { - min_distance = dis; - min_index = idx; - } - } - - /// series of results for queries i0..i1 is done - void end_multiple() {} -}; - } // namespace faiss diff --git a/faiss/impl/pq4_fast_scan.cpp b/faiss/impl/pq4_fast_scan.cpp index d2cca15de3..6173ecef47 100644 --- a/faiss/impl/pq4_fast_scan.cpp +++ b/faiss/impl/pq4_fast_scan.cpp @@ -54,6 +54,9 @@ void pq4_pack_codes( FAISS_THROW_IF_NOT(nb % bbs == 0); FAISS_THROW_IF_NOT(nsq % 2 == 0); + if (nb == 0) { + return; + } memset(blocks, 0, nb * nsq / 2); const uint8_t perm0[16] = { 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; diff --git a/faiss/impl/pq4_fast_scan.h b/faiss/impl/pq4_fast_scan.h index 2e6931f8d3..9f95f76cc1 100644 --- a/faiss/impl/pq4_fast_scan.h +++ b/faiss/impl/pq4_fast_scan.h @@ -24,6 +24,9 @@ namespace faiss { +struct NormTableScaler; +struct SIMDResultHandler; + /** Pack codes for consumption by the SIMD kernels. * The unused bytes are set to 0. * @@ -117,7 +120,6 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest); * @param LUT packed look-up table * @param scaler scaler to scale the encoded norm */ -template void pq4_accumulate_loop( int nq, size_t nb, @@ -125,8 +127,8 @@ void pq4_accumulate_loop( int nsq, const uint8_t* codes, const uint8_t* LUT, - ResultHandler& res, - const Scaler& scaler); + SIMDResultHandler& res, + const NormTableScaler* scaler); /* qbs versions, supported only for bbs=32. * @@ -178,14 +180,13 @@ int pq4_pack_LUT_qbs_q_map( * @param res call-back for the resutls * @param scaler scaler to scale the encoded norm */ -template void pq4_accumulate_loop_qbs( int qbs, size_t nb, int nsq, const uint8_t* codes, const uint8_t* LUT, - ResultHandler& res, - const Scaler& scaler); + SIMDResultHandler& res, + const NormTableScaler* scaler = nullptr); } // namespace faiss diff --git a/faiss/impl/pq4_fast_scan_search_1.cpp b/faiss/impl/pq4_fast_scan_search_1.cpp index 6197c2be78..ca41f287f2 100644 --- a/faiss/impl/pq4_fast_scan_search_1.cpp +++ b/faiss/impl/pq4_fast_scan_search_1.cpp @@ -134,10 +134,8 @@ void accumulate_fixed_blocks( } } -} // anonymous namespace - template -void pq4_accumulate_loop( +void pq4_accumulate_loop_fixed_scaler( int nq, size_t nb, int bbs, @@ -172,39 +170,55 @@ void pq4_accumulate_loop( #undef DISPATCH } -// explicit template instantiations - -#define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \ - template void pq4_accumulate_loop, S>( \ - int, \ - size_t, \ - int, \ - int, \ - const uint8_t*, \ - const uint8_t*, \ - TH&, \ - const S&); - -using DS = DummyScaler; -using NS = NormTableScaler; - -#define INSTANTIATE_3(C, with_id_map) \ - INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \ - INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \ - INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \ - \ - INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \ - INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \ - INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS) - -using Csi = CMax; -INSTANTIATE_3(Csi, false); -using CsiMin = CMin; -INSTANTIATE_3(CsiMin, false); - -using Csl = CMax; -INSTANTIATE_3(Csl, true); -using CslMin = CMin; -INSTANTIATE_3(CslMin, true); +template +void pq4_accumulate_loop_fixed_handler( + int nq, + size_t nb, + int bbs, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const NormTableScaler* scaler) { + if (scaler) { + pq4_accumulate_loop_fixed_scaler( + nq, nb, bbs, nsq, codes, LUT, res, *scaler); + } else { + DummyScaler dscaler; + pq4_accumulate_loop_fixed_scaler( + nq, nb, bbs, nsq, codes, LUT, res, dscaler); + } +} + +struct Run_pq4_accumulate_loop { + template + void f(ResultHandler& res, + int nq, + size_t nb, + int bbs, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + const NormTableScaler* scaler) { + pq4_accumulate_loop_fixed_handler( + nq, nb, bbs, nsq, codes, LUT, res, scaler); + } +}; + +} // anonymous namespace + +void pq4_accumulate_loop( + int nq, + size_t nb, + int bbs, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + SIMDResultHandler& res, + const NormTableScaler* scaler) { + Run_pq4_accumulate_loop consumer; + dispatch_SIMDResultHanlder( + res, consumer, nq, nb, bbs, nsq, codes, LUT, scaler); +} } // namespace faiss diff --git a/faiss/impl/pq4_fast_scan_search_qbs.cpp b/faiss/impl/pq4_fast_scan_search_qbs.cpp index 50c0f6217b..d69542c309 100644 --- a/faiss/impl/pq4_fast_scan_search_qbs.cpp +++ b/faiss/impl/pq4_fast_scan_search_qbs.cpp @@ -14,6 +14,9 @@ namespace faiss { +// declared in simd_result_handlers.h +bool simd_result_handlers_accept_virtual = true; + using namespace simd_result_handlers; /************************************************************ @@ -194,10 +197,8 @@ void accumulate( #undef DISPATCH } -} // namespace - template -void pq4_accumulate_loop_qbs( +void pq4_accumulate_loop_qbs_fixed_scaler( int qbs, size_t ntotal2, int nsq, @@ -272,49 +273,39 @@ void pq4_accumulate_loop_qbs( } } -// explicit template instantiations - -#define INSTANTIATE_ACCUMULATE_Q(RH) \ - template void pq4_accumulate_loop_qbs( \ - int, \ - size_t, \ - int, \ - const uint8_t*, \ - const uint8_t*, \ - RH&, \ - const DummyScaler&); \ - template void pq4_accumulate_loop_qbs( \ - int, \ - size_t, \ - int, \ - const uint8_t*, \ - const uint8_t*, \ - RH&, \ - const NormTableScaler&); - -using Csi = CMax; -INSTANTIATE_ACCUMULATE_Q(SingleResultHandler) -INSTANTIATE_ACCUMULATE_Q(HeapHandler) -INSTANTIATE_ACCUMULATE_Q(ReservoirHandler) -using Csi2 = CMin; -INSTANTIATE_ACCUMULATE_Q(SingleResultHandler) -INSTANTIATE_ACCUMULATE_Q(HeapHandler) -INSTANTIATE_ACCUMULATE_Q(ReservoirHandler) - -using Cfl = CMax; -using HHCsl = HeapHandler; -using RHCsl = ReservoirHandler; -using SHCsl = SingleResultHandler; -INSTANTIATE_ACCUMULATE_Q(HHCsl) -INSTANTIATE_ACCUMULATE_Q(RHCsl) -INSTANTIATE_ACCUMULATE_Q(SHCsl) -using Cfl2 = CMin; -using HHCsl2 = HeapHandler; -using RHCsl2 = ReservoirHandler; -using SHCsl2 = SingleResultHandler; -INSTANTIATE_ACCUMULATE_Q(HHCsl2) -INSTANTIATE_ACCUMULATE_Q(RHCsl2) -INSTANTIATE_ACCUMULATE_Q(SHCsl2) +struct Run_pq4_accumulate_loop_qbs { + template + void f(ResultHandler& res, + int qbs, + size_t nb, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + const NormTableScaler* scaler) { + if (scaler) { + pq4_accumulate_loop_qbs_fixed_scaler( + qbs, nb, nsq, codes, LUT, res, *scaler); + } else { + DummyScaler dummy; + pq4_accumulate_loop_qbs_fixed_scaler( + qbs, nb, nsq, codes, LUT, res, dummy); + } + } +}; + +} // namespace + +void pq4_accumulate_loop_qbs( + int qbs, + size_t nb, + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + SIMDResultHandler& res, + const NormTableScaler* scaler) { + Run_pq4_accumulate_loop_qbs consumer; + dispatch_SIMDResultHanlder(res, consumer, qbs, nb, nsq, codes, LUT, scaler); +} /*************************************************************** * Packing functions diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 94a2541e03..2d8e5388d9 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -14,40 +14,86 @@ #include #include +#include +#include #include #include #include /** This file contains callbacks for kernels that compute distances. - * - * The SIMDResultHandler object is intended to be templated and inlined. - * Methods: - * - handle(): called when 32 distances are computed and provided in two - * simd16uint16. (q, b) indicate which entry it is in the block. - * - set_block_origin(): set the sub-matrix that is being computed */ namespace faiss { +struct SIMDResultHandler { + // used to dispatch templates + bool is_CMax = false; + uint8_t sizeof_ids = 0; + bool with_fields = false; + + /** called when 32 distances are computed and provided in two + * simd16uint16. (q, b) indicate which entry it is in the block. */ + virtual void handle( + size_t q, + size_t b, + simd16uint16 d0, + simd16uint16 d1) = 0; + + /// set the sub-matrix that is being computed + virtual void set_block_origin(size_t i0, size_t j0) = 0; + + virtual ~SIMDResultHandler() {} +}; + +/* Result handler that will return float resutls eventually */ +struct SIMDResultHandlerToFloat : SIMDResultHandler { + size_t nq; // number of queries + size_t ntotal; // ignore excess elements after ntotal + + /// these fields are used mainly for the IVF variants (with_id_map=true) + const idx_t* id_map = nullptr; // map offset in invlist to vector id + const int* q_map = nullptr; // map q to global query + const uint16_t* dbias = + nullptr; // table of biases to add to each query (for IVF L2 search) + const float* normalizers = nullptr; // size 2 * nq, to convert + + SIMDResultHandlerToFloat(size_t nq, size_t ntotal) + : nq(nq), ntotal(ntotal) {} + + virtual void begin(const float* norms) { + normalizers = norms; + } + + // called at end of search to convert int16 distances to float, before + // normalizers are deallocated + virtual void end() { + normalizers = nullptr; + } +}; + +FAISS_API extern bool simd_result_handlers_accept_virtual; + namespace simd_result_handlers { -/** Dummy structure that just computes a checksum on results +/** Dummy structure that just computes a chqecksum on results * (to avoid the computation to be optimized away) */ -struct DummyResultHandler { +struct DummyResultHandler : SIMDResultHandler { size_t cs = 0; - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0(); } - void set_block_origin(size_t, size_t) {} + void set_block_origin(size_t, size_t) final {} + + ~DummyResultHandler() {} }; /** memorize results in a nq-by-nb matrix. * * j0 is the current upper-left block of the matrix */ -struct StoreResultHandler { +struct StoreResultHandler : SIMDResultHandler { uint16_t* data; size_t ld; // total number of columns size_t i0 = 0; @@ -55,32 +101,32 @@ struct StoreResultHandler { StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {} - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { size_t ofs = (q + i0) * ld + j0 + b * 32; d0.store(data + ofs); d1.store(data + ofs + 16); } - void set_block_origin(size_t i0_2, size_t j0_2) { - this->i0 = i0_2; - this->j0 = j0_2; + void set_block_origin(size_t i0_in, size_t j0_in) final { + this->i0 = i0_in; + this->j0 = j0_in; } }; /** stores results in fixed-size matrix. */ template -struct FixedStorageHandler { +struct FixedStorageHandler : SIMDResultHandler { simd16uint16 dis[NQ][BB]; int i0 = 0; - void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { dis[q + i0][2 * b] = d0; dis[q + i0][2 * b + 1] = d1; } - void set_block_origin(size_t i0_2, size_t j0) { - this->i0 = i0_2; - assert(j0 == 0); + void set_block_origin(size_t i0_in, size_t j0_in) final { + this->i0 = i0_in; + assert(j0_in == 0); } template @@ -91,30 +137,29 @@ struct FixedStorageHandler { } } } + virtual ~FixedStorageHandler() {} }; -/** Record origin of current block */ +/** Result handler that compares distances to check if they need to be kept */ template -struct SIMDResultHandler { +struct ResultHandlerCompare : SIMDResultHandlerToFloat { using TI = typename C::TI; bool disable = false; int64_t i0 = 0; // query origin int64_t j0 = 0; // db origin - size_t ntotal; // ignore excess elements after ntotal - - /// these fields are used mainly for the IVF variants (with_id_map=true) - const TI* id_map; // map offset in invlist to vector id - const int* q_map; // map q to global query - const uint16_t* dbias; // table of biases to add to each query - explicit SIMDResultHandler(size_t ntotal) - : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {} + ResultHandlerCompare(size_t nq, size_t ntotal) + : SIMDResultHandlerToFloat(nq, ntotal) { + this->is_CMax = C::is_max; + this->sizeof_ids = sizeof(typename C::TI); + this->with_fields = with_id_map; + } - void set_block_origin(size_t i0_2, size_t j0_2) { - this->i0 = i0_2; - this->j0 = j0_2; + void set_block_origin(size_t i0_in, size_t j0_in) final { + this->i0 = i0_in; + this->j0 = j0_in; } // adjust handler data for IVF. @@ -172,43 +217,37 @@ struct SIMDResultHandler { return lt_mask; } - virtual void to_flat_arrays( - float* distances, - int64_t* labels, - const float* normalizers = nullptr) = 0; - - virtual ~SIMDResultHandler() {} + virtual ~ResultHandlerCompare() {} }; /** Special version for k=1 */ template -struct SingleResultHandler : SIMDResultHandler { +struct SingleResultHandler : ResultHandlerCompare { using T = typename C::T; using TI = typename C::TI; + using RHC = ResultHandlerCompare; + using RHC::normalizers; - struct Result { - T val; - TI id; - }; - std::vector results; + std::vector idis; + float* dis; + int64_t* ids; - SingleResultHandler(size_t nq, size_t ntotal) - : SIMDResultHandler(ntotal), results(nq) { + SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids) + : RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) { for (int i = 0; i < nq; i++) { - Result res = {C::neutral(), -1}; - results[i] = res; + ids[i] = -1; + idis[i] = C::neutral(); } } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { if (this->disable) { return; } this->adjust_with_origin(q, d0, d1); - Result& res = results[q]; - uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1); + uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1); if (!lt_mask) { return; } @@ -221,70 +260,61 @@ struct SingleResultHandler : SIMDResultHandler { // find first non-zero int j = __builtin_ctz(lt_mask); lt_mask -= 1 << j; - T dis = d32tab[j]; - if (C::cmp(res.val, dis)) { - res.val = dis; - res.id = this->adjust_id(b, j); + T d = d32tab[j]; + if (C::cmp(idis[q], d)) { + idis[q] = d; + ids[q] = this->adjust_id(b, j); } } } - void to_flat_arrays( - float* distances, - int64_t* labels, - const float* normalizers = nullptr) override { - for (int q = 0; q < results.size(); q++) { + void end() { + for (int q = 0; q < this->nq; q++) { if (!normalizers) { - distances[q] = results[q].val; + dis[q] = idis[q]; } else { float one_a = 1 / normalizers[2 * q]; float b = normalizers[2 * q + 1]; - distances[q] = b + results[q].val * one_a; + dis[q] = b + idis[q] * one_a; } - labels[q] = results[q].id; } } }; /** Structure that collects results in a min- or max-heap */ template -struct HeapHandler : SIMDResultHandler { +struct HeapHandler : ResultHandlerCompare { using T = typename C::T; using TI = typename C::TI; + using RHC = ResultHandlerCompare; + using RHC::normalizers; - int nq; - T* heap_dis_tab; - TI* heap_ids_tab; + std::vector idis; + std::vector iids; + float* dis; + int64_t* ids; int64_t k; // number of results to keep - HeapHandler( - int nq, - T* heap_dis_tab, - TI* heap_ids_tab, - size_t k, - size_t ntotal) - : SIMDResultHandler(ntotal), - nq(nq), - heap_dis_tab(heap_dis_tab), - heap_ids_tab(heap_ids_tab), + HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids) + : RHC(nq, ntotal), + idis(nq * k), + iids(nq * k), + dis(dis), + ids(ids), k(k) { - for (int q = 0; q < nq; q++) { - T* heap_dis_in = heap_dis_tab + q * k; - TI* heap_ids_in = heap_ids_tab + q * k; - heap_heapify(k, heap_dis_in, heap_ids_in); - } + heap_heapify(k * nq, idis.data(), iids.data()); } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { if (this->disable) { return; } this->adjust_with_origin(q, d0, d1); - T* heap_dis = heap_dis_tab + q * k; - TI* heap_ids = heap_ids_tab + q * k; + T* heap_dis = idis.data() + q * k; + TI* heap_ids = iids.data() + q * k; uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff; @@ -313,16 +343,13 @@ struct HeapHandler : SIMDResultHandler { } } - void to_flat_arrays( - float* distances, - int64_t* labels, - const float* normalizers = nullptr) override { - for (int q = 0; q < nq; q++) { - T* heap_dis_in = heap_dis_tab + q * k; - TI* heap_ids_in = heap_ids_tab + q * k; + void end() override { + for (int q = 0; q < this->nq; q++) { + T* heap_dis_in = idis.data() + q * k; + TI* heap_ids_in = iids.data() + q * k; heap_reorder(k, heap_dis_in, heap_ids_in); - int64_t* heap_ids = labels + q * k; - float* heap_dis = distances + q * k; + float* heap_dis = dis + q * k; + int64_t* heap_ids = ids + q * k; float one_a = 1.0, b = 0.0; if (normalizers) { @@ -330,8 +357,8 @@ struct HeapHandler : SIMDResultHandler { b = normalizers[2 * q + 1]; } for (int j = 0; j < k; j++) { - heap_ids[j] = heap_ids_in[j]; heap_dis[j] = heap_dis_in[j] * one_a + b; + heap_ids[j] = heap_ids_in[j]; } } } @@ -342,114 +369,45 @@ struct HeapHandler : SIMDResultHandler { * Results are stored when they are below the threshold until the capacity is * reached. Then a partition sort is used to update the threshold. */ -namespace { - -uint64_t get_cy() { -#ifdef MICRO_BENCHMARK - uint32_t high, low; - asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high)); - return ((uint64_t)high << 32) | (low); -#else - return 0; -#endif -} - -} // anonymous namespace - -template -struct ReservoirTopN { - using T = typename C::T; - using TI = typename C::TI; - - T* vals; - TI* ids; - - size_t i; // number of stored elements - size_t n; // number of requested elements - size_t capacity; // size of storage - size_t cycles = 0; - - T threshold; // current threshold - - ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids) - : vals(vals), ids(ids), i(0), n(n), capacity(capacity) { - assert(n < capacity); - threshold = C::neutral(); - } - - void add(T val, TI id) { - if (C::cmp(threshold, val)) { - if (i == capacity) { - shrink_fuzzy(); - } - vals[i] = val; - ids[i] = id; - i++; - } - } - - /// shrink number of stored elements to n - void shrink_xx() { - uint64_t t0 = get_cy(); - qselect(vals, ids, i, n); - i = n; // forget all elements above i = n - threshold = C::Crev::neutral(); - for (size_t j = 0; j < n; j++) { - if (C::cmp(vals[j], threshold)) { - threshold = vals[j]; - } - } - cycles += get_cy() - t0; - } - - void shrink() { - uint64_t t0 = get_cy(); - threshold = partition(vals, ids, i, n); - i = n; - cycles += get_cy() - t0; - } - - void shrink_fuzzy() { - uint64_t t0 = get_cy(); - assert(i == capacity); - threshold = partition_fuzzy( - vals, ids, capacity, n, (capacity + n) / 2, &i); - cycles += get_cy() - t0; - } -}; - /** Handler built from several ReservoirTopN (one per query) */ template -struct ReservoirHandler : SIMDResultHandler { +struct ReservoirHandler : ResultHandlerCompare { using T = typename C::T; using TI = typename C::TI; + using RHC = ResultHandlerCompare; + using RHC::normalizers; size_t capacity; // rounded up to multiple of 16 + + // where the final results will be written + float* dis; + int64_t* ids; + std::vector all_ids; AlignedTable all_vals; - std::vector> reservoirs; - uint64_t times[4]; - - ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in) - : SIMDResultHandler(ntotal), - capacity((capacity_in + 15) & ~15), - all_ids(nq * capacity), - all_vals(nq * capacity) { + ReservoirHandler( + size_t nq, + size_t ntotal, + size_t k, + size_t cap, + float* dis, + int64_t* ids) + : RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) { assert(capacity % 16 == 0); - for (size_t i = 0; i < nq; i++) { + all_ids.resize(nq * capacity); + all_vals.resize(nq * capacity); + for (size_t q = 0; q < nq; q++) { reservoirs.emplace_back( - n, + k, capacity, - all_vals.get() + i * capacity, - all_ids.data() + i * capacity); + all_vals.get() + q * capacity, + all_ids.data() + q * capacity); } - times[0] = times[1] = times[2] = times[3] = 0; } - void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) { - uint64_t t0 = get_cy(); + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { if (this->disable) { return; } @@ -457,8 +415,6 @@ struct ReservoirHandler : SIMDResultHandler { ReservoirTopN& res = reservoirs[q]; uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1); - uint64_t t1 = get_cy(); - times[0] += t1 - t0; if (!lt_mask) { return; @@ -474,20 +430,14 @@ struct ReservoirHandler : SIMDResultHandler { T dis = d32tab[j]; res.add(dis, this->adjust_id(b, j)); } - times[1] += get_cy() - t1; } - void to_flat_arrays( - float* distances, - int64_t* labels, - const float* normalizers = nullptr) override { + void end() override { using Cf = typename std::conditional< C::is_max, CMax, CMin>::type; - uint64_t t0 = get_cy(); - uint64_t t3 = 0; std::vector perm(reservoirs[0].n); for (int q = 0; q < reservoirs.size(); q++) { ReservoirTopN& res = reservoirs[q]; @@ -496,8 +446,8 @@ struct ReservoirHandler : SIMDResultHandler { if (res.i > res.n) { res.shrink(); } - int64_t* heap_ids = labels + q * n; - float* heap_dis = distances + q * n; + int64_t* heap_ids = ids + q * n; + float* heap_dis = dis + q * n; float one_a = 1.0, b = 0.0; if (normalizers) { @@ -518,14 +468,236 @@ struct ReservoirHandler : SIMDResultHandler { // possibly add empty results heap_heapify(n - res.i, heap_dis + res.i, heap_ids + res.i); + } + } +}; + +/** Result hanlder for range search. The difficulty is that the range distances + * have to be scaled using the scaler. + */ + +template +struct RangeHandler : ResultHandlerCompare { + using T = typename C::T; + using TI = typename C::TI; + using RHC = ResultHandlerCompare; + using RHC::normalizers; + using RHC::nq; + + RangeSearchResult& rres; + float radius; + std::vector thresholds; + std::vector n_per_query; + size_t q0 = 0; + + // we cannot use the RangeSearchPartialResult interface because queries can + // be performed by batches + struct Triplet { + idx_t q; + idx_t b; + uint16_t dis; + }; + std::vector triplets; + + RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal) + : RHC(rres.nq, ntotal), rres(rres), radius(radius) { + thresholds.resize(nq); + n_per_query.resize(nq + 1); + } + + virtual void begin(const float* norms) { + normalizers = norms; + for (int q = 0; q < nq; ++q) { + thresholds[q] = + normalizers[2 * q] * (radius - normalizers[2 * q + 1]); + } + } + + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + if (this->disable) { + return; + } + this->adjust_with_origin(q, d0, d1); + + uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1); + + if (!lt_mask) { + return; + } + ALIGNED(32) uint16_t d32tab[32]; + d0.store(d32tab); + d1.store(d32tab + 16); + + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + T dis = d32tab[j]; + n_per_query[q]++; + triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis}); + } + } - t3 += res.cycles; + void end() override { + memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq); + rres.do_allocation(); + for (auto it = triplets.begin(); it != triplets.end(); ++it) { + size_t& l = rres.lims[it->q]; + rres.distances[l] = it->dis; + rres.labels[l] = it->b; + l++; + } + memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq); + rres.lims[0] = 0; + + for (int q = 0; q < nq; q++) { + float one_a = 1 / normalizers[2 * q]; + float b = normalizers[2 * q + 1]; + for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) { + rres.distances[i] = rres.distances[i] * one_a + b; + } } - times[2] += get_cy() - t0; - times[3] += t3; } }; +#ifndef SWIG + +// handler for a subset of queries +template +struct PartialRangeHandler : RangeHandler { + using T = typename C::T; + using TI = typename C::TI; + using RHC = RangeHandler; + using RHC::normalizers; + using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query; + + RangeSearchPartialResult& pres; + + PartialRangeHandler( + RangeSearchPartialResult& pres, + float radius, + size_t ntotal, + size_t q0, + size_t q1) + : RangeHandler(*pres.res, radius, ntotal), + pres(pres) { + nq = q1 - q0; + this->q0 = q0; + } + + // shift left n_per_query + void shift_n_per_query() { + memmove(n_per_query.data() + 1, + n_per_query.data(), + nq * sizeof(n_per_query[0])); + n_per_query[0] = 0; + } + + // commit to partial result instead of full RangeResult + void end() override { + std::vector sorted_triplets(triplets.size()); + for (int q = 0; q < nq; q++) { + n_per_query[q + 1] += n_per_query[q]; + } + shift_n_per_query(); + + for (size_t i = 0; i < triplets.size(); i++) { + sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i]; + } + shift_n_per_query(); + + size_t* lims = n_per_query.data(); + + for (int q = 0; q < nq; q++) { + float one_a = 1 / normalizers[2 * q]; + float b = normalizers[2 * q + 1]; + RangeQueryResult& qres = pres.new_result(q + q0); + for (size_t i = lims[q]; i < lims[q + 1]; i++) { + qres.add( + sorted_triplets[i].dis * one_a + b, + sorted_triplets[i].b); + } + } + } +}; + +#endif + +/******************************************************************************** + * Dynamic dispatching function. The consumer should have a templatized method f + * that will be replaced with the actual SIMDResultHandler that is determined + * dynamically. + */ + +template +void dispatch_SIMDResultHanlder_fixedCW( + SIMDResultHandler& res, + Consumer& consumer, + Types... args) { + if (auto resh = dynamic_cast*>(&res)) { + consumer.template f>(*resh, args...); + } else if (auto resh = dynamic_cast*>(&res)) { + consumer.template f>(*resh, args...); + } else if (auto resh = dynamic_cast*>(&res)) { + consumer.template f>(*resh, args...); + } else { // generic handler -- will not be inlined + FAISS_THROW_IF_NOT_FMT( + simd_result_handlers_accept_virtual, + "Running vitrual handler for %s", + typeid(res).name()); + consumer.template f(res, args...); + } +} + +template +void dispatch_SIMDResultHanlder_fixedC( + SIMDResultHandler& res, + Consumer& consumer, + Types... args) { + if (res.with_fields) { + dispatch_SIMDResultHanlder_fixedCW(res, consumer, args...); + } else { + dispatch_SIMDResultHanlder_fixedCW(res, consumer, args...); + } +} + +template +void dispatch_SIMDResultHanlder( + SIMDResultHandler& res, + Consumer& consumer, + Types... args) { + if (res.sizeof_ids == 0) { + if (auto resh = dynamic_cast(&res)) { + consumer.template f(*resh, args...); + } else if (auto resh = dynamic_cast(&res)) { + consumer.template f(*resh, args...); + } else { // generic path + FAISS_THROW_IF_NOT_FMT( + simd_result_handlers_accept_virtual, + "Running vitrual handler for %s", + typeid(res).name()); + consumer.template f(res, args...); + } + } else if (res.sizeof_ids == sizeof(int)) { + if (res.is_CMax) { + dispatch_SIMDResultHanlder_fixedC>( + res, consumer, args...); + } else { + dispatch_SIMDResultHanlder_fixedC>( + res, consumer, args...); + } + } else if (res.sizeof_ids == sizeof(int64_t)) { + if (res.is_CMax) { + dispatch_SIMDResultHanlder_fixedC>( + res, consumer, args...); + } else { + dispatch_SIMDResultHanlder_fixedC>( + res, consumer, args...); + } + } else { + FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids); + } +} } // namespace simd_result_handlers } // namespace faiss diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 3d6f94604a..fb7f50dd2e 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -81,6 +81,9 @@ typedef uint64_t size_t; #include #include #include +#include +#include + #include #include #include @@ -490,6 +493,11 @@ void gpu_sync_all_devices() %include %include %include + +// NOTE(matthijs) let's not go into wrapping simdlib +struct faiss::simd16uint16 {}; + +%include %include %include %include diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 5b66158c09..784793c9a9 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -131,16 +129,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) { namespace { /* Find the nearest neighbors for nx queries in a set of ny vectors */ -template +template void exhaustive_inner_product_seq( const float* x, const float* y, size_t d, size_t nx, size_t ny, - ResultHandler& res, + BlockResultHandler& res, const IDSelector* sel = nullptr) { - using SingleResultHandler = typename ResultHandler::SingleResultHandler; + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; int nt = std::min(int(nx), omp_get_max_threads()); FAISS_ASSERT(use_sel == (sel != nullptr)); @@ -167,16 +166,17 @@ void exhaustive_inner_product_seq( } } -template +template void exhaustive_L2sqr_seq( const float* x, const float* y, size_t d, size_t nx, size_t ny, - ResultHandler& res, + BlockResultHandler& res, const IDSelector* sel = nullptr) { - using SingleResultHandler = typename ResultHandler::SingleResultHandler; + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; int nt = std::min(int(nx), omp_get_max_threads()); FAISS_ASSERT(use_sel == (sel != nullptr)); @@ -202,14 +202,14 @@ void exhaustive_L2sqr_seq( } /** Find the nearest neighbors for nx queries in a set of ny vectors */ -template +template void exhaustive_inner_product_blas( const float* x, const float* y, size_t d, size_t nx, size_t ny, - ResultHandler& res) { + BlockResultHandler& res) { // BLAS does not like empty matrices if (nx == 0 || ny == 0) return; @@ -258,14 +258,14 @@ void exhaustive_inner_product_blas( // distance correction is an operator that can be applied to transform // the distances -template +template void exhaustive_L2sqr_blas_default_impl( const float* x, const float* y, size_t d, size_t nx, size_t ny, - ResultHandler& res, + BlockResultHandler& res, const float* y_norms = nullptr) { // BLAS does not like empty matrices if (nx == 0 || ny == 0) @@ -341,14 +341,14 @@ void exhaustive_L2sqr_blas_default_impl( } } -template +template void exhaustive_L2sqr_blas( const float* x, const float* y, size_t d, size_t nx, size_t ny, - ResultHandler& res, + BlockResultHandler& res, const float* y_norms = nullptr) { exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res); } @@ -360,7 +360,7 @@ void exhaustive_L2sqr_blas_cmax_avx2( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms) { // BLAS does not like empty matrices if (nx == 0 || ny == 0) @@ -563,13 +563,13 @@ void exhaustive_L2sqr_blas_cmax_avx2( // an override if only a single closest point is needed template <> -void exhaustive_L2sqr_blas>>( +void exhaustive_L2sqr_blas>>( const float* x, const float* y, size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms) { #if defined(__AVX2__) // use a faster fused kernel if available @@ -590,28 +590,29 @@ void exhaustive_L2sqr_blas>>( // run the default implementation exhaustive_L2sqr_blas_default_impl< - SingleBestResultHandler>>( + Top1BlockResultHandler>>( x, y, d, nx, ny, res, y_norms); #else // run the default implementation exhaustive_L2sqr_blas_default_impl< - SingleBestResultHandler>>( + Top1BlockResultHandler>>( x, y, d, nx, ny, res, y_norms); #endif } -template +template void knn_L2sqr_select( const float* x, const float* y, size_t d, size_t nx, size_t ny, - ResultHandler& res, + BlockResultHandler& res, const float* y_norm2, const IDSelector* sel) { if (sel) { - exhaustive_L2sqr_seq(x, y, d, nx, ny, res, sel); + exhaustive_L2sqr_seq( + x, y, d, nx, ny, res, sel); } else if (nx < distance_compute_blas_threshold) { exhaustive_L2sqr_seq(x, y, d, nx, ny, res); } else { @@ -619,6 +620,25 @@ void knn_L2sqr_select( } } +template +void knn_inner_product_select( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + BlockResultHandler& res, + const IDSelector* sel) { + if (sel) { + exhaustive_inner_product_seq( + x, y, d, nx, ny, res, sel); + } else if (nx < distance_compute_blas_threshold) { + exhaustive_inner_product_seq(x, y, d, nx, ny, res); + } else { + exhaustive_inner_product_blas(x, y, d, nx, ny, res); + } +} + } // anonymous namespace /******************************************************* @@ -637,7 +657,7 @@ void knn_inner_product( size_t nx, size_t ny, size_t k, - float* val, + float* vals, int64_t* ids, const IDSelector* sel) { int64_t imin = 0; @@ -650,30 +670,21 @@ void knn_inner_product( } if (auto sela = dynamic_cast(sel)) { knn_inner_products_by_idx( - x, y, sela->ids, d, nx, sela->n, k, val, ids, 0); + x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0); return; } - if (k < distance_compute_min_k_reservoir) { - using RH = HeapResultHandler>; - RH res(nx, val, ids, k); - if (sel) { - exhaustive_inner_product_seq(x, y, d, nx, ny, res, sel); - } else if (nx < distance_compute_blas_threshold) { - exhaustive_inner_product_seq(x, y, d, nx, ny, res); - } else { - exhaustive_inner_product_blas(x, y, d, nx, ny, res); - } + + if (k == 1) { + Top1BlockResultHandler> res(nx, vals, ids); + knn_inner_product_select(x, y, d, nx, ny, res, sel); + } else if (k < distance_compute_min_k_reservoir) { + HeapBlockResultHandler> res(nx, vals, ids, k); + knn_inner_product_select(x, y, d, nx, ny, res, sel); } else { - using RH = ReservoirResultHandler>; - RH res(nx, val, ids, k); - if (sel) { - exhaustive_inner_product_seq(x, y, d, nx, ny, res, sel); - } else if (nx < distance_compute_blas_threshold) { - exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr); - } else { - exhaustive_inner_product_blas(x, y, d, nx, ny, res); - } + ReservoirBlockResultHandler> res(nx, vals, ids, k); + knn_inner_product_select(x, y, d, nx, ny, res, sel); } + if (imin != 0) { for (size_t i = 0; i < nx * k; i++) { if (ids[i] >= 0) { @@ -719,13 +730,13 @@ void knn_L2sqr( return; } if (k == 1) { - SingleBestResultHandler> res(nx, vals, ids); + Top1BlockResultHandler> res(nx, vals, ids); knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel); } else if (k < distance_compute_min_k_reservoir) { - HeapResultHandler> res(nx, vals, ids, k); + HeapBlockResultHandler> res(nx, vals, ids, k); knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel); } else { - ReservoirResultHandler> res(nx, vals, ids, k); + ReservoirBlockResultHandler> res(nx, vals, ids, k); knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel); } if (imin != 0) { @@ -763,7 +774,7 @@ void range_search_L2sqr( float radius, RangeSearchResult* res, const IDSelector* sel) { - using RH = RangeSearchResultHandler>; + using RH = RangeSearchBlockResultHandler>; RH resh(res, radius); if (sel) { exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel); @@ -783,7 +794,7 @@ void range_search_inner_product( float radius, RangeSearchResult* res, const IDSelector* sel) { - using RH = RangeSearchResultHandler>; + using RH = RangeSearchBlockResultHandler>; RH resh(res, radius); if (sel) { exhaustive_inner_product_seq(x, y, d, nx, ny, resh, sel); diff --git a/faiss/utils/distances_fused/avx512.cpp b/faiss/utils/distances_fused/avx512.cpp index b5ff70f9e4..d4c442c79b 100644 --- a/faiss/utils/distances_fused/avx512.cpp +++ b/faiss/utils/distances_fused/avx512.cpp @@ -68,7 +68,7 @@ void kernel( const float* const __restrict y, const float* const __restrict y_transposed, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* __restrict y_norms, size_t i) { const size_t ny_p = @@ -231,7 +231,7 @@ void exhaustive_L2sqr_fused_cmax( const float* const __restrict y, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* __restrict y_norms) { // BLAS does not like empty matrices if (nx == 0 || ny == 0) { @@ -275,7 +275,7 @@ void exhaustive_L2sqr_fused_cmax( x, y, y_transposed.data(), ny, res, y_norms, i); } - // Does nothing for SingleBestResultHandler, but + // Does nothing for Top1BlockResultHandler, but // keeping the call for the consistency. res.end_multiple(); InterruptCallback::check(); @@ -289,7 +289,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms) { // process only cases with certain dimensionalities diff --git a/faiss/utils/distances_fused/avx512.h b/faiss/utils/distances_fused/avx512.h index b6d5fc0556..4cb62771a2 100644 --- a/faiss/utils/distances_fused/avx512.h +++ b/faiss/utils/distances_fused/avx512.h @@ -28,7 +28,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms); } // namespace faiss diff --git a/faiss/utils/distances_fused/distances_fused.cpp b/faiss/utils/distances_fused/distances_fused.cpp index a0af971c5c..2ba7e29014 100644 --- a/faiss/utils/distances_fused/distances_fused.cpp +++ b/faiss/utils/distances_fused/distances_fused.cpp @@ -20,7 +20,7 @@ bool exhaustive_L2sqr_fused_cmax( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms) { if (nx == 0 || ny == 0) { // nothing to do diff --git a/faiss/utils/distances_fused/distances_fused.h b/faiss/utils/distances_fused/distances_fused.h index e6e35c209e..54b58752b1 100644 --- a/faiss/utils/distances_fused/distances_fused.h +++ b/faiss/utils/distances_fused/distances_fused.h @@ -34,7 +34,7 @@ bool exhaustive_L2sqr_fused_cmax( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms); } // namespace faiss diff --git a/faiss/utils/distances_fused/simdlib_based.cpp b/faiss/utils/distances_fused/simdlib_based.cpp index 97ededd2f0..31239e866b 100644 --- a/faiss/utils/distances_fused/simdlib_based.cpp +++ b/faiss/utils/distances_fused/simdlib_based.cpp @@ -62,7 +62,7 @@ void kernel( const float* const __restrict y, const float* const __restrict y_transposed, const size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* __restrict y_norms, const size_t i) { const size_t ny_p = @@ -226,7 +226,7 @@ void exhaustive_L2sqr_fused_cmax( const float* const __restrict y, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* __restrict y_norms) { // BLAS does not like empty matrices if (nx == 0 || ny == 0) { @@ -270,7 +270,7 @@ void exhaustive_L2sqr_fused_cmax( x, y, y_transposed.data(), ny, res, y_norms, i); } - // Does nothing for SingleBestResultHandler, but + // Does nothing for Top1BlockResultHandler, but // keeping the call for the consistency. res.end_multiple(); InterruptCallback::check(); @@ -284,7 +284,7 @@ bool exhaustive_L2sqr_fused_cmax_simdlib( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms) { // Process only cases with certain dimensionalities. // An acceptable dimensionality value is limited by the number of diff --git a/faiss/utils/distances_fused/simdlib_based.h b/faiss/utils/distances_fused/simdlib_based.h index b60da7b193..6240a8f110 100644 --- a/faiss/utils/distances_fused/simdlib_based.h +++ b/faiss/utils/distances_fused/simdlib_based.h @@ -24,7 +24,7 @@ bool exhaustive_L2sqr_fused_cmax_simdlib( size_t d, size_t nx, size_t ny, - SingleBestResultHandler>& res, + Top1BlockResultHandler>& res, const float* y_norms); } // namespace faiss diff --git a/tests/test_fast_scan_ivf.py b/tests/test_fast_scan_ivf.py index 5a57a39ca9..d6dad8fec3 100644 --- a/tests/test_fast_scan_ivf.py +++ b/tests/test_fast_scan_ivf.py @@ -133,8 +133,6 @@ def test_by_residual_L2_v2(self): self.do_test(LUT, bias, nprobe, alt_3d=True) - - ########################################################## # Tests for various IndexPQFastScan implementations ########################################################## @@ -209,7 +207,6 @@ def test_by_residual_ip(self): self.do_test(True, faiss.METRIC_INNER_PRODUCT) - class TestIVFImplem2(unittest.TestCase): """ Verify implem 2 (search with original invlists with uint8 LUTs) against IndexIVFPQ. Entails some loss in accuracy. """ @@ -259,6 +256,7 @@ def test_qloss_no_residual_ip(self): def test_qloss_by_residual_ip(self): self.eval_quant_loss(True, faiss.METRIC_INNER_PRODUCT) + class TestEquivPQ(unittest.TestCase): def test_equiv_pq(self): @@ -309,6 +307,7 @@ def do_test(self, by_residual, metric=faiss.METRIC_L2, d=32, nq=200): index.add(ds.get_database()) index.nprobe = 4 + # compare against implem = 2, which includes quantized LUTs index2 = faiss.IndexIVFPQFastScan(index) index2.implem = 2 Dref, Iref = index2.search(ds.get_queries(), 4) @@ -370,7 +369,6 @@ def test_by_residual_odd_dim_single_query(self): self.do_test(True, d=30, nq=1) - class TestIVFImplem10(TestIVFImplem12): IMPLEM = 10 @@ -432,7 +430,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32): new_code_i = new_code_per_id[the_id] np.testing.assert_array_equal(ref_code_i, new_code_i) - def test_add(self): self.do_test() @@ -812,3 +809,75 @@ def subtest_io(self, factory_str): def test_io(self): self.subtest_io('IVF16,PLSQ2x3x4fsr_Nlsq2x4') self.subtest_io('IVF16,PRQ2x3x4fs_Nrq2x4') + + +class TestSearchParams(unittest.TestCase): + + def test_search_params(self): + ds = datasets.SyntheticDataset(32, 500, 100, 10) + + index = faiss.index_factory(ds.d, "IVF32,PQ16x4fs") + index.train(ds.get_train()) + index.add(ds.get_database()) + + index.nprobe + index.nprobe = 4 + Dref4, Iref4 = index.search(ds.get_queries(), 10) + # index.nprobe = 16 + # Dref16, Iref16 = index.search(ds.get_queries(), 10) + + index.nprobe = 1 + Dnew4, Inew4 = index.search( + ds.get_queries(), 10, params=faiss.IVFSearchParameters(nprobe=4)) + np.testing.assert_array_equal(Dref4, Dnew4) + np.testing.assert_array_equal(Iref4, Inew4) + + +class TestRangeSearchImplem12(unittest.TestCase): + IMPLEM = 12 + + def do_test(self, metric=faiss.METRIC_L2): + ds = datasets.SyntheticDataset(32, 750, 200, 100) + + index = faiss.index_factory(ds.d, "IVF32,PQ16x4np", metric) + index.train(ds.get_train()) + index.add(ds.get_database()) + index.nprobe = 4 + + # find a reasonable radius + D, I = index.search(ds.get_queries(), 10) + radius = np.median(D[:, -1]) + # print("radius=", radius) + lims1, D1, I1 = index.range_search(ds.get_queries(), radius) + + index2 = faiss.IndexIVFPQFastScan(index) + index2.implem = self.IMPLEM + lims2, D2, I2 = index2.range_search(ds.get_queries(), radius) + + nmiss = 0 + nextra = 0 + + for i in range(ds.nq): + ref = set(I1[lims1[i]: lims1[i + 1]]) + new = set(I2[lims2[i]: lims2[i + 1]]) + print(ref, new) + nmiss += len(ref - new) + nextra += len(new - ref) + + # need some tolerance because the look-up tables are quantized + self.assertLess(nmiss, 10) + self.assertLess(nextra, 10) + + def test_L2(self): + self.do_test() + + def test_IP(self): + self.do_test(metric=faiss.METRIC_INNER_PRODUCT) + + +class TestRangeSearchImplem10(TestRangeSearchImplem12): + IMPLEM = 10 + + +class TestRangeSearchImplem110(TestRangeSearchImplem12): + IMPLEM = 110 diff --git a/tests/test_graph_based.py b/tests/test_graph_based.py new file mode 100644 index 0000000000..914fac3ff1 --- /dev/null +++ b/tests/test_graph_based.py @@ -0,0 +1,426 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" a few tests for graph-based indices (HNSW and NSG)""" + +import numpy as np +import unittest +import faiss +import tempfile +import os + +from common_faiss_tests import get_dataset_2 + + +class TestHNSW(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + + (_, self.xb, self.xq) = get_dataset_2(d, nt, nb, nq) + index = faiss.IndexFlatL2(d) + index.add(self.xb) + Dref, Iref = index.search(self.xq, 1) + self.Iref = Iref + + def test_hnsw(self): + d = self.xq.shape[1] + + index = faiss.IndexHNSWFlat(d, 16) + index.add(self.xb) + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) + + self.io_and_retest(index, Dhnsw, Ihnsw) + + def test_range_search(self): + index_flat = faiss.IndexFlat(self.xb.shape[1]) + index_flat.add(self.xb) + D, _ = index_flat.search(self.xq, 10) + radius = np.median(D[:, -1]) + lims_ref, Dref, Iref = index_flat.range_search(self.xq, radius) + + index = faiss.IndexHNSWFlat(self.xb.shape[1], 16) + index.add(self.xb) + lims, D, I = index.range_search(self.xq, radius) + + nmiss = 0 + # check if returned resutls are a subset of the reference results + for i in range(len(self.xq)): + ref = Iref[lims_ref[i]: lims_ref[i + 1]] + new = I[lims[i]: lims[i + 1]] + self.assertLessEqual(set(new), set(ref)) + nmiss += len(ref) - len(new) + # currenly we miss 405 / 6019 neighbors + self.assertLessEqual(nmiss, lims_ref[-1] * 0.1) + + def test_hnsw_unbounded_queue(self): + d = self.xq.shape[1] + + index = faiss.IndexHNSWFlat(d, 16) + index.add(self.xb) + index.search_bounded_queue = False + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) + + self.io_and_retest(index, Dhnsw, Ihnsw) + + def io_and_retest(self, index, Dhnsw, Ihnsw): + index2 = faiss.deserialize_index(faiss.serialize_index(index)) + Dhnsw2, Ihnsw2 = index2.search(self.xq, 1) + + self.assertTrue(np.all(Dhnsw2 == Dhnsw)) + self.assertTrue(np.all(Ihnsw2 == Ihnsw)) + + # also test clone + index3 = faiss.clone_index(index) + Dhnsw3, Ihnsw3 = index3.search(self.xq, 1) + + self.assertTrue(np.all(Dhnsw3 == Dhnsw)) + self.assertTrue(np.all(Ihnsw3 == Ihnsw)) + + def test_hnsw_2level(self): + d = self.xq.shape[1] + + quant = faiss.IndexFlatL2(d) + + index = faiss.IndexHNSW2Level(quant, 256, 8, 8) + index.train(self.xb) + index.add(self.xb) + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 307) + + self.io_and_retest(index, Dhnsw, Ihnsw) + + def test_add_0_vecs(self): + index = faiss.IndexHNSWFlat(10, 16) + zero_vecs = np.zeros((0, 10), dtype='float32') + # infinite loop + index.add(zero_vecs) + + def test_hnsw_IP(self): + d = self.xq.shape[1] + + index_IP = faiss.IndexFlatIP(d) + index_IP.add(self.xb) + Dref, Iref = index_IP.search(self.xq, 1) + + index = faiss.IndexHNSWFlat(d, 16, faiss.METRIC_INNER_PRODUCT) + index.add(self.xb) + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((Iref == Ihnsw).sum(), 470) + + mask = Iref[:, 0] == Ihnsw[:, 0] + assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0]) + + +class TestNSG(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + self.GK = 32 + + _, self.xb, self.xq = get_dataset_2(d, nt, nb, nq) + + def make_knn_graph(self, metric): + n = self.xb.shape[0] + d = self.xb.shape[1] + index = faiss.IndexFlat(d, metric) + index.add(self.xb) + _, I = index.search(self.xb, self.GK + 1) + knn_graph = np.zeros((n, self.GK), dtype=np.int64) + + # For the inner product distance, the distance between a vector and + # itself may not be the smallest, so it is not guaranteed that I[:, 0] + # is the query itself. + for i in range(n): + cnt = 0 + for j in range(self.GK + 1): + if I[i, j] != i: + knn_graph[i, cnt] = I[i, j] + cnt += 1 + if cnt == self.GK: + break + return knn_graph + + def subtest_io_and_clone(self, index, Dnsg, Insg): + fd, tmpfile = tempfile.mkstemp() + os.close(fd) + try: + faiss.write_index(index, tmpfile) + index2 = faiss.read_index(tmpfile) + finally: + if os.path.exists(tmpfile): + os.unlink(tmpfile) + + Dnsg2, Insg2 = index2.search(self.xq, 1) + np.testing.assert_array_equal(Dnsg2, Dnsg) + np.testing.assert_array_equal(Insg2, Insg) + + # also test clone + index3 = faiss.clone_index(index) + Dnsg3, Insg3 = index3.search(self.xq, 1) + np.testing.assert_array_equal(Dnsg3, Dnsg) + np.testing.assert_array_equal(Insg3, Insg) + + def subtest_connectivity(self, index, nb): + vt = faiss.VisitedTable(nb) + count = index.nsg.dfs(vt, index.nsg.enterpoint, 0) + self.assertEqual(count, nb) + + def subtest_add(self, build_type, thresh, metric=faiss.METRIC_L2): + d = self.xq.shape[1] + metrics = {faiss.METRIC_L2: 'L2', + faiss.METRIC_INNER_PRODUCT: 'IP'} + + flat_index = faiss.IndexFlat(d, metric) + flat_index.add(self.xb) + Dref, Iref = flat_index.search(self.xq, 1) + + index = faiss.IndexNSGFlat(d, 16, metric) + index.verbose = True + index.build_type = build_type + index.GK = self.GK + index.add(self.xb) + Dnsg, Insg = index.search(self.xq, 1) + + recalls = (Iref == Insg).sum() + print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) + self.assertGreaterEqual(recalls, thresh) + self.subtest_connectivity(index, self.xb.shape[0]) + self.subtest_io_and_clone(index, Dnsg, Insg) + + def subtest_build(self, knn_graph, thresh, metric=faiss.METRIC_L2): + d = self.xq.shape[1] + metrics = {faiss.METRIC_L2: 'L2', + faiss.METRIC_INNER_PRODUCT: 'IP'} + + flat_index = faiss.IndexFlat(d, metric) + flat_index.add(self.xb) + Dref, Iref = flat_index.search(self.xq, 1) + + index = faiss.IndexNSGFlat(d, 16, metric) + index.verbose = True + + index.build(self.xb, knn_graph) + Dnsg, Insg = index.search(self.xq, 1) + + recalls = (Iref == Insg).sum() + print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) + self.assertGreaterEqual(recalls, thresh) + self.subtest_connectivity(index, self.xb.shape[0]) + + def test_add_bruteforce_L2(self): + self.subtest_add(0, 475, faiss.METRIC_L2) + + def test_add_nndescent_L2(self): + self.subtest_add(1, 475, faiss.METRIC_L2) + + def test_add_bruteforce_IP(self): + self.subtest_add(0, 480, faiss.METRIC_INNER_PRODUCT) + + def test_add_nndescent_IP(self): + self.subtest_add(1, 480, faiss.METRIC_INNER_PRODUCT) + + def test_build_L2(self): + knn_graph = self.make_knn_graph(faiss.METRIC_L2) + self.subtest_build(knn_graph, 475, faiss.METRIC_L2) + + def test_build_IP(self): + knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT) + self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT) + + def test_build_invalid_knng(self): + """Make some invalid entries in the input knn graph. + + It would cause a warning but IndexNSG should be able + to handel this. + """ + knn_graph = self.make_knn_graph(faiss.METRIC_L2) + knn_graph[:100, 5] = -111 + self.subtest_build(knn_graph, 475, faiss.METRIC_L2) + + knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT) + knn_graph[:100, 5] = -111 + self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT) + + def test_reset(self): + """test IndexNSG.reset()""" + d = self.xq.shape[1] + metrics = {faiss.METRIC_L2: 'L2', + faiss.METRIC_INNER_PRODUCT: 'IP'} + + metric = faiss.METRIC_L2 + flat_index = faiss.IndexFlat(d, metric) + flat_index.add(self.xb) + Dref, Iref = flat_index.search(self.xq, 1) + + index = faiss.IndexNSGFlat(d, 16) + index.verbose = True + index.GK = 32 + + index.add(self.xb) + Dnsg, Insg = index.search(self.xq, 1) + recalls = (Iref == Insg).sum() + print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) + self.assertGreaterEqual(recalls, 475) + self.subtest_connectivity(index, self.xb.shape[0]) + + index.reset() + index.add(self.xb) + Dnsg, Insg = index.search(self.xq, 1) + recalls = (Iref == Insg).sum() + print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) + self.assertGreaterEqual(recalls, 475) + self.subtest_connectivity(index, self.xb.shape[0]) + + def test_order(self): + """make sure that output results are sorted""" + d = self.xq.shape[1] + index = faiss.IndexNSGFlat(d, 32) + + index.train(self.xb) + index.add(self.xb) + + k = 10 + nq = self.xq.shape[0] + D, _ = index.search(self.xq, k) + + indices = np.argsort(D, axis=1) + gt = np.arange(0, k)[np.newaxis, :] # [1, k] + gt = np.repeat(gt, nq, axis=0) # [nq, k] + np.testing.assert_array_equal(indices, gt) + + def test_nsg_pq(self): + """Test IndexNSGPQ""" + d = self.xq.shape[1] + R, pq_M = 32, 4 + index = faiss.index_factory(d, f"NSG{R}_PQ{pq_M}np") + assert isinstance(index, faiss.IndexNSGPQ) + idxpq = faiss.downcast_index(index.storage) + assert index.nsg.R == R and idxpq.pq.M == pq_M + + flat_index = faiss.IndexFlat(d) + flat_index.add(self.xb) + Dref, Iref = flat_index.search(self.xq, k=1) + + index.GK = 32 + index.train(self.xb) + index.add(self.xb) + D, I = index.search(self.xq, k=1) + + # test accuracy + recalls = (Iref == I).sum() + print("IndexNSGPQ", recalls) + self.assertGreaterEqual(recalls, 190) # 193 + + # test I/O + self.subtest_io_and_clone(index, D, I) + + def test_nsg_sq(self): + """Test IndexNSGSQ""" + d = self.xq.shape[1] + R = 32 + index = faiss.index_factory(d, f"NSG{R}_SQ8") + assert isinstance(index, faiss.IndexNSGSQ) + idxsq = faiss.downcast_index(index.storage) + assert index.nsg.R == R + assert idxsq.sq.qtype == faiss.ScalarQuantizer.QT_8bit + + flat_index = faiss.IndexFlat(d) + flat_index.add(self.xb) + Dref, Iref = flat_index.search(self.xq, k=1) + + index.train(self.xb) + index.add(self.xb) + D, I = index.search(self.xq, k=1) + + # test accuracy + recalls = (Iref == I).sum() + print("IndexNSGSQ", recalls) + self.assertGreaterEqual(recalls, 405) # 411 + + # test I/O + self.subtest_io_and_clone(index, D, I) + + +class TestNNDescent(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + self.GK = 32 + + _, self.xb, self.xq = get_dataset_2(d, nt, nb, nq) + + def test_nndescentflat(self): + d = self.xq.shape[1] + index = faiss.IndexNNDescentFlat(d, 32) + index.nndescent.search_L = 8 + + flat_index = faiss.IndexFlat(d) + flat_index.add(self.xb) + Dref, Iref = flat_index.search(self.xq, k=1) + + index.train(self.xb) + index.add(self.xb) + D, I = index.search(self.xq, k=1) + + # test accuracy + recalls = (Iref == I).sum() + print("IndexNNDescentFlat", recalls) + self.assertGreaterEqual(recalls, 450) # 462 + + # do some IO tests + fd, tmpfile = tempfile.mkstemp() + os.close(fd) + try: + faiss.write_index(index, tmpfile) + index2 = faiss.read_index(tmpfile) + finally: + if os.path.exists(tmpfile): + os.unlink(tmpfile) + + D2, I2 = index2.search(self.xq, 1) + np.testing.assert_array_equal(D2, D) + np.testing.assert_array_equal(I2, I) + + # also test clone + index3 = faiss.clone_index(index) + D3, I3 = index3.search(self.xq, 1) + np.testing.assert_array_equal(D3, D) + np.testing.assert_array_equal(I3, I) + + def test_order(self): + """make sure that output results are sorted""" + d = self.xq.shape[1] + index = faiss.IndexNNDescentFlat(d, 32) + + index.train(self.xb) + index.add(self.xb) + + k = 10 + nq = self.xq.shape[0] + D, _ = index.search(self.xq, k) + + indices = np.argsort(D, axis=1) + gt = np.arange(0, k)[np.newaxis, :] # [1, k] + gt = np.repeat(gt, nq, axis=0) # [nq, k] + np.testing.assert_array_equal(indices, gt) diff --git a/tests/test_index.py b/tests/test_index.py index 0e828e08c1..f46c6a94bf 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -526,406 +526,6 @@ def test_IndexTransform(self): self.run_search_and_reconstruct(index, xb, xq) -class TestHNSW(unittest.TestCase): - - def __init__(self, *args, **kwargs): - unittest.TestCase.__init__(self, *args, **kwargs) - d = 32 - nt = 0 - nb = 1500 - nq = 500 - - (_, self.xb, self.xq) = get_dataset_2(d, nt, nb, nq) - index = faiss.IndexFlatL2(d) - index.add(self.xb) - Dref, Iref = index.search(self.xq, 1) - self.Iref = Iref - - def test_hnsw(self): - d = self.xq.shape[1] - - index = faiss.IndexHNSWFlat(d, 16) - index.add(self.xb) - Dhnsw, Ihnsw = index.search(self.xq, 1) - - self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) - - self.io_and_retest(index, Dhnsw, Ihnsw) - - def test_hnsw_unbounded_queue(self): - d = self.xq.shape[1] - - index = faiss.IndexHNSWFlat(d, 16) - index.add(self.xb) - index.search_bounded_queue = False - Dhnsw, Ihnsw = index.search(self.xq, 1) - - self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) - - self.io_and_retest(index, Dhnsw, Ihnsw) - - def io_and_retest(self, index, Dhnsw, Ihnsw): - fd, tmpfile = tempfile.mkstemp() - os.close(fd) - try: - faiss.write_index(index, tmpfile) - index2 = faiss.read_index(tmpfile) - finally: - if os.path.exists(tmpfile): - os.unlink(tmpfile) - - Dhnsw2, Ihnsw2 = index2.search(self.xq, 1) - - self.assertTrue(np.all(Dhnsw2 == Dhnsw)) - self.assertTrue(np.all(Ihnsw2 == Ihnsw)) - - # also test clone - index3 = faiss.clone_index(index) - Dhnsw3, Ihnsw3 = index3.search(self.xq, 1) - - self.assertTrue(np.all(Dhnsw3 == Dhnsw)) - self.assertTrue(np.all(Ihnsw3 == Ihnsw)) - - - def test_hnsw_2level(self): - d = self.xq.shape[1] - - quant = faiss.IndexFlatL2(d) - - index = faiss.IndexHNSW2Level(quant, 256, 8, 8) - index.train(self.xb) - index.add(self.xb) - Dhnsw, Ihnsw = index.search(self.xq, 1) - - self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 307) - - self.io_and_retest(index, Dhnsw, Ihnsw) - - def test_add_0_vecs(self): - index = faiss.IndexHNSWFlat(10, 16) - zero_vecs = np.zeros((0, 10), dtype='float32') - # infinite loop - index.add(zero_vecs) - - def test_hnsw_IP(self): - d = self.xq.shape[1] - - index_IP = faiss.IndexFlatIP(d) - index_IP.add(self.xb) - Dref, Iref = index_IP.search(self.xq, 1) - - index = faiss.IndexHNSWFlat(d, 16, faiss.METRIC_INNER_PRODUCT) - index.add(self.xb) - Dhnsw, Ihnsw = index.search(self.xq, 1) - - print('nb equal: ', (Iref == Ihnsw).sum()) - - self.assertGreaterEqual((Iref == Ihnsw).sum(), 470) - - mask = Iref[:, 0] == Ihnsw[:, 0] - assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0]) - - -class TestNSG(unittest.TestCase): - - def __init__(self, *args, **kwargs): - unittest.TestCase.__init__(self, *args, **kwargs) - d = 32 - nt = 0 - nb = 1500 - nq = 500 - self.GK = 32 - - _, self.xb, self.xq = get_dataset_2(d, nt, nb, nq) - - def make_knn_graph(self, metric): - n = self.xb.shape[0] - d = self.xb.shape[1] - index = faiss.IndexFlat(d, metric) - index.add(self.xb) - _, I = index.search(self.xb, self.GK + 1) - knn_graph = np.zeros((n, self.GK), dtype=np.int64) - - # For the inner product distance, the distance between a vector and itself - # may not be the smallest, so it is not guaranteed that I[:, 0] is the query itself. - for i in range(n): - cnt = 0 - for j in range(self.GK + 1): - if I[i, j] != i: - knn_graph[i, cnt] = I[i, j] - cnt += 1 - if cnt == self.GK: - break - return knn_graph - - def subtest_io_and_clone(self, index, Dnsg, Insg): - fd, tmpfile = tempfile.mkstemp() - os.close(fd) - try: - faiss.write_index(index, tmpfile) - index2 = faiss.read_index(tmpfile) - finally: - if os.path.exists(tmpfile): - os.unlink(tmpfile) - - Dnsg2, Insg2 = index2.search(self.xq, 1) - np.testing.assert_array_equal(Dnsg2, Dnsg) - np.testing.assert_array_equal(Insg2, Insg) - - # also test clone - index3 = faiss.clone_index(index) - Dnsg3, Insg3 = index3.search(self.xq, 1) - np.testing.assert_array_equal(Dnsg3, Dnsg) - np.testing.assert_array_equal(Insg3, Insg) - - def subtest_connectivity(self, index, nb): - vt = faiss.VisitedTable(nb) - count = index.nsg.dfs(vt, index.nsg.enterpoint, 0) - self.assertEqual(count, nb) - - def subtest_add(self, build_type, thresh, metric=faiss.METRIC_L2): - d = self.xq.shape[1] - metrics = {faiss.METRIC_L2: 'L2', - faiss.METRIC_INNER_PRODUCT: 'IP'} - - flat_index = faiss.IndexFlat(d, metric) - flat_index.add(self.xb) - Dref, Iref = flat_index.search(self.xq, 1) - - index = faiss.IndexNSGFlat(d, 16, metric) - index.verbose = True - index.build_type = build_type - index.GK = self.GK - index.add(self.xb) - Dnsg, Insg = index.search(self.xq, 1) - - recalls = (Iref == Insg).sum() - print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) - self.assertGreaterEqual(recalls, thresh) - self.subtest_connectivity(index, self.xb.shape[0]) - self.subtest_io_and_clone(index, Dnsg, Insg) - - def subtest_build(self, knn_graph, thresh, metric=faiss.METRIC_L2): - d = self.xq.shape[1] - metrics = {faiss.METRIC_L2: 'L2', - faiss.METRIC_INNER_PRODUCT: 'IP'} - - flat_index = faiss.IndexFlat(d, metric) - flat_index.add(self.xb) - Dref, Iref = flat_index.search(self.xq, 1) - - index = faiss.IndexNSGFlat(d, 16, metric) - index.verbose = True - - index.build(self.xb, knn_graph) - Dnsg, Insg = index.search(self.xq, 1) - - recalls = (Iref == Insg).sum() - print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) - self.assertGreaterEqual(recalls, thresh) - self.subtest_connectivity(index, self.xb.shape[0]) - - def test_add_bruteforce_L2(self): - self.subtest_add(0, 475, faiss.METRIC_L2) - - def test_add_nndescent_L2(self): - self.subtest_add(1, 475, faiss.METRIC_L2) - - def test_add_bruteforce_IP(self): - self.subtest_add(0, 480, faiss.METRIC_INNER_PRODUCT) - - def test_add_nndescent_IP(self): - self.subtest_add(1, 480, faiss.METRIC_INNER_PRODUCT) - - def test_build_L2(self): - knn_graph = self.make_knn_graph(faiss.METRIC_L2) - self.subtest_build(knn_graph, 475, faiss.METRIC_L2) - - def test_build_IP(self): - knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT) - self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT) - - def test_build_invalid_knng(self): - """Make some invalid entries in the input knn graph. - - It would cause a warning but IndexNSG should be able - to handel this. - """ - knn_graph = self.make_knn_graph(faiss.METRIC_L2) - knn_graph[:100, 5] = -111 - self.subtest_build(knn_graph, 475, faiss.METRIC_L2) - - knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT) - knn_graph[:100, 5] = -111 - self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT) - - def test_reset(self): - """test IndexNSG.reset()""" - d = self.xq.shape[1] - metrics = {faiss.METRIC_L2: 'L2', - faiss.METRIC_INNER_PRODUCT: 'IP'} - - metric = faiss.METRIC_L2 - flat_index = faiss.IndexFlat(d, metric) - flat_index.add(self.xb) - Dref, Iref = flat_index.search(self.xq, 1) - - index = faiss.IndexNSGFlat(d, 16) - index.verbose = True - index.GK = 32 - - index.add(self.xb) - Dnsg, Insg = index.search(self.xq, 1) - recalls = (Iref == Insg).sum() - print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) - self.assertGreaterEqual(recalls, 475) - self.subtest_connectivity(index, self.xb.shape[0]) - - index.reset() - index.add(self.xb) - Dnsg, Insg = index.search(self.xq, 1) - recalls = (Iref == Insg).sum() - print('metric: {}, nb equal: {}'.format(metrics[metric], recalls)) - self.assertGreaterEqual(recalls, 475) - self.subtest_connectivity(index, self.xb.shape[0]) - - def test_order(self): - """make sure that output results are sorted""" - d = self.xq.shape[1] - index = faiss.IndexNSGFlat(d, 32) - - index.train(self.xb) - index.add(self.xb) - - k = 10 - nq = self.xq.shape[0] - D, _ = index.search(self.xq, k) - - indices = np.argsort(D, axis=1) - gt = np.arange(0, k)[np.newaxis, :] # [1, k] - gt = np.repeat(gt, nq, axis=0) # [nq, k] - np.testing.assert_array_equal(indices, gt) - - def test_nsg_pq(self): - """Test IndexNSGPQ""" - d = self.xq.shape[1] - R, pq_M = 32, 4 - index = faiss.index_factory(d, f"NSG{R}_PQ{pq_M}np") - assert isinstance(index, faiss.IndexNSGPQ) - idxpq = faiss.downcast_index(index.storage) - assert index.nsg.R == R and idxpq.pq.M == pq_M - - flat_index = faiss.IndexFlat(d) - flat_index.add(self.xb) - Dref, Iref = flat_index.search(self.xq, k=1) - - index.GK = 32 - index.train(self.xb) - index.add(self.xb) - D, I = index.search(self.xq, k=1) - - # test accuracy - recalls = (Iref == I).sum() - print("IndexNSGPQ", recalls) - self.assertGreaterEqual(recalls, 190) # 193 - - # test I/O - self.subtest_io_and_clone(index, D, I) - - def test_nsg_sq(self): - """Test IndexNSGSQ""" - d = self.xq.shape[1] - R = 32 - index = faiss.index_factory(d, f"NSG{R}_SQ8") - assert isinstance(index, faiss.IndexNSGSQ) - idxsq = faiss.downcast_index(index.storage) - assert index.nsg.R == R - assert idxsq.sq.qtype == faiss.ScalarQuantizer.QT_8bit - - flat_index = faiss.IndexFlat(d) - flat_index.add(self.xb) - Dref, Iref = flat_index.search(self.xq, k=1) - - index.train(self.xb) - index.add(self.xb) - D, I = index.search(self.xq, k=1) - - # test accuracy - recalls = (Iref == I).sum() - print("IndexNSGSQ", recalls) - self.assertGreaterEqual(recalls, 405) # 411 - - # test I/O - self.subtest_io_and_clone(index, D, I) - - -class TestNNDescent(unittest.TestCase): - - def __init__(self, *args, **kwargs): - unittest.TestCase.__init__(self, *args, **kwargs) - d = 32 - nt = 0 - nb = 1500 - nq = 500 - self.GK = 32 - - _, self.xb, self.xq = get_dataset_2(d, nt, nb, nq) - - def test_nndescentflat(self): - d = self.xq.shape[1] - index = faiss.IndexNNDescentFlat(d, 32) - index.nndescent.search_L = 8 - - flat_index = faiss.IndexFlat(d) - flat_index.add(self.xb) - Dref, Iref = flat_index.search(self.xq, k=1) - - index.train(self.xb) - index.add(self.xb) - D, I = index.search(self.xq, k=1) - - # test accuracy - recalls = (Iref == I).sum() - print("IndexNNDescentFlat", recalls) - self.assertGreaterEqual(recalls, 450) # 462 - - # do some IO tests - fd, tmpfile = tempfile.mkstemp() - os.close(fd) - try: - faiss.write_index(index, tmpfile) - index2 = faiss.read_index(tmpfile) - finally: - if os.path.exists(tmpfile): - os.unlink(tmpfile) - - D2, I2 = index2.search(self.xq, 1) - np.testing.assert_array_equal(D2, D) - np.testing.assert_array_equal(I2, I) - - # also test clone - index3 = faiss.clone_index(index) - D3, I3 = index3.search(self.xq, 1) - np.testing.assert_array_equal(D3, D) - np.testing.assert_array_equal(I3, I) - - def test_order(self): - """make sure that output results are sorted""" - d = self.xq.shape[1] - index = faiss.IndexNNDescentFlat(d, 32) - - index.train(self.xb) - index.add(self.xb) - - k = 10 - nq = self.xq.shape[0] - D, _ = index.search(self.xq, k) - - indices = np.argsort(D, axis=1) - gt = np.arange(0, k)[np.newaxis, :] # [1, k] - gt = np.repeat(gt, nq, axis=0) # [nq, k] - np.testing.assert_array_equal(indices, gt) - class TestDistancesPositive(unittest.TestCase):