Skip to content

Commit

Permalink
Umeyama
Browse files Browse the repository at this point in the history
Summary:
Umeyama estimates a rigid motion between two sets of corresponding points.

Benchmark output for `bm_points_alignment`

```
Arguments key: [<allow_reflection>_<batch_size>_<dim>_<estimate_scale>_<n_points>_<use_pointclouds>]
Benchmark                                                        Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
CorrespodingPointsAlignment_True_1_3_True_100_False                   7382            9833             68
CorrespodingPointsAlignment_True_1_3_True_10000_False                 8183           10500             62
CorrespodingPointsAlignment_True_1_3_False_100_False                  7301            9263             69
CorrespodingPointsAlignment_True_1_3_False_10000_False                7945            9746             64
CorrespodingPointsAlignment_True_1_20_True_100_False                 13706           41623             37
CorrespodingPointsAlignment_True_1_20_True_10000_False               11044           33766             46
CorrespodingPointsAlignment_True_1_20_False_100_False                 9908           28791             51
CorrespodingPointsAlignment_True_1_20_False_10000_False               9523           18680             53
CorrespodingPointsAlignment_True_10_3_True_100_False                 29585           32026             17
CorrespodingPointsAlignment_True_10_3_True_10000_False               29626           36324             18
CorrespodingPointsAlignment_True_10_3_False_100_False                26013           29253             20
CorrespodingPointsAlignment_True_10_3_False_10000_False              25000           33820             20
CorrespodingPointsAlignment_True_10_20_True_100_False                40955           41592             13
CorrespodingPointsAlignment_True_10_20_True_10000_False              42087           42393             12
CorrespodingPointsAlignment_True_10_20_False_100_False               39863           40381             13
CorrespodingPointsAlignment_True_10_20_False_10000_False             40813           41699             13
CorrespodingPointsAlignment_True_100_3_True_100_False               183146          194745              3
CorrespodingPointsAlignment_True_100_3_True_10000_False             213789          231466              3
CorrespodingPointsAlignment_True_100_3_False_100_False              177805          180796              3
CorrespodingPointsAlignment_True_100_3_False_10000_False            184963          185695              3
CorrespodingPointsAlignment_True_100_20_True_100_False              347181          347325              2
CorrespodingPointsAlignment_True_100_20_True_10000_False            363259          363613              2
CorrespodingPointsAlignment_True_100_20_False_100_False             351769          352496              2
CorrespodingPointsAlignment_True_100_20_False_10000_False           375629          379818              2
CorrespodingPointsAlignment_False_1_3_True_100_False                 11155           13770             45
CorrespodingPointsAlignment_False_1_3_True_10000_False               10743           13938             47
CorrespodingPointsAlignment_False_1_3_False_100_False                 9578           11511             53
CorrespodingPointsAlignment_False_1_3_False_10000_False               9549           11984             53
CorrespodingPointsAlignment_False_1_20_True_100_False                13809           14183             37
CorrespodingPointsAlignment_False_1_20_True_10000_False              14084           15082             36
CorrespodingPointsAlignment_False_1_20_False_100_False               12765           14177             40
CorrespodingPointsAlignment_False_1_20_False_10000_False             12811           13096             40
CorrespodingPointsAlignment_False_10_3_True_100_False                28823           39384             18
CorrespodingPointsAlignment_False_10_3_True_10000_False              27135           27525             19
CorrespodingPointsAlignment_False_10_3_False_100_False               26236           28980             20
CorrespodingPointsAlignment_False_10_3_False_10000_False             42324           45123             12
CorrespodingPointsAlignment_False_10_20_True_100_False              723902          723902              1
CorrespodingPointsAlignment_False_10_20_True_10000_False            220007          252886              3
CorrespodingPointsAlignment_False_10_20_False_100_False              55593           71636              9
CorrespodingPointsAlignment_False_10_20_False_10000_False            44419           71861             12
CorrespodingPointsAlignment_False_100_3_True_100_False              184768          185199              3
CorrespodingPointsAlignment_False_100_3_True_10000_False            198657          213868              3
CorrespodingPointsAlignment_False_100_3_False_100_False             224598          309645              3
CorrespodingPointsAlignment_False_100_3_False_10000_False           197863          202002              3
CorrespodingPointsAlignment_False_100_20_True_100_False             293484          309459              2
CorrespodingPointsAlignment_False_100_20_True_10000_False           327253          366644              2
CorrespodingPointsAlignment_False_100_20_False_100_False            420793          422194              2
CorrespodingPointsAlignment_False_100_20_False_10000_False          462634          485542              2
CorrespodingPointsAlignment_True_1_3_True_100_True                    7664            9909             66
CorrespodingPointsAlignment_True_1_3_True_10000_True                  7190            8366             70
CorrespodingPointsAlignment_True_1_3_False_100_True                   6549            8316             77
CorrespodingPointsAlignment_True_1_3_False_10000_True                 6534            7710             77
CorrespodingPointsAlignment_True_10_3_True_100_True                  29052           32940             18
CorrespodingPointsAlignment_True_10_3_True_10000_True                30526           33453             17
CorrespodingPointsAlignment_True_10_3_False_100_True                 28708           32993             18
CorrespodingPointsAlignment_True_10_3_False_10000_True               30630           35973             17
CorrespodingPointsAlignment_True_100_3_True_100_True                264909          320820              3
CorrespodingPointsAlignment_True_100_3_True_10000_True              310902          322604              2
CorrespodingPointsAlignment_True_100_3_False_100_True               246832          250634              3
CorrespodingPointsAlignment_True_100_3_False_10000_True             276006          289061              2
CorrespodingPointsAlignment_False_1_3_True_100_True                  11421           13757             44
CorrespodingPointsAlignment_False_1_3_True_10000_True                11199           12532             45
CorrespodingPointsAlignment_False_1_3_False_100_True                 11474           15841             44
CorrespodingPointsAlignment_False_1_3_False_10000_True               10384           13188             49
CorrespodingPointsAlignment_False_10_3_True_100_True                 36599           47340             14
CorrespodingPointsAlignment_False_10_3_True_10000_True               40702           50754             13
CorrespodingPointsAlignment_False_10_3_False_100_True                41277           52149             13
CorrespodingPointsAlignment_False_10_3_False_10000_True              34286           37091             15
CorrespodingPointsAlignment_False_100_3_True_100_True               254991          258578              2
CorrespodingPointsAlignment_False_100_3_True_10000_True             257999          261285              2
CorrespodingPointsAlignment_False_100_3_False_100_True              247511          248693              3
CorrespodingPointsAlignment_False_100_3_False_10000_True            251807          263865              3
```

Reviewed By: gkioxari

Differential Revision: D19808389

fbshipit-source-id: 83305a58627d2fc5dcaf3c3015132d8148f28c29
  • Loading branch information
davnov134 authored and facebook-github-bot committed Apr 2, 2020
1 parent 745aaf3 commit e5b1d6d
Show file tree
Hide file tree
Showing 4 changed files with 550 additions and 0 deletions.
1 change: 1 addition & 0 deletions pytorch3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .mesh_face_areas_normals import mesh_face_areas_normals
from .nearest_neighbor_points import nn_points_idx
from .packed_to_padded import packed_to_padded, padded_to_packed
from .points_alignment import corresponding_points_alignment
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .vert_align import vert_align
Expand Down
151 changes: 151 additions & 0 deletions pytorch3d/ops/points_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import warnings
from typing import Tuple, Union
import torch

from pytorch3d.structures.pointclouds import Pointclouds


def corresponding_points_alignment(
X: Union[torch.Tensor, Pointclouds],
Y: Union[torch.Tensor, Pointclouds],
estimate_scale: bool = False,
allow_reflection: bool = False,
eps: float = 1e-8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Finds a similarity transformation (rotation `R`, translation `T`
and optionally scale `s`) between two given sets of corresponding
`d`-dimensional points `X` and `Y` such that:
`s[i] X[i] R[i] + T[i] = Y[i]`,
for all batch indexes `i` in the least squares sense.
The algorithm is also known as Umeyama [1].
Args:
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
estimate_scale: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes an identity
scale and returns a tensor of ones.
allow_reflection: If `True`, allows the algorithm to return `R`
which is orthonormal but has determinant==-1.
eps: A scalar for clamping to avoid dividing by zero. Active for the
code that estimates the output scale `s`.
Returns:
3-element tuple containing
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
- **T**: Batch of translations of shape `(minibatch, d)`.
- **s**: batch of scaling factors of shape `(minibatch, )`.
References:
[1] Shinji Umeyama: Least-Suqares Estimation of
Transformation Parameters Between Two Point Patterns
"""

# make sure we convert input Pointclouds structures to tensors
Xt, num_points = _convert_point_cloud_to_tensor(X)
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)

if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
raise ValueError(
"Point sets X and Y have to have the same \
number of batches, points and dimensions."
)

b, n, dim = Xt.shape

# compute the centroids of the point sets
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)

# mean-center the point sets
Xc = Xt - Xmu[:, None]
Yc = Yt - Ymu[:, None]

if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
mask = (
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
< num_points[:, None]
).type_as(Xc)
Xc *= mask[:, :, None]
Yc *= mask[:, :, None]

if (num_points < (dim + 1)).any():
warnings.warn(
"The size of one of the point clouds is <= dim+1. "
+ "corresponding_points_alignment can't return a unique solution."
)

# compute the covariance XYcov between the point sets Xc, Yc
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)

# decompose the covariance matrix XYcov
U, S, V = torch.svd(XYcov)

# identity matrix used for fixing reflections
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
b, 1, 1
)

if not allow_reflection:
# reflection test:
# checks whether the estimated rotation has det==1,
# if not, finds the nearest rotation s.t. det==1 by
# flipping the sign of the last singular vector U
R_test = torch.bmm(U, V.transpose(2, 1))
E[:, -1, -1] = torch.det(R_test)

# find the rotation matrix by composing U and V again
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))

if estimate_scale:
# estimate the scaling component of the transformation
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)

# the scaling component
s = trace_ES / torch.clamp(Xcov, eps)

# translation component
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]

else:
# translation component
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]

# unit scaling since we do not estimate scale
s = T.new_ones(b)

return R, T, s


def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
"""
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
padded representation and returns it together with the number of points
per batch. Otherwise, returns the input itself with the number of points
set to the size of the second dimension of `pcl`.
"""
if isinstance(pcl, Pointclouds):
X = pcl.points_padded()
num_points = pcl.num_points_per_cloud()
elif torch.is_tensor(pcl):
X = pcl
num_points = X.shape[1] * torch.ones(
X.shape[0], device=X.device, dtype=torch.int64
)
else:
raise ValueError(
"The inputs X, Y should be either Pointclouds objects or tensors."
)
return X, num_points
40 changes: 40 additions & 0 deletions tests/bm_points_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from copy import deepcopy
from itertools import product
from fvcore.common.benchmark import benchmark

from test_points_alignment import TestCorrespondingPointsAlignment


def bm_corresponding_points_alignment() -> None:

case_grid = {
"allow_reflection": [True, False],
"batch_size": [1, 10, 100],
"dim": [3, 20],
"estimate_scale": [True, False],
"n_points": [100, 10000],
"use_pointclouds": [False],
}

test_args = sorted(case_grid.keys())
test_cases = product(*[case_grid[k] for k in test_args])
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]

# add the use_pointclouds=True test cases whenever we have dim==3
kwargs_to_add = []
for entry in kwargs_list:
if entry["dim"] == 3:
entry_add = deepcopy(entry)
entry_add["use_pointclouds"] = True
kwargs_to_add.append(entry_add)
kwargs_list.extend(kwargs_to_add)

benchmark(
TestCorrespondingPointsAlignment.corresponding_points_alignment,
"CorrespodingPointsAlignment",
kwargs_list,
warmup_iters=1,
)
Loading

0 comments on commit e5b1d6d

Please sign in to comment.