Skip to content

Commit

Permalink
Add PyTorch3D->OpenCV camera parameter conversion.
Browse files Browse the repository at this point in the history
Summary: This diff implements the inverse of D28992470 (8006842): a function to extract OpenCV convention camera parameters from a PyTorch3D `PerspectiveCameras` object. This is the first part of the new PyTorch3d<>OpenCV<>Pulsar conversion functions.

Reviewed By: patricklabatut

Differential Revision: D29278411

fbshipit-source-id: 68d4555b508dbe8685d8239443f839d194cc2484
  • Loading branch information
classner authored and facebook-github-bot committed Jun 23, 2021
1 parent e4039aa commit da9974b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
5 changes: 4 additions & 1 deletion pytorch3d/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .camera_conversions import cameras_from_opencv_projection
from .camera_conversions import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
)
from .ico_sphere import ico_sphere
from .torus import torus

Expand Down
54 changes: 52 additions & 2 deletions pytorch3d/utils/camera_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

from ..renderer import PerspectiveCameras
from ..transforms import so3_exponential_map
from ..transforms import so3_exponential_map, so3_log_map


def cameras_from_opencv_projection(
Expand Down Expand Up @@ -35,7 +37,7 @@ def cameras_from_opencv_projection(
followed by the homogenization of `x_screen_opencv`.
Note:
The parameters `rvec, tvec, camera_matrix` correspond e.g. to the inputs
The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
Args:
Expand Down Expand Up @@ -74,3 +76,51 @@ def cameras_from_opencv_projection(
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)


def opencv_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Converts a batch of `PerspectiveCameras` into OpenCV-convention
axis-angle rotation vectors `rvec`, translation vectors `tvec`, and the camera
calibration matrices `camera_matrix`. This operation is exactly the inverse
of `cameras_from_opencv_projection`.
Note:
The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
Args:
cameras: A batch of `N` cameras in the PyTorch3D convention.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
Returns:
rvec: A batch of axis-angle rotation vectors of shape `(N, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
"""
R_pytorch3d = cameras.R
T_pytorch3d = cameras.T
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1 # pyre-ignore
R_pytorch3d[:, :, :2] *= -1 # pyre-ignore
tvec = T_pytorch3d.clone() # pyre-ignore
R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore

# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))

principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
focal_length = focal_pytorch3d * (0.5 * image_size_wh)

camera_matrix = torch.zeros_like(R)
camera_matrix[:, :2, 2] = principal_point
camera_matrix[:, 2, 2] = 1.0
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
rvec = so3_log_map(R)
return rvec, tvec, camera_matrix
9 changes: 9 additions & 0 deletions tests/test_camera_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch3d.transforms import so3_exponential_map, so3_log_map
from pytorch3d.utils import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
)

DATA_DIR = get_tests_dir() / "data"
Expand Down Expand Up @@ -151,3 +152,11 @@ def test_opencv_conversion(self):
self.assertClose(
pts_proj_opencv_in_pytorch3d_screen, pts_proj_pytorch3d, atol=1e-5
)

# Check the inverse.
rvec_i, tvec_i, camera_matrix_i = opencv_from_cameras_projection(
cameras_opencv_to_pytorch3d, image_size
)
self.assertClose(rvec, rvec_i)
self.assertClose(tvec, tvec_i)
self.assertClose(camera_matrix, camera_matrix_i)

0 comments on commit da9974b

Please sign in to comment.