Skip to content

Commit

Permalink
Move arithmetic functions into submodule FixedPointArithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
kimikage committed Apr 29, 2024
1 parent 2604b5a commit 65c6363
Show file tree
Hide file tree
Showing 7 changed files with 437 additions and 305 deletions.
227 changes: 28 additions & 199 deletions src/FixedPointNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ import Base: ==, <, <=, -, +, *, /, ~, isapprox,

import Random: Random, AbstractRNG, SamplerType, rand!

import Base.Checked: checked_neg, checked_abs, checked_add, checked_sub, checked_mul,
checked_div, checked_fld, checked_cld, checked_rem, checked_mod

using Base: @pure

"""
Expand All @@ -35,14 +32,11 @@ export
# "special" typealiases
# Q and N typealiases are exported in separate source files
# Functions
scaledual,
wrapping_neg, wrapping_abs, wrapping_add, wrapping_sub, wrapping_mul,
wrapping_div, wrapping_fld, wrapping_cld, wrapping_rem, wrapping_mod,
saturating_neg, saturating_abs, saturating_add, saturating_sub, saturating_mul,
saturating_div, saturating_fld, saturating_cld, saturating_rem, saturating_mod,
wrapping_fdiv, saturating_fdiv, checked_fdiv
scaledual

include("utilities.jl")
using .Utilities
import .Utilities: floattype, rawone, nbitsfrac, rawtype, signbits, nbitsint, scaledual

# reinterpretation
reinterpret(x::FixedPoint) = x.i
Expand All @@ -57,18 +51,6 @@ rawtype(::Type{X}) where {T, X <: FixedPoint{T}} = T
signbits(::Type{X}) where {T, X <: FixedPoint{T}} = T <: Unsigned ? 0 : 1
nbitsint(::Type{X}) where {X <: FixedPoint} = bitwidth(X) - nbitsfrac(X) - signbits(X)

# construction using the (approximate) intended value, i.e., N0f8
*(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)
wrapping_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = x % X
saturating_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = clamp(x, X)
checked_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)

# type modulus
rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
wrapping_rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
saturating_rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
checked_rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)

# constructor-style conversions
(::Type{X})(x::X) where {X <: FixedPoint} = x
(::Type{X})(x::Number) where {X <: FixedPoint} = _convert(X, x)
Expand Down Expand Up @@ -139,9 +121,6 @@ zero(::Type{X}) where {X <: FixedPoint} = X(zero(rawtype(X)), 0)
oneunit(::Type{X}) where {X <: FixedPoint} = X(rawone(X), 0)
one(::Type{X}) where {X <: FixedPoint} = oneunit(X)

# for Julia v1.0, which does not fold `div_float` before inlining
inv_rawone(x) = (@generated) ? (y = 1.0 / rawone(x); :($y)) : 1.0 / rawone(x)

# traits
eps(::Type{X}) where {X <: FixedPoint} = X(oneunit(rawtype(X)), 0)
typemax(::Type{T}) where {T <: FixedPoint} = T(typemax(rawtype(T)), 0)
Expand Down Expand Up @@ -192,164 +171,12 @@ RGB{Float32}
`RGB` itself is not a subtype of `AbstractFloat`, but unlike `RGB{N0f8}` operations with `RGB{Float32}` are not subject to integer overflow.
"""
floattype(::Type{T}) where {T <: AbstractFloat} = T # fallback (we want a MethodError if no method producing AbstractFloat is defined)
floattype(::Type{T}) where {T <: Union{ShortInts, Bool}} = Float32
floattype(::Type{T}) where {T <: Integer} = Float64
floattype(::Type{T}) where {T <: LongInts} = BigFloat
floattype(::Type{T}) where {I <: Integer, T <: Rational{I}} = typeof(zero(I)/oneunit(I))
floattype(::Type{<:AbstractIrrational}) = Float64
floattype(::Type{X}) where {T <: ShortInts, X <: FixedPoint{T}} = Float32
floattype(::Type{X}) where {T <: Integer, X <: FixedPoint{T}} = Float64
floattype(::Type{X}) where {T <: LongInts, X <: FixedPoint{T}} = BigFloat

# Non-Real types
floattype(::Type{Complex{T}}) where T = Complex{floattype(T)}
floattype(::Type{Base.TwicePrecision{Float64}}) = Float64 # wider would be nice, but hardware support is paramount
floattype(::Type{Base.TwicePrecision{T}}) where T<:Union{Float16,Float32} = widen(T)

float(x::FixedPoint) = convert(floattype(x), x)

# wrapping arithmetic
wrapping_neg(x::X) where {X <: FixedPoint} = X(-x.i, 0)
wrapping_abs(x::X) where {X <: FixedPoint} = X(abs(x.i), 0)
wrapping_add(x::X, y::X) where {X <: FixedPoint} = X(x.i + y.i, 0)
wrapping_sub(x::X, y::X) where {X <: FixedPoint} = X(x.i - y.i, 0)
wrapping_mul(x::X, y::X) where {X <: FixedPoint} = (float(x) * float(y)) % X
function wrapping_fdiv(x::X, y::X) where {X <: FixedPoint}
z = floattype(X)(x.i) / floattype(X)(y.i)
isfinite(z) ? z % X : zero(X)
end
function wrapping_div(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
z = round(floattype(X)(x.i) / floattype(X)(y.i), r)
isfinite(z) || return zero(T)
if T <: Unsigned
_unsafe_trunc(T, z)
else
z > typemax(T) ? typemin(T) : _unsafe_trunc(T, z)
end
end
wrapping_fld(x::X, y::X) where {X <: FixedPoint} = wrapping_div(x, y, RoundDown)
wrapping_cld(x::X, y::X) where {X <: FixedPoint} = wrapping_div(x, y, RoundUp)
wrapping_rem(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}} =
X(x.i - wrapping_div(x, y, r) * y.i, 0)
wrapping_mod(x::X, y::X) where {X <: FixedPoint} = wrapping_rem(x, y, RoundDown)

# saturating arithmetic
saturating_neg(x::X) where {X <: FixedPoint} = X(~min(x.i - true, x.i), 0)
saturating_neg(x::X) where {X <: FixedPoint{<:Unsigned}} = zero(X)

saturating_abs(x::X) where {X <: FixedPoint} =
X(ifelse(signbit(abs(x.i)), typemax(x.i), abs(x.i)), 0)

saturating_add(x::X, y::X) where {X <: FixedPoint} =
X(x.i + ifelse(x.i < 0, max(y.i, typemin(x.i) - x.i), min(y.i, typemax(x.i) - x.i)), 0)
saturating_add(x::X, y::X) where {X <: FixedPoint{<:Unsigned}} = X(x.i + min(~x.i, y.i), 0)

saturating_sub(x::X, y::X) where {X <: FixedPoint} =
X(x.i - ifelse(x.i < 0, min(y.i, x.i - typemin(x.i)), max(y.i, x.i - typemax(x.i))), 0)
saturating_sub(x::X, y::X) where {X <: FixedPoint{<:Unsigned}} = X(x.i - min(x.i, y.i), 0)

saturating_mul(x::X, y::X) where {X <: FixedPoint} = clamp(float(x) * float(y), X)

saturating_fdiv(x::X, y::X) where {X <: FixedPoint} =
clamp(floattype(X)(x.i) / floattype(X)(y.i), X)

function saturating_div(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
z = round(floattype(X)(x.i) / floattype(X)(y.i), r)
isnan(z) && return zero(T)
if T <: Unsigned
isfinite(z) ? _unsafe_trunc(T, z) : typemax(T)
else
_unsafe_trunc(T, clamp(z, typemin(T), typemax(T)))
end
end
saturating_fld(x::X, y::X) where {X <: FixedPoint} = saturating_div(x, y, RoundDown)
saturating_cld(x::X, y::X) where {X <: FixedPoint} = saturating_div(x, y, RoundUp)
function saturating_rem(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
T <: Unsigned && r isa RoundingMode{:Up} && return zero(X)
X(x.i - saturating_div(x, y, r) * y.i, 0)
end
saturating_mod(x::X, y::X) where {X <: FixedPoint} = saturating_rem(x, y, RoundDown)

# checked arithmetic
checked_neg(x::X) where {X <: FixedPoint} = checked_sub(zero(X), x)
function checked_abs(x::X) where {X <: FixedPoint}
abs(x.i) >= 0 || throw_overflowerror_abs(x)
X(abs(x.i), 0)
end
function checked_add(x::X, y::X) where {X <: FixedPoint}
r, f = Base.Checked.add_with_overflow(x.i, y.i)
z = X(r, 0) # store first
f && throw_overflowerror(:+, x, y)
z
end
function checked_sub(x::X, y::X) where {X <: FixedPoint}
r, f = Base.Checked.sub_with_overflow(x.i, y.i)
z = X(r, 0) # store first
f && throw_overflowerror(:-, x, y)
z
end
function checked_mul(x::X, y::X) where {X <: FixedPoint}
z = float(x) * float(y)
typemin(X) - eps(X)/2 <= z < typemax(X) + eps(X)/2 || throw_overflowerror(:*, x, y)
z % X
end
function checked_fdiv(x::X, y::X) where {T, X <: FixedPoint{T}}
y === zero(X) && throw(DivideError())
z = floattype(X)(x.i) / floattype(X)(y.i)
if T <: Unsigned
z < typemax(X) + eps(X)/2 || throw_overflowerror(:/, x, y)
else
typemin(X) - eps(X)/2 <= z < typemax(X) + eps(X)/2 || throw_overflowerror(:/, x, y)
end
z % X
end
function checked_div(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
y === zero(X) && throw(DivideError())
z = round(floattype(X)(x.i) / floattype(X)(y.i), r)
if T <: Signed
z <= typemax(T) || throw_overflowerror_div(r, x, y)
end
_unsafe_trunc(T, z)
end
checked_fld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundDown)
checked_cld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundUp)
function checked_rem(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
y === zero(X) && throw(DivideError())
fx, fy = floattype(X)(x.i), floattype(X)(y.i)
z = fx - round(fx / fy, r) * fy
if T <: Unsigned && r isa RoundingMode{:Up}
z >= zero(z) || throw_overflowerror_rem(r, x, y)
end
X(_unsafe_trunc(T, z), 0)
end
checked_mod(x::X, y::X) where {X <: FixedPoint} = checked_rem(x, y, RoundDown)

# default arithmetic
const DEFAULT_ARITHMETIC = :wrapping

for (op, name) in ((:-, :neg), (:abs, :abs))
f = Symbol(DEFAULT_ARITHMETIC, :_, name)
@eval begin
$op(x::X) where {X <: FixedPoint} = $f(x)
end
end
for (op, name) in ((:+, :add), (:-, :sub), (:*, :mul))
f = Symbol(DEFAULT_ARITHMETIC, :_, name)
@eval begin
$op(x::X, y::X) where {X <: FixedPoint} = $f(x, y)
end
end
# force checked arithmetic
/(x::X, y::X) where {X <: FixedPoint} = checked_fdiv(x, y)
div(x::X, y::X, r::RoundingMode = RoundToZero) where {X <: FixedPoint} = checked_div(x, y, r)
fld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundDown)
cld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundUp)
rem(x::X, y::X) where {X <: FixedPoint} = checked_rem(x, y, RoundToZero)
rem(x::X, y::X, ::RoundingMode{:Down}) where {X <: FixedPoint} = checked_rem(x, y, RoundDown)
rem(x::X, y::X, ::RoundingMode{:Up}) where {X <: FixedPoint} = checked_rem(x, y, RoundUp)
mod(x::X, y::X) where {X <: FixedPoint} = checked_rem(x, y, RoundDown)

function minmax(x::X, y::X) where {X <: FixedPoint}
a, b = minmax(reinterpret(x), reinterpret(y))
X(a,0), X(b,0)
Expand Down Expand Up @@ -518,6 +345,31 @@ include("normed.jl")
include("deprecations.jl")
const UF = (N0f8, N6f10, N4f12, N2f14, N0f16)

include("arithmetic/arithmetic.jl")
using .FixedPointArithmetic
# re-export
for name in names(FixedPointArithmetic.Wrapping)
@eval export $name
end
for name in names(FixedPointArithmetic.Saturating)
@eval export $name
end
for name in names(FixedPointArithmetic.Checked)
@eval export $name
end

# construction using the (approximate) intended value, i.e., N0f8
*(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)
Wrapping.wrapping_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = x % X
Saturating.saturating_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = clamp(x, X)
Checked.checked_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)

# type modulus
rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
Wrapping.wrapping_rem(x::Real, ::Type{X}) where {X<:FixedPoint} = _rem(x, X)
Saturating.saturating_rem(x::Real, ::Type{X}) where {X<:FixedPoint} = _rem(x, X)
Checked.checked_rem(x::Real, ::Type{X}) where {X<:FixedPoint} = _rem(x, X)

# Promotions
promote_rule(::Type{X}, ::Type{Tf}) where {X <: FixedPoint, Tf <: AbstractFloat} =
promote_type(floattype(X), Tf)
Expand Down Expand Up @@ -585,29 +437,6 @@ scaledual(::Type{Tdual}, x::AbstractArray{T}) where {Tdual, T <: FixedPoint} =
throw(ArgumentError(String(take!(io))))
end

@noinline function throw_overflowerror(op::Symbol, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
print(io, x, ' ', op, ' ', y, " overflowed for type ")
showtype(io, typeof(x))
throw(OverflowError(String(take!(io))))
end
@noinline function throw_overflowerror_abs(@nospecialize(x))
io = IOBuffer()
print(io, "abs(", x, ") overflowed for type ")
showtype(io, typeof(x))
throw(OverflowError(String(take!(io))))
end
@noinline function throw_overflowerror_div(r::RoundingMode, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
op = r === RoundUp ? "cld(" : r === RoundDown ? "fld(" : "div("
print(io, op, x, ", ", y, ") overflowed for type ", rawtype(x))
throw(OverflowError(String(take!(io))))
end
@noinline function throw_overflowerror_rem(r::RoundingMode, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
print(io, "rem(", x, ", ", y, ", ", r, ") overflowed for type ", typeof(x))
throw(OverflowError(String(take!(io))))
end

function Random.rand(r::AbstractRNG, ::SamplerType{X}) where X <: FixedPoint
X(rand(r, rawtype(X)), 0)
Expand Down
Loading

0 comments on commit 65c6363

Please sign in to comment.