Skip to content

Commit

Permalink
use junos pedersen hash funct
Browse files Browse the repository at this point in the history
  • Loading branch information
rian committed Jul 17, 2024
1 parent a1e2e28 commit 1548f9f
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 150 deletions.
7 changes: 2 additions & 5 deletions contracts/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,9 @@ func PrecomputeAddress(deployerAddress *felt.Felt, salt *felt.Felt, classHash *f
})

constructorCalldataBigIntArr := utils.FeltArrToBigIntArr(constructorCalldata)
constructorCallDataHashInt, _ := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr)
constructorCallDataHashInt := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr)
bigIntArr = append(bigIntArr, constructorCallDataHashInt)

preBigInt, err := curve.Curve.ComputeHashOnElements(bigIntArr)
if err != nil {
return nil, err
}
preBigInt := curve.Curve.ComputeHashOnElements(bigIntArr)
return utils.BigIntToFelt(preBigInt), nil
}
66 changes: 13 additions & 53 deletions curve/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,19 +528,18 @@ func (sc StarkCurve) SignFelt(msgHash, privKey *felt.Felt) (*felt.Felt, *felt.Fe
// Returns:
// - hash: The hash of the list of elements
// - err: An error if any
func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) {
func (sc StarkCurve) HashElements(elems []*big.Int) *big.Int {
if len(elems) == 0 {
elems = append(elems, big.NewInt(0))
}

hash = big.NewInt(0)
hash := new(felt.Felt).SetUint64(0)
for _, h := range elems {
hash, err = sc.PedersenHash([]*big.Int{hash, h})
if err != nil {
return hash, err
}
hFelt := new(felt.Felt).SetBytes(h.Bytes())
hash = sc.PedersenHash(hash, hFelt)
}
return hash, err
hashBytes := hash.Bytes()
return new(big.Int).SetBytes(hashBytes[:])
}

// ComputeHashOnElements computes the hash on the given elements using a golang Pedersen Hash implementation.
Expand All @@ -555,60 +554,21 @@ func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) {
// Returns:
// - hash: The hash of the list of elements
// - err: An error if any
func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int, err error) {
func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int) {
elems = append(elems, big.NewInt(int64(len(elems))))
return Curve.HashElements((elems))
}

// PedersenHash calculates the Pedersen hash of the given elements.
// NOTE: This function assumes the curve has been initialized with constant points
// (ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts)
//
// The function requires that the precomputed constant points have been initiated.
// If the length of `sc.ConstantPoints` is zero, an error is returned.
// The function iterates over the elements in `elems` and performs the Pedersen hash calculation.
// For each element, it checks if the value is within the valid range.
// If the value is invalid, an error is returned.
// For each bit in the element, the function performs an addition operation on `ptx` and `pty`
// using the corresponding constant point from the precomputed constant points.
// If the constant point is a duplicate of `ptx`, an error is returned.
// The function returns the resulting hash and a nil error if the calculation is successful.
// Otherwise, it returns `ptx` and an error describing the issue encountered.
// NOTE: This function just wraps Junos PedersenHash function
// (ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/pedersen_hash.go#L12)
//
// Parameters:
// - elems: An array of big integers representing the elements to hash.
// - elems: An array of felts representing the elements to hash.
// Returns:
// - hash: The resulting Pedersen hash as a big integer.
// - err: An error, if any, encountered during the calculation.
func (sc StarkCurve) PedersenHash(elems []*big.Int) (hash *big.Int, err error) {
if len(sc.ConstantPoints) == 0 {
return hash, fmt.Errorf("must initiate precomputed constant points")
}

ptx := new(big.Int).Set(sc.Gx)
pty := new(big.Int).Set(sc.Gy)
for i, elem := range elems {
x := new(big.Int).Set(elem)

if x.Cmp(big.NewInt(0)) == -1 || x.Cmp(sc.P) >= 0 {
return ptx, fmt.Errorf("invalid x: %v", x)
}

for j := 0; j < 252; j++ {
idx := 2 + (i * 252) + j
xin := new(big.Int).Set(sc.ConstantPoints[idx][0])
yin := new(big.Int).Set(sc.ConstantPoints[idx][1])
if xin.Cmp(ptx) == 0 {
return hash, fmt.Errorf("constant point duplication: %v %v", ptx, xin)
}
if x.Bit(0) == 1 {
ptx, pty = sc.Add(ptx, pty, xin, yin)
}
x = x.Rsh(x, 1)
}
}

return ptx, nil
// - hash: The resulting Pedersen hash as a felt.
func (sc StarkCurve) PedersenHash(felts ...*felt.Felt) (hash *felt.Felt) {
return junoCrypto.PedersenArray(felts...)
}

// PoseidonArray is a function that takes a variadic number of felt.Felt pointers as parameters and
Expand Down
98 changes: 44 additions & 54 deletions curve/curve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/big"
"testing"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/starknet.go/utils"
)

Expand All @@ -21,22 +22,21 @@ import (
//
// none
func BenchmarkPedersenHash(b *testing.B) {
suite := [][]*big.Int{
{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")},
{utils.HexToBN("0x1277312773"), utils.HexToBN("0x872362872362")},
{utils.HexToBN("0x1277312773"), utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")},
{utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.HexToBN("0x872362872362")},
{utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")},
{utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")},
{utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")},
suite := [][]*felt.Felt{
{utils.TestHexToFelt(b, "0x12773"), utils.TestHexToFelt(b, "0x872362")},
{utils.TestHexToFelt(b, "0x1277312773"), utils.TestHexToFelt(b, "0x872362872362")},
{utils.TestHexToFelt(b, "0x1277312773"), utils.TestHexToFelt(b, "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")},
{utils.TestHexToFelt(b, "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.TestHexToFelt(b, "0x872362872362")},
{utils.TestHexToFelt(b, "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.TestHexToFelt(b, "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")},
{utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.TestHexToFelt(b, "0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")},
{utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")},
}

for _, test := range suite {
b.Run(fmt.Sprintf("input_size_%d_%d", test[0].BitLen(), test[1].BitLen()), func(b *testing.B) {
if _, err := Curve.PedersenHash(test); err != nil {
log.Fatal(err)
}
})
b.Run(
fmt.Sprintf("input_size_%d_%d", len(test[0].Bits()), len(test[1].Bits())), func(b *testing.B) {
Curve.PedersenHash(test...)
})
}
}

Expand Down Expand Up @@ -92,21 +92,22 @@ func BenchmarkSignatureVerify(b *testing.B) {
private, _ := Curve.GetRandomPrivateKey()
x, y, _ := Curve.PrivateToPoint(private)

hash, _ := Curve.PedersenHash(
[]*big.Int{
utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"),
utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"),
})

r, s, _ := Curve.Sign(hash, private)

b.Run(fmt.Sprintf("sign_input_size_%d", hash.BitLen()), func(b *testing.B) {
if _, _, err := Curve.Sign(hash, private); err != nil {
hash := Curve.PedersenHash(
[]*felt.Felt{
utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"),
utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"),
}...)
hashBytes := hash.Bytes()
hashBigInt := new(big.Int).SetBytes(hashBytes[:])
r, s, _ := Curve.Sign(hashBigInt, private)

b.Run(fmt.Sprintf("sign_input_size_%d", hashBigInt.BitLen()), func(b *testing.B) {
if _, _, err := Curve.Sign(hashBigInt, private); err != nil {
log.Fatal(err)
}
})
b.Run(fmt.Sprintf("verify_input_size_%d", hash.BitLen()), func(b *testing.B) {
Curve.Verify(hash, r, s, x, y)
b.Run(fmt.Sprintf("verify_input_size_%d", hashBigInt.BitLen()), func(b *testing.B) {
Curve.Verify(hashBigInt, r, s, x, y)
})
}

Expand Down Expand Up @@ -140,29 +141,28 @@ func TestGeneral_PrivateToPoint(t *testing.T) {
// none
func TestGeneral_PedersenHash(t *testing.T) {
testPedersen := []struct {
elements []*big.Int
elements []*felt.Felt
expected *big.Int
}{
{
elements: []*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")},
elements: []*felt.Felt{utils.TestHexToFelt(t, "0x12773"), utils.TestHexToFelt(t, "0x872362")},
expected: utils.HexToBN("0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca"),
},
{
elements: []*big.Int{utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.HexToBN("0x537461726b4e6574204d61696c")},
elements: []*felt.Felt{utils.TestHexToFelt(t, "0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.TestHexToFelt(t, "0x537461726b4e6574204d61696c")},
expected: utils.HexToBN("0x180c0a3d13c1adfaa5cbc251f4fc93cc0e26cec30ca4c247305a7ce50ac807c"),
},
{
elements: []*big.Int{big.NewInt(100), big.NewInt(1000)},
elements: []*felt.Felt{new(felt.Felt).SetUint64(100), new(felt.Felt).SetUint64(1000)},
expected: utils.HexToBN("0x45a62091df6da02dce4250cb67597444d1f465319908486b836f48d0f8bf6e7"),
},
}

for _, tt := range testPedersen {
hash, err := Curve.PedersenHash(tt.elements)
if err != nil {
t.Errorf("Hashing err: %v\n", err)
}
if hash.Cmp(tt.expected) != 0 {
hash := Curve.PedersenHash(tt.elements...)
hashBytes := hash.Bytes()
hashBigInt := new(big.Int).SetBytes(hashBytes[:])
if hashBigInt.Cmp(tt.expected) != 0 {
t.Errorf("incorrect hash: got %v expected %v\n", hash, tt.expected)
}
}
Expand Down Expand Up @@ -312,25 +312,19 @@ func TestGeneral_MultAir(t *testing.T) {
//
// none
func TestGeneral_ComputeHashOnElements(t *testing.T) {
hashEmptyArray, err := Curve.ComputeHashOnElements([]*big.Int{})
hashEmptyArray := Curve.ComputeHashOnElements([]*big.Int{})
expectedHashEmmptyArray := utils.HexToBN("0x49ee3eba8c1600700ee1b87eb599f16716b0b1022947733551fde4050ca6804")
if err != nil {
t.Errorf("Could no hash an empty array %v\n", err)
}
if hashEmptyArray.Cmp(expectedHashEmmptyArray) != 0 {
t.Errorf("Hash empty array wrong value. Expected %v got %v\n", expectedHashEmmptyArray, hashEmptyArray)
}

hashFilledArray, err := Curve.ComputeHashOnElements([]*big.Int{
hashFilledArray := Curve.ComputeHashOnElements([]*big.Int{
big.NewInt(123782376),
big.NewInt(213984),
big.NewInt(128763521321),
})
expectedHashFilledArray := utils.HexToBN("0x7b422405da6571242dfc245a43de3b0fe695e7021c148b918cd9cdb462cac59")

if err != nil {
t.Errorf("Could no hash an array with values %v\n", err)
}
if hashFilledArray.Cmp(expectedHashFilledArray) != 0 {
t.Errorf("Hash filled array wrong value. Expected %v got %v\n", expectedHashFilledArray, hashFilledArray)
}
Expand All @@ -344,14 +338,11 @@ func TestGeneral_ComputeHashOnElements(t *testing.T) {
//
// none
func TestGeneral_HashAndSign(t *testing.T) {
hashy, err := Curve.HashElements([]*big.Int{
hashy := Curve.HashElements([]*big.Int{
big.NewInt(1953658213),
big.NewInt(126947999705460),
big.NewInt(1953658213),
})
if err != nil {
t.Errorf("Hasing elements: %v\n", err)
}

priv, _ := Curve.GetRandomPrivateKey()
x, y, err := Curve.PrivateToPoint(priv)
Expand Down Expand Up @@ -415,33 +406,32 @@ func TestGeneral_ComputeFact(t *testing.T) {
//
// none
func TestGeneral_BadSignature(t *testing.T) {
hash, err := Curve.PedersenHash([]*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")})
if err != nil {
t.Errorf("Hashing err: %v\n", err)
}
hash := Curve.PedersenHash([]*felt.Felt{utils.TestHexToFelt(t, "0x12773"), utils.TestHexToFelt(t, "0x872362")}...)
hashBytes := hash.Bytes()
hashBigInt := new(big.Int).SetBytes(hashBytes[:])

priv, _ := Curve.GetRandomPrivateKey()
x, y, err := Curve.PrivateToPoint(priv)
if err != nil {
t.Errorf("Could not convert random private key to point: %v\n", err)
}

r, s, err := Curve.Sign(hash, priv)
r, s, err := Curve.Sign(hashBigInt, priv)
if err != nil {
t.Errorf("Could not convert gen signature: %v\n", err)
}

badR := new(big.Int).Add(r, big.NewInt(1))
if Curve.Verify(hash, badR, s, x, y) {
if Curve.Verify(hashBigInt, badR, s, x, y) {
t.Errorf("Verified bad signature %v %v\n", r, s)
}

badS := new(big.Int).Add(s, big.NewInt(1))
if Curve.Verify(hash, r, badS, x, y) {
if Curve.Verify(hashBigInt, r, badS, x, y) {
t.Errorf("Verified bad signature %v %v\n", r, s)
}

badHash := new(big.Int).Add(hash, big.NewInt(1))
badHash := new(big.Int).Add(hashBigInt, big.NewInt(1))
if Curve.Verify(badHash, r, s, x, y) {
t.Errorf("Verified bad signature %v %v\n", r, s)
}
Expand Down
6 changes: 1 addition & 5 deletions hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ import (
// - error: an error if any
func ComputeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) {
bigIntArr := utils.FeltArrToBigIntArr(feltArr)

hash, err := curve.Curve.ComputeHashOnElements(bigIntArr)
if err != nil {
return nil, err
}
hash := curve.Curve.ComputeHashOnElements(bigIntArr)
return utils.BigIntToFelt(hash), nil
}

Expand Down
24 changes: 7 additions & 17 deletions merkle/merkle.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewFixedSizeMerkleTree(leaves ...*big.Int) (*FixedSizeMerkleTree, error) {
// Returns:
// - *big.Int: the Merkle hash of the two big integers
// - error: an error if the calculation fails
func MerkleHash(x, y *big.Int) (*big.Int, error) {
func MerkleHash(x, y *big.Int) *big.Int {
if x.Cmp(y) <= 0 {
return curve.Curve.HashElements([]*big.Int{x, y})
}
Expand All @@ -67,17 +67,12 @@ func (mt *FixedSizeMerkleTree) build(leaves []*big.Int) (*big.Int, error) {
newLeaves := []*big.Int{}
for i := 0; i < len(leaves); i += 2 {
if i+1 == len(leaves) {
hash, err := MerkleHash(leaves[i], big.NewInt(0))
if err != nil {
return nil, err
}
hash := MerkleHash(leaves[i], big.NewInt(0))

newLeaves = append(newLeaves, hash)
break
}
hash, err := MerkleHash(leaves[i], leaves[i+1])
if err != nil {
return nil, err
}
hash := MerkleHash(leaves[i], leaves[i+1])
newLeaves = append(newLeaves, hash)
}
return mt.build(newLeaves)
Expand Down Expand Up @@ -125,10 +120,8 @@ func (mt *FixedSizeMerkleTree) recursiveProof(leaf *big.Int, branchIndex int, ha
if index%2 != 0 {
nextProof = branch[index-1]
}
newLeaf, err := MerkleHash(leaf, nextProof)
if err != nil {
return nil, fmt.Errorf("nextproof error: %v", err)
}
newLeaf := MerkleHash(leaf, nextProof)

newHashPath := append(hashPath, nextProof)
return mt.recursiveProof(newLeaf, branchIndex+1, newHashPath)
}
Expand All @@ -150,9 +143,6 @@ func ProofMerklePath(root *big.Int, leaf *big.Int, path []*big.Int) bool {
if len(path) == 0 {
return root.Cmp(leaf) == 0
}
nexLeaf, err := MerkleHash(leaf, path[0])
if err != nil {
return false
}
nexLeaf := MerkleHash(leaf, path[0])
return ProofMerklePath(root, nexLeaf, path[1:])
}
Loading

0 comments on commit 1548f9f

Please sign in to comment.