diff --git a/include/mitsuba/core/distr_1d.h b/include/mitsuba/core/distr_1d.h index 5772ca005..57a8c1447 100644 --- a/include/mitsuba/core/distr_1d.h +++ b/include/mitsuba/core/distr_1d.h @@ -21,6 +21,7 @@ template struct DiscreteDistribution { using Float = std::conditional_t, dr::value_t, Value>; using FloatStorage = DynamicBuffer; + using UInt32 = dr::uint32_array_t; using Index = dr::uint32_array_t; using Mask = dr::mask_t; @@ -46,17 +47,15 @@ template struct DiscreteDistribution { /// Initialize from a given floating point array DiscreteDistribution(const ScalarFloat *values, size_t size) : m_pmf(dr::load(values, size)) { - compute_cdf(values, size); + compute_cdf_scalar(values, size); } /// Update the internal state. Must be invoked when changing the pmf. void update() { if constexpr (dr::is_jit_v) { - FloatStorage temp = dr::migrate(m_pmf, AllocType::Host); - dr::sync_thread(); - compute_cdf(temp.data(), temp.size()); + compute_cdf(); } else { - compute_cdf(m_pmf.data(), m_pmf.size()); + compute_cdf_scalar(m_pmf.data(), m_pmf.size()); } } @@ -108,21 +107,28 @@ template struct DiscreteDistribution { * \brief %Transform a uniformly distributed sample to the stored * distribution * - * \param value + * \param sample * A uniformly distributed sample on the interval [0, 1]. * * \return * The discrete index associated with the sample */ - Index sample(Value value, Mask active = true) const { + Index sample(Value sample, Mask active = true) const { MI_MASK_ARGUMENT(active); - value *= m_sum; + sample *= m_sum; return dr::binary_search( m_valid.x(), m_valid.y(), [&](Index index) DRJIT_INLINE_LAMBDA { - return dr::gather(m_cdf, index, active) < value; + Value value = dr::gather(m_cdf, index, active); + if constexpr (!dr::is_jit_v) { + return value < sample; + } else { + // `m_valid` is not computed in JIT variants + return ((value < sample) || (dr::eq(value, 0))) && + dr::neq(value, m_sum); + } } ); } @@ -209,7 +215,21 @@ template struct DiscreteDistribution { } private: - void compute_cdf(const ScalarFloat *pmf, size_t size) { + void compute_cdf() { + if (m_pmf.empty()) + Throw("DiscreteDistribution: empty distribution!"); + if (!dr::all(m_pmf >= 0.f)) + Throw("DiscreteDistribution: entries must be non-negative!"); + if (!dr::any(m_pmf > 0.f)) + Throw("DiscreteDistribution: no probability mass found!"); + + m_cdf = dr::prefix_sum(m_pmf, false); + m_valid = ScalarVector2u(0, m_pmf.size() - 1); + m_sum = dr::gather(m_cdf, UInt32(m_pmf.size() - 1)); + m_normalization = 1.0 / m_sum; + } + + void compute_cdf_scalar(const ScalarFloat *pmf, size_t size) { if (size == 0) Throw("DiscreteDistribution: empty distribution!"); @@ -225,7 +245,7 @@ template struct DiscreteDistribution { if (value < 0.0) { Throw("DiscreteDistribution: entries must be non-negative!"); } else if (value > 0.0) { - // Determine the first and last wavelength bin with nonzero density + // Determine the first and last bin with nonzero density if (m_valid.x() == (uint32_t) -1) m_valid.x() = i; m_valid.y() = i;