From 1548f9faae9526e355325e3ed58767168ebda669 Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 17 Jul 2024 11:18:31 +0300 Subject: [PATCH] use junos pedersen hash funct --- contracts/contracts.go | 7 +-- curve/curve.go | 66 ++++++---------------------- curve/curve_test.go | 98 +++++++++++++++++++----------------------- hash/hash.go | 6 +-- merkle/merkle.go | 24 +++-------- merkle/merkle_test.go | 14 +++--- typed/typed.go | 19 ++++---- 7 files changed, 84 insertions(+), 150 deletions(-) diff --git a/contracts/contracts.go b/contracts/contracts.go index 79916108..2fc7d5d5 100644 --- a/contracts/contracts.go +++ b/contracts/contracts.go @@ -72,12 +72,9 @@ func PrecomputeAddress(deployerAddress *felt.Felt, salt *felt.Felt, classHash *f }) constructorCalldataBigIntArr := utils.FeltArrToBigIntArr(constructorCalldata) - constructorCallDataHashInt, _ := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr) + constructorCallDataHashInt := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr) bigIntArr = append(bigIntArr, constructorCallDataHashInt) - preBigInt, err := curve.Curve.ComputeHashOnElements(bigIntArr) - if err != nil { - return nil, err - } + preBigInt := curve.Curve.ComputeHashOnElements(bigIntArr) return utils.BigIntToFelt(preBigInt), nil } diff --git a/curve/curve.go b/curve/curve.go index c7199524..164cb8e6 100644 --- a/curve/curve.go +++ b/curve/curve.go @@ -528,19 +528,18 @@ func (sc StarkCurve) SignFelt(msgHash, privKey *felt.Felt) (*felt.Felt, *felt.Fe // Returns: // - hash: The hash of the list of elements // - err: An error if any -func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) { +func (sc StarkCurve) HashElements(elems []*big.Int) *big.Int { if len(elems) == 0 { elems = append(elems, big.NewInt(0)) } - hash = big.NewInt(0) + hash := new(felt.Felt).SetUint64(0) for _, h := range elems { - hash, err = sc.PedersenHash([]*big.Int{hash, h}) - if err != nil { - return hash, err - } + hFelt := new(felt.Felt).SetBytes(h.Bytes()) + hash = sc.PedersenHash(hash, hFelt) } - return hash, err + hashBytes := hash.Bytes() + return new(big.Int).SetBytes(hashBytes[:]) } // ComputeHashOnElements computes the hash on the given elements using a golang Pedersen Hash implementation. @@ -555,60 +554,21 @@ func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) { // Returns: // - hash: The hash of the list of elements // - err: An error if any -func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int, err error) { +func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int) { elems = append(elems, big.NewInt(int64(len(elems)))) return Curve.HashElements((elems)) } // PedersenHash calculates the Pedersen hash of the given elements. -// NOTE: This function assumes the curve has been initialized with constant points -// (ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) -// -// The function requires that the precomputed constant points have been initiated. -// If the length of `sc.ConstantPoints` is zero, an error is returned. -// The function iterates over the elements in `elems` and performs the Pedersen hash calculation. -// For each element, it checks if the value is within the valid range. -// If the value is invalid, an error is returned. -// For each bit in the element, the function performs an addition operation on `ptx` and `pty` -// using the corresponding constant point from the precomputed constant points. -// If the constant point is a duplicate of `ptx`, an error is returned. -// The function returns the resulting hash and a nil error if the calculation is successful. -// Otherwise, it returns `ptx` and an error describing the issue encountered. +// NOTE: This function just wraps Junos PedersenHash function +// (ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/pedersen_hash.go#L12) // // Parameters: -// - elems: An array of big integers representing the elements to hash. +// - elems: An array of felts representing the elements to hash. // Returns: -// - hash: The resulting Pedersen hash as a big integer. -// - err: An error, if any, encountered during the calculation. -func (sc StarkCurve) PedersenHash(elems []*big.Int) (hash *big.Int, err error) { - if len(sc.ConstantPoints) == 0 { - return hash, fmt.Errorf("must initiate precomputed constant points") - } - - ptx := new(big.Int).Set(sc.Gx) - pty := new(big.Int).Set(sc.Gy) - for i, elem := range elems { - x := new(big.Int).Set(elem) - - if x.Cmp(big.NewInt(0)) == -1 || x.Cmp(sc.P) >= 0 { - return ptx, fmt.Errorf("invalid x: %v", x) - } - - for j := 0; j < 252; j++ { - idx := 2 + (i * 252) + j - xin := new(big.Int).Set(sc.ConstantPoints[idx][0]) - yin := new(big.Int).Set(sc.ConstantPoints[idx][1]) - if xin.Cmp(ptx) == 0 { - return hash, fmt.Errorf("constant point duplication: %v %v", ptx, xin) - } - if x.Bit(0) == 1 { - ptx, pty = sc.Add(ptx, pty, xin, yin) - } - x = x.Rsh(x, 1) - } - } - - return ptx, nil +// - hash: The resulting Pedersen hash as a felt. +func (sc StarkCurve) PedersenHash(felts ...*felt.Felt) (hash *felt.Felt) { + return junoCrypto.PedersenArray(felts...) } // PoseidonArray is a function that takes a variadic number of felt.Felt pointers as parameters and diff --git a/curve/curve_test.go b/curve/curve_test.go index fbbfa5e6..8a0c6a86 100644 --- a/curve/curve_test.go +++ b/curve/curve_test.go @@ -7,6 +7,7 @@ import ( "math/big" "testing" + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/utils" ) @@ -21,22 +22,21 @@ import ( // // none func BenchmarkPedersenHash(b *testing.B) { - suite := [][]*big.Int{ - {utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, - {utils.HexToBN("0x1277312773"), utils.HexToBN("0x872362872362")}, - {utils.HexToBN("0x1277312773"), utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")}, - {utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.HexToBN("0x872362872362")}, - {utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")}, - {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")}, - {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")}, + suite := [][]*felt.Felt{ + {utils.TestHexToFelt(b, "0x12773"), utils.TestHexToFelt(b, "0x872362")}, + {utils.TestHexToFelt(b, "0x1277312773"), utils.TestHexToFelt(b, "0x872362872362")}, + {utils.TestHexToFelt(b, "0x1277312773"), utils.TestHexToFelt(b, "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")}, + {utils.TestHexToFelt(b, "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.TestHexToFelt(b, "0x872362872362")}, + {utils.TestHexToFelt(b, "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.TestHexToFelt(b, "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")}, + {utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.TestHexToFelt(b, "0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")}, + {utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")}, } for _, test := range suite { - b.Run(fmt.Sprintf("input_size_%d_%d", test[0].BitLen(), test[1].BitLen()), func(b *testing.B) { - if _, err := Curve.PedersenHash(test); err != nil { - log.Fatal(err) - } - }) + b.Run( + fmt.Sprintf("input_size_%d_%d", len(test[0].Bits()), len(test[1].Bits())), func(b *testing.B) { + Curve.PedersenHash(test...) + }) } } @@ -92,21 +92,22 @@ func BenchmarkSignatureVerify(b *testing.B) { private, _ := Curve.GetRandomPrivateKey() x, y, _ := Curve.PrivateToPoint(private) - hash, _ := Curve.PedersenHash( - []*big.Int{ - utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), - utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"), - }) - - r, s, _ := Curve.Sign(hash, private) - - b.Run(fmt.Sprintf("sign_input_size_%d", hash.BitLen()), func(b *testing.B) { - if _, _, err := Curve.Sign(hash, private); err != nil { + hash := Curve.PedersenHash( + []*felt.Felt{ + utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), + utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"), + }...) + hashBytes := hash.Bytes() + hashBigInt := new(big.Int).SetBytes(hashBytes[:]) + r, s, _ := Curve.Sign(hashBigInt, private) + + b.Run(fmt.Sprintf("sign_input_size_%d", hashBigInt.BitLen()), func(b *testing.B) { + if _, _, err := Curve.Sign(hashBigInt, private); err != nil { log.Fatal(err) } }) - b.Run(fmt.Sprintf("verify_input_size_%d", hash.BitLen()), func(b *testing.B) { - Curve.Verify(hash, r, s, x, y) + b.Run(fmt.Sprintf("verify_input_size_%d", hashBigInt.BitLen()), func(b *testing.B) { + Curve.Verify(hashBigInt, r, s, x, y) }) } @@ -140,29 +141,28 @@ func TestGeneral_PrivateToPoint(t *testing.T) { // none func TestGeneral_PedersenHash(t *testing.T) { testPedersen := []struct { - elements []*big.Int + elements []*felt.Felt expected *big.Int }{ { - elements: []*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, + elements: []*felt.Felt{utils.TestHexToFelt(t, "0x12773"), utils.TestHexToFelt(t, "0x872362")}, expected: utils.HexToBN("0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca"), }, { - elements: []*big.Int{utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.HexToBN("0x537461726b4e6574204d61696c")}, + elements: []*felt.Felt{utils.TestHexToFelt(t, "0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.TestHexToFelt(t, "0x537461726b4e6574204d61696c")}, expected: utils.HexToBN("0x180c0a3d13c1adfaa5cbc251f4fc93cc0e26cec30ca4c247305a7ce50ac807c"), }, { - elements: []*big.Int{big.NewInt(100), big.NewInt(1000)}, + elements: []*felt.Felt{new(felt.Felt).SetUint64(100), new(felt.Felt).SetUint64(1000)}, expected: utils.HexToBN("0x45a62091df6da02dce4250cb67597444d1f465319908486b836f48d0f8bf6e7"), }, } for _, tt := range testPedersen { - hash, err := Curve.PedersenHash(tt.elements) - if err != nil { - t.Errorf("Hashing err: %v\n", err) - } - if hash.Cmp(tt.expected) != 0 { + hash := Curve.PedersenHash(tt.elements...) + hashBytes := hash.Bytes() + hashBigInt := new(big.Int).SetBytes(hashBytes[:]) + if hashBigInt.Cmp(tt.expected) != 0 { t.Errorf("incorrect hash: got %v expected %v\n", hash, tt.expected) } } @@ -312,25 +312,19 @@ func TestGeneral_MultAir(t *testing.T) { // // none func TestGeneral_ComputeHashOnElements(t *testing.T) { - hashEmptyArray, err := Curve.ComputeHashOnElements([]*big.Int{}) + hashEmptyArray := Curve.ComputeHashOnElements([]*big.Int{}) expectedHashEmmptyArray := utils.HexToBN("0x49ee3eba8c1600700ee1b87eb599f16716b0b1022947733551fde4050ca6804") - if err != nil { - t.Errorf("Could no hash an empty array %v\n", err) - } if hashEmptyArray.Cmp(expectedHashEmmptyArray) != 0 { t.Errorf("Hash empty array wrong value. Expected %v got %v\n", expectedHashEmmptyArray, hashEmptyArray) } - hashFilledArray, err := Curve.ComputeHashOnElements([]*big.Int{ + hashFilledArray := Curve.ComputeHashOnElements([]*big.Int{ big.NewInt(123782376), big.NewInt(213984), big.NewInt(128763521321), }) expectedHashFilledArray := utils.HexToBN("0x7b422405da6571242dfc245a43de3b0fe695e7021c148b918cd9cdb462cac59") - if err != nil { - t.Errorf("Could no hash an array with values %v\n", err) - } if hashFilledArray.Cmp(expectedHashFilledArray) != 0 { t.Errorf("Hash filled array wrong value. Expected %v got %v\n", expectedHashFilledArray, hashFilledArray) } @@ -344,14 +338,11 @@ func TestGeneral_ComputeHashOnElements(t *testing.T) { // // none func TestGeneral_HashAndSign(t *testing.T) { - hashy, err := Curve.HashElements([]*big.Int{ + hashy := Curve.HashElements([]*big.Int{ big.NewInt(1953658213), big.NewInt(126947999705460), big.NewInt(1953658213), }) - if err != nil { - t.Errorf("Hasing elements: %v\n", err) - } priv, _ := Curve.GetRandomPrivateKey() x, y, err := Curve.PrivateToPoint(priv) @@ -415,10 +406,9 @@ func TestGeneral_ComputeFact(t *testing.T) { // // none func TestGeneral_BadSignature(t *testing.T) { - hash, err := Curve.PedersenHash([]*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}) - if err != nil { - t.Errorf("Hashing err: %v\n", err) - } + hash := Curve.PedersenHash([]*felt.Felt{utils.TestHexToFelt(t, "0x12773"), utils.TestHexToFelt(t, "0x872362")}...) + hashBytes := hash.Bytes() + hashBigInt := new(big.Int).SetBytes(hashBytes[:]) priv, _ := Curve.GetRandomPrivateKey() x, y, err := Curve.PrivateToPoint(priv) @@ -426,22 +416,22 @@ func TestGeneral_BadSignature(t *testing.T) { t.Errorf("Could not convert random private key to point: %v\n", err) } - r, s, err := Curve.Sign(hash, priv) + r, s, err := Curve.Sign(hashBigInt, priv) if err != nil { t.Errorf("Could not convert gen signature: %v\n", err) } badR := new(big.Int).Add(r, big.NewInt(1)) - if Curve.Verify(hash, badR, s, x, y) { + if Curve.Verify(hashBigInt, badR, s, x, y) { t.Errorf("Verified bad signature %v %v\n", r, s) } badS := new(big.Int).Add(s, big.NewInt(1)) - if Curve.Verify(hash, r, badS, x, y) { + if Curve.Verify(hashBigInt, r, badS, x, y) { t.Errorf("Verified bad signature %v %v\n", r, s) } - badHash := new(big.Int).Add(hash, big.NewInt(1)) + badHash := new(big.Int).Add(hashBigInt, big.NewInt(1)) if Curve.Verify(badHash, r, s, x, y) { t.Errorf("Verified bad signature %v %v\n", r, s) } diff --git a/hash/hash.go b/hash/hash.go index b669753b..301993ae 100644 --- a/hash/hash.go +++ b/hash/hash.go @@ -17,11 +17,7 @@ import ( // - error: an error if any func ComputeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) { bigIntArr := utils.FeltArrToBigIntArr(feltArr) - - hash, err := curve.Curve.ComputeHashOnElements(bigIntArr) - if err != nil { - return nil, err - } + hash := curve.Curve.ComputeHashOnElements(bigIntArr) return utils.BigIntToFelt(hash), nil } diff --git a/merkle/merkle.go b/merkle/merkle.go index b6813bf0..c244b5bd 100644 --- a/merkle/merkle.go +++ b/merkle/merkle.go @@ -45,7 +45,7 @@ func NewFixedSizeMerkleTree(leaves ...*big.Int) (*FixedSizeMerkleTree, error) { // Returns: // - *big.Int: the Merkle hash of the two big integers // - error: an error if the calculation fails -func MerkleHash(x, y *big.Int) (*big.Int, error) { +func MerkleHash(x, y *big.Int) *big.Int { if x.Cmp(y) <= 0 { return curve.Curve.HashElements([]*big.Int{x, y}) } @@ -67,17 +67,12 @@ func (mt *FixedSizeMerkleTree) build(leaves []*big.Int) (*big.Int, error) { newLeaves := []*big.Int{} for i := 0; i < len(leaves); i += 2 { if i+1 == len(leaves) { - hash, err := MerkleHash(leaves[i], big.NewInt(0)) - if err != nil { - return nil, err - } + hash := MerkleHash(leaves[i], big.NewInt(0)) + newLeaves = append(newLeaves, hash) break } - hash, err := MerkleHash(leaves[i], leaves[i+1]) - if err != nil { - return nil, err - } + hash := MerkleHash(leaves[i], leaves[i+1]) newLeaves = append(newLeaves, hash) } return mt.build(newLeaves) @@ -125,10 +120,8 @@ func (mt *FixedSizeMerkleTree) recursiveProof(leaf *big.Int, branchIndex int, ha if index%2 != 0 { nextProof = branch[index-1] } - newLeaf, err := MerkleHash(leaf, nextProof) - if err != nil { - return nil, fmt.Errorf("nextproof error: %v", err) - } + newLeaf := MerkleHash(leaf, nextProof) + newHashPath := append(hashPath, nextProof) return mt.recursiveProof(newLeaf, branchIndex+1, newHashPath) } @@ -150,9 +143,6 @@ func ProofMerklePath(root *big.Int, leaf *big.Int, path []*big.Int) bool { if len(path) == 0 { return root.Cmp(leaf) == 0 } - nexLeaf, err := MerkleHash(leaf, path[0]) - if err != nil { - return false - } + nexLeaf := MerkleHash(leaf, path[0]) return ProofMerklePath(root, nexLeaf, path[1:]) } diff --git a/merkle/merkle_test.go b/merkle/merkle_test.go index f2dd6094..7e241fa4 100644 --- a/merkle/merkle_test.go +++ b/merkle/merkle_test.go @@ -11,7 +11,8 @@ import ( // - t: a pointer to the testing.T object // - proofs: a slice of pointers to big.Int objects representing the proofs // Returns: -// none +// +// none func debugProof(t *testing.T, proofs []*big.Int) { t.Log("...proof") for k, v := range proofs { @@ -26,14 +27,15 @@ func debugProof(t *testing.T, proofs []*big.Int) { // Parameters: // - t: A testing.T object used for reporting test failures and logging. // Returns: -// none +// +// none func TestGeneral_FixedSizeMerkleTree_Check1(t *testing.T) { leaves := []*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4), big.NewInt(5), big.NewInt(6), big.NewInt(7)} merkleTree, err := NewFixedSizeMerkleTree(leaves...) - proof_7_0, _ := MerkleHash(big.NewInt(7), big.NewInt(0)) - proof_1_2, _ := MerkleHash(big.NewInt(1), big.NewInt(2)) - proof_3_4, _ := MerkleHash(big.NewInt(3), big.NewInt(4)) - proof_1_2_3_4, _ := MerkleHash(proof_1_2, proof_3_4) + proof_7_0 := MerkleHash(big.NewInt(7), big.NewInt(0)) + proof_1_2 := MerkleHash(big.NewInt(1), big.NewInt(2)) + proof_3_4 := MerkleHash(big.NewInt(3), big.NewInt(4)) + proof_1_2_3_4 := MerkleHash(proof_1_2, proof_3_4) manualProof := []*big.Int{ big.NewInt(6), proof_7_0, diff --git a/typed/typed.go b/typed/typed.go index 67373bcf..c9427cf0 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -148,19 +148,19 @@ func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage, sc curve. } elements = append(elements, msgEnc) - hash, err = sc.ComputeHashOnElements(elements) - return hash, err + return sc.ComputeHashOnElements(elements), nil } // GetTypedMessageHash calculates the hash of a typed message using the provided StarkCurve. // // Parameters: -// - inType: the type of the message -// - msg: the typed message -// - sc: the StarkCurve used for hashing +// - inType: the type of the message +// - msg: the typed message +// - sc: the StarkCurve used for hashing +// // Returns: -// - hash: the calculated hash -// - err: any error if any +// - hash: the calculated hash +// - err: any error if any func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage, sc curve.StarkCurve) (hash *big.Int, err error) { prim := td.Types[inType] elements := []*big.Int{prim.Encoding} @@ -179,15 +179,14 @@ func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage, sc curv innerElements = append(innerElements, fmtDefinitions...) innerElements = append(innerElements, big.NewInt(int64(len(innerElements)))) - innerHash, err := sc.HashElements(innerElements) + innerHash := sc.HashElements(innerElements) if err != nil { return hash, fmt.Errorf("error hashing internal elements: %v %w", innerElements, err) } elements = append(elements, innerHash) } - hash, err = sc.ComputeHashOnElements(elements) - return hash, err + return sc.ComputeHashOnElements(elements), nil } // GetTypeHash returns the hash of the given type.