-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Implements K-Nearest Neighbors with C++ and CUDA versions. KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K. These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend. I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API. I still want to benchmark against FAISS to see how far away we are from their performance. Reviewed By: bottler Differential Revision: D19729286 fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
- Loading branch information
1 parent
02d4968
commit 870290d
Showing
12 changed files
with
1,328 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
// | ||
// This file provides utilities for dispatching to specialized versions of functions. | ||
// This is especially useful for CUDA kernels, since specializing them to particular | ||
// input sizes can often allow the compiler to unroll loops and place arrays into | ||
// registers, which can give huge performance speedups. | ||
// | ||
// As an example, suppose we have the following function which is specialized | ||
// based on a compile-time int64_t value: | ||
// | ||
// template<typename T, int64_t x> | ||
// struct SquareOffset { | ||
// static void run(T y) { | ||
// T val = x * x + y; | ||
// std::cout << val << std::endl; | ||
// } | ||
// } | ||
// | ||
// This function takes one compile-time argument x, and one run-time argument y. | ||
// We might want to compile specialized versions of this for x=0, x=1, etc and | ||
// then dispatch to the correct one based on the runtime value of x. | ||
// One simple way to achieve this is with a lookup table: | ||
// | ||
// template<typename T> | ||
// void DispatchSquareOffset(const int64_t x, T y) { | ||
// if (x == 0) { | ||
// SquareOffset<T, 0>::run(y); | ||
// } else if (x == 1) { | ||
// SquareOffset<T, 1>::run(y); | ||
// } else if (x == 2) { | ||
// SquareOffset<T, 2>::run(y); | ||
// } | ||
// } | ||
// | ||
// This function takes both x and y as run-time arguments, and dispatches to | ||
// different specialized versions of SquareOffset based on the run-time value | ||
// of x. This works, but it's tedious and error-prone. If we want to change the | ||
// set of x values for which we provide compile-time specializations, then we | ||
// will need to do a lot of tedius editing of the dispatch function. Also, if we | ||
// want to provide compile-time specializations for another function other than | ||
// SquareOffset, we will need to duplicate the entire lookup table. | ||
// | ||
// To solve these problems, we can use the DispatchKernel1D function provided by | ||
// this file instead: | ||
// | ||
// template<typename T> | ||
// void DispatchSquareOffset(const int64_t x, T y) { | ||
// constexpr int64_t xmin = 0; | ||
// constexpr int64_t xmax = 2; | ||
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y); | ||
// } | ||
// | ||
// DispatchKernel1D uses template metaprogramming to compile specialized | ||
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and | ||
// then dispatches to the correct one based on the run-time value of x. If we | ||
// want to change the range of x values for which SquareOffset is specialized | ||
// at compile-time, then all we have to do is change the values of the | ||
// compile-time constants xmin and xmax. | ||
// | ||
// This file also allows us to similarly dispatch functions that depend on two | ||
// compile-time int64_t values, using the DispatchKernel2D function like this: | ||
// | ||
// template<typename T, int64_t x, int64_t y> | ||
// struct Sum { | ||
// static void run(T z, T w) { | ||
// T val = x + y + z + w; | ||
// std::cout << val << std::endl; | ||
// } | ||
// } | ||
// | ||
// template<typename T> | ||
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) { | ||
// constexpr int64_t xmin = 1; | ||
// constexpr int64_t xmax = 3; | ||
// constexpr int64_t ymin = 2; | ||
// constexpr int64_t ymax = 5; | ||
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w); | ||
// } | ||
// | ||
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to | ||
// compile specialized versions of sum for all values of (x, y) with | ||
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct | ||
// specialized version based on the runtime values of x and y. | ||
|
||
// Define some helper structs in an anonymous namespace. | ||
namespace { | ||
|
||
// 1D dispatch: general case. | ||
// Kernel is the function we want to dispatch to; it should take a typename and | ||
// an int64_t as template args, and it should define a static void function | ||
// run which takes any number of arguments of any type. | ||
// In order to dispatch, we will take an additional template argument curN, | ||
// and increment it via template recursion until it is equal to the run-time | ||
// argument N. | ||
template< | ||
template<typename, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, | ||
int64_t maxN, | ||
int64_t curN, | ||
typename... Args | ||
> | ||
struct DispatchKernelHelper1D { | ||
static void run(const int64_t N, Args... args) { | ||
if (curN == N) { | ||
// The compile-time value curN is equal to the run-time value N, so we | ||
// can dispatch to the run method of the Kernel. | ||
Kernel<T, curN>::run(args...); | ||
} else if (curN < N) { | ||
// Increment curN via template recursion | ||
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...); | ||
} | ||
// We shouldn't get here -- throw an error? | ||
} | ||
}; | ||
|
||
|
||
// 1D dispatch: Specialization when curN == maxN | ||
// We need this base case to avoid infinite template recursion. | ||
template< | ||
template<typename, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, | ||
int64_t maxN, | ||
typename... Args | ||
> | ||
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> { | ||
static void run(const int64_t N, Args... args) { | ||
if (N == maxN) { | ||
Kernel<T, maxN>::run(args...); | ||
} | ||
// We shouldn't get here -- throw an error? | ||
} | ||
}; | ||
|
||
|
||
// 2D dispatch, general case. | ||
// This is similar to the 1D case: we take additional template args curN and | ||
// curM, and increment them via template recursion until they are equal to | ||
// the run-time values of N and M, at which point we dispatch to the run | ||
// method of the kernel. | ||
template< | ||
template<typename, int64_t, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, int64_t maxN, int64_t curN, | ||
int64_t minM, int64_t maxM, int64_t curM, | ||
typename... Args | ||
> | ||
struct DispatchKernelHelper2D { | ||
static void run(const int64_t N, const int64_t M, Args... args) { | ||
if (curN == N && curM == M) { | ||
Kernel<T, curN, curM>::run(args...); | ||
} else if (curN < N && curM < M) { | ||
// Increment both curN and curM. This isn't strictly necessary; we could | ||
// just increment one or the other at each step. But this helps to cut | ||
// on the number of recursive calls we make. | ||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...); | ||
} else if (curN < N) { | ||
// Increment curN only | ||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...); | ||
} else if (curM < M) { | ||
// Increment curM only | ||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...); | ||
} | ||
} | ||
}; | ||
|
||
|
||
// 2D dispatch, specialization for curN == maxN | ||
template< | ||
template<typename, int64_t, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, int64_t maxN, | ||
int64_t minM, int64_t maxM, int64_t curM, | ||
typename... Args | ||
> | ||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> { | ||
static void run(const int64_t N, const int64_t M, Args... args) { | ||
if (maxN == N && curM == M) { | ||
Kernel<T, maxN, curM>::run(args...); | ||
} else if (curM < maxM) { | ||
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...); | ||
} | ||
// We should not get here -- throw an error? | ||
} | ||
}; | ||
|
||
|
||
// 2D dispatch, specialization for curM == maxM | ||
template< | ||
template<typename, int64_t, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, int64_t maxN, int64_t curN, | ||
int64_t minM, int64_t maxM, | ||
typename... Args | ||
> | ||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> { | ||
static void run(const int64_t N, const int64_t M, Args... args) { | ||
if (curN == N && maxM == M) { | ||
Kernel<T, curN, maxM>::run(args...); | ||
} else if (curN < maxN) { | ||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...); | ||
} | ||
// We should not get here -- throw an error? | ||
} | ||
}; | ||
|
||
|
||
// 2D dispatch, specialization for curN == maxN, curM == maxM | ||
template< | ||
template<typename, int64_t, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, int64_t maxN, | ||
int64_t minM, int64_t maxM, | ||
typename... Args | ||
> | ||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> { | ||
static void run(const int64_t N, const int64_t M, Args... args) { | ||
if (maxN == N && maxM == M) { | ||
Kernel<T, maxN, maxM>::run(args...); | ||
} | ||
// We should not get here -- throw an error? | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
|
||
// This is the function we expect users to call to dispatch to 1D functions | ||
template< | ||
template<typename, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, | ||
int64_t maxN, | ||
typename... Args | ||
> | ||
void DispatchKernel1D(const int64_t N, Args... args) { | ||
if (minN <= N && N <= maxN) { | ||
// Kick off the template recursion by calling the Helper with curN = minN | ||
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...); | ||
} | ||
// Maybe throw an error if we tried to dispatch outside the allowed range? | ||
} | ||
|
||
|
||
// This is the function we expect users to call to dispatch to 2D functions | ||
template< | ||
template<typename, int64_t, int64_t> class Kernel, | ||
typename T, | ||
int64_t minN, int64_t maxN, | ||
int64_t minM, int64_t maxM, | ||
typename... Args | ||
> | ||
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { | ||
if (minN <= N && N <= maxN && minM <= M && M <= maxM) { | ||
// Kick off the template recursion by calling the Helper with curN = minN | ||
// and curM = minM | ||
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...); | ||
} | ||
// Maybe throw an error if we tried to dispatch outside the specified range? | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
// This converts dynamic array lookups into static array lookups, for small | ||
// arrays up to size 32. | ||
// | ||
// Suppose we have a small thread-local array: | ||
// | ||
// float vals[10]; | ||
// | ||
// Ideally we should only index this array using static indices: | ||
// | ||
// for (int i = 0; i < 10; ++i) vals[i] = i * i; | ||
// | ||
// If we do so, then the CUDA compiler may be able to place the array into | ||
// registers, which can have a big performance improvement. However if we | ||
// access the array dynamically, the the compiler may force the array into | ||
// local memory, which has the same latency as global memory. | ||
// | ||
// These functions convert dynamic array access into static array access | ||
// using a brute-force lookup table. It can be used like this: | ||
// | ||
// float vals[10]; | ||
// int idx = 3; | ||
// float val = 3.14f; | ||
// RegisterIndexUtils<float, 10>::set(vals, idx, val); | ||
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx); | ||
// | ||
// The implementation is based on fbcuda/RegisterUtils.cuh: | ||
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh | ||
// To avoid depending on the entire library, we just reimplement these two | ||
// functions. The fbcuda implementation is a bit more sophisticated, and uses | ||
// the preprocessor to generate switch statements that go up to N for each | ||
// value of N. We are lazy and just have a giant explicit switch statement. | ||
// | ||
// We might be able to use a template metaprogramming approach similar to | ||
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used | ||
// for dispatching to the correct CUDA kernel on the host, while this is | ||
// is intended to run on the device. I was concerned that a metaprogramming | ||
// approach for this might lead to extra function calls at runtime if the | ||
// compiler fails to optimize them away, which could be very slow on device. | ||
// However I didn't actually benchmark or test this. | ||
template<typename T, int N> | ||
struct RegisterIndexUtils { | ||
__device__ __forceinline__ static T get(const T arr[N], int idx) { | ||
if (idx < 0 || idx >= N) return T(); | ||
switch (idx) { | ||
case 0: return arr[0]; | ||
case 1: return arr[1]; | ||
case 2: return arr[2]; | ||
case 3: return arr[3]; | ||
case 4: return arr[4]; | ||
case 5: return arr[5]; | ||
case 6: return arr[6]; | ||
case 7: return arr[7]; | ||
case 8: return arr[8]; | ||
case 9: return arr[9]; | ||
case 10: return arr[10]; | ||
case 11: return arr[11]; | ||
case 12: return arr[12]; | ||
case 13: return arr[13]; | ||
case 14: return arr[14]; | ||
case 15: return arr[15]; | ||
case 16: return arr[16]; | ||
case 17: return arr[17]; | ||
case 18: return arr[18]; | ||
case 19: return arr[19]; | ||
case 20: return arr[20]; | ||
case 21: return arr[21]; | ||
case 22: return arr[22]; | ||
case 23: return arr[23]; | ||
case 24: return arr[24]; | ||
case 25: return arr[25]; | ||
case 26: return arr[26]; | ||
case 27: return arr[27]; | ||
case 28: return arr[28]; | ||
case 29: return arr[29]; | ||
case 30: return arr[30]; | ||
case 31: return arr[31]; | ||
}; | ||
return T(); | ||
} | ||
|
||
__device__ __forceinline__ static void set(T arr[N], int idx, T val) { | ||
if (idx < 0 || idx >= N) return; | ||
switch (idx) { | ||
case 0: arr[0] = val; break; | ||
case 1: arr[1] = val; break; | ||
case 2: arr[2] = val; break; | ||
case 3: arr[3] = val; break; | ||
case 4: arr[4] = val; break; | ||
case 5: arr[5] = val; break; | ||
case 6: arr[6] = val; break; | ||
case 7: arr[7] = val; break; | ||
case 8: arr[8] = val; break; | ||
case 9: arr[9] = val; break; | ||
case 10: arr[10] = val; break; | ||
case 11: arr[11] = val; break; | ||
case 12: arr[12] = val; break; | ||
case 13: arr[13] = val; break; | ||
case 14: arr[14] = val; break; | ||
case 15: arr[15] = val; break; | ||
case 16: arr[16] = val; break; | ||
case 17: arr[17] = val; break; | ||
case 18: arr[18] = val; break; | ||
case 19: arr[19] = val; break; | ||
case 20: arr[20] = val; break; | ||
case 21: arr[21] = val; break; | ||
case 22: arr[22] = val; break; | ||
case 23: arr[23] = val; break; | ||
case 24: arr[24] = val; break; | ||
case 25: arr[25] = val; break; | ||
case 26: arr[26] = val; break; | ||
case 27: arr[27] = val; break; | ||
case 28: arr[28] = val; break; | ||
case 29: arr[29] = val; break; | ||
case 30: arr[30] = val; break; | ||
case 31: arr[31] = val; break; | ||
} | ||
} | ||
}; |
Oops, something went wrong.