diff --git a/tests/functional/codegen/modules/test_exports.py b/tests/functional/codegen/modules/test_exports.py index 2dc90bfe74..b02ed6ba9e 100644 --- a/tests/functional/codegen/modules/test_exports.py +++ b/tests/functional/codegen/modules/test_exports.py @@ -1,3 +1,9 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.utils import method_id + + def test_simple_export(make_input_bundle, get_contract): lib1 = """ @external @@ -147,3 +153,290 @@ def foo() -> uint256: c = get_contract(main, input_bundle=input_bundle) assert c.foo() == 5 + + +@pytest.fixture +def simple_library(make_input_bundle): + ifoo = """ +@external +def foo() -> uint256: + ... + +@external +def bar() -> uint256: + ... + """ + ibar = """ +@external +def bar() -> uint256: + ... + +@external +def qux() -> uint256: + ... + """ + lib1 = """ +import ifoo +import ibar + +implements: ifoo +implements: ibar + +@external +def foo() -> uint256: + return 1 + +@external +def bar() -> uint256: + return 2 + +@external +def qux() -> uint256: + return 3 + """ + return make_input_bundle({"lib1.vy": lib1, "ifoo.vyi": ifoo, "ibar.vyi": ibar}) + + +@pytest.fixture +def send_failing_tx_to_signature(w3, tx_failed): + def _send_transaction(c, method_sig): + data = method_id(method_sig) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "data": data}) + + return _send_transaction + + +def test_exports_interface_simple(get_contract, simple_library): + main = """ +import lib1 + +exports: lib1.__interface__ + """ + c = get_contract(main, input_bundle=simple_library) + assert c.foo() == 1 + assert c.bar() == 2 + assert c.qux() == 3 + + +def test_exports_interface2(get_contract, send_failing_tx_to_signature, simple_library): + main = """ +import lib1 + +exports: lib1.ifoo + """ + out = compile_code( + main, output_formats=["abi"], contract_path="main.vy", input_bundle=simple_library + ) + fnames = [item["name"] for item in out["abi"]] + assert fnames == ["foo", "bar"] + + c = get_contract(main, input_bundle=simple_library) + assert c.foo() == 1 + assert c.bar() == 2 + assert not hasattr(c, "qux") + send_failing_tx_to_signature(c, "qux()") + + +def test_exported_fun_part_of_interface(get_contract, make_input_bundle): + main = """ +import lib2 + +exports: lib2.__interface__ + """ + lib1 = """ +@external +def bar() -> uint256: + return 1 + """ + lib2 = """ +import lib1 + +@external +def foo() -> uint256: + return 2 + +exports: lib1.bar + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(main, input_bundle=input_bundle) + assert c.bar() == 1 + assert c.foo() == 2 + + +def test_imported_module_not_part_of_interface( + send_failing_tx_to_signature, get_contract, make_input_bundle +): + main = """ +import lib2 + +exports: lib2.__interface__ + """ + lib1 = """ +@external +def bar() -> uint256: + return 1 + """ + lib2 = """ +import lib1 + +@external +def foo() -> uint256: + return 2 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 2 + send_failing_tx_to_signature(c, "bar()") + + +def test_export_unimplemented_function( + send_failing_tx_to_signature, get_contract, make_input_bundle +): + ifoo = """ +@external +def foo() -> uint256: + ... + """ + lib1 = """ +import ifoo +implements: ifoo + +@external +def foo() -> uint256: + return 1 + +@external +def bar() -> uint256: + return 2 + """ + main = """ +import lib1 + +exports: lib1.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "ifoo.vyi": ifoo}) + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 1 + send_failing_tx_to_signature(c, "bar()") + + +# sanity check that when multiple modules implement an interface, the +# correct one (specified by the user) gets selected for export. +def test_export_interface_multiple_choices(get_contract, make_input_bundle): + ifoo = """ +@external +def foo() -> uint256: + ... + """ + lib1 = """ +import ifoo +implements: ifoo + +@external +def foo() -> uint256: + return 1 + """ + lib2 = """ +import ifoo +implements: ifoo + +@external +def foo() -> uint256: + return 2 + """ + main = """ +import lib1 +import lib2 + +exports: lib1.ifoo + """ + main2 = """ +import lib1 +import lib2 + +exports: lib2.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "ifoo.vyi": ifoo}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 1 + + c = get_contract(main2, input_bundle=input_bundle) + assert c.foo() == 2 + + +def test_export_module_with_init(get_contract, make_input_bundle): + lib1 = """ +@deploy +def __init__(): + pass + +@external +def foo() -> uint256: + return 1 + """ + main = """ +import lib1 + +exports: lib1.__interface__ + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 1 + + +def test_export_module_with_getter(get_contract, make_input_bundle): + lib1 = """ +counter: public(uint256) + +@external +def foo(): + self.counter += 1 + """ + main = """ +import lib1 + +initializes: lib1 +exports: lib1.__interface__ + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + assert c.counter() == 100 + c.foo(transact={}) + assert c.counter() == 101 + + +def test_export_module_with_default(w3, get_contract, make_input_bundle): + lib1 = """ +counter: public(uint256) + +@external +def foo() -> uint256: + return 1 + +@external +def __default__(): + self.counter += 1 + """ + main = """ +import lib1 +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 5 + +exports: lib1.__interface__ + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 1 + assert c.counter() == 5 + # call `c.__default__()` + w3.eth.send_transaction({"to": c.address}) + assert c.counter() == 6 diff --git a/tests/functional/syntax/modules/test_exports.py b/tests/functional/syntax/modules/test_exports.py index 1edb99bc7f..7b00d29c98 100644 --- a/tests/functional/syntax/modules/test_exports.py +++ b/tests/functional/syntax/modules/test_exports.py @@ -1,7 +1,12 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import ImmutableViolation, NamespaceCollision, StructureException +from vyper.exceptions import ( + ImmutableViolation, + InterfaceViolation, + NamespaceCollision, + StructureException, +) from .helpers import NONREENTRANT_NOTE @@ -309,3 +314,133 @@ def bar(): assert e.value.prev_decl.col_offset == 9 assert e.value.prev_decl.node_source_code == "lib1.foo" assert e.value.prev_decl.module_node.path == "main.vy" + + +def test_interface_export_collision(make_input_bundle): + main = """ +import lib1 + +exports: lib1.__interface__ +exports: lib1.bar + """ + lib1 = """ +@external +def bar() -> uint256: + return 1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "already exported!" + + +def test_no_export_missing_function(make_input_bundle): + ifoo = """ +@external +def do_xyz(): + ... + """ + lib1 = """ +import ifoo + +@external +@view +def bar() -> uint256: + return 1 + """ + main = """ +import lib1 + +exports: lib1.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "ifoo.vyi": ifoo}) + with pytest.raises(InterfaceViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + +def test_no_export_unimplemented_interface(make_input_bundle): + ifoo = """ +@external +def do_xyz(): + ... + """ + lib1 = """ +import ifoo + +# technically implements ifoo, but missing `implements: ifoo` + +@external +def do_xyz(): + pass + """ + main = """ +import lib1 + +exports: lib1.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "ifoo.vyi": ifoo}) + with pytest.raises(InterfaceViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + +def test_export_selector_conflict(make_input_bundle): + ifoo = """ +@external +def gsf(): + ... + """ + lib1 = """ +import ifoo + +@external +def gsf(): + pass + +@external +@view +def tgeo() -> uint256: + return 1 + """ + main = """ +import lib1 + +exports: (lib1.ifoo, lib1.tgeo) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "ifoo.vyi": ifoo}) + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "Methods produce colliding method ID `0x67e43e43`: gsf(), tgeo()" + + +def test_export_different_return_type(make_input_bundle): + ifoo = """ +@external +def foo() -> uint256: + ... + """ + lib1 = """ +import ifoo + +foo: public(int256) + +@deploy +def __init__(): + self.foo = -1 + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.ifoo + +@deploy +def __init__(): + lib1.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "ifoo.vyi": ifoo}) + with pytest.raises(InterfaceViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 9d3f9ae1ff..cb1dc8430f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -20,6 +20,7 @@ ExceptionList, ImmutableViolation, InitializerException, + InterfaceViolation, InvalidLiteral, InvalidType, ModuleNotFound, @@ -530,42 +531,76 @@ def visit_InitializesDecl(self, node): def visit_ExportsDecl(self, node): items = vy_ast.as_tuple(node.annotation) - funcs = [] + exported_funcs = [] used_modules = OrderedSet() + # CMC 2024-04-13 TODO: reduce nesting in this function + for item in items: # set is_callable=True to give better error messages for imported # types, e.g. exports: some_module.MyEvent info = get_expr_info(item, is_callable=True) + if info.var_info is not None: - decl_node = info.var_info.decl_node + decl = info.var_info.decl_node if not info.var_info.is_public: - raise StructureException("not a public variable!", decl_node, item) - func_t = decl_node._expanded_getter._metadata["func_type"] - - else: + raise StructureException("not a public variable!", decl, item) + funcs = [decl._expanded_getter._metadata["func_type"]] + elif isinstance(info.typ, ContractFunctionT): # regular function - func_t = info.typ - decl_node = func_t.decl_node + funcs = [info.typ] + elif isinstance(info.typ, InterfaceT): + if not isinstance(item, vy_ast.Attribute): + raise StructureException( + "invalid export", + hint="exports should look like .", + ) + + module_info = get_expr_info(item.value).module_info + if module_info is None: + raise StructureException("not a valid module!", item.value) + + if info.typ not in module_info.typ.implemented_interfaces: + iface_str = item.node_source_code + module_str = item.value.node_source_code + msg = f"requested `{iface_str}` but `{module_str}`" + msg += f" does not implement `{iface_str}`!" + raise InterfaceViolation(msg, item) + + module_exposed_fns = {fn.name: fn for fn in module_info.typ.exposed_functions} + # find the specific implementation of the function in the module + funcs = [ + module_exposed_fns[fn.name] + for fn in info.typ.functions.values() + if fn.is_external + ] + else: + raise StructureException( + f"not a function or interface: `{info.typ}`", info.typ.decl_node, item + ) - if not isinstance(func_t, ContractFunctionT): - raise StructureException(f"not a function: `{func_t}`", decl_node, item) - if not func_t.is_external: - raise StructureException("can't export non-external functions!", decl_node, item) + for func_t in funcs: + if not func_t.is_external: + raise StructureException( + "can't export non-external functions!", func_t.decl_node, item + ) + + self._add_exposed_function(func_t, item, relax=False) + with tag_exceptions(item): # tag exceptions with specific item + self._self_t.typ.add_member(func_t.name, func_t) + + exported_funcs.append(func_t) - self._add_exposed_function(func_t, item, relax=False) - with tag_exceptions(item): # tag with specific item - self._self_t.typ.add_member(func_t.name, func_t) + # check module uses + if func_t.uses_state(): + module_info = check_module_uses(item) - funcs.append(func_t) + # guaranteed by above checks: + assert module_info is not None - # check module uses - if func_t.uses_state(): - module_info = check_module_uses(item) - assert module_info is not None # guaranteed by above checks - used_modules.add(module_info) + used_modules.add(module_info) - node._metadata["exports_info"] = ExportsInfo(funcs, used_modules) + node._metadata["exports_info"] = ExportsInfo(exported_funcs, used_modules) @property def _self_t(self): @@ -574,7 +609,7 @@ def _self_t(self): def _add_exposed_function(self, func_t, node, relax=True): # call this before self._self_t.typ.add_member() for exception raising # priority - if (prev_decl := self._exposed_functions.get(func_t)) is not None: + if not relax and (prev_decl := self._exposed_functions.get(func_t)) is not None: raise StructureException("already exported!", node, prev_decl=prev_decl) self._exposed_functions[func_t] = node diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 4557fc9612..51d55a167e 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -331,6 +331,8 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): # can access interfaces in type position self._helper.add_member(name, TYPE_T(interface_t)) + self.add_member("__interface__", self.interface) + # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, # search path, symlinked vs normalized path, etc.) @@ -371,6 +373,15 @@ def interface_defs(self): def implements_decls(self): return self._module.get_children(vy_ast.ImplementsDecl) + @cached_property + def implemented_interfaces(self): + ret = [node._metadata["interface_type"] for node in self.implements_decls] + + # a module implicitly implements module.__interface__. + ret.append(self.interface) + + return ret + @cached_property def interfaces(self) -> dict[str, InterfaceT]: ret = {}