Skip to content

Commit

Permalink
feat: introduces SaturatingAdd and SaturatingSub (#3519)
Browse files Browse the repository at this point in the history
  • Loading branch information
EclesioMeloJunior committed Oct 25, 2023
1 parent 2d63ae1 commit daa9e25
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 10 deletions.
14 changes: 4 additions & 10 deletions dot/state/slot.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/internal/database"
"github.com/ChainSafe/gossamer/lib/primitives"
"github.com/ChainSafe/gossamer/pkg/scale"
)

Expand Down Expand Up @@ -45,10 +46,10 @@ type headerAndSigner struct {
}

func (s *SlotState) CheckEquivocation(slotNow, slot uint64, header *types.Header,
signer types.AuthorityID) (*types.BabeEquivocationProof, error) {
signer types.AuthorityID) (*types.BabeEquivocationProof, error) { //skipcq: GO-R1005
// We don't check equivocations for old headers out of our capacity.
// checking slotNow is greater than slot to avoid overflow, same as saturating_sub
if saturatingSub(slotNow, slot) > maxSlotCapacity {
if primitives.SaturatingSub(slotNow, slot) > maxSlotCapacity {
return nil, nil
}

Expand Down Expand Up @@ -127,7 +128,7 @@ func (s *SlotState) CheckEquivocation(slotNow, slot uint64, header *types.Header
newFirstSavedSlot := firstSavedSlot

if slotNow-firstSavedSlot >= pruningBound {
newFirstSavedSlot = saturatingSub(slotNow, maxSlotCapacity)
newFirstSavedSlot = primitives.SaturatingSub(slotNow, maxSlotCapacity)

for s := firstSavedSlot; s < newFirstSavedSlot; s++ {
slotEncoded := make([]byte, 8)
Expand Down Expand Up @@ -184,10 +185,3 @@ func (s *SlotState) CheckEquivocation(slotNow, slot uint64, header *types.Header

return nil, nil
}

func saturatingSub(a, b uint64) uint64 {
if a > b {
return a - b
}
return 0
}
89 changes: 89 additions & 0 deletions lib/primitives/math.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package primitives

import (
"fmt"
"unsafe"

"golang.org/x/exp/constraints"
)

// saturatingOperations applies the correct operation
// given the input types
func saturatingOperations[T constraints.Integer](a, b T,
signedSaturatingOperation func(T, T, T, T) T,
unsignedSaturatingOperation func(T, T) T,
) T {
switch any(a).(type) {
case int, int8, int16, int32, int64:
// #nosec G103
sizeOf := (unsafe.Sizeof(a) * 8) - 1

var (
maxValueOfSignedType T = 1<<sizeOf - 1
minValueOfSignedType T = ^maxValueOfSignedType
)

return signedSaturatingOperation(a, b, maxValueOfSignedType, minValueOfSignedType)
case uint, uint8, uint16, uint32, uint64, uintptr:
// the operation ^T(0) gives us the max value of type T
// eg. if T is uint8 then it gives us 255
return unsignedSaturatingOperation(a, b)
}

panic(fmt.Sprintf("type %T not supported while performing SaturatingAdd", a))
}

// SaturatingAdd computes a + b saturating at the numeric bounds instead of overflowing
func SaturatingAdd[T constraints.Integer](a, b T) T {
return saturatingOperations(a, b, saturatingAddSigned, saturatingAddUnsigned)
}

func saturatingAddSigned[T constraints.Integer](a, b, max, min T) T {
if b > 0 && a > max-b {
return max
}

if b < 0 && a < min-b {
return min
}

return a + b
}

func saturatingAddUnsigned[T constraints.Integer](a, b T) T {
// the operation ^T(0) gives us the max value of type T
// eg. if T is uint8 then it gives us 255
max := ^T(0)

if a > max-b {
return max
}
return a + b
}

// SaturatingSub computes a - b saturating at the numeric bounds instead of overflowing
func SaturatingSub[T constraints.Integer](a, b T) T {
return saturatingOperations(a, b, saturatingSubSigned, saturatingSubUnsigned)
}

func saturatingSubSigned[T constraints.Integer](a, b, max, min T) T {
if b < 0 && a > max+b {
return max
}

if b > 0 && a < min+b {
return min
}

return a - b
}

func saturatingSubUnsigned[T constraints.Integer](a, b T) T {
if a > b {
return a - b
}
return 0
}
37 changes: 37 additions & 0 deletions lib/primitives/math_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package primitives

import (
"testing"

"github.com/ethereum/go-ethereum/common/math"
"github.com/stretchr/testify/require"
)

func TestSaturatingAdd(t *testing.T) {
require.Equal(t, uint8(2), SaturatingAdd(uint8(1), uint8(1)))
require.Equal(t, uint8(math.MaxUint8), SaturatingAdd(uint8(math.MaxUint8), 100))

require.Equal(t, uint32(math.MaxUint32), SaturatingAdd(uint32(math.MaxUint32), 100))
require.Equal(t, uint32(100), SaturatingAdd(uint32(0), 100))

// should not be able to overflow in the opposite direction as well
require.Equal(t, int64(math.MinInt64), SaturatingAdd(int64(math.MinInt64), -100))
require.Equal(t, int8(127), SaturatingAdd(int8(120), 7))
require.Equal(t, int8(127), SaturatingAdd(int8(120), 8))
}

func TestSaturatingSub(t *testing.T) {
// -128 - 100 overflows, so it should return just -128
require.Equal(t, int8(math.MinInt8), SaturatingSub(int8(math.MinInt8), 100))
require.Equal(t, int8(0), SaturatingSub(int8(100), 100))

// max - (-1) = max + 1 = overflows, so it should return just max
require.Equal(t, int64(math.MaxInt64), SaturatingSub(int64(math.MaxInt64), -1))

// 2 - 10 = -8 which overflows, then should return just 0
require.Equal(t, uint32(0), SaturatingSub(uint32(2), uint32(10)))
require.Equal(t, uint64(math.MaxUint64), SaturatingSub(uint64(math.MaxUint64), uint64(0)))
}

0 comments on commit daa9e25

Please sign in to comment.