From 870290df345873492d88f70b942893cd3b5deb87 Mon Sep 17 00:00:00 2001 From: Justin Johnson Date: Thu, 26 Mar 2020 13:37:32 -0700 Subject: [PATCH] Implement K-Nearest Neighbors Summary: Implements K-Nearest Neighbors with C++ and CUDA versions. KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K. These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend. I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API. I still want to benchmark against FAISS to see how far away we are from their performance. Reviewed By: bottler Differential Revision: D19729286 fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a --- pytorch3d/csrc/dispatch.cuh | 261 +++++++++++++ pytorch3d/csrc/ext.cpp | 2 + pytorch3d/csrc/index_utils.cuh | 120 ++++++ pytorch3d/csrc/knn/knn.cu | 369 ++++++++++++++++++ pytorch3d/csrc/knn/knn.h | 54 +++ pytorch3d/csrc/knn/knn_cpu.cpp | 52 +++ pytorch3d/csrc/mink.cuh | 162 ++++++++ .../nearest_neighbor_points.h | 2 +- pytorch3d/ops/knn.py | 67 ++++ tests/bm_knn.py | 174 +++++++++ tests/test_knn.py | 65 +++ tests/test_nearest_neighbor_points.py | 1 + 12 files changed, 1328 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/csrc/dispatch.cuh create mode 100644 pytorch3d/csrc/index_utils.cuh create mode 100644 pytorch3d/csrc/knn/knn.cu create mode 100644 pytorch3d/csrc/knn/knn.h create mode 100644 pytorch3d/csrc/knn/knn_cpu.cpp create mode 100644 pytorch3d/csrc/mink.cuh create mode 100644 pytorch3d/ops/knn.py create mode 100644 tests/bm_knn.py create mode 100644 tests/test_knn.py diff --git a/pytorch3d/csrc/dispatch.cuh b/pytorch3d/csrc/dispatch.cuh new file mode 100644 index 000000000..5226a5d71 --- /dev/null +++ b/pytorch3d/csrc/dispatch.cuh @@ -0,0 +1,261 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +// +// This file provides utilities for dispatching to specialized versions of functions. +// This is especially useful for CUDA kernels, since specializing them to particular +// input sizes can often allow the compiler to unroll loops and place arrays into +// registers, which can give huge performance speedups. +// +// As an example, suppose we have the following function which is specialized +// based on a compile-time int64_t value: +// +// template +// struct SquareOffset { +// static void run(T y) { +// T val = x * x + y; +// std::cout << val << std::endl; +// } +// } +// +// This function takes one compile-time argument x, and one run-time argument y. +// We might want to compile specialized versions of this for x=0, x=1, etc and +// then dispatch to the correct one based on the runtime value of x. +// One simple way to achieve this is with a lookup table: +// +// template +// void DispatchSquareOffset(const int64_t x, T y) { +// if (x == 0) { +// SquareOffset::run(y); +// } else if (x == 1) { +// SquareOffset::run(y); +// } else if (x == 2) { +// SquareOffset::run(y); +// } +// } +// +// This function takes both x and y as run-time arguments, and dispatches to +// different specialized versions of SquareOffset based on the run-time value +// of x. This works, but it's tedious and error-prone. If we want to change the +// set of x values for which we provide compile-time specializations, then we +// will need to do a lot of tedius editing of the dispatch function. Also, if we +// want to provide compile-time specializations for another function other than +// SquareOffset, we will need to duplicate the entire lookup table. +// +// To solve these problems, we can use the DispatchKernel1D function provided by +// this file instead: +// +// template +// void DispatchSquareOffset(const int64_t x, T y) { +// constexpr int64_t xmin = 0; +// constexpr int64_t xmax = 2; +// DispatchKernel1D(x, y); +// } +// +// DispatchKernel1D uses template metaprogramming to compile specialized +// versions of SquareOffset for all values of x with xmin <= x <= xmax, and +// then dispatches to the correct one based on the run-time value of x. If we +// want to change the range of x values for which SquareOffset is specialized +// at compile-time, then all we have to do is change the values of the +// compile-time constants xmin and xmax. +// +// This file also allows us to similarly dispatch functions that depend on two +// compile-time int64_t values, using the DispatchKernel2D function like this: +// +// template +// struct Sum { +// static void run(T z, T w) { +// T val = x + y + z + w; +// std::cout << val << std::endl; +// } +// } +// +// template +// void DispatchSum(const int64_t x, const int64_t y, int z, int w) { +// constexpr int64_t xmin = 1; +// constexpr int64_t xmax = 3; +// constexpr int64_t ymin = 2; +// constexpr int64_t ymax = 5; +// DispatchKernel2D(x, y, z, w); +// } +// +// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to +// compile specialized versions of sum for all values of (x, y) with +// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct +// specialized version based on the runtime values of x and y. + +// Define some helper structs in an anonymous namespace. +namespace { + +// 1D dispatch: general case. +// Kernel is the function we want to dispatch to; it should take a typename and +// an int64_t as template args, and it should define a static void function +// run which takes any number of arguments of any type. +// In order to dispatch, we will take an additional template argument curN, +// and increment it via template recursion until it is equal to the run-time +// argument N. +template< + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + typename... Args +> +struct DispatchKernelHelper1D { + static void run(const int64_t N, Args... args) { + if (curN == N) { + // The compile-time value curN is equal to the run-time value N, so we + // can dispatch to the run method of the Kernel. + Kernel::run(args...); + } else if (curN < N) { + // Increment curN via template recursion + DispatchKernelHelper1D::run(N, args...); + } + // We shouldn't get here -- throw an error? + } +}; + + +// 1D dispatch: Specialization when curN == maxN +// We need this base case to avoid infinite template recursion. +template< + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args +> +struct DispatchKernelHelper1D { + static void run(const int64_t N, Args... args) { + if (N == maxN) { + Kernel::run(args...); + } + // We shouldn't get here -- throw an error? + } +}; + + +// 2D dispatch, general case. +// This is similar to the 1D case: we take additional template args curN and +// curM, and increment them via template recursion until they are equal to +// the run-time values of N and M, at which point we dispatch to the run +// method of the kernel. +template< + template class Kernel, + typename T, + int64_t minN, int64_t maxN, int64_t curN, + int64_t minM, int64_t maxM, int64_t curM, + typename... Args +> +struct DispatchKernelHelper2D { + static void run(const int64_t N, const int64_t M, Args... args) { + if (curN == N && curM == M) { + Kernel::run(args...); + } else if (curN < N && curM < M) { + // Increment both curN and curM. This isn't strictly necessary; we could + // just increment one or the other at each step. But this helps to cut + // on the number of recursive calls we make. + DispatchKernelHelper2D::run(N, M, args...); + } else if (curN < N) { + // Increment curN only + DispatchKernelHelper2D::run(N, M, args...); + } else if (curM < M) { + // Increment curM only + DispatchKernelHelper2D::run(N, M, args...); + } + } +}; + + +// 2D dispatch, specialization for curN == maxN +template< + template class Kernel, + typename T, + int64_t minN, int64_t maxN, + int64_t minM, int64_t maxM, int64_t curM, + typename... Args +> +struct DispatchKernelHelper2D { + static void run(const int64_t N, const int64_t M, Args... args) { + if (maxN == N && curM == M) { + Kernel::run(args...); + } else if (curM < maxM) { + DispatchKernelHelper2D::run(N, M, args...); + } + // We should not get here -- throw an error? + } +}; + + +// 2D dispatch, specialization for curM == maxM +template< + template class Kernel, + typename T, + int64_t minN, int64_t maxN, int64_t curN, + int64_t minM, int64_t maxM, + typename... Args +> +struct DispatchKernelHelper2D { + static void run(const int64_t N, const int64_t M, Args... args) { + if (curN == N && maxM == M) { + Kernel::run(args...); + } else if (curN < maxN) { + DispatchKernelHelper2D::run(N, M, args...); + } + // We should not get here -- throw an error? + } +}; + + +// 2D dispatch, specialization for curN == maxN, curM == maxM +template< + template class Kernel, + typename T, + int64_t minN, int64_t maxN, + int64_t minM, int64_t maxM, + typename... Args +> +struct DispatchKernelHelper2D { + static void run(const int64_t N, const int64_t M, Args... args) { + if (maxN == N && maxM == M) { + Kernel::run(args...); + } + // We should not get here -- throw an error? + } +}; + +} // namespace + + +// This is the function we expect users to call to dispatch to 1D functions +template< + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args +> +void DispatchKernel1D(const int64_t N, Args... args) { + if (minN <= N && N <= maxN) { + // Kick off the template recursion by calling the Helper with curN = minN + DispatchKernelHelper1D::run(N, args...); + } + // Maybe throw an error if we tried to dispatch outside the allowed range? +} + + +// This is the function we expect users to call to dispatch to 2D functions +template< + template class Kernel, + typename T, + int64_t minN, int64_t maxN, + int64_t minM, int64_t maxM, + typename... Args +> +void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { + if (minN <= N && N <= maxN && minM <= M && M <= maxM) { + // Kick off the template recursion by calling the Helper with curN = minN + // and curM = minM + DispatchKernelHelper2D::run(N, M, args...); + } + // Maybe throw an error if we tried to dispatch outside the specified range? +} diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 10105c0b4..38292a342 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -6,6 +6,7 @@ #include "compositing/weighted_sum.h" #include "face_areas_normals/face_areas_normals.h" #include "gather_scatter/gather_scatter.h" +#include "knn/knn.h" #include "nearest_neighbor_points/nearest_neighbor_points.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "rasterize_meshes/rasterize_meshes.h" @@ -16,6 +17,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_backward", &FaceAreasNormalsBackward); m.def("packed_to_padded", &PackedToPadded); m.def("padded_to_packed", &PaddedToPacked); + m.def("knn_points_idx", &KNearestNeighborIdx); m.def("nn_points_idx", &NearestNeighborIdx); m.def("gather_scatter", &gather_scatter); m.def("rasterize_points", &RasterizePoints); diff --git a/pytorch3d/csrc/index_utils.cuh b/pytorch3d/csrc/index_utils.cuh new file mode 100644 index 000000000..66460ebf4 --- /dev/null +++ b/pytorch3d/csrc/index_utils.cuh @@ -0,0 +1,120 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +// This converts dynamic array lookups into static array lookups, for small +// arrays up to size 32. +// +// Suppose we have a small thread-local array: +// +// float vals[10]; +// +// Ideally we should only index this array using static indices: +// +// for (int i = 0; i < 10; ++i) vals[i] = i * i; +// +// If we do so, then the CUDA compiler may be able to place the array into +// registers, which can have a big performance improvement. However if we +// access the array dynamically, the the compiler may force the array into +// local memory, which has the same latency as global memory. +// +// These functions convert dynamic array access into static array access +// using a brute-force lookup table. It can be used like this: +// +// float vals[10]; +// int idx = 3; +// float val = 3.14f; +// RegisterIndexUtils::set(vals, idx, val); +// float val2 = RegisterIndexUtils::get(vals, idx); +// +// The implementation is based on fbcuda/RegisterUtils.cuh: +// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh +// To avoid depending on the entire library, we just reimplement these two +// functions. The fbcuda implementation is a bit more sophisticated, and uses +// the preprocessor to generate switch statements that go up to N for each +// value of N. We are lazy and just have a giant explicit switch statement. +// +// We might be able to use a template metaprogramming approach similar to +// DispatchKernel1D for this. However DispatchKernel1D is intended to be used +// for dispatching to the correct CUDA kernel on the host, while this is +// is intended to run on the device. I was concerned that a metaprogramming +// approach for this might lead to extra function calls at runtime if the +// compiler fails to optimize them away, which could be very slow on device. +// However I didn't actually benchmark or test this. +template +struct RegisterIndexUtils { + __device__ __forceinline__ static T get(const T arr[N], int idx) { + if (idx < 0 || idx >= N) return T(); + switch (idx) { + case 0: return arr[0]; + case 1: return arr[1]; + case 2: return arr[2]; + case 3: return arr[3]; + case 4: return arr[4]; + case 5: return arr[5]; + case 6: return arr[6]; + case 7: return arr[7]; + case 8: return arr[8]; + case 9: return arr[9]; + case 10: return arr[10]; + case 11: return arr[11]; + case 12: return arr[12]; + case 13: return arr[13]; + case 14: return arr[14]; + case 15: return arr[15]; + case 16: return arr[16]; + case 17: return arr[17]; + case 18: return arr[18]; + case 19: return arr[19]; + case 20: return arr[20]; + case 21: return arr[21]; + case 22: return arr[22]; + case 23: return arr[23]; + case 24: return arr[24]; + case 25: return arr[25]; + case 26: return arr[26]; + case 27: return arr[27]; + case 28: return arr[28]; + case 29: return arr[29]; + case 30: return arr[30]; + case 31: return arr[31]; + }; + return T(); + } + + __device__ __forceinline__ static void set(T arr[N], int idx, T val) { + if (idx < 0 || idx >= N) return; + switch (idx) { + case 0: arr[0] = val; break; + case 1: arr[1] = val; break; + case 2: arr[2] = val; break; + case 3: arr[3] = val; break; + case 4: arr[4] = val; break; + case 5: arr[5] = val; break; + case 6: arr[6] = val; break; + case 7: arr[7] = val; break; + case 8: arr[8] = val; break; + case 9: arr[9] = val; break; + case 10: arr[10] = val; break; + case 11: arr[11] = val; break; + case 12: arr[12] = val; break; + case 13: arr[13] = val; break; + case 14: arr[14] = val; break; + case 15: arr[15] = val; break; + case 16: arr[16] = val; break; + case 17: arr[17] = val; break; + case 18: arr[18] = val; break; + case 19: arr[19] = val; break; + case 20: arr[20] = val; break; + case 21: arr[21] = val; break; + case 22: arr[22] = val; break; + case 23: arr[23] = val; break; + case 24: arr[24] = val; break; + case 25: arr[25] = val; break; + case 26: arr[26] = val; break; + case 27: arr[27] = val; break; + case 28: arr[28] = val; break; + case 29: arr[29] = val; break; + case 30: arr[30] = val; break; + case 31: arr[31] = val; break; + } + } +}; diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu new file mode 100644 index 000000000..b065d969c --- /dev/null +++ b/pytorch3d/csrc/knn/knn.cu @@ -0,0 +1,369 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include +#include + +#include "dispatch.cuh" +#include "mink.cuh" + +template +__global__ void KNearestNeighborKernelV0( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t D, + const size_t K) { + // Stupid version: Make each thread handle one query point and loop over + // all P2 target points. There are N * P1 input points to handle, so + // do a trivial parallelization over threads. + // Store both dists and indices for knn in global memory. + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_threads = blockDim.x * gridDim.x; + for (int np = tid; np < N * P1; np += num_threads) { + int n = np / P1; + int p1 = np % P1; + int offset = n * P1 * K + p1 * K; + MinK mink(dists + offset, idxs + offset, K); + for (int p2 = 0; p2 < P2; ++p2) { + // Find the distance between points1[n, p1] and points[n, p2] + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; + scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; + scalar_t diff = coord1 - coord2; + dist += diff * diff; + } + mink.add(dist, p2); + } + } +} + +template +__global__ void KNearestNeighborKernelV1( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t K) { + // Same idea as the previous version, but hoist D into a template argument + // so we can cache the current point in a thread-local array. We still store + // the current best K dists and indices in global memory, so this should work + // for very large K and fairly large D. + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_threads = blockDim.x * gridDim.x; + scalar_t cur_point[D]; + for (int np = tid; np < N * P1; np += num_threads) { + int n = np / P1; + int p1 = np % P1; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int offset = n * P1 * K + p1 * K; + MinK mink(dists + offset, idxs + offset, K); + for (int p2 = 0; p2 < P2; ++p2) { + // Find the distance between cur_point and points[n, p2] + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; + dist += diff * diff; + } + mink.add(dist, p2); + } + } +} + +// This is a shim functor to allow us to dispatch using DispatchKernel1D +template +struct KNearestNeighborV1Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t K) { + KNearestNeighborKernelV1 + <<>>(points1, points2, dists, idxs, N, P1, P2, K); + } +}; + +template +__global__ void KNearestNeighborKernelV2( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const int64_t N, + const int64_t P1, + const int64_t P2) { + // Same general implementation as V2, but also hoist K into a template arg. + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_threads = blockDim.x * gridDim.x; + scalar_t cur_point[D]; + scalar_t min_dists[K]; + int min_idxs[K]; + for (int np = tid; np < N * P1; np += num_threads) { + int n = np / P1; + int p1 = np % P1; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + MinK mink(min_dists, min_idxs, K); + for (int p2 = 0; p2 < P2; ++p2) { + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + int offset = n * P2 * D + p2 * D + d; + scalar_t diff = cur_point[d] - points2[offset]; + dist += diff * diff; + } + mink.add(dist, p2); + } + for (int k = 0; k < mink.size(); ++k) { + idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; + dists[n * P1 * K + p1 * K + k] = min_dists[k]; + } + } +} + +// This is a shim so we can dispatch using DispatchKernel2D +template +struct KNearestNeighborKernelV2Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const int64_t N, + const int64_t P1, + const int64_t P2) { + KNearestNeighborKernelV2 + <<>>(points1, points2, dists, idxs, N, P1, P2); + } +}; + +template +__global__ void KNearestNeighborKernelV3( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2) { + // Same idea as V2, but use register indexing for thread-local arrays. + // Enabling sorting for this version leads to huge slowdowns; I suspect + // that it forces min_dists into local memory rather than registers. + // As a result this version is always unsorted. + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_threads = blockDim.x * gridDim.x; + scalar_t cur_point[D]; + scalar_t min_dists[K]; + int min_idxs[K]; + for (int np = tid; np < N * P1; np += num_threads) { + int n = np / P1; + int p1 = np % P1; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + RegisterMinK mink(min_dists, min_idxs); + for (int p2 = 0; p2 < P2; ++p2) { + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + int offset = n * P2 * D + p2 * D + d; + scalar_t diff = cur_point[d] - points2[offset]; + dist += diff * diff; + } + mink.add(dist, p2); + } + for (int k = 0; k < mink.size(); ++k) { + idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; + dists[n * P1 * K + p1 * K + k] = min_dists[k]; + } + } +} + +// This is a shim so we can dispatch using DispatchKernel2D +template +struct KNearestNeighborKernelV3Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2) { + KNearestNeighborKernelV3 + <<>>(points1, points2, dists, idxs, N, P1, P2); + } +}; + +constexpr int V1_MIN_D = 1; +constexpr int V1_MAX_D = 32; + +constexpr int V2_MIN_D = 1; +constexpr int V2_MAX_D = 8; +constexpr int V2_MIN_K = 1; +constexpr int V2_MAX_K = 32; + +constexpr int V3_MIN_D = 1; +constexpr int V3_MAX_D = 8; +constexpr int V3_MIN_K = 1; +constexpr int V3_MAX_K = 4; + +bool InBounds(const int64_t min, const int64_t x, const int64_t max) { + return min <= x && x <= max; +} + +bool CheckVersion(int version, const int64_t D, const int64_t K) { + if (version == 0) { + return true; + } else if (version == 1) { + return InBounds(V1_MIN_D, D, V1_MAX_D); + } else if (version == 2) { + return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K); + } else if (version == 3) { + return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K); + } + return false; +} + +int ChooseVersion(const int64_t D, const int64_t K) { + for (int version = 3; version >= 1; version--) { + if (CheckVersion(version, D, K)) { + return version; + } + } + return 0; +} + +std::tuple KNearestNeighborIdxCuda( + const at::Tensor& p1, + const at::Tensor& p2, + int K, + int version) { + const auto N = p1.size(0); + const auto P1 = p1.size(1); + const auto P2 = p2.size(1); + const auto D = p2.size(2); + const int64_t K_64 = K; + + AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); + auto long_dtype = p1.options().dtype(at::kLong); + auto idxs = at::full({N, P1, K}, -1, long_dtype); + auto dists = at::full({N, P1, K}, -1, p1.options()); + + if (version < 0) { + version = ChooseVersion(D, K); + } else if (!CheckVersion(version, D, K)) { + int new_version = ChooseVersion(D, K); + std::cout << "WARNING: Requested KNN version " << version + << " is not compatible with D = " << D << "; K = " << K + << ". Falling back to version = " << new_version << std::endl; + version = new_version; + } + + // At this point we should have a valid version no matter what data the user + // gave us. But we can check once more to be sure; however this time + // assert fail since failing at this point means we have a bug in our version + // selection or checking code. + AT_ASSERTM(CheckVersion(version, D, K), "Invalid version"); + + const size_t threads = 256; + const size_t blocks = 256; + if (version == 0) { + AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + KNearestNeighborKernelV0 + <<>>( + p1.data_ptr(), + p2.data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + D, + K); + })); + } else if (version == 1) { + AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + DispatchKernel1D< + KNearestNeighborV1Functor, + scalar_t, + V1_MIN_D, + V1_MAX_D>( + D, + blocks, + threads, + p1.data_ptr(), + p2.data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + K); + })); + } else if (version == 2) { + AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + DispatchKernel2D< + KNearestNeighborKernelV2Functor, + scalar_t, + V2_MIN_D, + V2_MAX_D, + V2_MIN_K, + V2_MAX_K>( + D, + K_64, + blocks, + threads, + p1.data_ptr(), + p2.data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2); + })); + } else if (version == 3) { + AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + DispatchKernel2D< + KNearestNeighborKernelV3Functor, + scalar_t, + V3_MIN_D, + V3_MAX_D, + V3_MIN_K, + V3_MAX_K>( + D, + K_64, + blocks, + threads, + p1.data_ptr(), + p2.data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2); + })); + } + + return std::make_tuple(idxs, dists); +} diff --git a/pytorch3d/csrc/knn/knn.h b/pytorch3d/csrc/knn/knn.h new file mode 100644 index 000000000..cb760c346 --- /dev/null +++ b/pytorch3d/csrc/knn/knn.h @@ -0,0 +1,54 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once +#include +#include +#include "pytorch3d_cutils.h" + +// Compute indices of K nearest neighbors in pointcloud p2 to points +// in pointcloud p1. +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// K: int giving the number of nearest points to return. +// sorted: bool telling whether to sort the K returned points by their +// distance version: Integer telling which implementation to use. +// TODO(jcjohns): Document this more, or maybe remove it before +// landing. +// +// Returns: +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. + +// CPU implementation. +std::tuple +KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K); + +// CUDA implementation +std::tuple KNearestNeighborIdxCuda( + const at::Tensor& p1, + const at::Tensor& p2, + int K, + int version); + +// Implementation which is exposed. +std::tuple KNearestNeighborIdx( + const at::Tensor& p1, + const at::Tensor& p2, + int K, + int version) { + if (p1.type().is_cuda() || p2.type().is_cuda()) { +#ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(p1); + CHECK_CONTIGUOUS_CUDA(p2); + return KNearestNeighborIdxCuda(p1, p2, K, version); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return KNearestNeighborIdxCpu(p1, p2, K); +} diff --git a/pytorch3d/csrc/knn/knn_cpu.cpp b/pytorch3d/csrc/knn/knn_cpu.cpp new file mode 100644 index 000000000..dada972a5 --- /dev/null +++ b/pytorch3d/csrc/knn/knn_cpu.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include + + +std::tuple +KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + const int P2 = p2.size(1); + + auto long_opts = p1.options().dtype(torch::kInt64); + torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); + torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto idxs_a = idxs.accessor(); + auto dists_a = dists.accessor(); + + for (int n = 0; n < N; ++n) { + for (int i1 = 0; i1 < P1; ++i1) { + // Use a priority queue to store (distance, index) tuples. + std::priority_queue> q; + for (int i2 = 0; i2 < P2; ++i2) { + float dist = 0; + for (int d = 0; d < D; ++d) { + float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; + dist += diff * diff; + } + int size = static_cast(q.size()); + if (size < K || dist < std::get<0>(q.top())) { + q.emplace(dist, i2); + if (size >= K) { + q.pop(); + } + } + } + while (!q.empty()) { + auto t = q.top(); + q.pop(); + const int k = q.size(); + dists_a[n][i1][k] = std::get<0>(t); + idxs_a[n][i1][k] = std::get<1>(t); + } + } + } + return std::make_tuple(idxs, dists); +} diff --git a/pytorch3d/csrc/mink.cuh b/pytorch3d/csrc/mink.cuh new file mode 100644 index 000000000..5d7eb7300 --- /dev/null +++ b/pytorch3d/csrc/mink.cuh @@ -0,0 +1,162 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once +#define MINK_H + +#include "index_utils.cuh" + + +// A data structure to keep track of the smallest K keys seen so far as well +// as their associated values, intended to be used in device code. +// This data structure doesn't allocate any memory; keys and values are stored +// in arrays passed to the constructor. +// +// The implementation is generic; it can be used for any key type that supports +// the < operator, and can be used with any value type. +// +// Example usage: +// +// float keys[K]; +// int values[K]; +// MinK mink(keys, values, K); +// for (...) { +// // Produce some key and value from somewhere +// mink.add(key, value); +// } +// mink.sort(); +// +// Now keys and values store the smallest K keys seen so far and the values +// associated to these keys: +// +// for (int k = 0; k < K; ++k) { +// float key_k = keys[k]; +// int value_k = values[k]; +// } +template +class MinK { + public: + + // Constructor. + // + // Arguments: + // keys: Array in which to store keys + // values: Array in which to store values + // K: How many values to keep track of + __device__ MinK(key_t *keys, value_t *vals, int K) : + keys(keys), vals(vals), K(K), _size(0) { } + + // Try to add a new key and associated value to the data structure. If the key + // is one of the smallest K seen so far then it will be kept; otherwise it + // it will not be kept. + // + // This takes O(1) operations if the new key is not kept, or if the structure + // currently contains fewer than K elements. Otherwise this takes O(K) time. + // + // Arguments: + // key: The key to add + // val: The value associated to the key + __device__ __forceinline__ void add(const key_t &key, const value_t &val) { + if (_size < K) { + keys[_size] = key; + vals[_size] = val; + if (_size == 0 || key > max_key) { + max_key = key; + max_idx = _size; + } + _size++; + } else if (key < max_key) { + keys[max_idx] = key; + vals[max_idx] = val; + max_key = key; + for (int k = 0; k < K; ++k) { + key_t cur_key = keys[k]; + if (cur_key > max_key) { + max_key = cur_key; + max_idx = k; + } + } + } + } + + // Get the number of items currently stored in the structure. + // This takes O(1) time. + __device__ __forceinline__ int size() { + return _size; + } + + // Sort the items stored in the structure using bubble sort. + // This takes O(K^2) time. + __device__ __forceinline__ void sort() { + for (int i = 0; i < _size - 1; ++i) { + for (int j = 0; j < _size - i - 1; ++j) { + if (keys[j + 1] < keys[j]) { + key_t key = keys[j]; + value_t val = vals[j]; + keys[j] = keys[j + 1]; + vals[j] = vals[j + 1]; + keys[j + 1] = key; + vals[j + 1] = val; + } + } + } + } + + private: + key_t *keys; + value_t *vals; + int K; + int _size; + key_t max_key; + int max_idx; +}; + + +// This is a version of MinK that only touches the arrays using static indexing +// via RegisterIndexUtils. If the keys and values are stored in thread-local +// arrays, then this may allow the compiler to place them in registers for +// fast access. +// +// This has the same API as RegisterMinK, but doesn't support sorting. +// We found that sorting via RegisterIndexUtils gave very poor performance, +// and suspect it may have prevented the compiler from placing the arrays +// into registers. +template +class RegisterMinK { + public: + __device__ RegisterMinK(key_t *keys, value_t *vals) : + keys(keys), vals(vals), _size(0) {} + + __device__ __forceinline__ void add(const key_t &key, const value_t &val) { + if (_size < K) { + RegisterIndexUtils::set(keys, _size, key); + RegisterIndexUtils::set(vals, _size, val); + if (_size == 0 || key > max_key) { + max_key = key; + max_idx = _size; + } + _size++; + } else if (key < max_key) { + RegisterIndexUtils::set(keys, max_idx, key); + RegisterIndexUtils::set(vals, max_idx, val); + max_key = key; + for (int k = 0; k < K; ++k) { + key_t cur_key = RegisterIndexUtils::get(keys, k); + if (cur_key > max_key) { + max_key = cur_key; + max_idx = k; + } + } + } + } + + __device__ __forceinline__ int size() { + return _size; + } + + private: + key_t *keys; + value_t *vals; + int _size; + key_t max_key; + int max_idx; +}; diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h index 7b447233e..27f9cc458 100644 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h @@ -39,4 +39,4 @@ at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) { #endif } return NearestNeighborIdxCpu(p1, p2); -}; +} diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py new file mode 100644 index 000000000..2ec35992b --- /dev/null +++ b/pytorch3d/ops/knn.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch +from pytorch3d import _C + + +def knn_points_idx(p1, p2, K, sorted=False, version=-1): + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each + containing P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each + containing P2 points of dimension D. + K: Integer giving the number of nearest neighbors to return + sorted: Whether to sort the resulting points. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of + the inputs. + + Returns: + idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K + nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then + p2[n, j] is the kth nearest neighbor to p1[n, i]. + """ + idx, dists = _C.knn_points_idx(p1, p2, K, version) + if sorted: + dists, sort_idx = dists.sort(dim=2) + idx = idx.gather(2, sort_idx) + return idx, dists + + +@torch.no_grad() +def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor: + """ + Naive PyTorch implementation of K-Nearest Neighbors. + + This is much less efficient than _C.knn_points_idx, but we include this + naive implementation for testing and benchmarking. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each + containing P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each + containing P2 points of dimension D. + K: Integer giving the number of nearest neighbors to return + sorted: Whether to sort the resulting points. + + Returns: + idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K + nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then + p2[n, j] is the kth nearest neighbor to p1[n, i]. + dists: Tensor of shape (N, P1, K) giving the distances to the nearest + neighbors. + """ + N, P1, D = p1.shape + _N, P2, _D = p2.shape + assert N == _N and D == _D + diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D) + dists2 = (diffs * diffs).sum(dim=3) + out = dists2.topk(K, dim=2, largest=False, sorted=sorted) + return out.indices, out.values diff --git a/tests/bm_knn.py b/tests/bm_knn.py new file mode 100644 index 000000000..fd3915327 --- /dev/null +++ b/tests/bm_knn.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark + +from pytorch3d import _C +from pytorch3d.ops.knn import _knn_points_idx_naive + + +def bm_knn() -> None: + """ Entry point for the benchmark """ + benchmark_knn_cpu() + benchmark_knn_cuda_vs_naive() + benchmark_knn_cuda_versions() + + +def benchmark_knn_cuda_versions() -> None: + # Compare our different KNN implementations, + # and also compare against our existing 1-NN + Ns = [1, 2] + Ps = [4096, 16384] + Ds = [3] + Ks = [1, 4, 16, 64] + versions = [0, 1, 2, 3] + knn_kwargs, nn_kwargs = [], [] + for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions): + if version == 2 and K > 32: + continue + if version == 3 and K > 4: + continue + knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version}) + for N, P, D in product(Ns, Ps, Ds): + nn_kwargs.append({'N': N, 'D': D, 'P': P}) + benchmark( + knn_cuda_with_init, + 'KNN_CUDA_VERSIONS', + knn_kwargs, + warmup_iters=1, + ) + benchmark( + nn_cuda_with_init, + 'NN_CUDA', + nn_kwargs, + warmup_iters=1, + ) + + +def benchmark_knn_cuda_vs_naive() -> None: + # Compare against naive pytorch version of KNN + Ns = [1, 2, 4] + Ps = [1024, 4096, 16384, 65536] + Ds = [3] + Ks = [1, 2, 4, 8, 16] + knn_kwargs, naive_kwargs = [], [] + for N, P, D, K in product(Ns, Ps, Ds, Ks): + knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) + if P <= 4096: + naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) + benchmark( + knn_python_cuda_with_init, + 'KNN_CUDA_PYTHON', + naive_kwargs, + warmup_iters=1, + ) + benchmark( + knn_cuda_with_init, + 'KNN_CUDA', + knn_kwargs, + warmup_iters=1, + ) + + +def benchmark_knn_cpu() -> None: + Ns = [1, 2] + Ps = [256, 512] + Ds = [3] + Ks = [1, 2, 4] + knn_kwargs, nn_kwargs = [], [] + for N, P, D, K in product(Ns, Ps, Ds, Ks): + knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) + for N, P, D in product(Ns, Ps, Ds): + nn_kwargs.append({'N': N, 'D': D, 'P': P}) + benchmark( + knn_python_cpu_with_init, + 'KNN_CPU_PYTHON', + knn_kwargs, + warmup_iters=1, + ) + benchmark( + knn_cpu_with_init, + 'KNN_CPU_CPP', + knn_kwargs, + warmup_iters=1, + ) + benchmark( + nn_cpu_with_init, + 'NN_CPU_CPP', + nn_kwargs, + warmup_iters=1, + ) + + +def knn_cuda_with_init(N, D, P, K, v=-1): + device = torch.device('cuda:0') + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + torch.cuda.synchronize() + + def knn(): + _C.knn_points_idx(x, y, K, v) + torch.cuda.synchronize() + + return knn + + +def knn_cpu_with_init(N, D, P, K): + device = torch.device('cpu') + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + + def knn(): + _C.knn_points_idx(x, y, K, 0) + + return knn + + +def knn_python_cuda_with_init(N, D, P, K): + device = torch.device('cuda') + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + torch.cuda.synchronize() + + def knn(): + _knn_points_idx_naive(x, y, K) + torch.cuda.synchronize() + + return knn + + +def knn_python_cpu_with_init(N, D, P, K): + device = torch.device('cpu') + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + + def knn(): + _knn_points_idx_naive(x, y, K) + + return knn + + +def nn_cuda_with_init(N, D, P): + device = torch.device('cuda') + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + torch.cuda.synchronize() + + def knn(): + _C.nn_points_idx(x, y) + torch.cuda.synchronize() + + return knn + + +def nn_cpu_with_init(N, D, P): + device = torch.device('cpu') + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + + def knn(): + _C.nn_points_idx(x, y) + + return knn diff --git a/tests/test_knn.py b/tests/test_knn.py new file mode 100644 index 000000000..9c9483d66 --- /dev/null +++ b/tests/test_knn.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest +from itertools import product +import torch + +from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx + + +class TestKNN(unittest.TestCase): + def _check_knn_result(self, out1, out2, sorted): + # When sorted=True, points should be sorted by distance and should + # match between implementations. When sorted=False we we only want to + # check that we got the same set of indices, so we sort the indices by + # index value. + idx1, dist1 = out1 + idx2, dist2 = out2 + if not sorted: + idx1 = idx1.sort(dim=2).values + idx2 = idx2.sort(dim=2).values + dist1 = dist1.sort(dim=2).values + dist2 = dist2.sort(dim=2).values + if not torch.all(idx1 == idx2): + print(idx1) + print(idx2) + self.assertTrue(torch.all(idx1 == idx2)) + self.assertTrue(torch.allclose(dist1, dist2)) + + def test_knn_vs_python_cpu(self): + """ Test CPU output vs PyTorch implementation """ + device = torch.device('cpu') + Ns = [1, 4] + Ds = [2, 3] + P1s = [1, 10, 101] + P2s = [10, 101] + Ks = [1, 3, 10] + sorts = [True, False] + factors = [Ns, Ds, P1s, P2s, Ks, sorts] + for N, D, P1, P2, K, sort in product(*factors): + x = torch.randn(N, P1, D, device=device) + y = torch.randn(N, P2, D, device=device) + out1 = _knn_points_idx_naive(x, y, K, sort) + out2 = knn_points_idx(x, y, K, sort) + self._check_knn_result(out1, out2, sort) + + def test_knn_vs_python_cuda(self): + """ Test CUDA output vs PyTorch implementation """ + device = torch.device('cuda') + Ns = [1, 4] + Ds = [2, 3, 8] + P1s = [1, 8, 64, 128, 1001] + P2s = [32, 128, 513] + Ks = [1, 3, 10] + sorts = [True, False] + versions = [0, 1, 2, 3] + factors = [Ns, Ds, P1s, P2s, Ks, sorts] + for N, D, P1, P2, K, sort in product(*factors): + x = torch.randn(N, P1, D, device=device) + y = torch.randn(N, P2, D, device=device) + out1 = _knn_points_idx_naive(x, y, K, sorted=sort) + for version in versions: + if version == 3 and K > 4: + continue + out2 = knn_points_idx(x, y, K, sort, version) + self._check_knn_result(out1, out2, sort) diff --git a/tests/test_nearest_neighbor_points.py b/tests/test_nearest_neighbor_points.py index 961434aa0..4332defb2 100644 --- a/tests/test_nearest_neighbor_points.py +++ b/tests/test_nearest_neighbor_points.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import unittest +from itertools import product import torch from pytorch3d import _C