diff --git a/docs/tutorials/render_textured_meshes.ipynb b/docs/tutorials/render_textured_meshes.ipynb index 0c5426a95..657bdb968 100644 --- a/docs/tutorials/render_textured_meshes.ipynb +++ b/docs/tutorials/render_textured_meshes.ipynb @@ -93,7 +93,7 @@ "\n", "# Data structures and functions for rendering\n", "from pytorch3d.structures import Meshes\n", - "from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene\n", + "from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene, texturesuv_image_matplotlib\n", "from pytorch3d.renderer import (\n", " look_at_view_transform,\n", " FoVPerspectiveCameras, \n", @@ -236,8 +236,7 @@ "obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n", "\n", "# Load obj file\n", - "mesh = load_objs_as_meshes([obj_filename], device=device)\n", - "texture_image=mesh.textures.maps_padded()" + "mesh = load_objs_as_meshes([obj_filename], device=device)" ] }, { @@ -265,9 +264,29 @@ "outputs": [], "source": [ "plt.figure(figsize=(7,7))\n", + "texture_image=mesh.textures.maps_padded()\n", "plt.imshow(texture_image.squeeze().cpu().numpy())\n", "plt.grid(\"off\");\n", - "plt.axis('off');" + "plt.axis(\"off\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "PyTorch3D has a built-in way to view the texture map with matplotlib along with the points on the map corresponding to vertices. There is also a method, texturesuv_image_PIL, to get a similar image which can be saved to a file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(7,7))\n", + "texturesuv_image_matplotlib(mesh.textures, subsample=None)\n", + "plt.grid(\"off\");\n", + "plt.axis(\"off\");" ] }, { diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 5f28ce7f9..9a58b841b 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -1174,6 +1174,42 @@ def join_scene(self) -> "TexturesUV": padding_mode=self.padding_mode, ) + def centers_for_image(self, index): + """ + Return the locations in the texture map which correspond to the given + verts_uvs, for one of the meshes. This is potentially useful for + visualizing the data. See the texturesuv_image_matplotlib and + texturesuv_image_PIL functions. + + Args: + index: batch index of the mesh whose centers to return. + + Returns: + centers: coordinates of points in the texture image + - a FloatTensor of shape (V,2) + """ + if self._N != 1: + raise ValueError( + "This function only supports plotting textures for one mesh." + ) + texture_image = self.maps_padded() + verts_uvs = self.verts_uvs_list()[index][None] + _, H, W, _3 = texture_image.shape + coord1 = torch.arange(W).expand(H, W) + coord2 = torch.arange(H)[:, None].expand(H, W) + coords = torch.stack([coord1, coord2])[None] + with torch.no_grad(): + # Get xy cartesian coordinates based on the uv coordinates + centers = F.grid_sample( + torch.flip(coords.to(texture_image), [2]), + # Convert from [0, 1] -> [-1, 1] range expected by grid sample + verts_uvs[:, None] * 2.0 - 1, + align_corners=self.align_corners, + padding_mode=self.padding_mode, + ).cpu() + centers = centers[0, :, 0].T + return centers + class TexturesVertex(TexturesBase): def __init__( diff --git a/pytorch3d/vis/__init__.py b/pytorch3d/vis/__init__.py index 3dfbf532a..6f7d948b5 100644 --- a/pytorch3d/vis/__init__.py +++ b/pytorch3d/vis/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene +from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/vis/texture_vis.py b/pytorch3d/vis/texture_vis.py new file mode 100644 index 000000000..05d375932 --- /dev/null +++ b/pytorch3d/vis/texture_vis.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Optional + +import numpy as np +from PIL import Image, ImageDraw +from pytorch3d.renderer.mesh import TexturesUV + + +def texturesuv_image_matplotlib( + texture: TexturesUV, + *, + texture_index: int = 0, + radius: float = 1, + color=(1.0, 0.0, 0.0), + subsample: Optional[int] = 10000, + origin: str = "upper", +): + """ + Plot the texture image for one element of a TexturesUV with + matplotlib together with verts_uvs positions circled. + In particular a value in verts_uvs which is never referenced + in faces_uvs will still be plotted. + This is for debugging purposes, e.g. to align the map with + the uv coordinates. In particular, matplotlib + is used which is not an official dependency of PyTorch3D. + + Args: + texture: a TexturesUV object with one mesh + texture_index: index in the batch to plot + radius: plotted circle radius in pixels + color: any matplotlib-understood color for the circles. + subsample: if not None, number of points to plot. + Otherwise all points are plotted. + origin: "upper" or "lower" like matplotlib.imshow + """ + + import matplotlib.pyplot as plt + from matplotlib.patches import Circle + + texture_image = texture.maps_padded() + centers = texture.centers_for_image(index=texture_index).numpy() + + ax = plt.gca() + ax.imshow(texture_image[texture_index].detach().cpu().numpy(), origin=origin) + + n_points = centers.shape[0] + if subsample is None or n_points <= subsample: + indices = range(n_points) + else: + indices = np.random.choice(n_points, subsample, replace=False) + for i in indices: + # setting clip_on=False makes it obvious when + # we have UV coordinates outside the correct range + ax.add_patch(Circle(centers[i], radius, color=color, clip_on=False)) + + +def texturesuv_image_PIL( + texture: TexturesUV, + *, + texture_index: int = 0, + radius: float = 1, + color="red", + subsample: Optional[int] = 10000, +): + """ + Return a PIL image of the texture image of one element of the batch + from a TexturesUV, together with the verts_uvs positions circled. + In particular a value in verts_uvs which is never referenced + in faces_uvs will still be plotted. + This is for debugging purposes, e.g. to align the map with + the uv coordinates. In particular, matplotlib + is used which is not an official dependency of PyTorch3D. + + Args: + texture: a TexturesUV object with one mesh + texture_index: index in the batch to plot + radius: plotted circle radius in pixels + color: any PIL-understood color for the circles. + subsample: if not None, number of points to plot. + Otherwise all points are plotted. + + Returns: + PIL Image object. + """ + + centers = texture.centers_for_image(index=texture_index).numpy() + texture_image = texture.maps_padded() + texture_array = (texture_image[texture_index] * 255).cpu().numpy().astype(np.uint8) + + image = Image.fromarray(texture_array) + draw = ImageDraw.Draw(image) + + n_points = centers.shape[0] + if subsample is None or n_points <= subsample: + indices = range(n_points) + else: + indices = np.random.choice(n_points, subsample, replace=False) + + for i in indices: + x = centers[i][0] + y = centers[i][1] + draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], fill=color) + + return image diff --git a/tests/data/texturesuv_debug.png b/tests/data/texturesuv_debug.png new file mode 100644 index 000000000..285806541 Binary files /dev/null and b/tests/data/texturesuv_debug.png differ diff --git a/tests/test_texturing.py b/tests/test_texturing.py index c10b36325..6f29feb65 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -2,10 +2,13 @@ import unittest +from pathlib import Path +import numpy as np import torch import torch.nn.functional as F from common_testing import TestCaseMixin +from PIL import Image from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.mesh.textures import ( TexturesAtlas, @@ -15,9 +18,14 @@ pack_rectangles, ) from pytorch3d.structures import Meshes, list_to_packed, packed_to_list +from pytorch3d.vis import texturesuv_image_PIL from test_meshes import TestMeshes +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" + + def tryindex(self, index, tex, meshes, source): tex2 = tex[index] meshes2 = meshes[index] @@ -471,6 +479,10 @@ def test_getitem(self): class TestTexturesUV(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + def test_sample_textures_uv(self): barycentric_coords = torch.tensor( [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 @@ -821,6 +833,22 @@ def test_getitem(self): tryindex(self, index, tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source) + def test_png_debug(self): + maps = torch.rand(size=(1, 256, 128, 3)) * torch.tensor([0.8, 1, 0.8]) + verts_uvs = torch.rand(size=(1, 20, 2)) + faces_uvs = torch.zeros(size=(1, 0, 3), dtype=torch.int64) + tex = TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs) + + image = texturesuv_image_PIL(tex, radius=3) + image_out = np.array(image) + if DEBUG: + image.save(DATA_DIR / "texturesuv_debug_.png") + + with Image.open(DATA_DIR / "texturesuv_debug.png") as image_ref_file: + image_ref = np.array(image_ref_file) + + self.assertClose(image_out, image_ref) + class TestRectanglePacking(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: