Skip to content

Commit

Permalink
CDF construction on device for IrregularContinuousDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
njroussel committed Oct 23, 2023
1 parent b12f365 commit 8f71fe9
Showing 1 changed file with 66 additions and 21 deletions.
87 changes: 66 additions & 21 deletions include/mitsuba/core/distr_1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ template <typename Value> struct IrregularContinuousDistribution {
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
using UInt32 = dr::uint32_array_t<Float>;
using Index = dr::uint32_array_t<Value>;
using Mask = dr::mask_t<Value>;

Expand Down Expand Up @@ -641,7 +642,7 @@ template <typename Value> struct IrregularContinuousDistribution {
const ScalarFloat *pdf,
size_t size)
: m_nodes(dr::load<FloatStorage>(nodes, size)), m_pdf(dr::load<FloatStorage>(pdf, size)) {
compute_cdf(nodes, pdf, size);
compute_cdf_scalar(nodes, pdf, size);
}

/// Update the internal state. Must be invoked when changing the pdf or range.
Expand All @@ -650,12 +651,9 @@ template <typename Value> struct IrregularContinuousDistribution {
Throw("IrregularContinuousDistribution: 'pdf' and 'nodes' size mismatch!");

if constexpr (dr::is_jit_v<Float>) {
FloatStorage temp_nodes = dr::migrate(m_nodes, AllocType::Host);
FloatStorage temp_pdf = dr::migrate(m_pdf, AllocType::Host);
dr::sync_thread();
compute_cdf(temp_nodes.data(), temp_pdf.data(), temp_nodes.size());
compute_cdf();
} else {
compute_cdf(m_nodes.data(), m_pdf.data(), m_nodes.size());
compute_cdf_scalar(m_nodes.data(), m_pdf.data(), m_nodes.size());
}
}

Expand Down Expand Up @@ -764,21 +762,28 @@ template <typename Value> struct IrregularContinuousDistribution {
* \brief %Transform a uniformly distributed sample to the stored
* distribution
*
* \param value
* \param sample
* A uniformly distributed sample on the interval [0, 1].
*
* \return
* The sampled position.
*/
Value sample(Value value, Mask active = true) const {
Value sample(Value sample, Mask active = true) const {
MI_MASK_ARGUMENT(active);

value *= m_integral;
sample *= m_integral;

Index index = dr::binary_search<Index>(
m_valid.x(), m_valid.y(),
[&](Index index) DRJIT_INLINE_LAMBDA {
return dr::gather<Value>(m_cdf, index, active) < value;
Value value = dr::gather<Value>(m_cdf, index, active);
if constexpr (!dr::is_jit_v<Float>) {
return value < sample;
} else {
// `m_valid` is not computed in JIT variants
return ((value < sample) || (dr::eq(value, 0))) &&
dr::neq(value, m_integral);
}
}
);

Expand All @@ -789,10 +794,10 @@ template <typename Value> struct IrregularContinuousDistribution {
c0 = dr::gather<Value>(m_cdf, index - 1u, active && index > 0),
w = x1 - x0;

value = (value - c0) / w;
sample = (sample - c0) / w;

Value t_linear = (y0 - dr::safe_sqrt(dr::sqr(y0) + 2.f * value * (y1 - y0))) / (y0 - y1),
t_const = value / y0,
Value t_linear = (y0 - dr::safe_sqrt(dr::sqr(y0) + 2.f * sample * (y1 - y0))) / (y0 - y1),
t_const = sample / y0,
t = dr::select(dr::eq(y0, y1), t_const, t_linear);

return dr::fmadd(t, w, x0);
Expand All @@ -802,7 +807,7 @@ template <typename Value> struct IrregularContinuousDistribution {
* \brief %Transform a uniformly distributed sample to the stored
* distribution
*
* \param value
* \param sample
* A uniformly distributed sample on the interval [0, 1].
*
* \return
Expand All @@ -811,15 +816,22 @@ template <typename Value> struct IrregularContinuousDistribution {
* 1. the sampled position.
* 2. the normalized probability density of the sample.
*/
std::pair<Value, Value> sample_pdf(Value value, Mask active = true) const {
std::pair<Value, Value> sample_pdf(Value sample, Mask active = true) const {
MI_MASK_ARGUMENT(active);

value *= m_integral;
sample *= m_integral;

Index index = dr::binary_search<Index>(
m_valid.x(), m_valid.y(),
[&](Index index) DRJIT_INLINE_LAMBDA {
return dr::gather<Value>(m_cdf, index, active) < value;
Value value = dr::gather<Value>(m_cdf, index, active);
if constexpr (!dr::is_jit_v<Float>) {
return value < sample;
} else {
// `m_valid` is not computed in JIT variants
return ((value < sample) || (dr::eq(value, 0))) &&
dr::neq(value, m_integral);
}
}
);

Expand All @@ -830,10 +842,10 @@ template <typename Value> struct IrregularContinuousDistribution {
c0 = dr::gather<Value>(m_cdf, index - 1u, active && index > 0),
w = x1 - x0;

value = (value - c0) / w;
sample = (sample - c0) / w;

Value t_linear = (y0 - dr::safe_sqrt(dr::sqr(y0) + 2.f * value * (y1 - y0))) / (y0 - y1),
t_const = value / y0,
Value t_linear = (y0 - dr::safe_sqrt(dr::sqr(y0) + 2.f * sample * (y1 - y0))) / (y0 - y1),
t_const = sample / y0,
t = dr::select(dr::eq(y0, y1), t_const, t_linear);

return { dr::fmadd(t, w, x0),
Expand All @@ -852,7 +864,40 @@ template <typename Value> struct IrregularContinuousDistribution {
}

private:
void compute_cdf(const ScalarFloat *nodes, const ScalarFloat *pdf, size_t size) {
void compute_cdf() {
if (m_pdf.size() < 2)
Throw("IrregularContinuousDistribution: needs at least two entries!");
if (!dr::all(m_pdf >= 0.f))
Throw("IrregularContinuousDistribution: entries must be non-negative!");
if (!dr::any(m_pdf > 0.f))
Throw("IrregularContinuousDistribution: no probability mass found!");

uint32_t size = m_pdf.size() - 1;
UInt32 index_curr = dr::arange<UInt32>(size);
UInt32 index_next = dr::arange<UInt32>(1, size + 1);

Float nodes_curr = dr::gather<Float>(m_nodes, index_curr);
Float nodes_next = dr::gather<Float>(m_nodes, index_next);

if (dr::any(nodes_next - nodes_curr <= 0))
Throw("IrregularContinuousDistribution: node positions must be strictly increasing!");

Float pdf_curr = dr::gather<Float>(m_pdf, index_curr);
Float pdf_next = dr::gather<Float>(m_pdf, index_next);

Float interval_integral =
0.5 * (nodes_next - nodes_curr) * (pdf_curr + pdf_next);
m_cdf = dr::prefix_sum(interval_integral, false);

m_integral = dr::gather<Float>(m_cdf, UInt32(size - 1));
m_normalization = 1.0 / m_integral;
m_range = ScalarVector2f(dr::slice(m_nodes, 0), dr::slice(m_nodes, size));
m_valid = ScalarVector2u(0, size);
m_interval_size = dr::slice(dr::min(nodes_next - nodes_curr));
m_max = dr::slice(dr::max(m_pdf));
}

void compute_cdf_scalar(const ScalarFloat *nodes, const ScalarFloat *pdf, size_t size) {
if (size < 2)
Throw("IrregularContinuousDistribution: needs at least two entries!");

Expand Down

0 comments on commit 8f71fe9

Please sign in to comment.