Skip to content

Commit

Permalink
Alpha channel to return the mask
Browse files Browse the repository at this point in the history
Summary: Updated the alpha channel in the `hard_rgb_blend` function to return the mask of the pixels which have overlapping mesh faces.

Reviewed By: bottler

Differential Revision: D29001604

fbshipit-source-id: 22a2173d769f2d3ad34892d68ceb628f073bca22
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Jun 9, 2021
1 parent ac6c07f commit a15c33a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pytorch3d/renderer/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
) # (N, H, W, 3)

# Concat with the alpha channel.
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
alpha = (~is_background).type_as(pixel_colors)[..., None]

return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def test_hard_rgb_blend(self):
channel_color = blend_params.background_color[i]
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())

# Examine the alpha channel is correct
self.assertTrue(images[..., 3].eq(1).all())
# Examine the alpha channel
self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float())

def test_sigmoid_alpha_blend_manual_gradients(self):
# Create dummy outputs of rasterization
Expand Down
8 changes: 8 additions & 0 deletions tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
)
images, fragments = renderer(sphere_mesh)
self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf)
# Check the alpha channel is the mask
self.assertClose(
images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float()
)
else:
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
Expand Down Expand Up @@ -165,6 +169,10 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
self.assertClose(
fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf
)
# Check the alpha channel is the mask
self.assertClose(
images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float()
)
else:
phong_renderer = MeshRenderer(
rasterizer=rasterizer, shader=phong_shader
Expand Down

0 comments on commit a15c33a

Please sign in to comment.