Skip to content

Commit

Permalink
Implicit/Volume renderer
Browse files Browse the repository at this point in the history
Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`.

Reviewed By: gkioxari

Differential Revision: D24418791

fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
  • Loading branch information
davnov134 authored and facebook-github-bot committed Jan 6, 2021
1 parent e6a32bf commit b466c38
Show file tree
Hide file tree
Showing 8 changed files with 1,575 additions and 3 deletions.
3 changes: 3 additions & 0 deletions pytorch3d/renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
AbsorptionOnlyRaymarcher,
EmissionAbsorptionRaymarcher,
GridRaysampler,
ImplicitRenderer,
MonteCarloRaysampler,
NDCGridRaysampler,
RayBundle,
VolumeRenderer,
VolumeSampler,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)
Expand Down
1 change: 1 addition & 0 deletions pytorch3d/renderer/implicit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
from .utils import (
RayBundle,
ray_bundle_to_ray_points,
Expand Down
372 changes: 372 additions & 0 deletions pytorch3d/renderer/implicit/renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Callable, Tuple

import torch

from ...ops.utils import eyes
from ...structures import Volumes
from ...transforms import Transform3d
from ..cameras import CamerasBase
from .raysampling import RayBundle
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points


# The implicit renderer class should be initialized with a
# function for raysampling and a function for raymarching.

# During the forward pass:
# 1) The raysampler:
# - samples rays from input cameras
# - transforms the rays to world coordinates
# 2) The volumetric_function (which is a callable argument of the forwad pass)
# evaluates ray_densities and ray_features at the sampled ray-points.
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
# algorithm to render each ray.


class ImplicitRenderer(torch.nn.Module):
"""
A class for rendering a batch of implicit surfaces. The class should
be initialized with a raysampler and raymarcher class which both have
to be a `Callable`.
VOLUMETRIC_FUNCTION
The `forward` function of the renderer accepts as input the rendering cameras as well
as the `volumetric_function` `Callable`, which defines a field of opacity
and feature vectors over the 3D domain of the scene.
A standard `volumetric_function` has the following signature:
```
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor]
```
With the following arguments:
`ray_bundle`: A RayBundle object containing the following variables:
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting
the origins of the rendering rays.
`rays_directions`: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of rendering rays.
`rays_lengths`: A tensor of shape
`(minibatch, ..., num_points_per_ray)`containing the
lengths at which the ray points are sampled.
Calling `volumetric_function` then returns the following:
`rays_densities`: A tensor of shape
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing
the an opacity vector for each ray point.
`rays_features`: A tensor of shape
`(minibatch, ..., num_points_per_ray, feature_dim)` containing
the an feature vector for each ray point.
Example:
A simple volumetric function of a 0-centered
RGB sphere with a unit diameter is defined as follows:
```
def volumetric_function(
ray_bundle: RayBundle,
) -> Tuple[torch.Tensor, torch.Tensor]:
# first convert the ray origins, directions and lengths
# to 3D ray point locations in world coords
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# set the densities as an inverse sigmoid of the
# ray point distance from the sphere centroid
rays_densities = torch.sigmoid(
-100.0 * rays_points_world.norm(dim=-1, keepdim=True)
)
# set the ray features to RGB colors proportional
# to the 3D location of the projection of ray points
# on the sphere surface
rays_features = torch.nn.functional.normalize(
rays_points_world, dim=-1
) * 0.5 + 0.5
return rays_densities, rays_features
```
"""

def __init__(self, raysampler: Callable, raymarcher: Callable):
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that
describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the response of the
`volumetric_function` (an input to `self.forward`) evaluated
along the sampled rays, and renders the rays with a
ray-marching algorithm.
"""
super().__init__()

if not callable(raysampler):
raise ValueError('"raysampler" has to be a "Callable" object.')
if not callable(raymarcher):
raise ValueError('"raymarcher" has to be a "Callable" object.')

self.raysampler = raysampler
self.raymarcher = raymarcher

def forward(
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
) -> Tuple[torch.Tensor, RayBundle]:
"""
Render a batch of images using a volumetric function
represented as a callable (e.g. a Pytorch module).
Args:
cameras: A batch of cameras that render the scene. A `self.raysampler`
takes the cameras as input and samples rays that pass through the
domain of the volumentric function.
volumetric_function: A `Callable` that accepts the parametrizations
of the rendering rays and returns the densities and features
at the respective 3D of the rendering rays. Please refer to
the main class documentation for details.
Returns:
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the
sampled rendering rays.
"""

if not callable(volumetric_function):
raise ValueError('"volumetric_function" has to be a "Callable" object.')

# first call the ray sampler that returns the RayBundle parametrizing
# the rendering rays.
ray_bundle = self.raysampler(
cameras=cameras, volumetric_function=volumetric_function, **kwargs
)
# ray_bundle.origins - minibatch x ... x 3
# ray_bundle.directions - minibatch x ... x 3
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray
# ray_bundle.xys - minibatch x ... x 2

# given sampled rays, call the volumetric function that
# evaluates the densities and features at the locations of the
# ray points
rays_densities, rays_features = volumetric_function(
ray_bundle=ray_bundle, cameras=cameras, **kwargs
)
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim

# finally, march along the sampled rays to obtain the renders
images = self.raymarcher(
rays_densities=rays_densities,
rays_features=rays_features,
ray_bundle=ray_bundle,
**kwargs
)
# images - minibatch x ... x (feature_dim + opacity_dim)

return images, ray_bundle


# The volume renderer class should be initialized with a
# function for raysampling and a function for raymarching.

# During the forward pass:
# 1) The raysampler:
# - samples rays from input cameras
# - transforms the rays to world coordinates
# 2) The scene volumes (which are an argument of the forward function)
# are then sampled at the locations of the ray-points to generate
# ray_densities and ray_features.
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
# algorithm to render each ray.


class VolumeRenderer(torch.nn.Module):
"""
A class for rendering a batch of Volumes. The class should
be initialized with a raysampler and a raymarcher class which both have
to be a `Callable`.
"""

def __init__(
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
):
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that
describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the `volumes`
(an instance of `Volumes` input to `self.forward`)
sampled at the ray-points, and renders the rays with a
ray-marching algorithm.
sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest".
"""
super().__init__()

self.renderer = ImplicitRenderer(raysampler, raymarcher)
self._sample_mode = sample_mode

def forward(
self, cameras: CamerasBase, volumes: Volumes, **kwargs
) -> Tuple[torch.Tensor, RayBundle]:
"""
Render a batch of images using raymarching over rays cast through
input `Volumes`.
Args:
cameras: A batch of cameras that render the scene. A `self.raysampler`
takes the cameras as input and samples rays that pass through the
domain of the volumentric function.
volumes: An instance of the `Volumes` class representing a
batch of volumes that are being rendered.
Returns:
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the
sampled rendering rays.
"""
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
return self.renderer(
cameras=cameras, volumetric_function=volumetric_function, **kwargs
)


class VolumeSampler(torch.nn.Module):
"""
A class that allows to sample a batch of volumes `Volumes`
at 3D points sampled along projection rays.
"""

def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
"""
Args:
volumes: An instance of the `Volumes` class representing a
batch if volumes that are being rendered.
sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest".
"""
super().__init__()
if not isinstance(volumes, Volumes):
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
self._volumes = volumes
self._sample_mode = sample_mode

def _get_ray_directions_transform(self):
"""
Compose the ray-directions transform by removing the translation component
from the volume global-to-local coords transform.
"""
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix()
directions_transform_matrix = eyes(
4,
N=world2local.shape[0],
device=world2local.device,
dtype=world2local.dtype,
)
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3]
directions_transform = Transform3d(matrix=directions_transform_matrix)
return directions_transform

def forward(
self, ray_bundle: RayBundle, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given an input ray parametrization, the forward function samples
`self._volumes` at the respective 3D ray-points.
Args:
ray_bundle: A RayBundle object with the following fields:
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
Returns:
rays_densities: A tensor of shape
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the
densitity vectors sampled from the volume at the locations of
the ray points.
rays_features: A tensor of shape
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the
feature vectors sampled from the volume at the locations of
the ray points.
"""

# take out the interesting parts of ray_bundle
rays_origins_world = ray_bundle.origins
rays_directions_world = ray_bundle.directions
rays_lengths = ray_bundle.lengths

# validate the inputs
_validate_ray_bundle_variables(
rays_origins_world, rays_directions_world, rays_lengths
)
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]:
raise ValueError("Input volumes have to have the same batch size as rays.")

#########################################################
# 1) convert the origins/directions to the local coords #
#########################################################

# origins are mapped with the world_to_local transform of the volumes
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world)

# obtain the Transform3d object that transforms ray directions to local coords
directions_transform = self._get_ray_directions_transform()

# transform the directions to the local coords
rays_directions_local = directions_transform.transform_points(
rays_directions_world.view(rays_lengths.shape[0], -1, 3)
).view(rays_directions_world.shape)

############################
# 2) obtain the ray points #
############################

# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3)
rays_points_local = ray_bundle_variables_to_ray_points(
rays_origins_local, rays_directions_local, rays_lengths
)

########################
# 3) sample the volume #
########################

# generate the tensor for sampling
volumes_densities = self._volumes.densities()
dim_density = volumes_densities.shape[1]
volumes_features = self._volumes.features()
# adjust the volumes_features variable in case we have a feature-less volume
if volumes_features is None:
dim_feature = 0
data_to_sample = volumes_densities
else:
dim_feature = volumes_features.shape[1]
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)

# reshape to a size which grid_sample likes
rays_points_local_flat = rays_points_local.view(
rays_points_local.shape[0], -1, 1, 1, 3
)

# run the grid sampler
data_sampled = torch.nn.functional.grid_sample(
data_to_sample,
rays_points_local_flat,
align_corners=True,
mode=self._sample_mode,
)

# permute the dimensions & reshape after sampling
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
*rays_points_local.shape[:-1], data_sampled.shape[1]
)

# split back to densities and features
rays_densities, rays_features = data_sampled.split(
[dim_density, dim_feature], dim=-1
)

return rays_densities, rays_features
Loading

0 comments on commit b466c38

Please sign in to comment.