Skip to content

Commit

Permalink
Join cameras as batch
Browse files Browse the repository at this point in the history
Summary:
Function to join a list of cameras objects into a single batched object.

FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites.

Reviewed By: nikhilaravi

Differential Revision: D33198209

fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
  • Loading branch information
bottler authored and facebook-github-bot committed Jan 21, 2022
1 parent 9e2bc3a commit 39bb2ce
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 9 deletions.
4 changes: 2 additions & 2 deletions pytorch3d/renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from .camera_utils import rotate_on_spot
from .camera_utils import join_cameras_as_batch, rotate_on_spot
from .cameras import OpenGLOrthographicCameras # deprecated
from .cameras import OpenGLPerspectiveCameras # deprecated
from .cameras import SfMOrthographicCameras # deprecated
Expand All @@ -29,6 +29,7 @@
AbsorptionOnlyRaymarcher,
EmissionAbsorptionRaymarcher,
GridRaysampler,
HarmonicEmbedding,
ImplicitRenderer,
MonteCarloRaysampler,
NDCGridRaysampler,
Expand All @@ -37,7 +38,6 @@
VolumeSampler,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
HarmonicEmbedding,
)
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
from .materials import Materials
Expand Down
66 changes: 65 additions & 1 deletion pytorch3d/renderer/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple
from typing import Sequence, Tuple

import torch
from pytorch3d.transforms import Transform3d

from .cameras import CamerasBase


def camera_to_eye_at_up(
world_to_view_transform: Transform3d,
Expand Down Expand Up @@ -141,3 +143,65 @@ def rotate_on_spot(
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]

return new_R, new_T


def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
"""
Create a batched cameras object by concatenating a list of input
cameras objects. All the tensor attributes will be joined along
the batch dimension.
Args:
cameras_list: List of camera classes all of the same type and
on the same device. Each represents one or more cameras.
Returns:
cameras: single batched cameras object of the same
type as all the objects in the input list.
"""
# Get the type and fields to join from the first camera in the batch
c0 = cameras_list[0]
fields = c0._FIELDS
shared_fields = c0._SHARED_FIELDS

if not all(isinstance(c, CamerasBase) for c in cameras_list):
raise ValueError("cameras in cameras_list must inherit from CamerasBase")

if not all(type(c) is type(c0) for c in cameras_list[1:]):
raise ValueError("All cameras must be of the same type")

if not all(c.device == c0.device for c in cameras_list[1:]):
raise ValueError("All cameras in the batch must be on the same device")

# Concat the fields to make a batched tensor
kwargs = {}
kwargs["device"] = c0.device

for field in fields:
field_not_none = [(getattr(c, field) is not None) for c in cameras_list]
if not any(field_not_none):
continue
if not all(field_not_none):
raise ValueError(f"Attribute {field} is inconsistently present")

attrs_list = [getattr(c, field) for c in cameras_list]

if field in shared_fields:
# Only needs to be set once
if not all(a == attrs_list[0] for a in attrs_list):
raise ValueError(f"Attribute {field} is not constant across inputs")

# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
# but provided as "in_ndc" in the input args
if field.startswith("_"):
field = field[1:]

kwargs[field] = attrs_list[0]
elif isinstance(attrs_list[0], torch.Tensor):
# In the init, all inputs will be converted to
# batched tensors before set as attributes
# Join as a tensor along the batch dimension
kwargs[field] = torch.cat(attrs_list, dim=0)
else:
raise ValueError(f"Field {field} type is not supported for batching")

return c0.__class__(**kwargs)
25 changes: 24 additions & 1 deletion pytorch3d/renderer/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ class CamerasBase(TensorProperties):

# Used in __getitem__ to index the relevant fields
# When creating a new camera, this should be set in the __init__
_FIELDS: Tuple = ()
_FIELDS: Tuple[str, ...] = ()

# Names of fields which are a constant property of the whole batch, rather
# than themselves a batch of data.
# When joining objects into a batch, they will have to agree.
_SHARED_FIELDS: Tuple[str, ...] = ()

def get_projection_transform(self):
"""
Expand Down Expand Up @@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase):
"degrees",
)

_SHARED_FIELDS = ("degrees",)

def __init__(
self,
znear=1.0,
Expand Down Expand Up @@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase):
"image_size",
)

_SHARED_FIELDS = ("_in_ndc",)

def __init__(
self,
focal_length=1.0,
Expand Down Expand Up @@ -1047,6 +1056,12 @@ def __init__(
else:
self.image_size = None

# When focal length is provided as one value, expand to
# create (N, 2) shape tensor
if self.focal_length.ndim == 1: # (N,)
self.focal_length = self.focal_length[:, None] # (N, 1)
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)

def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the projection matrix using the
Expand Down Expand Up @@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase):
"image_size",
)

_SHARED_FIELDS = ("_in_ndc",)

def __init__(
self,
focal_length=1.0,
Expand Down Expand Up @@ -1276,6 +1293,12 @@ def __init__(
else:
self.image_size = None

# When focal length is provided as one value, expand to
# create (N, 2) shape tensor
if self.focal_length.ndim == 1: # (N,)
self.focal_length = self.focal_length[:, None] # (N, 1)
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)

def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the projection matrix using
Expand Down
3 changes: 0 additions & 3 deletions tests/test_camera_pixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,4 @@ def test_camera(self):
],
dim=1,
)

print(wanted)
print(camera_points[batch_idx])
self.assertClose(camera_points[batch_idx], wanted)
98 changes: 96 additions & 2 deletions tests/test_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import (
CamerasBase,
FoVOrthographicCameras,
Expand Down Expand Up @@ -688,6 +689,99 @@ def test_clone(self, batch_size: int = 10):
else:
self.assertTrue(val == val_clone)

def test_join_cameras_as_batch_errors(self):
cam0 = PerspectiveCameras(device="cuda:0")
cam1 = OrthographicCameras(device="cuda:0")

# Cameras not of the same type
with self.assertRaisesRegex(ValueError, "same type"):
join_cameras_as_batch([cam0, cam1])

cam2 = OrthographicCameras(device="cpu")
# Cameras not on the same device
with self.assertRaisesRegex(ValueError, "same device"):
join_cameras_as_batch([cam1, cam2])

cam3 = OrthographicCameras(in_ndc=False, device="cuda:0")
# Different coordinate systems -- all should be in ndc or in screen
with self.assertRaisesRegex(
ValueError, "Attribute _in_ndc is not constant across inputs"
):
join_cameras_as_batch([cam1, cam3])

def join_cameras_as_batch_fov(self, camera_cls):
R0 = torch.randn((6, 3, 3))
R1 = torch.randn((3, 3, 3))
cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0")
cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0")

cam_batch = join_cameras_as_batch([cam0, cam1])

self.assertEqual(cam_batch._N, cam0._N + cam1._N)
self.assertEqual(cam_batch.device, cam0.device)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0"))

def join_cameras_as_batch(self, camera_cls):
R0 = torch.randn((6, 3, 3))
R1 = torch.randn((3, 3, 3))
p0 = torch.randn((6, 2, 1))
p1 = torch.randn((3, 2, 1))
f0 = 5.0
f1 = torch.randn(3, 2)
f2 = torch.randn(3, 1)
cam0 = camera_cls(
R=R0,
focal_length=f0,
principal_point=p0,
)
cam1 = camera_cls(
R=R1,
focal_length=f0,
principal_point=p1,
)
cam2 = camera_cls(
R=R1,
focal_length=f1,
principal_point=p1,
)
cam3 = camera_cls(
R=R1,
focal_length=f2,
principal_point=p1,
)
cam_batch = join_cameras_as_batch([cam0, cam1])

self.assertEqual(cam_batch._N, cam0._N + cam1._N)
self.assertEqual(cam_batch.device, cam0.device)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0))
self.assertEqual(cam_batch._in_ndc, cam0._in_ndc)

# Test one broadcasted value and one fixed value
# Focal length as (N,) in one camera and (N, 2) in the other
cam_batch = join_cameras_as_batch([cam0, cam2])
self.assertEqual(cam_batch._N, cam0._N + cam2._N)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
self.assertClose(
cam_batch.focal_length,
torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0),
)

# Focal length as (N, 1) in one camera and (N, 2) in the other
cam_batch = join_cameras_as_batch([cam2, cam3])
self.assertClose(
cam_batch.focal_length,
torch.cat([f1, f2.expand(-1, 2)], dim=0),
)

def test_join_batch_perspective(self):
self.join_cameras_as_batch_fov(FoVPerspectiveCameras)
self.join_cameras_as_batch(PerspectiveCameras)

def test_join_batch_orthographic(self):
self.join_cameras_as_batch_fov(FoVOrthographicCameras)
self.join_cameras_as_batch(OrthographicCameras)


############################################################
# FoVPerspective Camera #
Expand Down Expand Up @@ -1055,7 +1149,7 @@ def test_getitem(self):
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])

Expand Down Expand Up @@ -1131,7 +1225,7 @@ def test_getitem(self):
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])

Expand Down

0 comments on commit 39bb2ce

Please sign in to comment.