Skip to content

Commit

Permalink
begin torch_contrib
Browse files Browse the repository at this point in the history
Summary:
The contrib.torch subdirectory is intended to receive modules in python that are useful for similarity search and that apply to CPU or GPU pytorch tensors.

The current version includes CPU clustering on torch tensors. To be added:
* implementation of PQ

Differential Revision: D62759207
  • Loading branch information
mdouze authored and facebook-github-bot committed Sep 20, 2024
1 parent 281c604 commit 866e3fe
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 10 deletions.
45 changes: 37 additions & 8 deletions contrib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,25 +285,40 @@ def imbalance_factor(k, assign):
return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign))


def check_if_torch(x):
if x.__class__ == np.ndarray:
return False
import torch
if isinstance(x, torch.Tensor):
return True
raise NotImplementedError(f"Unknown tensor type {type(x)}")


def reassign_centroids(hassign, centroids, rs=None):
""" reassign centroids when some of them collapse """
if rs is None:
rs = np.random
k, d = centroids.shape
nsplit = 0
is_torch = check_if_torch(centroids)

empty_cents = np.where(hassign == 0)[0]

if empty_cents.size == 0:
if len(empty_cents) == 0:
return 0

fac = np.ones(d)
if is_torch:
import torch
fac = torch.ones_like(centroids[0])
else:
fac = np.ones_like(centroids[0])
fac[::2] += 1 / 1024.
fac[1::2] -= 1 / 1024.

# this is a single pass unless there are more than k/2
# empty centroids
while empty_cents.size > 0:
# choose which centroids to split
while len(empty_cents) > 0:
# choose which centroids to split (numpy)
probas = hassign.astype('float') - 1
probas[probas < 0] = 0
probas /= probas.sum()
Expand All @@ -327,13 +342,17 @@ def reassign_centroids(hassign, centroids, rs=None):
return nsplit



def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
return_stats=False):
"""Pure python kmeans implementation. Follows the Faiss C++ version
quite closely, but takes a DatasetAssign instead of a training data
matrix. Also redo is not implemented. """
matrix. Also redo is not implemented.
For the torch implementation, the centroids are tensors (possibly on GPU),
but the indices remain numpy on CPU.
"""
n, d = data.count(), data.dim()

log = print if verbose else print_nop

log(("Clustering %d points in %dD to %d clusters, " +
Expand All @@ -345,6 +364,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
# initialization
perm = rs.choice(n, size=k, replace=False)
centroids = data.get_subset(perm)
is_torch = check_if_torch(centroids)

iteration_stats = []

Expand All @@ -362,12 +382,17 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
t_search_tot += time.time() - t0s;

err = D.sum()
if is_torch:
err = err.item()
obj.append(err)

hassign = np.bincount(assign, minlength=k)

fac = hassign.reshape(-1, 1).astype('float32')
fac[fac == 0] = 1 # quiet warning
fac[fac == 0] = 1 # quiet warning
if is_torch:
import torch
fac = torch.from_numpy(fac).to(sums.device)

centroids = sums / fac

Expand All @@ -391,7 +416,11 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,

if checkpoint is not None:
log('storing centroids in', checkpoint)
np.save(checkpoint, centroids)
if is_torch:
import torch
torch.save(centroids, checkpoint)
else:
np.save(checkpoint, centroids)

if return_stats:
return centroids, iteration_stats
Expand Down
6 changes: 6 additions & 0 deletions contrib/torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# The Torch contrib

This contrib directory contains a few Pytorch routines that
are useful for similarity search. They do not necessarily depend on Faiss.

The code is designed to work with CPU and GPU tensors.
Empty file added contrib/torch/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions contrib/torch/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This contrib module contains Pytorch code for k-means clustering
"""
import faiss
import faiss.contrib.torch_utils
import torch

# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import DatasetAssign, kmeans


class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""

def __init__(self, x):
self.x = x

def count(self):
return self.x.shape[0]

def dim(self):
return self.x.shape[1]

def get_subset(self, indices):
return self.x[indices]

def perform_search(self, centroids):
return faiss.knn(self.x, centroids, 1)

def assign_to(self, centroids, weights=None):
D, I = self.perform_search(centroids)

I = I.ravel()
D = D.ravel()
nc, d = centroids.shape

sum_per_centroid = torch.zeros_like(centroids)
if weights is None:
sum_per_centroid.index_add_(0, I, self.x)
else:
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])

# the indices are still in numpy.
return I.cpu().numpy(), D, sum_per_centroid


class DatasetAssignGPU(DatasetAssign):

def __init__(self, res, x):
DatasetAssign.__init__(self, x)
self.res = res

def perform_search(self, centroids):
return faiss.knn_gpu(self.res, self.x, centroids, 1)
53 changes: 53 additions & 0 deletions contrib/torch/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This contrib module contains Pytorch code for quantization.
"""

import numpy as np
import torch
import faiss

from faiss.contrib import torch_utils


class Quantizer:

def __init__(self, d, code_size):
self.d = d
self.code_size = code_size

def train(self, x):
pass

def encode(self, x):
pass

def decode(self, x):
pass


class VectorQuantizer(Quantizer):

def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k

def train(self, x):
pass


class ProductQuantizer(Quantizer):

def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
self.M = M
self.nbits = nbits

def train(self, x):
pass
62 changes: 62 additions & 0 deletions contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@
import sys
import numpy as np

##################################################################
# Equivalent of swig_ptr for Torch tensors
##################################################################

def swig_ptr_from_UInt8Tensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.uint8
return faiss.cast_integer_to_uint8_ptr(
x.untyped_storage().data_ptr() + x.storage_offset())


def swig_ptr_from_HalfTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
Expand All @@ -43,27 +48,34 @@ def swig_ptr_from_HalfTensor(x):
return faiss.cast_integer_to_void_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 2)


def swig_ptr_from_FloatTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)


def swig_ptr_from_IntTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_int_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)


def swig_ptr_from_IndicesTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_idx_t_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 8)

##################################################################
# utilities
##################################################################

@contextlib.contextmanager
def using_stream(res, pytorch_stream=None):
""" Creates a scoping object to make Faiss GPU use the same stream
Expand Down Expand Up @@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement,
setattr(the_class, name + '_numpy', orig_method)
setattr(the_class, name, replacement)

##################################################################
# Setup wrappers
##################################################################

def handle_torch_Index(the_class):
def torch_replacement_add(self, x):
if type(x) is np.ndarray:
Expand Down Expand Up @@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None):
handle_torch_Index(the_class)


# allows torch tensor usage with knn
def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg)

nb, d = xb.size()
assert xb.is_contiguous()
assert xb.dtype == torch.float32
assert not xb.is_cuda, "use knn_gpu for GPU tensors"

nq, d2 = xq.size()
assert d2 == d
assert xq.is_contiguous()
assert xq.dtype == torch.float32
assert not xq.is_cuda, "use knn_gpu for GPU tensors"

D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
I_ptr = swig_ptr_from_IndicesTensor(I)
D_ptr = swig_ptr_from_FloatTensor(D)
xb_ptr = swig_ptr_from_FloatTensor(xb)
xq_ptr = swig_ptr_from_FloatTensor(xq)

if metric == faiss.METRIC_L2:
faiss.knn_L2sqr(
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
elif metric == faiss.METRIC_INNER_PRODUCT:
faiss.knn_inner_product(
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
else:
faiss.knn_extra_metrics(
xq_ptr, xb_ptr,
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
)

return D, I


torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True)


# allows torch tensor usage with bfKnn
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):
if type(xb) is np.ndarray:
Expand Down
33 changes: 33 additions & 0 deletions faiss/gpu/test/torch_test_contrib_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import faiss
import faiss.contrib.torch_utils

from faiss.contrib import datasets
from faiss.contrib.torch import clustering


def to_column_major_torch(x):
if hasattr(torch, 'contiguous_format'):
return x.t().clone(memory_format=torch.contiguous_format).t()
Expand Down Expand Up @@ -377,6 +381,7 @@ def test_knn_gpu_datatypes(self, use_raft=False):
self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)


class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase):
def test_pairwise_distance_gpu(self):
torch.manual_seed(10)
Expand Down Expand Up @@ -470,3 +475,31 @@ def test_pairwise_distance_gpu(self):
D, _ = torch.sort(D, dim=1)

self.assertLess((D.cpu() - gt_D[4:8]).abs().max(), 1e-4)


class TestClustering(unittest.TestCase):

def test_python_kmeans(self):
""" Test the python implementation of kmeans """
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
x = ds.get_train()

# bad distribution to stress-test split code
xt = x[:10000].copy()
xt[:5000] = x[0]

# CPU baseline
km_ref = faiss.Kmeans(ds.d, 100, niter=10)
km_ref.train(xt)
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()

xt_torch = torch.from_numpy(xt).to("cuda:0")
res = faiss.StandardGpuResources()
data = clustering.DatasetAssignGPU(res, xt_torch)
centroids = clustering.kmeans(100, data, 10)
centroids = centroids.cpu().numpy()
err2 = faiss.knn(xt, centroids, 1)[0].sum()

# 33498.332 33380.477
print(err, err2)
self.assertLess(err2, err * 1.1)
Loading

0 comments on commit 866e3fe

Please sign in to comment.