-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fix_gcc7_arm
- Loading branch information
Showing
16 changed files
with
2,060 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# The Torch contrib | ||
|
||
This contrib directory contains a few Pytorch routines that | ||
are useful for similarity search. They do not necessarily depend on Faiss. | ||
|
||
The code is designed to work with CPU and GPU tensors. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
This contrib module contains Pytorch code for k-means clustering | ||
""" | ||
import faiss | ||
import faiss.contrib.torch_utils | ||
import torch | ||
|
||
# the kmeans can produce both torch and numpy centroids | ||
from faiss.contrib.clustering import kmeans | ||
|
||
class DatasetAssign: | ||
"""Wrapper for a tensor that offers a function to assign the vectors | ||
to centroids. All other implementations offer the same interface""" | ||
|
||
def __init__(self, x): | ||
self.x = x | ||
|
||
def count(self): | ||
return self.x.shape[0] | ||
|
||
def dim(self): | ||
return self.x.shape[1] | ||
|
||
def get_subset(self, indices): | ||
return self.x[indices] | ||
|
||
def perform_search(self, centroids): | ||
return faiss.knn(self.x, centroids, 1) | ||
|
||
def assign_to(self, centroids, weights=None): | ||
D, I = self.perform_search(centroids) | ||
|
||
I = I.ravel() | ||
D = D.ravel() | ||
nc, d = centroids.shape | ||
|
||
sum_per_centroid = torch.zeros_like(centroids) | ||
if weights is None: | ||
sum_per_centroid.index_add_(0, I, self.x) | ||
else: | ||
sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) | ||
|
||
# the indices are still in numpy. | ||
return I.cpu().numpy(), D, sum_per_centroid | ||
|
||
|
||
class DatasetAssignGPU(DatasetAssign): | ||
|
||
def __init__(self, res, x): | ||
DatasetAssign.__init__(self, x) | ||
self.res = res | ||
|
||
def perform_search(self, centroids): | ||
return faiss.knn_gpu(self.res, self.x, centroids, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
This contrib module contains Pytorch code for quantization. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
import faiss | ||
|
||
from faiss.contrib import torch_utils | ||
|
||
|
||
class Quantizer: | ||
|
||
def __init__(self, d, code_size): | ||
self.d = d | ||
self.code_size = code_size | ||
|
||
def train(self, x): | ||
pass | ||
|
||
def encode(self, x): | ||
pass | ||
|
||
def decode(self, x): | ||
pass | ||
|
||
|
||
class VectorQuantizer(Quantizer): | ||
|
||
def __init__(self, d, k): | ||
code_size = int(torch.ceil(torch.log2(k) / 8)) | ||
Quantizer.__init__(d, code_size) | ||
self.k = k | ||
|
||
def train(self, x): | ||
pass | ||
|
||
|
||
class ProductQuantizer(Quantizer): | ||
|
||
def __init__(self, d, M, nbits): | ||
code_size = int(torch.ceil(M * nbits / 8)) | ||
Quantizer.__init__(d, code_size) | ||
self.M = M | ||
self.nbits = nbits | ||
|
||
def train(self, x): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.