Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

<random>: Implement Lemire's fast integer generation #3012

Merged
merged 12 commits into from
Sep 22, 2022
6 changes: 2 additions & 4 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,5 @@ function(add_benchmark name)
target_link_libraries(benchmark-${name} PRIVATE benchmark::benchmark)
endfunction()

add_benchmark(std_copy
src/std_copy.cpp
CXX_STANDARD 23
)
add_benchmark(std_copy src/std_copy.cpp)
add_benchmark(random_integer_generation src/random_integer_generation.cpp)
102 changes: 102 additions & 0 deletions benchmarks/src/random_integer_generation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <benchmark/benchmark.h>
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
#include <cstdint>
#include <random>

/// Test URBGs alone

static void BM_mt19937(benchmark::State& state) {
std::mt19937 gen;
for (auto _ : state) {
benchmark::DoNotOptimize(gen());
}
}
BENCHMARK(BM_mt19937);

static void BM_mt19937_64(benchmark::State& state) {
std::mt19937_64 gen;
for (auto _ : state) {
benchmark::DoNotOptimize(gen());
}
}
BENCHMARK(BM_mt19937_64);

static void BM_lcg(benchmark::State& state) {
std::minstd_rand gen;
for (auto _ : state) {
benchmark::DoNotOptimize(gen());
}
}
BENCHMARK(BM_lcg);

std::uint32_t GetMax() {
std::random_device gen;
std::uniform_int_distribution<std::uint32_t> dist(10'000'000, 20'000'000);
return dist(gen);
}

static const std::uint32_t maximum = GetMax(); // random divisor to prevent strength reduction

/// Test mt19937

static void BM_raw_mt19937_old(benchmark::State& state) {
std::mt19937 gen;
std::_Rng_from_urng<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_old);

static void BM_raw_mt19937_new(benchmark::State& state) {
std::mt19937 gen;
std::_Rng_from_urng_v2<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_new);

/// Test mt19937_64

static void BM_raw_mt19937_64_old(benchmark::State& state) {
std::mt19937_64 gen;
std::_Rng_from_urng<std::uint64_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_64_old);

static void BM_raw_mt19937_64_new(benchmark::State& state) {
std::mt19937_64 gen;
std::_Rng_from_urng_v2<std::uint64_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_64_new);

/// Test minstd_rand

static void BM_raw_lcg_old(benchmark::State& state) {
std::minstd_rand gen;
std::_Rng_from_urng<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_lcg_old);

static void BM_raw_lcg_new(benchmark::State& state) {
std::minstd_rand gen;
std::_Rng_from_urng_v2<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_lcg_new);

BENCHMARK_MAIN();
6 changes: 4 additions & 2 deletions stl/inc/random
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,9 @@ private:

template <class _Engine>
result_type _Eval(_Engine& _Eng, _Ty _Min, _Ty _Max) const { // compute next value in range [_Min, _Max]
_Rng_from_urng<_Uty, _Engine> _Generator(_Eng);
conditional_t<_Has_static_min_max<_Engine>::value, _Rng_from_urng_v2<_Uty, _Engine>,
_Rng_from_urng<_Uty, _Engine>>
_Generator(_Eng);

const _Uty _Umin = _Adjust(static_cast<_Uty>(_Min));
const _Uty _Umax = _Adjust(static_cast<_Uty>(_Max));
Expand All @@ -1862,7 +1864,7 @@ private:

static _Uty _Adjust(_Uty _Uval) { // convert signed ranges to unsigned ranges and vice versa
if constexpr (is_signed_v<_Ty>) {
const _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)
constexpr _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)

if (_Uval < _Adjuster) {
return static_cast<_Uty>(_Uval + _Adjuster);
Expand Down
132 changes: 128 additions & 4 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <yvals.h>
#if _STL_COMPILER_PREPROCESSOR

#include <__msvc_int128.hpp>
#include <__msvc_iter_core.hpp>
#include <climits>
#include <cstdlib>
Expand Down Expand Up @@ -6029,7 +6030,8 @@ public:

using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;

explicit _Rng_from_urng(_Urng& _Func) : _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(_Udiff(-1)) {
explicit _Rng_from_urng(_Urng& _Func)
: _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(static_cast<_Udiff>(-1)) {
for (; static_cast<_Udiff>((_Urng::max)() - (_Urng::min)()) < _Bmask; _Bmask >>= 1) {
--_Bits;
}
Expand All @@ -6040,7 +6042,7 @@ public:
_Udiff _Ret = 0; // random bits
_Udiff _Mask = 0; // 2^N - 1, _Ret is within [0, _Mask]

while (_Mask < _Udiff(_Index - 1)) { // need more random bits
while (_Mask < static_cast<_Udiff>(_Index - 1)) { // need more random bits
_Ret <<= _Bits - 1; // avoid full shift
_Ret <<= 1;
_Ret |= _Get_bits();
Expand All @@ -6050,7 +6052,7 @@ public:
}

// _Ret is [0, _Mask], _Index - 1 <= _Mask, return if unbiased
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == _Udiff(_Index - 1)) {
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == static_cast<_Udiff>(_Index - 1)) {
return static_cast<_Diff>(_Ret % _Index);
}
}
Expand All @@ -6074,7 +6076,7 @@ public:
private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
for (;;) { // repeat until random value is in range
_Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
const _Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());

if (_Val <= _Bmask) {
return _Val;
Expand All @@ -6087,6 +6089,128 @@ private:
_Udiff _Bmask; // 2^_Bits - 1
};

template <class _Diff, class _Urng>
class _Rng_from_urng_v2 { // wrap a URNG as an RNG
public:
using _Ty0 = make_unsigned_t<_Diff>;
using _Ty1 = _Invoke_result_t<_Urng&>;

using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
using _Uprod = conditional_t<_Udiff_bits <= 16, uint32_t, conditional_t<_Udiff_bits <= 32, uint64_t, _Unsigned128>>;

explicit _Rng_from_urng_v2(_Urng& _Func) : _Ref(_Func) {}

_Diff operator()(_Diff _Index) { // adapt _Urng closed range to [0, _Index)
MattStephanson marked this conversation as resolved.
Show resolved Hide resolved
// From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1),
// 2019.
//
// Algorithm 5 <-> This Code:
// m <-> _Product
// l <-> _Rem
// s <-> _Index
// t <-> _Threshold
// L <-> _Generated_bits
MattStephanson marked this conversation as resolved.
Show resolved Hide resolved
// 2^L - 1 <-> _Mask

_Udiff _Mask = _Bmask;
unsigned int _Niter = 1;

if constexpr (_Bits < _Udiff_bits) {
while (_Mask < static_cast<_Udiff>(_Index - 1)) {
_Mask <<= _Bits;
_Mask |= _Bmask;
++_Niter;
}
}

// x <- random integer in [0, 2^L)
// m <- x * s
auto _Product = _Get_random_product(_Index, _Niter);
// l <- m mod 2^L
auto _Rem = static_cast<_Udiff>(_Product) & _Mask;

if (_Rem < _Index) {
// t <- (2^L - s) mod s
const auto _Threshold = (_Mask - _Index + 1) % _Index;
while (_Rem < _Threshold) {
_Product = _Get_random_product(_Index, _Niter);
_Rem = static_cast<_Udiff>(_Product) & _Mask;
}
}

unsigned int _Generated_bits;
if constexpr (_Bits < _Udiff_bits) {
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
} else {
_Generated_bits = _Udiff_bits;
}

// m / 2^L
return static_cast<_Diff>(_Product >> _Generated_bits);
}

_Udiff _Get_all_bits() {
_Udiff _Ret = _Get_bits();

if constexpr (_Bits < _Udiff_bits) {
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
MattStephanson marked this conversation as resolved.
Show resolved Hide resolved
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}

return _Ret;
}

_Rng_from_urng_v2(const _Rng_from_urng_v2&) = delete;
_Rng_from_urng_v2& operator=(const _Rng_from_urng_v2&) = delete;

private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
static constexpr auto _Urng_min = (_Urng::min)();
for (;;) { // repeat until random value is in range
const _Udiff _Val = _Ref() - _Urng_min;

if (_Val <= _Bmask) {
return _Val;
}
}
}

static constexpr size_t _Calc_bits() {
auto _Bits_local = _Udiff_bits;
auto _Bmask_local = static_cast<_Udiff>(-1);
for (; (_Urng::max)() - (_Urng::min)() < _Bmask_local; _Bmask_local >>= 1) {
--_Bits_local;
}

return _Bits_local;
}

_Uprod _Get_random_product(const _Diff _Index, unsigned int _Niter) {
_Udiff _Ret = _Get_bits();
if constexpr (_Bits < _Udiff_bits) {
while (--_Niter > 0) {
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}

if constexpr (is_same_v<_Udiff, uint64_t>) {
uint64_t _High;
const auto _Low = _Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
return _Uprod{_Low, _High};
} else {
return _Uprod{_Ret} * _Uprod{_Index};
}
}

_Urng& _Ref; // reference to URNG
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
};

extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xbad_alloc();
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xinvalid_argument(_In_z_ const char*);
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xlength_error(_In_z_ const char*);
Expand Down
1 change: 1 addition & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ tests\Dev11_1150223_shared_mutex
tests\Dev11_1158803_regex_thread_safety
tests\Dev11_1180290_filesystem_error_code
tests\GH_000177_forbidden_aliasing
tests\GH_000178_uniform_int
tests\GH_000342_filebuf_close
tests\GH_000431_copy_move_family
tests\GH_000431_equal_family
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/GH_000178_uniform_int/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\usual_matrix.lst
Loading