From 866e3fe16400a146cf530d9d76908d35a3925f4d Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 20 Sep 2024 06:00:35 -0700 Subject: [PATCH] begin torch_contrib Summary: The contrib.torch subdirectory is intended to receive modules in python that are useful for similarity search and that apply to CPU or GPU pytorch tensors. The current version includes CPU clustering on torch tensors. To be added: * implementation of PQ Differential Revision: D62759207 --- contrib/clustering.py | 45 ++++++++++++++--- contrib/torch/README.md | 6 +++ contrib/torch/__init__.py | 0 contrib/torch/clustering.py | 60 +++++++++++++++++++++++ contrib/torch/quantization.py | 53 ++++++++++++++++++++ contrib/torch_utils.py | 62 ++++++++++++++++++++++++ faiss/gpu/test/torch_test_contrib_gpu.py | 33 +++++++++++++ tests/test_contrib.py | 3 +- tests/torch_test_contrib.py | 30 ++++++++++++ 9 files changed, 282 insertions(+), 10 deletions(-) create mode 100644 contrib/torch/README.md create mode 100644 contrib/torch/__init__.py create mode 100644 contrib/torch/clustering.py create mode 100644 contrib/torch/quantization.py diff --git a/contrib/clustering.py b/contrib/clustering.py index 79b6b05a5f..c1e8775c9b 100644 --- a/contrib/clustering.py +++ b/contrib/clustering.py @@ -285,25 +285,40 @@ def imbalance_factor(k, assign): return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign)) +def check_if_torch(x): + if x.__class__ == np.ndarray: + return False + import torch + if isinstance(x, torch.Tensor): + return True + raise NotImplementedError(f"Unknown tensor type {type(x)}") + + def reassign_centroids(hassign, centroids, rs=None): """ reassign centroids when some of them collapse """ if rs is None: rs = np.random k, d = centroids.shape nsplit = 0 + is_torch = check_if_torch(centroids) + empty_cents = np.where(hassign == 0)[0] - if empty_cents.size == 0: + if len(empty_cents) == 0: return 0 - fac = np.ones(d) + if is_torch: + import torch + fac = torch.ones_like(centroids[0]) + else: + fac = np.ones_like(centroids[0]) fac[::2] += 1 / 1024. fac[1::2] -= 1 / 1024. # this is a single pass unless there are more than k/2 # empty centroids - while empty_cents.size > 0: - # choose which centroids to split + while len(empty_cents) > 0: + # choose which centroids to split (numpy) probas = hassign.astype('float') - 1 probas[probas < 0] = 0 probas /= probas.sum() @@ -327,13 +342,17 @@ def reassign_centroids(hassign, centroids, rs=None): return nsplit + def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, return_stats=False): """Pure python kmeans implementation. Follows the Faiss C++ version quite closely, but takes a DatasetAssign instead of a training data - matrix. Also redo is not implemented. """ + matrix. Also redo is not implemented. + + For the torch implementation, the centroids are tensors (possibly on GPU), + but the indices remain numpy on CPU. + """ n, d = data.count(), data.dim() - log = print if verbose else print_nop log(("Clustering %d points in %dD to %d clusters, " + @@ -345,6 +364,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, # initialization perm = rs.choice(n, size=k, replace=False) centroids = data.get_subset(perm) + is_torch = check_if_torch(centroids) iteration_stats = [] @@ -362,12 +382,17 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, t_search_tot += time.time() - t0s; err = D.sum() + if is_torch: + err = err.item() obj.append(err) hassign = np.bincount(assign, minlength=k) fac = hassign.reshape(-1, 1).astype('float32') - fac[fac == 0] = 1 # quiet warning + fac[fac == 0] = 1 # quiet warning + if is_torch: + import torch + fac = torch.from_numpy(fac).to(sums.device) centroids = sums / fac @@ -391,7 +416,11 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, if checkpoint is not None: log('storing centroids in', checkpoint) - np.save(checkpoint, centroids) + if is_torch: + import torch + torch.save(centroids, checkpoint) + else: + np.save(checkpoint, centroids) if return_stats: return centroids, iteration_stats diff --git a/contrib/torch/README.md b/contrib/torch/README.md new file mode 100644 index 0000000000..470d062250 --- /dev/null +++ b/contrib/torch/README.md @@ -0,0 +1,6 @@ +# The Torch contrib + +This contrib directory contains a few Pytorch routines that +are useful for similarity search. They do not necessarily depend on Faiss. + +The code is designed to work with CPU and GPU tensors. diff --git a/contrib/torch/__init__.py b/contrib/torch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contrib/torch/clustering.py b/contrib/torch/clustering.py new file mode 100644 index 0000000000..bdaa0a1f9a --- /dev/null +++ b/contrib/torch/clustering.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This contrib module contains Pytorch code for k-means clustering +""" +import faiss +import faiss.contrib.torch_utils +import torch + +# the kmeans can produce both torch and numpy centroids +from faiss.contrib.clustering import DatasetAssign, kmeans + + +class DatasetAssign: + """Wrapper for a tensor that offers a function to assign the vectors + to centroids. All other implementations offer the same interface""" + + def __init__(self, x): + self.x = x + + def count(self): + return self.x.shape[0] + + def dim(self): + return self.x.shape[1] + + def get_subset(self, indices): + return self.x[indices] + + def perform_search(self, centroids): + return faiss.knn(self.x, centroids, 1) + + def assign_to(self, centroids, weights=None): + D, I = self.perform_search(centroids) + + I = I.ravel() + D = D.ravel() + nc, d = centroids.shape + + sum_per_centroid = torch.zeros_like(centroids) + if weights is None: + sum_per_centroid.index_add_(0, I, self.x) + else: + sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) + + # the indices are still in numpy. + return I.cpu().numpy(), D, sum_per_centroid + + +class DatasetAssignGPU(DatasetAssign): + + def __init__(self, res, x): + DatasetAssign.__init__(self, x) + self.res = res + + def perform_search(self, centroids): + return faiss.knn_gpu(self.res, self.x, centroids, 1) diff --git a/contrib/torch/quantization.py b/contrib/torch/quantization.py new file mode 100644 index 0000000000..550c17dbb7 --- /dev/null +++ b/contrib/torch/quantization.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This contrib module contains Pytorch code for quantization. +""" + +import numpy as np +import torch +import faiss + +from faiss.contrib import torch_utils + + +class Quantizer: + + def __init__(self, d, code_size): + self.d = d + self.code_size = code_size + + def train(self, x): + pass + + def encode(self, x): + pass + + def decode(self, x): + pass + + +class VectorQuantizer(Quantizer): + + def __init__(self, d, k): + code_size = int(torch.ceil(torch.log2(k) / 8)) + Quantizer.__init__(d, code_size) + self.k = k + + def train(self, x): + pass + + +class ProductQuantizer(Quantizer): + + def __init__(self, d, M, nbits): + code_size = int(torch.ceil(M * nbits / 8)) + Quantizer.__init__(d, code_size) + self.M = M + self.nbits = nbits + + def train(self, x): + pass diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 18f136e914..21e6439726 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -28,6 +28,10 @@ import sys import numpy as np +################################################################## +# Equivalent of swig_ptr for Torch tensors +################################################################## + def swig_ptr_from_UInt8Tensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -35,6 +39,7 @@ def swig_ptr_from_UInt8Tensor(x): return faiss.cast_integer_to_uint8_ptr( x.untyped_storage().data_ptr() + x.storage_offset()) + def swig_ptr_from_HalfTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -43,6 +48,7 @@ def swig_ptr_from_HalfTensor(x): return faiss.cast_integer_to_void_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 2) + def swig_ptr_from_FloatTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -50,6 +56,7 @@ def swig_ptr_from_FloatTensor(x): return faiss.cast_integer_to_float_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 4) + def swig_ptr_from_IntTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -57,6 +64,7 @@ def swig_ptr_from_IntTensor(x): return faiss.cast_integer_to_int_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 4) + def swig_ptr_from_IndicesTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -64,6 +72,10 @@ def swig_ptr_from_IndicesTensor(x): return faiss.cast_integer_to_idx_t_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 8) +################################################################## +# utilities +################################################################## + @contextlib.contextmanager def using_stream(res, pytorch_stream=None): """ Creates a scoping object to make Faiss GPU use the same stream @@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement, setattr(the_class, name + '_numpy', orig_method) setattr(the_class, name, replacement) +################################################################## +# Setup wrappers +################################################################## + def handle_torch_Index(the_class): def torch_replacement_add(self, x): if type(x) is np.ndarray: @@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None): handle_torch_Index(the_class) +# allows torch tensor usage with knn +def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): + if type(xb) is np.ndarray: + # Forward to faiss __init__.py base method + return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg) + + nb, d = xb.size() + assert xb.is_contiguous() + assert xb.dtype == torch.float32 + assert not xb.is_cuda, "use knn_gpu for GPU tensors" + + nq, d2 = xq.size() + assert d2 == d + assert xq.is_contiguous() + assert xq.dtype == torch.float32 + assert not xq.is_cuda, "use knn_gpu for GPU tensors" + + D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) + I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) + I_ptr = swig_ptr_from_IndicesTensor(I) + D_ptr = swig_ptr_from_FloatTensor(D) + xb_ptr = swig_ptr_from_FloatTensor(xb) + xq_ptr = swig_ptr_from_FloatTensor(xq) + + if metric == faiss.METRIC_L2: + faiss.knn_L2sqr( + xq_ptr, xb_ptr, + d, nq, nb, k, D_ptr, I_ptr + ) + elif metric == faiss.METRIC_INNER_PRODUCT: + faiss.knn_inner_product( + xq_ptr, xb_ptr, + d, nq, nb, k, D_ptr, I_ptr + ) + else: + faiss.knn_extra_metrics( + xq_ptr, xb_ptr, + d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr + ) + + return D, I + + +torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True) + + # allows torch tensor usage with bfKnn def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False): if type(xb) is np.ndarray: diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index f1a92c33b3..6c58b37b25 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -9,6 +9,10 @@ import faiss import faiss.contrib.torch_utils +from faiss.contrib import datasets +from faiss.contrib.torch import clustering + + def to_column_major_torch(x): if hasattr(torch, 'contiguous_format'): return x.t().clone(memory_format=torch.contiguous_format).t() @@ -377,6 +381,7 @@ def test_knn_gpu_datatypes(self, use_raft=False): self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I)) self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3) + class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase): def test_pairwise_distance_gpu(self): torch.manual_seed(10) @@ -470,3 +475,31 @@ def test_pairwise_distance_gpu(self): D, _ = torch.sort(D, dim=1) self.assertLess((D.cpu() - gt_D[4:8]).abs().max(), 1e-4) + + +class TestClustering(unittest.TestCase): + + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + # CPU baseline + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + xt_torch = torch.from_numpy(xt).to("cuda:0") + res = faiss.StandardGpuResources() + data = clustering.DatasetAssignGPU(res, xt_torch) + centroids = clustering.kmeans(100, data, 10) + centroids = centroids.cpu().numpy() + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # 33498.332 33380.477 + print(err, err2) + self.assertLess(err2, err * 1.1) diff --git a/tests/test_contrib.py b/tests/test_contrib.py index fa5d85ab51..a2eb7046bd 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -26,8 +26,7 @@ range_search_max_results, exponential_query_iterator from contextlib import contextmanager -@unittest.skipIf(platform.python_version_tuple()[0] < '3', - 'Submodule import broken in python 2.') + class TestComputeGT(unittest.TestCase): def do_test_compute_GT(self, metric=faiss.METRIC_L2): diff --git a/tests/torch_test_contrib.py b/tests/torch_test_contrib.py index e26a79c6bb..41311d6c78 100644 --- a/tests/torch_test_contrib.py +++ b/tests/torch_test_contrib.py @@ -9,6 +9,10 @@ import faiss # usort: skip import faiss.contrib.torch_utils # usort: skip +from faiss.contrib import datasets +from faiss.contrib.torch import clustering + + class TestTorchUtilsCPU(unittest.TestCase): # tests add, search @@ -344,3 +348,29 @@ def test_non_contiguous(self): # disabled since we now accept non-contiguous arrays # with self.assertRaises(ValueError): # index.add(xb.numpy()) + + +class TestClustering(unittest.TestCase): + + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + xt_torch = torch.from_numpy(xt) + data = clustering.DatasetAssign(xt_torch) + centroids = clustering.kmeans(100, data, 10) + centroids = centroids.numpy() + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # 33498.332 33380.477 + # print(err, err2) 1/0 + self.assertLess(err2, err * 1.1)