diff --git a/blindsign/blindrsa/pss.go b/blindsign/blindrsa/pss.go index dca0abe6a..79ab75a87 100644 --- a/blindsign/blindrsa/pss.go +++ b/blindsign/blindrsa/pss.go @@ -35,14 +35,8 @@ package blindrsa // This file implements the RSASSA-PSS signature scheme according to RFC 8017. import ( - "bytes" - "crypto" - "crypto/rsa" "errors" - "fmt" "hash" - "io" - "math/big" ) // Per RFC 8017, Section 9.1 @@ -132,227 +126,3 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt // 13. Output EM. return em, nil } - -func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { - // See RFC 8017, Section 9.1.2. - - hLen := hash.Size() - if sLen == PSSSaltLengthEqualsHash { - sLen = hLen - } - emLen := (emBits + 7) / 8 - if emLen != len(em) { - return errors.New("rsa: internal error: inconsistent length") - } - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" - // and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if hLen != len(mHash) { - fmt.Println("here3", hLen, len(mHash)) - return ErrVerification - } - - // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. - if emLen < hLen+sLen+2 { - fmt.Println("here2") - return ErrVerification - } - - // 4. If the rightmost octet of EM does not have hexadecimal value - // 0xbc, output "inconsistent" and stop. - if em[emLen-1] != 0xbc { - fmt.Println("here") - return ErrVerification - } - - // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and - // let H be the next hLen octets. - db := em[:emLen-hLen-1] - h := em[emLen-hLen-1 : emLen-1] - - // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in - // maskedDB are not all equal to zero, output "inconsistent" and - // stop. - var bitMask byte = 0xff >> (8*emLen - emBits) - if em[0] & ^bitMask != 0 { - fmt.Println("here4") - return ErrVerification - } - - // 7. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 8. Let DB = maskedDB \xor dbMask. - mgf1XOR(db, hash, h) - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= bitMask - - // If we don't know the salt length, look for the 0x01 delimiter. - if sLen == PSSSaltLengthAuto { - psLen := bytes.IndexByte(db, 0x01) - if psLen < 0 { - fmt.Println("here5") - return ErrVerification - } - sLen = len(db) - psLen - 1 - } - - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - psLen := emLen - hLen - sLen - 2 - for _, e := range db[:psLen] { - if e != 0x00 { - fmt.Println("here6") - return ErrVerification - } - } - if db[psLen] != 0x01 { - fmt.Println("here7") - return ErrVerification - } - - // 11. Let salt be the last sLen octets of DB. - salt := db[len(db)-sLen:] - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - var prefix [8]byte - hash.Write(prefix[:]) - hash.Write(mHash) - hash.Write(salt) - - h0 := hash.Sum(nil) - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if !bytes.Equal(h0, h) { // TODO: constant time? - fmt.Println("here8") - return ErrVerification - } - return nil -} - -// signPSSWithSalt calculates the signature of hashed using PSS with specified salt. -// Note that hashed must be the result of hashing the input message using the -// given hash function. salt is a random sequence of bytes whose length will be -// later used to verify the signature. -func signPSSWithSalt(rand io.Reader, priv *rsa.PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { - emBits := priv.N.BitLen() - 1 - em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) - if err != nil { - return nil, err - } - m := new(big.Int).SetBytes(em) - c, err := decryptAndCheck(rand, priv, m) - if err != nil { - return nil, err - } - s := make([]byte, priv.Size()) - copyWithLeftPad(s, c.Bytes()) - return s, nil -} - -const ( - // PSSSaltLengthAuto causes the salt in a PSS signature to be as large - // as possible when signing, and to be auto-detected when verifying. - PSSSaltLengthAuto = 0 - // PSSSaltLengthEqualsHash causes the salt length to equal the length - // of the hash used in the signature. - PSSSaltLengthEqualsHash = -1 -) - -// PSSOptions contains options for creating and verifying PSS signatures. -type PSSOptions struct { - // SaltLength controls the length of the salt used in the PSS - // signature. It can either be a number of bytes, or one of the special - // PSSSaltLength constants. - SaltLength int - - // Hash is the hash function used to generate the message digest. If not - // zero, it overrides the hash function passed to SignPSS. It's required - // when using PrivateKey.Sign. - Hash crypto.Hash -} - -// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts. -func (opts *PSSOptions) HashFunc() crypto.Hash { - return opts.Hash -} - -func (opts *PSSOptions) saltLength() int { - if opts == nil { - return PSSSaltLengthAuto - } - return opts.SaltLength -} - -// SignPSS calculates the signature of digest using PSS. -// -// digest must be the result of hashing the input message using the given hash -// function. The opts argument may be nil, in which case sensible defaults are -// used. If opts.Hash is set, it overrides hash. -func SignPSS(rand io.Reader, priv *rsa.PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) { - if opts != nil && opts.Hash != 0 { - hash = opts.Hash - } - - saltLength := opts.saltLength() - switch saltLength { - case PSSSaltLengthAuto: - saltLength = priv.Size() - 2 - hash.Size() - case PSSSaltLengthEqualsHash: - saltLength = hash.Size() - } - - salt := make([]byte, saltLength) - if _, err := io.ReadFull(rand, salt); err != nil { - return nil, err - } - return signPSSWithSalt(rand, priv, hash, digest, salt) -} - -// VerifyPSS verifies a PSS signature. -// -// A valid signature is indicated by returning a nil error. digest must be the -// result of hashing the input message using the given hash function. The opts -// argument may be nil, in which case sensible defaults are used. opts.Hash is -// ignored. -func VerifyPSS(pub *rsa.PublicKey, hash hash.Hash, digest []byte, sig []byte, opts *PSSOptions) error { - if len(sig) != pub.Size() { - fmt.Println("1") - return ErrVerification - } - s := new(big.Int).SetBytes(sig) - m := encrypt(new(big.Int), pub, s) - emBits := pub.N.BitLen() - 1 - emLen := (emBits + 7) / 8 - emBytes := m.Bytes() - if m.BitLen() > emLen*8 { - fmt.Println("2") - return ErrVerification - } - - em := make([]byte, emLen) - copyWithLeftPad(em, emBytes) - - return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash) -} - -// copyWithLeftPad copies src to the end of dest, padding with zero bytes as -// needed. -func copyWithLeftPad(dest, src []byte) { - numPaddingBytes := len(dest) - len(src) - for i := 0; i < numPaddingBytes; i++ { - dest[i] = 0 - } - copy(dest[numPaddingBytes:], src) -} diff --git a/blindsign/blindrsa/rsa.go b/blindsign/blindrsa/rsa.go index 2f825a023..648d2730b 100644 --- a/blindsign/blindrsa/rsa.go +++ b/blindsign/blindrsa/rsa.go @@ -42,10 +42,6 @@ var ( bigOne = big.NewInt(1) ) -// ErrVerification represents a failure to verify a signature. -// It is deliberately vague to avoid adaptive attacks. -var ErrVerification = errors.New("crypto/rsa: verification error") - // incCounter increments a four byte, big-endian counter. func incCounter(c *[4]byte) { if c[3]++; c[3] != 0 { @@ -87,12 +83,17 @@ func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int { return c } +// decrypt performs an RSA decryption, resulting in a plaintext integer. If a +// random source is given, RSA blinding is used. func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) { // TODO(agl): can we get away with reusing blinds? if c.Cmp(priv.N) > 0 { err = rsa.ErrDecryption return } + if priv.N.Sign() == 0 { + return nil, rsa.ErrDecryption + } var ir *big.Int if random != nil { @@ -102,7 +103,7 @@ func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, er // by multiplying by the multiplicative inverse of r. var r *big.Int - + ir = new(big.Int) for { r, err = rand.Int(random, priv.N) if err != nil { @@ -111,13 +112,13 @@ func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, er if r.Cmp(bigZero) == 0 { r = bigOne } - ir = new(big.Int).ModInverse(r, priv.N) - if ir != nil { + ok := ir.ModInverse(r, priv.N) + if ok != nil { break } } bigE := big.NewInt(int64(priv.E)) - rpowe := new(big.Int).Exp(r, bigE, priv.N) + rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0 cCopy := new(big.Int).Set(c) cCopy.Mul(cCopy, rpowe) cCopy.Mod(cCopy, priv.N)