diff --git a/tests/evm/precompiles/test_ecRecover.py b/tests/evm/precompiles/test_ecRecover.py index 8ed8d55bf..2f46f88c4 100644 --- a/tests/evm/precompiles/test_ecRecover.py +++ b/tests/evm/precompiles/test_ecRecover.py @@ -4,7 +4,6 @@ from common import CallContext, rand_fq from eth_keys import keys # type: ignore from eth_utils import keccak -from itertools import product from zkevm_specs.evm_circuit import ( Bytecode, CallContextFieldTag, @@ -15,15 +14,13 @@ Tables, verify_steps, ) +from zkevm_specs.evm_circuit.execution.precompiles.ecrecover import SECP256K1N from zkevm_specs.util import ( Word, FQ, ) from zkevm_specs.evm_circuit.table import SigTableRow -# FIXME: import from ecRecover.py -SECP256K1N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 - def gen_testing_data(): # basic @@ -32,49 +29,44 @@ def gen_testing_data(): address = sk.public_key.to_canonical_address() msg_hash = keccak(bytes(msg, "utf-8")) sig = sk.sign_msg_hash(msg_hash) + v = sig.v + r = sig.r + s = sig.s - # sig_r is over upper bound - # sig2 = sig - # sig2.r = SECP256K1N - - # # sig_r is over upper bound - # sig3 = sig - # sig3.r = 1 - - # sig_v is 29 - # sig4 = sig - # sig4.v = 29 - - # print(f"** {sig} =?= {sig4}") + # successful case + normal = [CallContext(), msg_hash, v, r, s, address] - normal = [CallContext(), msg_hash, sig, address, True] - # zero_addr = [CallContext(), msg_hash, sig, 0, False] - # sig_r_over_ub = [CallContext(), msg_hash, sig2, 0, False] - # sig_r_zero = [CallContext(), msg_hash, sig3, 0, False] - # sig_v_29 = [CallContext(), msg_hash, sig4, 0, False] + # failure cases + zero_addr = [CallContext(), msg_hash, v, r, s, bytes(0)] + sig_r_over_ub = [CallContext(), msg_hash, v, SECP256K1N, s, bytes(0)] + sig_s_over_ub = [CallContext(), msg_hash, v, r, SECP256K1N, bytes(0)] + sig_r_one = [CallContext(), msg_hash, v, 1, s, bytes(0)] + sig_s_one = [CallContext(), msg_hash, v, r, 1, bytes(0)] + sig_v_29 = [CallContext(), msg_hash, v, SECP256K1N, s, bytes(0)] - return [normal] - # return [normal, zero_addr, sig_r_over_ub, sig_r_zero, sig_v_29] + return [normal, zero_addr, sig_r_over_ub, sig_s_over_ub, sig_r_one, sig_s_one, sig_v_29] TESTING_DATA = gen_testing_data() @pytest.mark.parametrize( - "caller_ctx, msg_hash, sig, address, success", + "caller_ctx, msg_hash, v, r, s, address", TESTING_DATA, ) def test_ecRecover( caller_ctx: CallContext, msg_hash: bytes, - sig: keys.Signature, + v: int, + r: int, + s: int, address: bytes, - success: bool, ): call_id = 1 callee_id = 2 gas = Precompile.ECRECOVER.base_gas_cost() + success = True if len(address) != 0 else False call_data_offset = 0 call_data_length = 0x80 return_data_offset = 0 @@ -82,19 +74,20 @@ def test_ecRecover( aux_data = [ Word(msg_hash), - Word(sig.v + 27), - Word(sig.r), - Word(sig.s), + Word(v + 27), + Word(r), + Word(s), FQ(int.from_bytes(address, "big")), ] + # assign sig_table sig_row: List[SigTableRow] = [] sig_row.append( SigTableRow( Word(msg_hash), - FQ(sig.v), - Word(sig.r), - Word(sig.s), + FQ(v), + Word(r), + Word(s), FQ(int.from_bytes(address, "big")), FQ(success), )