diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 8a5a3d54d..4fa78defa 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -26,6 +26,7 @@ #include "point_mesh/point_mesh_cuda.h" #include "rasterize_meshes/rasterize_meshes.h" #include "rasterize_points/rasterize_points.h" +#include "sample_pdf/sample_pdf.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_forward", &FaceAreasNormalsForward); @@ -83,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward); m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward); + // Sample PDF + m.def("sample_pdf", &SamplePdf); + // Pulsar. #ifdef PULSAR_LOGGING_ENABLED c10::ShowLogInfoToStderr(); diff --git a/pytorch3d/csrc/sample_pdf/sample_pdf.cu b/pytorch3d/csrc/sample_pdf/sample_pdf.cu new file mode 100644 index 000000000..f39b93f45 --- /dev/null +++ b/pytorch3d/csrc/sample_pdf/sample_pdf.cu @@ -0,0 +1,153 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +// There is no intermediate memory, so no reason not to have blocksize=32. +// 256 is a reasonable number of blocks. + +// DESIGN +// We exploit the fact that n_samples is not tiny. +// A chunk of work is T*blocksize many samples from +// a single batch elememt. +// For each batch element there will be +// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them. +// The number of potential chunks to do is +// n_chunks = chunks_per_batch * n_batches. +// These chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on batch_element i/chunks_per_batch +// on samples starting from (i%chunks_per_batch) * (T*blocksize) + +// BEGIN HYPOTHETICAL +// Another option (not implemented) if batch_size was always large +// would be as follows. + +// A chunk of work is S samples from each of blocksize-many +// batch elements. +// For each batch element there will be +// chunks_per_batch = (1+(n_samples-1)/S) of them. +// The number of potential chunks to do is +// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize) +// These chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on samples starting from S*(i%chunks_per_batch) +// on batch elements starting from blocksize*(i/chunks_per_batch). +// END HYPOTHETICAL + +__global__ void SamplePdfCudaKernel( + const float* __restrict__ bins, + const float* __restrict__ weights, + float* __restrict__ outputs, + float eps, + const int T, + const int64_t batch_size, + const int64_t n_bins, + const int64_t n_samples) { + const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x); + const int64_t n_chunks = chunks_per_batch * batch_size; + + for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) { + // Loop over the chunks. + int64_t i_batch_element = i_chunk / chunks_per_batch; + int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x); + const float* const weight_startp = weights + n_bins * i_batch_element; + const float* const bin_startp = bins + (1 + n_bins) * i_batch_element; + + // Each chunk looks at a single batch element, so we do the preprocessing + // which depends on the batch element, namely finding the total weight. + // Idenntical work is being done in sync here by every thread of the block. + float total_weight = eps; + for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) { + total_weight += weight_startp[i_bin]; + } + + float* const output_startp = + outputs + n_samples * i_batch_element + sample_start; + + for (int t = 0; t < T; ++t) { + // Loop over T, which is the number of samples each thread makes within + // the chunk. + const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x; + if (sample_start + i_sample_within_chunk >= n_samples) { + // Some threads need to exit early because the sample they would + // make is unwanted. + continue; + } + // output_startp[i_sample_within_chunk] contains the quantile we (i.e. + // this thread) are calcvulating. + float uniform = total_weight * output_startp[i_sample_within_chunk]; + int64_t i_bin = 0; + // We find the bin containing the quantile by walking along the weights. + // This loop must be thread dependent. I.e. the whole warp will wait until + // every thread has found the bin for its quantile. + // It may be best to write it differently. + while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) { + uniform -= weight_startp[i_bin]; + ++i_bin; + } + + // Now we know which bin to look in, we use linear interpolation + // to find the location of the quantile within the bin, and + // write the answer back. + float bin_start = bin_startp[i_bin]; + float bin_end = bin_startp[i_bin + 1]; + float bin_weight = weight_startp[i_bin]; + float output_value = bin_start; + if (uniform > bin_weight) { + output_value = bin_end; + } else if (bin_weight > eps) { + output_value += (uniform / bin_weight) * (bin_end - bin_start); + } + output_startp[i_sample_within_chunk] = output_value; + } + } +} + +void SamplePdfCuda( + const at::Tensor& bins, + const at::Tensor& weights, + const at::Tensor& outputs, + float eps) { + // Check inputs are on the same device + at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2}, + outputs_t{outputs, "outputs", 3}; + at::CheckedFrom c = "SamplePdfCuda"; + at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t}); + at::checkAllSameType(c, {bins_t, weights_t, outputs_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(bins.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int64_t batch_size = bins.size(0); + const int64_t n_bins = weights.size(1); + const int64_t n_samples = outputs.size(1); + + const int64_t threads = 32; + const int64_t T = n_samples <= threads ? 1 : 2; + const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads); + const int64_t n_chunks = chunks_per_batch * batch_size; + + const int64_t max_blocks = 1024; + const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks; + + SamplePdfCudaKernel<<>>( + bins.contiguous().data_ptr(), + weights.contiguous().data_ptr(), + outputs.data_ptr(), // Checked contiguous in header file. + eps, + T, + batch_size, + n_bins, + n_samples); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/pytorch3d/csrc/sample_pdf/sample_pdf.h b/pytorch3d/csrc/sample_pdf/sample_pdf.h new file mode 100644 index 000000000..af963b2f7 --- /dev/null +++ b/pytorch3d/csrc/sample_pdf/sample_pdf.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include "utils/pytorch3d_cutils.h" + +// **************************************************************************** +// * SamplePdf * +// **************************************************************************** + +// Samples a probability density functions defined by bin edges `bins` and +// the non-negative per-bin probabilities `weights`. + +// Args: +// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges +// of the sampling bins. + +// weights: FloatTensor of shape `(batch_size, n_bins)` containing +// non-negative numbers representing the probability of sampling the +// corresponding bin. + +// uniforms: The quantiles to draw, FloatTensor of shape +// `(batch_size, n_samples)`. + +// outputs: On call, this contains the quantiles to draw. It is overwritten +// with the drawn samples. FloatTensor of shape +// `(batch_size, n_samples), where `n_samples are drawn from each +// distribution. + +// eps: A constant preventing division by zero in case empty bins are +// present. + +// Not differentiable + +#ifdef WITH_CUDA +void SamplePdfCuda( + const torch::Tensor& bins, + const torch::Tensor& weights, + const torch::Tensor& outputs, + float eps); +#endif + +void SamplePdfCpu( + const torch::Tensor& bins, + const torch::Tensor& weights, + const torch::Tensor& outputs, + float eps); + +inline void SamplePdf( + const torch::Tensor& bins, + const torch::Tensor& weights, + const torch::Tensor& outputs, + float eps) { + if (bins.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(weights); + CHECK_CONTIGUOUS_CUDA(outputs); + SamplePdfCuda(bins, weights, outputs, eps); + return; +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + CHECK_CONTIGUOUS(outputs); + SamplePdfCpu(bins, weights, outputs, eps); +} diff --git a/pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp b/pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp new file mode 100644 index 000000000..c33005206 --- /dev/null +++ b/pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// If the number of bins is the typical 64, it is +// quicker to use binary search than linear scan. +// With more bins, it is more important. +// There is no equivalent CUDA implementation yet. +#define USE_BINARY_SEARCH + +namespace { +// This worker function does the job of SamplePdf but only on +// batch elements in [start_batch, end_batch). +void SamplePdfCpu_worker( + const torch::Tensor& bins, + const torch::Tensor& weights, + const torch::Tensor& outputs, + float eps, + int64_t start_batch, + int64_t end_batch) { + const int64_t n_bins = weights.size(1); + const int64_t n_samples = outputs.size(1); + + auto bins_a = bins.accessor(); + auto weights_a = weights.accessor(); + float* __restrict__ output_p = + outputs.data_ptr() + start_batch * n_samples; + +#ifdef USE_BINARY_SEARCH + std::vector partial_sums(n_bins); +#endif + + for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch; + ++i_batch_elt) { + auto bin_a = bins_a[i_batch_elt]; + auto weight_a = weights_a[i_batch_elt]; + + // Here we do the work which has to be done once per batch element. + // i.e. (1) finding the total weight. (2) If using binary search, + // precompute the partial sums of the weights. + + float total_weight = 0; + for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) { + total_weight += weight_a[i_bin]; +#ifdef USE_BINARY_SEARCH + partial_sums[i_bin] = total_weight; +#endif + } + total_weight += eps; + + for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) { + // Here we are taking a single random quantile (which is stored + // in *output_p) and using it to make a single sample, which we + // write back to the same location. First we find which bin + // the quantile lives in, either by binary search in the + // precomputed partial sums, or by scanning through the weights. + + float uniform = total_weight * *output_p; +#ifdef USE_BINARY_SEARCH + int64_t i_bin = std::lower_bound( + partial_sums.begin(), --partial_sums.end(), uniform) - + partial_sums.begin(); + if (i_bin > 0) { + uniform -= partial_sums[i_bin - 1]; + } +#else + int64_t i_bin = 0; + while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) { + uniform -= weight_a[i_bin]; + ++i_bin; + } +#endif + + // Now i_bin identifies the bin the quantile lives in, we use + // straight line interpolation to find the position of the + // quantile within the bin, and write it to *output_p. + + float bin_start = bin_a[i_bin]; + float bin_end = bin_a[i_bin + 1]; + float bin_weight = weight_a[i_bin]; + float output_value = bin_start; + if (uniform > bin_weight) { + output_value = bin_end; + } else if (bin_weight > eps) { + output_value += (uniform / bin_weight) * (bin_end - bin_start); + } + *output_p = output_value; + ++output_p; + } + } +} + +} // anonymous namespace + +void SamplePdfCpu( + const torch::Tensor& bins, + const torch::Tensor& weights, + const torch::Tensor& outputs, + float eps) { + const int64_t batch_size = bins.size(0); + const int64_t max_threads = std::min(4, at::get_num_threads()); + const int64_t n_threads = std::min(max_threads, batch_size); + if (batch_size == 0) { + return; + } + + // SamplePdfCpu_worker does the work of this function. We send separate ranges + // of batch elements to that function in nThreads-1 separate threads. + + std::vector threads; + threads.reserve(n_threads - 1); + const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads; + int64_t start_batch = 0; + for (int iThread = 0; iThread < n_threads - 1; ++iThread) { + threads.emplace_back( + SamplePdfCpu_worker, + bins, + weights, + outputs, + eps, + start_batch, + start_batch + batch_elements_per_thread); + start_batch += batch_elements_per_thread; + } + + // The remaining batch elements are calculated in this threads. If nThreads is + // 1 then all the work happens in this line. + SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size); + for (auto&& thread : threads) { + thread.join(); + } +} diff --git a/pytorch3d/renderer/implicit/sample_pdf.py b/pytorch3d/renderer/implicit/sample_pdf.py index e3d7fedfc..d986b6829 100644 --- a/pytorch3d/renderer/implicit/sample_pdf.py +++ b/pytorch3d/renderer/implicit/sample_pdf.py @@ -6,6 +6,62 @@ import torch +from pytorch3d import _C + + +def sample_pdf( + bins: torch.Tensor, + weights: torch.Tensor, + n_samples: int, + det: bool = False, + eps: float = 1e-5, +) -> torch.Tensor: + """ + Samples probability density functions defined by bin edges `bins` and + the non-negative per-bin probabilities `weights`. + + Args: + bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins. + weights: Tensor of shape `(..., n_bins)` containing non-negative numbers + representing the probability of sampling the corresponding bin. + n_samples: The number of samples to draw from each set of bins. + det: If `False`, the sampling is random. `True` yields deterministic + uniformly-spaced sampling from the inverse cumulative density function. + eps: A constant preventing division by zero in case empty bins are present. + + Returns: + samples: Tensor of shape `(..., n_samples)` containing `n_samples` samples + drawn from each probability distribution. + + Refs: + [1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501 + """ + if torch.is_grad_enabled() and (bins.requires_grad or weights.requires_grad): + raise NotImplementedError("sample_pdf differentiability.") + if weights.min() <= -eps: + raise ValueError("Negative weights provided.") + batch_shape = bins.shape[:-1] + n_bins = weights.shape[-1] + if n_bins + 1 != bins.shape[-1] or weights.shape[:-1] != batch_shape: + shapes = f"{bins.shape}{weights.shape}" + raise ValueError("Inconsistent shapes of bins and weights: " + shapes) + output_shape = batch_shape + (n_samples,) + + if det: + u = torch.linspace(0.0, 1.0, n_samples, device=bins.device, dtype=torch.float32) + output = u.expand(output_shape).contiguous() + else: + output = torch.rand(output_shape, dtype=torch.float32, device=bins.device) + + # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. + _C.sample_pdf( + bins.reshape(-1, n_bins + 1), + weights.reshape(-1, n_bins), + output.reshape(-1, n_samples), + eps, + ) + + return output def sample_pdf_python( @@ -16,6 +72,12 @@ def sample_pdf_python( eps: float = 1e-5, ) -> torch.Tensor: """ + This is a pure python implementation of the `sample_pdf` function. + It may be faster than sample_pdf when the number of bins is very large, + because it behaves as O(batchsize * [n_bins + log(n_bins) * n_samples] ) + whereas sample_pdf behaves as O(batchsize * n_bins * n_samples). + For 64 bins sample_pdf is much faster. + Samples probability density functions defined by bin edges `bins` and the non-negative per-bin probabilities `weights`. diff --git a/tests/bm_sample_pdf.py b/tests/bm_sample_pdf.py index b56e62ccc..59a0f4cea 100644 --- a/tests/bm_sample_pdf.py +++ b/tests/bm_sample_pdf.py @@ -12,7 +12,7 @@ def bm_sample_pdf() -> None: - backends = ["python_cuda", "python_cpu"] + backends = ["python_cuda", "cuda", "python_cpu", "cpu"] kwargs_list = [] sample_counts = [64] diff --git a/tests/test_sample_pdf.py b/tests/test_sample_pdf.py index ed76cd839..4d1cf9ac5 100644 --- a/tests/test_sample_pdf.py +++ b/tests/test_sample_pdf.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. import unittest +from itertools import product import torch from common_testing import TestCaseMixin -from pytorch3d.renderer.implicit.sample_pdf import sample_pdf_python +from pytorch3d.renderer.implicit.sample_pdf import sample_pdf, sample_pdf_python class TestSamplePDF(TestCaseMixin, unittest.TestCase): @@ -23,9 +24,59 @@ def test_single_bin(self): calc = torch.linspace(17, 18, 100).expand(5, -1) self.assertClose(output, calc) + def test_simple_det(self): + for n_bins, n_samples, batch in product( + [7, 20], [2, 7, 31, 32, 33], [(), (1, 4), (31,), (32,), (33,)] + ): + weights = torch.rand(size=(batch + (n_bins,))) + bins = torch.cumsum(torch.rand(size=(batch + (n_bins + 1,))), dim=-1) + python = sample_pdf_python(bins, weights, n_samples, det=True) + + cpp = sample_pdf(bins, weights, n_samples, det=True) + self.assertClose(cpp, python, atol=2e-3) + + nthreads = torch.get_num_threads() + torch.set_num_threads(1) + cpp_singlethread = sample_pdf(bins, weights, n_samples, det=True) + self.assertClose(cpp_singlethread, python, atol=2e-3) + torch.set_num_threads(nthreads) + + device = torch.device("cuda:0") + cuda = sample_pdf( + bins.to(device), weights.to(device), n_samples, det=True + ).cpu() + + self.assertClose(cuda, python, atol=2e-3) + + def test_rand_cpu(self): + n_bins, n_samples, batch_size = 11, 17, 9 + weights = torch.rand(size=(batch_size, n_bins)) + bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1) + torch.manual_seed(1) + python = sample_pdf_python(bins, weights, n_samples) + torch.manual_seed(1) + cpp = sample_pdf(bins, weights, n_samples) + + self.assertClose(cpp, python, atol=2e-3) + + def test_rand_nogap(self): + # Case where random is actually deterministic + weights = torch.FloatTensor([0, 10, 0]) + bins = torch.FloatTensor([0, 10, 10, 25]) + n_samples = 8 + predicted = torch.full((n_samples,), 10.0) + python = sample_pdf_python(bins, weights, n_samples) + self.assertClose(python, predicted) + cpp = sample_pdf(bins, weights, n_samples) + self.assertClose(cpp, predicted) + + device = torch.device("cuda:0") + cuda = sample_pdf(bins.to(device), weights.to(device), n_samples).cpu() + self.assertClose(cuda, predicted) + @staticmethod def bm_fn(*, backend: str, n_samples, batch_size, n_bins): - f = sample_pdf_python + f = sample_pdf_python if "python" in backend else sample_pdf weights = torch.rand(size=(batch_size, n_bins)) bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)