diff --git a/demos/offline_ivf/README.md b/demos/offline_ivf/README.md new file mode 100644 index 0000000000..df848ba0ab --- /dev/null +++ b/demos/offline_ivf/README.md @@ -0,0 +1,52 @@ + +# Offline IVF + +This folder contains the code for the offline ivf algorithm powered by faiss big batch search. + +Create a conda env: + +`conda create --name oivf python=3.10` + +`conda activate oivf` + +`conda install -c pytorch/label/nightly -c nvidia faiss-gpu=1.7.4` + +`conda install tqdm` + +`conda install pyyaml` + +`conda install -c conda-forge submitit` + + +## Run book + +1. Optionally shard your dataset (see create_sharded_dataset.py) and create the corresponding yaml file `config_ssnpp.yaml`. You can use `generate_config.py` by specifying the root directory of your dataset and the files with the data shards + +`python generate_config` + +2. Run the train index command + +`python run.py --command train_index --config config_ssnpp.yaml --xb ssnpp_1B` + + +3. Run the index-shard command so it produces sharded indexes, required for the search step + +`python run.py --command index_shard --config config_ssnpp.yaml --xb ssnpp_1B` + + +6. Send jobs to the cluster to run search + +`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --cluster_run --partition ` + + +Remarks about the `search` command: it is assumed that the database vectors are the query vectors when performing the search step. +a. If the query vectors are different than the database vectors, it should be passed in the xq argument +b. A new dataset needs to be prepared (step 1) before passing it to the query vectors argument `–xq` + +`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --xq ` + + +6. We can always run the consistency-check for sanity checks! + +`python run.py --command consistency_check--config config_ssnpp.yaml --xb ssnpp_1B` + diff --git a/demos/offline_ivf/__init__.py b/demos/offline_ivf/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demos/offline_ivf/config_ssnpp.yaml b/demos/offline_ivf/config_ssnpp.yaml new file mode 100644 index 0000000000..690f0de156 --- /dev/null +++ b/demos/offline_ivf/config_ssnpp.yaml @@ -0,0 +1,109 @@ +d: 256 +output: /checkpoint/marialomeli/offline_faiss/ssnpp +index: + prod: + - 'IVF8192,PQ128' + non-prod: + - 'IVF16384,PQ128' + - 'IVF32768,PQ128' +nprobe: + prod: + - 512 + non-prod: + - 256 + - 128 + - 1024 + - 2048 + - 4096 + - 8192 + +k: 50 +index_shard_size: 50000000 +query_batch_size: 50000000 +evaluation_sample: 10000 +training_sample: 1572864 +datasets: + ssnpp_1B: + root: /checkpoint/marialomeli/ssnpp_data + size: 1000000000 + files: + - dtype: uint8 + format: npy + name: ssnpp_0000000000.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000001.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000002.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000003.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000004.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000005.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000006.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000007.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000008.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000009.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000010.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000011.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000012.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000013.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000014.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000015.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000016.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000017.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000018.npy + size: 50000000 + - dtype: uint8 + format: npy + name: ssnpp_0000000019.npy + size: 50000000 diff --git a/demos/offline_ivf/create_sharded_ssnpp_files.py b/demos/offline_ivf/create_sharded_ssnpp_files.py new file mode 100644 index 0000000000..1dd22d2be8 --- /dev/null +++ b/demos/offline_ivf/create_sharded_ssnpp_files.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import argparse +import os + + +def xbin_mmap(fname, dtype, maxn=-1): + """ + Code from + https://github.com/harsha-simhadri/big-ann-benchmarks/blob/main/benchmark/dataset_io.py#L94 + mmap the competition file format for a given type of items + """ + n, d = map(int, np.fromfile(fname, dtype="uint32", count=2)) + assert os.stat(fname).st_size == 8 + n * d * np.dtype(dtype).itemsize + if maxn > 0: + n = min(n, maxn) + return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d)) + + +def main(args: argparse.Namespace): + ssnpp_data = xbin_mmap(fname=args.filepath, dtype="uint8") + num_batches = ssnpp_data.shape[0] // args.data_batch + assert ( + ssnpp_data.shape[0] % args.data_batch == 0 + ), "num of embeddings per file should divide total num of embeddings" + for i in range(num_batches): + xb_batch = ssnpp_data[ + i * args.data_batch:(i + 1) * args.data_batch, : + ] + filename = args.output_dir + f"/ssnpp_{(i):010}.npy" + np.save(filename, xb_batch) + print(f"File {filename} is saved!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_batch", + dest="data_batch", + type=int, + default=50000000, + help="Number of embeddings per file, should be a divisor of 1B", + ) + parser.add_argument( + "--filepath", + dest="filepath", + type=str, + default="/datasets01/big-ann-challenge-data/FB_ssnpp/FB_ssnpp_database.u8bin", + help="path of 1B ssnpp database vectors' original file", + ) + parser.add_argument( + "--filepath", + dest="output_dir", + type=str, + default="/checkpoint/marialomeli/ssnpp_data", + help="path to put sharded files", + ) + + args = parser.parse_args() + main(args) diff --git a/demos/offline_ivf/dataset.py b/demos/offline_ivf/dataset.py new file mode 100644 index 0000000000..f9e30009c5 --- /dev/null +++ b/demos/offline_ivf/dataset.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import faiss +from typing import List +import random +import logging +from functools import lru_cache + + +def create_dataset_from_oivf_config(cfg, ds_name): + normalise = cfg["normalise"] if "normalise" in cfg else False + return MultiFileVectorDataset( + cfg["datasets"][ds_name]["root"], + [ + FileDescriptor( + f["name"], f["format"], np.dtype(f["dtype"]), f["size"] + ) + for f in cfg["datasets"][ds_name]["files"] + ], + cfg["d"], + normalise, + cfg["datasets"][ds_name]["size"], + ) + + +@lru_cache(maxsize=100) +def _memmap_vecs( + file_name: str, format: str, dtype: np.dtype, size: int, d: int +) -> np.array: + """ + If the file is in raw format, the file size will + be divisible by the dimensionality and by the size + of the data type. + Otherwise,the file contains a header and we assume + it is of .npy type. It the returns the memmapped file. + """ + + assert os.path.exists(file_name), f"file does not exist {file_name}" + if format == "raw": + fl = os.path.getsize(file_name) + nb = fl // d // dtype.itemsize + assert nb == size, f"{nb} is different than config's {size}" + assert fl == d * dtype.itemsize * nb # no header + return np.memmap(file_name, shape=(nb, d), dtype=dtype, mode="r") + elif format == "npy": + vecs = np.load(file_name, mmap_mode="r") + assert vecs.shape[0] == size, f"size:{size},shape {vecs.shape[0]}" + assert vecs.shape[1] == d + assert vecs.dtype == dtype + return vecs + else: + ValueError("The file cannot be loaded in the current format.") + + +class FileDescriptor: + def __init__(self, name: str, format: str, dtype: np.dtype, size: int): + self.name = name + self.format = format + self.dtype = dtype + self.size = size + + +class MultiFileVectorDataset: + def __init__( + self, + root: str, + file_descriptors: List[FileDescriptor], + d: int, + normalize: bool, + size: int, + ): + assert os.path.exists(root) + self.root = root + self.file_descriptors = file_descriptors + self.d = d + self.normalize = normalize + self.size = size + self.file_offsets = [0] + t = 0 + for f in self.file_descriptors: + xb = _memmap_vecs( + f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d + ) + t += xb.shape[0] + self.file_offsets.append(t) + assert ( + t == self.size + ), "the sum of num of embeddings per file!=total num of embeddings" + + def iterate(self, start: int, batch_size: int, dt: np.dtype): + buffer = np.empty(shape=(batch_size, self.d), dtype=dt) + rem = 0 + for f in self.file_descriptors: + if start >= f.size: + start -= f.size + continue + logging.info(f"processing: {f.name}...") + xb = _memmap_vecs( + f"{self.root}/{f.name}", + f.format, + f.dtype, + f.size, + self.d, + ) + if start > 0: + xb = xb[start:] + start = 0 + req = min(batch_size - rem, xb.shape[0]) + buffer[rem:rem + req] = xb[:req] + rem += req + if rem == batch_size: + if self.normalize: + faiss.normalize_L2(buffer) + yield buffer.copy() + rem = 0 + for i in range(req, xb.shape[0], batch_size): + j = i + batch_size + if j <= xb.shape[0]: + tmp = xb[i:j].astype(dt) + if self.normalize: + faiss.normalize_L2(tmp) + yield tmp + else: + rem = xb.shape[0] - i + buffer[:rem] = xb[i:j] + if rem > 0: + tmp = buffer[:rem] + if self.normalize: + faiss.normalize_L2(tmp) + yield tmp + + def get(self, idx: List[int]): + n = len(idx) + fidx = np.searchsorted(self.file_offsets, idx, "right") + res = np.empty(shape=(len(idx), self.d), dtype=np.float32) + for r, id, fid in zip(range(n), idx, fidx): + assert fid > 0 and fid <= len(self.file_descriptors), f"{fid}" + f = self.file_descriptors[fid - 1] + # deferring normalization until after reading the vec + vecs = _memmap_vecs( + f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d + ) + i = id - self.file_offsets[fid - 1] + assert i >= 0 and i < vecs.shape[0] + res[r, :] = vecs[i] # TODO: find a faster way + if self.normalize: + faiss.normalize_L2(res) + return res + + def sample(self, n, idx_fn, vecs_fn): + if vecs_fn and os.path.exists(vecs_fn): + vecs = np.load(vecs_fn) + assert vecs.shape == (n, self.d) + return vecs + if idx_fn and os.path.exists(idx_fn): + idx = np.load(idx_fn) + assert idx.size == n + else: + idx = np.array(sorted(random.sample(range(self.size), n))) + if idx_fn: + np.save(idx_fn, idx) + vecs = self.get(idx) + if vecs_fn: + np.save(vecs_fn, vecs) + return vecs + + def get_first_n(self, n, dt): + assert n <= self.size + return next(self.iterate(0, n, dt)) diff --git a/demos/offline_ivf/generate_config.py b/demos/offline_ivf/generate_config.py new file mode 100644 index 0000000000..b5a12645ab --- /dev/null +++ b/demos/offline_ivf/generate_config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import os +import yaml + +# with ssnpp sharded data +root = "/checkpoint/marialomeli/ssnpp_data" +file_names = [f"ssnpp_{i:010}.npy" for i in range(20)] +d = 256 +dt = np.dtype(np.uint8) + + +def read_embeddings(fp): + fl = os.path.getsize(fp) + nb = fl // d // dt.itemsize + print(nb) + if fl == d * dt.itemsize * nb: # no header + return ("raw", np.memmap(fp, shape=(nb, d), dtype=dt, mode="r")) + else: # assume npy + vecs = np.load(fp, mmap_mode="r") + assert vecs.shape[1] == d + assert vecs.dtype == dt + return ("npy", vecs) + + +cfg = {} +files = [] +size = 0 +for fn in file_names: + fp = f"{root}/{fn}" + assert os.path.exists(fp), f"{fp} is missing" + ft, xb = read_embeddings(fp) + files.append( + {"name": fn, "size": xb.shape[0], "dtype": dt.name, "format": ft} + ) + size += xb.shape[0] + +cfg["size"] = size +cfg["root"] = root +cfg["d"] = d +cfg["files"] = files +print(yaml.dump(cfg)) diff --git a/demos/offline_ivf/offline_ivf.py b/demos/offline_ivf/offline_ivf.py new file mode 100644 index 0000000000..5c316178cb --- /dev/null +++ b/demos/offline_ivf/offline_ivf.py @@ -0,0 +1,938 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import faiss +import numpy as np +import os +from tqdm import tqdm, trange +import sys +import logging +from faiss.contrib.ondisk import merge_ondisk +from faiss.contrib.big_batch_search import big_batch_search +from faiss.contrib.exhaustive_search import knn_ground_truth +from faiss.contrib.evaluation import knn_intersection_measure +from utils import ( + get_intersection_cardinality_frequencies, + margin, + is_pretransform_index, +) +from dataset import create_dataset_from_oivf_config + +logging.basicConfig( + format=( + "%(asctime)s.%(msecs)03d %(levelname)-8s %(threadName)-12s %(message)s" + ), + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + force=True, +) + +EMBEDDINGS_BATCH_SIZE: int = 100_000 +NUM_SUBSAMPLES: int = 100 +SMALL_DATA_SAMPLE: int = 10000 + + +class OfflineIVF: + def __init__(self, cfg, args, nprobe, index_factory_str): + self.input_d = cfg["d"] + self.dt = cfg["datasets"][args.xb]["files"][0]["dtype"] + assert self.input_d > 0 + output_dir = cfg["output"] + assert os.path.exists(output_dir) + self.index_factory = index_factory_str + assert self.index_factory is not None + self.index_factory_fn = self.index_factory.replace(",", "_") + self.index_template_file = ( + f"{output_dir}/{args.xb}/{self.index_factory_fn}.empty.faissindex" + ) + logging.info(f"index template: {self.index_template_file}") + + if not args.xq: + args.xq = args.xb + + self.by_residual = True + if args.no_residuals: + self.by_residual = False + + xb_output_dir = f"{output_dir}/{args.xb}" + if not os.path.exists(xb_output_dir): + os.makedirs(xb_output_dir) + xq_output_dir = f"{output_dir}/{args.xq}" + if not os.path.exists(xq_output_dir): + os.makedirs(xq_output_dir) + search_output_dir = f"{output_dir}/{args.xq}_in_{args.xb}" + if not os.path.exists(search_output_dir): + os.makedirs(search_output_dir) + self.knn_dir = f"{search_output_dir}/knn" + if not os.path.exists(self.knn_dir): + os.makedirs(self.knn_dir) + self.eval_dir = f"{search_output_dir}/eval" + if not os.path.exists(self.eval_dir): + os.makedirs(self.eval_dir) + self.index = {} # to keep a reference to opened indices, + self.ivls = {} # hstack inverted lists, + self.index_shards = {} # and index shards + self.index_shard_prefix = ( + f"{xb_output_dir}/{self.index_factory_fn}.shard_" + ) + self.xq_index_shard_prefix = ( + f"{xq_output_dir}/{self.index_factory_fn}.shard_" + ) + self.index_file = ( # TODO: added back temporarily for evaluate, handle name of non-sharded index file and remove. + f"{xb_output_dir}/{self.index_factory_fn}.faissindex" + ) + self.xq_index_file = ( + f"{xq_output_dir}/{self.index_factory_fn}.faissindex" + ) + self.training_sample = cfg["training_sample"] + self.evaluation_sample = cfg["evaluation_sample"] + self.xq_ds = create_dataset_from_oivf_config(cfg, args.xq) + self.xb_ds = create_dataset_from_oivf_config(cfg, args.xb) + file_descriptors = self.xq_ds.file_descriptors + self.file_sizes = [fd.size for fd in file_descriptors] + self.shard_size = cfg["index_shard_size"] # ~100GB + self.nshards = self.xb_ds.size // self.shard_size + if self.xb_ds.size % self.shard_size != 0: + self.nshards += 1 + self.xq_nshards = self.xq_ds.size // self.shard_size + if self.xq_ds.size % self.shard_size != 0: + self.xq_nshards += 1 + self.nprobe = nprobe + assert self.nprobe > 0, "Invalid nprobe parameter." + if "deduper" in cfg: + self.deduper = cfg["deduper"] + self.deduper_codec_fn = [ + f"{xb_output_dir}/deduper_codec_{codec.replace(',', '_')}" + for codec in self.deduper + ] + self.deduper_idx_fn = [ + f"{xb_output_dir}/deduper_idx_{codec.replace(',', '_')}" + for codec in self.deduper + ] + else: + self.deduper = None + self.k = cfg["k"] + assert self.k > 0, "Invalid number of neighbours parameter." + self.knn_output_file_suffix = ( + f"{self.index_factory_fn}_np{self.nprobe}.npy" + ) + + fp = 32 + if self.dt == "float16": + fp = 16 + + self.xq_bs = cfg["query_batch_size"] + if "metric" in cfg: + self.metric = eval(f'faiss.{cfg["metric"]}') + else: + self.metric = faiss.METRIC_L2 + + if "evaluate_by_margin" in cfg: + self.evaluate_by_margin = cfg["evaluate_by_margin"] + else: + self.evaluate_by_margin = False + + os.system("grep -m1 'model name' < /proc/cpuinfo") + os.system("grep -E 'MemTotal|MemFree' /proc/meminfo") + os.system("nvidia-smi") + os.system("nvcc --version") + + self.knn_queries_memory_limit = 4 * 1024 * 1024 * 1024 # 4 GB + self.knn_vectors_memory_limit = 8 * 1024 * 1024 * 1024 # 8 GB + + def input_stats(self): + """ + Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added). + """ + xb_sample = self.xb_ds.get_first_n(self.training_sample, np.float32) + logging.info(f"input shape: {xb_sample.shape}") + logging.info("running MatrixStats on training sample...") + logging.info(faiss.MatrixStats(xb_sample).comments) + logging.info("done") + + def dedupe(self): + logging.info(self.deduper) + if self.deduper is None: + logging.info("No deduper configured") + return + codecs = [] + codesets = [] + idxs = [] + for factory, filename in zip(self.deduper, self.deduper_codec_fn): + if os.path.exists(filename): + logging.info(f"loading trained dedupe codec: {filename}") + codec = faiss.read_index(filename) + else: + logging.info(f"training dedupe codec: {factory}") + codec = faiss.index_factory(self.input_d, factory) + xb_sample = np.unique( + self.xb_ds.get_first_n(100_000, np.float32), axis=0 + ) + faiss.ParameterSpace().set_index_parameter(codec, "verbose", 1) + codec.train(xb_sample) + logging.info(f"writing trained dedupe codec: {filename}") + faiss.write_index(codec, filename) + codecs.append(codec) + codesets.append(faiss.CodeSet(codec.sa_code_size())) + idxs.append(np.empty((0,), dtype=np.uint32)) + bs = 1_000_000 + i = 0 + for buffer in tqdm(self.xb_ds.iterate(0, bs, np.float32)): + for j in range(len(codecs)): + codec, codeset, idx = codecs[j], codesets[j], idxs[j] + uniq = codeset.insert(codec.sa_encode(buffer)) + idxs[j] = np.append( + idx, + np.arange(i, i + buffer.shape[0], dtype=np.uint32)[uniq], + ) + i += buffer.shape[0] + for idx, filename in zip(idxs, self.deduper_idx_fn): + logging.info(f"writing {filename}, shape: {idx.shape}") + np.save(filename, idx) + logging.info("done") + + def train_index(self): + """ + Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added). + """ + assert not os.path.exists(self.index_template_file), ( + "The train command has been ran, the index template file already" + " exists." + ) + xb_sample = np.unique( + self.xb_ds.get_first_n(self.training_sample, np.float32), axis=0 + ) + logging.info(f"input shape: {xb_sample.shape}") + index = faiss.index_factory( + self.input_d, self.index_factory, self.metric + ) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + index_ivf.by_residual = True + faiss.ParameterSpace().set_index_parameter(index, "verbose", 1) + logging.info("running training...") + index.train(xb_sample) + logging.info(f"writing trained index {self.index_template_file}...") + faiss.write_index(index, self.index_template_file) + logging.info("done") + + def _iterate_transformed(self, ds, start, batch_size, dt): + assert os.path.exists(self.index_template_file) + index = faiss.read_index(self.index_template_file) + if is_pretransform_index(index): + vt = index.chain.at(0) # fetch pretransform + for buffer in ds.iterate(start, batch_size, dt): + yield vt.apply(buffer) + else: + for buffer in ds.iterate(start, batch_size, dt): + yield buffer + + def index_shard_and_quantize(self): + assert os.path.exists(self.index_template_file) + index = faiss.read_index(self.index_template_file) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + assert self.nprobe <= index_ivf.quantizer.ntotal, ( + f"the number of vectors {index_ivf.quantizer.ntotal} is not enough" + f" to retrieve {self.nprobe} neighbours, check." + ) + + if is_pretransform_index(index): + d = index.chain.at(0).d_out + else: + d = self.input_d + for i in range(0, self.nshards): + sfn = f"{self.index_shard_prefix}{i}" + cqfn = f"{self.coarse_quantization_prefix}{i}" # fixme + if os.path.exists(sfn) or os.path.exists(cqfn): + logging.info(f"skipping shard: {i}") + continue + try: + with open(cqfn, "xb") as cqf: + index.reset() + start = i * self.shard_size + j = 0 + quantizer = faiss.index_cpu_to_all_gpus( + index_ivf.quantizer + ) + for xb_j in tqdm( + self._iterate_transformed( + self.xb_ds, + start, + EMBEDDINGS_BATCH_SIZE, + np.float32, + ), + file=sys.stdout, + ): + assert xb_j.shape[1] == d + _, I = quantizer.search(xb_j, self.nprobe) + assert np.amin(I) >= 0, f"{I}" + assert np.amax(I) < index_ivf.nlist + cqf.write(I) + self._index_add_core_wrapper( # fixme + index_ivf, + xb_j, + np.arange(start + j, start + j + xb_j.shape[0]), + I[:, 0], + ) + j += xb_j.shape[0] + assert j <= self.shard_size + if j == self.shard_size: + break + logging.info(f"writing {sfn}...") + faiss.write_index(index, sfn) + except FileExistsError: + logging.info(f"skipping shard: {i}") + continue + logging.info("done") + + def index_shard(self): + assert os.path.exists(self.index_template_file) + index = faiss.read_index(self.index_template_file) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + assert self.nprobe <= index_ivf.quantizer.ntotal, ( + f"the number of vectors {index_ivf.quantizer.ntotal} is not enough" + f" to retrieve {self.nprobe} neighbours, check." + ) + cpu_quantizer = index_ivf.quantizer + gpu_quantizer = faiss.index_cpu_to_all_gpus(cpu_quantizer) + + for i in range(0, self.nshards): + sfn = f"{self.index_shard_prefix}{i}" + try: + index.reset() + index_ivf.quantizer = gpu_quantizer + with open(sfn, "xb"): + start = i * self.shard_size + jj = 0 + embeddings_batch_size = min( + EMBEDDINGS_BATCH_SIZE, self.shard_size + ) + assert ( + self.shard_size % embeddings_batch_size == 0 + or EMBEDDINGS_BATCH_SIZE % embeddings_batch_size == 0 + ), ( + f"the shard size {self.shard_size} and embeddings" + f" shard size {EMBEDDINGS_BATCH_SIZE} are not" + " divisible" + ) + + for xb_j in tqdm( + self._iterate_transformed( + self.xb_ds, + start, + embeddings_batch_size, + np.float32, + ), + file=sys.stdout, + ): + assert xb_j.shape[1] == index.d + index.add_with_ids( + xb_j, + np.arange(start + jj, start + jj + xb_j.shape[0]), + ) + jj += xb_j.shape[0] + logging.info(jj) + assert ( + jj <= self.shard_size + ), f"jj {jj} and shard_zide {self.shard_size}" + if jj == self.shard_size: + break + logging.info(f"writing {sfn}...") + index_ivf.quantizer = cpu_quantizer + faiss.write_index(index, sfn) + except FileExistsError: + logging.info(f"skipping shard: {i}") + continue + logging.info("done") + + def merge_index(self): + ivf_file = f"{self.index_file}.ivfdata" + + assert os.path.exists(self.index_template_file) + assert not os.path.exists( + ivf_file + ), f"file with embeddings data {ivf_file} not found, check." + assert not os.path.exists(self.index_file) + index = faiss.read_index(self.index_template_file) + block_fnames = [ + f"{self.index_shard_prefix}{i}" for i in range(self.nshards) + ] + for fn in block_fnames: + assert os.path.exists(fn) + logging.info(block_fnames) + logging.info("merging...") + merge_ondisk(index, block_fnames, ivf_file) + logging.info("writing index...") + faiss.write_index(index, self.index_file) + logging.info("done") + + def _cached_search( + self, + sample, + xq_ds, + xb_ds, + idx_file, + vecs_file, + I_file, + D_file, + index_file=None, + nprobe=None, + ): + if not os.path.exists(I_file): + assert not os.path.exists(I_file), f"file {I_file} does not exist " + assert not os.path.exists(D_file), f"file {D_file} does not exist " + xq = xq_ds.sample(sample, idx_file, vecs_file) + + if index_file: + D, I = self._index_nonsharded_search(index_file, xq, nprobe) + else: + logging.info("ground truth computations") + db_iterator = xb_ds.iterate(0, 100_000, np.float32) + D, I = knn_ground_truth( + xq, db_iterator, self.k, metric_type=self.metric + ) + assert np.amin(I) >= 0 + + np.save(I_file, I) + np.save(D_file, D) + else: + assert os.path.exists(idx_file), f"file {idx_file} does not exist " + assert os.path.exists( + vecs_file + ), f"file {vecs_file} does not exist " + assert os.path.exists(I_file), f"file {I_file} does not exist " + assert os.path.exists(D_file), f"file {D_file} does not exist " + I = np.load(I_file) + D = np.load(D_file) + assert I.shape == (sample, self.k), f"{I_file} shape mismatch" + assert D.shape == (sample, self.k), f"{D_file} shape mismatch" + return (D, I) + + def _index_search(self, index_shard_prefix, xq, nprobe): + assert nprobe is not None + logging.info( + f"open sharded index: {index_shard_prefix}, {self.nshards}" + ) + index = self._open_sharded_index(index_shard_prefix) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + logging.info(f"setting nprobe to {nprobe}") + index_ivf.nprobe = nprobe + return index.search(xq, self.k) + + def _index_nonsharded_search(self, index_file, xq, nprobe): + assert nprobe is not None + logging.info(f"index {index_file}") + assert os.path.exists(index_file), f"file {index_file} does not exist " + index = faiss.read_index(index_file, faiss.IO_FLAG_ONDISK_SAME_DIR) + logging.info(f"index size {index.ntotal} ") + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + logging.info(f"setting nprobe to {nprobe}") + index_ivf.nprobe = nprobe + return index.search(xq, self.k) + + def _refine_distances(self, xq_ds, idx, xb_ds, I): + xq = xq_ds.get(idx).repeat(self.k, axis=0) + xb = xb_ds.get(I.reshape(-1)) + if self.metric == faiss.METRIC_INNER_PRODUCT: + return (xq * xb).sum(axis=1).reshape(I.shape) + elif self.metric == faiss.METRIC_L2: + return ((xq - xb) ** 2).sum(axis=1).reshape(I.shape) + else: + raise ValueError(f"metric not supported {self.metric}") + + def evaluate(self): + self._evaluate( + self.index_factory_fn, + self.index_file, + self.xq_index_file, + self.nprobe, + ) + + def _evaluate(self, index_factory_fn, index_file, xq_index_file, nprobe): + idx_a_file = f"{self.eval_dir}/idx_a.npy" + idx_b_gt_file = f"{self.eval_dir}/idx_b_gt.npy" + idx_b_ann_file = ( + f"{self.eval_dir}/idx_b_ann_{index_factory_fn}_np{nprobe}.npy" + ) + vecs_a_file = f"{self.eval_dir}/vecs_a.npy" + vecs_b_gt_file = f"{self.eval_dir}/vecs_b_gt.npy" + vecs_b_ann_file = ( + f"{self.eval_dir}/vecs_b_ann_{index_factory_fn}_np{nprobe}.npy" + ) + D_a_gt_file = f"{self.eval_dir}/D_a_gt.npy" + D_a_ann_file = ( + f"{self.eval_dir}/D_a_ann_{index_factory_fn}_np{nprobe}.npy" + ) + D_a_ann_refined_file = f"{self.eval_dir}/D_a_ann_refined_{index_factory_fn}_np{nprobe}.npy" + D_b_gt_file = f"{self.eval_dir}/D_b_gt.npy" + D_b_ann_file = ( + f"{self.eval_dir}/D_b_ann_{index_factory_fn}_np{nprobe}.npy" + ) + D_b_ann_gt_file = ( + f"{self.eval_dir}/D_b_ann_gt_{index_factory_fn}_np{nprobe}.npy" + ) + I_a_gt_file = f"{self.eval_dir}/I_a_gt.npy" + I_a_ann_file = ( + f"{self.eval_dir}/I_a_ann_{index_factory_fn}_np{nprobe}.npy" + ) + I_b_gt_file = f"{self.eval_dir}/I_b_gt.npy" + I_b_ann_file = ( + f"{self.eval_dir}/I_b_ann_{index_factory_fn}_np{nprobe}.npy" + ) + I_b_ann_gt_file = ( + f"{self.eval_dir}/I_b_ann_gt_{index_factory_fn}_np{nprobe}.npy" + ) + margin_gt_file = f"{self.eval_dir}/margin_gt.npy" + margin_refined_file = ( + f"{self.eval_dir}/margin_refined_{index_factory_fn}_np{nprobe}.npy" + ) + margin_ann_file = ( + f"{self.eval_dir}/margin_ann_{index_factory_fn}_np{nprobe}.npy" + ) + + logging.info("exact search forward") + # xq -> xb AKA a -> b + D_a_gt, I_a_gt = self._cached_search( + self.evaluation_sample, + self.xq_ds, + self.xb_ds, + idx_a_file, + vecs_a_file, + I_a_gt_file, + D_a_gt_file, + ) + idx_a = np.load(idx_a_file) + + logging.info("approximate search forward") + D_a_ann, I_a_ann = self._cached_search( + self.evaluation_sample, + self.xq_ds, + self.xb_ds, + idx_a_file, + vecs_a_file, + I_a_ann_file, + D_a_ann_file, + index_file, + nprobe, + ) + + logging.info( + "calculate refined distances on approximate search forward" + ) + if os.path.exists(D_a_ann_refined_file): + D_a_ann_refined = np.load(D_a_ann_refined_file) + assert D_a_ann.shape == D_a_ann_refined.shape + else: + D_a_ann_refined = self._refine_distances( + self.xq_ds, idx_a, self.xb_ds, I_a_ann + ) + np.save(D_a_ann_refined_file, D_a_ann_refined) + + if self.evaluate_by_margin: + k_extract = self.k + margin_threshold = 1.05 + logging.info( + "exact search backward from the k_extract NN results of" + " forward search" + ) + # xb -> xq AKA b -> a + D_a_b_gt = D_a_gt[:, :k_extract].ravel() + idx_b_gt = I_a_gt[:, :k_extract].ravel() + assert len(idx_b_gt) == self.evaluation_sample * k_extract + np.save(idx_b_gt_file, idx_b_gt) + # exact search + D_b_gt, _ = self._cached_search( + len(idx_b_gt), + self.xb_ds, + self.xq_ds, + idx_b_gt_file, + vecs_b_gt_file, + I_b_gt_file, + D_b_gt_file, + ) # xb and xq ^^^ are inverted + + logging.info("margin on exact search") + margin_gt = margin( + self.evaluation_sample, + idx_a, + idx_b_gt, + D_a_b_gt, + D_a_gt, + D_b_gt, + self.k, + k_extract, + margin_threshold, + ) + np.save(margin_gt_file, margin_gt) + + logging.info( + "exact search backward from the k_extract NN results of" + " approximate forward search" + ) + D_a_b_refined = D_a_ann_refined[:, :k_extract].ravel() + idx_b_ann = I_a_ann[:, :k_extract].ravel() + assert len(idx_b_ann) == self.evaluation_sample * k_extract + np.save(idx_b_ann_file, idx_b_ann) + # exact search + D_b_ann_gt, _ = self._cached_search( + len(idx_b_ann), + self.xb_ds, + self.xq_ds, + idx_b_ann_file, + vecs_b_ann_file, + I_b_ann_gt_file, + D_b_ann_gt_file, + ) # xb and xq ^^^ are inverted + + logging.info("refined margin on approximate search") + margin_refined = margin( + self.evaluation_sample, + idx_a, + idx_b_ann, + D_a_b_refined, + D_a_gt, # not D_a_ann_refined(!) + D_b_ann_gt, + self.k, + k_extract, + margin_threshold, + ) + np.save(margin_refined_file, margin_refined) + + D_b_ann, I_b_ann = self._cached_search( + len(idx_b_ann), + self.xb_ds, + self.xq_ds, + idx_b_ann_file, + vecs_b_ann_file, + I_b_ann_file, + D_b_ann_file, + xq_index_file, + nprobe, + ) + + D_a_b_ann = D_a_ann[:, :k_extract].ravel() + + logging.info("approximate search margin") + + margin_ann = margin( + self.evaluation_sample, + idx_a, + idx_b_ann, + D_a_b_ann, + D_a_ann, + D_b_ann, + self.k, + k_extract, + margin_threshold, + ) + np.save(margin_ann_file, margin_ann) + + logging.info("intersection") + logging.info(I_a_gt) + logging.info(I_a_ann) + + for i in range(1, self.k + 1): + logging.info( + f"{i}: {knn_intersection_measure(I_a_gt[:,:i], I_a_ann[:,:i])}" + ) + + logging.info(f"mean of gt distances: {D_a_gt.mean()}") + logging.info(f"mean of approx distances: {D_a_ann.mean()}") + logging.info(f"mean of refined distances: {D_a_ann_refined.mean()}") + + logging.info("intersection cardinality frequencies") + logging.info(get_intersection_cardinality_frequencies(I_a_ann, I_a_gt)) + + logging.info("done") + pass + + def _knn_function(self, xq, xb, k, metric, thread_id=None): + try: + return faiss.knn_gpu( + self.all_gpu_resources[thread_id], + xq, + xb, + k, + metric=metric, + device=thread_id, + vectorsMemoryLimit=self.knn_vectors_memory_limit, + queriesMemoryLimit=self.knn_queries_memory_limit, + ) + except Exception: + logging.info(f"knn_function failed: {xq.shape}, {xb.shape}") + raise + + def _coarse_quantize(self, index_ivf, xq, nprobe): + assert nprobe <= index_ivf.quantizer.ntotal + quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer) + bs = 100_000 + nq = len(xq) + q_assign = np.empty((nq, nprobe), dtype="int32") + for i0 in trange(0, nq, bs): + i1 = min(nq, i0 + bs) + _, q_assign_i = quantizer.search(xq[i0:i1], nprobe) + q_assign[i0:i1] = q_assign_i + return q_assign + + def search(self): + logging.info(f"search: {self.knn_dir}") + slurm_job_id = os.environ.get("SLURM_JOB_ID") + + ngpu = faiss.get_num_gpus() + logging.info(f"number of gpus: {ngpu}") + self.all_gpu_resources = [ + faiss.StandardGpuResources() for _ in range(ngpu) + ] + self._knn_function( + np.zeros((10, 10), dtype=np.float16), + np.zeros((10, 10), dtype=np.float16), + self.k, + metric=self.metric, + thread_id=0, + ) + + index = self._open_sharded_index() + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + logging.info(f"setting nprobe to {self.nprobe}") + index_ivf.nprobe = self.nprobe + # quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer) + for i in range(0, self.xq_ds.size, self.xq_bs): + Ifn = f"{self.knn_dir}/I{(i):010}_{self.knn_output_file_suffix}" + Dfn = f"{self.knn_dir}/D_approx{(i):010}_{self.knn_output_file_suffix}" + CPfn = f"{self.knn_dir}/CP{(i):010}_{self.knn_output_file_suffix}" + + if slurm_job_id: + worker_record = ( + self.knn_dir + + f"/record_{(i):010}_{self.knn_output_file_suffix}.txt" + ) + if not os.path.exists(worker_record): + logging.info( + f"creating record file {worker_record} and saving job" + f" id: {slurm_job_id}" + ) + with open(worker_record, "w") as h: + h.write(slurm_job_id) + else: + old_slurm_id = open(worker_record, "r").read() + logging.info( + f"old job slurm id {old_slurm_id} and current job id:" + f" {slurm_job_id}" + ) + if old_slurm_id == slurm_job_id: + if os.path.getsize(Ifn) == 0: + logging.info( + f"cleaning up zero length files {Ifn} and" + f" {Dfn}" + ) + os.remove(Ifn) + os.remove(Dfn) + + try: # TODO: modify shape for pretransform case + with open(Ifn, "xb") as f, open(Dfn, "xb") as g: + xq_i = np.empty( + shape=(self.xq_bs, self.input_d), dtype=np.float16 + ) + q_assign = np.empty( + (self.xq_bs, self.nprobe), dtype=np.int32 + ) + j = 0 + quantizer = faiss.index_cpu_to_all_gpus( + index_ivf.quantizer + ) + for xq_i_j in tqdm( + self._iterate_transformed( + self.xq_ds, i, min(100_000, self.xq_bs), np.float16 + ), + file=sys.stdout, + ): + xq_i[j:j + xq_i_j.shape[0]] = xq_i_j + ( + _, + q_assign[j:j + xq_i_j.shape[0]], + ) = quantizer.search(xq_i_j, self.nprobe) + j += xq_i_j.shape[0] + assert j <= xq_i.shape[0] + if j == xq_i.shape[0]: + break + xq_i = xq_i[:j] + q_assign = q_assign[:j] + + assert q_assign.shape == (xq_i.shape[0], index_ivf.nprobe) + del quantizer + logging.info(f"computing: {Ifn}") + logging.info(f"computing: {Dfn}") + prefetch_threads = faiss.get_num_gpus() + D_ann, I = big_batch_search( + index_ivf, + xq_i, + self.k, + verbose=10, + method="knn_function", + knn=self._knn_function, + threaded=faiss.get_num_gpus() * 8, + use_float16=True, + prefetch_threads=prefetch_threads, + computation_threads=faiss.get_num_gpus(), + q_assign=q_assign, + checkpoint=CPfn, + checkpoint_freq=7200, # in seconds + ) + assert ( + np.amin(I) >= 0 + ), f"{I}, there exists negative indices, check" + logging.info(f"saving: {Ifn}") + np.save(f, I) + logging.info(f"saving: {Dfn}") + np.save(g, D_ann) + + if os.path.exists(CPfn): + logging.info(f"removing: {CPfn}") + os.remove(CPfn) + + except FileExistsError: + logging.info(f"skipping {Ifn}, already exists") + logging.info(f"skipping {Dfn}, already exists") + continue + + def _open_index_shard(self, fn): + if fn in self.index_shards: + index_shard = self.index_shards[fn] + else: + logging.info(f"open index shard: {fn}") + index_shard = faiss.read_index( + fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY + ) + self.index_shards[fn] = index_shard + return index_shard + + def _open_sharded_index(self, index_shard_prefix=None): + if index_shard_prefix is None: + index_shard_prefix = self.index_shard_prefix + if index_shard_prefix in self.index: + return self.index[index_shard_prefix] + assert os.path.exists( + self.index_template_file + ), f"file {self.index_template_file} does not exist " + logging.info(f"open index template: {self.index_template_file}") + index = faiss.read_index(self.index_template_file) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + ilv = faiss.InvertedListsPtrVector() + for i in range(self.nshards): + fn = f"{index_shard_prefix}{i}" + assert os.path.exists(fn), f"file {fn} does not exist " + logging.info(fn) + index_shard = self._open_index_shard(fn) + il = faiss.downcast_index( + faiss.extract_index_ivf(index_shard) + ).invlists + ilv.push_back(il) + hsil = faiss.HStackInvertedLists(ilv.size(), ilv.data()) + index_ivf.replace_invlists(hsil, False) + self.ivls[index_shard_prefix] = hsil + self.index[index_shard_prefix] = index + return index + + def index_shard_stats(self): + for i in range(self.nshards): + fn = f"{self.index_shard_prefix}{i}" + assert os.path.exists(fn) + index = faiss.read_index( + fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY + ) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + il = index_ivf.invlists + il.print_stats() + + def index_stats(self): + index = self._open_sharded_index() + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + il = index_ivf.invlists + list_sizes = [il.list_size(i) for i in range(il.nlist)] + logging.info(np.max(list_sizes)) + logging.info(np.mean(list_sizes)) + logging.info(np.argmax(list_sizes)) + logging.info("index_stats:") + il.print_stats() + + def consistency_check(self): + logging.info("consistency-check") + + logging.info("index template...") + + assert os.path.exists(self.index_template_file) + index = faiss.read_index(self.index_template_file) + + offset = 0 # 2**24 + assert self.shard_size > offset + SMALL_DATA_SAMPLE + + logging.info("index shards...") + for i in range(self.nshards): + r = i * self.shard_size + offset + xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32)) + fn = f"{self.index_shard_prefix}{i}" + assert os.path.exists(fn), f"There is no index shard file {fn}" + index = self._open_index_shard(fn) + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + index_ivf.nprobe = 1 + _, I = index.search(xb, 100) + for j in range(SMALL_DATA_SAMPLE): + assert np.where(I[j] == j + r)[0].size > 0, ( + f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:" + f" {self.shard_size}" + ) + + logging.info("merged index...") + index = self._open_sharded_index() + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + index_ivf.nprobe = 1 + for i in range(self.nshards): + r = i * self.shard_size + offset + xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32)) + _, I = index.search(xb, 100) + for j in range(SMALL_DATA_SAMPLE): + assert np.where(I[j] == j + r)[0].size > 0, ( + f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:" + f" {self.shard_size}" + ) + + logging.info("search results...") + index_ivf.nprobe = self.nprobe + for i in range(0, self.xq_ds.size, self.xq_bs): + Ifn = f"{self.knn_dir}/I{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy" + assert os.path.exists(Ifn) + assert os.path.getsize(Ifn) > 0, f"The file {Ifn} is empty." + logging.info(Ifn) + I = np.load(Ifn, mmap_mode="r") + + assert I.shape[1] == self.k + assert I.shape[0] == min(self.xq_bs, self.xq_ds.size - i) + assert np.all(I[:, 1] >= 0) + + Dfn = f"{self.knn_dir}/D_approx{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy" + assert os.path.exists(Dfn) + assert os.path.getsize(Dfn) > 0, f"The file {Dfn} is empty." + logging.info(Dfn) + D = np.load(Dfn, mmap_mode="r") + assert D.shape == I.shape + + xq = next(self.xq_ds.iterate(i, SMALL_DATA_SAMPLE, np.float32)) + D_online, I_online = index.search(xq, self.k) + assert ( + np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size + / (self.k * SMALL_DATA_SAMPLE) + > 0.95 + ), ( + "the ratio is" + f" {np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size / (self.k * SMALL_DATA_SAMPLE)}" + ) + assert np.allclose( + D[:SMALL_DATA_SAMPLE].sum(axis=1), + D_online.sum(axis=1), + rtol=0.01, + ), ( + "the difference is" + f" {D[:SMALL_DATA_SAMPLE].sum(axis=1), D_online.sum(axis=1)}" + ) + + logging.info("done") diff --git a/demos/offline_ivf/run.py b/demos/offline_ivf/run.py new file mode 100644 index 0000000000..dfa831d6f0 --- /dev/null +++ b/demos/offline_ivf/run.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from utils import ( + load_config, + add_group_args, +) +from offline_ivf import OfflineIVF +import faiss +from typing import List, Callable, Dict +import submitit + + +def join_lists_in_dict(poss: List[str]) -> List[str]: + """ + Joins two lists of prod and non-prod values, checking if the prod value is already included. + If there is no non-prod list, it returns the prod list. + """ + if "non-prod" in poss.keys(): + all_poss = poss["non-prod"] + if poss["prod"][-1] not in poss["non-prod"]: + all_poss += poss["prod"] + return all_poss + else: + return poss["prod"] + + +def main( + args: argparse.Namespace, + cfg: Dict[str, str], + nprobe: int, + index_factory_str: str, +) -> None: + oivf = OfflineIVF(cfg, args, nprobe, index_factory_str) + eval(f"oivf.{args.command}()") + + +def process_options_and_run_jobs(args: argparse.Namespace) -> None: + """ + If "--cluster_run", it launches an array of jobs to the cluster using the submitit library for all the index strings. In + the case of evaluate, it launches a job for each index string and nprobe pair. Otherwise, it launches a single job + that is ran locally with the prod values for index string and nprobe. + """ + + cfg = load_config(args.config) + index_strings = cfg["index"] + nprobes = cfg["nprobe"] + if args.command == "evaluate": + if args.cluster_run: + all_nprobes = join_lists_in_dict(nprobes) + all_index_strings = join_lists_in_dict(index_strings) + for index_factory_str in all_index_strings: + for nprobe in all_nprobes: + launch_job(main, args, cfg, nprobe, index_factory_str) + else: + launch_job( + main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1] + ) + else: + if args.cluster_run: + all_index_strings = join_lists_in_dict(index_strings) + for index_factory_str in all_index_strings: + launch_job( + main, args, cfg, nprobes["prod"][-1], index_factory_str + ) + else: + launch_job( + main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1] + ) + + +def launch_job( + func: Callable, + args: argparse.Namespace, + cfg: Dict[str, str], + n_probe: int, + index_str: str, +) -> None: + """ + Launches an array of slurm jobs to the cluster using the submitit library. + """ + + if args.cluster_run: + assert args.num_nodes >= 1 + executor = submitit.AutoExecutor(folder=args.logs_dir) + + executor.update_parameters( + nodes=args.num_nodes, + gpus_per_node=args.gpus_per_node, + cpus_per_task=args.cpus_per_task, + tasks_per_node=args.tasks_per_node, + name=args.job_name, + slurm_partition=args.partition, + slurm_time=70 * 60, + ) + if args.slurm_constraint: + executor.update_parameters(slurm_constraint=args.slurm_constrain) + + job = executor.submit(func, args, cfg, n_probe, index_str) + print(f"Job id: {job.job_id}") + else: + func(args, cfg, n_probe, index_str) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + group = parser.add_argument_group("general") + + add_group_args(group, "--command", required=True, help="command to run") + add_group_args( + group, + "--config", + required=True, + help="config yaml with the dataset specs", + ) + add_group_args( + group, "--nt", type=int, default=96, help="nb search threads" + ) + add_group_args( + group, + "--no_residuals", + action="store_false", + help="set index.by_residual to False during train index.", + ) + + group = parser.add_argument_group("slurm_job") + + add_group_args( + group, + "--cluster_run", + action="store_true", + help=" if True, runs in cluster", + ) + add_group_args( + group, + "--job_name", + type=str, + default="oivf", + help="cluster job name", + ) + add_group_args( + group, + "--num_nodes", + type=str, + default=1, + help="num of nodes per job", + ) + add_group_args( + group, + "--tasks_per_node", + type=int, + default=1, + help="tasks per job", + ) + + add_group_args( + group, + "--gpus_per_node", + type=int, + default=8, + help="cluster job name", + ) + add_group_args( + group, + "--cpus_per_task", + type=int, + default=80, + help="cluster job name", + ) + + add_group_args( + group, + "--logs_dir", + type=str, + default="/checkpoint/marialomeli/offline_faiss/logs", + help="cluster job name", + ) + + add_group_args( + group, + "--slurm_constraint", + type=str, + default=None, + help="can be volta32gb for the fair cluster", + ) + + add_group_args( + group, + "--partition", + type=str, + default="learnlab", + help="specify which partition to use if ran on cluster with job arrays", + choices=[ + "learnfair", + "devlab", + "scavenge", + "learnlab", + "nllb", + "seamless", + "seamless_medium", + "learnaccel", + "onellm_low", + "learn", + "scavenge", + ], + ) + + group = parser.add_argument_group("dataset") + + add_group_args(group, "--xb", required=True, help="database vectors") + add_group_args(group, "--xq", help="query vectors") + + args = parser.parse_args() + print("args:", args) + faiss.omp_set_num_threads(args.nt) + process_options_and_run_jobs(args=args) diff --git a/demos/offline_ivf/tests/test_iterate_input.py b/demos/offline_ivf/tests/test_iterate_input.py new file mode 100644 index 0000000000..3f59071102 --- /dev/null +++ b/demos/offline_ivf/tests/test_iterate_input.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import unittest +from typing import List +from utils import load_config +from tests.testing_utils import TestDataCreator +import tempfile +from dataset import create_dataset_from_oivf_config + +DIMENSION: int = 768 +SMALL_FILE_SIZES: List[int] = [100, 210, 450] +LARGE_FILE_SIZES: List[int] = [1253, 3459, 890] +TEST_BATCH_SIZE: int = 500 +SMALL_SAMPLE_SIZE: int = 1000 +NUM_FILES: int = 3 + + +class TestUtilsMethods(unittest.TestCase): + """ + Unit tests for iterate and decreasing_matrix methods. + """ + + def test_iterate_input_file_smaller_than_batch(self): + """ + Tests when batch size is larger than the file size. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=DIMENSION, + data_type=np.float16, + file_sizes=SMALL_FILE_SIZES, + ) + data_creator.create_test_data() + args = data_creator.setup_cli() + cfg = load_config(args.config) + db_iterator = create_dataset_from_oivf_config( + cfg, args.xb + ).iterate(0, TEST_BATCH_SIZE, np.float32) + + for i in range(len(SMALL_FILE_SIZES) - 1): + vecs = next(db_iterator) + if i != 1: + self.assertEqual(vecs.shape[0], TEST_BATCH_SIZE) + else: + self.assertEqual( + vecs.shape[0], sum(SMALL_FILE_SIZES) - TEST_BATCH_SIZE + ) + + def test_iterate_input_file_larger_than_batch(self): + """ + Tests when batch size is smaller than the file size. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=DIMENSION, + data_type=np.float16, + file_sizes=LARGE_FILE_SIZES, + ) + data_creator.create_test_data() + args = data_creator.setup_cli() + cfg = load_config(args.config) + db_iterator = create_dataset_from_oivf_config( + cfg, args.xb + ).iterate(0, TEST_BATCH_SIZE, np.float32) + + for i in range(len(LARGE_FILE_SIZES) - 1): + vecs = next(db_iterator) + if i != 9: + self.assertEqual(vecs.shape[0], TEST_BATCH_SIZE) + else: + self.assertEqual( + vecs.shape[0], + sum(LARGE_FILE_SIZES) - TEST_BATCH_SIZE * 9, + ) + + def test_get_vs_iterate(self) -> None: + """ + Loads vectors with iterator and get, and checks that they match, non-aligned by file size case. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=DIMENSION, + data_type=np.float16, + file_size=SMALL_SAMPLE_SIZE, + num_files=NUM_FILES, + normalise=True, + ) + data_creator.create_test_data() + args = data_creator.setup_cli() + cfg = load_config(args.config) + ds = create_dataset_from_oivf_config(cfg, args.xb) + vecs_by_iterator = np.vstack(list(ds.iterate(0, 317, np.float32))) + self.assertEqual( + vecs_by_iterator.shape[0], SMALL_SAMPLE_SIZE * NUM_FILES + ) + vecs_by_get = ds.get(list(range(vecs_by_iterator.shape[0]))) + self.assertTrue(np.all(vecs_by_iterator == vecs_by_get)) + + def test_iterate_back(self) -> None: + """ + Loads vectors with iterator and get, and checks that they match, non-aligned by file size case. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=DIMENSION, + data_type=np.float16, + file_size=SMALL_SAMPLE_SIZE, + num_files=NUM_FILES, + normalise=True, + ) + data_creator.create_test_data() + args = data_creator.setup_cli() + cfg = load_config(args.config) + ds = create_dataset_from_oivf_config(cfg, args.xb) + vecs_by_iterator = np.vstack(list(ds.iterate(0, 317, np.float32))) + self.assertEqual( + vecs_by_iterator.shape[0], SMALL_SAMPLE_SIZE * NUM_FILES + ) + vecs_chunk = np.vstack( + [ + next(ds.iterate(i, 543, np.float32)) + for i in range(0, SMALL_SAMPLE_SIZE * NUM_FILES, 543) + ] + ) + self.assertTrue(np.all(vecs_by_iterator == vecs_chunk)) diff --git a/demos/offline_ivf/tests/test_offline_ivf.py b/demos/offline_ivf/tests/test_offline_ivf.py new file mode 100644 index 0000000000..557a0b37dd --- /dev/null +++ b/demos/offline_ivf/tests/test_offline_ivf.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import unittest +from utils import load_config +import pathlib as pl +import tempfile +from typing import List +from tests.testing_utils import TestDataCreator +from run import process_options_and_run_jobs + +KNN_RESULTS_FILE: str = ( + "/my_test_data_in_my_test_data/knn/I0000000000_IVF256_PQ4_np2.npy" +) + +A_INDEX_FILES: List[str] = [ + "I_a_gt.npy", + "D_a_gt.npy", + "vecs_a.npy", + "D_a_ann_IVF256_PQ4_np2.npy", + "I_a_ann_IVF256_PQ4_np2.npy", + "D_a_ann_refined_IVF256_PQ4_np2.npy", +] + +A_INDEX_OPQ_FILES: List[str] = [ + "I_a_gt.npy", + "D_a_gt.npy", + "vecs_a.npy", + "D_a_ann_OPQ4_IVF256_PQ4_np200.npy", + "I_a_ann_OPQ4_IVF256_PQ4_np200.npy", + "D_a_ann_refined_OPQ4_IVF256_PQ4_np200.npy", +] + + +class TestOIVF(unittest.TestCase): + """ + Unit tests for OIVF. Some of these unit tests first copy the required test data objects and puts them in the tempdir created by the context manager. + """ + + def assert_file_exists(self, filepath: str) -> None: + path = pl.Path(filepath) + self.assertEqual((str(path), path.is_file()), (str(path), True)) + + def test_consistency_check(self) -> None: + """ + Test the OIVF consistency check step, that it throws if no other steps have been ran. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float16, + index_factory=["OPQ4,IVF256,PQ4"], + training_sample=9984, + num_files=3, + file_size=10000, + nprobe=2, + k=2, + metric="METRIC_L2", + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("consistency_check") + self.assertRaises( + AssertionError, process_options_and_run_jobs, test_args + ) + + def test_train_index(self) -> None: + """ + Test the OIVF train index step, that it correctly produces the empty.faissindex template file. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float16, + index_factory=["OPQ4,IVF256,PQ4"], + training_sample=9984, + num_files=3, + file_size=10000, + nprobe=2, + k=2, + metric="METRIC_L2", + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("train_index") + cfg = load_config(test_args.config) + process_options_and_run_jobs(test_args) + empty_index = ( + cfg["output"] + + "/my_test_data/" + + cfg["index"]["prod"][-1].replace(",", "_") + + ".empty.faissindex" + ) + self.assert_file_exists(empty_index) + + def test_index_shard_equal_file_sizes(self) -> None: + """ + Test the case where the shard size is a divisor of the database size and it is equal to the first file size. + """ + + with tempfile.TemporaryDirectory() as tmpdirname: + index_shard_size = 10000 + num_files = 3 + file_size = 10000 + xb_ds_size = num_files * file_size + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float16, + index_factory=["IVF256,PQ4"], + training_sample=9984, + num_files=num_files, + file_size=file_size, + nprobe=2, + k=2, + metric="METRIC_L2", + index_shard_size=index_shard_size, + query_batch_size=1000, + evaluation_sample=100, + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("train_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("index_shard") + cfg = load_config(test_args.config) + process_options_and_run_jobs(test_args) + num_shards = xb_ds_size // index_shard_size + if xb_ds_size % index_shard_size != 0: + num_shards += 1 + print(f"number of shards:{num_shards}") + for i in range(num_shards): + index_shard_file = ( + cfg["output"] + + "/my_test_data/" + + cfg["index"]["prod"][-1].replace(",", "_") + + f".shard_{i}" + ) + self.assert_file_exists(index_shard_file) + + def test_index_shard_unequal_file_sizes(self) -> None: + """ + Test the case where the shard size is not a divisor of the database size and is greater than the first file size. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + file_sizes = [20000, 15001, 13990] + xb_ds_size = sum(file_sizes) + index_shard_size = 30000 + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float16, + index_factory=["IVF256,PQ4"], + training_sample=9984, + file_sizes=file_sizes, + nprobe=2, + k=2, + metric="METRIC_L2", + index_shard_size=index_shard_size, + evaluation_sample=100, + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("train_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("index_shard") + cfg = load_config(test_args.config) + process_options_and_run_jobs(test_args) + num_shards = xb_ds_size // index_shard_size + if xb_ds_size % index_shard_size != 0: + num_shards += 1 + print(f"number of shards:{num_shards}") + for i in range(num_shards): + index_shard_file = ( + cfg["output"] + + "/my_test_data/" + + cfg["index"]["prod"][-1].replace(",", "_") + + f".shard_{i}" + ) + self.assert_file_exists(index_shard_file) + + def test_search(self) -> None: + """ + Test search step using test data objects to bypass dependencies on previous steps. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + num_files = 3 + file_size = 10000 + query_batch_size = 10000 + total_batches = num_files * file_size // query_batch_size + if num_files * file_size % query_batch_size != 0: + total_batches += 1 + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float32, + index_factory=["IVF256,PQ4"], + training_sample=9984, + num_files=3, + file_size=10000, + nprobe=2, + k=2, + metric="METRIC_L2", + index_shard_size=10000, + query_batch_size=query_batch_size, + evaluation_sample=100, + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("train_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("index_shard") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("search") + cfg = load_config(test_args.config) + process_options_and_run_jobs(test_args) + # TODO: add check that there are number of batches total of files + knn_file = cfg["output"] + KNN_RESULTS_FILE + self.assert_file_exists(knn_file) + + def test_evaluate_without_margin(self) -> None: + """ + Test evaluate step using test data objects, no margin evaluation, single index. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float32, + index_factory=["IVF256,PQ4"], + training_sample=9984, + num_files=3, + file_size=10000, + nprobe=2, + k=2, + metric="METRIC_L2", + index_shard_size=10000, + query_batch_size=10000, + evaluation_sample=100, + with_queries_ds=True, + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("train_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("index_shard") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("merge_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("evaluate") + process_options_and_run_jobs(test_args) + common_path = tmpdirname + "/my_queries_data_in_my_test_data/eval/" + for filename in A_INDEX_FILES: + file_to_check = common_path + "/" + filename + self.assert_file_exists(file_to_check) + + def test_evaluate_without_margin_OPQ(self) -> None: + """ + Test evaluate step using test data objects, no margin evaluation, single index. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + data_creator = TestDataCreator( + tempdir=tmpdirname, + dimension=8, + data_type=np.float32, + index_factory=["OPQ4,IVF256,PQ4"], + training_sample=9984, + num_files=3, + file_size=10000, + nprobe=200, + k=2, + metric="METRIC_L2", + index_shard_size=10000, + query_batch_size=10000, + evaluation_sample=100, + with_queries_ds=True, + ) + data_creator.create_test_data() + test_args = data_creator.setup_cli("train_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("index_shard") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("merge_index") + process_options_and_run_jobs(test_args) + test_args = data_creator.setup_cli("evaluate") + process_options_and_run_jobs(test_args) + common_path = tmpdirname + "/my_queries_data_in_my_test_data/eval/" + for filename in A_INDEX_OPQ_FILES: + file_to_check = common_path + filename + self.assert_file_exists(file_to_check) diff --git a/demos/offline_ivf/tests/testing_utils.py b/demos/offline_ivf/tests/testing_utils.py new file mode 100644 index 0000000000..34751f278a --- /dev/null +++ b/demos/offline_ivf/tests/testing_utils.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import yaml +import numpy as np +from typing import Dict, List, Optional + +OIVF_TEST_ARGS: List[str] = [ + "--config", + "--xb", + "--xq", + "--command", + "--cluster_run", + "--no_residuals", +] + + +def get_test_parser(args) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + for arg in args: + parser.add_argument(arg) + return parser + + +class TestDataCreator: + def __init__( + self, + tempdir: str, + dimension: int, + data_type: np.dtype, + index_factory: Optional[List] = ["OPQ4,IVF256,PQ4"], + training_sample: Optional[int] = 9984, + index_shard_size: Optional[int] = 1000, + query_batch_size: Optional[int] = 1000, + evaluation_sample: Optional[int] = 100, + num_files: Optional[int] = None, + file_size: Optional[int] = None, + file_sizes: Optional[List] = None, + nprobe: Optional[int] = 64, + k: Optional[int] = 10, + metric: Optional[str] = "METRIC_L2", + normalise: Optional[bool] = False, + with_queries_ds: Optional[bool] = False, + evaluate_by_margin: Optional[bool] = False, + ) -> None: + self.tempdir = tempdir + self.dimension = dimension + self.data_type = np.dtype(data_type).name + self.index_factory = {"prod": index_factory} + if file_size and num_files: + self.file_sizes = [file_size for _ in range(num_files)] + elif file_sizes: + self.file_sizes = file_sizes + else: + raise ValueError("no file sizes provided") + self.num_files = len(self.file_sizes) + self.training_sample = training_sample + self.index_shard_size = index_shard_size + self.query_batch_size = query_batch_size + self.evaluation_sample = evaluation_sample + self.nprobe = {"prod": [nprobe]} + self.k = k + self.metric = metric + self.normalise = normalise + self.config_file = self.tempdir + "/config_test.yaml" + self.ds_name = "my_test_data" + self.qs_name = "my_queries_data" + self.evaluate_by_margin = evaluate_by_margin + self.with_queries_ds = with_queries_ds + + def create_test_data(self) -> None: + datafiles = self._create_data_files() + files_info = [] + + for i, file in enumerate(datafiles): + files_info.append( + { + "dtype": self.data_type, + "format": "npy", + "name": file, + "size": self.file_sizes[i], + } + ) + + config_for_yaml = { + "d": self.dimension, + "output": self.tempdir, + "index": self.index_factory, + "nprobe": self.nprobe, + "k": self.k, + "normalise": self.normalise, + "metric": self.metric, + "training_sample": self.training_sample, + "evaluation_sample": self.evaluation_sample, + "index_shard_size": self.index_shard_size, + "query_batch_size": self.query_batch_size, + "datasets": { + self.ds_name: { + "root": self.tempdir, + "size": sum(self.file_sizes), + "files": files_info, + } + }, + } + if self.evaluate_by_margin: + config_for_yaml["evaluate_by_margin"] = self.evaluate_by_margin + q_datafiles = self._create_data_files("my_q_data") + q_files_info = [] + + for i, file in enumerate(q_datafiles): + q_files_info.append( + { + "dtype": self.data_type, + "format": "npy", + "name": file, + "size": self.file_sizes[i], + } + ) + if self.with_queries_ds: + config_for_yaml["datasets"][self.qs_name] = { + "root": self.tempdir, + "size": sum(self.file_sizes), + "files": q_files_info, + } + + self._create_config_yaml(config_for_yaml) + + def setup_cli(self, command="consistency_check") -> argparse.Namespace: + parser = get_test_parser(OIVF_TEST_ARGS) + + if self.with_queries_ds: + return parser.parse_args( + [ + "--xb", + self.ds_name, + "--config", + self.config_file, + "--command", + command, + "--xq", + self.qs_name, + ] + ) + return parser.parse_args( + [ + "--xb", + self.ds_name, + "--config", + self.config_file, + "--command", + command, + ] + ) + + def _create_data_files(self, name_of_file="my_data") -> List[str]: + """ + Creates a dataset "my_test_data" with number of files (num_files), using padding in the files + name. If self.with_queries is True, it adds an extra dataset "my_queries_data" with the same number of files + as the "my_test_data". The default name for embeddings files is "my_data" + .npy. + """ + filenames = [] + for i, file_size in enumerate(self.file_sizes): + # np.random.seed(i) + db_vectors = np.random.random((file_size, self.dimension)).astype( + self.data_type + ) + filename = name_of_file + f"{i:02}" + ".npy" + filenames.append(filename) + np.save(self.tempdir + "/" + filename, db_vectors) + return filenames + + def _create_config_yaml(self, dict_file: Dict[str, str]) -> None: + """ + Creates a yaml file in dir (can be a temporary dir for tests). + """ + filename = self.tempdir + "/config_test.yaml" + with open(filename, "w") as file: + yaml.dump(dict_file, file, default_flow_style=False) diff --git a/demos/offline_ivf/utils.py b/demos/offline_ivf/utils.py new file mode 100644 index 0000000000..378af00c30 --- /dev/null +++ b/demos/offline_ivf/utils.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import os +from typing import Dict +import yaml +import faiss +from faiss.contrib.datasets import SyntheticDataset + + +def load_config(config): + assert os.path.exists(config) + with open(config, "r") as f: + return yaml.safe_load(f) + + +def faiss_sanity_check(): + ds = SyntheticDataset(256, 0, 100, 100) + xq = ds.get_queries() + xb = ds.get_database() + index_cpu = faiss.IndexFlat(ds.d) + index_gpu = faiss.index_cpu_to_all_gpus(index_cpu) + index_cpu.add(xb) + index_gpu.add(xb) + D_cpu, I_cpu = index_cpu.search(xq, 10) + D_gpu, I_gpu = index_gpu.search(xq, 10) + assert np.all(I_cpu == I_gpu), "faiss sanity check failed" + assert np.all(np.isclose(D_cpu, D_gpu)), "faiss sanity check failed" + + +def margin(sample, idx_a, idx_b, D_a_b, D_a, D_b, k, k_extract, threshold): + """ + two datasets: xa, xb; n = number of pairs + idx_a - (np,) - query vector ids in xa + idx_b - (np,) - query vector ids in xb + D_a_b - (np,) - pairwise distances between xa[idx_a] and xb[idx_b] + D_a - (np, k) - distances between vectors xa[idx_a] and corresponding nearest neighbours in xb + D_b - (np, k) - distances between vectors xb[idx_b] and corresponding nearest neighbours in xa + k - k nearest neighbours used for margin + k_extract - number of nearest neighbours of each query in xb we consider for margin calculation and filtering + threshold - margin threshold + """ + + n = sample + nk = n * k_extract + assert idx_a.shape == (n,) + idx_a_k = idx_a.repeat(k_extract) + assert idx_a_k.shape == (nk,) + assert idx_b.shape == (nk,) + assert D_a_b.shape == (nk,) + assert D_a.shape == (n, k) + assert D_b.shape == (nk, k) + mean_a = np.mean(D_a, axis=1) + assert mean_a.shape == (n,) + mean_a_k = mean_a.repeat(k_extract) + assert mean_a_k.shape == (nk,) + mean_b = np.mean(D_b, axis=1) + assert mean_b.shape == (nk,) + margin = 2 * D_a_b / (mean_a_k + mean_b) + above_threshold = margin > threshold + print(np.count_nonzero(above_threshold)) + print(idx_a_k[above_threshold]) + print(idx_b[above_threshold]) + print(margin[above_threshold]) + return margin + + +def add_group_args(group, *args, **kwargs): + return group.add_argument(*args, **kwargs) + + +def get_intersection_cardinality_frequencies( + I: np.ndarray, I_gt: np.ndarray +) -> Dict[int, int]: + """ + Computes the frequencies for the cardinalities of the intersection of neighbour indices. + """ + nq = I.shape[0] + res = [] + for ell in range(nq): + res.append(len(np.intersect1d(I[ell, :], I_gt[ell, :]))) + values, counts = np.unique(res, return_counts=True) + return dict(zip(values, counts)) + + +def is_pretransform_index(index): + if index.__class__ == faiss.IndexPreTransform: + assert hasattr(index, "chain") + return True + else: + assert not hasattr(index, "chain") + return False