Skip to content

Commit

Permalink
Make sphere shape differentiable
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Nov 11, 2022
1 parent 685d0ea commit f5dbede
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 47 deletions.
143 changes: 100 additions & 43 deletions src/shapes/sphere.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,38 @@ class Sphere final : public Shape<Float, Spectrum> {
if (!(dr::abs(S[0][0] - S[1][1]) < 1e-6f && dr::abs(S[0][0] - S[2][2]) < 1e-6f))
Log(Warn, "'to_world' transform shouldn't contain non-uniform scaling!");

m_radius = S[0][0];
m_center = ScalarPoint3f(T);
m_radius = dr::norm(m_to_world.value().transform_affine(Vector3f(1.f, 0.f, 0.f)));
m_center = m_to_world.value().transform_affine(Point3f(0.f));

if (m_radius.scalar() <= 0.f) {
m_radius = dr::abs(m_radius.scalar());
m_radius = dr::abs(m_radius.value());
m_flip_normals = !m_flip_normals;
}

// Reconstruct the to_world transform with uniform scaling and no shear
m_to_world = dr::transform_compose<ScalarMatrix4f>(ScalarMatrix3f(m_radius.scalar()), Q, T);
m_to_object = m_to_world.scalar().inverse();
m_to_object = m_to_world.value().inverse();

m_inv_surface_area = dr::rcp(surface_area());

dr::make_opaque(m_radius, m_center, m_inv_surface_area);
mark_dirty();
}

void traverse(TraversalCallback *callback) override {
Base::traverse(callback);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::Differentiable);
}

void parameters_changed(const std::vector<std::string> &keys) override {
if (keys.empty() || string::contains(keys, "to_world")) {
// Update the scalar value of the matrix
m_to_world = m_to_world.value();
update();
}

Base::parameters_changed(keys);
}

ScalarBoundingBox3f bbox() const override {
ScalarBoundingBox3f bbox;
bbox.min = m_center.scalar() - m_radius.scalar();
Expand Down Expand Up @@ -293,6 +307,24 @@ class Sphere final : public Shape<Float, Spectrum> {
);
}

SurfaceInteraction3f eval_parameterization(const Point2f &uv,
uint32_t ray_flags,
Mask active) const override {
Point3f local = warp::square_to_uniform_sphere(uv);
Point3f p = dr::fmadd(local, m_radius.value(), m_center.value());

Ray3f ray(p + local, -local, 0, Wavelength(0));

PreliminaryIntersection3f pi = ray_intersect_preliminary(ray, active);

if (dr::none_or<false>(pi.is_valid()))
return dr::zeros<SurfaceInteraction3f>();

pi.shape = this;

return pi.compute_surface_interaction(ray, ray_flags, active);
}

//! @}
// =============================================================

Expand All @@ -306,12 +338,12 @@ class Sphere final : public Shape<Float, Spectrum> {
ray_intersect_preliminary_impl(const Ray3fP &ray,
dr::mask_t<FloatP> active) const {
MI_MASK_ARGUMENT(active);

using Value = std::conditional_t<dr::is_cuda_v<FloatP> ||
dr::is_diff_v<Float>,
dr::float32_array_t<FloatP>,
dr::float64_array_t<FloatP>>;
dr::is_diff_v<Float>,
dr::float32_array_t<FloatP>,
dr::float64_array_t<FloatP>>;
using Value3 = Vector<Value, 3>;

using ScalarValue = dr::scalar_t<Value>;
using ScalarValue3 = Vector<ScalarValue, 3>;

Expand Down Expand Up @@ -402,35 +434,70 @@ class Sphere final : public Shape<Float, Spectrum> {
uint32_t recursion_depth,
Mask active) const override {
MI_MASK_ARGUMENT(active);
constexpr bool IsDiff = dr::is_diff_v<Float>;

// Early exit when tracing isn't necessary
if (!m_is_instance && recursion_depth > 0)
return dr::zeros<SurfaceInteraction3f>();

// Recompute ray intersection to get differentiable t
Float t = pi.t;
if constexpr (dr::is_diff_v<Float>)
t = dr::replace_grad(t, ray_intersect_preliminary(ray, active).t);

// TODO handle RayFlags::FollowShape and RayFlags::DetachShape

// Fields requirement dependencies
bool need_dn_duv = has_flag(ray_flags, RayFlags::dNSdUV) ||
has_flag(ray_flags, RayFlags::dNGdUV);
bool need_dp_duv = has_flag(ray_flags, RayFlags::dPdUV) || need_dn_duv;
bool need_uv = has_flag(ray_flags, RayFlags::UV) || need_dp_duv;
bool need_dn_duv = has_flag(ray_flags, RayFlags::dNSdUV) ||
has_flag(ray_flags, RayFlags::dNGdUV);
bool need_dp_duv = has_flag(ray_flags, RayFlags::dPdUV) || need_dn_duv;
bool need_uv = has_flag(ray_flags, RayFlags::UV) || need_dp_duv;
bool detach_shape = has_flag(ray_flags, RayFlags::DetachShape);
bool follow_shape = has_flag(ray_flags, RayFlags::FollowShape);

const Point3f& center = m_center.value();
const Float& radius = m_radius.value();
const Transform4f& to_world = m_to_world.value();
const Transform4f& to_object = m_to_object.value();

/* If necessary, temporally suspend gradient tracking for all shape
parameters to construct a surface interaction completely detach from
the shape. */
dr::suspend_grad<Float> scope(detach_shape, center, radius, to_world, to_object);

SurfaceInteraction3f si = dr::zeros<SurfaceInteraction3f>();
si.t = dr::select(active, t, dr::Infinity<Float>);

si.sh_frame.n = dr::normalize(ray(t) - m_center.value());
Point3f local;
if constexpr (IsDiff) {
if (follow_shape) {
/* FollowShape glues the interaction point with the sphere,
therefore to also needs to account for a possible differential
rotation in to_world. we first compute a detached intersection
point in local space and transform it back in world space to
get a point rigidly attached to the sphere attached point,
including translation, scaling and rotation */
Normal3f n = dr::normalize(ray(pi.t) - center);
local = to_object.transform_affine(dr::fmadd(n, radius, center));
/* With FollowShape the local position should always be static as
the intersection point follows any motion of the sphere. */
local = dr::detach(local);
si.p = to_world.transform_affine(local);
si.sh_frame.n = (si.p - center) / radius;
si.t = dr::sqrt(dr::squared_norm(si.p - ray.o) / dr::squared_norm(ray.d));
} else {
/* To ensure that the differential interaction point stays along
the traced ray, we first recompute the intersection distance
in a differentiable way (w.r.t. to the sphere parameters) and
then compute the corresponding point along the ray. */
si.t = dr::replace_grad(pi.t, ray_intersect_preliminary(ray, active).t);
si.p = ray(si.t);
si.sh_frame.n = dr::normalize(si.p - center);
local = to_object.transform_affine(si.p);
}
} else {
si.t = pi.t;
si.sh_frame.n = dr::normalize(ray(si.t) - center);
// Re-project onto the sphere to improve accuracy
si.p = dr::fmadd(si.sh_frame.n, radius, center);
local = to_object.transform_affine(si.p);
}

// Re-project onto the sphere to improve accuracy
si.p = dr::fmadd(si.sh_frame.n, m_radius.value(), m_center.value());
si.t = dr::select(active, si.t, dr::Infinity<Float>);

if (likely(need_uv)) {
Vector3f local = m_to_object.value().transform_affine(si.p);

Float rd_2 = dr::sqr(local.x()) + dr::sqr(local.y()),
theta = unit_angle_z(local),
phi = dr::atan2(local.y(), local.x());
Expand All @@ -454,8 +521,8 @@ class Sphere final : public Shape<Float, Spectrum> {
if (unlikely(dr::any_or<true>(singularity_mask)))
si.dp_dv[singularity_mask] = Vector3f(1.f, 0.f, 0.f);

si.dp_du = m_to_world.value() * si.dp_du * (2.f * dr::Pi<Float>);
si.dp_dv = m_to_world.value() * si.dp_dv * dr::Pi<Float>;
si.dp_du = to_world * si.dp_du * (2.f * dr::Pi<Float>);
si.dp_dv = to_world * si.dp_dv * dr::Pi<Float>;
}
}

Expand All @@ -466,7 +533,7 @@ class Sphere final : public Shape<Float, Spectrum> {

if (need_dn_duv) {
Float inv_radius =
(m_flip_normals ? -1.f : 1.f) * dr::rcp(m_radius.value());
(m_flip_normals ? -1.f : 1.f) * dr::rcp(radius);
si.dn_du = si.dp_du * inv_radius;
si.dn_dv = si.dp_dv * inv_radius;
}
Expand All @@ -480,22 +547,12 @@ class Sphere final : public Shape<Float, Spectrum> {
return si;
}

//! @}
// =============================================================

void traverse(TraversalCallback *callback) override {
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::NonDifferentiable);
Base::traverse(callback);
bool parameters_grad_enabled() const override {
return dr::grad_enabled(m_radius) || dr::grad_enabled(m_center);
}

void parameters_changed(const std::vector<std::string> &keys) override {
if (keys.empty() || string::contains(keys, "to_world")) {
// Update the scalar value of the matrix
m_to_world = m_to_world.value();
update();
}
Base::parameters_changed();
}
//! @}
// =============================================================

#if defined(MI_ENABLE_CUDA)
using Base::m_optix_data_ptr;
Expand Down
86 changes: 82 additions & 4 deletions src/shapes/tests/test_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def sample_cone(sample, cos_theta_max):
for xi_2 in dr.linspace(Float, 1e-3, 1 - 1e-3, 10):
sample = sphere.sample_direction(it, [xi_2, 1 - xi_1])
d = sample_cone([xi_1, xi_2], cos_cone_angle)
its = sphere.ray_intersect(mi.Ray3f(it.p, d))
si = sphere.ray_intersect(mi.Ray3f(it.p, d))
assert dr.allclose(d, sample.d, atol=1e-5, rtol=1e-5)
assert dr.allclose(its.t, sample.dist, atol=1e-5, rtol=1e-5)
assert dr.allclose(its.p, sample.p, atol=1e-5, rtol=1e-5)
assert dr.allclose(si.t, sample.dist, atol=1e-5, rtol=1e-5)
assert dr.allclose(si.p, sample.p, atol=1e-5, rtol=1e-5)


def test06_differentiable_surface_interaction_ray_forward(variants_all_ad_rgb):
Expand Down Expand Up @@ -217,7 +217,85 @@ def test07_differentiable_surface_interaction_ray_backward(variants_all_ad_rgb):
assert dr.allclose(dr.grad(ray.o), [0, 0, -1])


def test08_si_singularity(variants_all_rgb):
def test08_differentiable_surface_interaction_ray_forward_follow_shape(variants_all_ad_rgb):
shape = mi.load_dict({'type' : 'sphere'})
params = mi.traverse(shape)

# Test 01: When the sphere is inflating, the point hitting the center will
# move back along the ray. The normal isn't changing nor the UVs.

ray = mi.Ray3f(mi.Vector3f(0, 0, -2), mi.Vector3f(0, 0, 1))

theta = mi.Float(0)
dr.enable_grad(theta)
params['to_world'] = mi.Transform4f.scale(1 + theta)
params.update()
si = shape.ray_intersect(ray, mi.RayFlags.All)

dr.forward(theta)

assert dr.allclose(dr.grad(si.t), -1)
assert dr.allclose(dr.grad(si.p), [0, 0, -1])
assert dr.allclose(dr.grad(si.n), 0)
assert dr.allclose(dr.grad(si.uv), 0)

# Test 02: With FollowShape, an intersection point at the pole of a translating
# sphere should move with the pole. The normal and the UVs should be static.

ray = mi.Ray3f(mi.Vector3f(0.0, 0.0, -2.0), mi.Vector3f(0.0, 0.0, 1.0))

theta = mi.Float(0.0)
dr.enable_grad(theta)
params['to_world'] = mi.Transform4f.translate([theta, 0.0, 0.0])
params.update()
si = shape.ray_intersect(ray, mi.RayFlags.All | mi.RayFlags.FollowShape)

dr.forward(theta)

assert dr.allclose(dr.grad(si.p), [1.0, 0.0, 0.0])
assert dr.allclose(dr.grad(si.n), 0.0)
assert dr.allclose(dr.grad(si.uv), 0.0)

# Test 03: With FollowShape, an intersection point and normal at the pole of
# a rotating sphere should follow the rotation speed along the
# tangent direction. The UVs should be static.

ray = mi.Ray3f(mi.Vector3f(0.0, 0.0, -2.0), mi.Vector3f(0.0, 0.0, 1.0))

theta = mi.Float(0.0)
dr.enable_grad(theta)
params['to_world'] = mi.Transform4f.rotate([0, 1, 0], 90 * theta)
params.update()
si = shape.ray_intersect(ray, mi.RayFlags.All | mi.RayFlags.FollowShape)

dr.forward(theta)

assert dr.allclose(dr.grad(si.p), [-dr.pi / 2.0, 0.0, 0.0])
assert dr.allclose(dr.grad(si.n), [-dr.pi / 2.0, 0.0, 0.0])
assert dr.allclose(dr.grad(si.uv), 0.0)

# Test 04: Without FollowShape, a sphere that is only rotating shouldn't
# produce any gradients for the intersection point and normal, but
# for the UVs.

ray = mi.Ray3f(mi.Vector3f(0.0, -2.0, 0.0), mi.Vector3f(0.0, 1.0, 0.0))

theta = mi.Float(0.0)
dr.enable_grad(theta)
params['to_world'] = mi.Transform4f.rotate([1, 0, 0], 90 * theta)
params.update()
si = shape.ray_intersect(ray, mi.RayFlags.All)

dr.forward(theta)

assert dr.allclose(dr.grad(si.p), 0.0)
assert dr.allclose(dr.grad(si.n), 0.0)
assert dr.allclose(dr.grad(si.uv), [0.0, -0.5])




def test09_si_singularity(variants_all_rgb):
scene = mi.load_dict({"type" : "scene", 's': { 'type': 'sphere' }})
ray = mi.Ray3f([0, 0, -1], [0, 0, 1])

Expand Down

0 comments on commit f5dbede

Please sign in to comment.