Skip to content

Commit

Permalink
Generalize ResultHanlder, support range search for HNSW and Fast Scan (
Browse files Browse the repository at this point in the history
…facebookresearch#3190)

Summary:
Pull Request resolved: facebookresearch#3190

This diff adds more result handlers in order to expose them externally.
This enables range search for HSNW and Fast Scan, and nprobe parameter support for FastScan.

Reviewed By: pemazare

Differential Revision: D52547384

fbshipit-source-id: 271da5ffea6411df3d8e50641abade18bd7b774b
  • Loading branch information
mdouze authored and facebook-github-bot committed Jan 11, 2024
1 parent 0013c70 commit 32f0e8c
Show file tree
Hide file tree
Showing 38 changed files with 1,995 additions and 2,015 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ We try to indicate most contributions here with the contributor names who are no
the Facebook Faiss team. Feel free to add entries here if you submit a PR.

## [Unreleased]
- Support for range search in HNSW and Fast scan IVF.
## [1.7.4] - 2023-04-12
### Added
- Added big batch IVF search for conducting efficient search with big batches of queries
Expand Down
2 changes: 2 additions & 0 deletions benchs/link_and_code/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The code runs on top of Faiss. The HNSW index can be extended with a
`ReconstructFromNeighbors` C++ object that refines the distances. The
training is implemented in Python.

Update: 2023-12-28: the current Faiss dropped support for reconstruction with
this method.

Reproducing Table 2 in the paper
--------------------------------
Expand Down
1 change: 1 addition & 0 deletions contrib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
mask = DrefC == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))


def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
Expand Down
27 changes: 17 additions & 10 deletions faiss/IndexAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,19 @@ struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
* scanning implementation for search
************************************************************/

template <class VectorDistance, class ResultHandler>
template <class VectorDistance, class BlockResultHandler>
void search_with_decompress(
const IndexAdditiveQuantizer& ir,
const float* xq,
VectorDistance& vd,
ResultHandler& res) {
BlockResultHandler& res) {
const uint8_t* codes = ir.codes.data();
size_t ntotal = ir.ntotal;
size_t code_size = ir.code_size;
const AdditiveQuantizer* aq = ir.aq;

using SingleResultHandler = typename ResultHandler::SingleResultHandler;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;

#pragma omp parallel for if (res.nq > 100)
for (int64_t q = 0; q < res.nq; q++) {
Expand All @@ -142,19 +143,23 @@ void search_with_decompress(
}
}

template <bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
template <
bool is_IP,
AdditiveQuantizer::Search_type_t st,
class BlockResultHandler>
void search_with_LUT(
const IndexAdditiveQuantizer& ir,
const float* xq,
ResultHandler& res) {
BlockResultHandler& res) {
const AdditiveQuantizer& aq = *ir.aq;
const uint8_t* codes = ir.codes.data();
size_t ntotal = ir.ntotal;
size_t code_size = aq.code_size;
size_t nq = res.nq;
size_t d = ir.d;

using SingleResultHandler = typename ResultHandler::SingleResultHandler;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
std::unique_ptr<float[]> LUT(new float[nq * aq.total_codebook_size]);

aq.compute_LUT(nq, xq, LUT.get());
Expand Down Expand Up @@ -241,21 +246,23 @@ void IndexAdditiveQuantizer::search(
if (metric_type == METRIC_L2) {
using VD = VectorDistance<METRIC_L2>;
VD vd = {size_t(d), metric_arg};
HeapResultHandler<VD::C> rh(n, distances, labels, k);
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
search_with_decompress(*this, x, vd, rh);
} else if (metric_type == METRIC_INNER_PRODUCT) {
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
VD vd = {size_t(d), metric_arg};
HeapResultHandler<VD::C> rh(n, distances, labels, k);
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
search_with_decompress(*this, x, vd, rh);
}
} else {
if (metric_type == METRIC_INNER_PRODUCT) {
HeapResultHandler<CMin<float, idx_t>> rh(n, distances, labels, k);
HeapBlockResultHandler<CMin<float, idx_t>> rh(
n, distances, labels, k);
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
*this, x, rh);
} else {
HeapResultHandler<CMax<float, idx_t>> rh(n, distances, labels, k);
HeapBlockResultHandler<CMax<float, idx_t>> rh(
n, distances, labels, k);
switch (aq->search_type) {
#define DISPATCH(st) \
case AdditiveQuantizer::st: \
Expand Down
4 changes: 2 additions & 2 deletions faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ void IndexAdditiveQuantizerFastScan::search(

NormTableScaler scaler(norm_scale);
if (metric_type == METRIC_L2) {
search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
} else {
search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
}
}

Expand Down
23 changes: 13 additions & 10 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/IndexBinaryHNSW.h>

#include <omp.h>
Expand All @@ -28,6 +26,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
Expand Down Expand Up @@ -201,27 +200,31 @@ void IndexBinaryHNSW::search(
!params, "search params not supported for this index");
FAISS_THROW_IF_NOT(k > 0);

// we use the buffer for distances as float but convert them back
// to int in the end
float* distances_f = (float*)distances;

using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances_f, labels, k);

#pragma omp parallel
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(get_distance_computer());
RH::SingleResultHandler res(bres);

#pragma omp for
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = (float*)(distances + i * k);

res.begin(i);
dis->set_query((float*)(x + i * code_size));

maxheap_heapify(k, simi, idxi);
hnsw.search(*dis, k, idxi, simi, vt);
maxheap_reorder(k, simi, idxi);
hnsw.search(*dis, res, vt);
res.end();
}
}

#pragma omp parallel for
for (int i = 0; i < n * k; ++i) {
distances[i] = std::round(((float*)distances)[i]);
distances[i] = std::round(distances_f[i]);
}
}

Expand Down
Loading

0 comments on commit 32f0e8c

Please sign in to comment.