From 6a9b04ca281a26d1c7fcfb5c575b69469b75955a Mon Sep 17 00:00:00 2001 From: Kimi Wu Date: Fri, 27 Oct 2023 10:47:57 +0800 Subject: [PATCH] Feat/#318 precompile ecrecover (#495) * doc: ecRecover.md spec file * doc: signature circuit, copied from Sroll's design and revisted to fit our architecture * feat: impl. sig_circuit * test: add more cases * feat: add sig_table * doc: complete constraints desc. * feat: impl. ecRecover * test: add a normal case * test: complete testing * feat: remove rlc usage in sig circuit * feat: correct public key to little-endian * fix: is_success is always true and using iz_zero gadget for sig_r/v * test: fix testing data * fix return data length when the addr is not recoverable * doc: refinement --- specs/precompile/01ecRecover.md | 33 ++++ specs/sig-proof.md | 51 ++++++ specs/tables.md | 14 ++ src/zkevm_specs/evm_circuit/__init__.py | 1 + .../evm_circuit/execution/__init__.py | 3 +- .../execution/precompiles/ecrecover.py | 65 +++++++ src/zkevm_specs/evm_circuit/instruction.py | 11 ++ src/zkevm_specs/evm_circuit/table.py | 33 ++++ src/zkevm_specs/sig_circuit.py | 122 ++++++++++++ src/zkevm_specs/tx_circuit.py | 2 +- src/zkevm_specs/util/__init__.py | 2 + src/zkevm_specs/util/ec.py | 113 ++++++++++++ src/zkevm_specs/util/tables.py | 33 ++++ tests/evm/precompiles/test_ecRecover.py | 173 ++++++++++++++++++ tests/test_sig_circuit.py | 164 +++++++++++++++++ 15 files changed, 818 insertions(+), 2 deletions(-) create mode 100644 specs/precompile/01ecRecover.md create mode 100644 specs/sig-proof.md create mode 100644 src/zkevm_specs/evm_circuit/execution/precompiles/ecrecover.py create mode 100644 src/zkevm_specs/sig_circuit.py create mode 100644 src/zkevm_specs/util/ec.py create mode 100644 src/zkevm_specs/util/tables.py create mode 100644 tests/evm/precompiles/test_ecRecover.py create mode 100644 tests/test_sig_circuit.py diff --git a/specs/precompile/01ecRecover.md b/specs/precompile/01ecRecover.md new file mode 100644 index 000000000..e65869419 --- /dev/null +++ b/specs/precompile/01ecRecover.md @@ -0,0 +1,33 @@ +# ecRecover precompile + +## Procedure + +To recover the signer from a signature. It returns signer's address if input signature is valid, otherwise returns 0. + +## EVM behavior + +### Inputs + +The length of inputs is 128 bytes. The first 32 bytes is keccak hash of the message, and the following 96 bytes are v, r, s values. The value v is either 27 or 28. + +### Output + +The recovered 20-byte address right aligned to 32 byte. If an address can't be recovered or not enough gas was given, then the output is 0. + +### Gas cost + +A constant gas cost: 3000 + +## Constraints + +1. If gas_left < gas_required, then is_success == false and return data is zero. +1. v, r and s are valid + - v is 27 or 28 + - both of r and s are less than `secp256k1N (0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141)` + - both of r and s are greater than `1` +2. `sig_table` lookups +3. recovered address is zero if the signature can't be recovered. + +## Code + +Please refer to `src/zkevm_specs/evm_circuit/execution/precompiles/ecrecover.py`. diff --git a/specs/sig-proof.md b/specs/sig-proof.md new file mode 100644 index 000000000..183933e60 --- /dev/null +++ b/specs/sig-proof.md @@ -0,0 +1,51 @@ +# Signature Proof + +[Elliptic Curve Digital Signature Algorithm]: https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm + +According to the [Elliptic Curve Digital Signature Algorithm] (ECDSA), the signatures `(r,s)` are calculated via ECDSA from `msg_hash` and a `public_key` using the formula + +`(r,s)=ecdsa(msg_hash, public_key)` + +The `public_key` is obtained from `private_key` by mapping the latter to an elliptic curve (EC) point. The `r` is the x-component of an EC point, and the same EC point's y-component will be used to determine the recovery id `v = y%2` (the parity of y). Given the signature `(v, r, s)`, the `public_key` can be recovered from `(v, r, s)` and `msg_hash` using `ecrecover`. + + +## Circuit behavior + +SigTable built inside zkevm-circuits is used to verify signatures. It has the following columns: +- `msg_hash`: Advice Column, the Keccak256 hash of the message that's signed; +- `sig_v`: Advice Column, the recovery id, either 0 or 1, it should be the parity of y; +- `sig_r`: Advice Column, the signature's `r` component; +- `sig_s`: Advice Column, the signature's `s` component; +- `recovered_addr`: Advice Column, the recovered address, i.e. the 20-bytes address that must have signed the message; +- `is_valid`: Advice Column, indicates whether or not the signature is valid or not upon signature verification. + +Constraints on the shape of the table is like: + +| 0 msg_hash | 1 sig_v | 2 sig_r | 3 sig_s | 4 recovered_addr | 5 is_valid | +| ------------- | ------ | ------------- | ------------- | ---------------- | ---------- | +| $value{Lo,Hi} | 0/1 | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | + + +The Sig Circuit aims at proving the correctness of SigTable. This mainly includes the following type of constraints: +- Checking that the signature is obtained correctly. This is done by the ECDSA chip, and the correctness of `v` is checked separately; +- Checking that `msg_hash` is obtained correctly from Keccak hash function. This is done by lookup to Keccak table; + + +## Constraints + +`assign_ecdsa` method takes the signature data and uses ECDSA chip to verify its correctness. The verification result `sig_is_valid` will be returned. The recovery id `v` value will be computed and verified. + +`sign_data_decomposition` method takes the signature data and the return values of `assign_ecdsa`, and returns the cells for byte decomposition of the keys and messages in the form of `SignDataDecomposed`. The latter consists of the following contents: +- `SignDataDecomposed` + - `pk_hash_cells`: byte cells for keccak256 hash of public key; + - `msg_hash_cells`: byte cells for `msg_hash`; + - `pk_cells`: byte cells for the EC coordinates of public key; + - `address`: RLC of `pk_hash` last 20 bytes; + - `is_address_zero`: check if address is zero; + - `r_cells`, `s_cells`: byte cells for signatures `r` and `s`. + +The decomposed sign data are sent to `assign_sign_verify` method to compute and verify their RLC values and perform Keccak lookup checks. + +## Code + +Please refer to `src/zkevm-specs/sig_circuit.py` \ No newline at end of file diff --git a/specs/tables.md b/specs/tables.md index 17077b3e8..677d68a8f 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -365,3 +365,17 @@ Row(is_step=1, identifier=rwc, is_last=0, base_limbs=[3, 0, 0, 0], exponent_lo_h ``` Row(is_step=1, identifier=rwc, is_last=1, base_limbs=[3, 0, 0, 0], exponent_lo_hi=[2, 0], exponentiation_lo_hi=[9, 0]) ``` + + +## `sig_table` + +Provided by the Signature circuit. + +The circuit verifies the correctness of signatures. + +| 0 msg_hash | 1 sig_v | 2 sig_r | 3 sig_s | 4 recovered_addr | 5 is_valid | +| ------------- | ------ | ------------- | ------------- | ---------------- | ---------- | +| $value{Lo,Hi} | 0/1 | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | + +NOTE: +- `sig_v` is either 0 or 1 so boolean type is used here. diff --git a/src/zkevm_specs/evm_circuit/__init__.py b/src/zkevm_specs/evm_circuit/__init__.py index ee05962d0..d445ee9a1 100644 --- a/src/zkevm_specs/evm_circuit/__init__.py +++ b/src/zkevm_specs/evm_circuit/__init__.py @@ -7,3 +7,4 @@ from .table import * from .typing import * from .util import * +from .precompile import * diff --git a/src/zkevm_specs/evm_circuit/execution/__init__.py b/src/zkevm_specs/evm_circuit/execution/__init__.py index 1987c48d0..2523d9a8d 100644 --- a/src/zkevm_specs/evm_circuit/execution/__init__.py +++ b/src/zkevm_specs/evm_circuit/execution/__init__.py @@ -76,6 +76,7 @@ from .error_oog_static_memory_expansion import * from .error_oog_sload_sstore import * from .error_oog_create import * +from .precompiles.ecrecover import * EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { @@ -152,7 +153,7 @@ ExecutionState.ErrorOutOfGasSloadSstore: error_oog_sload_sstore, ExecutionState.ErrorReturnDataOutOfBound: error_return_data_out_of_bound, ExecutionState.ErrorOutOfGasCREATE: error_oog_create, - # ExecutionState.ECRECOVER: , + ExecutionState.ECRECOVER: ecRecover, # ExecutionState.SHA256: , # ExecutionState.RIPEMD160: , ExecutionState.DATACOPY: dataCopy, diff --git a/src/zkevm_specs/evm_circuit/execution/precompiles/ecrecover.py b/src/zkevm_specs/evm_circuit/execution/precompiles/ecrecover.py new file mode 100644 index 000000000..779726ada --- /dev/null +++ b/src/zkevm_specs/evm_circuit/execution/precompiles/ecrecover.py @@ -0,0 +1,65 @@ +from zkevm_specs.evm_circuit.instruction import Instruction +from zkevm_specs.evm_circuit.table import ( + CallContextFieldTag, + FixedTableTag, + RW, +) +from zkevm_specs.util import FQ, Word, EcrecoverGas + +SECP256K1N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 + + +def ecRecover(instruction: Instruction): + is_success = instruction.call_context_lookup(CallContextFieldTag.IsSuccess, RW.Read) + address_word = instruction.call_context_lookup_word(CallContextFieldTag.CalleeAddress) + address = instruction.word_to_address(address_word) + instruction.fixed_lookup( + FixedTableTag.PrecompileInfo, + FQ(instruction.curr.execution_state), + address, + FQ(EcrecoverGas), + ) + + # Get msg_hash, signature and recovered address from aux_data + msg_hash: Word = instruction.curr.aux_data[0] + sig_v: Word = instruction.curr.aux_data[1] + sig_r: Word = instruction.curr.aux_data[2] + sig_s: Word = instruction.curr.aux_data[3] + recovered_addr: FQ = instruction.curr.aux_data[4] + + is_recovered = FQ(instruction.is_zero(recovered_addr) != FQ(1)) + + # is_success is always true + # ref: https://github.com/ethereum/execution-specs/blob/master/src/ethereum/shanghai/vm/precompiled_contracts/ecrecover.py + instruction.constrain_equal(is_success, FQ(1)) + + # verify r and s + sig_r_upper_bound, _ = instruction.compare_word(sig_r, Word(SECP256K1N)) + sig_s_upper_bound, _ = instruction.compare_word(sig_s, Word(SECP256K1N)) + sig_r_is_non_zero = FQ(instruction.is_zero_word(sig_r) != FQ(1)) + sig_s_is_non_zero = FQ(instruction.is_zero_word(sig_s) != FQ(1)) + valid_r_s = instruction.is_equal( + sig_r_upper_bound + sig_s_upper_bound + sig_r_is_non_zero + sig_s_is_non_zero, FQ(4) + ) + + # verify v + is_equal_27 = instruction.is_equal_word(sig_v, Word(27)) + is_equal_28 = instruction.is_equal_word(sig_v, Word(28)) + valid_v = instruction.is_equal(is_equal_27 + is_equal_28, FQ(1)) + + if valid_r_s + valid_v == FQ(2): + # sig table lookups + instruction.sig_lookup( + msg_hash, sig_v.lo.expr() - FQ(27), sig_r, sig_s, recovered_addr, is_recovered + ) + else: + instruction.constrain_zero(is_recovered) + instruction.constrain_zero(recovered_addr) + + # Restore caller state to next StepState + instruction.step_state_transition_to_restored_context( + rw_counter_delta=instruction.rw_counter_offset, + return_data_offset=FQ.zero(), + return_data_length=FQ(32) if is_recovered == FQ(1) else FQ.zero(), + gas_left=instruction.curr.gas_left - EcrecoverGas, + ) diff --git a/src/zkevm_specs/evm_circuit/instruction.py b/src/zkevm_specs/evm_circuit/instruction.py index f2f0f0b63..5748dad8a 100644 --- a/src/zkevm_specs/evm_circuit/instruction.py +++ b/src/zkevm_specs/evm_circuit/instruction.py @@ -1398,6 +1398,17 @@ def exp_lookup( exp_table_row = self.tables.exp_lookup(identifier, is_last, base_limbs, exponent) return exp_table_row.exponentiation + def sig_lookup( + self, + msg_hash: Word, + sig_v: Expression, + sig_r: Word, + sig_s: Word, + recovered_addr: FQ, + is_valid: Expression, + ) -> Word: + return self.tables.sig_lookup(msg_hash, sig_v, sig_r, sig_s, recovered_addr, is_valid) + def constrain_error_state(self, rw_counter_delta: int): # Current call must fail. rw_counter_delta += 1 diff --git a/src/zkevm_specs/evm_circuit/table.py b/src/zkevm_specs/evm_circuit/table.py index 12ce69193..a53fba839 100644 --- a/src/zkevm_specs/evm_circuit/table.py +++ b/src/zkevm_specs/evm_circuit/table.py @@ -538,6 +538,16 @@ class ExpTableRow(TableRow): exponentiation: Word +@dataclass(frozen=True) +class SigTableRow(TableRow): + msg_hash: Word + sig_v: FQ + sig_r: Word + sig_s: Word + recovered_addr: FQ + is_valid: FQ + + class Tables: """ A collection of lookup tables used in EVM circuit. @@ -552,6 +562,7 @@ class Tables: copy_table: Set[CopyTableRow] keccak_table: Set[KeccakTableRow] exp_table: Set[ExpTableRow] + sig_table: Set[SigTableRow] def __init__( self, @@ -563,6 +574,7 @@ def __init__( copy_circuit: Optional[Sequence[CopyCircuitRow]] = None, keccak_table: Optional[Sequence[KeccakTableRow]] = None, exp_circuit: Optional[Sequence[ExpCircuitRow]] = None, + sig_table: Optional[Sequence[SigTableRow]] = None, ) -> None: self.block_table = block_table self.tx_table = tx_table @@ -578,6 +590,8 @@ def __init__( self.keccak_table = set(keccak_table) if exp_circuit is not None: self.exp_table = self._convert_exp_circuit_to_table(exp_circuit) + if sig_table is not None: + self.sig_table = set(sig_table) def _convert_copy_circuit_to_table(self, copy_circuit: Sequence[CopyCircuitRow]): rows: List[CopyTableRow] = [] @@ -768,6 +782,25 @@ def exp_lookup( } return lookup(ExpTableRow, self.exp_table, query) + def sig_lookup( + self, + msg_hash: Word, + sig_v: Expression, + sig_r: Word, + sig_s: Word, + recovered_addr: FQ, + is_valid: Expression, + ) -> SigTableRow: + query = { + "msg_hash": msg_hash, + "sig_v": sig_v, + "sig_r": sig_r, + "sig_s": sig_s, + "recovered_addr": recovered_addr, + "is_valid": is_valid, + } + return lookup(SigTableRow, self.sig_table, query) + T = TypeVar("T", bound=TableRow) diff --git a/src/zkevm_specs/sig_circuit.py b/src/zkevm_specs/sig_circuit.py new file mode 100644 index 000000000..bbc1a0844 --- /dev/null +++ b/src/zkevm_specs/sig_circuit.py @@ -0,0 +1,122 @@ +from typing import List, NamedTuple +from .util import FQ, RLC, Word, ECDSAVerifyChip, KeccakTable, is_circuit_code +from eth_keys import KeyAPI # type: ignore +from eth_utils import keccak + + +class Row: + """ + Signature circuit + Verify a message hash is signed by an Ethereum Address. + """ + + msg_hash: Word + sig_v: FQ + sig_r: Word + sig_s: Word + recovered_addr: FQ + is_valid: FQ + + ecdsa_chip: ECDSAVerifyChip + pub_key_hash: bytes + pub_key_x_bytes: bytes + pub_key_y_bytes: bytes + msg_hash_bytes: bytes + + def __init__( + self, + pub_key_hash: bytes, + address: FQ, + msg_hash: Word, + ecdsa_chip: ECDSAVerifyChip, + is_valid: bool = True, + ) -> None: + self.ecdsa_chip = ecdsa_chip + self.pub_key_x_bytes = ecdsa_chip.pub_key_x_bytes + self.pub_key_y_bytes = ecdsa_chip.pub_key_y_bytes + self.msg_hash_bytes = ecdsa_chip.msg_hash_bytes + + # table + self.msg_hash = msg_hash + self.sig_v = FQ(int.from_bytes(self.ecdsa_chip.sig_v.le_bytes, "little")) + self.sig_r = Word(int.from_bytes(self.ecdsa_chip.sig_r.le_bytes, "little")) + self.sig_s = Word(int.from_bytes(self.ecdsa_chip.sig_s.le_bytes, "little")) + self.recovered_addr = address + self.is_valid = is_valid + + self.pub_key_hash = pub_key_hash + + @classmethod + def assign( + cls, + signature: KeyAPI.Signature, + pub_key: KeyAPI.PublicKey, + msg_hash: bytes, + is_valid: bool = True, + ): + pub_key_hash = keccak(pub_key.to_bytes()) + self_pub_key_hash = pub_key_hash + self_address = FQ(int.from_bytes(pub_key_hash[-20:], "big")) + self_msg_hash = Word(int.from_bytes(msg_hash, "big")) + self_ecdsa_chip = ECDSAVerifyChip.assign(signature, pub_key, msg_hash) + return cls(self_pub_key_hash, self_address, self_msg_hash, self_ecdsa_chip, is_valid) + + def verify(self, keccak_table: KeccakTable, keccak_randomness: FQ, assert_msg: str): + # 0. Copy constraints between pub_key, msg_hash and signature of this chip + # and the ones in ECDSA chip + assert self.pub_key_x_bytes == self.ecdsa_chip.pub_key_x_bytes + assert self.pub_key_y_bytes == self.ecdsa_chip.pub_key_y_bytes + assert self.msg_hash_bytes == self.ecdsa_chip.msg_hash_bytes + assert self.sig_r.int_value() == int.from_bytes(self.ecdsa_chip.sig_r.le_bytes, "little") + assert self.sig_s.int_value() == int.from_bytes(self.ecdsa_chip.sig_s.le_bytes, "little") + + # 1. Constrain v to be equal 0 or 1 + assert self.sig_v == 0 or self.sig_v == 1 + + # 2. Verify that keccak(pub_key_bytes) = pub_key_hash by keccak table + # lookup, where pub_key_bytes is built from the pub_key in the + # ecdsa_chip + pub_key_bytes = bytes(reversed(self.pub_key_x_bytes)) + bytes( + reversed(self.pub_key_y_bytes) + ) + keccak_table.lookup( + True, + RLC(bytes(reversed(pub_key_bytes)), keccak_randomness, n_bytes=64).expr(), + FQ(64), + Word(self.pub_key_hash), + assert_msg, + ) + + # 3. Verify that the least significant 20 bytes of the pub_key_hash equals `recovered_addr` + addr_expr = FQ(int.from_bytes(bytes(self.pub_key_hash[-20:]), "big")) + assert ( + addr_expr == self.recovered_addr + ), f"{assert_msg}: {hex(addr_expr.n)} != {hex(self.recovered_addr.n)}" + + # 4. Verify that the signed message in the ecdsa_chip equals `msg_hash` + msg_hash = Word(self.msg_hash_bytes) + assert ( + msg_hash == self.msg_hash + ), f"{assert_msg}: {hex(msg_hash.int_value())} != {hex(self.msg_hash.int_value())}" + + # 5. Verify the ECDSA signature + is_valid = self.ecdsa_chip.verify() + assert is_valid == self.is_valid, f"{assert_msg}: {is_valid} != {self.is_valid}" + + +class Witness(NamedTuple): + rows: List[Row] # Transaction table rows + keccak_table: KeccakTable + + +@is_circuit_code +def verify_circuit( + witness: Witness, + keccak_randomness: FQ, +) -> None: + """ + Entry level circuit verification function + """ + for i, row in enumerate(witness.rows): + assert_msg = f"Constraints failed at row = {i}" + row.verify(witness.keccak_table, keccak_randomness, assert_msg) diff --git a/src/zkevm_specs/tx_circuit.py b/src/zkevm_specs/tx_circuit.py index 5770aa143..b19cda2b6 100644 --- a/src/zkevm_specs/tx_circuit.py +++ b/src/zkevm_specs/tx_circuit.py @@ -226,7 +226,7 @@ def verify(self, keccak_table: KeccakTable, keccak_randomness: FQ, assert_msg: s assert_msg, ) - # 2. Verify that the first 20 bytes of the pub_key_hash equal the address + # 2. Verify that the least significant 20 bytes of the pub_key_hash equal the address addr_expr = linear_combine_bytes(list(reversed(self.pub_key_hash[-20:])), FQ(2**8)) assert ( addr_expr == self.address diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 8f4940a05..c361f55bc 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -3,3 +3,5 @@ from .hash import * from .param import * from .typing import * +from .ec import * +from .tables import * diff --git a/src/zkevm_specs/util/ec.py b/src/zkevm_specs/util/ec.py new file mode 100644 index 000000000..eea102109 --- /dev/null +++ b/src/zkevm_specs/util/ec.py @@ -0,0 +1,113 @@ +from typing import Tuple +from .arithmetic import FQ +from eth_keys import KeyAPI # type: ignore + + +class WrongFieldInteger: + """ + Wrong Field arithmetic Integer, representing the implementation at + https://github.com/privacy-scaling-explorations/halo2wrong/blob/master/integer/src/integer.rs + """ + + limbs: Tuple[FQ, FQ, FQ, FQ] # Little-Endian limbs of [72, 72, 72, 40] bits + le_bytes: bytes # Little-Endian bytes + + def __init__(self, value: int) -> None: + mask = (1 << 72) - 1 + l0 = (value >> 0 * 72) & mask + l1 = (value >> 1 * 72) & mask + l2 = (value >> 2 * 72) & mask + l3 = (value >> 3 * 72) & mask + self.limbs = (FQ(l0), FQ(l1), FQ(l2), FQ(l3)) + self.le_bytes = value.to_bytes(32, "little") + + def to_le_bytes(self) -> bytes: + (l0, l1, l2, l3) = self.limbs + val = l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) + return val.to_bytes(32, "little") + + def to_be_bytes(self) -> bytes: + (l0, l1, l2, l3) = self.limbs + val = l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) + return val.to_bytes(32, "big") + + +class Secp256k1BaseField(WrongFieldInteger): + """ + Secp256k1 Base Field. + """ + + def __init__(self, value: int) -> None: + WrongFieldInteger.__init__(self, value) + + +class Secp256k1ScalarField(WrongFieldInteger): + """ + Secp256k1 Scalar Field. + """ + + def __init__(self, value: int) -> None: + WrongFieldInteger.__init__(self, value) + + +# TODO: There is another one used in tx_circuit, try to merge into one. +# Reminder: endianness of public key is differ with the one in tx_circuit +class ECDSAVerifyChip: + """ + ECDSA Signature Verification Chip. This represents an ECDSA signature + verification Chip as implemented in + https://github.com/privacy-scaling-explorations/halo2wrong/blob/master/ecdsa/src/ecdsa.rs + """ + + sig_v: Secp256k1ScalarField + sig_r: Secp256k1ScalarField + sig_s: Secp256k1ScalarField + pub_key: Tuple[Secp256k1BaseField, Secp256k1BaseField] + pub_key_x_bytes: bytes + pub_key_y_bytes: bytes + msg_hash: Secp256k1ScalarField + msg_hash_bytes: bytes + + def __init__( + self, + signature: Tuple[Secp256k1ScalarField, Secp256k1ScalarField, Secp256k1ScalarField], + pub_key: Tuple[Secp256k1BaseField, Secp256k1BaseField], + msg_hash: Secp256k1ScalarField, + ) -> None: + self.sig_v = signature[0] + self.sig_r = signature[1] + self.sig_s = signature[2] + self.pub_key = pub_key + self.msg_hash = msg_hash + self.pub_key_x_bytes = pub_key[0].to_le_bytes() + self.pub_key_y_bytes = pub_key[1].to_le_bytes() + self.msg_hash_bytes = msg_hash.to_be_bytes() + # NOTE: The circuit must constrain that all elements in the `*_bytes` + # parameters are in range 0..255 and that they represent the same + # value as their corresponding WrongFieldInteger limbs. + + @classmethod + def assign(cls, signature: KeyAPI.Signature, pub_key: KeyAPI.PublicKey, msg_hash: bytes): + # signature + self_sig_v = Secp256k1ScalarField(signature.v) + self_sig_r = Secp256k1ScalarField(signature.r) + self_sig_s = Secp256k1ScalarField(signature.s) + # public key + pub_key_bytes = pub_key.to_bytes() + pub_key_bytes_x, pub_key_bytes_y = pub_key_bytes[:32], pub_key_bytes[32:] + pub_key_x = int.from_bytes(pub_key_bytes_x, "big") + pub_key_y = int.from_bytes(pub_key_bytes_y, "big") + self_pub_key = (Secp256k1BaseField(pub_key_x), Secp256k1BaseField(pub_key_y)) + # message hash + self_msg_hash = Secp256k1ScalarField(int.from_bytes(msg_hash, "big")) + return cls((self_sig_v, self_sig_r, self_sig_s), self_pub_key, self_msg_hash) + + def verify(self) -> bool: + sig_v = int.from_bytes(self.sig_v.to_le_bytes(), "little") + sig_r = int.from_bytes(self.sig_r.to_le_bytes(), "little") + sig_s = int.from_bytes(self.sig_s.to_le_bytes(), "little") + signature = KeyAPI.Signature(vrs=[sig_v, sig_r, sig_s]) + + msg_hash = bytes(self.msg_hash.to_be_bytes()) + public_key = KeyAPI.PublicKey(self.pub_key[0].to_be_bytes() + self.pub_key[1].to_be_bytes()) + return KeyAPI().ecdsa_verify(msg_hash, signature, public_key) diff --git a/src/zkevm_specs/util/tables.py b/src/zkevm_specs/util/tables.py new file mode 100644 index 000000000..0816bfb58 --- /dev/null +++ b/src/zkevm_specs/util/tables.py @@ -0,0 +1,33 @@ +from typing import Tuple, Set +from .arithmetic import ( + FQ, + RLC, + Word, +) +from eth_utils import keccak + + +class KeccakTable: + # The columns are: (is_enabled, input_rlc, input_len, output) + table: Set[Tuple[FQ, FQ, FQ, Word]] + + def __init__(self): + self.table = set() + self.table.add((FQ(0), FQ(0), FQ(0), Word(0))) # Add all 0s row + + def add(self, input: bytes, keccak_randomness: FQ): + output = keccak(input) + self.table.add( + ( + FQ(1), + RLC(bytes(reversed(input)), keccak_randomness, n_bytes=64).expr(), + FQ(len(input)), + Word(output), + ) + ) + + def lookup(self, is_enabled: FQ, input_rlc: FQ, input_len: FQ, output: Word, assert_msg: str): + assert (is_enabled, input_rlc, input_len, output) in self.table, ( + f"{assert_msg}: {(is_enabled, input_rlc, input_len, output)} " + + "not found in the lookup table" + ) diff --git a/tests/evm/precompiles/test_ecRecover.py b/tests/evm/precompiles/test_ecRecover.py new file mode 100644 index 000000000..9bf610a46 --- /dev/null +++ b/tests/evm/precompiles/test_ecRecover.py @@ -0,0 +1,173 @@ +import pytest + +from typing import List +from common import CallContext, rand_fq +from eth_keys import keys # type: ignore +from eth_utils import keccak +from zkevm_specs.evm_circuit import ( + Bytecode, + CallContextFieldTag, + ExecutionState, + Precompile, + RWDictionary, + StepState, + Tables, + verify_steps, +) +from zkevm_specs.evm_circuit.execution.precompiles.ecrecover import SECP256K1N +from zkevm_specs.util import ( + Word, + FQ, +) +from zkevm_specs.evm_circuit.table import SigTableRow + + +def gen_testing_data(): + # basic + msg = "Hello World!" + sk = keys.PrivateKey(rand_fq().n.to_bytes(32, "little")) + address = sk.public_key.to_canonical_address() + msg_hash = keccak(bytes(msg, "utf-8")) + sig = sk.sign_msg_hash(msg_hash) + v = sig.v + r = sig.r + s = sig.s + + # successful case + normal = [CallContext(), msg_hash, v, r, s, address] + + # failure cases + zero_addr = [CallContext(), msg_hash, v, r, s, bytes(0)] + sig_r_over_ub = [CallContext(), msg_hash, v, SECP256K1N, s, bytes(0)] + sig_s_over_ub = [CallContext(), msg_hash, v, r, SECP256K1N, bytes(0)] + sig_r_zero = [CallContext(), msg_hash, v, 0, s, bytes(0)] + sig_s_zero = [CallContext(), msg_hash, v, r, 0, bytes(0)] + sig_v_29 = [CallContext(), msg_hash, v, SECP256K1N, s, bytes(0)] + + return [normal, zero_addr, sig_r_over_ub, sig_s_over_ub, sig_r_zero, sig_s_zero, sig_v_29] + + +TESTING_DATA = gen_testing_data() + + +@pytest.mark.parametrize( + "caller_ctx, msg_hash, v, r, s, address", + TESTING_DATA, +) +def test_ecRecover( + caller_ctx: CallContext, + msg_hash: bytes, + v: int, + r: int, + s: int, + address: bytes, +): + call_id = 1 + callee_id = 2 + gas = Precompile.ECRECOVER.base_gas_cost() + + recovered = True if len(address) != 0 else False + call_data_offset = 0 + call_data_length = 0x80 + return_data_offset = 0 + return_data_length = 0x20 if recovered else 0 + + aux_data = [ + Word(msg_hash), + Word(v + 27), + Word(r), + Word(s), + FQ(int.from_bytes(address, "big")), + ] + + # assign sig_table + sig_row: List[SigTableRow] = [] + sig_row.append( + SigTableRow( + Word(msg_hash), + FQ(v), + Word(r), + Word(s), + FQ(int.from_bytes(address, "big")), + FQ(recovered), + ) + ) + + code = ( + Bytecode() + .call( + gas, + Precompile.ECRECOVER, + 0, + call_data_offset, + call_data_length, + return_data_offset, + return_data_length, + ) + .stop() + ) + code_hash = Word(code.hash()) + + rw_dictionary = ( + # fmt: off + RWDictionary(1) + .call_context_read(callee_id, CallContextFieldTag.IsSuccess, FQ(1)) + .call_context_read(callee_id, CallContextFieldTag.CalleeAddress, Word(Precompile.ECRECOVER)) + # fmt: on + ) + + rw_dictionary = ( + # fmt: off + rw_dictionary + .call_context_read(callee_id, CallContextFieldTag.CallerId, call_id) + .call_context_read(call_id, CallContextFieldTag.IsRoot, False) + .call_context_read(call_id, CallContextFieldTag.IsCreate, False) + .call_context_read(call_id, CallContextFieldTag.CodeHash, code_hash) + .call_context_read(call_id, CallContextFieldTag.ProgramCounter, caller_ctx.program_counter) + .call_context_read(call_id, CallContextFieldTag.StackPointer, caller_ctx.stack_pointer) + .call_context_read(call_id, CallContextFieldTag.GasLeft, caller_ctx.gas_left) + .call_context_read(call_id, CallContextFieldTag.MemorySize, caller_ctx.memory_word_size) + .call_context_read(call_id, CallContextFieldTag.ReversibleWriteCounter, caller_ctx.reversible_write_counter) + .call_context_write(call_id, CallContextFieldTag.LastCalleeId, callee_id) + .call_context_write(call_id, CallContextFieldTag.LastCalleeReturnDataOffset, FQ(return_data_offset)) + .call_context_write(call_id, CallContextFieldTag.LastCalleeReturnDataLength, FQ(return_data_length)) + # fmt: on + ) + + tables = Tables( + block_table=set(), + tx_table=set(), + withdrawal_table=set(), + bytecode_table=set(code.table_assignments()), + rw_table=set(rw_dictionary.rws), + sig_table=set(sig_row), + ) + + verify_steps( + tables, + steps=[ + StepState( + execution_state=ExecutionState.ECRECOVER, + rw_counter=1, + call_id=callee_id, + is_root=False, + code_hash=code_hash, + program_counter=caller_ctx.program_counter - 1, + stack_pointer=1023, + memory_word_size=call_data_length, + gas_left=gas, + aux_data=aux_data, + ), + StepState( + execution_state=ExecutionState.STOP, + rw_counter=rw_dictionary.rw_counter, + call_id=call_id, + is_root=False, + code_hash=code_hash, + program_counter=caller_ctx.program_counter, + stack_pointer=caller_ctx.stack_pointer, + memory_word_size=caller_ctx.memory_word_size, + gas_left=0, + ), + ], + ) diff --git a/tests/test_sig_circuit.py b/tests/test_sig_circuit.py new file mode 100644 index 000000000..aebb68c6a --- /dev/null +++ b/tests/test_sig_circuit.py @@ -0,0 +1,164 @@ +from typing import NamedTuple, List +from eth_keys import keys # type: ignore +from eth_utils import keccak +from zkevm_specs.sig_circuit import * +from zkevm_specs.util import FQ +from common import rand_fq +from zkevm_specs.util import ( + FQ, + Word, + U160, + U256, +) + +keccak_randomness = rand_fq() +r = keccak_randomness + + +class SignedData(NamedTuple): + msg_hash: bytes + sig_v: U256 + sig_r: U256 + sig_s: U256 + addr: U160 + is_valid: bool + + +def sign_msg(sk: keys.PrivateKey, msg: bytes, valid: bool = True) -> SignedData: + """ + Return a copy of the signed data + """ + + msg_hash = keccak(msg) + sig = sk.sign_msg_hash(msg_hash) + sig_v = sig.v + sig_r = sig.r if valid else U256(1) + sig_s = sig.s if valid else U256(1) + return SignedData(msg_hash, sig_v, sig_r, sig_s, int(sk.public_key.to_address(), 16), valid) + + +def signedData2witness( + signed_data: List[SignedData], + keccak_randomness: FQ, +) -> Witness: + """ + Generate the complete witness of a list of signed data. + """ + + rows: List[Row] = [] + keccak_table = KeccakTable() + for i, data in enumerate(signed_data): + sig = KeyAPI.Signature(vrs=(data.sig_v, data.sig_r, data.sig_s)) + pk = sig.recover_public_key_from_msg_hash(data.msg_hash) + ecdsa_chip = ECDSAVerifyChip.assign(sig, pk, data.msg_hash) + + pk_bytes = pk.to_bytes() + keccak_table.add(pk_bytes, keccak_randomness) + pk_hash = keccak(pk_bytes) + rows.append( + Row( + pk_hash, + FQ(data.addr), + Word(data.msg_hash), + ecdsa_chip, + ) + ) + + return Witness(rows, keccak_table) + + +def gen_witness(num: int = 10, valid: bool = True) -> Witness: + sks = [keys.PrivateKey(bytes([byte + 1]) * 32) for byte in range(num)] + + list: List[SignedData] = [] + for sk in sks: + signed_msg = sign_msg(sk, bytes("Message", "utf-8"), valid) + list.append(signed_msg) + + witness = signedData2witness(list, r) + return witness + + +def verify( + witness: Witness, + keccak_randomness: FQ, + success: bool = True, +): + """ + Verify the circuit with the assigned witness (or the witness calculated + from the transactions). If `success` is False, expect the verification to + fail. + """ + + exception = None + try: + verify_circuit( + witness, + keccak_randomness, + ) + except AssertionError as e: + exception = e + + if success: + if exception: + raise exception + assert exception is None + else: + assert exception is not None + + +def test_ecdsa_verify_chip(): + sk = keys.PrivateKey(b"\x02" * 32) + pk = sk.public_key + msg_hash = b"\xae" * 32 + sig = sk.sign_msg_hash(msg_hash) + + ecdsa_chip = ECDSAVerifyChip.assign(sig, pk, msg_hash) + assert ecdsa_chip.verify() == True + + +def test_sig_verify(): + witness = gen_witness() + verify(witness, r) + + +def test_sig_incorrect_keccak(): + witness = gen_witness() + # Set empty keccak lookup table + witness = Witness(witness.rows, KeccakTable()) + verify(witness, r, success=False) + + +def test_sig_incorrect_signature(): + witness = gen_witness(10, False) + verify(witness, r, success=False) + + +def test_sig_incorrect_signature_v(): + witness = gen_witness(1) + witness.rows[0].sig_v = FQ(2) + verify(witness, r, success=False) + + +def test_sig_incorrect_msg_hash(): + witness = gen_witness(1) + witness.rows[0].msg_hash = Word(1) + verify(witness, r, success=False) + + +def test_sig_inconsistent_msg_hash(): + witness = gen_witness(1) + witness.rows[0].ecdsa_chip.msg_hash_bytes = Word(1) + verify(witness, r, success=False) + + +def test_sig_inconsistent_pub_key_hash(): + witness = gen_witness(1) + witness.rows[0].pub_key_hash = Word(1) + verify(witness, r, success=False) + + +def test_sig_incorrect_address(): + witness = gen_witness(1) + witness.rows[0].recovered_addr = FQ(1) + verify(witness, r, success=False)