diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index d8b6b13b9..6501acba7 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .clip import ClipFrustum, ClippedFaces, clip_faces from .rasterize_meshes import rasterize_meshes from .rasterizer import MeshRasterizer, RasterizationSettings from .renderer import MeshRenderer diff --git a/pytorch3d/renderer/mesh/clip.py b/pytorch3d/renderer/mesh/clip.py new file mode 100644 index 000000000..bf4681df3 --- /dev/null +++ b/pytorch3d/renderer/mesh/clip.py @@ -0,0 +1,600 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Any, List, Optional, Tuple + +import torch + + +""" +Mesh clipping is done before rasterization and is implemented using 4 cases +(these will be referred to throughout the functions below) + +Case 1: the triangle is completely in front of the clipping plane (it is left + unchanged) +Case 2: the triangle is completely behind the clipping plane (it is culled) +Case 3: the triangle has exactly two vertices behind the clipping plane (it is + clipped into a smaller triangle) +Case 4: the triangle has exactly one vertex behind the clipping plane (it is clipped + into a smaller quadrilateral and divided into two triangular faces) + +After rasterization, the Fragments from the clipped/modified triangles +are mapped back to the triangles in the original mesh. The indices, +barycentric coordinates and distances are all relative to original mesh triangles. + +NOTE: It is assumed that all z-coordinates are in world coordinates (not NDC +coordinates), while x and y coordinates may be in NDC/screen coordinates +(i.e after applying a projective transform e.g. cameras.transform_points(points)). +""" + + +class ClippedFaces: + """ + Helper class to store the data for the clipped version of a Meshes object + (face_verts, mesh_to_face_first_idx, num_faces_per_mesh) along with + conversion information (faces_clipped_to_unclipped_idx, barycentric_conversion, + faces_clipped_to_conversion_idx, clipped_faces_neighbor_idx) required to convert + barycentric coordinates from rasterization of the clipped Meshes to barycentric + coordinates in terms of the unclipped Meshes. + + Args: + face_verts: FloatTensor of shape (F_clipped, 3, 3) giving the verts of + each of the clipped faces + mesh_to_face_first_idx: an tensor of shape (N,), where N is the number of meshes + in the batch. The ith element stores the index into face_verts + of the first face of the ith mesh. + num_faces_per_mesh: a tensor of shape (N,) storing the number of faces in each mesh. + faces_clipped_to_unclipped_idx: (F_clipped,) shaped LongTensor mapping each clipped + face back to the face in faces_unclipped (i.e. the faces in the original meshes + obtained using meshes.faces_packed()) + barycentric_conversion: (T, 3, 3) FloatTensor, where barycentric_conversion[i, :, k] + stores the barycentric weights in terms of the world coordinates of the original + (big) unclipped triangle for the kth vertex in the clipped (small) triangle. + If the rasterizer then expresses some NDC coordinate in terms of barycentric + world coordinates for the clipped (small) triangle as alpha_clipped[i,:], + alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :] + faces_clipped_to_conversion_idx: (F_clipped,) shaped LongTensor mapping each clipped + face to the applicable row of barycentric_conversion (or set to -1 if conversion is + not needed). + clipped_faces_neighbor_idx: LongTensor of shape (F_clipped,) giving the index of the + neighboring face for each case 4 triangle. e.g. for a case 4 face with f split + into two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx. + Faces which are not clipped and subdivided are set to -1 (i.e cases 1/2/3). + """ + + __slots__ = [ + "face_verts", + "mesh_to_face_first_idx", + "num_faces_per_mesh", + "faces_clipped_to_unclipped_idx", + "barycentric_conversion", + "faces_clipped_to_conversion_idx", + "clipped_faces_neighbor_idx", + ] + + def __init__( + self, + face_verts: torch.Tensor, + mesh_to_face_first_idx: torch.Tensor, + num_faces_per_mesh: torch.Tensor, + faces_clipped_to_unclipped_idx: Optional[torch.Tensor] = None, + barycentric_conversion: Optional[torch.Tensor] = None, + faces_clipped_to_conversion_idx: Optional[torch.Tensor] = None, + clipped_faces_neighbor_idx: Optional[torch.Tensor] = None, + ): + self.face_verts = face_verts + self.mesh_to_face_first_idx = mesh_to_face_first_idx + self.num_faces_per_mesh = num_faces_per_mesh + self.faces_clipped_to_unclipped_idx = faces_clipped_to_unclipped_idx + self.barycentric_conversion = barycentric_conversion + self.faces_clipped_to_conversion_idx = faces_clipped_to_conversion_idx + self.clipped_faces_neighbor_idx = clipped_faces_neighbor_idx + + +class ClipFrustum: + """ + Helper class to store the information needed to represent a view frustum + (left, right, top, bottom, znear, zfar), which is used to clip or cull triangles. + Values left as None mean that culling should not be performed for that axis. + The parameters perspective_correct, cull, and z_clip_value are used to define + behavior for clipping triangles to the frustum. + + Args: + left: NDC coordinate of the left clipping plane (along x axis) + right: NDC coordinate of the right clipping plane (along x axis) + top: NDC coordinate of the top clipping plane (along y axis) + bottom: NDC coordinate of the bottom clipping plane (along y axis) + znear: world space z coordinate of the near clipping plane + zfar: world space z coordinate of the far clipping plane + perspective_correct: should be set to True for a perspective camera + cull: if True, triangles outside the frustum should be culled + z_clip_value: if not None, then triangles should be clipped (possibly into + smaller triangles) such that z >= z_clip_value. This avoids projections + that go to infinity as z->0 + """ + + __slots__ = [ + "left", + "right", + "top", + "bottom", + "znear", + "zfar", + "perspective_correct", + "cull", + "z_clip_value", + ] + + def __init__( + self, + left: Optional[float] = None, + right: Optional[float] = None, + top: Optional[float] = None, + bottom: Optional[float] = None, + znear: Optional[float] = None, + zfar: Optional[float] = None, + perspective_correct: bool = False, + cull: bool = True, + z_clip_value: Optional[float] = None, + ): + self.left = left + self.right = right + self.top = top + self.bottom = bottom + self.znear = znear + self.zfar = zfar + self.perspective_correct = perspective_correct + self.cull = cull + self.z_clip_value = z_clip_value + + +def _get_culled_faces(face_verts: torch.Tensor, frustum: ClipFrustum) -> torch.Tensor: + """ + Helper function used to find all the faces in Meshes which are + fully outside the view frustum. A face is culled if all 3 vertices are outside + the same axis of the view frustum. + + Args: + face_verts: An (F,3,3) tensor, where F is the number of faces in + the packed representation of Meshes. The 2nd dimension represents the 3 vertices + of a triangle, and the 3rd dimension stores the xyz locations of each + vertex. + frustum: An instance of the ClipFrustum class with the information on the + position of the clipping planes. + + Returns: + faces_culled: An boolean tensor of size F specifying whether or not each face should be + culled. + """ + clipping_planes = ( + (frustum.left, 0, "<"), + (frustum.right, 0, ">"), + (frustum.top, 1, "<"), + (frustum.bottom, 1, ">"), + (frustum.znear, 2, "<"), + (frustum.zfar, 2, ">"), + ) + faces_culled = torch.zeros( + [face_verts.shape[0]], dtype=torch.bool, device=face_verts.device + ) + for plane in clipping_planes: + clip_value, axis, op = plane + # If clip_value is None then don't clip along that plane + if frustum.cull and clip_value is not None: + if op == "<": + verts_clipped = face_verts[:, axis] < clip_value + else: + verts_clipped = face_verts[:, axis] > clip_value + + # If all verts are clipped then face is outside the frustum + faces_culled |= verts_clipped.sum(1) == 3 + + return faces_culled + + +def _find_verts_intersecting_clipping_plane( + face_verts: torch.Tensor, + p1_face_ind: torch.Tensor, + clip_value: float, + perspective_correct: bool, +) -> Tuple[Tuple[Any, Any, Any, Any, Any], List[Any]]: + r""" + Helper function to find the vertices used to form a new triangle for case 3/case 4 faces. + + Given a list of triangles that are already known to intersect the clipping plane, + solve for the two vertices p4 and p5 where the edges of the triangle intersects the + clipping plane. + + p1 + /\ + / \ + / t \ + _____________p4/______\p5__________ clip_value + / \ + /____ \ + p2 ---____\p3 + + Args: + face_verts: An (F,3,3) tensor, where F is the number of faces in + the packed representation of the Meshes, the 2nd dimension represents + the 3 vertices of the face, and the 3rd dimension stores the xyz locations of each + vertex. The z-coordinates must be represented in world coordinates, while + the xy-coordinates may be in NDC/screen coordinates (i.e. after projection). + p1_face_ind: A tensor of shape (N,) with values in the range of 0 to 2. In each + case 3/case 4 triangle, two vertices are on the same side of the + clipping plane and the 3rd is on the other side. p1_face_ind stores the index of + the vertex that is not on the same side as any other vertex in the triangle. + clip_value: Float, the z-value defining where to clip the triangle. + perspective_correct: Bool, Should be set to true if a perspective camera was + used and xy-coordinates of face_verts_unclipped are in NDC/screen coordinates. + + Returns: + A 2-tuple + p: (p1, p2, p3, p4, p5)) + p_barycentric (p1_bary, p2_bary, p3_bary, p4_bary, p5_bary) + + Each of p1...p5 is an (F,3) tensor of the xyz locations of the 5 points in the + diagram above for case 3/case 4 faces. Each p1_bary...p5_bary is an (F, 3) tensor + storing the barycentric weights used to encode p1...p5 in terms of the the original + unclipped triangle. + """ + + # Let T be number of triangles in face_verts (note that these correspond to the subset + # of case 1 or case 2 triangles). p1_face_ind, p2_face_ind, and p3_face_ind are (T) + # tensors with values in the range of 0 to 2. p1_face_ind stores the index of the + # vertex that is not on the same side as any other vertex in the triangle, and + # p2_face_ind and p3_face_ind are the indices of the other two vertices preserving + # the same counterclockwise or clockwise ordering + T = face_verts.shape[0] + p2_face_ind = torch.remainder(p1_face_ind + 1, 3) + p3_face_ind = torch.remainder(p1_face_ind + 2, 3) + + # p1, p2, p3 are (T, 3) tensors storing the corresponding (x, y, z) coordinates + # of p1_face_ind, p2_face_ind, p3_face_ind + # pyre-ignore[16] + p1 = face_verts.gather(1, p1_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1) + p2 = face_verts.gather(1, p2_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1) + p3 = face_verts.gather(1, p3_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1) + + ################################## + # Solve for intersection point p4 + ################################## + + # p4 is a (T, 3) tensor is the point on the segment between p1 and p2 that + # intersects the clipping plane. + # Solve for the weight w2 such that p1.z*(1-w2) + p2.z*w2 = clip_value. + # Then interpolate p4 = p1*(1-w2) + p2*w2 where it is assumed that z-coordinates + # are expressed in world coordinates (since we want to clip z in world coordinates). + w2 = (p1[:, 2] - clip_value) / (p1[:, 2] - p2[:, 2]) + p4 = p1 * (1 - w2[:, None]) + p2 * w2[:, None] + if perspective_correct: + # It is assumed that all z-coordinates are in world coordinates (not NDC + # coordinates), while x and y coordinates may be in NDC/screen coordinates. + # If x and y are in NDC/screen coordinates and a projective transform was used + # in a perspective camera, then we effectively want to: + # 1. Convert back to world coordinates (by multiplying by z) + # 2. Interpolate using w2 + # 3. Convert back to NDC/screen coordinates (by dividing by the new z=clip_value) + p1_world = p1[:, :2] * p1[:, 2:3] + p2_world = p2[:, :2] * p2[:, 2:3] + p4[:, :2] = (p1_world * (1 - w2[:, None]) + p2_world * w2[:, None]) / clip_value + + ################################## + # Solve for intersection point p5 + ################################## + + # p5 is a (T, 3) tensor representing the point on the segment between p1 and p3 that + # intersects the clipping plane. + # Solve for the weight w3 such that p1.z * (1-w3) + p2.z * w3 = clip_value, + # and then interpolate p5 = p1 * (1-w3) + p3 * w3 + w3 = (p1[:, 2] - clip_value) / (p1[:, 2] - p3[:, 2]) + w3 = w3.detach() + p5 = p1 * (1 - w3[:, None]) + p3 * w3[:, None] + if perspective_correct: + # Again if using a perspective camera, convert back to world coordinates + # interpolate and convert back + p1_world = p1[:, :2] * p1[:, 2:3] + p3_world = p3[:, :2] * p3[:, 2:3] + p5[:, :2] = (p1_world * (1 - w3[:, None]) + p3_world * w3[:, None]) / clip_value + + # Set the barycentric coordinates of p1,p2,p3,p4,p5 in terms of the original + # unclipped triangle in face_verts. + T_idx = torch.arange(T, device=face_verts.device) + p_barycentric = [torch.zeros((T, 3), device=face_verts.device) for i in range(5)] + p_barycentric[0][(T_idx, p1_face_ind)] = 1 + p_barycentric[1][(T_idx, p2_face_ind)] = 1 + p_barycentric[2][(T_idx, p3_face_ind)] = 1 + p_barycentric[3][(T_idx, p1_face_ind)] = 1 - w2 + p_barycentric[3][(T_idx, p2_face_ind)] = w2 + p_barycentric[4][(T_idx, p1_face_ind)] = 1 - w3 + p_barycentric[4][(T_idx, p3_face_ind)] = w3 + + p = (p1, p2, p3, p4, p5) + + return p, p_barycentric + + +################### +# Main Entry point +################### +def clip_faces( + face_verts_unclipped: torch.Tensor, + mesh_to_face_first_idx: torch.Tensor, + num_faces_per_mesh: torch.Tensor, + frustum: ClipFrustum, +) -> ClippedFaces: + """ + Clip a mesh to the portion contained within a view frustum and with z > z_clip_value. + + There are two types of clipping: + 1) Cull triangles that are completely outside the view frustum. This is purely + to save computation by reducing the number of triangles that need to be + rasterized. + 2) Clip triangles into the portion of the triangle where z > z_clip_value. The + clipped region may be a quadrilateral, which results in splitting a triangle + into two triangles. This does not save computation, but is necessary to + correctly rasterize using perspective cameras for triangles that pass through + z <= 0, because NDC/screen coordinates go to infinity at z=0. + + Args: + face_verts_unclipped: An (F, 3, 3) tensor, where F is the number of faces in + the packed representation of Meshes, the 2nd dimension represents the 3 vertices + of the triangle, and the 3rd dimension stores the xyz locations of each + vertex. The z-coordinates must be represented in world coordinates, while + the xy-coordinates may be in NDC/screen coordinates + mesh_to_face_first_idx: an tensor of shape (N,), where N is the number of meshes + in the batch. The ith element stores the index into face_verts_unclipped + of the first face of the ith mesh. + num_faces_per_mesh: a tensor of shape (N,) storing the number of faces in each mesh. + frustum: a ClipFrustum object defining the frustum used to cull faces. + + Returns: + clipped_faces: ClippedFaces object storing a clipped version of the Meshes + along with tensors that can be used to convert barycentric coordinates + returned by rasterization of the clipped meshes into a barycentric + coordinates for the unclipped meshes. + """ + F = face_verts_unclipped.shape[0] + device = face_verts_unclipped.device + + # Triangles completely outside the view frustum will be culled + # faces_culled is of shape (F, ) + faces_culled = _get_culled_faces(face_verts_unclipped, frustum) + + # Triangles that are partially behind the z clipping plane will be clipped to + # smaller triangles + z_clip_value = frustum.z_clip_value + perspective_correct = frustum.perspective_correct + if z_clip_value is not None: + # (F, 3) tensor (where F is the number of triangles) marking whether each vertex + # in a triangle is behind the clipping plane + faces_clipped_verts = face_verts_unclipped[:, :, 2] < z_clip_value + + # (F) dim tensor containing the number of clipped vertices in each triangle + faces_num_clipped_verts = faces_clipped_verts.sum(1) + else: + faces_num_clipped_verts = torch.zeros([F, 3], device=device) + + # If no triangles need to be clipped or culled, avoid unnecessary computation + # and return early + if faces_num_clipped_verts.sum().item() == 0 and faces_culled.sum().item() == 0: + return ClippedFaces( + face_verts=face_verts_unclipped, + mesh_to_face_first_idx=mesh_to_face_first_idx, + num_faces_per_mesh=num_faces_per_mesh, + ) + + ##################################################################################### + # Classify faces into the 4 relevant cases: + # 1) The triangle is completely in front of the clipping plane (it is left + # unchanged) + # 2) The triangle is completely behind the clipping plane (it is culled) + # 3) The triangle has exactly two vertices behind the clipping plane (it is + # clipped into a smaller triangle) + # 4) The triangle has exactly one vertex behind the clipping plane (it is clipped + # into a smaller quadrilateral and split into two triangles) + ##################################################################################### + + # pyre-ignore[16]: + faces_unculled = ~faces_culled + # Case 1: no clipped verts or culled faces + cases1_unclipped = torch.logical_and(faces_num_clipped_verts == 0, faces_unculled) + case1_unclipped_idx = cases1_unclipped.nonzero(as_tuple=True)[0] + # Case 2: all verts clipped + case2_unclipped = torch.logical_or(faces_num_clipped_verts == 3, faces_culled) + # Case 3: two verts clipped + case3_unclipped = torch.logical_and(faces_num_clipped_verts == 2, faces_unculled) + case3_unclipped_idx = case3_unclipped.nonzero(as_tuple=True)[0] + # Case 4: one vert clipped + case4_unclipped = torch.logical_and(faces_num_clipped_verts == 1, faces_unculled) + case4_unclipped_idx = case4_unclipped.nonzero(as_tuple=True)[0] + + # faces_unclipped_to_clipped_idx is an (F) dim tensor storing the index of each + # face to the corresponding face in face_verts_clipped. + # Each case 2 triangle will be culled (deleted from face_verts_clipped), + # while each case 4 triangle will be split into two smaller triangles + # (replaced by two consecutive triangles in face_verts_clipped) + + # case2_unclipped is an (F,) dim 0/1 tensor of all the case2 faces + # case4_unclipped is an (F,) dim 0/1 tensor of all the case4 faces + faces_delta = case4_unclipped.int() - case2_unclipped.int() + # faces_delta_cum gives the per face change in index. Faces which are + # clipped in the original mesh are mapped to the closest non clipped face + # in face_verts_clipped (this doesn't matter as they are not used + # during rasterization anyway). + faces_delta_cum = faces_delta.cumsum(0) - faces_delta + delta = 1 + case4_unclipped.int() - case2_unclipped.int() + # pyre-ignore[16] + faces_unclipped_to_clipped_idx = delta.cumsum(0) - delta + + ########################################### + # Allocate tensors for the output Meshes. + # These will then be filled in for each case. + ########################################### + F_clipped = ( + F + faces_delta_cum[-1].item() + faces_delta[-1].item() + ) # Total number of faces in the new Meshes + face_verts_clipped = torch.zeros( + (F_clipped, 3, 3), dtype=face_verts_unclipped.dtype, device=device + ) + faces_clipped_to_unclipped_idx = torch.zeros( + [F_clipped], dtype=torch.int64, device=device + ) + + # Update version of mesh_to_face_first_idx and num_faces_per_mesh applicable to + # face_verts_clipped + mesh_to_face_first_idx_clipped = faces_unclipped_to_clipped_idx[ + mesh_to_face_first_idx + ] + F_clipped_t = torch.full([1], F_clipped, dtype=torch.int64, device=device) + num_faces_next = torch.cat((mesh_to_face_first_idx_clipped[1:], F_clipped_t)) + num_faces_per_mesh_clipped = num_faces_next - mesh_to_face_first_idx_clipped + + ################# Start Case 1 ######################################## + + # Case 1: Triangles are fully visible, copy unchanged triangles into the + # appropriate position in the new list of faces + case1_clipped_idx = faces_unclipped_to_clipped_idx[case1_unclipped_idx] + face_verts_clipped[case1_clipped_idx] = face_verts_unclipped[case1_unclipped_idx] + faces_clipped_to_unclipped_idx[case1_clipped_idx] = case1_unclipped_idx + + # If no triangles need to be clipped but some triangles were culled, avoid + # unnecessary clipping computation + if case3_unclipped_idx.shape[0] + case4_unclipped_idx.shape[0] == 0: + return ClippedFaces( + face_verts=face_verts_clipped, + mesh_to_face_first_idx=mesh_to_face_first_idx_clipped, + num_faces_per_mesh=num_faces_per_mesh_clipped, + faces_clipped_to_unclipped_idx=faces_clipped_to_unclipped_idx, + ) + + ################# End Case 1 ########################################## + + ################# Start Case 3 ######################################## + + # Case 3: exactly two vertices are behind the camera, clipping the triangle into a + # triangle. In the diagram below, we clip the bottom part of the triangle, and add + # new vertices p4 and p5 by intersecting with the clipping plane. The updated + # triangle is the triangle between p4, p1, p5 + # + # p1 (unclipped vertex) + # /\ + # / \ + # / t \ + # _____________p4/______\p5__________ clip_value + # xxxxxxxxxxxxxx/ \xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxx/____ \xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # xxxxxxxxxx p2 xxxx---____\p3 xxxxxxxxxxxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + faces_case3 = face_verts_unclipped[case3_unclipped_idx] + + # index (0, 1, or 2) of the vertex in front of the clipping plane + p1_face_ind = torch.where(~faces_clipped_verts[case3_unclipped_idx])[1] + + # Solve for the points p4, p5 that intersect the clipping plane + p, p_barycentric = _find_verts_intersecting_clipping_plane( + faces_case3, p1_face_ind, z_clip_value, perspective_correct + ) + + p1, _, _, p4, p5 = p + p1_barycentric, _, _, p4_barycentric, p5_barycentric = p_barycentric + + # Store clipped triangle + case3_clipped_idx = faces_unclipped_to_clipped_idx[case3_unclipped_idx] + t_barycentric = torch.stack((p4_barycentric, p5_barycentric, p1_barycentric), 2) + face_verts_clipped[case3_clipped_idx] = torch.stack((p4, p5, p1), 1) + faces_clipped_to_unclipped_idx[case3_clipped_idx] = case3_unclipped_idx + + ################# End Case 3 ########################################## + + ################# Start Case 4 ######################################## + + # Case 4: exactly one vertex is behind the camera, clip the triangle into a + # quadrilateral. In the diagram below, we clip the bottom part of the triangle, + # and add new vertices p4 and p5 by intersecting with the cliiping plane. The + # unclipped region is a quadrilateral, which is split into two triangles: + # t1: p4, p2, p5 + # t2: p5, p2, p3 + # + # p3_____________________p2 + # \ __--/ + # \ t2 __-- / + # \ __-- t1 / + # ______________p5\__--_________/p4_________clip_value + # xxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx + # xxxxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx + # p1 (clipped vertex) + + faces_case4 = face_verts_unclipped[case4_unclipped_idx] + + # index (0, 1, or 2) of the vertex behind the clipping plane + p1_face_ind = torch.where(faces_clipped_verts[case4_unclipped_idx])[1] + + # Solve for the points p4, p5 that intersect the clipping plane + p, p_barycentric = _find_verts_intersecting_clipping_plane( + faces_case4, p1_face_ind, z_clip_value, perspective_correct + ) + _, p2, p3, p4, p5 = p + _, p2_barycentric, p3_barycentric, p4_barycentric, p5_barycentric = p_barycentric + + # Store clipped triangles + case4_clipped_idx = faces_unclipped_to_clipped_idx[case4_unclipped_idx] + face_verts_clipped[case4_clipped_idx] = torch.stack((p4, p2, p5), 1) + face_verts_clipped[case4_clipped_idx + 1] = torch.stack((p5, p2, p3), 1) + t1_barycentric = torch.stack((p4_barycentric, p2_barycentric, p5_barycentric), 2) + t2_barycentric = torch.stack((p5_barycentric, p2_barycentric, p3_barycentric), 2) + faces_clipped_to_unclipped_idx[case4_clipped_idx] = case4_unclipped_idx + faces_clipped_to_unclipped_idx[case4_clipped_idx + 1] = case4_unclipped_idx + + ##################### End Case 4 ######################### + + # Triangles that were clipped (case 3 & case 4) will require conversion of + # barycentric coordinates from being in terms of the smaller clipped triangle to in terms + # of the original big triangle. If there are T clipped triangles, + # barycentric_conversion is a (T, 3, 3) tensor, where barycentric_conversion[i, :, k] + # stores the barycentric weights in terms of the world coordinates of the original + # (big) triangle for the kth vertex in the clipped (small) triangle. If our + # rasterizer then expresses some NDC coordinate in terms of barycentric + # world coordinates for the clipped (small) triangle as alpha_clipped[i,:], + # alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :] + barycentric_conversion = torch.cat((t_barycentric, t1_barycentric, t2_barycentric)) + + # faces_clipped_to_conversion_idx is an (F_clipped,) shape tensor mapping each output + # face to the applicable row of barycentric_conversion (or set to -1 if conversion is + # not needed) + faces_to_convert_idx = torch.cat( + (case3_clipped_idx, case4_clipped_idx, case4_clipped_idx + 1), 0 + ) + barycentric_idx = torch.arange( + barycentric_conversion.shape[0], dtype=torch.int64, device=device + ) + faces_clipped_to_conversion_idx = torch.full( + [F_clipped], -1, dtype=torch.int64, device=device + ) + faces_clipped_to_conversion_idx[faces_to_convert_idx] = barycentric_idx + + # clipped_faces_quadrilateral_ind is an (F_clipped) dim tensor + # For case 4 clipped triangles (where a big triangle is split in two smaller triangles), + # store the index of the neighboring clipped triangle. + # This will be needed because if the soft rasterizer includes both + # triangles in the list of top K nearest triangles, we + # should only use the one with the smaller distance. + clipped_faces_neighbor_idx = torch.full( + [F_clipped], -1, dtype=torch.int64, device=device + ) + clipped_faces_neighbor_idx[case4_clipped_idx] = case4_clipped_idx + 1 + clipped_faces_neighbor_idx[case4_clipped_idx + 1] = case4_clipped_idx + + clipped_faces = ClippedFaces( + face_verts=face_verts_clipped, + mesh_to_face_first_idx=mesh_to_face_first_idx_clipped, + num_faces_per_mesh=num_faces_per_mesh_clipped, + faces_clipped_to_unclipped_idx=faces_clipped_to_unclipped_idx, + barycentric_conversion=barycentric_conversion, + faces_clipped_to_conversion_idx=faces_clipped_to_conversion_idx, + clipped_faces_neighbor_idx=clipped_faces_neighbor_idx, + ) + return clipped_faces diff --git a/tests/test_render_meshes_clipped.py b/tests/test_render_meshes_clipped.py new file mode 100644 index 000000000..474d730ff --- /dev/null +++ b/tests/test_render_meshes_clipped.py @@ -0,0 +1,352 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +""" +Checks for mesh rasterization in the case where the camera enters the +inside of the mesh and some mesh faces are partially +behind the image plane. These faces are clipped and then rasterized. +See pytorch3d/renderer/mesh/clip.py for more details about the +clipping process. +""" +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer.mesh import ClipFrustum, clip_faces +from pytorch3d.structures.meshes import Meshes + + +class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase): + @staticmethod + def clip_faces(meshes): + verts_packed = meshes.verts_packed() + faces_packed = meshes.faces_packed() + face_verts = verts_packed[faces_packed] + mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx() + num_faces_per_mesh = meshes.num_faces_per_mesh() + + frustum = ClipFrustum( + left=-1, + right=1, + top=-1, + bottom=1, + # In the unit tests for each case below the triangles are asummed + # to have already been projected onto the image plane. + perspective_correct=False, + z_clip_value=1e-2, + cull=True, # Cull to frustrum + ) + + clipped_faces = clip_faces( + face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum + ) + return clipped_faces + + def test_case_1(self): + """ + Case 1: Single triangle fully in front of the image plane (z=0) + Triangle is not clipped or culled. The triangle is asummed to have + already been projected onto the image plane so no perspective + correction is needed. + """ + device = "cuda:0" + verts = torch.tensor( + [[0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]], + dtype=torch.float32, + device=device, + ) + faces = torch.tensor( + [ + [0, 1, 2], + ], + dtype=torch.int64, + device=device, + ) + meshes = Meshes(verts=[verts], faces=[faces]) + clipped_faces = self.clip_faces(meshes) + + self.assertClose(clipped_faces.face_verts, verts[faces]) + self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0) + self.assertEqual(clipped_faces.num_faces_per_mesh.item(), 1) + self.assertIsNone(clipped_faces.faces_clipped_to_unclipped_idx) + self.assertIsNone(clipped_faces.faces_clipped_to_conversion_idx) + self.assertIsNone(clipped_faces.clipped_faces_neighbor_idx) + self.assertIsNone(clipped_faces.barycentric_conversion) + + def test_case_2(self): + """ + Case 2 triangles are fully behind the image plane (z=0) so are completely culled. + Test with a single triangle behind the image plane. + """ + + device = "cuda:0" + verts = torch.tensor( + [[-1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [1.0, 0.0, -1.0]], + dtype=torch.float32, + device=device, + ) + faces = torch.tensor( + [ + [0, 1, 2], + ], + dtype=torch.int64, + device=device, + ) + meshes = Meshes(verts=[verts], faces=[faces]) + clipped_faces = self.clip_faces(meshes) + + zero_t = torch.zeros(size=(1,), dtype=torch.int64, device=device) + self.assertClose( + clipped_faces.face_verts, torch.empty(device=device, size=(0, 3, 3)) + ) + self.assertClose(clipped_faces.mesh_to_face_first_idx, zero_t) + self.assertClose(clipped_faces.num_faces_per_mesh, zero_t) + self.assertClose( + clipped_faces.faces_clipped_to_unclipped_idx, + torch.empty(device=device, dtype=torch.int64, size=(0,)), + ) + self.assertIsNone(clipped_faces.faces_clipped_to_conversion_idx) + self.assertIsNone(clipped_faces.clipped_faces_neighbor_idx) + self.assertIsNone(clipped_faces.barycentric_conversion) + + def test_case_3(self): + """ + Case 3 triangles have exactly two vertices behind the clipping plane (z=0) so are + clipped into a smaller triangle. + + Test with a single triangle parallel to the z axis which intersects with + the image plane. + """ + + device = "cuda:0" + verts = torch.tensor( + [[-1.0, 0.0, -1.0], [0.0, 0.0, 1.0], [1.0, 0.0, -1.0]], + dtype=torch.float32, + device=device, + ) + faces = torch.tensor( + [ + [0, 1, 2], + ], + dtype=torch.int64, + device=device, + ) + meshes = Meshes(verts=[verts], faces=[faces]) + clipped_faces = self.clip_faces(meshes) + + zero_t = torch.zeros(size=(1,), dtype=torch.int64, device=device) + clipped_face_verts = torch.tensor( + [ + [ + [0.4950, 0.0000, 0.0100], + [-0.4950, 0.0000, 0.0100], + [0.0000, 0.0000, 1.0000], + ] + ], + device=device, + dtype=torch.float32, + ) + + # barycentric_conversion[i, :, k] stores the barycentric weights + # in terms of the world coordinates of the original + # (big) triangle for the kth vertex in the clipped (small) triangle. + barycentric_conversion = torch.tensor( + [ + [ + [0.0000, 0.4950, 0.0000], + [0.5050, 0.5050, 1.0000], + [0.4950, 0.0000, 0.0000], + ] + ], + device=device, + dtype=torch.float32, + ) + + self.assertClose(clipped_faces.face_verts, clipped_face_verts) + self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0) + self.assertEqual(clipped_faces.num_faces_per_mesh.item(), 1) + self.assertClose(clipped_faces.faces_clipped_to_unclipped_idx, zero_t) + self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, zero_t) + self.assertClose( + clipped_faces.clipped_faces_neighbor_idx, + zero_t - 1, # default is -1 + ) + self.assertClose(clipped_faces.barycentric_conversion, barycentric_conversion) + + def test_case_4(self): + """ + Case 4 triangles have exactly 1 vertex behind the clipping plane (z=0) so + are clipped into a smaller quadrilateral and then divided into two triangles. + + Test with a single triangle parallel to the z axis which intersects with + the image plane. + """ + + device = "cuda:0" + verts = torch.tensor( + [[0.0, 0.0, -1.0], [-1.0, 0.0, 1.0], [1.0, 0.0, 1.0]], + dtype=torch.float32, + device=device, + ) + faces = torch.tensor( + [ + [0, 1, 2], + ], + dtype=torch.int64, + device=device, + ) + meshes = Meshes(verts=[verts], faces=[faces]) + clipped_faces = self.clip_faces(meshes) + + clipped_face_verts = torch.tensor( + [ + # t1 + [ + [-0.5050, 0.0000, 0.0100], + [-1.0000, 0.0000, 1.0000], + [0.5050, 0.0000, 0.0100], + ], + # t2 + [ + [0.5050, 0.0000, 0.0100], + [-1.0000, 0.0000, 1.0000], + [1.0000, 0.0000, 1.0000], + ], + ], + device=device, + dtype=torch.float32, + ) + + barycentric_conversion = torch.tensor( + [ + [ + [0.4950, 0.0000, 0.4950], + [0.5050, 1.0000, 0.0000], + [0.0000, 0.0000, 0.5050], + ], + [ + [0.4950, 0.0000, 0.0000], + [0.0000, 1.0000, 0.0000], + [0.5050, 0.0000, 1.0000], + ], + ], + device=device, + dtype=torch.float32, + ) + + self.assertClose(clipped_faces.face_verts, clipped_face_verts) + self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0) + self.assertEqual( + clipped_faces.num_faces_per_mesh.item(), 2 + ) # now two faces instead of 1 + self.assertClose( + clipped_faces.faces_clipped_to_unclipped_idx, + torch.tensor([0, 0], device=device, dtype=torch.int64), + ) + # Neighboring face for each of the sub triangles e.g. for t1, neighbor is t2, + # and for t2, neighbor is t1 + self.assertClose( + clipped_faces.clipped_faces_neighbor_idx, + torch.tensor([1, 0], device=device, dtype=torch.int64), + ) + # barycentric_conversion is of shape (F_clipped) + self.assertEqual(clipped_faces.barycentric_conversion.shape[0], 2) + self.assertClose(clipped_faces.barycentric_conversion, barycentric_conversion) + # Index into barycentric_conversion for each clipped face. + self.assertClose( + clipped_faces.faces_clipped_to_conversion_idx, + torch.tensor([0, 1], device=device, dtype=torch.int64), + ) + + def test_mixture_of_cases(self): + """ + Test with two meshes composed of different cases to check all the + indexing is correct. + Case 4 faces are subdivided into two faces which are referred + to as t1 and t2. + """ + device = "cuda:0" + # fmt: off + verts = [ + torch.tensor( + [ + [-1.0, 0.0, -1.0], # noqa: E241, E201 + [ 0.0, 1.0, -1.0], # noqa: E241, E201 + [ 1.0, 0.0, -1.0], # noqa: E241, E201 + [ 0.0, -1.0, -1.0], # noqa: E241, E201 + [-1.0, 0.5, 0.5], # noqa: E241, E201 + [ 1.0, 1.0, 1.0], # noqa: E241, E201 + [ 0.0, -1.0, 1.0], # noqa: E241, E201 + [-1.0, 0.5, -0.5], # noqa: E241, E201 + [ 1.0, 1.0, -1.0], # noqa: E241, E201 + [-1.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 1.0], # noqa: E241, E201 + [ 1.0, 0.0, 1.0], # noqa: E241, E201 + ], + dtype=torch.float32, + device=device, + ), + torch.tensor( + [ + [ 0.0, -1.0, -1.0], # noqa: E241, E201 + [-1.0, 0.5, 0.5], # noqa: E241, E201 + [ 1.0, 1.0, 1.0], # noqa: E241, E201 + ], + dtype=torch.float32, + device=device + ) + ] + faces = [ + torch.tensor( + [ + [0, 1, 2], # noqa: E241, E201 Case 2 fully clipped + [3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided + [5, 4, 3], # noqa: E241, E201 Repeat of Case 4 + [6, 7, 8], # noqa: E241, E201 Case 3 clipped + [9, 10, 11], # noqa: E241, E201 Case 1 untouched + ], + dtype=torch.int64, + device=device, + ), + torch.tensor( + [ + [0, 1, 2], # noqa: E241, E201 Case 4 + ], + dtype=torch.int64, + device=device, + ), + ] + # fmt: on + meshes = Meshes(verts=verts, faces=faces) + + # Clip meshes + clipped_faces = self.clip_faces(meshes) + + # mesh 1: 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1) + # mesh 2: 2x faces (from Case 4) + self.assertEqual(clipped_faces.face_verts.shape[0], 6 + 2) + + # dummy idx type tensor to avoid having to initialize the dype/device each time + idx = torch.empty(size=(1,), dtype=torch.int64, device=device) + unclipped_idx = idx.new_tensor([1, 1, 2, 2, 3, 4, 5, 5]) + neighbors = idx.new_tensor([1, 0, 3, 2, -1, -1, 7, 6]) + first_idx = idx.new_tensor([0, 6]) + num_faces = idx.new_tensor([6, 2]) + + self.assertClose(clipped_faces.clipped_faces_neighbor_idx, neighbors) + self.assertClose(clipped_faces.faces_clipped_to_unclipped_idx, unclipped_idx) + self.assertClose(clipped_faces.mesh_to_face_first_idx, first_idx) + self.assertClose(clipped_faces.num_faces_per_mesh, num_faces) + + # faces_clipped_to_conversion_idx maps each output face to the + # corresponding row of the barycentric_conversion matrix. + # The barycentric_conversion matrix is composed by + # finding the barycentric conversion weights for case 3 faces + # case 4 (t1) faces and case 4 (t2) faces. These are then + # concatenated. Therefore case 3 faces will be the first rows of + # the barycentric_conversion matrix followed by t1 and then t2. + # Case type of all faces: [4 (t1), 4 (t2), 4 (t1), 4 (t2), 3, 1, 4 (t1), 4 (t2)] + # Based on this information we can calculate the indices into the + # barycentric conversion matrix. + bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6]) + self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)