Skip to content

Commit

Permalink
Adding join_mesh in pytorch3d.structures.meshes
Browse files Browse the repository at this point in the history
Summary: Adding a function in pytorch3d.structures.meshes to join multiple meshes into a Meshes object representing a single mesh. The function currently ignores all textures.

Reviewed By: nikhilaravi

Differential Revision: D21876908

fbshipit-source-id: 448602857e9d3d3f774d18bb4e93076f78329823
  • Loading branch information
megluyagao authored and facebook-github-bot committed Jun 9, 2020
1 parent 4b78e95 commit e053d7c
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 3 deletions.
27 changes: 26 additions & 1 deletion pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import List
from typing import List, Union

import torch

Expand Down Expand Up @@ -1539,3 +1539,28 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):

tex = Textures(**kwargs)
return Meshes(verts=verts, faces=faces, textures=tex)


def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
"""
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
objects as a single mesh. If the input is a list, the Meshes objects in the list
must all be on the same device. This version ignores all textures in the input mehses.
Args:
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
Returns:
new Meshes object containing a single mesh
"""
if isinstance(meshes, List):
meshes = join_meshes_as_batch(meshes, include_textures=False)

if len(meshes) == 1:
return meshes
verts = meshes.verts_packed() # (sum(V_n), 3)
# Offset automatically done by faces_packed
faces = meshes.faces_packed() # (sum(F_n), 3)

mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0))
return mesh
Binary file added tests/data/test_joined_spheres_flat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_joined_spheres_gouraud.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_joined_spheres_phong.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 37 additions & 0 deletions tests/test_obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_bilinear_interpolation_vectorized,
)
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
from pytorch3d.structures.meshes import join_mesh
from pytorch3d.utils import torus


Expand Down Expand Up @@ -648,6 +649,42 @@ def check_item(x, y):
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])

def test_join_meshes(self):
"""
Test that join_mesh joins single meshes and the corresponding values are
consistent with the single meshes.
"""

# Load cow mesh.
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
cow_obj = DATA_DIR / "cow_mesh/cow.obj"

cow_mesh = load_objs_as_meshes([cow_obj])
cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
# Join a batch of three single meshes and check that the values are consistent
# with the individual meshes.
cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])

def check_item(x, y, offset):
self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)

check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)

# Test the joining of meshes of different sizes.
teapot_obj = DATA_DIR / "teapot.obj"
teapot_mesh = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)

mix_mesh = join_mesh([cow_mesh, teapot_mesh])
mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
self.assertEqual(len(mix_mesh), 1)

self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)

@staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_obj(StringIO(), verts, faces, decimal_places)
Expand Down
71 changes: 69 additions & 2 deletions tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TexturedSoftPhongShader,
)
from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes
from pytorch3d.structures.meshes import Meshes, join_mesh
from pytorch3d.utils.ico_sphere import ico_sphere


Expand Down Expand Up @@ -176,7 +176,7 @@ def test_simple_sphere_batched(self):
# Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
shaders = {
"phong": HardGouraudShader,
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
Expand Down Expand Up @@ -369,3 +369,70 @@ def test_texture_map(self):
)

self.assertClose(rgb, image_ref, atol=0.05)

def test_joined_spheres(self):
"""
Test a list of Meshes can be joined as a single mesh and
the single mesh is rendered correctly with Phong, Gouraud
and Flat Shaders.
"""
device = torch.device("cuda:0")

# Init mesh with vertex textures.
# Initialize a list containing two ico spheres of different sizes.
sphere_list = [ico_sphere(3, device), ico_sphere(4, device)]
# [(42 verts, 80 faces), (162 verts, 320 faces)]
# The scale the vertices need to be set at to resize the spheres
scales = [0.25, 1]
# The distance the spheres ought to be offset horizontally to prevent overlap.
offsets = [1.2, -0.3]
# Initialize a list containing the adjusted sphere meshes.
sphere_mesh_list = []
for i in range(len(sphere_list)):
verts = sphere_list[i].verts_padded() * scales[i]
verts[0, :, 0] += offsets[i]
sphere_mesh_list.append(
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
)
joined_sphere_mesh = join_mesh(sphere_mesh_list)
joined_sphere_mesh.textures = Textures(
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
)

# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)

# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))

# Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
image = renderer(joined_sphere_mesh)
rgb = image[..., :3].squeeze().cpu()
if DEBUG:
file_name = "DEBUG_joined_spheres_%s.png" % name
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / file_name
)
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)

1 comment on commit e053d7c

@ANABUR920
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest version seems cannot find this function

has no attribute ”join_mesh“

Please sign in to comment.