Skip to content

Commit

Permalink
Update Rotate transform to use device of input rotation
Browse files Browse the repository at this point in the history
Summary: Currently the Rotate transform does not consider the R's device at all, resulting in errors if you're expecting it to be on cuda but it gets the default casting to cpu. This updates the transform to respect R's device.

Reviewed By: nikhilaravi

Differential Revision: D27828118

fbshipit-source-id: ddd99f73eadbd990688eb22f3d1ffbacbe168c81
  • Loading branch information
theschnitz authored and facebook-github-bot committed May 19, 2021
1 parent c9dea62 commit cd5af25
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def cuda(self):


class Translate(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
"""
Create a new Transform3d representing 3D translations.
Expand All @@ -448,11 +448,11 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
- A torch scalar
- A 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(x, y, z, dtype, device, "Translate")
super().__init__(device=xyz.device)
N = xyz.shape[0]

mat = torch.eye(4, dtype=dtype, device=device)
mat = torch.eye(4, dtype=dtype, device=self.device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 3, :3] = xyz
self._matrix = mat
Expand All @@ -468,7 +468,7 @@ def _get_matrix_inverse(self):


class Scale(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
Expand All @@ -485,12 +485,12 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
- torch scalar
- 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
super().__init__(device=xyz.device)
N = xyz.shape[0]

# TODO: Can we do this all in one go somehow?
mat = torch.eye(4, dtype=dtype, device=device)
mat = torch.eye(4, dtype=dtype, device=self.device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 0, 0] = xyz[:, 0]
mat[:, 1, 1] = xyz[:, 1]
Expand All @@ -509,7 +509,7 @@ def _get_matrix_inverse(self):

class Rotate(Transform3d):
def __init__(
self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
):
"""
Create a new Transform3d representing 3D rotation using a rotation
Expand All @@ -520,6 +520,7 @@ def __init__(
orthogonal_tol: tolerance for the test of the orthogonality of R
"""
device = _get_device(R, device)
super().__init__(device=device)
if R.dim() == 2:
R = R[None]
Expand Down Expand Up @@ -548,7 +549,7 @@ def __init__(
axis: str = "X",
degrees: bool = True,
dtype=torch.float64,
device="cpu",
device=None,
):
"""
Create a new Transform3d representing 3D rotation about an axis
Expand Down Expand Up @@ -578,7 +579,7 @@ def __init__(
# is for transforming column vectors. Therefore we transpose this matrix.
# R will always be of shape (N, 3, 3)
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
super().__init__(device=device, R=R)
super().__init__(device=angle.device, R=R)


def _handle_coord(c, dtype, device):
Expand All @@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device):
c = torch.tensor(c, dtype=dtype, device=device)
if c.dim() == 0:
c = c.view(1)
if c.device != device:
c = c.to(device=device)
return c


def _get_device(x, device=None):
if device is not None:
# User overriding device, leave
device = device
elif torch.is_tensor(x):
# Set device based on input tensor
device = x.device
else:
# Default device is cpu
device = "cpu"
return device


def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
"""
Helper function to handle parsing logic for building transforms. The output
Expand Down Expand Up @@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
Returns:
xyz: Tensor of shape (N, 3)
"""
device = _get_device(x, device)
# If x is actually a tensor of shape (N, 3) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 3:
Expand All @@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg)
return x
return x.to(device=device)

if allow_singleton and y is None and z is None:
y = x
Expand Down Expand Up @@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str):
- Python scalar
- Torch scalar
"""
device = _get_device(x, device)
if torch.is_tensor(x) and x.dim() > 1:
msg = "Expected tensor of shape (N,); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
Expand Down

0 comments on commit cd5af25

Please sign in to comment.