Skip to content

Commit

Permalink
Add annotations to vision/fair/pytorch3d
Browse files Browse the repository at this point in the history
Reviewed By: shannonzhu

Differential Revision: D33970393

fbshipit-source-id: 9b4dfaccfc3793fd37705a923d689cb14c9d26ba
  • Loading branch information
Pyre Bot Jr authored and facebook-github-bot committed Feb 3, 2022
1 parent c2862ff commit e9fb6c2
Show file tree
Hide file tree
Showing 21 changed files with 65 additions and 49 deletions.
7 changes: 5 additions & 2 deletions pytorch3d/datasets/r2n2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover
return collated_dict


def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover
def compute_extrinsic_matrix(
azimuth: float, elevation: float, distance: float
): # pragma: no cover
"""
Copied from meshrcnn codebase:
https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L96
Expand Down Expand Up @@ -138,6 +140,7 @@ def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
# roll that rotation into the extrinsic matrix here
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
# pyre-fixme[16]: `Tensor` has no attribute `mm`.
RT = RT.mm(rot.to(RT))

return RT
Expand Down Expand Up @@ -384,7 +387,7 @@ def voxelize(voxel_coords, P, V): # pragma: no cover
return voxels


def project_verts(verts, P, eps=1e-1): # pragma: no cover
def project_verts(verts, P, eps: float = 1e-1): # pragma: no cover
"""
Copied from meshrcnn codebase:
https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L159
Expand Down
10 changes: 6 additions & 4 deletions pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)


def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
def _format_faces_indices(faces_indices, max_index: int, device, pad_value=None):
"""
Format indices and check for invalid values. Indices can refer to
values in one of the face properties: vertices, textures or normals.
Expand All @@ -57,6 +57,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
)

if pad_value is not None:
# pyre-fixme[28]: Unexpected keyword argument `dim`.
mask = faces_indices.eq(pad_value).all(dim=-1)

# Change to 0 based indexing.
Expand All @@ -66,14 +67,15 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
faces_indices[(faces_indices < 0)] += max_index

if pad_value is not None:
# pyre-fixme[61]: `mask` is undefined, or not always defined.
faces_indices[mask] = pad_value

return _check_faces_indices(faces_indices, max_index, pad_value)


def load_obj(
f,
load_textures=True,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
Expand Down Expand Up @@ -351,7 +353,7 @@ def _parse_face(
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
):
) -> None:
face = tokens[1:]
face_list = [f.split("/") for f in face]
face_verts = []
Expand Down Expand Up @@ -546,7 +548,7 @@ def _load_materials(
def _load_obj(
f_obj,
*,
data_dir,
data_dir: str,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
Expand Down
4 changes: 3 additions & 1 deletion pytorch3d/io/ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def _read_ply_element_ascii(f, definition: _PlyElementType):
return data


def _read_raw_array(f, aim: str, length: int, dtype: type = np.uint8, dtype_size=1):
def _read_raw_array(
f, aim: str, length: int, dtype: type = np.uint8, dtype_size: int = 1
):
"""
Read [length] elements from a file.
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def nullcontext(x):
PathOrStr = Union[pathlib.Path, str]


def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
def _open_file(f, path_manager: PathManager, mode: str = "r") -> ContextManager[IO]:
if isinstance(f, str):
f = path_manager.open(f, mode)
return contextlib.closing(f)
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/loss/chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def _validate_chamfer_reduction_inputs(
batch_reduction: Union[str, None], point_reduction: str
):
) -> None:
"""Check the requested reductions are valid.
Args:
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def knn_points(
version: int = -1,
return_nn: bool = False,
return_sorted: bool = True,
):
) -> _KNN:
"""
K-Nearest neighbors on point clouds.
Expand Down
3 changes: 2 additions & 1 deletion pytorch3d/ops/points_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def estimate_pointcloud_local_coord_frames(
return curvatures, local_coord_frames


def _disambiguate_vector_directions(pcl, knns, vecs):
def _disambiguate_vector_directions(pcl, knns, vecs: float) -> float:
"""
Disambiguates normal directions according to [1].
Expand All @@ -180,6 +180,7 @@ def _disambiguate_vector_directions(pcl, knns, vecs):
# each element of the neighborhood
df = knns - pcl[:, :, None]
# projection of the difference on the principal direction
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
proj = (vecs[:, :, None] * df).sum(3)
# check how many projections are positive
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/ops/points_to_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def _check_points_to_volumes_inputs(
volume_features: torch.Tensor,
grid_sizes: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
):
) -> None:

max_grid_size = grid_sizes.max(dim=0).values
if torch.prod(max_grid_size) > volume_densities.shape[1]:
Expand Down
4 changes: 3 additions & 1 deletion pytorch3d/ops/subdivide_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
return verts_idx


def create_faces_index(faces_per_mesh, device=None):
def create_faces_index(faces_per_mesh: int, device=None):
"""
Helper function to group the faces indices for each mesh. New faces are
stacked at the end of the original faces tensor, so in order to have
Expand All @@ -417,7 +417,9 @@ def create_faces_index(faces_per_mesh, device=None):
"""
# e.g. faces_per_mesh = [2, 5, 3]

# pyre-fixme[16]: `int` has no attribute `sum`.
F = faces_per_mesh.sum() # e.g. 10
# pyre-fixme[16]: `int` has no attribute `cumsum`.
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)

switch1_idx = faces_per_mesh_cumsum.clone()
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
return X, num_points


def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]) -> bool:
"""Checks whether the input `pcl` is an instance of `Pointclouds`
by checking the existence of `points_padded` and `num_points_per_cloud`
functions.
Expand Down
30 changes: 15 additions & 15 deletions pytorch3d/renderer/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,10 @@ def __getitem__(


def OpenGLPerspectiveCameras(
znear=1.0,
zfar=100.0,
aspect_ratio=1.0,
fov=60.0,
znear: float = 1.0,
zfar: float = 100.0,
aspect_ratio: float = 1.0,
fov: float = 60.0,
degrees: bool = True,
R: torch.Tensor = _R,
T: torch.Tensor = _T,
Expand Down Expand Up @@ -709,12 +709,12 @@ def in_ndc(self):


def OpenGLOrthographicCameras(
znear=1.0,
zfar=100.0,
top=1.0,
bottom=-1.0,
left=-1.0,
right=1.0,
znear: float = 1.0,
zfar: float = 100.0,
top: float = 1.0,
bottom: float = -1.0,
left: float = -1.0,
right: float = 1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R: torch.Tensor = _R,
T: torch.Tensor = _T,
Expand Down Expand Up @@ -956,7 +956,7 @@ def in_ndc(self):


def SfMPerspectiveCameras(
focal_length=1.0,
focal_length: float = 1.0,
principal_point=((0.0, 0.0),),
R: torch.Tensor = _R,
T: torch.Tensor = _T,
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def in_ndc(self):


def SfMOrthographicCameras(
focal_length=1.0,
focal_length: float = 1.0,
principal_point=((0.0, 0.0),),
R: torch.Tensor = _R,
T: torch.Tensor = _T,
Expand Down Expand Up @@ -1645,9 +1645,9 @@ def look_at_rotation(


def look_at_view_transform(
dist=1.0,
elev=0.0,
azim=0.0,
dist: float = 1.0,
elev: float = 0.0,
azim: float = 0.0,
degrees: bool = True,
eye: Optional[Sequence] = None,
at=((0, 0, 0),), # (1, 3)
Expand Down
6 changes: 3 additions & 3 deletions pytorch3d/renderer/implicit/raymarching.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(
return opacities


def _shifted_cumprod(x, shift=1):
def _shifted_cumprod(x, shift: int = 1):
"""
Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
ones and removes `shift` trailing elements to/from the last dimension
Expand All @@ -177,7 +177,7 @@ def _shifted_cumprod(x, shift=1):

def _check_density_bounds(
rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0)
):
) -> None:
"""
Checks whether the elements of `rays_densities` range within `bounds`.
If not issues a warning.
Expand All @@ -197,7 +197,7 @@ def _check_raymarcher_inputs(
features_can_be_none: bool = False,
z_can_be_none: bool = False,
density_1d: bool = True,
):
) -> None:
"""
Checks the validity of the inputs to raymarching algorithms.
"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/implicit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _validate_ray_bundle_variables(
rays_origins: torch.Tensor,
rays_directions: torch.Tensor,
rays_lengths: torch.Tensor,
):
) -> None:
"""
Validate the shapes of RayBundle variables
`rays_origins`, `rays_directions`, and `rays_lengths`.
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/lighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
return torch.zeros_like(points)


def _validate_light_properties(obj):
def _validate_light_properties(obj) -> None:
props = ("ambient_color", "diffuse_color", "specular_color")
for n in props:
t = getattr(obj, n)
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/mesh/shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def TexturedSoftPhongShader(
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
):
) -> SoftPhongShader:
"""
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
Preserving TexturedSoftPhongShader as a function for backwards compatibility.
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def sample_textures(self, fragments):
raise ValueError("Meshes does not have textures")


def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True) -> Meshes:
"""
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
must all be on the same device. If include_textures is true, they must all
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def inside_box(self, box):
return coord_inside.all(dim=-1)


def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]):
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]) -> Pointclouds:
"""
Merge a list of Pointclouds objects into a single batched Pointclouds
object. All pointclouds must be on the same device.
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/transforms/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch


DEFAULT_ACOS_BOUND = 1.0 - 1e-4
DEFAULT_ACOS_BOUND: float = 1.0 - 1e-4


def acos_linear_extrapolation(
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def _broadcast_bmm(a, b):


@torch.no_grad()
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None:
"""
Determine if R is a valid rotation matrix by checking it satisfies the
following conditions:
Expand Down
Loading

0 comments on commit e9fb6c2

Please sign in to comment.