Skip to content

Commit

Permalink
WIP make Perspective sensor parameters differential
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Oct 20, 2022
1 parent d3a7580 commit ef9f559
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 29 deletions.
42 changes: 24 additions & 18 deletions src/python/python/ad/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def sample_rays(
scene: mi.Scene,
sensor: mi.Sensor,
sampler: mi.Sampler,
reparam: Callable[[mi.Ray3f, mi.Bool],
Tuple[mi.Ray3f, mi.Float]] = None
) -> Tuple[mi.RayDifferential3f, mi.Spectrum, mi.Vector2f]:
reparam: Callable[[mi.Ray3f, mi.UInt32, mi.Bool],
Tuple[mi.Vector3f, mi.Float]] = None
) -> Tuple[mi.RayDifferential3f, mi.Spectrum, mi.Vector2f, mi.Float]:
"""
Sample a 2D grid of primary rays for a given sensor
Expand Down Expand Up @@ -363,12 +363,13 @@ def sample_rays(
if mi.is_spectral:
wavelength_sample = sampler.next_1d()

ray, weight = sensor.sample_ray_differential(
time=time,
sample1=wavelength_sample,
sample2=pos_adjusted,
sample3=aperture_sample
)
with dr.resume_grad():
ray, weight = sensor.sample_ray_differential(
time=time,
sample1=wavelength_sample,
sample2=pos_adjusted,
sample3=aperture_sample
)

reparam_det = 1.0

Expand Down Expand Up @@ -396,10 +397,16 @@ def sample_rays(

with dr.resume_grad():
# Reparameterize the camera ray
reparam_d, reparam_det = reparam(ray=ray, depth=mi.UInt32(0))
reparam_d, reparam_det = reparam(ray=dr.detach(ray),
depth=mi.UInt32(0))

# TODO better understand why this is necessary
# Reparameterize the camera ray to handle camera translations
if dr.grad_enabled(ray.o):
reparam_d, _ = reparam(ray=ray, depth=mi.UInt32(0))

# Create a fake interaction along the sampled ray and use it to the
# position with derivative tracking
# Create a fake interaction along the sampled ray and use it to
# recompute the position with derivative tracking
it = dr.zeros(mi.Interaction3f)
it.p = ray.o + reparam_d
ds, _ = sensor.sample_direction(it, aperture_sample)
Expand Down Expand Up @@ -484,10 +491,9 @@ def sample(self,
δL: Optional[mi.Spectrum],
state_in: Any,
reparam: Optional[
Callable[[mi.Ray3f, mi.Bool],
Tuple[mi.Ray3f, mi.Float]]],
active: mi.Bool) -> Tuple[mi.Spectrum,
mi.Bool]:
Callable[[mi.Ray3f, mi.UInt32, mi.Bool],
Tuple[mi.Vector3f, mi.Float]]],
active: mi.Bool) -> Tuple[mi.Spectrum, mi.Bool]:
"""
This function does the main work of differentiable rendering and
remains unimplemented here. It is provided by subclasses of the
Expand Down Expand Up @@ -1168,8 +1174,8 @@ def __init__(self,
params: Any,
reparam: Callable[
[mi.Scene, mi.PCG32, Any,
mi.Ray3f, mi.Bool],
Tuple[mi.Ray3f, mi.Float]],
mi.Ray3f, mi.UInt32, mi.Bool],
Tuple[mi.Vector3f, mi.Float]],
wavefront_size : int,
seed : int):

Expand Down
2 changes: 1 addition & 1 deletion src/python/python/ad/integrators/direct_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def sample(self,
bsdf_ctx = mi.BSDFContext()
L = mi.Spectrum(0)

ray_reparam = mi.Ray3f(ray)
ray_reparam = mi.Ray3f(dr.detach(ray))
if not primal:
# Camera ray reparameterization determinant multiplied in ADIntegrator.sample_rays()
ray_reparam.d, _ = reparam(ray, depth=0, active=active)
Expand Down
2 changes: 1 addition & 1 deletion src/python/python/ad/integrators/prb.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def sample(self,
# --------------------- Configure loop state ----------------------

# Copy input arguments to avoid mutating the caller's state
ray = mi.Ray3f(ray)
ray = mi.Ray3f(dr.detach(ray))
depth = mi.UInt32(0) # Depth of current vertex
L = mi.Spectrum(0 if primal else state_in) # Radiance accumulator
δL = mi.Spectrum(δL if δL is not None else 0) # Differential/adjoint radiance
Expand Down
2 changes: 1 addition & 1 deletion src/python/python/ad/integrators/prb_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def sample(self,
# Initialize loop state variables caching the rays and preliminary
# intersections of the previous (zero-initialized) and current vertex
ray_prev = dr.zeros(mi.Ray3f)
ray_cur = mi.Ray3f(ray)
ray_cur = mi.Ray3f(dr.detach(ray))
pi_prev = dr.zeros(mi.PreliminaryIntersection3f)
pi_cur = scene.ray_intersect_preliminary(ray_cur, coherent=True,
active=active)
Expand Down
21 changes: 15 additions & 6 deletions src/python/python/ad/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _sample_warp_field(scene: mi.Scene,
coherent=False)

# Convert into a direction at 'ray.o'. When no surface was intersected,
# copy the original (static) direction
# copy the original direction
hit = si.is_valid()
V_direct = dr.select(hit, dr.normalize(si.p - ray.o), ray.d)

Expand Down Expand Up @@ -168,7 +168,7 @@ def forward(self):
it = mi.UInt32(0)
rng = self.rng
ray_grad_o = self.grad_in('ray').o
ray_frame = mi.Frame3f(self.ray.d)
ray_grad_d = self.grad_in('ray').d

loop = mi.Loop(name="reparameterize_ray(): forward propagation",
state=lambda: (it, Z, dZ, grad_V, grad_div_lhs, rng.state))
Expand All @@ -182,10 +182,12 @@ def forward(self):
ray = mi.Ray3f(self.ray)
dr.enable_grad(ray.o)
dr.set_grad(ray.o, ray_grad_o)
dr.enable_grad(ray.d)
dr.set_grad(ray.d, ray_grad_d)
ray_frame = mi.Frame3f(ray.d)

rng_state_backup = rng.state
sample = mi.Point2f(rng.next_float32(),
rng.next_float32())
sample = mi.Point2f(rng.next_float32(), rng.next_float32())

if self.antithetic:
repeat = dr.eq(it & 1, 0)
Expand Down Expand Up @@ -285,9 +287,10 @@ def backward_symbolic(self):

it = mi.UInt32(0)
ray_grad_o = mi.Point3f(0)
ray_grad_d = mi.Vector3f(0)

loop = mi.Loop(name="reparameterize_ray(): backpropagation",
state=lambda: (it, rng.state, ray_grad_o))
state=lambda: (it, rng.state, ray_grad_o, ray_grad_d))

# Unroll the entire loop in wavefront mode
# loop.set_uniform(True) # TODO can we turn this back on? (see self.active in loop condition)
Expand All @@ -297,6 +300,8 @@ def backward_symbolic(self):
while loop(self.active & (it < self.num_rays)):
ray = mi.Ray3f(self.ray)
dr.enable_grad(ray.o)
dr.enable_grad(ray.d)
ray_frame = mi.Frame3f(ray.d)

rng_state_backup = rng.state

Expand All @@ -319,10 +324,12 @@ def backward_symbolic(self):
dr.enqueue(dr.ADMode.Backward, V_i, div_V_1_i)
dr.traverse(mi.Float, dr.ADMode.Backward, dr.ADFlag.ClearVertices)
ray_grad_o += dr.grad(ray.o)
ray_grad_d += dr.grad(ray.d)
it += 1

ray_grad = dr.detach(dr.zeros(type(self.ray)))
ray_grad.o = ray_grad_o
ray_grad.d = ray_grad_d
self.set_grad_in('ray', ray_grad)


Expand All @@ -334,12 +341,13 @@ def backward_unroll(self):
# Ignore inactive lanes
grad_direction = dr.select(self.active, grad_direction, 0.0)
grad_divergence = dr.select(self.active, grad_divergence, 0.0)
ray_frame = mi.Frame3f(self.ray.d)
rng = self.rng
rng_clone = mi.PCG32(rng)

ray = mi.Ray3f(self.ray)
dr.enable_grad(ray.o)
dr.enable_grad(ray.d)
ray_frame = mi.Frame3f(ray.d)

warp_fields = []
for i in range(self.num_rays):
Expand Down Expand Up @@ -384,6 +392,7 @@ def backward_unroll(self):

ray_grad = dr.detach(dr.zeros(type(self.ray)))
ray_grad.o = dr.grad(ray.o)
ray_grad.d = dr.grad(ray.d)
self.set_grad_in('ray', ray_grad)


Expand Down
2 changes: 1 addition & 1 deletion src/sensors/perspective.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class PerspectiveCamera final : public ProjectiveCamera<Float, Spectrum> {
void traverse(TraversalCallback *callback) override {
Base::traverse(callback);
callback->put_parameter("x_fov", m_x_fov, +ParamFlags::NonDifferentiable);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::NonDifferentiable);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::Differentiable);
}

void parameters_changed(const std::vector<std::string> &keys) override {
Expand Down
2 changes: 1 addition & 1 deletion src/sensors/thinlens.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class ThinLensCamera final : public ProjectiveCamera<Float, Spectrum> {
callback->put_parameter("aperture_radius", m_aperture_radius, +ParamFlags::NonDifferentiable);
callback->put_parameter("focus_distance", m_focus_distance, +ParamFlags::NonDifferentiable);
callback->put_parameter("x_fov", m_x_fov, +ParamFlags::NonDifferentiable);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::NonDifferentiable);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::Differentiable);
}

void parameters_changed(const std::vector<std::string> &keys) override {
Expand Down

0 comments on commit ef9f559

Please sign in to comment.