Skip to content

Commit

Permalink
Adapt RayPointRefiner and RayMarcher to support bins.
Browse files Browse the repository at this point in the history
Summary:
## Context

Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals.

We introduce here the support of bins like in MipNerf implementation.

## RayPointRefiner

Small changes have been made to modify RayPointRefiner.
- If bins is None

```
mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5)
z_samples = sample_pdf(
		mids, # [..., npt]
		weights[..., 1:-1], # [..., npt - 1]
               ….
            )
```

- If bins is not None
In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it.

```
z_samples = sample_pdf(
		ray_bundle.bins, # [..., npt + 1]
		weights, # [..., npt]
               ...
            )
```

## RayMarcher

Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths.

Reviewed By: shapovalov

Differential Revision: D46389092

fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
  • Loading branch information
EmGarr authored and facebook-github-bot committed Jul 6, 2023
1 parent 5910d81 commit 3d011a9
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 18 deletions.
4 changes: 4 additions & 0 deletions pytorch3d/implicitron/models/renderer/multipass_ea.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,13 @@ def _run_raymarcher(
else 0.0
)

ray_deltas = (
None if ray_bundle.bins is None else torch.diff(ray_bundle.bins, dim=-1)
)
output = self.raymarcher(
*implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths,
ray_deltas=ray_deltas,
density_noise_std=density_noise_std,
)
output.prev_stage = prev_stage
Expand Down
31 changes: 22 additions & 9 deletions pytorch3d/implicitron/models/renderer/ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,29 +78,42 @@ def forward(
"""

z_vals = input_ray_bundle.lengths
with torch.no_grad():
if self.blurpool_weights:
ray_weights = apply_blurpool_on_weights(ray_weights)

z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
n_pts_per_ray = self.n_pts_per_ray
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
if input_ray_bundle.bins is None:
z_vals: torch.Tensor = input_ray_bundle.lengths
ray_weights = ray_weights[..., 1:-1]
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
else:
z_vals = input_ray_bundle.bins
n_pts_per_ray += 1
bins = z_vals
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
bins.view(-1, bins.shape[-1]),
ray_weights,
n_pts_per_ray,
det=not self.random_sampling,
eps=self.sample_pdf_eps,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
).view(*z_vals.shape[:-1], n_pts_per_ray)

if self.add_input_samples:
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
z_vals = z_samples
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)

new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle
kwargs_ray = dict(vars(input_ray_bundle))
if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)


def apply_blurpool_on_weights(weights) -> torch.Tensor:
Expand Down
23 changes: 15 additions & 8 deletions pytorch3d/implicitron/models/renderer/raymarcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import torch
from pytorch3d.implicitron.models.renderer.base import RendererOutput
Expand Down Expand Up @@ -119,6 +119,7 @@ def forward(
rays_features: torch.Tensor,
aux: Dict[str, Any],
ray_lengths: torch.Tensor,
ray_deltas: Optional[torch.Tensor] = None,
density_noise_std: float = 0.0,
**kwargs,
) -> RendererOutput:
Expand All @@ -131,6 +132,9 @@ def forward(
aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
ray_deltas: Optional differences between consecutive elements along the ray bundle
represented with a tensor of shape `(..., n_points_per_ray)`. If None,
these differences are computed from ray_lengths.
density_noise_std: the magnitude of the noise added to densities.
Returns:
Expand All @@ -152,14 +156,17 @@ def forward(
density_1d=True,
)

ray_lengths_diffs = ray_lengths[..., 1:] - ray_lengths[..., :-1]
if self.replicate_last_interval:
last_interval = ray_lengths_diffs[..., -1:]
if ray_deltas is None:
ray_lengths_diffs = torch.diff(ray_lengths, dim=-1)
if self.replicate_last_interval:
last_interval = ray_lengths_diffs[..., -1:]
else:
last_interval = torch.full_like(
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
else:
last_interval = torch.full_like(
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
deltas = ray_deltas

rays_densities = rays_densities[..., 0]

Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/implicit/harmonic_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During, the inference you can provide the extra argument `diag_cov`.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
Expand Down
65 changes: 65 additions & 0 deletions tests/implicitron/test_ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,71 @@ def test_simple(self):
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)

def test_simple_use_bins(self):
"""
Same spirit than test_simple but use bins in the ImplicitronRayBunle.
It has been duplicated to avoid cognitive overload while reading the
test (lot of if else).
"""
length = 15
n_pts_per_ray = 10

for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)

bundle = ImplicitronRayBundle(
lengths=None,
bins=torch.arange(length + 1, dtype=torch.float32).expand(
3, 25, length + 1
),
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool)

self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1)
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1)
if add_input_samples:
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[
0
]
full_expected = torch.lerp(
expected_bins[..., :-1], expected_bins[..., 1:], 0.5
)

self.assertClose(refined.lengths, full_expected)

ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)

refined_random = ray_point_refiner_random(
bundle, weights, blurpool_weights=use_blurpool
)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0)
self.assertLess(lengths_random.max().item(), length)

# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)

def test_apply_blurpool_on_weights(self):
weights = torch.tensor(
[
Expand Down

0 comments on commit 3d011a9

Please sign in to comment.