From 4c66c8c10a9164e464a3cf786bc199f854a9beac Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 9 May 2024 18:26:28 -0400 Subject: [PATCH] fix[lang]: recursion in `uses` analysis for nonreentrant functions (#3971) this commit fixes `uses` analysis for nonreentrant functions, which are called recursively. a partial fix for this was applied in cb940684a9137, but it missed the case where a nonreentrant function is deep in the call tree. --- .../syntax/modules/test_initializers.py | 80 +++++++++++++------ vyper/semantics/types/function.py | 6 +- 2 files changed, 62 insertions(+), 24 deletions(-) diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index 2193050a5f..29d611d54a 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -1300,52 +1300,86 @@ def foo(): assert e.value._hint == "try importing lib1 first" -def test_nonreentrant_exports(make_input_bundle): +@pytest.fixture +def nonreentrant_library_bundle(make_input_bundle): + # test simple case lib1 = """ # lib1.vy -@external +@internal @nonreentrant def bar(): pass + +# lib1.vy +@external +@nonreentrant +def ext_bar(): + pass """ - main = """ + # test case with recursion + lib2 = """ +@internal +def bar(): + self.baz() + +@external +def ext_bar(): + self.baz() + +@nonreentrant +@internal +def baz(): + return + """ + # test case with nested recursion + lib3 = """ import lib1 +uses: lib1 -exports: lib1.bar # line 4 +@internal +def bar(): + lib1.bar() + +@external +def ext_bar(): + lib1.bar() + """ + + return make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + +@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3")) +def test_nonreentrant_exports(nonreentrant_library_bundle, lib): + main = f""" +import {lib} + +exports: {lib}.ext_bar # line 4 @external def foo(): pass """ - input_bundle = make_input_bundle({"lib1.vy": lib1}) with pytest.raises(ImmutableViolation) as e: - compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE - hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" + compile_code(main, input_bundle=nonreentrant_library_bundle) + assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE + hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract" assert e.value._hint == hint assert e.value.annotations[0].lineno == 4 -def test_internal_nonreentrant_import(make_input_bundle): - lib1 = """ -# lib1.vy -@internal -@nonreentrant -def bar(): - pass - """ - main = """ -import lib1 +@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3")) +def test_internal_nonreentrant_import(nonreentrant_library_bundle, lib): + main = f""" +import {lib} @external def foo(): - lib1.bar() # line 6 + {lib}.bar() # line 6 """ - input_bundle = make_input_bundle({"lib1.vy": lib1}) with pytest.raises(ImmutableViolation) as e: - compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE + compile_code(main, input_bundle=nonreentrant_library_bundle) + assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE - hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" + hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract" assert e.value._hint == hint assert e.value.annotations[0].lineno == 6 diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 86fd90f0f9..7eab0958a6 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -165,7 +165,11 @@ def get_variable_accesses(self): return self._variable_reads | self._variable_writes def uses_state(self): - return self.nonreentrant or uses_state(self.get_variable_accesses()) + return ( + self.nonreentrant + or uses_state(self.get_variable_accesses()) + or any(f.nonreentrant for f in self.reachable_internal_functions) + ) def get_used_modules(self): # _used_modules is populated during analysis