diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 23df98003..d91d556ec 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -58,7 +58,8 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: ) # (N, H, W, 3) # Concat with the alpha channel. - alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device) + alpha = (~is_background).type_as(pixel_colors)[..., None] + return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4) diff --git a/tests/test_blending.py b/tests/test_blending.py index f2fa2348e..8a2f90b39 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -184,8 +184,8 @@ def test_hard_rgb_blend(self): channel_color = blend_params.background_color[i] self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all()) - # Examine the alpha channel is correct - self.assertTrue(images[..., 3].eq(1).all()) + # Examine the alpha channel + self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float()) def test_sigmoid_alpha_blend_manual_gradients(self): # Create dummy outputs of rasterization diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index e2341cc86..64bbe7fdb 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -125,6 +125,10 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False): ) images, fragments = renderer(sphere_mesh) self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf) + # Check the alpha channel is the mask + self.assertClose( + images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float() + ) else: renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) images = renderer(sphere_mesh) @@ -165,6 +169,10 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False): self.assertClose( fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf ) + # Check the alpha channel is the mask + self.assertClose( + images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float() + ) else: phong_renderer = MeshRenderer( rasterizer=rasterizer, shader=phong_shader