Skip to content

Commit

Permalink
Fix prbvolpath NEE throughput
Browse files Browse the repository at this point in the history
This commit moves the phase function sampling after NEE. Currently, the path throughput used in NEE contains the phase function sampling weight which is wrong. We never noticed this issue beforehand as most phase functions return a sampling weight of one.

Thanks to elite-sheep for tracking down this bug.
  • Loading branch information
Sebastien Speierer authored and njroussel committed Aug 18, 2023
1 parent a456bed commit 91b0b7e
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions src/python/python/ad/integrators/prbvolpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,35 +120,40 @@ def sample(self,
last_scatter_event = dr.zeros(mi.Interaction3f)
last_scatter_direction_pdf = mi.Float(1.0)

# TODO: Support sensors inside media
# medium = mi.MediumPtr(medium)
# TODO: support sensors inside media
medium = dr.zeros(mi.MediumPtr)

channel = 0
depth = mi.UInt32(0)
valid_ray = mi.Bool(False)
specular_chain = mi.Bool(True)

if mi.is_rgb: # Sample a color channel to sample free-flight distances
if mi.is_rgb:
# Sample a color channel to sample free-flight distances
n_channels = dr.size_v(mi.Spectrum)
channel = mi.UInt32(dr.minimum(n_channels * sampler.next_1d(active), n_channels - 1))

loop = mi.Loop(name=f"Path Replay Backpropagation ({mode.name})",
state=lambda: (sampler, active, depth, ray, medium, si,
throughput, L, needs_intersection,
last_scatter_event, specular_chain, η,
last_scatter_direction_pdf, valid_ray))
state=lambda: (sampler, active, depth, ray, medium, si,
throughput, L, needs_intersection,
last_scatter_event, specular_chain, η,
last_scatter_direction_pdf, valid_ray))
while loop(active):
active &= dr.any(dr.neq(throughput, 0.0))

#--------------------- Perform russian roulette --------------------

q = dr.minimum(dr.max(throughput) * dr.sqr(η), 0.99)
perform_rr = (depth > self.rr_depth)
active &= (sampler.next_1d(active) < q) | ~perform_rr
throughput[perform_rr] = throughput * dr.rcp(q)

active_medium = active & dr.neq(medium, None) # TODO this is not necessary
active_medium = active & dr.neq(medium, None)
active_surface = active & ~active_medium

with dr.resume_grad(when=not is_primal):
#--------------------- Sample medium interaction -------------------

# Handle medium sampling and potential medium escape
u = sampler.next_1d(active_medium)
mei = medium.sample_interaction(ray, u, channel, active_medium)
Expand Down Expand Up @@ -202,22 +207,14 @@ def sample(self,
phase = mei.medium.phase_function()
phase[~act_medium_scatter] = dr.zeros(mi.PhaseFunctionPtr)

valid_ray |= act_medium_scatter
with dr.suspend_grad():
wo, phase_weight, phase_pdf = phase.sample(phase_ctx, mei, sampler.next_1d(act_medium_scatter), sampler.next_2d(act_medium_scatter), act_medium_scatter)
act_medium_scatter &= phase_pdf > 0.0
new_ray = mei.spawn_ray(wo)
ray[act_medium_scatter] = new_ray
needs_intersection |= act_medium_scatter
last_scatter_direction_pdf[act_medium_scatter] = phase_pdf
throughput[act_medium_scatter] *= phase_weight
#--------------------- Surface Interactions --------------------

#--------------------- Surface Interactions ---------------------
active_surface |= escaped_medium
intersect = active_surface & needs_intersection
si[intersect] = scene.ray_intersect(ray, intersect)

# ---------------- Intersection with emitters ----------------
# ----------------- Intersection with emitters -----------------

ray_from_camera = active_surface & dr.eq(depth, 0)
count_direct = ray_from_camera | specular_chain
emitter = si.emitter(scene)
Expand All @@ -240,7 +237,8 @@ def sample(self,
ctx = mi.BSDFContext()
bsdf = si.bsdf(ray)

# --------------------- Emitter sampling ---------------------
# ---------------------- Emitter sampling ----------------------

if self.use_nee:
active_e_surface = active_surface & mi.has_flag(bsdf.flags(), mi.BSDFFlags.Smooth) & (depth + 1 < self.max_depth)
sample_emitters = mei.medium.use_emitter_sampling()
Expand Down Expand Up @@ -271,10 +269,27 @@ def sample(self,
if dr.grad_enabled(nee_weight) or dr.grad_enabled(emitted):
dr.backward(δL * contrib)

# ----------------------- BSDF sampling ----------------------
#-------------------- Phase function sampling ------------------

valid_ray |= act_medium_scatter
with dr.suspend_grad():
bs, bsdf_weight = bsdf.sample(ctx, si, sampler.next_1d(active_surface),
sampler.next_2d(active_surface), active_surface)
wo, phase_weight, phase_pdf = phase.sample(phase_ctx, mei,
sampler.next_1d(act_medium_scatter),
sampler.next_2d(act_medium_scatter),
act_medium_scatter)
act_medium_scatter &= phase_pdf > 0.0
ray[act_medium_scatter] = mei.spawn_ray(wo)
needs_intersection |= act_medium_scatter
last_scatter_direction_pdf[act_medium_scatter] = phase_pdf
throughput[act_medium_scatter] *= phase_weight

# ------------------------ BSDF sampling -----------------------

with dr.suspend_grad():
bs, bsdf_weight = bsdf.sample(ctx, si,
sampler.next_1d(active_surface),
sampler.next_2d(active_surface),
active_surface)
active_surface &= bs.pdf > 0

bsdf_eval = bsdf.eval(ctx, si, bs.wo, active_surface)
Expand Down Expand Up @@ -310,7 +325,6 @@ def sample(self,

def sample_emitter(self, mei, si, active_medium, active_surface, scene, sampler, medium, channel,
active, adj_emitted=None, δL=None, mode=None):

is_primal = mode == dr.ADMode.Primal

active = mi.Bool(active)
Expand All @@ -319,7 +333,9 @@ def sample_emitter(self, mei, si, active_medium, active_surface, scene, sampler,
ref_interaction[active_medium] = mei
ref_interaction[active_surface] = si

ds, emitter_val = scene.sample_emitter_direction(ref_interaction, sampler.next_2d(active), False, active)
ds, emitter_val = scene.sample_emitter_direction(ref_interaction,
sampler.next_2d(active),
False, active)
ds = dr.detach(ds)
invalid = dr.eq(ds.pdf, 0.0)
emitter_val[invalid] = 0.0
Expand Down Expand Up @@ -387,8 +403,7 @@ def sample_emitter(self, mei, si, active_medium, active_surface, scene, sampler,
transmittance *= dr.detach(tr_multiplier)

# Update the ray with new origin & t parameter
new_ray = si.spawn_ray(mi.Vector3f(ray.d))
ray[active_surface] = dr.detach(new_ray)
ray[active_surface] = dr.detach(si.spawn_ray(mi.Vector3f(ray.d)))
ray.maxt = dr.detach(remaining_dist)
needs_intersection |= active_surface

Expand Down

0 comments on commit 91b0b7e

Please sign in to comment.