Skip to content

Commit

Permalink
multigpu mesh rendering fixes
Browse files Browse the repository at this point in the history
Summary:
Small fix and updated tests for multigpu rendering case.

This resolves the issue seen in: #401

Reviewed By: gkioxari

Differential Revision: D24314681

fbshipit-source-id: 84c5a5359844c77518b48044001daa9a86f3c43a
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Oct 16, 2020
1 parent 4d52f9f commit 563d441
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 66 deletions.
5 changes: 3 additions & 2 deletions pytorch3d/renderer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import numpy as np
import torch
import torch.nn as nn


class TensorAccessor(object):
class TensorAccessor(nn.Module):
"""
A helper class to be used with the __getitem__ method. This can be used for
getting/setting the values for an attribute of a class at one particular
Expand Down Expand Up @@ -82,7 +83,7 @@ def __getattr__(self, name: str):
BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray)


class TensorProperties(object):
class TensorProperties(nn.Module):
"""
A mix-in class for storing tensors as properties with helper methods.
"""
Expand Down
64 changes: 0 additions & 64 deletions tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,67 +1042,3 @@ def test_simple_sphere_outside_zfar(self):
)

self.assertClose(rgb, image_ref, atol=0.05)

def test_to(self):
# Test moving all the tensors in the renderer to a new device
# to support multigpu rendering.
device1 = torch.device("cpu")

R, T = look_at_view_transform(1500, 0.0, 0.0)

# Init shader settings
materials = Materials(device=device1)
lights = PointLights(device=device1)
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]

raster_settings = RasterizationSettings(
image_size=256, blur_radius=0.0, faces_per_pixel=1
)
cameras = FoVPerspectiveCameras(
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)

blend_params = BlendParams(
1e-4,
1e-4,
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
)

shader = SoftPhongShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)

def _check_props_on_device(renderer, device):
self.assertEqual(renderer.rasterizer.cameras.device, device)
self.assertEqual(renderer.shader.cameras.device, device)
self.assertEqual(renderer.shader.lights.device, device)
self.assertEqual(renderer.shader.lights.ambient_color.device, device)
self.assertEqual(renderer.shader.materials.device, device)
self.assertEqual(renderer.shader.materials.ambient_color.device, device)

mesh = ico_sphere(2, device1)
verts_padded = mesh.verts_padded()
textures = TexturesVertex(
verts_features=torch.ones_like(verts_padded, device=device1)
)
mesh.textures = textures
_check_props_on_device(renderer, device1)

# Test rendering on cpu
output_images = renderer(mesh)
self.assertEqual(output_images.device, device1)

# Move renderer and mesh to another device and re render
# This also tests that background_color is correctly moved to
# the new device
device2 = torch.device("cuda:0")
renderer.to(device2)
mesh = mesh.to(device2)
_check_props_on_device(renderer, device2)
output_images = renderer(mesh)
self.assertEqual(output_images.device, device2)
159 changes: 159 additions & 0 deletions tests/test_render_multigpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest

import torch
import torch.nn as nn
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.renderer import (
BlendParams,
HardGouraudShader,
Materials,
MeshRasterizer,
MeshRenderer,
PointLights,
RasterizationSettings,
SoftPhongShader,
TexturesVertex,
)
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere


# Set the number of GPUS you want to test with
NUM_GPUS = 3
GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)})
print("GPUs: %s" % ", ".join(GPU_LIST))


class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase):
def _check_mesh_renderer_props_on_device(self, renderer, device):
"""
Helper function to check that all the properties of the mesh
renderer have been moved to the correct device.
"""
# Cameras
self.assertEqual(renderer.rasterizer.cameras.device, device)
self.assertEqual(renderer.rasterizer.cameras.R.device, device)
self.assertEqual(renderer.rasterizer.cameras.T.device, device)
self.assertEqual(renderer.shader.cameras.device, device)
self.assertEqual(renderer.shader.cameras.R.device, device)
self.assertEqual(renderer.shader.cameras.T.device, device)

# Lights and Materials
self.assertEqual(renderer.shader.lights.device, device)
self.assertEqual(renderer.shader.lights.ambient_color.device, device)
self.assertEqual(renderer.shader.materials.device, device)
self.assertEqual(renderer.shader.materials.ambient_color.device, device)

def test_mesh_renderer_to(self):
"""
Test moving all the tensors in the mesh renderer to a new device.
"""

device1 = torch.device("cpu")

R, T = look_at_view_transform(1500, 0.0, 0.0)

# Init shader settings
materials = Materials(device=device1)
lights = PointLights(device=device1)
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]

raster_settings = RasterizationSettings(
image_size=256, blur_radius=0.0, faces_per_pixel=1
)
cameras = FoVPerspectiveCameras(
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)

blend_params = BlendParams(
1e-4,
1e-4,
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
)

shader = SoftPhongShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)

mesh = ico_sphere(2, device1)
verts_padded = mesh.verts_padded()
textures = TexturesVertex(
verts_features=torch.ones_like(verts_padded, device=device1)
)
mesh.textures = textures
self._check_mesh_renderer_props_on_device(renderer, device1)

# Test rendering on cpu
output_images = renderer(mesh)
self.assertEqual(output_images.device, device1)

# Move renderer and mesh to another device and re render
# This also tests that background_color is correctly moved to
# the new device
device2 = torch.device("cuda:0")
renderer.to(device2)
mesh = mesh.to(device2)
self._check_mesh_renderer_props_on_device(renderer, device2)
output_images = renderer(mesh)
self.assertEqual(output_images.device, device2)

def test_render_meshes(self):
test = self

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
mesh = ico_sphere(3)
self.register_buffer("faces", mesh.faces_padded())
self.renderer = self.init_render()

def init_render(self):

cameras = FoVPerspectiveCameras()
raster_settings = RasterizationSettings(
image_size=128, blur_radius=0.0, faces_per_pixel=1
)
lights = PointLights(
ambient_color=((1.0, 1.0, 1.0),),
diffuse_color=((0, 0.0, 0),),
specular_color=((0.0, 0, 0),),
location=((0.0, 0.0, 1e5),),
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
),
shader=HardGouraudShader(cameras=cameras, lights=lights),
)
return renderer

def forward(self, verts, texs):
batch_size = verts.size(0)
self.renderer.to(verts.device)
tex = TexturesVertex(verts_features=texs)
faces = self.faces.expand(batch_size, -1, -1).to(verts.device)
mesh = Meshes(verts, faces, tex).to(verts.device)

test._check_mesh_renderer_props_on_device(self.renderer, verts.device)
img_render = self.renderer(mesh)
return img_render[:, :, :, :3]

# DataParallel requires every input tensor be provided
# on the first device in its device_ids list.
verts = ico_sphere(3).verts_padded()
texs = verts.new_ones(verts.shape)
model = Model()
model = nn.DataParallel(model, device_ids=GPU_LIST)
model.to(f"cuda:{model.device_ids[0]}")

# Test a few iterations
for _ in range(100):
model(verts, texs)

0 comments on commit 563d441

Please sign in to comment.