Skip to content

Commit

Permalink
Accumulate points (#4)
Browse files Browse the repository at this point in the history
Summary:
Code for accumulating points in the z-buffer in three ways:
1. weighted sum
2. normalised weighted sum
3. alpha compositing

Pull Request resolved: fairinternal/pytorch3d#4

Reviewed By: nikhilaravi

Differential Revision: D20522422

Pulled By: gkioxari

fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
  • Loading branch information
oawiles authored and facebook-github-bot committed Mar 19, 2020
1 parent 5218f45 commit 5359977
Show file tree
Hide file tree
Showing 21 changed files with 2,466 additions and 4 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/deform_source_mesh_to_target_mesh.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,9 @@
"provenance": []
},
"kernelspec": {
"display_name": "pytorch3d (local)",
"display_name": "Python 3",
"language": "python",
"name": "pytorch3d_local"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -687,7 +687,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
"version": "3.7.6"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
303 changes: 303 additions & 0 deletions docs/tutorials/render_coloured_points.ipynb

Large diffs are not rendered by default.

187 changes: 187 additions & 0 deletions pytorch3d/csrc/compositing/alpha_composite.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <stdio.h>
#include <vector>

// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void alphaCompositeCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);

// Get the batch and index
const int batch = blockIdx.x;

const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Iterate over each feature in each pixel
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;

// alphacomposite the different values
float cum_alpha = 1.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];

// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}

float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&result[batch][ch][j][i], features[ch][n_idx] * cum_alpha * alpha);
cum_alpha = cum_alpha * (1 - alpha);
}
}
}

// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void alphaCompositeCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);

// Get the batch and index
const int batch = blockIdx.x;

const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;

// alphacomposite the different values
float cum_alpha = 1.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];

// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas[batch][k][j][i];

// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&grad_alphas[batch][k][j][i],
cum_alpha * features[ch][n_idx] * grad_outputs[batch][ch][j][i]);
atomicAdd(
&grad_features[ch][n_idx],
cum_alpha * alpha * grad_outputs[batch][ch][j][i]);

// Iterate over all (K-1) nearest points to update gradient
for (int t = 0; t < k; ++t) {
int t_idx = points_idx[batch][t][j][i];
// Sentinel value is -1, indicating no point overlaps this pixel
if (t_idx < 0) {
continue;
}
float alpha_tvalue = alphas[batch][t][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&grad_alphas[batch][t][j][i],
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha *
alpha / (1 - alpha_tvalue));
}

cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]);
}
}
}

torch::Tensor alphaCompositeCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);

auto result = torch::zeros({batch_size, C, H, W}, features.options());

const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);

// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on

return result;
}

std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
auto grad_features = torch::zeros_like(features);
auto grad_alphas = torch::zeros_like(alphas);

const int64_t bs = alphas.size(0);

const dim3 threadsPerBlock(64);
const dim3 numBlocks(bs, 1024 / bs + 1);

// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on

return std::make_tuple(grad_features, grad_alphas);
}
110 changes: 110 additions & 0 deletions pytorch3d/csrc/compositing/alpha_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

#include <torch/extension.h>
#include "pytorch3d_cutils.h"

#include <vector>

// Perform alpha compositing of points in a z-buffer.
//
// Inputs:
// features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and
// P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
// points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
// indices of the nearest points at each pixel, sorted in z-order.
// Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
// feature for each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
// features[c,points_idx[b,k,i,j]]
// where cum_alpha_k =
// alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])

// CUDA declarations
#ifdef WITH_CUDA
torch::Tensor alphaCompositeCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);

std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
#endif

// C++ declarations
torch::Tensor alphaCompositeCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);

std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);

torch::Tensor alphaCompositeForward(
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();

if (features.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return alphaCompositeCudaForward(features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);

return alphaCompositeCpuForward(features, alphas, points_idx);
}
}

std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
torch::Tensor& grad_outputs,
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
grad_outputs = grad_outputs.contiguous();
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();

if (grad_outputs.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(grad_outputs);
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif

return alphaCompositeCudaBackward(
grad_outputs, features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(grad_outputs);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);

return alphaCompositeCpuBackward(
grad_outputs, features, alphas, points_idx);
}
}
Loading

0 comments on commit 5359977

Please sign in to comment.