Skip to content

Commit

Permalink
Relax IVFFlatDedup test (facebookresearch#3077)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3077

This diff relaxes some IVFFlatDedup tests where distances are slighlty different over runs.
Should fix

https://app.circleci.com/pipelines/github/facebookresearch/faiss/4709/workflows/8c8213bf-8fe0-4c4e-9a7d-991f44bf1010/jobs/25551

https://app.circleci.com/pipelines/github/facebookresearch/faiss/4709/workflows/8c8213bf-8fe0-4c4e-9a7d-991f44bf1010/jobs/25547

Reviewed By: algoriddle

Differential Revision: D49732349

fbshipit-source-id: 728b9885c6b7d6ba697ccb6bacc0abd0ee2b0679
  • Loading branch information
mdouze authored and facebook-github-bot committed Sep 29, 2023
1 parent 0f18251 commit 9db1824
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
36 changes: 27 additions & 9 deletions contrib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,41 @@ def compute_PR_for(q):
# Functions that compare search results with a reference result.
# They are intended for use in tests

def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
def _cluster_tables_with_tolerance(tab1, tab2, thr):
""" for two tables, cluster them by merging values closer than thr.
Returns the cluster ids for each table element """
tab = np.hstack([tab1, tab2])
tab.sort()
n = len(tab)
diffs = np.ones(n)
diffs[1:] = tab[1:] - tab[:-1]
unique_vals = tab[diffs > thr]
idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1
idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1
return idx1, idx2


def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
""" test that knn search results are identical, with possible ties.
Raise if not. """
np.testing.assert_allclose(Dref, Dnew, rtol=rtol)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:

# otherwise collect elements per distance
r = rtol * Dref[i].max()

DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r)

for dis in np.unique(DrefC):
if dis == DrefC[-1]:
continue
mask = Dref[i, :] == dis
mask = DrefC == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))


def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
Expand Down
23 changes: 4 additions & 19 deletions tests/test_index_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from common_faiss_tests import get_dataset_2
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.inspect_tools import make_LinearTransform_matrix

from faiss.contrib.evaluation import check_ref_knn_with_draws

class TestRemoveFastScan(unittest.TestCase):
def do_test(self, ntotal, removed):
Expand Down Expand Up @@ -430,12 +430,6 @@ def test_mmappedIO_pretrans(self):

class TestIVFFlatDedup(unittest.TestCase):

def normalize_res(self, D, I):
dmax = D[-1]
res = [(d, i) for d, i in zip(D, I) if d < dmax]
res.sort()
return res

def test_dedup(self):
d = 10
nb = 1000
Expand Down Expand Up @@ -471,10 +465,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)

# test I/O
fd, tmpfile = tempfile.mkstemp()
Expand All @@ -487,10 +478,7 @@ def test_dedup(self):
os.unlink(tmpfile)
Dst, Ist = index_st.search(xq, 20)

for i in range(nq):
new = self.normalize_res(Dnew[i], Inew[i])
st = self.normalize_res(Dst[i], Ist[i])
assert st == new
check_ref_knn_with_draws(Dnew, Inew, Dst, Ist)

# test remove
toremove = np.hstack((np.arange(3, 1000, 5), np.arange(850, 950)))
Expand All @@ -501,10 +489,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)


class TestSerialize(unittest.TestCase):
Expand Down

0 comments on commit 9db1824

Please sign in to comment.