-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Code for accumulating points in the z-buffer in three ways: 1. weighted sum 2. normalised weighted sum 3. alpha compositing Pull Request resolved: fairinternal/pytorch3d#4 Reviewed By: nikhilaravi Differential Revision: D20522422 Pulled By: gkioxari fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
- Loading branch information
1 parent
5218f45
commit 5359977
Showing
21 changed files
with
2,466 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
#include <torch/extension.h> | ||
|
||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
#include <stdio.h> | ||
#include <vector> | ||
|
||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles. | ||
// Currently, support is for floats only. | ||
__global__ void alphaCompositeCudaForwardKernel( | ||
// clang-format off | ||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result, | ||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features, | ||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas, | ||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) { | ||
// clang-format on | ||
const int64_t batch_size = result.size(0); | ||
const int64_t C = features.size(0); | ||
const int64_t H = points_idx.size(2); | ||
const int64_t W = points_idx.size(3); | ||
|
||
// Get the batch and index | ||
const int batch = blockIdx.x; | ||
|
||
const int num_pixels = C * W * H; | ||
const int num_threads = gridDim.y * blockDim.x; | ||
const int tid = blockIdx.y * blockDim.x + threadIdx.x; | ||
|
||
// Iterate over each feature in each pixel | ||
for (int pid = tid; pid < num_pixels; pid += num_threads) { | ||
int ch = pid / (W * H); | ||
int j = (pid % (W * H)) / H; | ||
int i = (pid % (W * H)) % H; | ||
|
||
// alphacomposite the different values | ||
float cum_alpha = 1.; | ||
// Iterate through the closest K points for this pixel | ||
for (int k = 0; k < points_idx.size(1); ++k) { | ||
int n_idx = points_idx[batch][k][j][i]; | ||
|
||
// Sentinel value is -1 indicating no point overlaps the pixel | ||
if (n_idx < 0) { | ||
continue; | ||
} | ||
|
||
float alpha = alphas[batch][k][j][i]; | ||
// TODO(gkioxari) It might be more efficient to have threads write in a | ||
// local variable, and move atomicAdd outside of the loop such that | ||
// atomicAdd is executed once per thread. | ||
atomicAdd( | ||
&result[batch][ch][j][i], features[ch][n_idx] * cum_alpha * alpha); | ||
cum_alpha = cum_alpha * (1 - alpha); | ||
} | ||
} | ||
} | ||
|
||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles. | ||
// Currently, support is for floats only. | ||
__global__ void alphaCompositeCudaBackwardKernel( | ||
// clang-format off | ||
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features, | ||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas, | ||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs, | ||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features, | ||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas, | ||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) { | ||
// clang-format on | ||
const int64_t batch_size = points_idx.size(0); | ||
const int64_t C = features.size(0); | ||
const int64_t H = points_idx.size(2); | ||
const int64_t W = points_idx.size(3); | ||
|
||
// Get the batch and index | ||
const int batch = blockIdx.x; | ||
|
||
const int num_pixels = C * W * H; | ||
const int num_threads = gridDim.y * blockDim.x; | ||
const int tid = blockIdx.y * blockDim.x + threadIdx.x; | ||
|
||
// Parallelize over each feature in each pixel in images of size H * W, | ||
// for each image in the batch of size batch_size | ||
for (int pid = tid; pid < num_pixels; pid += num_threads) { | ||
int ch = pid / (W * H); | ||
int j = (pid % (W * H)) / H; | ||
int i = (pid % (W * H)) % H; | ||
|
||
// alphacomposite the different values | ||
float cum_alpha = 1.; | ||
// Iterate through the closest K points for this pixel | ||
for (int k = 0; k < points_idx.size(1); ++k) { | ||
int n_idx = points_idx[batch][k][j][i]; | ||
|
||
// Sentinel value is -1 indicating no point overlaps the pixel | ||
if (n_idx < 0) { | ||
continue; | ||
} | ||
float alpha = alphas[batch][k][j][i]; | ||
|
||
// TODO(gkioxari) It might be more efficient to have threads write in a | ||
// local variable, and move atomicAdd outside of the loop such that | ||
// atomicAdd is executed once per thread. | ||
atomicAdd( | ||
&grad_alphas[batch][k][j][i], | ||
cum_alpha * features[ch][n_idx] * grad_outputs[batch][ch][j][i]); | ||
atomicAdd( | ||
&grad_features[ch][n_idx], | ||
cum_alpha * alpha * grad_outputs[batch][ch][j][i]); | ||
|
||
// Iterate over all (K-1) nearest points to update gradient | ||
for (int t = 0; t < k; ++t) { | ||
int t_idx = points_idx[batch][t][j][i]; | ||
// Sentinel value is -1, indicating no point overlaps this pixel | ||
if (t_idx < 0) { | ||
continue; | ||
} | ||
float alpha_tvalue = alphas[batch][t][j][i]; | ||
// TODO(gkioxari) It might be more efficient to have threads write in a | ||
// local variable, and move atomicAdd outside of the loop such that | ||
// atomicAdd is executed once per thread. | ||
atomicAdd( | ||
&grad_alphas[batch][t][j][i], | ||
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha * | ||
alpha / (1 - alpha_tvalue)); | ||
} | ||
|
||
cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]); | ||
} | ||
} | ||
} | ||
|
||
torch::Tensor alphaCompositeCudaForward( | ||
const torch::Tensor& features, | ||
const torch::Tensor& alphas, | ||
const torch::Tensor& points_idx) { | ||
const int64_t batch_size = points_idx.size(0); | ||
const int64_t C = features.size(0); | ||
const int64_t H = points_idx.size(2); | ||
const int64_t W = points_idx.size(3); | ||
|
||
auto result = torch::zeros({batch_size, C, H, W}, features.options()); | ||
|
||
const dim3 threadsPerBlock(64); | ||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1); | ||
|
||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports | ||
// doubles. Currently, support is for floats only. | ||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>( | ||
// clang-format off | ||
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), | ||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(), | ||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), | ||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>()); | ||
// clang-format on | ||
|
||
return result; | ||
} | ||
|
||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward( | ||
const torch::Tensor& grad_outputs, | ||
const torch::Tensor& features, | ||
const torch::Tensor& alphas, | ||
const torch::Tensor& points_idx) { | ||
auto grad_features = torch::zeros_like(features); | ||
auto grad_alphas = torch::zeros_like(alphas); | ||
|
||
const int64_t bs = alphas.size(0); | ||
|
||
const dim3 threadsPerBlock(64); | ||
const dim3 numBlocks(bs, 1024 / bs + 1); | ||
|
||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports | ||
// doubles. Currently, support is for floats only. | ||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>( | ||
// clang-format off | ||
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(), | ||
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), | ||
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), | ||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(), | ||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(), | ||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>()); | ||
// clang-format on | ||
|
||
return std::make_tuple(grad_features, grad_alphas); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
#include <torch/extension.h> | ||
#include "pytorch3d_cutils.h" | ||
|
||
#include <vector> | ||
|
||
// Perform alpha compositing of points in a z-buffer. | ||
// | ||
// Inputs: | ||
// features: FloatTensor of shape (C, P) which gives the features | ||
// of each point where C is the size of the feature and | ||
// P the number of points. | ||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where | ||
// points_per_pixel is the number of points in the z-buffer | ||
// sorted in z-order, and W is the image size. | ||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the | ||
// indices of the nearest points at each pixel, sorted in z-order. | ||
// Returns: | ||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated | ||
// feature for each point. Concretely, it gives: | ||
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k * | ||
// features[c,points_idx[b,k,i,j]] | ||
// where cum_alpha_k = | ||
// alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j]) | ||
|
||
// CUDA declarations | ||
#ifdef WITH_CUDA | ||
torch::Tensor alphaCompositeCudaForward( | ||
const torch::Tensor& features, | ||
const torch::Tensor& alphas, | ||
const torch::Tensor& points_idx); | ||
|
||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward( | ||
const torch::Tensor& grad_outputs, | ||
const torch::Tensor& features, | ||
const torch::Tensor& alphas, | ||
const torch::Tensor& points_idx); | ||
#endif | ||
|
||
// C++ declarations | ||
torch::Tensor alphaCompositeCpuForward( | ||
const torch::Tensor& features, | ||
const torch::Tensor& alphas, | ||
const torch::Tensor& points_idx); | ||
|
||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward( | ||
const torch::Tensor& grad_outputs, | ||
const torch::Tensor& features, | ||
const torch::Tensor& alphas, | ||
const torch::Tensor& points_idx); | ||
|
||
torch::Tensor alphaCompositeForward( | ||
torch::Tensor& features, | ||
torch::Tensor& alphas, | ||
torch::Tensor& points_idx) { | ||
features = features.contiguous(); | ||
alphas = alphas.contiguous(); | ||
points_idx = points_idx.contiguous(); | ||
|
||
if (features.type().is_cuda()) { | ||
#ifdef WITH_CUDA | ||
CHECK_CONTIGUOUS_CUDA(features); | ||
CHECK_CONTIGUOUS_CUDA(alphas); | ||
CHECK_CONTIGUOUS_CUDA(points_idx); | ||
#else | ||
AT_ERROR("Not compiled with GPU support"); | ||
#endif | ||
return alphaCompositeCudaForward(features, alphas, points_idx); | ||
} else { | ||
CHECK_CONTIGUOUS(features); | ||
CHECK_CONTIGUOUS(alphas); | ||
CHECK_CONTIGUOUS(points_idx); | ||
|
||
return alphaCompositeCpuForward(features, alphas, points_idx); | ||
} | ||
} | ||
|
||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward( | ||
torch::Tensor& grad_outputs, | ||
torch::Tensor& features, | ||
torch::Tensor& alphas, | ||
torch::Tensor& points_idx) { | ||
grad_outputs = grad_outputs.contiguous(); | ||
features = features.contiguous(); | ||
alphas = alphas.contiguous(); | ||
points_idx = points_idx.contiguous(); | ||
|
||
if (grad_outputs.type().is_cuda()) { | ||
#ifdef WITH_CUDA | ||
CHECK_CONTIGUOUS_CUDA(grad_outputs); | ||
CHECK_CONTIGUOUS_CUDA(features); | ||
CHECK_CONTIGUOUS_CUDA(alphas); | ||
CHECK_CONTIGUOUS_CUDA(points_idx); | ||
#else | ||
AT_ERROR("Not compiled with GPU support"); | ||
#endif | ||
|
||
return alphaCompositeCudaBackward( | ||
grad_outputs, features, alphas, points_idx); | ||
} else { | ||
CHECK_CONTIGUOUS(grad_outputs); | ||
CHECK_CONTIGUOUS(features); | ||
CHECK_CONTIGUOUS(alphas); | ||
CHECK_CONTIGUOUS(points_idx); | ||
|
||
return alphaCompositeCpuBackward( | ||
grad_outputs, features, alphas, points_idx); | ||
} | ||
} |
Oops, something went wrong.