Skip to content

Commit

Permalink
some pointcloud typing
Browse files Browse the repository at this point in the history
Summary: Make clear that features_padded() etc can return None

Reviewed By: patricklabatut

Differential Revision: D31795088

fbshipit-source-id: 7b0bbb6f3b7ad7f7b6e6a727129537af1d1873af
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 28, 2021
1 parent 73a14d7 commit bfeb82e
Showing 1 changed file with 61 additions and 45 deletions.
106 changes: 61 additions & 45 deletions pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from itertools import zip_longest
from typing import Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -240,7 +240,9 @@ def __init__(self, points, normals=None, features=None) -> None:
if features_C is not None:
self._C = features_C

def _parse_auxiliary_input(self, aux_input):
def _parse_auxiliary_input(
self, aux_input
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[int]]:
"""
Interpret the auxiliary inputs (normals, features) given to __init__.
Expand Down Expand Up @@ -323,24 +325,26 @@ def __getitem__(self, index) -> "Pointclouds":
Pointclouds object with selected clouds. The tensors are not cloned.
"""
normals, features = None, None
normals_list = self.normals_list()
features_list = self.features_list()
if isinstance(index, int):
points = [self.points_list()[index]]
if self.normals_list() is not None:
normals = [self.normals_list()[index]]
if self.features_list() is not None:
features = [self.features_list()[index]]
if normals_list is not None:
normals = [normals_list[index]]
if features_list is not None:
features = [features_list[index]]
elif isinstance(index, slice):
points = self.points_list()[index]
if self.normals_list() is not None:
normals = self.normals_list()[index]
if self.features_list() is not None:
features = self.features_list()[index]
if normals_list is not None:
normals = normals_list[index]
if features_list is not None:
features = features_list[index]
elif isinstance(index, list):
points = [self.points_list()[i] for i in index]
if self.normals_list() is not None:
normals = [self.normals_list()[i] for i in index]
if self.features_list() is not None:
features = [self.features_list()[i] for i in index]
if normals_list is not None:
normals = [normals_list[i] for i in index]
if features_list is not None:
features = [features_list[i] for i in index]
elif isinstance(index, torch.Tensor):
if index.dim() != 1 or index.dtype.is_floating_point:
raise IndexError(index)
Expand All @@ -351,10 +355,10 @@ def __getitem__(self, index) -> "Pointclouds":
index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist()
points = [self.points_list()[i] for i in index]
if self.normals_list() is not None:
normals = [self.normals_list()[i] for i in index]
if self.features_list() is not None:
features = [self.features_list()[i] for i in index]
if normals_list is not None:
normals = [normals_list[i] for i in index]
if features_list is not None:
features = [features_list[i] for i in index]
else:
raise IndexError(index)

Expand All @@ -369,7 +373,7 @@ def isempty(self) -> bool:
"""
return self._N == 0 or self.valid.eq(False).all()

def points_list(self):
def points_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the points.
Expand All @@ -388,9 +392,10 @@ def points_list(self):
self._points_list = points_list
return self._points_list

def normals_list(self):
def normals_list(self) -> Optional[List[torch.Tensor]]:
"""
Get the list representation of the normals.
Get the list representation of the normals,
or None if there are no normals.
Returns:
list of tensors of normals of shape (P_n, 3).
Expand All @@ -404,9 +409,10 @@ def normals_list(self):
)
return self._normals_list

def features_list(self):
def features_list(self) -> Optional[List[torch.Tensor]]:
"""
Get the list representation of the features.
Get the list representation of the features,
or None if there are no features.
Returns:
list of tensors of features of shape (P_n, C).
Expand All @@ -420,7 +426,7 @@ def features_list(self):
)
return self._features_list

def points_packed(self):
def points_packed(self) -> torch.Tensor:
"""
Get the packed representation of the points.
Expand All @@ -430,22 +436,24 @@ def points_packed(self):
self._compute_packed()
return self._points_packed

def normals_packed(self):
def normals_packed(self) -> Optional[torch.Tensor]:
"""
Get the packed representation of the normals.
Returns:
tensor of normals of shape (sum(P_n), 3).
tensor of normals of shape (sum(P_n), 3),
or None if there are no normals.
"""
self._compute_packed()
return self._normals_packed

def features_packed(self):
def features_packed(self) -> Optional[torch.Tensor]:
"""
Get the packed representation of the features.
Returns:
tensor of features of shape (sum(P_n), C).
tensor of features of shape (sum(P_n), C),
or None if there are no features
"""
self._compute_packed()
return self._features_packed
Expand Down Expand Up @@ -483,7 +491,7 @@ def num_points_per_cloud(self):
"""
return self._num_points_per_cloud

def points_padded(self):
def points_padded(self) -> torch.Tensor:
"""
Get the padded representation of the points.
Expand All @@ -493,19 +501,21 @@ def points_padded(self):
self._compute_padded()
return self._points_padded

def normals_padded(self):
def normals_padded(self) -> Optional[torch.Tensor]:
"""
Get the padded representation of the normals.
Get the padded representation of the normals,
or None if there are no normals.
Returns:
tensor of normals of shape (N, max(P_n), 3).
"""
self._compute_padded()
return self._normals_padded

def features_padded(self):
def features_padded(self) -> Optional[torch.Tensor]:
"""
Get the padded representation of the features.
Get the padded representation of the features,
or None if there are no features.
Returns:
tensor of features of shape (N, max(P_n), 3).
Expand Down Expand Up @@ -562,16 +572,18 @@ def _compute_padded(self, refresh: bool = False):
pad_value=0.0,
equisized=self.equisized,
)
if self.normals_list() is not None:
normals_list = self.normals_list()
if normals_list is not None:
self._normals_padded = struct_utils.list_to_padded(
self.normals_list(),
normals_list,
(self._P, 3),
pad_value=0.0,
equisized=self.equisized,
)
if self.features_list() is not None:
features_list = self.features_list()
if features_list is not None:
self._features_padded = struct_utils.list_to_padded(
self.features_list(),
features_list,
(self._P, self._C),
pad_value=0.0,
equisized=self.equisized,
Expand Down Expand Up @@ -772,10 +784,12 @@ def get_cloud(self, index: int):
)
points = self.points_list()[index]
normals, features = None, None
if self.normals_list() is not None:
normals = self.normals_list()[index]
if self.features_list() is not None:
features = self.features_list()[index]
normals_list = self.normals_list()
if normals_list is not None:
normals = normals_list[index]
features_list = self.features_list()
if features_list is not None:
features = features_list[index]
return points, normals, features

# TODO(nikhilar) Move function to a utils file.
Expand Down Expand Up @@ -1022,13 +1036,15 @@ def extend(self, N: int):
new_points_list, new_normals_list, new_features_list = [], None, None
for points in self.points_list():
new_points_list.extend(points.clone() for _ in range(N))
if self.normals_list() is not None:
normals_list = self.normals_list()
if normals_list is not None:
new_normals_list = []
for normals in self.normals_list():
for normals in normals_list:
new_normals_list.extend(normals.clone() for _ in range(N))
if self.features_list() is not None:
features_list = self.features_list()
if features_list is not None:
new_features_list = []
for features in self.features_list():
for features in features_list:
new_features_list.extend(features.clone() for _ in range(N))
return self.__class__(
points=new_points_list, normals=new_normals_list, features=new_features_list
Expand Down

0 comments on commit bfeb82e

Please sign in to comment.