Skip to content

Commit

Permalink
Make instances differentiable
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Nov 11, 2022
1 parent 29bcc75 commit 54d2d3a
Showing 1 changed file with 67 additions and 16 deletions.
83 changes: 67 additions & 16 deletions src/shapes/instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ details on how to create instances, refer to the :ref:`shape-shapegroup` plugin.
template <typename Float, typename Spectrum>
class Instance final: public Shape<Float, Spectrum> {
public:
MI_IMPORT_BASE(Shape, m_id, m_to_world, m_to_object)
MI_IMPORT_BASE(Shape, m_id, m_to_world, m_to_object, mark_dirty)
MI_IMPORT_TYPES(BSDF)

using typename Base::ScalarSize;
Expand All @@ -77,14 +77,15 @@ class Instance final: public Shape<Float, Spectrum> {
}

void traverse(TraversalCallback *callback) override {
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::NonDifferentiable);
Base::traverse(callback);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::Differentiable | ParamFlags::Discontinuous);
}

void parameters_changed(const std::vector<std::string> &keys) override {
if (keys.empty() || string::contains(keys, "to_world")) {
// Update the scalar value of the matrix
m_to_world = m_to_world.value();
m_to_object = m_to_world.value().inverse();
mark_dirty();
}
Base::parameters_changed();
}
Expand Down Expand Up @@ -148,38 +149,84 @@ class Instance final: public Shape<Float, Spectrum> {
Mask active) const override {
MI_MASK_ARGUMENT(active);

const Transform4f& to_world = m_to_world.value();
const Transform4f& to_object = m_to_object.value();

constexpr bool IsDiff = dr::is_diff_v<Float>;
bool grad_enabled = dr::grad_enabled(to_world);

if constexpr (IsDiff) {
if (grad_enabled && m_shapegroup->parameters_grad_enabled())
Throw("Cannot differentiate instance parameters and shapegroup "
"internal parameters at the same time!");
}

// Nested instancing is not supported
if (recursion_depth > 0)
return dr::zeros<SurfaceInteraction3f>();

SurfaceInteraction3f si = m_shapegroup->compute_surface_interaction(
m_to_object.value().transform_affine(ray), pi, ray_flags, recursion_depth, active);
bool detach_shape = has_flag(ray_flags, RayFlags::DetachShape);
bool follow_shape = has_flag(ray_flags, RayFlags::FollowShape);

/* If necessary, temporally suspend gradient tracking for all shape
parameters to construct a surface interaction completely detach from
the shape. */
dr::suspend_grad<Float> scope(detach_shape, to_world, to_object);

SurfaceInteraction3f si;
{
/* Temporally suspend gradient tracking when `to_world` need to be
differentiated as the various terms of `si` will be recomputed
to account for the motion of `si` already. */
dr::suspend_grad<Float> scope(grad_enabled);
si = m_shapegroup->compute_surface_interaction(
to_object.transform_affine(ray), pi, ray_flags,
recursion_depth, active);
}

si.p = m_to_world.value().transform_affine(si.p);
si.n = dr::normalize(m_to_world.value().transform_affine(si.n));
// Hit point `si.p` is only attached to the surface motion
si.p = to_world.transform_affine(si.p);
si.n = dr::normalize(dr::detach(to_world).transform_affine(si.n));
if (likely(has_flag(ray_flags, RayFlags::ShadingFrame)))
si.sh_frame.n = dr::normalize(dr::detach(to_world).transform_affine(si.sh_frame.n));

if constexpr (IsDiff) {
if (follow_shape && grad_enabled) {
/* Recompute si.t in a differential manner as the distance
between the ray origin and the hit point following the moving
surface. */
si.t = dr::sqrt(dr::squared_norm(si.p - ray.o) / dr::squared_norm(ray.d));
} else if (!follow_shape && grad_enabled) {
/* Differential recomputation of the intersection of the ray
with the moving plane tangent to the hit point. In this
scenario, it is important that `si.p` stays along the ray as
the surface moves. */
si.t = (dr::dot(si.n, si.p) - dr::dot(si.n, ray.o)) / dr::dot(si.n, ray.d);
si.p = ray(si.t);
// TODO what can we do about the normals? Take into account curvature?
// TODO si.uv should be attached but we don't know about the underlying parameterization
}
}

if (likely(has_flag(ray_flags, RayFlags::ShadingFrame))) {
si.sh_frame.n = dr::normalize(m_to_world.value().transform_affine(si.sh_frame.n));
if (likely(has_flag(ray_flags, RayFlags::ShadingFrame)))
si.initialize_sh_frame();
}

if (likely(has_flag(ray_flags, RayFlags::dPdUV))) {
si.dp_du = m_to_world.value().transform_affine(si.dp_du);
si.dp_dv = m_to_world.value().transform_affine(si.dp_dv);
si.dp_du = to_world.transform_affine(si.dp_du);
si.dp_dv = to_world.transform_affine(si.dp_dv);
}

if (has_flag(ray_flags, RayFlags::dNGdUV) || has_flag(ray_flags, RayFlags::dNSdUV)) {
Normal3f n = has_flag(ray_flags, RayFlags::dNGdUV) ? si.n : si.sh_frame.n;

// Determine the length of the transformed normal before it was re-normalized
Normal3f tn = m_to_world.value().transform_affine(
dr::normalize(m_to_object.value().transform_affine(n)));
Normal3f tn = to_world.transform_affine(dr::normalize(to_object.transform_affine(n)));
Float inv_len = dr::rcp(dr::norm(tn));
tn *= inv_len;

// Apply transform to dn_du and dn_dv
si.dn_du = m_to_world.value().transform_affine(Normal3f(si.dn_du)) * inv_len;
si.dn_dv = m_to_world.value().transform_affine(Normal3f(si.dn_dv)) * inv_len;
si.dn_du = to_world.transform_affine(Normal3f(si.dn_du)) * inv_len;
si.dn_dv = to_world.transform_affine(Normal3f(si.dn_dv)) * inv_len;

si.dn_du -= tn * dr::dot(tn, si.dn_du);
si.dn_dv -= tn * dr::dot(tn, si.dn_dv);
Expand Down Expand Up @@ -233,6 +280,10 @@ class Instance final: public Shape<Float, Spectrum> {
}
#endif

bool parameters_grad_enabled() const override {
return dr::grad_enabled(m_to_world) | m_shapegroup->parameters_grad_enabled();
}

MI_DECLARE_CLASS()
private:
ref<ShapeGroup_> m_shapegroup;
Expand Down

0 comments on commit 54d2d3a

Please sign in to comment.