diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index e983540b4..0b3c7d444 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -10,7 +10,7 @@ sigmoid_alpha_blend, softmax_rgb_blend, ) -from .camera_utils import rotate_on_spot +from .camera_utils import join_cameras_as_batch, rotate_on_spot from .cameras import OpenGLOrthographicCameras # deprecated from .cameras import OpenGLPerspectiveCameras # deprecated from .cameras import SfMOrthographicCameras # deprecated @@ -29,6 +29,7 @@ AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher, GridRaysampler, + HarmonicEmbedding, ImplicitRenderer, MonteCarloRaysampler, NDCGridRaysampler, @@ -37,7 +38,6 @@ VolumeSampler, ray_bundle_to_ray_points, ray_bundle_variables_to_ray_points, - HarmonicEmbedding, ) from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular from .materials import Materials diff --git a/pytorch3d/renderer/camera_utils.py b/pytorch3d/renderer/camera_utils.py index 90bd02089..1bddcaf23 100644 --- a/pytorch3d/renderer/camera_utils.py +++ b/pytorch3d/renderer/camera_utils.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Sequence, Tuple import torch from pytorch3d.transforms import Transform3d +from .cameras import CamerasBase + def camera_to_eye_at_up( world_to_view_transform: Transform3d, @@ -141,3 +143,65 @@ def rotate_on_spot( new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0] return new_R, new_T + + +def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase: + """ + Create a batched cameras object by concatenating a list of input + cameras objects. All the tensor attributes will be joined along + the batch dimension. + + Args: + cameras_list: List of camera classes all of the same type and + on the same device. Each represents one or more cameras. + Returns: + cameras: single batched cameras object of the same + type as all the objects in the input list. + """ + # Get the type and fields to join from the first camera in the batch + c0 = cameras_list[0] + fields = c0._FIELDS + shared_fields = c0._SHARED_FIELDS + + if not all(isinstance(c, CamerasBase) for c in cameras_list): + raise ValueError("cameras in cameras_list must inherit from CamerasBase") + + if not all(type(c) is type(c0) for c in cameras_list[1:]): + raise ValueError("All cameras must be of the same type") + + if not all(c.device == c0.device for c in cameras_list[1:]): + raise ValueError("All cameras in the batch must be on the same device") + + # Concat the fields to make a batched tensor + kwargs = {} + kwargs["device"] = c0.device + + for field in fields: + field_not_none = [(getattr(c, field) is not None) for c in cameras_list] + if not any(field_not_none): + continue + if not all(field_not_none): + raise ValueError(f"Attribute {field} is inconsistently present") + + attrs_list = [getattr(c, field) for c in cameras_list] + + if field in shared_fields: + # Only needs to be set once + if not all(a == attrs_list[0] for a in attrs_list): + raise ValueError(f"Attribute {field} is not constant across inputs") + + # e.g. "in_ndc" is set as attribute "_in_ndc" on the class + # but provided as "in_ndc" in the input args + if field.startswith("_"): + field = field[1:] + + kwargs[field] = attrs_list[0] + elif isinstance(attrs_list[0], torch.Tensor): + # In the init, all inputs will be converted to + # batched tensors before set as attributes + # Join as a tensor along the batch dimension + kwargs[field] = torch.cat(attrs_list, dim=0) + else: + raise ValueError(f"Field {field} type is not supported for batching") + + return c0.__class__(**kwargs) diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 9860bb7a7..1225381ee 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -77,7 +77,12 @@ class CamerasBase(TensorProperties): # Used in __getitem__ to index the relevant fields # When creating a new camera, this should be set in the __init__ - _FIELDS: Tuple = () + _FIELDS: Tuple[str, ...] = () + + # Names of fields which are a constant property of the whole batch, rather + # than themselves a batch of data. + # When joining objects into a batch, they will have to agree. + _SHARED_FIELDS: Tuple[str, ...] = () def get_projection_transform(self): """ @@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase): "degrees", ) + _SHARED_FIELDS = ("degrees",) + def __init__( self, znear=1.0, @@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase): "image_size", ) + _SHARED_FIELDS = ("_in_ndc",) + def __init__( self, focal_length=1.0, @@ -1047,6 +1056,12 @@ def __init__( else: self.image_size = None + # When focal length is provided as one value, expand to + # create (N, 2) shape tensor + if self.focal_length.ndim == 1: # (N,) + self.focal_length = self.focal_length[:, None] # (N, 1) + self.focal_length = self.focal_length.expand(-1, 2) # (N, 2) + def get_projection_transform(self, **kwargs) -> Transform3d: """ Calculate the projection matrix using the @@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase): "image_size", ) + _SHARED_FIELDS = ("_in_ndc",) + def __init__( self, focal_length=1.0, @@ -1276,6 +1293,12 @@ def __init__( else: self.image_size = None + # When focal length is provided as one value, expand to + # create (N, 2) shape tensor + if self.focal_length.ndim == 1: # (N,) + self.focal_length = self.focal_length[:, None] # (N, 1) + self.focal_length = self.focal_length.expand(-1, 2) # (N, 2) + def get_projection_transform(self, **kwargs) -> Transform3d: """ Calculate the projection matrix using diff --git a/tests/test_camera_pixels.py b/tests/test_camera_pixels.py index 1235439e6..98f2621ec 100644 --- a/tests/test_camera_pixels.py +++ b/tests/test_camera_pixels.py @@ -250,7 +250,4 @@ def test_camera(self): ], dim=1, ) - - print(wanted) - print(camera_points[batch_idx]) self.assertClose(camera_points[batch_idx], wanted) diff --git a/tests/test_cameras.py b/tests/test_cameras.py index a55799e47..b2a97a7cc 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -36,6 +36,7 @@ import numpy as np import torch from common_testing import TestCaseMixin +from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.cameras import ( CamerasBase, FoVOrthographicCameras, @@ -688,6 +689,99 @@ def test_clone(self, batch_size: int = 10): else: self.assertTrue(val == val_clone) + def test_join_cameras_as_batch_errors(self): + cam0 = PerspectiveCameras(device="cuda:0") + cam1 = OrthographicCameras(device="cuda:0") + + # Cameras not of the same type + with self.assertRaisesRegex(ValueError, "same type"): + join_cameras_as_batch([cam0, cam1]) + + cam2 = OrthographicCameras(device="cpu") + # Cameras not on the same device + with self.assertRaisesRegex(ValueError, "same device"): + join_cameras_as_batch([cam1, cam2]) + + cam3 = OrthographicCameras(in_ndc=False, device="cuda:0") + # Different coordinate systems -- all should be in ndc or in screen + with self.assertRaisesRegex( + ValueError, "Attribute _in_ndc is not constant across inputs" + ): + join_cameras_as_batch([cam1, cam3]) + + def join_cameras_as_batch_fov(self, camera_cls): + R0 = torch.randn((6, 3, 3)) + R1 = torch.randn((3, 3, 3)) + cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0") + cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0") + + cam_batch = join_cameras_as_batch([cam0, cam1]) + + self.assertEqual(cam_batch._N, cam0._N + cam1._N) + self.assertEqual(cam_batch.device, cam0.device) + self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0")) + + def join_cameras_as_batch(self, camera_cls): + R0 = torch.randn((6, 3, 3)) + R1 = torch.randn((3, 3, 3)) + p0 = torch.randn((6, 2, 1)) + p1 = torch.randn((3, 2, 1)) + f0 = 5.0 + f1 = torch.randn(3, 2) + f2 = torch.randn(3, 1) + cam0 = camera_cls( + R=R0, + focal_length=f0, + principal_point=p0, + ) + cam1 = camera_cls( + R=R1, + focal_length=f0, + principal_point=p1, + ) + cam2 = camera_cls( + R=R1, + focal_length=f1, + principal_point=p1, + ) + cam3 = camera_cls( + R=R1, + focal_length=f2, + principal_point=p1, + ) + cam_batch = join_cameras_as_batch([cam0, cam1]) + + self.assertEqual(cam_batch._N, cam0._N + cam1._N) + self.assertEqual(cam_batch.device, cam0.device) + self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0)) + self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0)) + self.assertEqual(cam_batch._in_ndc, cam0._in_ndc) + + # Test one broadcasted value and one fixed value + # Focal length as (N,) in one camera and (N, 2) in the other + cam_batch = join_cameras_as_batch([cam0, cam2]) + self.assertEqual(cam_batch._N, cam0._N + cam2._N) + self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0)) + self.assertClose( + cam_batch.focal_length, + torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0), + ) + + # Focal length as (N, 1) in one camera and (N, 2) in the other + cam_batch = join_cameras_as_batch([cam2, cam3]) + self.assertClose( + cam_batch.focal_length, + torch.cat([f1, f2.expand(-1, 2)], dim=0), + ) + + def test_join_batch_perspective(self): + self.join_cameras_as_batch_fov(FoVPerspectiveCameras) + self.join_cameras_as_batch(PerspectiveCameras) + + def test_join_batch_orthographic(self): + self.join_cameras_as_batch_fov(FoVOrthographicCameras) + self.join_cameras_as_batch(OrthographicCameras) + ############################################################ # FoVPerspective Camera # @@ -1055,7 +1149,7 @@ def test_getitem(self): index = torch.tensor([1, 3, 5], dtype=torch.int64) c135 = cam[index] self.assertEqual(len(c135), 3) - self.assertClose(c135.focal_length, torch.tensor([5.0] * 3)) + self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3)) self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...]) @@ -1131,7 +1225,7 @@ def test_getitem(self): index = torch.tensor([1, 3, 5], dtype=torch.int64) c135 = cam[index] self.assertEqual(len(c135), 3) - self.assertClose(c135.focal_length, torch.tensor([5.0] * 3)) + self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3)) self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])