Skip to content

Commit

Permalink
torch.distributed kmeans (#3876)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3876

Demo script for distributed kmeans. It provides a `DatasetAssign` object and shows how to run it with torch.distributed.

Reviewed By: asadoughi, pankajsingh88

Differential Revision: D63013820
  • Loading branch information
mdouze authored and facebook-github-bot committed Sep 20, 2024
1 parent 866e3fe commit c96f0f4
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 43 deletions.
32 changes: 16 additions & 16 deletions contrib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def assign_to(self, centroids, weights=None):
sum_per_centroid = np.zeros((nc, d), dtype='float32')
if weights is None:
np.add.at(sum_per_centroid, I, self.x)
else:
else:
np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x)

return I, D, sum_per_centroid
Expand Down Expand Up @@ -183,7 +183,7 @@ def perform_search(self, centroids):

def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None):
""" assignment function for xq is sparse, xb is dense
uses a matrix multiplication. The squared norms can be provided if
uses a matrix multiplication. The squared norms can be provided if
available.
"""
nq = xq.shape[0]
Expand Down Expand Up @@ -271,7 +271,7 @@ def assign_to(self, centroids, weights=None):
if weights is None:
weights = np.ones(n, dtype='float32')
nc = len(centroids)

m = scipy.sparse.csc_matrix(
(weights, I, np.arange(n + 1)),
shape=(nc, n))
Expand All @@ -289,7 +289,7 @@ def check_if_torch(x):
if x.__class__ == np.ndarray:
return False
import torch
if isinstance(x, torch.Tensor):
if isinstance(x, torch.Tensor):
return True
raise NotImplementedError(f"Unknown tensor type {type(x)}")

Expand All @@ -307,11 +307,11 @@ def reassign_centroids(hassign, centroids, rs=None):
if len(empty_cents) == 0:
return 0

if is_torch:
if is_torch:
import torch
fac = torch.ones_like(centroids[0])
else:
fac = np.ones_like(centroids[0])
fac = torch.ones_like(centroids[0])
else:
fac = np.ones_like(centroids[0])
fac[::2] += 1 / 1024.
fac[1::2] -= 1 / 1024.

Expand Down Expand Up @@ -347,9 +347,9 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
return_stats=False):
"""Pure python kmeans implementation. Follows the Faiss C++ version
quite closely, but takes a DatasetAssign instead of a training data
matrix. Also redo is not implemented.
For the torch implementation, the centroids are tensors (possibly on GPU),
matrix. Also redo is not implemented.
For the torch implementation, the centroids are tensors (possibly on GPU),
but the indices remain numpy on CPU.
"""
n, d = data.count(), data.dim()
Expand Down Expand Up @@ -382,15 +382,15 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
t_search_tot += time.time() - t0s;

err = D.sum()
if is_torch:
if is_torch:
err = err.item()
obj.append(err)

hassign = np.bincount(assign, minlength=k)

fac = hassign.reshape(-1, 1).astype('float32')
fac[fac == 0] = 1 # quiet warning
if is_torch:
if is_torch:
import torch
fac = torch.from_numpy(fac).to(sums.device)

Expand All @@ -402,7 +402,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
"obj": err,
"time": (time.time() - t0),
"time_search": t_search_tot,
"imbalance_factor": imbalance_factor (k, assign),
"imbalance_factor": imbalance_factor(k, assign),
"nsplit": nsplit
}

Expand All @@ -416,10 +416,10 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,

if checkpoint is not None:
log('storing centroids in', checkpoint)
if is_torch:
if is_torch:
import torch
torch.save(centroids, checkpoint)
else:
else:
np.save(checkpoint, centroids)

if return_stats:
Expand Down
5 changes: 2 additions & 3 deletions contrib/torch/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import torch

# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import DatasetAssign, kmeans

from faiss.contrib.clustering import kmeans

class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
Expand Down Expand Up @@ -52,7 +51,7 @@ def assign_to(self, centroids, weights=None):

class DatasetAssignGPU(DatasetAssign):

def __init__(self, res, x):
def __init__(self, res, x):
DatasetAssign.__init__(self, x)
self.res = res

Expand Down
30 changes: 15 additions & 15 deletions contrib/torch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,46 @@
"""

import numpy as np
import torch
import torch
import faiss

from faiss.contrib import torch_utils


class Quantizer:
class Quantizer:

def __init__(self, d, code_size):
self.d = d
def __init__(self, d, code_size):
self.d = d
self.code_size = code_size

def train(self, x):
def train(self, x):
pass
def encode(self, x):

def encode(self, x):
pass
def decode(self, x):

def decode(self, x):
pass


class VectorQuantizer(Quantizer):
class VectorQuantizer(Quantizer):

def __init__(self, d, k):
def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k

def train(self, x):
def train(self, x):
pass


class ProductQuantizer(Quantizer):
class ProductQuantizer(Quantizer):

def __init__(self, d, M, nbits):
def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
self.M = M
self.nbits = nbits

def train(self, x):
def train(self, x):
pass
6 changes: 3 additions & 3 deletions contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def swig_ptr_from_IndicesTensor(x):
x.untyped_storage().data_ptr() + x.storage_offset() * 8)

##################################################################
# utilities
# utilities
##################################################################

@contextlib.contextmanager
Expand Down Expand Up @@ -519,7 +519,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
assert xb.is_contiguous()
assert xb.dtype == torch.float32
assert not xb.is_cuda, "use knn_gpu for GPU tensors"

nq, d2 = xq.size()
assert d2 == d
assert xq.is_contiguous()
Expand All @@ -543,7 +543,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
else:
else:
faiss.knn_extra_metrics(
xq_ptr, xb_ptr,
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
Expand Down
173 changes: 173 additions & 0 deletions demos/demo_distributed_kmeans_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

import torch
import torch.distributed

import faiss

import faiss.contrib.torch_utils
from faiss.contrib.torch import clustering
from faiss.contrib import datasets


class DatasetAssignDistributedGPU(clustering.DatasetAssign):
"""
There is one instance per worker, each worker has a dataset shard.
The non-master workers do not run through the k-means function, so some
code has run it to keep the workers in sync.
"""

def __init__(self, res, x, rank, nproc):
clustering.DatasetAssign.__init__(self, x)
self.res = res
self.rank = rank
self.nproc = nproc
self.device = x.device

n = len(x)
sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64)
sizes[rank] = n
torch.distributed.all_gather(
[sizes[i:i + 1] for i in range(nproc)], sizes[rank:rank + 1])
self.sizes = sizes.cpu().numpy()

# begin & end of each shard
self.cs = np.zeros(nproc + 1, dtype='int64')
self.cs[1:] = np.cumsum(self.sizes)

def count(self):
return int(self.sizes.sum())

def int_to_slaves(self, i):
" broadcast an int to all workers "
rank = self.rank
tab = torch.zeros(1, device=self.device, dtype=torch.int64)
if rank == 0:
tab[0] = i
else:
assert i is None
torch.distributed.broadcast(tab, 0)
return tab.item()

def get_subset(self, indices):
rank = self.rank
assert rank == 0 or indices is None

len_indices = self.int_to_slaves(len(indices) if rank == 0 else None)

if rank == 0:
indices = torch.from_numpy(indices).to(self.device)
else:
indices = torch.zeros(
len_indices, dtype=torch.int64, device=self.device)
torch.distributed.broadcast(indices, 0)

# select subset of indices

i0, i1 = self.cs[rank], self.cs[rank + 1]

mask = torch.logical_and(indices < i1, indices >= i0)
output = torch.zeros(
len_indices, self.x.shape[1],
dtype=self.x.dtype, device=self.device)
output[mask] = self.x[indices[mask] - i0]
torch.distributed.reduce(output, 0) # sum
if rank == 0:
return output
else:
return None

def perform_search(self, centroids):
assert False, "shoudl not be called"

def assign_to(self, centroids, weights=None):
assert weights is None

rank, nproc = self.rank, self.nproc
assert rank == 0 or centroids is None
nc = self.int_to_slaves(len(centroids) if rank == 0 else None)

if rank != 0:
centroids = torch.zeros(
nc, self.x.shape[1], dtype=self.x.dtype, device=self.device)
torch.distributed.broadcast(centroids, 0)

# perform search
D, I = faiss.knn_gpu(
self.res, self.x, centroids, 1, device=self.device.index)

I = I.ravel()
D = D.ravel()

sum_per_centroid = torch.zeros_like(centroids)
if weights is None:
sum_per_centroid.index_add_(0, I, self.x)
else:
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])

torch.distributed.reduce(sum_per_centroid, 0)

if rank == 0:
# gather deos not support tensors of different sizes
# should be implemented with point-to-point communication
assert np.all(self.sizes == self.sizes[0])
device = self.device
all_I = torch.zeros(self.count(), dtype=I.dtype, device=device)
all_D = torch.zeros(self.count(), dtype=D.dtype, device=device)
torch.distributed.gather(
I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)],
dst=0,
)
torch.distributed.gather(
D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)],
dst=0,
)
return all_I.cpu().numpy(), all_D, sum_per_centroid
else:
torch.distributed.gather(I, None, dst=0)
torch.distributed.gather(D, None, dst=0)
return None


if __name__ == "__main__":

torch.distributed.init_process_group(
backend="nccl",
)
rank = torch.distributed.get_rank()
nproc = torch.distributed.get_world_size()

# current version does only support shards of the same size
ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank)
x = ds.get_train()

device = torch.device(f"cuda:{rank}")

torch.cuda.set_device(device)
x = torch.from_numpy(x).to(device)
res = faiss.StandardGpuResources()

da = DatasetAssignDistributedGPU(res, x, rank, nproc)

k = 1000
niter = 25

if rank == 0:
print(f"sizes = {da.sizes}")
centroids, iteration_stats = clustering.kmeans(
k, da, niter=niter, return_stats=True)
print("clusters:", centroids.cpu().numpy())
else:
# make sure the iterations are aligned with master
da.get_subset(None)

for _ in range(niter):
da.assign_to(None)

torch.distributed.barrier()
print("Done")
2 changes: 1 addition & 1 deletion faiss/gpu/test/torch_test_contrib_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,5 +501,5 @@ def test_python_kmeans(self):
err2 = faiss.knn(xt, centroids, 1)[0].sum()

# 33498.332 33380.477
print(err, err2)
print(err, err2)
self.assertLess(err2, err * 1.1)
Loading

0 comments on commit c96f0f4

Please sign in to comment.