Skip to content

Commit

Permalink
Non square image rasterization for meshes
Browse files Browse the repository at this point in the history
Summary:
There are a couple of options for supporting non square images:
1) NDC stays at [-1, 1] in both directions with the distance calculations all modified by (W/H). There are a lot of distance based calculations (e.g. triangle areas for barycentric coordinates etc) so this requires changes in many places.
2) NDC is scaled by (W/H) so the smallest side has [-1, 1]. In this case none of the distance calculations need to be updated and only the pixel to NDC calculation needs to be modified.

I decided to go with option 2 after trying option 1!

API Changes:
- Image size can now be specified optionally as a tuple

TODO:
- add a benchmark test for the non square case.

Reviewed By: jcjohnson

Differential Revision: D24404975

fbshipit-source-id: 545efb67c822d748ec35999b35762bce58db2cf4
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Dec 9, 2020
1 parent 0216e46 commit d07307a
Show file tree
Hide file tree
Showing 13 changed files with 774 additions and 115 deletions.
17 changes: 17 additions & 0 deletions docs/notes/renderer_getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ While we tried to emulate several aspects of OpenGL, there are differences in th

---

### Rasterizing Non Square Images

To rasterize an image where H != W, you can specify the `image_size` in the `RasterizationSettings` as a tuple of (H, W).

The aspect ratio needs special consideration. There are two aspect ratios to be aware of:
- the aspect ratio of each pixel
- the aspect ratio of the output image
In the cameras e.g. `FoVPerspectiveCameras`, the `aspect_ratio` argument can be used to set the pixel aspect ratio. In the rasterizer, we assume square pixels, but variable image aspect ratio (i.e rectangle images).

In most cases you will want to set the camera aspect ratio to 1.0 (i.e. square pixels) and only vary the `image_size` in the `RasterizationSettings`(i.e. the output image dimensions in pixels).

---

### The pulsar backend

Since v0.3, [pulsar](https://arxiv.org/abs/2004.07484) can be used as a backend for point-rendering. It has a focus on efficiency, which comes with pros and cons: it is highly optimized and all rendering stages are integrated in the CUDA kernels. This leads to significantly higher speed and better scaling behavior. We use it at Facebook Reality Labs to render and optimize scenes with millions of spheres in resolutions up to 4K. You can find a runtime comparison plot below (settings: `bin_size=None`, `points_per_pixel=5`, `image_size=1024`, `radius=1e-2`, `composite_params.radius=1e-4`; benchmarked on an RTX 2070 GPU).
Expand All @@ -75,6 +88,8 @@ For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturi

<img src="assets/texturing.jpg" width="1000">

---

### A simple renderer

A renderer in PyTorch3D is composed of a **rasterizer** and a **shader**. Create a renderer in a few simple steps:
Expand Down Expand Up @@ -108,6 +123,8 @@ renderer = MeshRenderer(
)
```

---

### A custom shader

Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.
Expand Down
129 changes: 83 additions & 46 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const int xi = W - 1 - pix_idx % W;

// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf);

// For keeping track of the K closest points we want a data structure
Expand All @@ -262,6 +262,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place.

CheckPixelInsideFace(
face_verts,
f,
Expand All @@ -280,6 +281,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + pix_idx * K;

for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
Expand All @@ -296,7 +298,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_faces_packed_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct,
Expand Down Expand Up @@ -332,8 +334,8 @@ RasterizeMeshesNaiveCuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images.
const int W = image_size;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = num_closest;

auto long_opts = num_faces_per_mesh.options().dtype(at::kLong);
Expand Down Expand Up @@ -405,8 +407,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;

const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf);

// Loop over all the faces for this pixel.
Expand Down Expand Up @@ -589,12 +591,25 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
int* bin_faces) {
extern __shared__ char sbuf[];
const int M = max_faces_per_bin;
const int num_bins = 1 + (W - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / W; // Size of half a pixel in NDC units
// Integer divide round up
const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;

// NDC range depends on the ratio of W/H
// The shorter side from (H, W) is given an NDC range of 2.0 and
// the other side is scaled by the ratio of H:W.
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;

// Size of half a pixel in NDC units is the NDC half range
// divided by the corresponding image dimension
const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H;

// This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);

// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
Expand Down Expand Up @@ -641,21 +656,24 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
}

// Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins; ++by) {
for (int by = 0; by < num_bins_y; ++by) {
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
// Reverse ordering of Y axis so that +Y is upwards in the image.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
const float bin_y_min =
PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
const float bin_y_max =
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);

for (int bx = 0; bx < num_bins; ++bx) {
for (int bx = 0; bx < num_bins_x; ++bx) {
// X coordinate of the left and right of the bin.
// Reverse ordering of x axis so that +X is left.
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
const float bin_x_min =
PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;

const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
Expand All @@ -668,12 +686,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
const int by = byx / num_bins;
const int bx = byx % num_bins;
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx);
const int faces_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx;
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;

// This atomically increments the (global) number of faces found
// in the current bin, and gets the previous value of the counter;
Expand All @@ -683,8 +702,8 @@ __global__ void RasterizeMeshesCoarseCudaKernel(

// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
bx * M + start;
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
by * num_bins_x * M + bx * M + start;
for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) {
// TODO(T54296346) find the correct method for handling errors in
Expand All @@ -703,7 +722,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
Expand All @@ -725,29 +744,35 @@ at::Tensor RasterizeMeshesCoarseCuda(
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int W = image_size;
const int H = image_size;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);

const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
const int M = max_faces_per_bin;

if (num_bins >= kMaxFacesPerBin) {
// Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;

if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);

if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}

const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;

Expand Down Expand Up @@ -782,7 +807,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
const bool clip_barycentric_coords,
const bool cull_backfaces,
const int N,
const int B,
const int BH,
const int BW,
const int M,
const int H,
const int W,
Expand All @@ -793,7 +819,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
float* bary // (N, S, S, K, 3)
) {
// This can be more than S^2 if S % bin_size != 0
int num_pixels = N * B * B * bin_size * bin_size;
int num_pixels = N * BH * BW * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;

Expand All @@ -803,20 +829,26 @@ __global__ void RasterizeMeshesFineCudaKernel(
// into the same bin; this should give them coalesced memory reads when
// they read from faces and bin_faces.
int i = pid;
const int n = i / (B * B * bin_size * bin_size);
i %= B * B * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size);
i %= B * bin_size * bin_size;
const int n = i / (BH * BW * bin_size * bin_size);
i %= BH * BW * bin_size * bin_size;
// bin index y
const int by = i / (BW * bin_size * bin_size);
i %= BW * bin_size * bin_size;
// bin index y
const int bx = i / (bin_size * bin_size);
// pixel within the bin
i %= bin_size * bin_size;

// Pixel x, y indices
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;

if (yi >= H || xi >= W)
continue;

const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);

const float2 pxy = make_float2(xf, yf);

// This part looks like the naive rasterization kernel, except we use
Expand All @@ -828,7 +860,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; m++) {
const int f = bin_faces[n * B * B * M + by * B * M + bx * M + m];
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
if (f < 0) {
continue; // bin_faces uses -1 as a sentinal value.
}
Expand Down Expand Up @@ -858,7 +890,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;

const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
for (int k = 0; k < q_size; k++) {
face_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
Expand All @@ -874,7 +907,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesFineCuda(
const at::Tensor& face_verts,
const at::Tensor& bin_faces,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
Expand All @@ -897,12 +930,15 @@ RasterizeMeshesFineCuda(
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// bin_faces shape (N, BH, BW, M)
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int BH = bin_faces.size(1);
const int BW = bin_faces.size(2);
const int M = bin_faces.size(3);
const int K = faces_per_pixel;
const int H = image_size; // Assume square images only.
const int W = image_size;

const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);

if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 150");
Expand Down Expand Up @@ -932,7 +968,8 @@ RasterizeMeshesFineCuda(
clip_barycentric_coords,
cull_backfaces,
N,
B,
BH,
BW,
M,
H,
W,
Expand Down
Loading

0 comments on commit d07307a

Please sign in to comment.