diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index cdfb4e60c..bfdae6c9c 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -9,6 +9,7 @@ import torch from pytorch3d import _C +from pytorch3d.common.datatypes import Device # Example functions for blending the top K colors per pixel using the outputs @@ -37,6 +38,17 @@ class BlendParams(NamedTuple): background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0) +def _get_background_color( + blend_params: BlendParams, device: Device, dtype=torch.float32 +) -> torch.Tensor: + background_color_ = blend_params.background_color + if isinstance(background_color_, torch.Tensor): + background_color = background_color_.to(device) + else: + background_color = torch.tensor(background_color_, dtype=dtype, device=device) + return background_color + + def hard_rgb_blend( colors: torch.Tensor, fragments, blend_params: BlendParams ) -> torch.Tensor: @@ -57,18 +69,11 @@ def hard_rgb_blend( Returns: RGBA pixel_colors: (N, H, W, 4) """ - N, H, W, K = fragments.pix_to_face.shape - device = fragments.pix_to_face.device + background_color = _get_background_color(blend_params, fragments.pix_to_face.device) # Mask for the background. is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W) - background_color_ = blend_params.background_color - if isinstance(background_color_, torch.Tensor): - background_color = background_color_.to(device) - else: - background_color = colors.new_tensor(background_color_) - # Find out how much background_color needs to be expanded to be used for masked_scatter. num_background_pixels = is_background.sum() @@ -182,13 +187,8 @@ def softmax_rgb_blend( """ N, H, W, K = fragments.pix_to_face.shape - device = fragments.pix_to_face.device pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) - background_ = blend_params.background_color - if not isinstance(background_, torch.Tensor): - background = torch.tensor(background_, dtype=torch.float32, device=device) - else: - background = background_.to(device) + background_color = _get_background_color(blend_params, fragments.pix_to_face.device) # Weight for background color eps = 1e-10 @@ -233,7 +233,7 @@ def softmax_rgb_blend( # Sum: weights * textures + background color weighted_colors = (weights_num[..., None] * colors).sum(dim=-2) - weighted_background = delta * background + weighted_background = delta * background_color pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom pixel_colors[..., 3] = 1.0 - alpha