From 12f20d799e06e4ef16dd32686bf5a33e87c56c91 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Wed, 9 Feb 2022 12:48:47 -0800 Subject: [PATCH] Convert from Pytorch3D NDC coordinates to grid_sample coordinates. Summary: Implements a utility function to convert from 2D coordinates in Pytorch3D NDC space to the coordinates in grid_sample. Reviewed By: shapovalov Differential Revision: D33741394 fbshipit-source-id: 88981653356588fe646e6dea48fe7f7298738437 --- pytorch3d/renderer/__init__.py | 7 +- pytorch3d/renderer/utils.py | 79 ++++++++++++++- tests/test_rendering_utils.py | 177 ++++++++++++++++++++++++++++++++- 3 files changed, 260 insertions(+), 3 deletions(-) diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index ef82733d2..c0a2d13c3 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -70,7 +70,12 @@ PulsarPointsRenderer, rasterize_points, ) -from .utils import TensorProperties, convert_to_tensors_and_broadcast +from .utils import ( + TensorProperties, + convert_to_tensors_and_broadcast, + ndc_to_grid_sample_coords, + ndc_grid_sample, +) __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 7d6f5acf3..3984cf141 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -8,7 +8,7 @@ import copy import inspect import warnings -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Tuple import numpy as np import torch @@ -350,3 +350,80 @@ def convert_to_tensors_and_broadcast( args_Nd.append(c.expand(*expand_sizes)) return args_Nd + + +def ndc_grid_sample( + input: torch.Tensor, + grid_ndc: torch.Tensor, + **grid_sample_kwargs, +) -> torch.Tensor: + """ + Samples a tensor `input` of shape `(B, dim, H, W)` at 2D locations + specified by a tensor `grid_ndc` of shape `(B, ..., 2)` using + the `torch.nn.functional.grid_sample` function. + `grid_ndc` is specified in PyTorch3D NDC coordinate frame. + + Args: + input: The tensor of shape `(B, dim, H, W)` to be sampled. + grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of + 2D locations at which `input` is sampled. + See [1] for a detailed description of the NDC coordinates. + grid_sample_kwargs: Additional arguments forwarded to the + `torch.nn.functional.grid_sample` call. See the corresponding + docstring for a listing of the corresponding arguments. + + Returns: + sampled_input: A tensor of shape `(B, dim, ...)` containing the samples + of `input` at 2D locations `grid_ndc`. + + References: + [1] https://pytorch3d.org/docs/cameras + """ + + batch, *spatial_size, pt_dim = grid_ndc.shape + if batch != input.shape[0]: + raise ValueError("'input' and 'grid_ndc' have to have the same batch size.") + if input.ndim != 4: + raise ValueError("'input' has to be a 4-dimensional Tensor.") + if pt_dim != 2: + raise ValueError("The last dimension of 'grid_ndc' has to be == 2.") + + grid_ndc_flat = grid_ndc.reshape(batch, -1, 1, 2) + + grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:]) + + sampled_input_flat = torch.nn.functional.grid_sample( + input, grid_flat, **grid_sample_kwargs + ) + + sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size]) + + return sampled_input + + +def ndc_to_grid_sample_coords( + xy_ndc: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.Tensor: + """ + Convert from the PyTorch3D's NDC coordinates to + `torch.nn.functional.grid_sampler`'s coordinates. + + Args: + xy_ndc: Tensor of shape `(..., 2)` containing 2D points in the + PyTorch3D's NDC coordinates. + image_size_hw: A tuple `(image_height, image_width)` denoting the + height and width of the image tensor to sample. + Returns: + xy_grid_sample: Tensor of shape `(..., 2)` containing 2D points in the + `torch.nn.functional.grid_sample` coordinates. + """ + if len(image_size_hw) != 2 or any(s <= 0 for s in image_size_hw): + raise ValueError("'image_size_hw' has to be a 2-tuple of positive integers") + aspect = min(image_size_hw) / max(image_size_hw) + xy_grid_sample = -xy_ndc # first negate the coords + if image_size_hw[0] >= image_size_hw[1]: + xy_grid_sample[..., 1] *= aspect + else: + xy_grid_sample[..., 0] *= aspect + return xy_grid_sample diff --git a/tests/test_rendering_utils.py b/tests/test_rendering_utils.py index e803fd3cc..6037a6598 100644 --- a/tests/test_rendering_utils.py +++ b/tests/test_rendering_utils.py @@ -10,7 +10,20 @@ import numpy as np import torch from common_testing import TestCaseMixin -from pytorch3d.renderer.utils import TensorProperties +from pytorch3d.ops import eyes +from pytorch3d.renderer import ( + PerspectiveCameras, + AlphaCompositor, + PointsRenderer, + PointsRasterizationSettings, + PointsRasterizer, +) +from pytorch3d.renderer.utils import ( + TensorProperties, + ndc_to_grid_sample_coords, + ndc_grid_sample, +) +from pytorch3d.structures import Pointclouds # Example class for testing @@ -96,3 +109,165 @@ def test_gather_props(self): # the input. self.assertClose(test_class_gathered.x[inds].mean(dim=0), x[i, ...]) self.assertClose(test_class_gathered.y[inds].mean(dim=0), y[i, ...]) + + def test_ndc_grid_sample_rendering(self): + """ + Use PyTorch3D point renderer to render a colored point cloud, then + sample the image at the locations of the point projections with + `ndc_grid_sample`. Finally, assert that the sampled colors are equal to the + original point cloud colors. + + Note that, in order to ensure correctness, we use a nearest-neighbor + assignment point renderer (i.e. no soft splatting). + """ + + # generate a bunch of 3D points on a regular grid lying in the z-plane + n_grid_pts = 10 + grid_scale = 0.9 + z_plane = 2.0 + image_size = [128, 128] + point_radius = 0.015 + n_pts = n_grid_pts * n_grid_pts + pts = torch.stack( + torch.meshgrid( + [torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, indexing="ij" + ), + dim=-1, + ) + pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1) + pts = pts.reshape(1, n_pts, 3) + + # color the points randomly + pts_colors = torch.rand(1, n_pts, 3) + + # make trivial rendering cameras + cameras = PerspectiveCameras( + R=eyes(dim=3, N=1), + device=pts.device, + T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device), + ) + + # render the point cloud + pcl = Pointclouds(points=pts, features=pts_colors) + renderer = NearestNeighborPointsRenderer( + rasterizer=PointsRasterizer( + cameras=cameras, + raster_settings=PointsRasterizationSettings( + image_size=image_size, + radius=point_radius, + points_per_pixel=1, + ), + ), + compositor=AlphaCompositor(), + ) + im_render = renderer(pcl) + + # sample the render at projected pts + pts_proj = cameras.transform_points(pcl.points_padded())[..., :2] + pts_colors_sampled = ndc_grid_sample( + im_render, + pts_proj, + mode="nearest", + align_corners=False, + ).permute(0, 2, 1) + + # assert that the samples are the same as original points + self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4) + + def test_ndc_to_grid_sample_coords(self): + """ + Test the conversion from ndc to grid_sample coords by comparing + to known conversion results. + """ + + # square image tests + image_size_square = [100, 100] + xy_ndc_gs_square = torch.FloatTensor( + [ + # 4 corners + [[-1.0, -1.0], [1.0, 1.0]], + [[1.0, 1.0], [-1.0, -1.0]], + [[1.0, -1.0], [-1.0, 1.0]], + [[1.0, 1.0], [-1.0, -1.0]], + # center + [[0.0, 0.0], [0.0, 0.0]], + ] + ) + + # non-batched version + for xy_ndc, xy_gs in xy_ndc_gs_square: + xy_gs_predicted = ndc_to_grid_sample_coords( + xy_ndc, + image_size_square, + ) + self.assertClose(xy_gs_predicted, xy_gs) + + # batched version + xy_ndc, xy_gs = xy_ndc_gs_square[:, 0], xy_ndc_gs_square[:, 1] + xy_gs_predicted = ndc_to_grid_sample_coords( + xy_ndc, + image_size_square, + ) + self.assertClose(xy_gs_predicted, xy_gs) + + # non-square image tests + image_size = [100, 200] + xy_ndc_gs = torch.FloatTensor( + [ + # 4 corners + [[-2.0, -1.0], [1.0, 1.0]], + [[2.0, -1.0], [-1.0, 1.0]], + [[-2.0, 1.0], [1.0, -1.0]], + [[2.0, 1.0], [-1.0, -1.0]], + # center + [[0.0, 0.0], [0.0, 0.0]], + # non-corner points + [[4.0, 0.5], [-2.0, -0.5]], + [[1.0, -0.5], [-0.5, 0.5]], + ] + ) + + # check both H > W and W > H + for flip_axes in [False, True]: + + # non-batched version + for xy_ndc, xy_gs in xy_ndc_gs: + xy_gs_predicted = ndc_to_grid_sample_coords( + xy_ndc.flip(dims=(-1,)) if flip_axes else xy_ndc, + list(reversed(image_size)) if flip_axes else image_size, + ) + self.assertClose( + xy_gs_predicted, + xy_gs.flip(dims=(-1,)) if flip_axes else xy_gs, + ) + + # batched version + xy_ndc, xy_gs = xy_ndc_gs[:, 0], xy_ndc_gs[:, 1] + xy_gs_predicted = ndc_to_grid_sample_coords( + xy_ndc.flip(dims=(-1,)) if flip_axes else xy_ndc, + list(reversed(image_size)) if flip_axes else image_size, + ) + self.assertClose( + xy_gs_predicted, + xy_gs.flip(dims=(-1,)) if flip_axes else xy_gs, + ) + + +class NearestNeighborPointsRenderer(PointsRenderer): + """ + A class for rendering a batch of points by a trivial nearest + neighbor assignment. + """ + + def forward(self, point_clouds, **kwargs) -> torch.Tensor: + fragments = self.rasterizer(point_clouds, **kwargs) + # set all weights trivially to one + dists2 = fragments.dists.permute(0, 3, 1, 2) + weights = torch.ones_like(dists2) + images = self.compositor( + fragments.idx.long().permute(0, 3, 1, 2), + weights, + point_clouds.features_packed().permute(1, 0), + **kwargs, + ) + return images