Skip to content

Commit

Permalink
Remove pytorch3d's wrappers for eigh, solve, lstsq, qr
Browse files Browse the repository at this point in the history
Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly.

Reviewed By: bottler

Differential Revision: D39172949

fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
  • Loading branch information
Chris Lambert authored and facebook-github-bot committed Aug 31, 2022
1 parent 9a1213e commit d4a1051
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 67 deletions.
47 changes: 0 additions & 47 deletions pytorch3d/common/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,6 @@
"""


def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""
Like torch.linalg.solve, tries to return X
such that AX=B, with A square.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
# PyTorch version >= 1.8.0
return torch.linalg.solve(A, B)

# pyre-fixme[16]: `Tuple` has no attribute `solution`.
return torch.solve(B, A).solution


def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""
Like torch.linalg.lstsq, tries to return X
such that AX=B.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
# PyTorch version >= 1.9
return torch.linalg.lstsq(A, B).solution

solution = torch.lstsq(B, A).solution
if A.shape[1] < A.shape[0]:
return solution[: A.shape[1]]
return solution


def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.qr.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
# PyTorch version >= 1.9
return torch.linalg.qr(A)
return torch.qr(A)


def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
return torch.linalg.eigh(A)
return torch.symeig(A, eigenvectors=True)


def meshgrid_ij(
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
Expand Down
6 changes: 2 additions & 4 deletions pytorch3d/implicitron/tools/circle_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Optional

import torch
from pytorch3d.common.compat import eigh, lstsq


def _get_rotation_to_best_fit_xy(
Expand All @@ -28,7 +27,7 @@ def _get_rotation_to_best_fit_xy(
(3,3) tensor rotation matrix
"""
points_centered = points - centroid[None]
return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]


def _signed_area(path: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -106,9 +105,8 @@ def fit_circle_in_2d(
n_provided = points2d.shape[0]
if n_provided < 3:
raise ValueError(f"{n_provided} points are not enough to determine a circle")
solution = lstsq(design, rhs[:, None])
solution = torch.linalg.lstsq(design, rhs[:, None]).solution
center = solution[:2, 0] / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
radius = torch.sqrt(solution[2, 0] + (center**2).sum())
if n_points > 0:
if angles is not None:
Expand Down
3 changes: 1 addition & 2 deletions pytorch3d/implicitron/tools/eval_video_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Optional, Tuple

import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools import utils
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
Expand Down Expand Up @@ -205,7 +204,7 @@ def _disambiguate_normal(normal, up):
def _fit_plane(x):
x = x - x.mean(dim=0)[None]
cov = (x.t() @ x) / x.shape[0]
_, e_vec = eigh(cov)
_, e_vec = torch.linalg.eigh(cov)
return e_vec


Expand Down
3 changes: 1 addition & 2 deletions pytorch3d/ops/perspective_n_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
import torch.nn.functional as F
from pytorch3d.common.compat import eigh
from pytorch3d.ops import points_alignment, utils as oputil


Expand Down Expand Up @@ -106,7 +105,7 @@ def _null_space(m, kernel_dim):
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = eigh(mTm)
s, v = torch.linalg.eigh(mTm)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]


Expand Down
3 changes: 1 addition & 2 deletions pytorch3d/ops/points_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Tuple, TYPE_CHECKING, Union

import torch
from pytorch3d.common.compat import eigh
from pytorch3d.common.workaround import symeig3x3

from .utils import convert_pointclouds_to_tensor, get_point_covariances
Expand Down Expand Up @@ -147,7 +146,7 @@ def estimate_pointcloud_local_coord_frames(
if use_symeig_workaround:
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
else:
curvatures, local_coord_frames = eigh(cov)
curvatures, local_coord_frames = torch.linalg.eigh(cov)

# disambiguate the directions of individual principal vectors
if disambiguate_directions:
Expand Down
3 changes: 1 addition & 2 deletions pytorch3d/transforms/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch
from pytorch3d.common.compat import solve

from .so3 import _so3_exp_map, hat, so3_log_map

Expand Down Expand Up @@ -174,7 +173,7 @@ def se3_log_map(
# log_translation is V^-1 @ T
T = transform[:, 3, :3]
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
log_translation = solve(V, T[:, :, None])[:, :, 0]
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]

return torch.cat((log_translation, log_rotation), dim=1)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_acos_linear_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import torch
from pytorch3d.common.compat import lstsq
from pytorch3d.transforms import acos_linear_extrapolation

from .common_testing import TestCaseMixin
Expand Down Expand Up @@ -66,7 +65,7 @@ def _test_acos_outside_bounds(self, x, y, dydx, bound):
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
# fit a line: slope * x + bias = y
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
slope, bias = torch.linalg.lstsq(x_1, y[:, None]).solution.view(-1)[:2]
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t**2)
# test that the desired slope is the same as the fitted one
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import torch
from pytorch3d.common.compat import qr
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
Expand Down Expand Up @@ -199,7 +198,7 @@ def test_se3_log_singularity(self, batch_size: int = 100):
r = [identity, rot180]
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-6)[0]
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-6)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
# this adds random noise to the second half
# of the random orthogonal matrices to generate
Expand Down
7 changes: 3 additions & 4 deletions tests/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np
import torch
from pytorch3d.common.compat import qr
from pytorch3d.transforms.so3 import (
hat,
so3_exp_map,
Expand Down Expand Up @@ -49,7 +48,7 @@ def init_rot(batch_size: int = 10):
# TODO(dnovotny): replace with random_rotation from random_rotation.py
rot = []
for _ in range(batch_size):
r = qr(torch.randn((3, 3), device=device))[0]
r = torch.linalg.qr(torch.randn((3, 3), device=device))[0]
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
if f.sum() % 2 == 0:
f = 1 - f
Expand Down Expand Up @@ -145,7 +144,7 @@ def test_so3_log_singularity(self, batch_size: int = 100):
# add random rotations and random almost orthonormal matrices
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-4)[0]
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
Expand Down Expand Up @@ -245,7 +244,7 @@ def test_so3_cos_bound(self, batch_size: int = 100):
r = [identity, rot180]
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-4)[0]
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
for _ in range(batch_size - 2)
]
)
Expand Down

0 comments on commit d4a1051

Please sign in to comment.