From e4d594d2b1afa746b9ba0ee7ebba0490341cd911 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 20 Sep 2024 03:24:41 -0700 Subject: [PATCH] torch.distributed kmeans (#3876) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3876 Demo script for distributed kmeans. It provides a `DatasetAssign` object and shows how to run it with torch.distributed. Reviewed By: asadoughi, pankajsingh88 Differential Revision: D63013820 --- contrib/clustering.py | 32 ++--- contrib/torch/clustering.py | 2 +- contrib/torch/quantization.py | 30 ++-- contrib/torch_utils.py | 6 +- demos/demo_distributed_kmeans_torch.py | 171 +++++++++++++++++++++++ faiss/gpu/test/torch_test_contrib_gpu.py | 2 +- faiss/python/CMakeLists.txt | 4 +- faiss/python/setup.py | 2 +- tests/torch_test_contrib.py | 4 +- 9 files changed, 212 insertions(+), 41 deletions(-) create mode 100644 demos/demo_distributed_kmeans_torch.py diff --git a/contrib/clustering.py b/contrib/clustering.py index c1e8775c9b..19c2656dc1 100644 --- a/contrib/clustering.py +++ b/contrib/clustering.py @@ -155,7 +155,7 @@ def assign_to(self, centroids, weights=None): sum_per_centroid = np.zeros((nc, d), dtype='float32') if weights is None: np.add.at(sum_per_centroid, I, self.x) - else: + else: np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x) return I, D, sum_per_centroid @@ -183,7 +183,7 @@ def perform_search(self, centroids): def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None): """ assignment function for xq is sparse, xb is dense - uses a matrix multiplication. The squared norms can be provided if + uses a matrix multiplication. The squared norms can be provided if available. """ nq = xq.shape[0] @@ -271,7 +271,7 @@ def assign_to(self, centroids, weights=None): if weights is None: weights = np.ones(n, dtype='float32') nc = len(centroids) - + m = scipy.sparse.csc_matrix( (weights, I, np.arange(n + 1)), shape=(nc, n)) @@ -289,7 +289,7 @@ def check_if_torch(x): if x.__class__ == np.ndarray: return False import torch - if isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor): return True raise NotImplementedError(f"Unknown tensor type {type(x)}") @@ -307,11 +307,11 @@ def reassign_centroids(hassign, centroids, rs=None): if len(empty_cents) == 0: return 0 - if is_torch: + if is_torch: import torch - fac = torch.ones_like(centroids[0]) - else: - fac = np.ones_like(centroids[0]) + fac = torch.ones_like(centroids[0]) + else: + fac = np.ones_like(centroids[0]) fac[::2] += 1 / 1024. fac[1::2] -= 1 / 1024. @@ -347,9 +347,9 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, return_stats=False): """Pure python kmeans implementation. Follows the Faiss C++ version quite closely, but takes a DatasetAssign instead of a training data - matrix. Also redo is not implemented. - - For the torch implementation, the centroids are tensors (possibly on GPU), + matrix. Also redo is not implemented. + + For the torch implementation, the centroids are tensors (possibly on GPU), but the indices remain numpy on CPU. """ n, d = data.count(), data.dim() @@ -382,7 +382,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, t_search_tot += time.time() - t0s; err = D.sum() - if is_torch: + if is_torch: err = err.item() obj.append(err) @@ -390,7 +390,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, fac = hassign.reshape(-1, 1).astype('float32') fac[fac == 0] = 1 # quiet warning - if is_torch: + if is_torch: import torch fac = torch.from_numpy(fac).to(sums.device) @@ -402,7 +402,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, "obj": err, "time": (time.time() - t0), "time_search": t_search_tot, - "imbalance_factor": imbalance_factor (k, assign), + "imbalance_factor": imbalance_factor(k, assign), "nsplit": nsplit } @@ -416,10 +416,10 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, if checkpoint is not None: log('storing centroids in', checkpoint) - if is_torch: + if is_torch: import torch torch.save(centroids, checkpoint) - else: + else: np.save(checkpoint, centroids) if return_stats: diff --git a/contrib/torch/clustering.py b/contrib/torch/clustering.py index bdaa0a1f9a..32d3a75732 100644 --- a/contrib/torch/clustering.py +++ b/contrib/torch/clustering.py @@ -52,7 +52,7 @@ def assign_to(self, centroids, weights=None): class DatasetAssignGPU(DatasetAssign): - def __init__(self, res, x): + def __init__(self, res, x): DatasetAssign.__init__(self, x) self.res = res diff --git a/contrib/torch/quantization.py b/contrib/torch/quantization.py index 550c17dbb7..8d6b17fa8f 100644 --- a/contrib/torch/quantization.py +++ b/contrib/torch/quantization.py @@ -8,46 +8,46 @@ """ import numpy as np -import torch +import torch import faiss from faiss.contrib import torch_utils -class Quantizer: +class Quantizer: - def __init__(self, d, code_size): - self.d = d + def __init__(self, d, code_size): + self.d = d self.code_size = code_size - def train(self, x): + def train(self, x): pass - - def encode(self, x): + + def encode(self, x): pass - - def decode(self, x): + + def decode(self, x): pass -class VectorQuantizer(Quantizer): +class VectorQuantizer(Quantizer): - def __init__(self, d, k): + def __init__(self, d, k): code_size = int(torch.ceil(torch.log2(k) / 8)) Quantizer.__init__(d, code_size) self.k = k - def train(self, x): + def train(self, x): pass -class ProductQuantizer(Quantizer): +class ProductQuantizer(Quantizer): - def __init__(self, d, M, nbits): + def __init__(self, d, M, nbits): code_size = int(torch.ceil(M * nbits / 8)) Quantizer.__init__(d, code_size) self.M = M self.nbits = nbits - def train(self, x): + def train(self, x): pass diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 21e6439726..9b4855ea3a 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -73,7 +73,7 @@ def swig_ptr_from_IndicesTensor(x): x.untyped_storage().data_ptr() + x.storage_offset() * 8) ################################################################## -# utilities +# utilities ################################################################## @contextlib.contextmanager @@ -519,7 +519,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): assert xb.is_contiguous() assert xb.dtype == torch.float32 assert not xb.is_cuda, "use knn_gpu for GPU tensors" - + nq, d2 = xq.size() assert d2 == d assert xq.is_contiguous() @@ -543,7 +543,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): xq_ptr, xb_ptr, d, nq, nb, k, D_ptr, I_ptr ) - else: + else: faiss.knn_extra_metrics( xq_ptr, xb_ptr, d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr diff --git a/demos/demo_distributed_kmeans_torch.py b/demos/demo_distributed_kmeans_torch.py new file mode 100644 index 0000000000..e868570dc9 --- /dev/null +++ b/demos/demo_distributed_kmeans_torch.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import torch +import torch.distributed + +import faiss + +# did not manage to run it in fbcode, so graphted in github version with torchrun +# +# import new_contrib +new_contrib=faiss.contrib + +import new_contrib.torch_utils +from new_contrib.torch import clustering +from new_contrib import datasets + + +class DatasetAssignDistributedGPU(clustering.DatasetAssign): + """ + There is one instance per worker, each worker has a dataset shard. + The non-master workers do not run through the k-means function, so some + code has run it to keep the workers in sync. + """ + + + def __init__(self, res, x, rank, nproc): + clustering.DatasetAssign.__init__(self, x) + self.res = res + self.rank = rank + self.nproc = nproc + self.device = x.device + + n = len(x) + sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64) + sizes[rank] = n + torch.distributed.all_gather([sizes[i:i+1] for i in range(nproc)], sizes[rank:rank+1]) + self.sizes = sizes.cpu().numpy() + + # begin & end of each shard + self.cs = np.zeros(nproc + 1, dtype='int64') + self.cs[1:] = np.cumsum(self.sizes) + + def count(self): + return int(self.sizes.sum()) + + def int_to_slaves(self, i): + " broadcast an int to all workers " + rank, nproc = self.rank, self.nproc + tab = torch.zeros(1, device=self.device, dtype=torch.int64) + if rank == 0: + tab[0] = i + else: + assert i is None + torch.distributed.broadcast(tab, 0) + return tab.item() + + def get_subset(self, indices): + rank, nproc = self.rank, self.nproc + assert rank == 0 or indices is None + + len_indices = self.int_to_slaves(len(indices) if rank == 0 else None) + + if rank == 0: + indices = torch.from_numpy(indices).to(self.device) + else: + indices = torch.zeros(len_indices, dtype=torch.int64, device=self.device) + torch.distributed.broadcast(indices, 0) + + # select subset of indices + + i0, i1 = self.cs[rank], self.cs[rank + 1] + + mask = torch.logical_and(indices < i1, indices >= i0) + output = torch.zeros(len_indices, self.x.shape[1], dtype=self.x.dtype, device=self.device) + output[mask] = self.x[indices[mask] - i0] + torch.distributed.reduce(output, 0) # sum + if rank == 0: + return output + else: + return None + + def perform_search(self, centroids): + assert False, "shoudl not be called" + + def assign_to(self, centroids, weights=None): + assert weights is None + + rank, nproc = self.rank, self.nproc + assert rank == 0 or centroids is None + nc = self.int_to_slaves(len(centroids) if rank == 0 else None) + + if rank != 0: + centroids = torch.zeros(nc, self.x.shape[1], dtype=self.x.dtype, device=self.device) + torch.distributed.broadcast(centroids, 0) + + # perform search + D, I = faiss.knn_gpu(self.res, self.x, centroids, 1, device=self.device.index) + + I = I.ravel() + D = D.ravel() + + sum_per_centroid = torch.zeros_like(centroids) + if weights is None: + sum_per_centroid.index_add_(0, I, self.x) + else: + sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) + + torch.distributed.reduce(sum_per_centroid, 0) + + if rank == 0: + # gather deos not support tensors of different sizes + # should be implemented with point-to-point communication + assert np.all(self.sizes == self.sizes[0]) + all_I = torch.zeros(self.count(), dtype=I.dtype, device=self.device) + all_D = torch.zeros(self.count(), dtype=D.dtype, device=self.device) + torch.distributed.gather( + I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], + dst=0, + ) + torch.distributed.gather( + D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], + dst=0, + ) + return all_I.cpu().numpy(), all_D, sum_per_centroid + else: + torch.distributed.gather(I, None, dst=0) + torch.distributed.gather(D, None, dst=0) + return None + + +if __name__ == "__main__": + + torch.distributed.init_process_group( + backend="nccl", + ) + rank = torch.distributed.get_rank() + nproc = torch.distributed.get_world_size() + + # current version does only support shards of the same size + ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank) + x = ds.get_train() + + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + x = torch.from_numpy(x).to(device) + res = faiss.StandardGpuResources() + + da = DatasetAssignDistributedGPU(res, x, rank, nproc) + + k = 1000 + niter = 25 + + if rank == 0: + print(f"sizes = {da.sizes}") + centroids, iteration_stats = clustering.kmeans(k, da, niter=niter, return_stats=True) + print("clusters:", centroids.cpu().numpy()) + else: + # make sure the iterations are aligned with master + da.get_subset(None) + + for it in range(niter): + da.assign_to(None) + + torch.distributed.barrier() + print("Done") diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index 6c58b37b25..1f6f27ecca 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -501,5 +501,5 @@ def test_python_kmeans(self): err2 = faiss.knn(xt, centroids, 1)[0].sum() # 33498.332 33380.477 - print(err, err2) + print(err, err2) self.assertLess(err2, err * 1.1) diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index c7b22d19c8..aea99af795 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -261,5 +261,5 @@ configure_file(gpu_wrappers.py gpu_wrappers.py COPYONLY) configure_file(extra_wrappers.py extra_wrappers.py COPYONLY) configure_file(array_conversions.py array_conversions.py COPYONLY) -file(GLOB files "${PROJECT_SOURCE_DIR}/../../contrib/*.py") -file(COPY ${files} DESTINATION contrib/) +# file(GLOB files "${PROJECT_SOURCE_DIR}/../../contrib/*.py") +file(COPY ${PROJECT_SOURCE_DIR}/../../contrib DESTINATION .) diff --git a/faiss/python/setup.py b/faiss/python/setup.py index ea623ee1b2..ad4733472d 100644 --- a/faiss/python/setup.py +++ b/faiss/python/setup.py @@ -84,7 +84,7 @@ keywords='search nearest neighbors', install_requires=['numpy', 'packaging'], - packages=['faiss', 'faiss.contrib'], + packages=['faiss', 'faiss.contrib', 'faiss.contrib.torch'], package_data={ 'faiss': ['*.so', '*.pyd'], }, diff --git a/tests/torch_test_contrib.py b/tests/torch_test_contrib.py index 41311d6c78..d3dd8c0ae8 100644 --- a/tests/torch_test_contrib.py +++ b/tests/torch_test_contrib.py @@ -9,7 +9,7 @@ import faiss # usort: skip import faiss.contrib.torch_utils # usort: skip -from faiss.contrib import datasets +from faiss.contrib import datasets from faiss.contrib.torch import clustering @@ -365,7 +365,7 @@ def test_python_kmeans(self): km_ref.train(xt) err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() - xt_torch = torch.from_numpy(xt) + xt_torch = torch.from_numpy(xt) data = clustering.DatasetAssign(xt_torch) centroids = clustering.kmeans(100, data, 10) centroids = centroids.numpy()