Skip to content

Commit

Permalink
Make other sensor parameters differentiable
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Oct 20, 2022
1 parent fad031a commit ea513f7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions include/mitsuba/render/sensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class MI_EXPORT_LIB ProjectiveCamera : public Sensor<Float, Spectrum> {
ScalarFloat far_clip() const { return m_far_clip; }

/// Return the distance to the focal plane
ScalarFloat focus_distance() const { return m_focus_distance; }
Float focus_distance() const { return m_focus_distance; }

void traverse(TraversalCallback *callback) override {
callback->put_parameter("near_clip", m_near_clip, +ParamFlags::NonDifferentiable);
Expand All @@ -213,7 +213,7 @@ class MI_EXPORT_LIB ProjectiveCamera : public Sensor<Float, Spectrum> {
protected:
ScalarFloat m_near_clip;
ScalarFloat m_far_clip;
ScalarFloat m_focus_distance;
Float m_focus_distance;
};

// ========================================================================
Expand Down
4 changes: 2 additions & 2 deletions src/sensors/perspective.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ class PerspectiveCamera final : public ProjectiveCamera<Float, Spectrum> {

void traverse(TraversalCallback *callback) override {
Base::traverse(callback);
callback->put_parameter("x_fov", m_x_fov, +ParamFlags::NonDifferentiable);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::Differentiable);
callback->put_parameter("x_fov", m_x_fov, ParamFlags::Differentiable | ParamFlags::Discontinuous);
callback->put_parameter("to_world", *m_to_world.ptr(), ParamFlags::Differentiable | ParamFlags::Discontinuous);
}

void parameters_changed(const std::vector<std::string> &keys) override {
Expand Down
12 changes: 6 additions & 6 deletions src/sensors/thinlens.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class ThinLensCamera final : public ProjectiveCamera<Float, Spectrum> {

m_aperture_radius = props.get<ScalarFloat>("aperture_radius");

if (m_aperture_radius == 0.f) {
if (dr::all(dr::eq(m_aperture_radius, 0.f))) {
Log(Warn, "Can't have a zero aperture radius -- setting to %f", dr::Epsilon<Float>);
m_aperture_radius = dr::Epsilon<Float>;
}
Expand All @@ -170,10 +170,10 @@ class ThinLensCamera final : public ProjectiveCamera<Float, Spectrum> {

void traverse(TraversalCallback *callback) override {
Base::traverse(callback);
callback->put_parameter("aperture_radius", m_aperture_radius, +ParamFlags::NonDifferentiable);
callback->put_parameter("focus_distance", m_focus_distance, +ParamFlags::NonDifferentiable);
callback->put_parameter("x_fov", m_x_fov, +ParamFlags::NonDifferentiable);
callback->put_parameter("to_world", *m_to_world.ptr(), +ParamFlags::Differentiable);
callback->put_parameter("aperture_radius", m_aperture_radius, ParamFlags::Differentiable | ParamFlags::Discontinuous);
callback->put_parameter("focus_distance", m_focus_distance, ParamFlags::Differentiable | ParamFlags::Discontinuous);
callback->put_parameter("x_fov", m_x_fov, ParamFlags::Differentiable | ParamFlags::Discontinuous);
callback->put_parameter("to_world", *m_to_world.ptr(), ParamFlags::Differentiable | ParamFlags::Discontinuous);
}

void parameters_changed(const std::vector<std::string> &keys) override {
Expand Down Expand Up @@ -338,7 +338,7 @@ class ThinLensCamera final : public ProjectiveCamera<Float, Spectrum> {
if (dr::none_or<false>(valid))
return { ds, dr::zeros<Spectrum>() };

ds.uv = dr::head<2>(scr) * m_resolution;
ds.uv = dr::head<2>(scr) * m_resolution;
ds.p = trafo.transform_affine(aperture_p);
ds.d = (ds.p - it.p) * inv_dist;
ds.dist = dist;
Expand Down

0 comments on commit ea513f7

Please sign in to comment.