Skip to content

Commit

Permalink
Implemented a new bisection algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
schunkes authored and Speierers committed Sep 26, 2022
1 parent 69b7289 commit 7ca09a3
Showing 1 changed file with 42 additions and 18 deletions.
60 changes: 42 additions & 18 deletions include/mitsuba/core/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,42 @@ MI_INLINE Index find_interval(dr::scalar_t<Index> size,
return dr::binary_search<Index>(1, size - 1, pred) - 1;
}

/**
* \brief This function computes a suitable middle point for use in the \ref bisect() function
*
* To mitigate the issue of varying density of floating point numbers on the
* number line, the floats are reinterpreted as unsigned integers. As long as
* sign of both numbers is the same, this maps the floats to the evenly spaced
* set of integers. The middle of these integers ensures that the space of
* numbers is halved on each iteration of the bisection.
*
* Note that this strategy does not work if the numbers have different sign.
* In this case the function always returns 0.0 as the middle.
*/
template <typename Scalar>
Scalar middle(Scalar left, Scalar right) {
using ScalarUInt = dr::uint_array_t<Scalar>;

// Propagate invalid values (infinities, NaN) back to the caller
if (!dr::isfinite(left) || !dr::isfinite(right))
return left + right;

// always return zero if left and right have different signs
if (dr::sign(left) != dr::sign(right) && left != Scalar(0.0) && right != Scalar(0.0))
return Scalar(0.0);

// we reinterpret as UInt using the absolute value, so we store the sign
// to reapply after interpreting the result back to Float
bool negate = left < Scalar(0.0) || right < Scalar(0.0);
left = dr::abs(left);
right = dr::abs(right);
ScalarUInt left_int = dr::reinterpret_array<ScalarUInt>(left);
ScalarUInt right_int = dr::reinterpret_array<ScalarUInt>(right);
ScalarUInt mid_int = (left_int + right_int) >> 1;
Scalar mid = dr::reinterpret_array<Scalar>(mid_int);
return negate ? -mid : mid;
}

/**
* \brief Bisect a floating point interval given a predicate function
*
Expand All @@ -252,27 +288,15 @@ MI_INLINE Index find_interval(dr::scalar_t<Index> size,

template <typename Scalar, typename Predicate>
Scalar bisect(Scalar left, Scalar right, const Predicate &pred) {
int it = 0;
while (true) {
Scalar middle = (left + right) * Scalar(0.5);
Scalar mid = middle(left, right);

/* Paranoid stopping criterion */
if (middle <= left || middle >= right) {
middle = dr::next_float(left);

if (middle == right)
break;
}

if (pred(middle))
left = middle;
while (left < mid && mid < right) {
if (pred(mid))
left = mid;
else
right = middle;
it++;
if (it > (dr::is_floating_point_v<Scalar> ? 100 : 150))
throw std::runtime_error("Internal error in util::bisect!");
right = mid;
mid = middle(left, right);
}

return left;
}

Expand Down

0 comments on commit 7ca09a3

Please sign in to comment.