Skip to content

Commit

Permalink
CDF construction on device for DiscreteDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
njroussel committed Oct 23, 2023
1 parent 9267f6c commit 825f44f
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions include/mitsuba/core/distr_1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ template <typename Value> struct DiscreteDistribution {
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
using UInt32 = dr::uint32_array_t<Float>;
using Index = dr::uint32_array_t<Value>;
using Mask = dr::mask_t<Value>;

Expand All @@ -46,17 +47,15 @@ template <typename Value> struct DiscreteDistribution {
/// Initialize from a given floating point array
DiscreteDistribution(const ScalarFloat *values, size_t size)
: m_pmf(dr::load<FloatStorage>(values, size)) {
compute_cdf(values, size);
compute_cdf_scalar(values, size);
}

/// Update the internal state. Must be invoked when changing the pmf.
void update() {
if constexpr (dr::is_jit_v<Float>) {
FloatStorage temp = dr::migrate(m_pmf, AllocType::Host);
dr::sync_thread();
compute_cdf(temp.data(), temp.size());
compute_cdf();
} else {
compute_cdf(m_pmf.data(), m_pmf.size());
compute_cdf_scalar(m_pmf.data(), m_pmf.size());
}
}

Expand Down Expand Up @@ -108,21 +107,28 @@ template <typename Value> struct DiscreteDistribution {
* \brief %Transform a uniformly distributed sample to the stored
* distribution
*
* \param value
* \param sample
* A uniformly distributed sample on the interval [0, 1].
*
* \return
* The discrete index associated with the sample
*/
Index sample(Value value, Mask active = true) const {
Index sample(Value sample, Mask active = true) const {
MI_MASK_ARGUMENT(active);

value *= m_sum;
sample *= m_sum;

return dr::binary_search<Index>(
m_valid.x(), m_valid.y(),
[&](Index index) DRJIT_INLINE_LAMBDA {
return dr::gather<Value>(m_cdf, index, active) < value;
Value value = dr::gather<Value>(m_cdf, index, active);
if constexpr (!dr::is_jit_v<Float>) {
return value < sample;
} else {
// `m_valid` is not computed in JIT variants
return ((value < sample) || (dr::eq(value, 0))) &&
dr::neq(value, m_sum);
}
}
);
}
Expand Down Expand Up @@ -209,7 +215,21 @@ template <typename Value> struct DiscreteDistribution {
}

private:
void compute_cdf(const ScalarFloat *pmf, size_t size) {
void compute_cdf() {
if (m_pmf.empty())
Throw("DiscreteDistribution: empty distribution!");
if (!dr::all(m_pmf >= 0.f))
Throw("DiscreteDistribution: entries must be non-negative!");
if (!dr::any(m_pmf > 0.f))
Throw("DiscreteDistribution: no probability mass found!");

m_cdf = dr::prefix_sum(m_pmf, false);
m_valid = ScalarVector2u(0, m_pmf.size() - 1);
m_sum = dr::gather<Float>(m_cdf, UInt32(m_pmf.size() - 1));
m_normalization = 1.0 / m_sum;
}

void compute_cdf_scalar(const ScalarFloat *pmf, size_t size) {
if (size == 0)
Throw("DiscreteDistribution: empty distribution!");

Expand All @@ -225,7 +245,7 @@ template <typename Value> struct DiscreteDistribution {
if (value < 0.0) {
Throw("DiscreteDistribution: entries must be non-negative!");
} else if (value > 0.0) {
// Determine the first and last wavelength bin with nonzero density
// Determine the first and last bin with nonzero density
if (m_valid.x() == (uint32_t) -1)
m_valid.x() = i;
m_valid.y() = i;
Expand Down

0 comments on commit 825f44f

Please sign in to comment.