diff --git a/contrib/evaluation.py b/contrib/evaluation.py index 9c762b6081..1f4068734e 100644 --- a/contrib/evaluation.py +++ b/contrib/evaluation.py @@ -226,23 +226,41 @@ def compute_PR_for(q): # Functions that compare search results with a reference result. # They are intended for use in tests -def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew): - """ test that knn search results are identical, raise if not """ - np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) +def _cluster_tables_with_tolerance(tab1, tab2, thr): + """ for two tables, cluster them by merging values closer than thr. + Returns the cluster ids for each table element """ + tab = np.hstack([tab1, tab2]) + tab.sort() + n = len(tab) + diffs = np.ones(n) + diffs[1:] = tab[1:] - tab[:-1] + unique_vals = tab[diffs > thr] + idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1 + idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1 + return idx1, idx2 + + +def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5): + """ test that knn search results are identical, with possible ties. + Raise if not. """ + np.testing.assert_allclose(Dref, Dnew, rtol=rtol) # here we have to be careful because of draws testcase = unittest.TestCase() # because it makes nice error messages for i in range(len(Iref)): if np.all(Iref[i] == Inew[i]): # easy case continue - # we can deduce nothing about the latest line - skip_dis = Dref[i, -1] - for dis in np.unique(Dref): - if dis == skip_dis: + + # otherwise collect elements per distance + r = rtol * Dref[i].max() + + DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r) + + for dis in np.unique(DrefC): + if dis == DrefC[-1]: continue - mask = Dref[i, :] == dis + mask = DrefC == dis testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) - def check_ref_range_results(Lref, Dref, Iref, Lnew, Dnew, Inew): """ compare range search results wrt. a reference result, diff --git a/tests/test_index_composite.py b/tests/test_index_composite.py index d4f99b92d0..81a00cb938 100644 --- a/tests/test_index_composite.py +++ b/tests/test_index_composite.py @@ -17,7 +17,7 @@ from common_faiss_tests import get_dataset_2 from faiss.contrib.datasets import SyntheticDataset from faiss.contrib.inspect_tools import make_LinearTransform_matrix - +from faiss.contrib.evaluation import check_ref_knn_with_draws class TestRemoveFastScan(unittest.TestCase): def do_test(self, ntotal, removed): @@ -430,12 +430,6 @@ def test_mmappedIO_pretrans(self): class TestIVFFlatDedup(unittest.TestCase): - def normalize_res(self, D, I): - dmax = D[-1] - res = [(d, i) for d, i in zip(D, I) if d < dmax] - res.sort() - return res - def test_dedup(self): d = 10 nb = 1000 @@ -471,10 +465,7 @@ def test_dedup(self): Dref, Iref = index_ref.search(xq, 20) Dnew, Inew = index_new.search(xq, 20) - for i in range(nq): - ref = self.normalize_res(Dref[i], Iref[i]) - new = self.normalize_res(Dnew[i], Inew[i]) - assert ref == new + check_ref_knn_with_draws(Dref, Iref, Dnew, Inew) # test I/O fd, tmpfile = tempfile.mkstemp() @@ -487,10 +478,7 @@ def test_dedup(self): os.unlink(tmpfile) Dst, Ist = index_st.search(xq, 20) - for i in range(nq): - new = self.normalize_res(Dnew[i], Inew[i]) - st = self.normalize_res(Dst[i], Ist[i]) - assert st == new + check_ref_knn_with_draws(Dnew, Inew, Dst, Ist) # test remove toremove = np.hstack((np.arange(3, 1000, 5), np.arange(850, 950))) @@ -501,10 +489,7 @@ def test_dedup(self): Dref, Iref = index_ref.search(xq, 20) Dnew, Inew = index_new.search(xq, 20) - for i in range(nq): - ref = self.normalize_res(Dref[i], Iref[i]) - new = self.normalize_res(Dnew[i], Inew[i]) - assert ref == new + check_ref_knn_with_draws(Dref, Iref, Dnew, Inew) class TestSerialize(unittest.TestCase):