Skip to content

Commit

Permalink
Implement K-Nearest Neighbors
Browse files Browse the repository at this point in the history
Summary:
Implements K-Nearest Neighbors with C++ and CUDA versions.

KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K.

These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend.

I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API.

I still want to benchmark against FAISS to see how far away we are from their performance.

Reviewed By: bottler

Differential Revision: D19729286

fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
  • Loading branch information
jcjohnson authored and facebook-github-bot committed Mar 26, 2020
1 parent 02d4968 commit 870290d
Show file tree
Hide file tree
Showing 12 changed files with 1,328 additions and 1 deletion.
261 changes: 261 additions & 0 deletions pytorch3d/csrc/dispatch.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This file provides utilities for dispatching to specialized versions of functions.
// This is especially useful for CUDA kernels, since specializing them to particular
// input sizes can often allow the compiler to unroll loops and place arrays into
// registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
//
// template<typename T, int64_t x>
// struct SquareOffset {
// static void run(T y) {
// T val = x * x + y;
// std::cout << val << std::endl;
// }
// }
//
// This function takes one compile-time argument x, and one run-time argument y.
// We might want to compile specialized versions of this for x=0, x=1, etc and
// then dispatch to the correct one based on the runtime value of x.
// One simple way to achieve this is with a lookup table:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// if (x == 0) {
// SquareOffset<T, 0>::run(y);
// } else if (x == 1) {
// SquareOffset<T, 1>::run(y);
// } else if (x == 2) {
// SquareOffset<T, 2>::run(y);
// }
// }
//
// This function takes both x and y as run-time arguments, and dispatches to
// different specialized versions of SquareOffset based on the run-time value
// of x. This works, but it's tedious and error-prone. If we want to change the
// set of x values for which we provide compile-time specializations, then we
// will need to do a lot of tedius editing of the dispatch function. Also, if we
// want to provide compile-time specializations for another function other than
// SquareOffset, we will need to duplicate the entire lookup table.
//
// To solve these problems, we can use the DispatchKernel1D function provided by
// this file instead:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// constexpr int64_t xmin = 0;
// constexpr int64_t xmax = 2;
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
// }
//
// DispatchKernel1D uses template metaprogramming to compile specialized
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
// then dispatches to the correct one based on the run-time value of x. If we
// want to change the range of x values for which SquareOffset is specialized
// at compile-time, then all we have to do is change the values of the
// compile-time constants xmin and xmax.
//
// This file also allows us to similarly dispatch functions that depend on two
// compile-time int64_t values, using the DispatchKernel2D function like this:
//
// template<typename T, int64_t x, int64_t y>
// struct Sum {
// static void run(T z, T w) {
// T val = x + y + z + w;
// std::cout << val << std::endl;
// }
// }
//
// template<typename T>
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
// constexpr int64_t xmin = 1;
// constexpr int64_t xmax = 3;
// constexpr int64_t ymin = 2;
// constexpr int64_t ymax = 5;
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
// }
//
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
// compile specialized versions of sum for all values of (x, y) with
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
// specialized version based on the runtime values of x and y.

// Define some helper structs in an anonymous namespace.
namespace {

// 1D dispatch: general case.
// Kernel is the function we want to dispatch to; it should take a typename and
// an int64_t as template args, and it should define a static void function
// run which takes any number of arguments of any type.
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args
>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (curN == N) {
// The compile-time value curN is equal to the run-time value N, so we
// can dispatch to the run method of the Kernel.
Kernel<T, curN>::run(args...);
} else if (curN < N) {
// Increment curN via template recursion
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
}
// We shouldn't get here -- throw an error?
}
};


// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args
>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
static void run(const int64_t N, Args... args) {
if (N == maxN) {
Kernel<T, maxN>::run(args...);
}
// We shouldn't get here -- throw an error?
}
};


// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN, int64_t curN,
int64_t minM, int64_t maxM, int64_t curM,
typename... Args
>
struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) {
Kernel<T, curN, curM>::run(args...);
} else if (curN < N && curM < M) {
// Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make.
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
} else if (curN < N) {
// Increment curN only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
} else if (curM < M) {
// Increment curM only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
}
}
};


// 2D dispatch, specialization for curN == maxN
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM, int64_t curM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) {
Kernel<T, maxN, curM>::run(args...);
} else if (curM < maxM) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};


// 2D dispatch, specialization for curM == maxM
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN, int64_t curN,
int64_t minM, int64_t maxM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) {
Kernel<T, curN, maxM>::run(args...);
} else if (curN < maxN) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};


// 2D dispatch, specialization for curN == maxN, curM == maxM
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) {
Kernel<T, maxN, maxM>::run(args...);
}
// We should not get here -- throw an error?
}
};

} // namespace


// This is the function we expect users to call to dispatch to 1D functions
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args
>
void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
}
// Maybe throw an error if we tried to dispatch outside the allowed range?
}


// This is the function we expect users to call to dispatch to 2D functions
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM,
typename... Args
>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
}
// Maybe throw an error if we tried to dispatch outside the specified range?
}
2 changes: 2 additions & 0 deletions pytorch3d/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "compositing/weighted_sum.h"
#include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h"
#include "knn/knn.h"
#include "nearest_neighbor_points/nearest_neighbor_points.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "rasterize_meshes/rasterize_meshes.h"
Expand All @@ -16,6 +17,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
m.def("packed_to_padded", &PackedToPadded);
m.def("padded_to_packed", &PaddedToPacked);
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints);
Expand Down
120 changes: 120 additions & 0 deletions pytorch3d/csrc/index_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

// This converts dynamic array lookups into static array lookups, for small
// arrays up to size 32.
//
// Suppose we have a small thread-local array:
//
// float vals[10];
//
// Ideally we should only index this array using static indices:
//
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
//
// If we do so, then the CUDA compiler may be able to place the array into
// registers, which can have a big performance improvement. However if we
// access the array dynamically, the the compiler may force the array into
// local memory, which has the same latency as global memory.
//
// These functions convert dynamic array access into static array access
// using a brute-force lookup table. It can be used like this:
//
// float vals[10];
// int idx = 3;
// float val = 3.14f;
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
//
// The implementation is based on fbcuda/RegisterUtils.cuh:
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
// To avoid depending on the entire library, we just reimplement these two
// functions. The fbcuda implementation is a bit more sophisticated, and uses
// the preprocessor to generate switch statements that go up to N for each
// value of N. We are lazy and just have a giant explicit switch statement.
//
// We might be able to use a template metaprogramming approach similar to
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
// for dispatching to the correct CUDA kernel on the host, while this is
// is intended to run on the device. I was concerned that a metaprogramming
// approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this.
template<typename T, int N>
struct RegisterIndexUtils {
__device__ __forceinline__ static T get(const T arr[N], int idx) {
if (idx < 0 || idx >= N) return T();
switch (idx) {
case 0: return arr[0];
case 1: return arr[1];
case 2: return arr[2];
case 3: return arr[3];
case 4: return arr[4];
case 5: return arr[5];
case 6: return arr[6];
case 7: return arr[7];
case 8: return arr[8];
case 9: return arr[9];
case 10: return arr[10];
case 11: return arr[11];
case 12: return arr[12];
case 13: return arr[13];
case 14: return arr[14];
case 15: return arr[15];
case 16: return arr[16];
case 17: return arr[17];
case 18: return arr[18];
case 19: return arr[19];
case 20: return arr[20];
case 21: return arr[21];
case 22: return arr[22];
case 23: return arr[23];
case 24: return arr[24];
case 25: return arr[25];
case 26: return arr[26];
case 27: return arr[27];
case 28: return arr[28];
case 29: return arr[29];
case 30: return arr[30];
case 31: return arr[31];
};
return T();
}

__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
if (idx < 0 || idx >= N) return;
switch (idx) {
case 0: arr[0] = val; break;
case 1: arr[1] = val; break;
case 2: arr[2] = val; break;
case 3: arr[3] = val; break;
case 4: arr[4] = val; break;
case 5: arr[5] = val; break;
case 6: arr[6] = val; break;
case 7: arr[7] = val; break;
case 8: arr[8] = val; break;
case 9: arr[9] = val; break;
case 10: arr[10] = val; break;
case 11: arr[11] = val; break;
case 12: arr[12] = val; break;
case 13: arr[13] = val; break;
case 14: arr[14] = val; break;
case 15: arr[15] = val; break;
case 16: arr[16] = val; break;
case 17: arr[17] = val; break;
case 18: arr[18] = val; break;
case 19: arr[19] = val; break;
case 20: arr[20] = val; break;
case 21: arr[21] = val; break;
case 22: arr[22] = val; break;
case 23: arr[23] = val; break;
case 24: arr[24] = val; break;
case 25: arr[25] = val; break;
case 26: arr[26] = val; break;
case 27: arr[27] = val; break;
case 28: arr[28] = val; break;
case 29: arr[29] = val; break;
case 30: arr[30] = val; break;
case 31: arr[31] = val; break;
}
}
};
Loading

0 comments on commit 870290d

Please sign in to comment.