From d4a1051e0f9f93a2052487cb42e34fe6e1987f84 Mon Sep 17 00:00:00 2001 From: Chris Lambert Date: Wed, 31 Aug 2022 13:04:07 -0700 Subject: [PATCH] Remove pytorch3d's wrappers for eigh, solve, lstsq, qr Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly. Reviewed By: bottler Differential Revision: D39172949 fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c --- pytorch3d/common/compat.py | 47 ------------------- pytorch3d/implicitron/tools/circle_fitting.py | 6 +-- .../tools/eval_video_trajectory.py | 3 +- pytorch3d/ops/perspective_n_points.py | 3 +- pytorch3d/ops/points_normals.py | 3 +- pytorch3d/transforms/se3.py | 3 +- tests/test_acos_linear_extrapolation.py | 3 +- tests/test_se3.py | 3 +- tests/test_so3.py | 7 ++- 9 files changed, 11 insertions(+), 67 deletions(-) diff --git a/pytorch3d/common/compat.py b/pytorch3d/common/compat.py index 5cb0d5c63..5c155f12f 100644 --- a/pytorch3d/common/compat.py +++ b/pytorch3d/common/compat.py @@ -14,53 +14,6 @@ """ -def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover - """ - Like torch.linalg.solve, tries to return X - such that AX=B, with A square. - """ - if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"): - # PyTorch version >= 1.8.0 - return torch.linalg.solve(A, B) - - # pyre-fixme[16]: `Tuple` has no attribute `solution`. - return torch.solve(B, A).solution - - -def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover - """ - Like torch.linalg.lstsq, tries to return X - such that AX=B. - """ - if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"): - # PyTorch version >= 1.9 - return torch.linalg.lstsq(A, B).solution - - solution = torch.lstsq(B, A).solution - if A.shape[1] < A.shape[0]: - return solution[: A.shape[1]] - return solution - - -def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover - """ - Like torch.linalg.qr. - """ - if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"): - # PyTorch version >= 1.9 - return torch.linalg.qr(A) - return torch.qr(A) - - -def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover - """ - Like torch.linalg.eigh, assuming the argument is a symmetric real matrix. - """ - if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): - return torch.linalg.eigh(A) - return torch.symeig(A, eigenvectors=True) - - def meshgrid_ij( *A: Union[torch.Tensor, Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, ...]: # pragma: no cover diff --git a/pytorch3d/implicitron/tools/circle_fitting.py b/pytorch3d/implicitron/tools/circle_fitting.py index 97165ed1c..966b9ff9b 100644 --- a/pytorch3d/implicitron/tools/circle_fitting.py +++ b/pytorch3d/implicitron/tools/circle_fitting.py @@ -10,7 +10,6 @@ from typing import Optional import torch -from pytorch3d.common.compat import eigh, lstsq def _get_rotation_to_best_fit_xy( @@ -28,7 +27,7 @@ def _get_rotation_to_best_fit_xy( (3,3) tensor rotation matrix """ points_centered = points - centroid[None] - return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]] + return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]] def _signed_area(path: torch.Tensor) -> torch.Tensor: @@ -106,9 +105,8 @@ def fit_circle_in_2d( n_provided = points2d.shape[0] if n_provided < 3: raise ValueError(f"{n_provided} points are not enough to determine a circle") - solution = lstsq(design, rhs[:, None]) + solution = torch.linalg.lstsq(design, rhs[:, None]).solution center = solution[:2, 0] / 2 - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. radius = torch.sqrt(solution[2, 0] + (center**2).sum()) if n_points > 0: if angles is not None: diff --git a/pytorch3d/implicitron/tools/eval_video_trajectory.py b/pytorch3d/implicitron/tools/eval_video_trajectory.py index e540a3452..6b5174dc2 100644 --- a/pytorch3d/implicitron/tools/eval_video_trajectory.py +++ b/pytorch3d/implicitron/tools/eval_video_trajectory.py @@ -9,7 +9,6 @@ from typing import Optional, Tuple import torch -from pytorch3d.common.compat import eigh from pytorch3d.implicitron.tools import utils from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras @@ -205,7 +204,7 @@ def _disambiguate_normal(normal, up): def _fit_plane(x): x = x - x.mean(dim=0)[None] cov = (x.t() @ x) / x.shape[0] - _, e_vec = eigh(cov) + _, e_vec = torch.linalg.eigh(cov) return e_vec diff --git a/pytorch3d/ops/perspective_n_points.py b/pytorch3d/ops/perspective_n_points.py index 2f552a6e5..c6b7d6816 100644 --- a/pytorch3d/ops/perspective_n_points.py +++ b/pytorch3d/ops/perspective_n_points.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F -from pytorch3d.common.compat import eigh from pytorch3d.ops import points_alignment, utils as oputil @@ -106,7 +105,7 @@ def _null_space(m, kernel_dim): kernel vectors, of size B x kernel_dim """ mTm = torch.bmm(m.transpose(1, 2), m) - s, v = eigh(mTm) + s, v = torch.linalg.eigh(mTm) return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim] diff --git a/pytorch3d/ops/points_normals.py b/pytorch3d/ops/points_normals.py index 2515020c2..63aeefbd8 100644 --- a/pytorch3d/ops/points_normals.py +++ b/pytorch3d/ops/points_normals.py @@ -7,7 +7,6 @@ from typing import Tuple, TYPE_CHECKING, Union import torch -from pytorch3d.common.compat import eigh from pytorch3d.common.workaround import symeig3x3 from .utils import convert_pointclouds_to_tensor, get_point_covariances @@ -147,7 +146,7 @@ def estimate_pointcloud_local_coord_frames( if use_symeig_workaround: curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True) else: - curvatures, local_coord_frames = eigh(cov) + curvatures, local_coord_frames = torch.linalg.eigh(cov) # disambiguate the directions of individual principal vectors if disambiguate_directions: diff --git a/pytorch3d/transforms/se3.py b/pytorch3d/transforms/se3.py index 81a17c755..1c8a5a1b1 100644 --- a/pytorch3d/transforms/se3.py +++ b/pytorch3d/transforms/se3.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch -from pytorch3d.common.compat import solve from .so3 import _so3_exp_map, hat, so3_log_map @@ -174,7 +173,7 @@ def se3_log_map( # log_translation is V^-1 @ T T = transform[:, 3, :3] V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps) - log_translation = solve(V, T[:, :, None])[:, :, 0] + log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0] return torch.cat((log_translation, log_rotation), dim=1) diff --git a/tests/test_acos_linear_extrapolation.py b/tests/test_acos_linear_extrapolation.py index 1cadf723b..6a4513d19 100644 --- a/tests/test_acos_linear_extrapolation.py +++ b/tests/test_acos_linear_extrapolation.py @@ -9,7 +9,6 @@ import numpy as np import torch -from pytorch3d.common.compat import lstsq from pytorch3d.transforms import acos_linear_extrapolation from .common_testing import TestCaseMixin @@ -66,7 +65,7 @@ def _test_acos_outside_bounds(self, x, y, dydx, bound): bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype) # fit a line: slope * x + bias = y x_1 = torch.stack([x, torch.ones_like(x)], dim=-1) - slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2] + slope, bias = torch.linalg.lstsq(x_1, y[:, None]).solution.view(-1)[:2] desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t**2) # test that the desired slope is the same as the fitted one self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2) diff --git a/tests/test_se3.py b/tests/test_se3.py index 27c2f1a82..4bd4f8eb0 100644 --- a/tests/test_se3.py +++ b/tests/test_se3.py @@ -9,7 +9,6 @@ import numpy as np import torch -from pytorch3d.common.compat import qr from pytorch3d.transforms.rotation_conversions import random_rotations from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle @@ -199,7 +198,7 @@ def test_se3_log_singularity(self, batch_size: int = 100): r = [identity, rot180] r.extend( [ - qr(identity + torch.randn_like(identity) * 1e-6)[0] + torch.linalg.qr(identity + torch.randn_like(identity) * 1e-6)[0] + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8 # this adds random noise to the second half # of the random orthogonal matrices to generate diff --git a/tests/test_so3.py b/tests/test_so3.py index fa02d6ddf..5320f5976 100644 --- a/tests/test_so3.py +++ b/tests/test_so3.py @@ -11,7 +11,6 @@ import numpy as np import torch -from pytorch3d.common.compat import qr from pytorch3d.transforms.so3 import ( hat, so3_exp_map, @@ -49,7 +48,7 @@ def init_rot(batch_size: int = 10): # TODO(dnovotny): replace with random_rotation from random_rotation.py rot = [] for _ in range(batch_size): - r = qr(torch.randn((3, 3), device=device))[0] + r = torch.linalg.qr(torch.randn((3, 3), device=device))[0] f = torch.randint(2, (3,), device=device, dtype=torch.float32) if f.sum() % 2 == 0: f = 1 - f @@ -145,7 +144,7 @@ def test_so3_log_singularity(self, batch_size: int = 100): # add random rotations and random almost orthonormal matrices r.extend( [ - qr(identity + torch.randn_like(identity) * 1e-4)[0] + torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0] + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3 # this adds random noise to the second half # of the random orthogonal matrices to generate @@ -245,7 +244,7 @@ def test_so3_cos_bound(self, batch_size: int = 100): r = [identity, rot180] r.extend( [ - qr(identity + torch.randn_like(identity) * 1e-4)[0] + torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0] for _ in range(batch_size - 2) ] )