-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
sample_pdf CUDA and C++ implementations.
Summary: Implement the sample_pdf function from the NeRF project as compiled operators.. The binary search (in searchsorted) is replaced with a low tech linear search, but this is not a problem for the envisaged numbers of bins. Reviewed By: gkioxari Differential Revision: D26312535 fbshipit-source-id: df1c3119cd63d944380ed1b2657b6ad81d743e49
- Loading branch information
1 parent
7d7d00f
commit 1ea2b72
Showing
7 changed files
with
488 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
// There is no intermediate memory, so no reason not to have blocksize=32. | ||
// 256 is a reasonable number of blocks. | ||
|
||
// DESIGN | ||
// We exploit the fact that n_samples is not tiny. | ||
// A chunk of work is T*blocksize many samples from | ||
// a single batch elememt. | ||
// For each batch element there will be | ||
// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them. | ||
// The number of potential chunks to do is | ||
// n_chunks = chunks_per_batch * n_batches. | ||
// These chunks are divided among the gridSize-many blocks. | ||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . | ||
// In chunk i, we work on batch_element i/chunks_per_batch | ||
// on samples starting from (i%chunks_per_batch) * (T*blocksize) | ||
|
||
// BEGIN HYPOTHETICAL | ||
// Another option (not implemented) if batch_size was always large | ||
// would be as follows. | ||
|
||
// A chunk of work is S samples from each of blocksize-many | ||
// batch elements. | ||
// For each batch element there will be | ||
// chunks_per_batch = (1+(n_samples-1)/S) of them. | ||
// The number of potential chunks to do is | ||
// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize) | ||
// These chunks are divided among the gridSize-many blocks. | ||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . | ||
// In chunk i, we work on samples starting from S*(i%chunks_per_batch) | ||
// on batch elements starting from blocksize*(i/chunks_per_batch). | ||
// END HYPOTHETICAL | ||
|
||
__global__ void SamplePdfCudaKernel( | ||
const float* __restrict__ bins, | ||
const float* __restrict__ weights, | ||
float* __restrict__ outputs, | ||
float eps, | ||
const int T, | ||
const int64_t batch_size, | ||
const int64_t n_bins, | ||
const int64_t n_samples) { | ||
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x); | ||
const int64_t n_chunks = chunks_per_batch * batch_size; | ||
|
||
for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) { | ||
// Loop over the chunks. | ||
int64_t i_batch_element = i_chunk / chunks_per_batch; | ||
int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x); | ||
const float* const weight_startp = weights + n_bins * i_batch_element; | ||
const float* const bin_startp = bins + (1 + n_bins) * i_batch_element; | ||
|
||
// Each chunk looks at a single batch element, so we do the preprocessing | ||
// which depends on the batch element, namely finding the total weight. | ||
// Idenntical work is being done in sync here by every thread of the block. | ||
float total_weight = eps; | ||
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) { | ||
total_weight += weight_startp[i_bin]; | ||
} | ||
|
||
float* const output_startp = | ||
outputs + n_samples * i_batch_element + sample_start; | ||
|
||
for (int t = 0; t < T; ++t) { | ||
// Loop over T, which is the number of samples each thread makes within | ||
// the chunk. | ||
const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x; | ||
if (sample_start + i_sample_within_chunk >= n_samples) { | ||
// Some threads need to exit early because the sample they would | ||
// make is unwanted. | ||
continue; | ||
} | ||
// output_startp[i_sample_within_chunk] contains the quantile we (i.e. | ||
// this thread) are calcvulating. | ||
float uniform = total_weight * output_startp[i_sample_within_chunk]; | ||
int64_t i_bin = 0; | ||
// We find the bin containing the quantile by walking along the weights. | ||
// This loop must be thread dependent. I.e. the whole warp will wait until | ||
// every thread has found the bin for its quantile. | ||
// It may be best to write it differently. | ||
while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) { | ||
uniform -= weight_startp[i_bin]; | ||
++i_bin; | ||
} | ||
|
||
// Now we know which bin to look in, we use linear interpolation | ||
// to find the location of the quantile within the bin, and | ||
// write the answer back. | ||
float bin_start = bin_startp[i_bin]; | ||
float bin_end = bin_startp[i_bin + 1]; | ||
float bin_weight = weight_startp[i_bin]; | ||
float output_value = bin_start; | ||
if (uniform > bin_weight) { | ||
output_value = bin_end; | ||
} else if (bin_weight > eps) { | ||
output_value += (uniform / bin_weight) * (bin_end - bin_start); | ||
} | ||
output_startp[i_sample_within_chunk] = output_value; | ||
} | ||
} | ||
} | ||
|
||
void SamplePdfCuda( | ||
const at::Tensor& bins, | ||
const at::Tensor& weights, | ||
const at::Tensor& outputs, | ||
float eps) { | ||
// Check inputs are on the same device | ||
at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2}, | ||
outputs_t{outputs, "outputs", 3}; | ||
at::CheckedFrom c = "SamplePdfCuda"; | ||
at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t}); | ||
at::checkAllSameType(c, {bins_t, weights_t, outputs_t}); | ||
|
||
// Set the device for the kernel launch based on the device of the input | ||
at::cuda::CUDAGuard device_guard(bins.device()); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
const int64_t batch_size = bins.size(0); | ||
const int64_t n_bins = weights.size(1); | ||
const int64_t n_samples = outputs.size(1); | ||
|
||
const int64_t threads = 32; | ||
const int64_t T = n_samples <= threads ? 1 : 2; | ||
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads); | ||
const int64_t n_chunks = chunks_per_batch * batch_size; | ||
|
||
const int64_t max_blocks = 1024; | ||
const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks; | ||
|
||
SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>( | ||
bins.contiguous().data_ptr<float>(), | ||
weights.contiguous().data_ptr<float>(), | ||
outputs.data_ptr<float>(), // Checked contiguous in header file. | ||
eps, | ||
T, | ||
batch_size, | ||
n_bins, | ||
n_samples); | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
#include <cstdio> | ||
#include <tuple> | ||
#include "utils/pytorch3d_cutils.h" | ||
|
||
// **************************************************************************** | ||
// * SamplePdf * | ||
// **************************************************************************** | ||
|
||
// Samples a probability density functions defined by bin edges `bins` and | ||
// the non-negative per-bin probabilities `weights`. | ||
|
||
// Args: | ||
// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges | ||
// of the sampling bins. | ||
|
||
// weights: FloatTensor of shape `(batch_size, n_bins)` containing | ||
// non-negative numbers representing the probability of sampling the | ||
// corresponding bin. | ||
|
||
// uniforms: The quantiles to draw, FloatTensor of shape | ||
// `(batch_size, n_samples)`. | ||
|
||
// outputs: On call, this contains the quantiles to draw. It is overwritten | ||
// with the drawn samples. FloatTensor of shape | ||
// `(batch_size, n_samples), where `n_samples are drawn from each | ||
// distribution. | ||
|
||
// eps: A constant preventing division by zero in case empty bins are | ||
// present. | ||
|
||
// Not differentiable | ||
|
||
#ifdef WITH_CUDA | ||
void SamplePdfCuda( | ||
const torch::Tensor& bins, | ||
const torch::Tensor& weights, | ||
const torch::Tensor& outputs, | ||
float eps); | ||
#endif | ||
|
||
void SamplePdfCpu( | ||
const torch::Tensor& bins, | ||
const torch::Tensor& weights, | ||
const torch::Tensor& outputs, | ||
float eps); | ||
|
||
inline void SamplePdf( | ||
const torch::Tensor& bins, | ||
const torch::Tensor& weights, | ||
const torch::Tensor& outputs, | ||
float eps) { | ||
if (bins.is_cuda()) { | ||
#ifdef WITH_CUDA | ||
CHECK_CUDA(weights); | ||
CHECK_CONTIGUOUS_CUDA(outputs); | ||
SamplePdfCuda(bins, weights, outputs, eps); | ||
return; | ||
#else | ||
AT_ERROR("Not compiled with GPU support."); | ||
#endif | ||
} | ||
CHECK_CONTIGUOUS(outputs); | ||
SamplePdfCpu(bins, weights, outputs, eps); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <torch/extension.h> | ||
#include <algorithm> | ||
#include <thread> | ||
#include <vector> | ||
|
||
// If the number of bins is the typical 64, it is | ||
// quicker to use binary search than linear scan. | ||
// With more bins, it is more important. | ||
// There is no equivalent CUDA implementation yet. | ||
#define USE_BINARY_SEARCH | ||
|
||
namespace { | ||
// This worker function does the job of SamplePdf but only on | ||
// batch elements in [start_batch, end_batch). | ||
void SamplePdfCpu_worker( | ||
const torch::Tensor& bins, | ||
const torch::Tensor& weights, | ||
const torch::Tensor& outputs, | ||
float eps, | ||
int64_t start_batch, | ||
int64_t end_batch) { | ||
const int64_t n_bins = weights.size(1); | ||
const int64_t n_samples = outputs.size(1); | ||
|
||
auto bins_a = bins.accessor<float, 2>(); | ||
auto weights_a = weights.accessor<float, 2>(); | ||
float* __restrict__ output_p = | ||
outputs.data_ptr<float>() + start_batch * n_samples; | ||
|
||
#ifdef USE_BINARY_SEARCH | ||
std::vector<float> partial_sums(n_bins); | ||
#endif | ||
|
||
for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch; | ||
++i_batch_elt) { | ||
auto bin_a = bins_a[i_batch_elt]; | ||
auto weight_a = weights_a[i_batch_elt]; | ||
|
||
// Here we do the work which has to be done once per batch element. | ||
// i.e. (1) finding the total weight. (2) If using binary search, | ||
// precompute the partial sums of the weights. | ||
|
||
float total_weight = 0; | ||
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) { | ||
total_weight += weight_a[i_bin]; | ||
#ifdef USE_BINARY_SEARCH | ||
partial_sums[i_bin] = total_weight; | ||
#endif | ||
} | ||
total_weight += eps; | ||
|
||
for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) { | ||
// Here we are taking a single random quantile (which is stored | ||
// in *output_p) and using it to make a single sample, which we | ||
// write back to the same location. First we find which bin | ||
// the quantile lives in, either by binary search in the | ||
// precomputed partial sums, or by scanning through the weights. | ||
|
||
float uniform = total_weight * *output_p; | ||
#ifdef USE_BINARY_SEARCH | ||
int64_t i_bin = std::lower_bound( | ||
partial_sums.begin(), --partial_sums.end(), uniform) - | ||
partial_sums.begin(); | ||
if (i_bin > 0) { | ||
uniform -= partial_sums[i_bin - 1]; | ||
} | ||
#else | ||
int64_t i_bin = 0; | ||
while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) { | ||
uniform -= weight_a[i_bin]; | ||
++i_bin; | ||
} | ||
#endif | ||
|
||
// Now i_bin identifies the bin the quantile lives in, we use | ||
// straight line interpolation to find the position of the | ||
// quantile within the bin, and write it to *output_p. | ||
|
||
float bin_start = bin_a[i_bin]; | ||
float bin_end = bin_a[i_bin + 1]; | ||
float bin_weight = weight_a[i_bin]; | ||
float output_value = bin_start; | ||
if (uniform > bin_weight) { | ||
output_value = bin_end; | ||
} else if (bin_weight > eps) { | ||
output_value += (uniform / bin_weight) * (bin_end - bin_start); | ||
} | ||
*output_p = output_value; | ||
++output_p; | ||
} | ||
} | ||
} | ||
|
||
} // anonymous namespace | ||
|
||
void SamplePdfCpu( | ||
const torch::Tensor& bins, | ||
const torch::Tensor& weights, | ||
const torch::Tensor& outputs, | ||
float eps) { | ||
const int64_t batch_size = bins.size(0); | ||
const int64_t max_threads = std::min(4, at::get_num_threads()); | ||
const int64_t n_threads = std::min(max_threads, batch_size); | ||
if (batch_size == 0) { | ||
return; | ||
} | ||
|
||
// SamplePdfCpu_worker does the work of this function. We send separate ranges | ||
// of batch elements to that function in nThreads-1 separate threads. | ||
|
||
std::vector<std::thread> threads; | ||
threads.reserve(n_threads - 1); | ||
const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads; | ||
int64_t start_batch = 0; | ||
for (int iThread = 0; iThread < n_threads - 1; ++iThread) { | ||
threads.emplace_back( | ||
SamplePdfCpu_worker, | ||
bins, | ||
weights, | ||
outputs, | ||
eps, | ||
start_batch, | ||
start_batch + batch_elements_per_thread); | ||
start_batch += batch_elements_per_thread; | ||
} | ||
|
||
// The remaining batch elements are calculated in this threads. If nThreads is | ||
// 1 then all the work happens in this line. | ||
SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size); | ||
for (auto&& thread : threads) { | ||
thread.join(); | ||
} | ||
} |
Oops, something went wrong.