Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

<random>: Implement Lemire's fast integer generation #3012

Merged
merged 12 commits into from
Sep 22, 2022
6 changes: 4 additions & 2 deletions stl/inc/random
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,9 @@ private:

template <class _Engine>
result_type _Eval(_Engine& _Eng, _Ty _Min, _Ty _Max) const { // compute next value in range [_Min, _Max]
_Rng_from_urng<_Uty, _Engine> _Generator(_Eng);
conditional_t<_Has_static_min_max<_Engine>::value, _Rng_from_urng_v2<_Uty, _Engine>,
_Rng_from_urng<_Uty, _Engine>>
_Generator(_Eng);

const _Uty _Umin = _Adjust(static_cast<_Uty>(_Min));
const _Uty _Umax = _Adjust(static_cast<_Uty>(_Max));
Expand All @@ -1862,7 +1864,7 @@ private:

static _Uty _Adjust(_Uty _Uval) { // convert signed ranges to unsigned ranges and vice versa
if constexpr (is_signed_v<_Ty>) {
const _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)
constexpr _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)

if (_Uval < _Adjuster) {
return static_cast<_Uty>(_Uval + _Adjuster);
Expand Down
118 changes: 114 additions & 4 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <yvals.h>
#if _STL_COMPILER_PREPROCESSOR

#include <__msvc_int128.hpp>
#include <__msvc_iter_core.hpp>
#include <climits>
#include <cstdlib>
Expand Down Expand Up @@ -6012,7 +6013,8 @@ public:

using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;

explicit _Rng_from_urng(_Urng& _Func) : _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(_Udiff(-1)) {
explicit _Rng_from_urng(_Urng& _Func)
: _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(static_cast<_Udiff>(-1)) {
for (; static_cast<_Udiff>((_Urng::max)() - (_Urng::min)()) < _Bmask; _Bmask >>= 1) {
--_Bits;
}
Expand All @@ -6023,7 +6025,7 @@ public:
_Udiff _Ret = 0; // random bits
_Udiff _Mask = 0; // 2^N - 1, _Ret is within [0, _Mask]

while (_Mask < _Udiff(_Index - 1)) { // need more random bits
while (_Mask < static_cast<_Udiff>(_Index - 1)) { // need more random bits
_Ret <<= _Bits - 1; // avoid full shift
_Ret <<= 1;
_Ret |= _Get_bits();
Expand All @@ -6033,7 +6035,7 @@ public:
}

// _Ret is [0, _Mask], _Index - 1 <= _Mask, return if unbiased
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == _Udiff(_Index - 1)) {
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == static_cast<_Udiff>(_Index - 1)) {
return static_cast<_Diff>(_Ret % _Index);
}
}
Expand All @@ -6057,7 +6059,7 @@ public:
private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
for (;;) { // repeat until random value is in range
_Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
const _Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());

if (_Val <= _Bmask) {
return _Val;
Expand All @@ -6070,6 +6072,114 @@ private:
_Udiff _Bmask; // 2^_Bits - 1
};

template <class _Diff, class _Urng>
class _Rng_from_urng_v2 { // wrap a URNG as an RNG
public:
using _Ty0 = make_unsigned_t<_Diff>;
using _Ty1 = _Invoke_result_t<_Urng&>;

using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
using _Uprod = conditional_t<_Udiff_bits <= 16, uint32_t, conditional_t<_Udiff_bits <= 32, uint64_t, _Unsigned128>>;

explicit _Rng_from_urng_v2(_Urng& _Func) : _Ref(_Func) {}

_Diff operator()(_Diff _Index) { // adapt _Urng closed range to [0, _Index)
MattStephanson marked this conversation as resolved.
Show resolved Hide resolved
// From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1),
// 2019.
_Udiff _Mask = _Bmask;
unsigned int _Niter = 1;

if constexpr (_Bits < _Udiff_bits) {
while (_Mask < static_cast<_Udiff>(_Index - 1)) {
_Mask <<= _Bits;
_Mask |= _Bmask;
++_Niter;
}
}

auto _Product = _Get_random_product(_Index, _Niter);
auto _Rem = static_cast<_Udiff>(_Product) & _Mask;

if (_Rem < _Index) {
const auto _Threshold = (_Mask - (_Index - 1)) % _Index;
MattStephanson marked this conversation as resolved.
Show resolved Hide resolved
while (_Rem < _Threshold) {
_Product = _Get_random_product(_Index, _Niter);
_Rem = static_cast<_Udiff>(_Product) & _Mask;
}
}

unsigned int _Generated_bits;
if constexpr (_Bits < _Udiff_bits) {
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
} else {
_Generated_bits = _Udiff_bits;
}

return static_cast<_Diff>(_Product >> _Generated_bits);
}

_Udiff _Get_all_bits() {
_Udiff _Ret = _Get_bits();

if constexpr (_Bits < _Udiff_bits) {
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
MattStephanson marked this conversation as resolved.
Show resolved Hide resolved
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}

return _Ret;
}

_Rng_from_urng_v2(const _Rng_from_urng_v2&) = delete;
_Rng_from_urng_v2& operator=(const _Rng_from_urng_v2&) = delete;

private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
static constexpr auto _Urng_min = (_Urng::min)();
for (;;) { // repeat until random value is in range
const _Udiff _Val = _Ref() - _Urng_min;

if (_Val <= _Bmask) {
return _Val;
}
}
}

static constexpr size_t _Calc_bits() {
auto _Bits_local = _Udiff_bits;
auto _Bmask_local = static_cast<_Udiff>(-1);
for (; (_Urng::max)() - (_Urng::min)() < _Bmask_local; _Bmask_local >>= 1) {
--_Bits_local;
}

return _Bits_local;
}

_Uprod _Get_random_product(const _Diff _Index, unsigned int _Niter) {
_Udiff _Ret = _Get_bits();
if constexpr (_Bits < _Udiff_bits) {
while (--_Niter > 0) {
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}

if constexpr (is_same_v<_Udiff, uint64_t>) {
uint64_t _High;
const auto _Low = _Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
return _Uprod{_Low, _High};
} else {
return _Uprod{_Ret} * _Uprod{_Index};
}
}

_Urng& _Ref; // reference to URNG
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
};

[[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xbad_alloc();
[[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xinvalid_argument(_In_z_ const char*);
[[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xlength_error(_In_z_ const char*);
Expand Down
1 change: 1 addition & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ tests\Dev11_1150223_shared_mutex
tests\Dev11_1158803_regex_thread_safety
tests\Dev11_1180290_filesystem_error_code
tests\GH_000177_forbidden_aliasing
tests\GH_000178_uniform_int
tests\GH_000342_filebuf_close
tests\GH_000431_copy_move_family
tests\GH_000431_equal_family
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/GH_000178_uniform_int/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\usual_matrix.lst
80 changes: 80 additions & 0 deletions tests/std/tests/GH_000178_uniform_int/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <random>

using namespace std;

template <class Generator>
bool basic_test() {
constexpr auto maximum = 1ull << 60;
constexpr auto num_bins = 20ull; // Don't change this without looking up a new threshold below.
constexpr auto bin_width = maximum / num_bins; // except possibly the last bin
constexpr auto bin_freq = static_cast<double>(bin_width) / maximum;
constexpr auto freq_rem = static_cast<double>(maximum % bin_width) / maximum;
constexpr double threshold = 31.410; // chi-squared critical value for d.f. = 20 and p = 0.05

Generator gen;
uniform_int_distribution<uint64_t> dist(0, maximum - 1u);


const int N = 20'000;
int frequency[num_bins] = {};
for (int i = 0; i < N; ++i) {
++frequency[min(dist(gen) / bin_width, num_bins - 1u)];
}

double chi_squared = 0.0;
for (unsigned int i = 0; i < num_bins; ++i) {
const auto expected = (bin_freq + (i == num_bins - 1u ? freq_rem : 0.0)) * N;
const auto delta = static_cast<double>(frequency[i] - expected);
chi_squared += delta * delta / expected;
}

return chi_squared <= threshold;
}

bool test_modulus_bias() {
// This test is designed to detect modulus bias. When generating random intergers in [0,s) with a URBG having range
// [0,R), then R mod s values must be rejected to ensure uniformity. By making (R mod s)/R large, we can introduce a
// large bias if the rejection is incorrect.

constexpr int maximum = 5; // Don't change this without looking up a new threshold below.
constexpr double threshold = 11.07; // chi-squared critical value for d.f. = 5 and p = 0.05
uniform_int_distribution<> rng(0, maximum - 1);
independent_bits_engine<mt19937, 3, uint32_t> gen;

const int N = 1'000;
int frequency[maximum] = {};
for (int i = 0; i < N; ++i) {
++frequency[rng(gen)];
}

double chi_squared = 0.0;
for (int i = 0; i < maximum; ++i) {
const double expected = static_cast<double>(N) / maximum;
const double delta = frequency[i] - expected;
chi_squared += delta * delta / expected;
}

return chi_squared <= threshold;
}

int main() {
// Four cases tested below:
// (1) URBG provides enough bits to completely fill the underlying type
// (2) URBG provides enough bits for our upper bound, but not enough to fill the type
// (3) URBG is called multiple times, but doesn't fill the type
// (4) URBG is called multiple times and overflows the number of bits in the type
assert((basic_test<mt19937_64>()));
assert((basic_test<independent_bits_engine<mt19937_64, 61, uint64_t>>()));
assert((basic_test<independent_bits_engine<mt19937, 31, uint32_t>>()));
assert((basic_test<independent_bits_engine<mt19937, 25, uint32_t>>()));

assert(test_modulus_bias());

return 0;
}