diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 70e0f1ce8..09cdc72ba 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -75,6 +75,10 @@ class CamerasBase(TensorProperties): boolean argument of the function. """ + # Used in __getitem__ to index the relevant fields + # When creating a new camera, this should be set in the __init__ + _FIELDS: Tuple = () + def get_projection_transform(self): """ Calculate the projective transformation matrix. @@ -362,6 +366,55 @@ def get_image_size(self): """ return self.image_size if hasattr(self, "image_size") else None + def __getitem__( + self, index: Union[int, List[int], torch.LongTensor] + ) -> "CamerasBase": + """ + Override for the __getitem__ method in TensorProperties which needs to be + refactored. + + Args: + index: an int/list/long tensor used to index all the fields in the cameras given by + self._FIELDS. + Returns: + if `index` is an index int/list/long tensor return an instance of the current + cameras class with only the values at the selected index. + """ + + kwargs = {} + + if not isinstance(index, (int, list, torch.LongTensor)): + msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r" + raise ValueError(msg % type(index)) + + if isinstance(index, int): + index = [index] + + if max(index) >= len(self): + raise ValueError(f"Index {max(index)} is out of bounds for select cameras") + + for field in self._FIELDS: + val = getattr(self, field, None) + if val is None: + continue + + # e.g. "in_ndc" is set as attribute "_in_ndc" on the class + # but provided as "in_ndc" on initialization + if field.startswith("_"): + field = field[1:] + + if isinstance(val, (str, bool)): + kwargs[field] = val + elif isinstance(val, torch.Tensor): + # In the init, all inputs will be converted to + # tensors before setting as attributes + kwargs[field] = val[index] + else: + raise ValueError(f"Field {field} type is not supported for indexing") + + kwargs["device"] = self.device + return self.__class__(**kwargs) + ############################################################ # Field of View Camera Classes # @@ -434,6 +487,18 @@ class FoVPerspectiveCameras(CamerasBase): for rasterization. """ + # For __getitem__ + _FIELDS = ( + "K", + "znear", + "zfar", + "aspect_ratio", + "fov", + "R", + "T", + "degrees", + ) + def __init__( self, znear=1.0, @@ -590,7 +655,7 @@ def unproject_points( xy_depth: torch.Tensor, world_coordinates: bool = True, scaled_depth_input: bool = False, - **kwargs + **kwargs, ) -> torch.Tensor: """>! FoV cameras further allow for passing depth in world units @@ -681,6 +746,20 @@ class FoVOrthographicCameras(CamerasBase): The definition of the parameters follow the OpenGL orthographic camera. """ + # For __getitem__ + _FIELDS = ( + "K", + "znear", + "zfar", + "R", + "T", + "max_y", + "min_y", + "max_x", + "min_x", + "scale_xyz", + ) + def __init__( self, znear=1.0, @@ -819,7 +898,7 @@ def unproject_points( xy_depth: torch.Tensor, world_coordinates: bool = True, scaled_depth_input: bool = False, - **kwargs + **kwargs, ) -> torch.Tensor: """>! FoV cameras further allow for passing depth in world units @@ -907,6 +986,17 @@ class PerspectiveCameras(CamerasBase): If parameters are specified in screen space, `in_ndc` must be set to False. """ + # For __getitem__ + _FIELDS = ( + "K", + "R", + "T", + "focal_length", + "principal_point", + "_in_ndc", # arg is in_ndc but attribute set as _in_ndc + "image_size", + ) + def __init__( self, focal_length=1.0, @@ -1007,7 +1097,7 @@ def unproject_points( xy_depth: torch.Tensor, world_coordinates: bool = True, from_ndc: bool = False, - **kwargs + **kwargs, ) -> torch.Tensor: """ Args: @@ -1126,6 +1216,17 @@ class OrthographicCameras(CamerasBase): If parameters are specified in screen space, `in_ndc` must be set to False. """ + # For __getitem__ + _FIELDS = ( + "K", + "R", + "T", + "focal_length", + "principal_point", + "_in_ndc", + "image_size", + ) + def __init__( self, focal_length=1.0, @@ -1225,7 +1326,7 @@ def unproject_points( xy_depth: torch.Tensor, world_coordinates: bool = True, from_ndc: bool = False, - **kwargs + **kwargs, ) -> torch.Tensor: """ Args: diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index fd99e8a43..33e5359d2 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -155,7 +155,7 @@ def __getitem__(self, index: Union[int, slice]) -> TensorAccessor: Returns: if `index` is an index int/slice return a TensorAccessor class with getattribute/setattribute methods which return/update the value - at the index in the original camera. + at the index in the original class. """ if isinstance(index, (int, slice)): return TensorAccessor(class_object=self, index=index) diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 15fd355a1..d2e646481 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -783,18 +783,53 @@ def test_camera_class_init(self): self.assertTrue(cam.znear.shape == (2,)) self.assertTrue(cam.zfar.shape == (2,)) - # update znear element 1 - cam[1].znear = 20.0 - self.assertTrue(cam.znear[1] == 20.0) - - # Get item and get value - c0 = cam[0] - self.assertTrue(c0.zfar == 100.0) - # Test to new_cam = cam.to(device=device) self.assertTrue(new_cam.device == device) + def test_getitem(self): + R_matrix = torch.randn((6, 3, 3)) + cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix) + + # Check get item returns an instance of the same class + # with all the same keys + c0 = cam[0] + self.assertTrue(isinstance(c0, FoVPerspectiveCameras)) + self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys()) + + # Check all fields correct in get item with int index + self.assertEqual(len(c0), 1) + self.assertClose(c0.zfar, torch.tensor([100.0])) + self.assertClose(c0.znear, torch.tensor([10.0])) + self.assertClose(c0.R, R_matrix[0:1, ...]) + self.assertEqual(c0.device, torch.device("cpu")) + + # Check list(int) index + c012 = cam[[0, 1, 2]] + self.assertEqual(len(c012), 3) + self.assertClose(c012.zfar, torch.tensor([100.0] * 3)) + self.assertClose(c012.znear, torch.tensor([10.0] * 3)) + self.assertClose(c012.R, R_matrix[0:3, ...]) + + # Check torch.LongTensor index + index = torch.tensor([1, 3, 5], dtype=torch.int64) + c135 = cam[index] + self.assertEqual(len(c135), 3) + self.assertClose(c135.zfar, torch.tensor([100.0] * 3)) + self.assertClose(c135.znear, torch.tensor([10.0] * 3)) + self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) + + # Check errors with get item + with self.assertRaisesRegex(ValueError, "out of bounds"): + cam[6] + + with self.assertRaisesRegex(ValueError, "Invalid index type"): + cam[slice(0, 1)] + + with self.assertRaisesRegex(ValueError, "Invalid index type"): + index = torch.tensor([1, 3, 5], dtype=torch.float32) + cam[index] + def test_get_full_transform(self): cam = FoVPerspectiveCameras() T = torch.tensor([0.0, 0.0, 1.0]).view(1, -1) @@ -919,6 +954,30 @@ def test_perspective_type(self): self.assertFalse(cam.is_perspective()) self.assertEqual(cam.get_znear(), 1.0) + def test_getitem(self): + R_matrix = torch.randn((6, 3, 3)) + scale = torch.tensor([[1.0, 1.0, 1.0]], requires_grad=True) + cam = FoVOrthographicCameras( + znear=10.0, zfar=100.0, R=R_matrix, scale_xyz=scale + ) + + # Check get item returns an instance of the same class + # with all the same keys + c0 = cam[0] + self.assertTrue(isinstance(c0, FoVOrthographicCameras)) + self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys()) + + # Check torch.LongTensor index + index = torch.tensor([1, 3, 5], dtype=torch.int64) + c135 = cam[index] + self.assertEqual(len(c135), 3) + self.assertClose(c135.zfar, torch.tensor([100.0] * 3)) + self.assertClose(c135.znear, torch.tensor([10.0] * 3)) + self.assertClose(c135.min_x, torch.tensor([-1.0] * 3)) + self.assertClose(c135.max_x, torch.tensor([1.0] * 3)) + self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) + self.assertClose(c135.scale_xyz, scale.expand(3, -1)) + ############################################################ # Orthographic Camera # @@ -976,6 +1035,30 @@ def test_perspective_type(self): self.assertFalse(cam.is_perspective()) self.assertIsNone(cam.get_znear()) + def test_getitem(self): + R_matrix = torch.randn((6, 3, 3)) + principal_point = torch.randn((6, 2, 1)) + focal_length = 5.0 + cam = OrthographicCameras( + R=R_matrix, + focal_length=focal_length, + principal_point=principal_point, + ) + + # Check get item returns an instance of the same class + # with all the same keys + c0 = cam[0] + self.assertTrue(isinstance(c0, OrthographicCameras)) + self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys()) + + # Check torch.LongTensor index + index = torch.tensor([1, 3, 5], dtype=torch.int64) + c135 = cam[index] + self.assertEqual(len(c135), 3) + self.assertClose(c135.focal_length, torch.tensor([5.0] * 3)) + self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) + self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...]) + ############################################################ # Perspective Camera # @@ -1027,3 +1110,30 @@ def test_perspective_type(self): cam = PerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),)) self.assertTrue(cam.is_perspective()) self.assertIsNone(cam.get_znear()) + + def test_getitem(self): + R_matrix = torch.randn((6, 3, 3)) + principal_point = torch.randn((6, 2, 1)) + focal_length = 5.0 + cam = PerspectiveCameras( + R=R_matrix, + focal_length=focal_length, + principal_point=principal_point, + ) + + # Check get item returns an instance of the same class + # with all the same keys + c0 = cam[0] + self.assertTrue(isinstance(c0, PerspectiveCameras)) + self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys()) + + # Check torch.LongTensor index + index = torch.tensor([1, 3, 5], dtype=torch.int64) + c135 = cam[index] + self.assertEqual(len(c135), 3) + self.assertClose(c135.focal_length, torch.tensor([5.0] * 3)) + self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) + self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...]) + + # Check in_ndc is handled correctly + self.assertEqual(cam._in_ndc, c0._in_ndc)