Skip to content

Commit

Permalink
Use sampler-based Random API (#206)
Browse files Browse the repository at this point in the history
This prevents `rand` from returning a `ReinterpretArray` to avoid the performance problem with `ReinterpretArray` .
This also supports specifying the RNG option.
  • Loading branch information
kimikage committed Aug 8, 2020
1 parent ae6b911 commit 3e41a6a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Expand Down
16 changes: 12 additions & 4 deletions src/FixedPointNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import Base: ==, <, <=, -, +, *, /, ~, isapprox,
big, rationalize, float, trunc, round, floor, ceil, bswap, clamp,
div, fld, rem, mod, mod1, fld1, min, max, minmax,
signed, unsigned, copysign, flipsign, signbit,
rand, length
length

import Statistics # for _mean_promote
import Random: Random, AbstractRNG, SamplerType, rand!

using Base.Checked: checked_add, checked_sub, checked_div

Expand Down Expand Up @@ -315,7 +316,7 @@ const UF = (N0f8, N6f10, N4f12, N2f14, N0f16)
promote_rule(::Type{X}, ::Type{Tf}) where {X <: FixedPoint, Tf <: AbstractFloat} =
promote_type(floattype(X), Tf)

# Note that `Tr` does not always have enough domains.
# Note that `Tr` does not always have enough domains.
promote_rule(::Type{X}, ::Type{Tr}) where {X <: FixedPoint, Tr <: Rational} = Tr

promote_rule(::Type{X}, ::Type{Ti}) where {X <: FixedPoint, Ti <: Integer} = floattype(X)
Expand Down Expand Up @@ -382,8 +383,15 @@ scaledual(::Type{Tdual}, x::AbstractArray{T}) where {Tdual, T <: FixedPoint} =
throw(ArgumentError(String(take!(io))))
end

rand(::Type{T}) where {T <: FixedPoint} = reinterpret(T, rand(rawtype(T)))
rand(::Type{T}, sz::Dims) where {T <: FixedPoint} = reinterpret(T, rand(rawtype(T), sz))
function Random.rand(r::AbstractRNG, ::SamplerType{X}) where X <: FixedPoint
X(rand(r, rawtype(X)), 0)
end

function rand!(r::AbstractRNG, A::Array{X}, ::SamplerType{X}) where {T, X <: FixedPoint{T}}
At = unsafe_wrap(Array, reinterpret(Ptr{T}, pointer(A)), size(A))
Random.rand!(r, At, SamplerType{T}())
A
end

if VERSION >= v"1.1" # work around https://github.com/JuliaLang/julia/issues/34121
include("precompile.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/common.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using FixedPointNumbers, Statistics, Test
using FixedPointNumbers, Statistics, Random, Test
using FixedPointNumbers: bitwidth, rawtype, nbitsfrac

"""
Expand Down
1 change: 1 addition & 0 deletions test/fixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ end
@test ndims(a) == 2 && eltype(a) === F
@test size(a) == (3,5)
end
@test rand(MersenneTwister(1234), Q0f7) === -0.156Q0f7
end

@testset "Promotion within Fixed" begin
Expand Down
1 change: 1 addition & 0 deletions test/normed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ end
@test ndims(a) == 2 && eltype(a) === N
@test size(a) == (3,5)
end
@test rand(MersenneTwister(1234), N0f8) === 0.925N0f8
end

@testset "Promotion within Normed" begin
Expand Down

0 comments on commit 3e41a6a

Please sign in to comment.