diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 8d15a115ae..68b94a6bc0 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -17,6 +17,7 @@ limitations under the License. package rollup import ( + tedwards "github.com/consensys/gnark-crypto/ecc/twistededwards" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/accumulator/merkle" "github.com/consensys/gnark/std/algebra/twistededwards" @@ -87,18 +88,8 @@ type TransferConstraints struct { } func (circuit *Circuit) postInit(api frontend.API) error { - // edward curve params - params, err := twistededwards.NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } for i := 0; i < batchSize; i++ { - // setting sender public key - circuit.PublicKeysSender[i].Curve = params - - // setting receiver public key - circuit.PublicKeysReceiver[i].Curve = params // setting the sender accounts before update circuit.SenderAccountsBefore[i].PubKey = circuit.PublicKeysSender[i] @@ -163,7 +154,13 @@ func verifyTransferSignature(api frontend.API, t TransferConstraints, hFunc mimc hFunc.Write(t.Nonce, t.Amount, t.SenderPubKey.A.X, t.SenderPubKey.A.Y, t.ReceiverPubKey.A.X, t.ReceiverPubKey.A.Y) htransfer := hFunc.Sum() - err := eddsa.Verify(api, t.Signature, htransfer, t.SenderPubKey) + curve, err := twistededwards.NewEdCurve(api, tedwards.BN254) + if err != nil { + return err + } + + hFunc.Reset() + err = eddsa.Verify(curve, t.Signature, htransfer, t.SenderPubKey, &hFunc) if err != nil { return err } diff --git a/examples/rollup/rollup_test.go b/examples/rollup/rollup_test.go index 6759b30f10..d4a6d74360 100644 --- a/examples/rollup/rollup_test.go +++ b/examples/rollup/rollup_test.go @@ -154,7 +154,12 @@ func createAccount(i int) (Account, eddsa.PrivateKey) { src := rand.NewSource(int64(i)) r := rand.New(src) - privkey, _ = eddsa.GenerateKey(r) + pkey, err := eddsa.GenerateKey(r) + if err != nil { + panic(err) + } + privkey = *pkey + acc.pubKey = privkey.PublicKey return acc, privkey diff --git a/frontend/compile.go b/frontend/compile.go index a0f98c19f7..c6d1140b87 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -81,6 +81,7 @@ func parseCircuit(builder Builder, circuit Circuit) (err error) { // leafs are Constraints that need to be initialized in the context of compiling a circuit var handler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { if tInput.CanSet() { + // log.Trace().Str("name", name).Str("visibility", visibility.String()).Msg("init input wire") switch visibility { case schema.Secret: tInput.Set(reflect.ValueOf(builder.AddSecretVariable(name))) diff --git a/go.mod b/go.mod index f1635be729..d07a8940e2 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.17 require ( github.com/consensys/bavard v0.1.10 - github.com/consensys/gnark-crypto v0.6.1 - github.com/fxamacker/cbor/v2 v2.4.0 + github.com/consensys/gnark-crypto v0.6.2-0.20220317143658-fb0d80a11bf4 + github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/rs/zerolog v1.26.1 github.com/stretchr/testify v1.7.1 diff --git a/go.sum b/go.sum index c184a0164e..94c6aeee84 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,14 @@ -github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/bavard v0.1.10 h1:1I/IvY7bkX/O7QLNCEuV2+YBKdTetzw3gnBbvFaWiEE= github.com/consensys/bavard v0.1.10/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04= -github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= +github.com/consensys/gnark-crypto v0.6.2-0.20220317143658-fb0d80a11bf4 h1:ZsuTwNqDe83xtYP8SplQ9iOoXgOoLg9WzP04VfqOjGc= +github.com/consensys/gnark-crypto v0.6.2-0.20220317143658-fb0d80a11bf4/go.mod h1:BnexKTAHX6j7zpGXR/s6E/R0tyYtbnXlbhIMQkNdcPs= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= -github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.2.0 h1:6eXqdDDe588rSYAi1HfZKbx6YYQO4mxQ9eC6xYpU/JQ= +github.com/fxamacker/cbor/v2 v2.2.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index dd68123181..4959594248 100644 Binary files a/internal/stats/latest.stats and b/internal/stats/latest.stats differ diff --git a/std/algebra/twistededwards/bandersnatch/curve.go b/std/algebra/twistededwards/bandersnatch/curve.go deleted file mode 100644 index 0cfb3c071e..0000000000 --- a/std/algebra/twistededwards/bandersnatch/curve.go +++ /dev/null @@ -1,76 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package bandersnatch - -import ( - "errors" - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - bandersnatch "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards/bandersnatch" - "github.com/consensys/gnark/internal/utils" -) - -// Coordinates of a point on a twisted Edwards curve -type Coord struct { - X, Y big.Int -} - -// EdCurve stores the info on the chosen edwards curve -type EdCurve struct { - A, D, Cofactor, Order big.Int - Base Coord - ID ecc.ID -} - -var constructors map[ecc.ID]func() EdCurve - -func init() { - constructors = map[ecc.ID]func() EdCurve{ - ecc.BLS12_381: newBandersnatch, - } -} - -// NewEdCurve returns an Edwards curve parameters -func NewEdCurve(id ecc.ID) (EdCurve, error) { - if constructor, ok := constructors[id]; ok { - return constructor(), nil - } - return EdCurve{}, errors.New("unknown curve id") -} - -// ------------------------------------------------------------------------------------------------- -// constructors - -func newBandersnatch() EdCurve { - - edcurve := bandersnatch.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BLS12_381, - } - -} diff --git a/std/algebra/twistededwards/bandersnatch/point.go b/std/algebra/twistededwards/bandersnatch/point.go deleted file mode 100644 index 61eb558522..0000000000 --- a/std/algebra/twistededwards/bandersnatch/point.go +++ /dev/null @@ -1,179 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package bandersnatch - -import ( - "math/big" - - "github.com/consensys/gnark/frontend" -) - -// Point point on a twisted Edwards curve in a Snark cs -type Point struct { - X, Y frontend.Variable -} - -// Neg computes the negative of a point in SNARK coordinates -func (p *Point) Neg(api frontend.API, p1 *Point) *Point { - p.X = api.Neg(p1.X) - p.Y = p1.Y - return p -} - -// MustBeOnCurve checks if a point is on the reduced twisted Edwards curve -// a*x² + y² = 1 + d*x²*y². -func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { - - one := big.NewInt(1) - - xx := api.Mul(p.X, p.X) - yy := api.Mul(p.Y, p.Y) - axx := api.Mul(xx, &curve.A) - lhs := api.Add(axx, yy) - - dxx := api.Mul(xx, &curve.D) - dxxyy := api.Mul(dxx, yy) - rhs := api.Add(dxxyy, one) - - api.AssertIsEqual(lhs, rhs) - -} - -// Add Adds two points on a twisted edwards curve (eg jubjub) -// p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { - - // u = (x1 + y1) * (x2 + y2) - u1 := api.Mul(p1.X, &curve.A) - u1 = api.Sub(p1.Y, u1) - u2 := api.Add(p2.X, p2.Y) - u := api.Mul(u1, u2) - - // v0 = x1 * y2 - v0 := api.Mul(p2.Y, p1.X) - - // v1 = x2 * y1 - v1 := api.Mul(p2.X, p1.Y) - - // v2 = d * v0 * v1 - v2 := api.Mul(&curve.D, v0, v1) - - // x = (v0 + v1) / (1 + v2) - p.X = api.Add(v0, v1) - p.X = api.DivUnchecked(p.X, api.Add(1, v2)) - - // y = (u + a * v0 - v1) / (1 - v2) - p.Y = api.Mul(&curve.A, v0) - p.Y = api.Sub(p.Y, v1) - p.Y = api.Add(p.Y, u) - p.Y = api.DivUnchecked(p.Y, api.Sub(1, v2)) - - return p -} - -// Double doubles a points in SNARK coordinates -func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { - - u := api.Mul(p1.X, p1.Y) - v := api.Mul(p1.X, p1.X) - w := api.Mul(p1.Y, p1.Y) - - n1 := api.Mul(2, u) - av := api.Mul(v, &curve.A) - n2 := api.Sub(w, av) - d1 := api.Add(w, av) - d2 := api.Sub(2, d1) - - p.X = api.DivUnchecked(n1, d1) - p.Y = api.DivUnchecked(n2, d2) - - return p -} - -// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve -// p1: base point (as snark point) -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { - - // first unpack the scalar - b := api.ToBinary(scalar) - - res := Point{} - tmp := Point{} - A := Point{} - B := Point{} - - A.Double(api, p1, curve) - B.Add(api, &A, p1, curve) - - n := len(b) - 1 - res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X) - res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y) - - for i := n - 2; i >= 1; i -= 2 { - res.Double(api, &res, curve). - Double(api, &res, curve) - tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X) - tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y) - res.Add(api, &res, &tmp, curve) - } - - if n%2 == 0 { - res.Double(api, &res, curve) - tmp.Add(api, &res, p1, curve) - res.X = api.Select(b[0], tmp.X, res.X) - res.Y = api.Select(b[0], tmp.Y, res.Y) - } - - p.X = res.X - p.Y = res.Y - - return p -} - -// DoubleBaseScalarMul computes s1*P1+s2*P2 -// where P1 and P2 are points on a twisted Edwards curve -// and s1, s2 scalars. -func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point { - - // first unpack the scalars - b1 := api.ToBinary(s1) - b2 := api.ToBinary(s2) - - res := Point{} - tmp := Point{} - sum := Point{} - sum.Add(api, p1, p2, curve) - - n := len(b1) - res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X) - res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y) - - for i := n - 2; i >= 0; i-- { - res.Double(api, &res, curve) - tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X) - tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y) - res.Add(api, &res, &tmp, curve) - } - - p.X = res.X - p.Y = res.Y - - return p -} diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go deleted file mode 100644 index 2c07b89d57..0000000000 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ /dev/null @@ -1,399 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package bandersnatch - -import ( - "math/big" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards/bandersnatch" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/test" -) - -type mustBeOnCurve struct { - P Point -} - -func (circuit *mustBeOnCurve) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - circuit.P.MustBeOnCurve(api, params) - - return nil -} - -func TestIsOnCurve(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness mustBeOnCurve - - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - - witness.P.X = (params.Base.X) - witness.P.Y = (params.Base.Y) - - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -type add struct { - P, E Point -} - -func (circuit *add) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - p := Point{} - p.X = params.Base.X - p.Y = params.Base.Y - res := circuit.P.Add(api, &circuit.P, &p, params) - - api.AssertIsEqual(res.X, circuit.E.X) - api.AssertIsEqual(res.Y, circuit.E.Y) - - return nil -} - -func TestAddFixedPoint(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness add - - // generate a random point, and compute expected_point = base + random_point - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - var base, point, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - point.Set(&base) - r := big.NewInt(5) - point.ScalarMul(&point, r) - expected.Add(&base, &point) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -//------------------------------------------------------------- -// addGeneric - -type addGeneric struct { - P1, P2, E Point -} - -func (circuit *addGeneric) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params) - - api.AssertIsEqual(res.X, circuit.E.X) - api.AssertIsEqual(res.Y, circuit.E.Y) - - return nil -} - -func TestAddGeneric(t *testing.T) { - - assert := test.NewAssert(t) - var circuit, witness addGeneric - - // generate random points, and compute expected_point = point1 + point2s - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - var point1, point2, expected bandersnatch.PointAffine - point1.X.SetBigInt(¶ms.Base.X) - point1.Y.SetBigInt(¶ms.Base.Y) - point2.Set(&point1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - point1.ScalarMul(&point1, r1) - point2.ScalarMul(&point2, r2) - expected.Add(&point1, &point2) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -//------------------------------------------------------------- -// Double -type double struct { - P, E Point -} - -func (circuit *double) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - res := circuit.P.Double(api, &circuit.P, params) - - api.AssertIsEqual(res.X, circuit.E.X) - api.AssertIsEqual(res.Y, circuit.E.Y) - - return nil -} - -func TestDouble(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness double - - // generate witness data - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - - // populate witness - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -type scalarMulFixed struct { - E Point - S frontend.Variable -} - -func (circuit *scalarMulFixed) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - var resFixed, p Point - p.X = params.Base.X - p.Y = params.Base.Y - resFixed.ScalarMul(api, &p, circuit.S, params) - - api.AssertIsEqual(resFixed.X, circuit.E.X) - api.AssertIsEqual(resFixed.Y, circuit.E.Y) - - return nil -} - -func TestScalarMulFixed(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness scalarMulFixed - - // generate witness data - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - - // populate witness - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -type scalarMulGeneric struct { - P, E Point - S frontend.Variable -} - -func (circuit *scalarMulGeneric) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params) - - api.AssertIsEqual(resGeneric.X, circuit.E.X) - api.AssertIsEqual(resGeneric.Y, circuit.E.Y) - - return nil -} - -func TestScalarMulGeneric(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness scalarMulGeneric - - // generate witness data - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - var base, point, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -type neg struct { - P, E Point -} - -func (circuit *neg) Define(api frontend.API) error { - - circuit.P.Neg(api, &circuit.P) - api.AssertIsEqual(circuit.P.X, circuit.E.X) - api.AssertIsEqual(circuit.P.Y, circuit.E.Y) - - return nil -} - -func TestNeg(t *testing.T) { - - assert := test.NewAssert(t) - - // generate witness data - params, err := NewEdCurve(ecc.BLS12_381) - if err != nil { - t.Fatal(err) - } - var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Neg(&base) - - // generate witness - var circuit, witness neg - witness.P.X = (base.X) - witness.P.Y = (base.Y) - witness.E.X = (expected.X) - witness.E.Y = (expected.Y) - - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) - -} - -// Bench -func BenchmarkDouble(b *testing.B) { - var c double - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkAddGeneric(b *testing.B) { - var c addGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkAddFixedPoint(b *testing.B) { - var c add - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkMustBeOnCurve(b *testing.B) { - var c mustBeOnCurve - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkScalarMulGeneric(b *testing.B) { - var c scalarMulGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkScalarMulFixed(b *testing.B) { - var c scalarMulFixed - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} diff --git a/std/algebra/twistededwards/curve.go b/std/algebra/twistededwards/curve.go index 432210c588..bcc5f36119 100644 --- a/std/algebra/twistededwards/curve.go +++ b/std/algebra/twistededwards/curve.go @@ -1,182 +1,63 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package twistededwards import ( - "errors" - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - edbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" - edbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" - edbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" - edbn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" - edbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" - edbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" - "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark-crypto/ecc/twistededwards" + "github.com/consensys/gnark/frontend" ) -// Coordinates of a point on a twisted Edwards curve -type Coord struct { - X, Y big.Int +// curve curve is the default twisted edwards companion curve (defined on api.Curve().Fr) +type curve struct { + api frontend.API + id twistededwards.ID + params *CurveParams + endo *EndoParams } -// EdCurve stores the info on the chosen edwards curve -// note that all curves implemented in gnark-crypto have A = -1 -type EdCurve struct { - A, D, Cofactor, Order big.Int - Base Coord - ID ecc.ID +func (c *curve) Params() *CurveParams { + return c.params } -var constructors map[ecc.ID]func() EdCurve - -func init() { - constructors = map[ecc.ID]func() EdCurve{ - ecc.BLS12_381: newEdBLS381, - ecc.BN254: newEdBN254, - ecc.BLS12_377: newEdBLS377, - ecc.BW6_761: newEdBW761, - ecc.BLS24_315: newEdBLS315, - ecc.BW6_633: newEdBW633, - } +func (c *curve) API() frontend.API { + return c.api } -// NewEdCurve returns an Edwards curve parameters -func NewEdCurve(id ecc.ID) (EdCurve, error) { - if constructor, ok := constructors[id]; ok { - return constructor(), nil - } - return EdCurve{}, errors.New("unknown curve id") +func (c *curve) Endo() *EndoParams { + return c.endo } -// ------------------------------------------------------------------------------------------------- -// constructors - -func newEdBN254() EdCurve { - - edcurve := edbn254.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BN254, - } - +func (c *curve) Add(p1, p2 Point) Point { + var p Point + p.add(c.api, &p1, &p2, c.params) + return p } -func newEdBLS381() EdCurve { - - edcurve := edbls12381.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BLS12_381, - } - +func (c *curve) Double(p1 Point) Point { + var p Point + p.double(c.api, &p1, c.params) + return p } - -func newEdBLS377() EdCurve { - - edcurve := edbls12377.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BLS12_377, - } - +func (c *curve) Neg(p1 Point) Point { + var p Point + p.neg(c.api, &p1) + return p } - -func newEdBW633() EdCurve { - - edcurve := edbw6633.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BW6_633, - } - +func (c *curve) AssertIsOnCurve(p1 Point) { + p1.assertIsOnCurve(c.api, c.params) } - -func newEdBW761() EdCurve { - - edcurve := edbw6761.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BW6_761, +func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { + var p Point + if c.endo != nil { + // TODO restore + // this is disabled until this issue is solved https://github.com/ConsenSys/gnark/issues/268 + // p.scalarMulGLV(c.api, &p1, scalar, c.params, c.endo) + p.scalarMul(c.api, &p1, scalar, c.params) + } else { + p.scalarMul(c.api, &p1, scalar, c.params) } - + return p } - -func newEdBLS315() EdCurve { - - edcurve := edbls24315.GetEdwardsCurve() - edcurve.Cofactor.FromMont() - - return EdCurve{ - A: utils.FromInterface(edcurve.A), - D: utils.FromInterface(edcurve.D), - Cofactor: utils.FromInterface(edcurve.Cofactor), - Order: utils.FromInterface(edcurve.Order), - Base: Coord{ - X: utils.FromInterface(edcurve.Base.X), - Y: utils.FromInterface(edcurve.Base.Y), - }, - ID: ecc.BLS24_315, - } - +func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point { + var p Point + p.doubleBaseScalarMul(c.api, &p1, &p2, s1, s2, c.params) + return p } diff --git a/std/algebra/twistededwards/curve_test.go b/std/algebra/twistededwards/curve_test.go new file mode 100644 index 0000000000..583b2c7b46 --- /dev/null +++ b/std/algebra/twistededwards/curve_test.go @@ -0,0 +1,384 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package twistededwards + +import ( + "crypto/rand" + "math/big" + "testing" + + tbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" + tbls12381_bandersnatch "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch" + tbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" + tbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" + tbn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" + tbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" + tbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" + "github.com/consensys/gnark-crypto/ecc/twistededwards" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +var curves = []twistededwards.ID{twistededwards.BN254, twistededwards.BLS12_377, twistededwards.BLS12_381, twistededwards.BLS12_381_BANDERSNATCH, twistededwards.BW6_761, twistededwards.BW6_633, twistededwards.BLS24_315} + +type mustBeOnCurve struct { + curveID twistededwards.ID + P Point +} + +func (circuit *mustBeOnCurve) Define(api frontend.API) error { + + // get edwards curve curve + curve, err := NewEdCurve(api, circuit.curveID) + if err != nil { + return err + } + + curve.AssertIsOnCurve(circuit.P) + + return nil +} + +func TestIsOnCurve(t *testing.T) { + + assert := test.NewAssert(t) + + for _, curve := range curves { + var circuit, witness mustBeOnCurve + circuit.curveID = curve + + // get matching snark curve + snarkCurve, err := GetSnarkCurve(curve) + assert.NoError(err) + + // get curve params + params, err := GetCurveParams(curve) + assert.NoError(err) + + witness.P.X = params.Base[0] + witness.P.Y = params.Base[1] + + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(snarkCurve)) + + witness.P.X = params.Base[0] + witness.P.Y = params.randomScalar() + + assert.SolvingFailed(&circuit, &witness, test.WithCurves(snarkCurve)) + } + +} + +type addCircuit struct { + curveID twistededwards.ID + P1, P2 Point + AddResult Point + DoubleResult Point + ScalarMulResult Point + DoubleScalarMulResult Point + NegResult Point + S1, S2 frontend.Variable + fixedPoint Point +} + +func (circuit *addCircuit) Define(api frontend.API) error { + + // get edwards curve curve + curve, err := NewEdCurve(api, circuit.curveID) + if err != nil { + return err + } + + { + // addition 2 variable points + res := curve.Add(circuit.P1, circuit.P2) + api.AssertIsEqual(res.X, circuit.AddResult.X) + api.AssertIsEqual(res.Y, circuit.AddResult.Y) + } + + { + // addition 1 fixed + 1 variable point + res := curve.Add(circuit.fixedPoint, circuit.P1) + api.AssertIsEqual(res.X, circuit.AddResult.X) + api.AssertIsEqual(res.Y, circuit.AddResult.Y) + } + + { + // doubling + res := curve.Double(circuit.P1) + api.AssertIsEqual(res.X, circuit.DoubleResult.X) + api.AssertIsEqual(res.Y, circuit.DoubleResult.Y) + } + + { + // Neg + res := curve.Neg(circuit.P2) + api.AssertIsEqual(res.X, circuit.NegResult.X) + api.AssertIsEqual(res.Y, circuit.NegResult.Y) + } + + { + // scalar mul + res := curve.ScalarMul(circuit.P2, circuit.S2) + api.AssertIsEqual(res.X, circuit.ScalarMulResult.X) + api.AssertIsEqual(res.Y, circuit.ScalarMulResult.Y) + } + + { + // scalar mul fixed + res := curve.ScalarMul(circuit.fixedPoint, circuit.S2) + api.AssertIsEqual(res.X, circuit.ScalarMulResult.X) + api.AssertIsEqual(res.Y, circuit.ScalarMulResult.Y) + } + + { + // double scalar mul + res := curve.DoubleBaseScalarMul(circuit.P1, circuit.P2, circuit.S1, circuit.S2) + api.AssertIsEqual(res.X, circuit.DoubleScalarMulResult.X) + api.AssertIsEqual(res.Y, circuit.DoubleScalarMulResult.Y) + } + + return nil +} + +func TestCurve(t *testing.T) { + assert := test.NewAssert(t) + for _, curve := range curves { + var circuit, witness addCircuit + circuit.curveID = curve + + // get matching snark curve + snarkCurve, err := GetSnarkCurve(curve) + assert.NoError(err) + + // get curve params + params, err := GetCurveParams(curve) + assert.NoError(err) + + witness.P1, + witness.P2, + witness.AddResult, + witness.DoubleResult, + witness.ScalarMulResult, + witness.DoubleScalarMulResult, + witness.NegResult, + witness.S1, witness.S2 = testData(params, curve) + + circuit.fixedPoint = witness.P2 + + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(snarkCurve)) + + witness.P1.Y = params.randomScalar() + + assert.SolvingFailed(&circuit, &witness, test.WithCurves(snarkCurve)) + } +} + +// testData generates random test data for given curve +// returns p1, p2 and r, d such that p1 + p2 == r and p1 + p1 == d +// returns rs1, rs12, s1, s2 such that rs1 = p2 * s2 and rs12 = p1*s1 + p2 * s2 +// retunrs n such that n = -p2 +func testData(params *CurveParams, curveID twistededwards.ID) ( + _p1, + _p2, + _r, + _d, + _rs1, + _rs12, + _n Point, + s1, s2 frontend.Variable) { + scalar1 := params.randomScalar() + scalar2 := params.randomScalar() + + switch curveID { + case twistededwards.BN254: + var p1, p2, r, d, rs1, rs12, n tbn254.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + case twistededwards.BLS12_381: + var p1, p2, r, d, rs1, rs12, n tbls12381.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + case twistededwards.BLS12_381_BANDERSNATCH: + var p1, p2, r, d, rs1, rs12, n tbls12381_bandersnatch.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + case twistededwards.BLS12_377: + var p1, p2, r, d, rs1, rs12, n tbls12377.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + case twistededwards.BLS24_315: + var p1, p2, r, d, rs1, rs12, n tbls24315.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + case twistededwards.BW6_633: + var p1, p2, r, d, rs1, rs12, n tbw6633.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + case twistededwards.BW6_761: + var p1, p2, r, d, rs1, rs12, n tbw6761.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + + p1.ScalarMul(&p1, scalar1) + p2.ScalarMul(&p2, scalar2) + r.Add(&p1, &p2) + d.Double(&p1) + rs1.ScalarMul(&p2, scalar2) + rs12.ScalarMul(&p1, scalar1) + rs12.Add(&rs12, &rs1) + n.Neg(&p2) + + return Point{p1.X, p1.Y}, + Point{p2.X, p2.Y}, + Point{r.X, r.Y}, + Point{d.X, d.Y}, + Point{rs1.X, rs1.Y}, + Point{rs12.X, rs12.Y}, + Point{n.X, n.Y}, + scalar1, scalar2 + + default: + panic("not implemented") + } +} + +// randomScalar returns a scalar <= p.Order +func (p *CurveParams) randomScalar() *big.Int { + r, _ := rand.Int(rand.Reader, p.Order) + return r +} diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index fcc969edbe..dbacdb30d5 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -17,48 +17,39 @@ limitations under the License. package twistededwards import ( - "math/big" - "github.com/consensys/gnark/frontend" ) -// Point point on a twisted Edwards curve in a Snark cs -type Point struct { - X, Y frontend.Variable -} - -// Neg computes the negative of a point in SNARK coordinates -func (p *Point) Neg(api frontend.API, p1 *Point) *Point { +// neg computes the negative of a point in SNARK coordinates +func (p *Point) neg(api frontend.API, p1 *Point) *Point { p.X = api.Neg(p1.X) p.Y = p1.Y return p } -// MustBeOnCurve checks if a point is on the reduced twisted Edwards curve +// assertIsOnCurve checks if a point is on the reduced twisted Edwards curve // a*x² + y² = 1 + d*x²*y². -func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { - - one := big.NewInt(1) +func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) { xx := api.Mul(p.X, p.X) yy := api.Mul(p.Y, p.Y) - axx := api.Mul(xx, &curve.A) + axx := api.Mul(xx, curve.A) lhs := api.Add(axx, yy) - dxx := api.Mul(xx, &curve.D) + dxx := api.Mul(xx, curve.D) dxxyy := api.Mul(dxx, yy) - rhs := api.Add(dxxyy, one) + rhs := api.Add(dxxyy, 1) api.AssertIsEqual(lhs, rhs) } -// Add Adds two points on a twisted edwards curve (eg jubjub) +// add Adds two points on a twisted edwards curve (eg jubjub) // p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { +func (p *Point) add(api frontend.API, p1, p2 *Point, curve *CurveParams) *Point { // u = (x1 + y1) * (x2 + y2) - u1 := api.Mul(p1.X, &curve.A) + u1 := api.Mul(p1.X, curve.A) u1 = api.Sub(p1.Y, u1) u2 := api.Add(p2.X, p2.Y) u := api.Mul(u1, u2) @@ -70,14 +61,14 @@ func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { v1 := api.Mul(p2.X, p1.Y) // v2 = d * v0 * v1 - v2 := api.Mul(&curve.D, v0, v1) + v2 := api.Mul(curve.D, v0, v1) // x = (v0 + v1) / (1 + v2) p.X = api.Add(v0, v1) p.X = api.DivUnchecked(p.X, api.Add(1, v2)) // y = (u + a * v0 - v1) / (1 - v2) - p.Y = api.Mul(&curve.A, v0) + p.Y = api.Mul(curve.A, v0) p.Y = api.Sub(p.Y, v1) p.Y = api.Add(p.Y, u) p.Y = api.DivUnchecked(p.Y, api.Sub(1, v2)) @@ -85,15 +76,15 @@ func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { return p } -// Double doubles a points in SNARK coordinates -func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { +// double doubles a points in SNARK coordinates +func (p *Point) double(api frontend.API, p1 *Point, curve *CurveParams) *Point { u := api.Mul(p1.X, p1.Y) v := api.Mul(p1.X, p1.X) w := api.Mul(p1.Y, p1.Y) n1 := api.Mul(2, u) - av := api.Mul(v, &curve.A) + av := api.Mul(v, curve.A) n2 := api.Sub(w, av) d1 := api.Add(w, av) d2 := api.Sub(2, d1) @@ -104,12 +95,16 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { return p } -// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { +func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { + if len(endo) == 1 && endo[0] != nil { + // use glv + return p.scalarMulGLV(api, p1, scalar, curve, endo[0]) + } // first unpack the scalar b := api.ToBinary(scalar) @@ -119,24 +114,24 @@ func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, A := Point{} B := Point{} - A.Double(api, p1, curve) - B.Add(api, &A, p1, curve) + A.double(api, p1, curve) + B.add(api, &A, p1, curve) n := len(b) - 1 res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X) res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y) for i := n - 2; i >= 1; i -= 2 { - res.Double(api, &res, curve). - Double(api, &res, curve) + res.double(api, &res, curve). + double(api, &res, curve) tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X) tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y) - res.Add(api, &res, &tmp, curve) + res.add(api, &res, &tmp, curve) } if n%2 == 0 { - res.Double(api, &res, curve) - tmp.Add(api, &res, p1, curve) + res.double(api, &res, curve) + tmp.add(api, &res, p1, curve) res.X = api.Select(b[0], tmp.X, res.X) res.Y = api.Select(b[0], tmp.Y, res.Y) } @@ -147,10 +142,10 @@ func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, return p } -// DoubleBaseScalarMul computes s1*P1+s2*P2 +// doubleBaseScalarMul computes s1*P1+s2*P2 // where P1 and P2 are points on a twisted Edwards curve // and s1, s2 scalars. -func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point { +func (p *Point) doubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams) *Point { // first unpack the scalars b1 := api.ToBinary(s1) @@ -159,17 +154,17 @@ func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 fron res := Point{} tmp := Point{} sum := Point{} - sum.Add(api, p1, p2, curve) + sum.add(api, p1, p2, curve) n := len(b1) res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X) res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y) for i := n - 2; i >= 0; i-- { - res.Double(api, &res, curve) + res.double(api, &res, curve) tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X) tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y) - res.Add(api, &res, &tmp, curve) + res.add(api, &res, &tmp, curve) } p.X = res.X diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go deleted file mode 100644 index 14dcb47b08..0000000000 --- a/std/algebra/twistededwards/point_test.go +++ /dev/null @@ -1,820 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package twistededwards - -import ( - "math/big" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - tbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" - tbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" - tbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" - tbn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" - tbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" - tbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/test" -) - -type mustBeOnCurve struct { - P Point -} - -func (circuit *mustBeOnCurve) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - circuit.P.MustBeOnCurve(api, params) - - return nil -} - -func TestIsOnCurve(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness mustBeOnCurve - - params, err := NewEdCurve(ecc.BN254) - if err != nil { - t.Fatal(err) - } - - witness.P.X = (params.Base.X) - witness.P.Y = (params.Base.Y) - - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) - -} - -type add struct { - P, E Point -} - -func (circuit *add) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - p := Point{} - p.X = params.Base.X - p.Y = params.Base.Y - res := circuit.P.Add(api, &circuit.P, &p, params) - - api.AssertIsEqual(res.X, circuit.E.X) - api.AssertIsEqual(res.Y, circuit.E.Y) - - return nil -} - -func TestAddFixedPoint(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness add - - // generate a random point, and compute expected_point = base + random_point - params, err := NewEdCurve(ecc.BN254) - if err != nil { - t.Fatal(err) - } - var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - point.Set(&base) - r := big.NewInt(5) - point.ScalarMul(&point, r) - expected.Add(&base, &point) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) - -} - -//------------------------------------------------------------- -// addGeneric - -type addGeneric struct { - P1, P2, E Point -} - -func (circuit *addGeneric) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params) - - api.AssertIsEqual(res.X, circuit.E.X) - api.AssertIsEqual(res.Y, circuit.E.Y) - - return nil -} - -func TestAddGeneric(t *testing.T) { - - assert := test.NewAssert(t) - var circuit, witness addGeneric - - // generate witness data - for _, id := range ecc.Implemented() { - - params, err := NewEdCurve(id) - if err != nil { - t.Fatal(err) - } - - switch id { - case ecc.BN254: - var op1, op2, expected tbn254.PointAffine - op1.X.SetBigInt(¶ms.Base.X) - op1.Y.SetBigInt(¶ms.Base.Y) - op2.Set(&op1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - op1.ScalarMul(&op1, r1) - op2.ScalarMul(&op2, r2) - expected.Add(&op1, &op2) - witness.P1.X = (op1.X.String()) - witness.P1.Y = (op1.Y.String()) - witness.P2.X = (op2.X.String()) - witness.P2.Y = (op2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BLS12_381: - var op1, op2, expected tbls12381.PointAffine - op1.X.SetBigInt(¶ms.Base.X) - op1.Y.SetBigInt(¶ms.Base.Y) - op2.Set(&op1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - op1.ScalarMul(&op1, r1) - op2.ScalarMul(&op2, r2) - expected.Add(&op1, &op2) - witness.P1.X = (op1.X.String()) - witness.P1.Y = (op1.Y.String()) - witness.P2.X = (op2.X.String()) - witness.P2.Y = (op2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BLS12_377: - var op1, op2, expected tbls12377.PointAffine - op1.X.SetBigInt(¶ms.Base.X) - op1.Y.SetBigInt(¶ms.Base.Y) - op2.Set(&op1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - op1.ScalarMul(&op1, r1) - op2.ScalarMul(&op2, r2) - expected.Add(&op1, &op2) - witness.P1.X = (op1.X.String()) - witness.P1.Y = (op1.Y.String()) - witness.P2.X = (op2.X.String()) - witness.P2.Y = (op2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BLS24_315: - var op1, op2, expected tbls24315.PointAffine - op1.X.SetBigInt(¶ms.Base.X) - op1.Y.SetBigInt(¶ms.Base.Y) - op2.Set(&op1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - op1.ScalarMul(&op1, r1) - op2.ScalarMul(&op2, r2) - expected.Add(&op1, &op2) - witness.P1.X = (op1.X.String()) - witness.P1.Y = (op1.Y.String()) - witness.P2.X = (op2.X.String()) - witness.P2.Y = (op2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BW6_633: - var op1, op2, expected tbw6633.PointAffine - op1.X.SetBigInt(¶ms.Base.X) - op1.Y.SetBigInt(¶ms.Base.Y) - op2.Set(&op1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - op1.ScalarMul(&op1, r1) - op2.ScalarMul(&op2, r2) - expected.Add(&op1, &op2) - witness.P1.X = (op1.X.String()) - witness.P1.Y = (op1.Y.String()) - witness.P2.X = (op2.X.String()) - witness.P2.Y = (op2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BW6_761: - var op1, op2, expected tbw6761.PointAffine - op1.X.SetBigInt(¶ms.Base.X) - op1.Y.SetBigInt(¶ms.Base.Y) - op2.Set(&op1) - r1 := big.NewInt(5) - r2 := big.NewInt(12) - op1.ScalarMul(&op1, r1) - op2.ScalarMul(&op2, r2) - expected.Add(&op1, &op2) - witness.P1.X = (op1.X.String()) - witness.P1.Y = (op1.Y.String()) - witness.P2.X = (op2.X.String()) - witness.P2.Y = (op2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - } - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) - } - -} - -//------------------------------------------------------------- -// Double - -type double struct { - P, E Point -} - -func (circuit *double) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - res := circuit.P.Double(api, &circuit.P, params) - - api.AssertIsEqual(res.X, circuit.E.X) - api.AssertIsEqual(res.Y, circuit.E.Y) - - return nil -} - -func TestDouble(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness double - - // generate witness data - for _, id := range ecc.Implemented() { - - params, err := NewEdCurve(id) - if err != nil { - t.Fatal(err) - } - - switch id { - case ecc.BN254: - var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BLS12_381: - var base, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BLS12_377: - var base, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BLS24_315: - var base, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BW6_633: - var base, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - case ecc.BW6_761: - var base, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Double(&base) - witness.P.X = (base.X.String()) - witness.P.Y = (base.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - } - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) - } - -} - -//------------------------------------------------------------- -// scalarMulFixed - -type scalarMulFixed struct { - E Point - S frontend.Variable -} - -func (circuit *scalarMulFixed) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - var resFixed, p Point - p.X = params.Base.X - p.Y = params.Base.Y - resFixed.ScalarMul(api, &p, circuit.S, params) - - api.AssertIsEqual(resFixed.X, circuit.E.X) - api.AssertIsEqual(resFixed.Y, circuit.E.Y) - - return nil -} - -func TestScalarMulFixed(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness scalarMulFixed - - // generate witness data - for _, id := range ecc.Implemented() { - - params, err := NewEdCurve(id) - if err != nil { - t.Fatal(err) - } - - switch id { - case ecc.BN254: - var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BLS12_381: - var base, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BLS12_377: - var base, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BLS24_315: - var base, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BW6_633: - var base, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BW6_761: - var base, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - r := big.NewInt(928323002) - expected.ScalarMul(&base, r) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - } - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) - } - -} - -type scalarMulGeneric struct { - P, E Point - S frontend.Variable -} - -func (circuit *scalarMulGeneric) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params) - - api.AssertIsEqual(resGeneric.X, circuit.E.X) - api.AssertIsEqual(resGeneric.Y, circuit.E.Y) - - return nil -} - -func TestScalarMulGeneric(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness scalarMulGeneric - - // generate witness data - for _, id := range ecc.Implemented() { - - params, err := NewEdCurve(id) - if err != nil { - t.Fatal(err) - } - - switch id { - case ecc.BN254: - var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BLS12_377: - var base, point, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BLS12_381: - var base, point, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BLS24_315: - var base, point, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BW6_761: - var base, point, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - case ecc.BW6_633: - var base, point, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) - } - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) - } -} - -// - -type doubleScalarMulGeneric struct { - P1, P2, E Point - S1, S2 frontend.Variable -} - -func (circuit *doubleScalarMulGeneric) Define(api frontend.API) error { - - // get edwards curve params - params, err := NewEdCurve(api.Compiler().Curve()) - if err != nil { - return err - } - - resGeneric := circuit.P1.DoubleBaseScalarMul(api, &circuit.P1, &circuit.P2, circuit.S1, circuit.S2, params) - - api.AssertIsEqual(resGeneric.X, circuit.E.X) - api.AssertIsEqual(resGeneric.Y, circuit.E.Y) - - return nil -} - -func TestDoubleScalarMulGeneric(t *testing.T) { - - assert := test.NewAssert(t) - - var circuit, witness doubleScalarMulGeneric - - // generate witness data - for _, id := range ecc.Implemented() { - - params, err := NewEdCurve(id) - if err != nil { - t.Fatal(err) - } - - switch id { - case ecc.BN254: - var base, point1, point2, tmp, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s1 := big.NewInt(902) - s2 := big.NewInt(891) - point1.ScalarMul(&base, s1) // random point - point2.ScalarMul(&base, s2) // random point - r1 := big.NewInt(230928303) - r2 := big.NewInt(2830309) - tmp.ScalarMul(&point1, r1) - expected.ScalarMul(&point2, r2). - Add(&expected, &tmp) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S1 = (r1) - witness.S2 = (r2) - case ecc.BLS12_377: - var base, point1, point2, tmp, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s1 := big.NewInt(902) - s2 := big.NewInt(891) - point1.ScalarMul(&base, s1) // random point - point2.ScalarMul(&base, s2) // random point - r1 := big.NewInt(230928303) - r2 := big.NewInt(2830309) - tmp.ScalarMul(&point1, r1) - expected.ScalarMul(&point2, r2). - Add(&expected, &tmp) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S1 = (r1) - witness.S2 = (r2) - case ecc.BLS12_381: - var base, point1, point2, tmp, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s1 := big.NewInt(902) - s2 := big.NewInt(891) - point1.ScalarMul(&base, s1) // random point - point2.ScalarMul(&base, s2) // random point - r1 := big.NewInt(230928303) - r2 := big.NewInt(2830309) - tmp.ScalarMul(&point1, r1) - expected.ScalarMul(&point2, r2). - Add(&expected, &tmp) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S1 = (r1) - witness.S2 = (r2) - case ecc.BLS24_315: - var base, point1, point2, tmp, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s1 := big.NewInt(902) - s2 := big.NewInt(891) - point1.ScalarMul(&base, s1) // random point - point2.ScalarMul(&base, s2) // random point - r1 := big.NewInt(230928303) - r2 := big.NewInt(2830309) - tmp.ScalarMul(&point1, r1) - expected.ScalarMul(&point2, r2). - Add(&expected, &tmp) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S1 = (r1) - witness.S2 = (r2) - case ecc.BW6_761: - var base, point1, point2, tmp, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s1 := big.NewInt(902) - s2 := big.NewInt(891) - point1.ScalarMul(&base, s1) // random point - point2.ScalarMul(&base, s2) // random point - r1 := big.NewInt(230928303) - r2 := big.NewInt(2830309) - tmp.ScalarMul(&point1, r1) - expected.ScalarMul(&point2, r2). - Add(&expected, &tmp) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S1 = (r1) - witness.S2 = (r2) - case ecc.BW6_633: - var base, point1, point2, tmp, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s1 := big.NewInt(902) - s2 := big.NewInt(891) - point1.ScalarMul(&base, s1) // random point - point2.ScalarMul(&base, s2) // random point - r1 := big.NewInt(230928303) - r2 := big.NewInt(2830309) - tmp.ScalarMul(&point1, r1) - expected.ScalarMul(&point2, r2). - Add(&expected, &tmp) - - // populate witness - witness.P1.X = (point1.X.String()) - witness.P1.Y = (point1.Y.String()) - witness.P2.X = (point2.X.String()) - witness.P2.Y = (point2.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S1 = (r1) - witness.S2 = (r2) - } - - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) - } -} - -type neg struct { - P, E Point -} - -func (circuit *neg) Define(api frontend.API) error { - - circuit.P.Neg(api, &circuit.P) - api.AssertIsEqual(circuit.P.X, circuit.E.X) - api.AssertIsEqual(circuit.P.Y, circuit.E.Y) - - return nil -} - -func TestNeg(t *testing.T) { - - assert := test.NewAssert(t) - - // generate witness data - params, err := NewEdCurve(ecc.BN254) - if err != nil { - t.Fatal(err) - } - var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - expected.Neg(&base) - - // generate witness - var circuit, witness neg - witness.P.X = (base.X) - witness.P.Y = (base.Y) - witness.E.X = (expected.X) - witness.E.Y = (expected.Y) - - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) - -} diff --git a/std/algebra/twistededwards/scalarmul_glv.go b/std/algebra/twistededwards/scalarmul_glv.go new file mode 100644 index 0000000000..193a1f72e1 --- /dev/null +++ b/std/algebra/twistededwards/scalarmul_glv.go @@ -0,0 +1,135 @@ +/* +Copyright © 2022 ConsenSys Software Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package twistededwards + +import ( + "errors" + "math/big" + "sync" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" +) + +// phi endomorphism √-2 ∈ 𝒪₋₈ +// (x,y) → λ × (x,y) s.t. λ² = -2 mod Order +func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { + + xy := api.Mul(p1.X, p1.Y) + yy := api.Mul(p1.Y, p1.Y) + f := api.Sub(1, yy) + f = api.Mul(f, endo.Endo[1]) + g := api.Add(yy, endo.Endo[0]) + g = api.Mul(g, endo.Endo[0]) + h := api.Sub(yy, endo.Endo[0]) + + p.X = api.DivUnchecked(f, xy) + p.Y = api.DivUnchecked(g, h) + + return p +} + +type glvParams struct { + lambda, order big.Int + glvBasis ecc.Lattice +} + +var DecomposeScalar = func(curve ecc.ID, inputs []*big.Int, res []*big.Int) error { + // the efficient endomorphism exists on Bandersnatch only + if curve != ecc.BLS12_381 { + return errors.New("no efficient endomorphism is available on this curve") + } + var glv glvParams + var init sync.Once + init.Do(func() { + glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) + }) + + // sp[0] is always negative because, in SplitScalar(), we always round above + // the determinant/2 computed in PrecomputeLattice() which is negative for Bandersnatch. + // Thus taking -sp[0] here and negating the point in ScalarMul(). + // If we keep -sp[0] it will be reduced mod r (the BLS12-381 prime order) + // and not the Bandersnatch prime order (Order) and the result will be incorrect. + // Also, if we reduce it mod Order here, we can't use api.ToBinary(sp[0], 129) + // and hence we can't reduce optimally the number of constraints. + sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) + res[0].Neg(&(sp[0])) + res[1].Set(&(sp[1])) + + // figure out how many times we have overflowed + res[2].Mul(res[1], &glv.lambda).Sub(res[2], res[0]) + res[2].Sub(res[2], inputs[0]) + res[2].Div(res[2], &glv.order) + + return nil +} + +func init() { + hint.Register(DecomposeScalar) +} + +// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// p1: base point (as snark point) +// curve: parameters of the Edwards curve +// scal: scalar as a SNARK constraint +// Standard left to right double and add +func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // the hints allow to decompose the scalar s into s1 and s2 such that + // s1 + λ * s2 == s mod Order, + // with λ s.t. λ² = -2 mod Order. + sd, err := api.NewHint(DecomposeScalar, 3, scalar) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + s1, s2 := sd[0], sd[1] + + // -s1 + λ * s2 == s + k*Order + api.AssertIsEqual(api.Sub(api.Mul(s2, endo.Lambda), s1), api.Add(scalar, api.Mul(curve.Order, sd[2]))) + + // Normally s1 and s2 are of the max size sqrt(Order) = 128 + // But in a circuit, we force s1 to be negative by rounding always above. + // This changes the size bounds to 2*sqrt(Order) = 129. + n := 129 + + b1 := api.ToBinary(s1, n) + b2 := api.ToBinary(s2, n) + + var res, _p1, p2, p3, tmp Point + _p1.neg(api, p1) + p2.phi(api, p1, curve, endo) + p3.add(api, &_p1, &p2, curve) + + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y) + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y) + res.add(api, &res, &tmp, curve) + } + + p.X = res.X + p.Y = res.Y + + return p +} diff --git a/std/algebra/twistededwards/twistededwards.go b/std/algebra/twistededwards/twistededwards.go new file mode 100644 index 0000000000..70828a1cb4 --- /dev/null +++ b/std/algebra/twistededwards/twistededwards.go @@ -0,0 +1,254 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package twistededwards + +import ( + "errors" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + edbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" + edbls12381_bandersnatch "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch" + edbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" + edbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" + edbn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" + edbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" + edbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" + "github.com/consensys/gnark-crypto/ecc/twistededwards" + "github.com/consensys/gnark/frontend" +) + +// Curve methods implemented by a twisted edwards curve inside a circuit +type Curve interface { + Params() *CurveParams + Endo() *EndoParams + Add(p1, p2 Point) Point + Double(p1 Point) Point + Neg(p1 Point) Point + AssertIsOnCurve(p1 Point) + ScalarMul(p1 Point, scalar frontend.Variable) Point + DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point + API() frontend.API +} + +// Point represent a pair of X, Y coordinates inside a circuit +type Point struct { + X, Y frontend.Variable +} + +// CurveParams twisted edwards curve parameters ax^2 + y^2 = 1 + d*x^2*y^2 +// Matches gnark-crypto curve specific params +type CurveParams struct { + A, D, Cofactor, Order *big.Int + Base [2]*big.Int // base point coordinates +} + +// EndoParams endomorphism parameters for the curve, if they exist +type EndoParams struct { + Endo [2]*big.Int + Lambda *big.Int +} + +// NewEdCurve returns a new Edwards curve +func NewEdCurve(api frontend.API, id twistededwards.ID) (Curve, error) { + snarkCurve, err := GetSnarkCurve(id) + if err != nil { + return nil, err + } + if api.Curve() != snarkCurve { + return nil, errors.New("invalid curve pair; snark field doesn't match twisted edwards field") + } + params, err := GetCurveParams(id) + if err != nil { + return nil, err + } + var endo *EndoParams + + // bandersnatch + if id == twistededwards.BLS12_381_BANDERSNATCH { + endo = &EndoParams{ + Endo: [2]*big.Int{new(big.Int), new(big.Int)}, + Lambda: new(big.Int), + } + endo.Endo[0].SetString("37446463827641770816307242315180085052603635617490163568005256780843403514036", 10) + endo.Endo[1].SetString("49199877423542878313146170939139662862850515542392585932876811575731455068989", 10) + endo.Lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + } + + // default + return &curve{api: api, params: params, endo: endo, id: id}, nil +} + +func GetCurveParams(id twistededwards.ID) (*CurveParams, error) { + var params *CurveParams + switch id { + case twistededwards.BN254: + params = newEdBN254() + case twistededwards.BLS12_377: + params = newEdBLS12_377() + case twistededwards.BLS12_381: + params = newEdBLS12_381() + case twistededwards.BLS12_381_BANDERSNATCH: + params = newEdBLS12_381_BANDERSNATCH() + case twistededwards.BLS24_315: + params = newEdBLS24_315() + case twistededwards.BW6_761: + params = newEdBW6_761() + case twistededwards.BW6_633: + params = newEdBW6_633() + default: + return nil, errors.New("unknown twisted edwards curve id") + } + return params, nil +} + +// GetSnarkCurve returns the matching snark curve for a twisted edwards curve +func GetSnarkCurve(id twistededwards.ID) (ecc.ID, error) { + switch id { + case twistededwards.BN254: + return ecc.BN254, nil + case twistededwards.BLS12_377: + return ecc.BLS12_377, nil + case twistededwards.BLS12_381, twistededwards.BLS12_381_BANDERSNATCH: + return ecc.BLS12_381, nil + case twistededwards.BLS24_315: + return ecc.BLS24_315, nil + case twistededwards.BW6_761: + return ecc.BW6_761, nil + case twistededwards.BW6_633: + return ecc.BW6_633, nil + default: + return ecc.UNKNOWN, errors.New("unknown twisted edwards curve id") + } +} + +// ------------------------------------------------------------------------------------------------- +// constructors + +func newCurveParams() *CurveParams { + return &CurveParams{ + A: new(big.Int), + D: new(big.Int), + Cofactor: new(big.Int), + Order: new(big.Int), + Base: [2]*big.Int{new(big.Int), new(big.Int)}, + } +} + +func newEdBN254() *CurveParams { + + edcurve := edbn254.GetEdwardsCurve() + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} + +func newEdBLS12_381() *CurveParams { + + edcurve := edbls12381.GetEdwardsCurve() + + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} + +func newEdBLS12_381_BANDERSNATCH() *CurveParams { + + edcurve := edbls12381_bandersnatch.GetEdwardsCurve() + + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} + +func newEdBLS12_377() *CurveParams { + + edcurve := edbls12377.GetEdwardsCurve() + + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} + +func newEdBW6_633() *CurveParams { + + edcurve := edbw6633.GetEdwardsCurve() + + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} + +func newEdBW6_761() *CurveParams { + + edcurve := edbw6761.GetEdwardsCurve() + + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} + +func newEdBLS24_315() *CurveParams { + + edcurve := edbls24315.GetEdwardsCurve() + + r := newCurveParams() + edcurve.A.ToBigIntRegular(r.A) + edcurve.D.ToBigIntRegular(r.D) + edcurve.Cofactor.ToBigIntRegular(r.Cofactor) + r.Order.Set(&edcurve.Order) + edcurve.Base.X.ToBigIntRegular(r.Base[0]) + edcurve.Base.Y.ToBigIntRegular(r.Base[1]) + return r + +} diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index c266b33a7e..cf0e56b99a 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -18,15 +18,26 @@ limitations under the License. package eddsa import ( + "errors" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/twistededwards" - "github.com/consensys/gnark/std/hash/mimc" + + edwardsbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" + edwardsbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" + edwardsbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" + edwardsbn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" + edwardsbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" + edwardsbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" ) // PublicKey stores an eddsa public key (to be used in gnark circuit) type PublicKey struct { - A twistededwards.Point - Curve twistededwards.EdCurve + A twistededwards.Point } // Signature stores a signature (to be used in gnark circuit) @@ -40,53 +51,203 @@ type Signature struct { S frontend.Variable } -// Verify verifies an eddsa signature +// Verify verifies an eddsa signature using MiMC hash function // cf https://en.wikipedia.org/wiki/EdDSA -func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey PublicKey) error { - - // compute H(R, A, M), all parameters in data are in Montgomery form - data := []frontend.Variable{ - sig.R.X, - sig.R.Y, - pubKey.A.X, - pubKey.A.Y, - msg, - } - - hash, err := mimc.NewMiMC(api) - if err != nil { - return err +func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.Hash) error { + + // compute H(R, A, M) + hash.Write(sig.R.X) + hash.Write(sig.R.Y) + hash.Write(pubKey.A.X) + hash.Write(pubKey.A.Y) + hash.Write(msg) + hRAM := hash.Sum() + + base := twistededwards.Point{ + X: curve.Params().Base[0], + Y: curve.Params().Base[1], } - hash.Write(data...) - hramConstant := hash.Sum() - - base := twistededwards.Point{} - base.X = pubKey.Curve.Base.X - base.Y = pubKey.Curve.Base.Y //[S]G-[H(R,A,M)]*A - cofactor := pubKey.Curve.Cofactor.Uint64() - Q := twistededwards.Point{} - _A := twistededwards.Point{} - _A.Neg(api, &pubKey.A) - Q.DoubleBaseScalarMul(api, &base, &_A, sig.S, hramConstant, pubKey.Curve) - Q.MustBeOnCurve(api, pubKey.Curve) + _A := curve.Neg(pubKey.A) + Q := curve.DoubleBaseScalarMul(base, _A, sig.S, hRAM) + curve.AssertIsOnCurve(Q) //[S]G-[H(R,A,M)]*A-R - Q.Neg(api, &Q).Add(api, &Q, &sig.R, pubKey.Curve) + Q = curve.Add(curve.Neg(Q), sig.R) // [cofactor]*(lhs-rhs) + log := logger.Logger() + if !curve.Params().Cofactor.IsUint64() { + err := errors.New("invalid cofactor") + log.Err(err).Str("cofactor", curve.Params().Cofactor.String()).Send() + return err + } + cofactor := curve.Params().Cofactor.Uint64() switch cofactor { case 4: - Q.Double(api, &Q, pubKey.Curve). - Double(api, &Q, pubKey.Curve) + Q = curve.Double(curve.Double(Q)) case 8: - Q.Double(api, &Q, pubKey.Curve). - Double(api, &Q, pubKey.Curve).Double(api, &Q, pubKey.Curve) + Q = curve.Double(curve.Double(curve.Double(Q))) + default: + log.Warn().Str("cofactor", curve.Params().Cofactor.String()).Msg("curve cofactor is not implemented") } - api.AssertIsEqual(Q.X, 0) - api.AssertIsEqual(Q.Y, 1) + curve.API().AssertIsEqual(Q.X, 0) + curve.API().AssertIsEqual(Q.Y, 1) return nil } + +// Assign is a helper to assigned a compressed binary public key representation into its uncompressed form +func (p *PublicKey) Assign(curveID ecc.ID, buf []byte) { + ax, ay, err := parsePoint(curveID, buf) + if err != nil { + panic(err) + } + p.A.X = ax + p.A.Y = ay +} + +// Assign is a helper to assigned a compressed binary signature representation into its uncompressed form +func (s *Signature) Assign(curveID ecc.ID, buf []byte) { + rx, ry, S, err := parseSignature(curveID, buf) + if err != nil { + panic(err) + } + s.R.X = rx + s.R.Y = ry + s.S = S +} + +// parseSignature parses a compressed binary signature into uncompressed R.X, R.Y and S +func parseSignature(curveID ecc.ID, buf []byte) ([]byte, []byte, []byte, error) { + + var pointbn254 edwardsbn254.PointAffine + var pointbls12381 edwardsbls12381.PointAffine + var pointbls12377 edwardsbls12377.PointAffine + var pointbw6761 edwardsbw6761.PointAffine + var pointbls24315 edwardsbls24315.PointAffine + var pointbw6633 edwardsbw6633.PointAffine + + switch curveID { + case ecc.BN254: + if _, err := pointbn254.SetBytes(buf[:32]); err != nil { + return nil, nil, nil, err + } + a, b, err := parsePoint(curveID, buf) + if err != nil { + return nil, nil, nil, err + } + s := buf[32:] + return a, b, s, nil + case ecc.BLS12_381: + if _, err := pointbls12381.SetBytes(buf[:32]); err != nil { + return nil, nil, nil, err + } + a, b, err := parsePoint(curveID, buf) + if err != nil { + return nil, nil, nil, err + } + s := buf[32:] + return a, b, s, nil + case ecc.BLS12_377: + if _, err := pointbls12377.SetBytes(buf[:32]); err != nil { + return nil, nil, nil, err + } + a, b, err := parsePoint(curveID, buf) + if err != nil { + return nil, nil, nil, err + } + s := buf[32:] + return a, b, s, nil + case ecc.BW6_761: + if _, err := pointbw6761.SetBytes(buf[:48]); err != nil { + return nil, nil, nil, err + } + a, b, err := parsePoint(curveID, buf) + if err != nil { + return nil, nil, nil, err + } + s := buf[48:] + return a, b, s, nil + case ecc.BLS24_315: + if _, err := pointbls24315.SetBytes(buf[:32]); err != nil { + return nil, nil, nil, err + } + a, b, err := parsePoint(curveID, buf) + if err != nil { + return nil, nil, nil, err + } + s := buf[32:] + return a, b, s, nil + case ecc.BW6_633: + if _, err := pointbw6633.SetBytes(buf[:40]); err != nil { + return nil, nil, nil, err + } + a, b, err := parsePoint(curveID, buf) + if err != nil { + return nil, nil, nil, err + } + s := buf[40:] + return a, b, s, nil + default: + panic("not implemented") + } +} + +// parsePoint parses a compressed binary point into uncompressed P.X and P.Y +func parsePoint(curveID ecc.ID, buf []byte) ([]byte, []byte, error) { + var pointbn254 edwardsbn254.PointAffine + var pointbls12381 edwardsbls12381.PointAffine + var pointbls12377 edwardsbls12377.PointAffine + var pointbw6761 edwardsbw6761.PointAffine + var pointbls24315 edwardsbls24315.PointAffine + var pointbw6633 edwardsbw6633.PointAffine + switch curveID { + case ecc.BN254: + if _, err := pointbn254.SetBytes(buf[:32]); err != nil { + return nil, nil, err + } + a := pointbn254.X.Bytes() + b := pointbn254.Y.Bytes() + return a[:], b[:], nil + case ecc.BLS12_381: + if _, err := pointbls12381.SetBytes(buf[:32]); err != nil { + return nil, nil, err + } + a := pointbls12381.X.Bytes() + b := pointbls12381.Y.Bytes() + return a[:], b[:], nil + case ecc.BLS12_377: + if _, err := pointbls12377.SetBytes(buf[:32]); err != nil { + return nil, nil, err + } + a := pointbls12377.X.Bytes() + b := pointbls12377.Y.Bytes() + return a[:], b[:], nil + case ecc.BW6_761: + if _, err := pointbw6761.SetBytes(buf[:48]); err != nil { + return nil, nil, err + } + a := pointbw6761.X.Bytes() + b := pointbw6761.Y.Bytes() + return a[:], b[:], nil + case ecc.BLS24_315: + if _, err := pointbls24315.SetBytes(buf[:32]); err != nil { + return nil, nil, err + } + a := pointbls24315.X.Bytes() + b := pointbls24315.Y.Bytes() + return a[:], b[:], nil + case ecc.BW6_633: + if _, err := pointbw6633.SetBytes(buf[:40]); err != nil { + return nil, nil, err + } + a := pointbw6633.X.Bytes() + b := pointbw6633.Y.Bytes() + return a[:], b[:], nil + default: + panic("not implemented") + } +} diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 9bc4265074..574b4ba9d2 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -20,235 +20,121 @@ import ( "math/big" "math/rand" "testing" + "time" - "github.com/consensys/gnark-crypto/ecc" - edwardsbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" - eddsabls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards/eddsa" - edwardsbls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" - eddsabls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards/eddsa" - edwardsbls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" - eddsabls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards/eddsa" - edwardsbn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" - eddsabn254 "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards/eddsa" - edwardsbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" - eddsabw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards/eddsa" - edwardsbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" - eddsabw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards/eddsa" + tedwards "github.com/consensys/gnark-crypto/ecc/twistededwards" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/signature" + "github.com/consensys/gnark-crypto/signature/eddsa" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) type eddsaCircuit struct { + curveID tedwards.ID PublicKey PublicKey `gnark:",public"` Signature Signature `gnark:",public"` Message frontend.Variable `gnark:",public"` } -//func parseSignature(id ecc.ID, buf []byte) ([]byte, []byte, []byte) { -func parseSignature(id ecc.ID, buf []byte) ([]byte, []byte, []byte) { - - var pointbn254 edwardsbn254.PointAffine - var pointbls12381 edwardsbls12381.PointAffine - var pointbls12377 edwardsbls12377.PointAffine - var pointbw6761 edwardsbw6761.PointAffine - var pointbls24315 edwardsbls24315.PointAffine - var pointbw6633 edwardsbw6633.PointAffine - - switch id { - case ecc.BN254: - pointbn254.SetBytes(buf[:32]) - a, b := parsePoint(id, buf) - s := buf[32:] - return a[:], b[:], s - case ecc.BLS12_381: - pointbls12381.SetBytes(buf[:32]) - a, b := parsePoint(id, buf) - s := buf[32:] - return a[:], b[:], s - case ecc.BLS12_377: - pointbls12377.SetBytes(buf[:32]) - a, b := parsePoint(id, buf) - s := buf[32:] - return a[:], b[:], s - case ecc.BW6_761: - pointbw6761.SetBytes(buf[:48]) - a, b := parsePoint(id, buf) - s := buf[48:] - return a[:], b[:], s - case ecc.BLS24_315: - pointbls24315.SetBytes(buf[:32]) - a, b := parsePoint(id, buf) - s := buf[32:] - return a[:], b[:], s - case ecc.BW6_633: - pointbw6633.SetBytes(buf[:40]) - a, b := parsePoint(id, buf) - s := buf[40:] - return a[:], b[:], s - default: - return buf, buf, buf - } -} +func (circuit *eddsaCircuit) Define(api frontend.API) error { -func parsePoint(id ecc.ID, buf []byte) ([]byte, []byte) { - var pointbn254 edwardsbn254.PointAffine - var pointbls12381 edwardsbls12381.PointAffine - var pointbls12377 edwardsbls12377.PointAffine - var pointbw6761 edwardsbw6761.PointAffine - var pointbls24315 edwardsbls24315.PointAffine - var pointbw6633 edwardsbw6633.PointAffine - switch id { - case ecc.BN254: - pointbn254.SetBytes(buf[:32]) - a := pointbn254.X.Bytes() - b := pointbn254.Y.Bytes() - return a[:], b[:] - case ecc.BLS12_381: - pointbls12381.SetBytes(buf[:32]) - a := pointbls12381.X.Bytes() - b := pointbls12381.Y.Bytes() - return a[:], b[:] - case ecc.BLS12_377: - pointbls12377.SetBytes(buf[:32]) - a := pointbls12377.X.Bytes() - b := pointbls12377.Y.Bytes() - return a[:], b[:] - case ecc.BW6_761: - pointbw6761.SetBytes(buf[:48]) - a := pointbw6761.X.Bytes() - b := pointbw6761.Y.Bytes() - return a[:], b[:] - case ecc.BLS24_315: - pointbls24315.SetBytes(buf[:32]) - a := pointbls24315.X.Bytes() - b := pointbls24315.Y.Bytes() - return a[:], b[:] - case ecc.BW6_633: - pointbw6633.SetBytes(buf[:40]) - a := pointbw6633.X.Bytes() - b := pointbw6633.Y.Bytes() - return a[:], b[:] - default: - return buf, buf + curve, err := twistededwards.NewEdCurve(api, circuit.curveID) + if err != nil { + return err } -} -func (circuit *eddsaCircuit) Define(api frontend.API) error { - - params, err := twistededwards.NewEdCurve(api.Compiler().Curve()) + mimc, err := mimc.NewMiMC(api) if err != nil { return err } - circuit.PublicKey.Curve = params // verify the signature in the cs - Verify(api, circuit.Signature, circuit.Message, circuit.PublicKey) - - return nil + return Verify(curve, circuit.Signature, circuit.Message, circuit.PublicKey, &mimc) } func TestEddsa(t *testing.T) { assert := test.NewAssert(t) - type confSig struct { - h hash.Hash - s signature.SignatureScheme + type testData struct { + hash hash.Hash + curve tedwards.ID } - signature.Register(signature.EDDSA_BN254, eddsabn254.GenerateKeyInterfaces) - signature.Register(signature.EDDSA_BLS12_381, eddsabls12381.GenerateKeyInterfaces) - signature.Register(signature.EDDSA_BLS12_377, eddsabls12377.GenerateKeyInterfaces) - signature.Register(signature.EDDSA_BW6_761, eddsabw6761.GenerateKeyInterfaces) - signature.Register(signature.EDDSA_BLS24_315, eddsabls24315.GenerateKeyInterfaces) - signature.Register(signature.EDDSA_BW6_633, eddsabw6633.GenerateKeyInterfaces) - - confs := map[ecc.ID]confSig{ - ecc.BN254: {hash.MIMC_BN254, signature.EDDSA_BN254}, - ecc.BLS12_381: {hash.MIMC_BLS12_381, signature.EDDSA_BLS12_381}, - ecc.BLS12_377: {hash.MIMC_BLS12_377, signature.EDDSA_BLS12_377}, - ecc.BW6_761: {hash.MIMC_BW6_761, signature.EDDSA_BW6_761}, - ecc.BLS24_315: {hash.MIMC_BLS24_315, signature.EDDSA_BLS24_315}, - ecc.BW6_633: {hash.MIMC_BW6_633, signature.EDDSA_BW6_633}, + confs := []testData{ + {hash.MIMC_BN254, tedwards.BN254}, + {hash.MIMC_BLS12_381, tedwards.BLS12_381}, + // {hash.MIMC_BLS12_381, tedwards.BLS12_381_BANDERSNATCH}, + {hash.MIMC_BLS12_377, tedwards.BLS12_377}, + {hash.MIMC_BW6_761, tedwards.BW6_761}, + {hash.MIMC_BLS24_315, tedwards.BLS24_315}, + {hash.MIMC_BW6_633, tedwards.BW6_633}, } - for id, conf := range confs { - - // generate parameters for the signatures - hFunc := conf.h.New() - src := rand.NewSource(0) - r := rand.New(src) - privKey, err := conf.s.New(r) - if err != nil { - t.Fatal(err) - } - pubKey := privKey.Public() - // pick a message to sign - var frMsg big.Int - frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978", 10) - msgBin := frMsg.Bytes() + bound := 5 + if testing.Short() { + bound = 1 + } - // generate signature - signature, err := privKey.Sign(msgBin[:], hFunc) - if err != nil { - t.Fatal(err) - } + for i := 0; i < bound; i++ { + seed := time.Now().Unix() + t.Logf("setting seed in rand %d", seed) + randomness := rand.New(rand.NewSource(seed)) - // check if there is no problem in the signature - checkSig, err := pubKey.Verify(signature, msgBin[:], hFunc) - if err != nil { - t.Fatal(err) - } - if !checkSig { - t.Fatal("Unexpected failed signature verification") - } + for _, conf := range confs { - // create and compile the circuit for signature verification - var circuit eddsaCircuit - - // verification with the correct Message - { - var witness eddsaCircuit - witness.Message = frMsg - - pubkeyAx, pubkeyAy := parsePoint(id, pubKey.Bytes()) - var pbAx, pbAy big.Int - pbAx.SetBytes(pubkeyAx) - pbAy.SetBytes(pubkeyAy) - witness.PublicKey.A.X = pubkeyAx - witness.PublicKey.A.Y = pubkeyAy - - // sigRx, sigRy, sigS1, sigS2 := parseSignature(id, signature) - sigRx, sigRy, sigS := parseSignature(id, signature) - witness.Signature.R.X = sigRx - witness.Signature.R.Y = sigRy - // witness.Signature.S1 = sigS1 - // witness.Signature.S2 = sigS2 - witness.Signature.S = sigS - - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) - } + snarkCurve, err := twistededwards.GetSnarkCurve(conf.curve) + assert.NoError(err) - // verification with incorrect Message - { - var witness eddsaCircuit - witness.Message = "44717650746155748460101257525078853138837311576962212923649547644148297035979" + // generate parameters for the signatures + privKey, err := eddsa.New(conf.curve, randomness) + assert.NoError(err, "generating eddsa key pair") - pubkeyAx, pubkeyAy := parsePoint(id, pubKey.Bytes()) - witness.PublicKey.A.X = pubkeyAx - witness.PublicKey.A.Y = pubkeyAy + // pick a message to sign + var msg big.Int + msg.Rand(randomness, snarkCurve.Info().Fr.Modulus()) + t.Log("msg to sign", msg.String()) + msgData := msg.Bytes() - // sigRx, sigRy, sigS1, sigS2 := parseSignature(id, signature) - sigRx, sigRy, sigS := parseSignature(id, signature) - witness.Signature.R.X = sigRx - witness.Signature.R.Y = sigRy - witness.Signature.S = sigS + // generate signature + signature, err := privKey.Sign(msgData[:], conf.hash.New()) + assert.NoError(err, "signing message") - assert.SolvingFailed(&circuit, &witness, test.WithCurves(id)) - } + // check if there is no problem in the signature + pubKey := privKey.Public() + checkSig, err := pubKey.Verify(signature, msgData[:], conf.hash.New()) + assert.NoError(err, "verifying signature") + assert.True(checkSig, "signature verification failed") + + // create and compile the circuit for signature verification + var circuit eddsaCircuit + circuit.curveID = conf.curve + + // verification with the correct Message + { + var witness eddsaCircuit + witness.Message = msg + witness.PublicKey.Assign(snarkCurve, pubKey.Bytes()) + witness.Signature.Assign(snarkCurve, signature) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(snarkCurve)) + } + + // verification with incorrect Message + { + var witness eddsaCircuit + + msg.Rand(randomness, snarkCurve.Info().Fr.Modulus()) + witness.Message = msg + witness.PublicKey.Assign(snarkCurve, pubKey.Bytes()) + witness.Signature.Assign(snarkCurve, signature) + + assert.SolvingFailed(&circuit, &witness, test.WithCurves(snarkCurve)) + } + + } } + } diff --git a/test/assert.go b/test/assert.go index 4c461c1403..9f849cd467 100644 --- a/test/assert.go +++ b/test/assert.go @@ -443,7 +443,7 @@ func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendI return nil, ErrCompilationNotDeterministic } - // add the compiled circuit to the cache + // // add the compiled circuit to the cache assert.compiled[key] = ccs return ccs, nil