From 35c602476ed11b1a54881239b8c72b8a612117a2 Mon Sep 17 00:00:00 2001 From: itsdevbear Date: Tue, 25 Jun 2024 10:40:02 -0400 Subject: [PATCH] chore(ssz): Make generics a little less nasty --- mod/da/pkg/blob/factory.go | 11 +++-- mod/primitives/pkg/merkle/root.go | 24 +++++----- mod/primitives/pkg/merkle/root_test.go | 8 ++-- mod/primitives/pkg/merkle/tree.go | 50 ++++++++++----------- mod/primitives/pkg/merkle/tree_fuzz_test.go | 2 +- mod/primitives/pkg/merkle/tree_test.go | 22 ++++----- mod/primitives/pkg/ssz/merkleize.go | 20 +++++---- 7 files changed, 70 insertions(+), 67 deletions(-) diff --git a/mod/da/pkg/blob/factory.go b/mod/da/pkg/blob/factory.go index 0c928d0ddf..0c8cc69054 100644 --- a/mod/da/pkg/blob/factory.go +++ b/mod/da/pkg/blob/factory.go @@ -141,9 +141,10 @@ func (f *SidecarFactory[BeaconBlockT, BeaconBlockBodyT]) BuildBlockBodyProof( return nil, err } - tree, err := merkle.NewTreeWithMaxLeaves[ - [32]byte, [32]byte, - ](membersRoots, body.Length()-1) + tree, err := merkle.NewTreeWithMaxLeaves[[32]byte]( + membersRoots, + body.Length()-1, + ) if err != nil { return nil, err } @@ -159,9 +160,7 @@ func (f *SidecarFactory[BeaconBlockT, BeaconBlockBodyT]) BuildCommitmentProof( startTime := time.Now() defer f.metrics.measureBuildCommitmentProofDuration(startTime) - bodyTree, err := merkle.NewTreeWithMaxLeaves[ - [32]byte, [32]byte, - ]( + bodyTree, err := merkle.NewTreeWithMaxLeaves[[32]byte]( body.GetBlobKzgCommitments().Leafify(), f.chainSpec.MaxBlobCommitmentsPerBlock(), ) diff --git a/mod/primitives/pkg/merkle/root.go b/mod/primitives/pkg/merkle/root.go index 295f49ee7f..bf69b12d39 100644 --- a/mod/primitives/pkg/merkle/root.go +++ b/mod/primitives/pkg/merkle/root.go @@ -43,18 +43,18 @@ const ( ) // NewRootWithMaxLeaves constructs a Merkle tree root from a set of. -func NewRootWithMaxLeaves[U64T U64[U64T], LeafT, RootT ~[32]byte]( - leaves []LeafT, +func NewRootWithMaxLeaves[U64T U64[U64T], RootT ~[32]byte]( + leaves []RootT, length uint64, ) (RootT, error) { - return NewRootWithDepth[LeafT, RootT]( + return NewRootWithDepth[RootT]( leaves, math.U64(length).NextPowerOfTwo().ILog2Ceil(), ) } // NewRootWithDepth constructs a Merkle tree root from a set of leaves. -func NewRootWithDepth[LeafT, RootT ~[32]byte]( - leaves []LeafT, +func NewRootWithDepth[RootT ~[32]byte]( + leaves []RootT, depth uint8, ) (RootT, error) { // Return zerohash at depth @@ -70,7 +70,7 @@ func NewRootWithDepth[LeafT, RootT ~[32]byte]( leaves = append(leaves, zerohash) } var err error - leaves, err = BuildParentTreeRoots[LeafT, LeafT](leaves) + leaves, err = BuildParentTreeRoots[RootT](leaves) if err != nil { return zero.Hashes[depth], err } @@ -78,15 +78,15 @@ func NewRootWithDepth[LeafT, RootT ~[32]byte]( if len(leaves) != 1 { return zero.Hashes[depth], nil } - return RootT(leaves[0]), nil + return leaves[0], nil } // BuildParentTreeRoots calls BuildParentTreeRootsWithNRoutines with the // number of routines set to runtime.GOMAXPROCS(0)-1. -func BuildParentTreeRoots[LeafT, RootT ~[32]byte]( - inputList []LeafT, +func BuildParentTreeRoots[RootT ~[32]byte]( + inputList []RootT, ) ([]RootT, error) { - return BuildParentTreeRootsWithNRoutines[LeafT, RootT]( + return BuildParentTreeRootsWithNRoutines[RootT]( inputList, runtime.GOMAXPROCS(0)-1, ) } @@ -95,8 +95,8 @@ func BuildParentTreeRoots[LeafT, RootT ~[32]byte]( // using CPU-specific vector instructions and parallel processing. This // method adapts to the host machine's hardware for potential performance // gains over sequential hashing. -func BuildParentTreeRootsWithNRoutines[LeafT, RootT ~[32]byte]( - inputList []LeafT, n int, +func BuildParentTreeRootsWithNRoutines[RootT ~[32]byte]( + inputList []RootT, n int, ) ([]RootT, error) { // Validate input list length. inputLength := len(inputList) diff --git a/mod/primitives/pkg/merkle/root_test.go b/mod/primitives/pkg/merkle/root_test.go index 4af82656e7..205b45e259 100644 --- a/mod/primitives/pkg/merkle/root_test.go +++ b/mod/primitives/pkg/merkle/root_test.go @@ -63,7 +63,7 @@ func Test_HashTreeRootEqualInputs(t *testing.T) { go func() { defer wg.Done() var tempHash [][32]byte - tempHash, err = merkle.BuildParentTreeRoots[[32]byte, [32]byte]( + tempHash, err = merkle.BuildParentTreeRoots[[32]byte]( largeSlice, ) copy(hash1, tempHash) @@ -71,7 +71,7 @@ func Test_HashTreeRootEqualInputs(t *testing.T) { wg.Wait() require.NoError(t, err) - hash2, err = merkle.BuildParentTreeRoots[[32]byte, [32]byte]( + hash2, err = merkle.BuildParentTreeRoots[[32]byte]( secondLargeSlice, ) require.NoError(t, err) @@ -155,7 +155,7 @@ func TestBuildParentTreeRootsWithNRoutines_DivisionByZero(t *testing.T) { // Attempt to call BuildParentTreeRootsWithNRoutines with n set to 0 // to test handling of division by zero. inputList := make([][32]byte, 10) // Arbitrary size larger than 0 - _, err := merkle.BuildParentTreeRootsWithNRoutines[[32]byte, [32]byte]( + _, err := merkle.BuildParentTreeRootsWithNRoutines[[32]byte]( inputList, 0, ) @@ -183,7 +183,7 @@ func requireGoHashTreeEquivalence( go func() { defer wg.Done() var err error - output, err = merkle.BuildParentTreeRootsWithNRoutines[[32]byte, [32]byte]( + output, err = merkle.BuildParentTreeRootsWithNRoutines[[32]byte]( inputList, numRoutines, ) diff --git a/mod/primitives/pkg/merkle/tree.go b/mod/primitives/pkg/merkle/tree.go index 28de55626e..7ef1812ccd 100644 --- a/mod/primitives/pkg/merkle/tree.go +++ b/mod/primitives/pkg/merkle/tree.go @@ -36,20 +36,20 @@ const ( MaxTreeDepth = 62 ) -// Tree[LeafT, RootT] implements a Merkle tree that has been optimized to +// Tree[RootT] implements a Merkle tree that has been optimized to // handle leaves that are 32 bytes in size. -type Tree[LeafT, RootT ~[32]byte] struct { +type Tree[RootT ~[32]byte] struct { depth uint8 - branches [][]LeafT - leaves []LeafT + branches [][]RootT + leaves []RootT } // NewTreeFromLeaves constructs a Merkle tree, with the minimum // depth required to support the number of leaves. -func NewTreeFromLeaves[LeafT, RootT ~[32]byte]( - leaves []LeafT, -) (*Tree[LeafT, RootT], error) { - return NewTreeFromLeavesWithDepth[LeafT, RootT]( +func NewTreeFromLeaves[RootT ~[32]byte]( + leaves []RootT, +) (*Tree[RootT], error) { + return NewTreeFromLeavesWithDepth[RootT]( leaves, math.U64(len(leaves)).NextPowerOfTwo().ILog2Ceil(), ) @@ -57,11 +57,11 @@ func NewTreeFromLeaves[LeafT, RootT ~[32]byte]( // NewTreeWithMaxLeaves constructs a Merkle tree with a maximum number of // leaves. -func NewTreeWithMaxLeaves[LeafT, RootT ~[32]byte]( - leaves []LeafT, +func NewTreeWithMaxLeaves[RootT ~[32]byte]( + leaves []RootT, maxLeaves uint64, -) (*Tree[LeafT, RootT], error) { - return NewTreeFromLeavesWithDepth[LeafT, RootT]( +) (*Tree[RootT], error) { + return NewTreeFromLeavesWithDepth[RootT]( leaves, math.U64(maxLeaves).NextPowerOfTwo().ILog2Ceil(), ) @@ -69,15 +69,15 @@ func NewTreeWithMaxLeaves[LeafT, RootT ~[32]byte]( // NewTreeFromLeaves constructs a Merkle tree from a sequence of byte slices. // It will fill the tree with zero hashes to create the required depth. -func NewTreeFromLeavesWithDepth[LeafT, RootT ~[32]byte]( - leaves []LeafT, +func NewTreeFromLeavesWithDepth[RootT ~[32]byte]( + leaves []RootT, depth uint8, -) (*Tree[LeafT, RootT], error) { +) (*Tree[RootT], error) { if err := verifySufficientDepth(len(leaves), depth); err != nil { - return &Tree[LeafT, RootT]{}, err + return &Tree[RootT]{}, err } - layers := make([][]LeafT, depth+1) + layers := make([][]RootT, depth+1) layers[0] = leaves var err error @@ -86,13 +86,13 @@ func NewTreeFromLeavesWithDepth[LeafT, RootT ~[32]byte]( if len(currentLayer)%2 == 1 { currentLayer = append(currentLayer, zero.Hashes[d]) } - layers[d+1], err = BuildParentTreeRoots[LeafT, LeafT](currentLayer) + layers[d+1], err = BuildParentTreeRoots[RootT](currentLayer) if err != nil { - return &Tree[LeafT, RootT]{}, err + return &Tree[RootT]{}, err } } - return &Tree[LeafT, RootT]{ + return &Tree[RootT]{ branches: layers, leaves: leaves, depth: depth, @@ -100,7 +100,7 @@ func NewTreeFromLeavesWithDepth[LeafT, RootT ~[32]byte]( } // Insert an item into the tree. -func (m *Tree[LeafT, RootT]) Insert(item [32]byte, index int) error { +func (m *Tree[RootT]) Insert(item [32]byte, index int) error { if index < 0 { return errors.Wrap(ErrNegativeIndex, fmt.Sprintf("index: %d", index)) } @@ -159,13 +159,13 @@ func (m *Tree[LeafT, RootT]) Insert(item [32]byte, index int) error { } // Root returns the root of the Merkle tree. -func (m *Tree[LeafT, RootT]) Root() [32]byte { +func (m *Tree[RootT]) Root() [32]byte { return m.branches[len(m.branches)-1][0] } // HashTreeRoot returns the Root of the Merkle tree with the // number of leaves mixed in. -func (m *Tree[LeafT, RootT]) HashTreeRoot() ([32]byte, error) { +func (m *Tree[RootT]) HashTreeRoot() ([32]byte, error) { numItems := uint64(len(m.leaves)) if len(m.leaves) == 1 && m.leaves[0] == zero.Hashes[0] { @@ -175,7 +175,7 @@ func (m *Tree[LeafT, RootT]) HashTreeRoot() ([32]byte, error) { } // MerkleProof computes a proof from a tree's branches using a Merkle index. -func (m *Tree[LeafT, RootT]) MerkleProof(leafIndex uint64) ([][32]byte, error) { +func (m *Tree[RootT]) MerkleProof(leafIndex uint64) ([][32]byte, error) { numLeaves := uint64(len(m.branches[0])) if leafIndex >= numLeaves { return nil, errors.Newf( @@ -198,7 +198,7 @@ func (m *Tree[LeafT, RootT]) MerkleProof(leafIndex uint64) ([][32]byte, error) { // MerkleProofWithMixin computes a proof from a tree's branches using a Merkle // index. -func (m *Tree[LeafT, RootT]) MerkleProofWithMixin( +func (m *Tree[RootT]) MerkleProofWithMixin( index uint64, ) ([][32]byte, error) { proof, err := m.MerkleProof(index) diff --git a/mod/primitives/pkg/merkle/tree_fuzz_test.go b/mod/primitives/pkg/merkle/tree_fuzz_test.go index 939538c15c..f2b0eaa5b3 100644 --- a/mod/primitives/pkg/merkle/tree_fuzz_test.go +++ b/mod/primitives/pkg/merkle/tree_fuzz_test.go @@ -55,7 +55,7 @@ func FuzzTree_IsValidMerkleBranch(f *testing.F) { byteslib.ToBytes32([]byte("G")), byteslib.ToBytes32([]byte("H")), } - m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, depth, ) diff --git a/mod/primitives/pkg/merkle/tree_test.go b/mod/primitives/pkg/merkle/tree_test.go index 39face3178..7fd221c12d 100644 --- a/mod/primitives/pkg/merkle/tree_test.go +++ b/mod/primitives/pkg/merkle/tree_test.go @@ -34,7 +34,7 @@ const ( ) func TestNewTreeFromLeavesWithDepth_NoItemsProvided(t *testing.T) { - _, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + _, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( nil, treeDepth, ) @@ -52,7 +52,7 @@ func TestNewTreeFromLeavesWithDepth_DepthSupport(t *testing.T) { byteslib.ToBytes32([]byte("GGGGGGG")), } // Supported depth - m1, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m1, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, merkle.MaxTreeDepth, ) @@ -61,7 +61,7 @@ func TestNewTreeFromLeavesWithDepth_DepthSupport(t *testing.T) { require.NoError(t, err) require.Len(t, proof, int(merkle.MaxTreeDepth)+1) // Unsupported depth - _, err = merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + _, err = merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, merkle.MaxTreeDepth+1, ) @@ -79,7 +79,7 @@ func TestMerkleTree_IsValidMerkleBranch(t *testing.T) { byteslib.ToBytes32([]byte("G")), byteslib.ToBytes32([]byte("H")), } - m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) @@ -131,7 +131,7 @@ func TestMerkleTree_VerifyProof(t *testing.T) { byteslib.ToBytes32([]byte("H")), } - m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) @@ -173,7 +173,7 @@ func TestMerkleTree_NegativeIndexes(t *testing.T) { byteslib.ToBytes32([]byte("G")), byteslib.ToBytes32([]byte("H")), } - m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) @@ -189,7 +189,7 @@ func TestMerkleTree_VerifyProof_TrieUpdated(t *testing.T) { {3}, {4}, } - m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth+1, ) @@ -236,7 +236,7 @@ func BenchmarkNewTreeFromLeavesWithDepth(b *testing.B) { byteslib.ToBytes32([]byte("GGGGGGG")), } for i := 0; i < b.N; i++ { - _, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + _, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) @@ -251,7 +251,7 @@ func BenchmarkInsertTrie_Optimized(b *testing.B) { for i := range numDeposits { items[i] = byteslib.ToBytes32([]byte(strconv.Itoa(i))) } - tr, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + tr, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) @@ -275,7 +275,7 @@ func BenchmarkGenerateProof(b *testing.B) { byteslib.ToBytes32([]byte("FFFFFF")), byteslib.ToBytes32([]byte("GGGGGGG")), } - goodTree, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + goodTree, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) @@ -299,7 +299,7 @@ func BenchmarkIsValidMerkleBranch(b *testing.B) { byteslib.ToBytes32([]byte("FFFFFF")), byteslib.ToBytes32([]byte("GGGGGGG")), } - m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte, [32]byte]( + m, err := merkle.NewTreeFromLeavesWithDepth[[32]byte]( items, treeDepth, ) diff --git a/mod/primitives/pkg/ssz/merkleize.go b/mod/primitives/pkg/ssz/merkleize.go index 476f78834d..0c526cc71e 100644 --- a/mod/primitives/pkg/ssz/merkleize.go +++ b/mod/primitives/pkg/ssz/merkleize.go @@ -24,6 +24,7 @@ import ( "reflect" "github.com/berachain/beacon-kit/mod/errors" + "github.com/berachain/beacon-kit/mod/primitives/pkg/common" "github.com/berachain/beacon-kit/mod/primitives/pkg/merkle" ) @@ -150,7 +151,10 @@ func MerkleizeVecComposite[ copy(htrs.Bytes[i][:], htr[:]) } } - return Merkleize[U64T, RootT](htrs.Bytes) + r, err := Merkleize[U64T, common.Root]( + htrs.Bytes, + ) + return RootT(r), err } // MerkleizeListComposite implements the SSZ merkleization algorithm for a list @@ -180,14 +184,14 @@ func MerkleizeListComposite[ return RootT{}, errors.New("htrs.Bytes is nil") } } - root, err := Merkleize[U64T, RootT]( + root, err := Merkleize[U64T, common.Root]( htrs.Bytes, ChunkCountCompositeList[C](value, limit), ) if err != nil { return RootT{}, err } - return merkle.MixinLength(root, uint64(len(value))), nil + return RootT(merkle.MixinLength(root, uint64(len(value)))), nil } // Merkleize hashes a list of chunks and returns the HTR of the list of. @@ -211,13 +215,13 @@ func MerkleizeListComposite[ // Then, merkleize the chunks (empty input is padded to 1 zero chunk): // If 1 chunk: the root is the chunk itself. // If > 1 chunks: merkleize as binary tree. -func Merkleize[U64T U64[U64T], RootT, ChunkT ~[32]byte]( - chunks []ChunkT, +func Merkleize[U64T U64[U64T], RootT ~[32]byte]( + chunks []RootT, limit ...uint64, ) (RootT, error) { var ( effectiveLimit U64T - effectiveChunks []ChunkT + effectiveChunks []RootT lenChunks = uint64(len(chunks)) ) @@ -240,10 +244,10 @@ func Merkleize[U64T U64[U64T], RootT, ChunkT ~[32]byte]( effectiveChunks = PadTo(chunks, effectiveLimit) if len(effectiveChunks) == 1 { - return RootT(effectiveChunks[0]), nil + return effectiveChunks[0], nil } - return merkle.NewRootWithMaxLeaves[U64T, ChunkT, RootT]( + return merkle.NewRootWithMaxLeaves[U64T]( effectiveChunks, //#nosec:G701 // This is a safe operation. uint64(effectiveLimit),