From bb6df4fe1d222dab90d3f32d1b55e63173a9ac5b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 8 Feb 2024 17:38:27 -0500 Subject: [PATCH 01/26] fix iterator modification analysis --- vyper/semantics/analysis/base.py | 3 +- vyper/semantics/analysis/local.py | 115 ++++++++-------------------- vyper/semantics/environment.py | 4 +- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/primitives.py | 6 ++ 5 files changed, 45 insertions(+), 85 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 2086e5f9da..e59132bd82 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -7,6 +7,7 @@ from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.semantics.types.primitives import SelfT from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: @@ -224,7 +225,7 @@ def __post_init__(self): # `self.my_struct.x.y` will return varinfo for `self.my_struct` def get_root_varinfo(self) -> Optional[VarInfo]: for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None: + if expr_info.var_info is not None and not isinstance(expr_info.typ, SelfT): return expr_info.var_info return None diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..758adefa49 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,5 +1,6 @@ # CMC 2024-02-03 TODO: split me into function.py and expr.py +import contextlib from typing import Optional from vyper import ast as vy_ast @@ -99,36 +100,6 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: return ret -def _check_iterator_modification( - target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode -) -> Optional[vy_ast.VyperNode]: - similar_nodes = [ - n - for n in search_node.get_descendants(type(target_node)) - if vy_ast.compare_nodes(target_node, n) - ] - - for node in similar_nodes: - # raise if the node is the target of an assignment statement - assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) - # note the use of get_descendants() blocks statements like - # self.my_array[i] = x - if assign_node and node in assign_node.target.get_descendants(include_self=True): - return node - - attr_node = node.get_ancestor(vy_ast.Attribute) - # note the use of get_descendants() blocks statements like - # self.my_array[i].append(x) - if ( - attr_node is not None - and node in attr_node.value.get_descendants(include_self=True) - and attr_node.attr in ("append", "pop", "extend") - ): - return node - - return None - - # helpers def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": @@ -196,6 +167,8 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) + self.loop_variables: list[Optional[VarInfo]] = [] + def analyze(self): # allow internal function params to be mutable if self.func.is_internal: @@ -225,6 +198,14 @@ def analyze(self): for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) + @contextlib.contextmanager + def enter_for_loop(self, varinfo: Optional[VarInfo]): + self.loop_variables.append(varinfo) + try: + yield + finally: + self.loop_variables.pop() + def visit(self, node): super().visit(node) @@ -409,6 +390,8 @@ def visit_For(self, node): target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_varinfo = None + if isinstance(node.iter, vy_ast.Call): # iteration via range() if node.iter.get("func.id") != "range": @@ -419,7 +402,10 @@ def visit_For(self, node): else: # iteration over a variable or literal list - iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter + iter_val = node.iter + if iter_val.has_folded_value: + iter_val = node.iter.get_folded_value() + if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) @@ -428,50 +414,11 @@ def visit_For(self, node): ): raise InvalidType("Not an iterable type", node.iter) - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # check for references to the iterated value within the body of the loop - assign = _check_iterator_modification(node.iter, node) - if assign: - raise ImmutableViolation("Cannot modify array during iteration", assign) - - # Check if `iter` is a storage variable. get_descendants` is used to check for - # nested `self` (e.g. structs) - # NOTE: this analysis will be borked once stateful modules are allowed! - iter_is_storage_var = ( - isinstance(node.iter, vy_ast.Attribute) - and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 - ) - - if iter_is_storage_var: - # check if iterated value may be modified by function calls inside the loop - iter_name = node.iter.attr - for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): - fn_name = call_node.func.attr - - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] - if _check_iterator_modification(node.iter, fn_node): - # check for direct modification - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it potentially " - f"modifies iterated storage variable '{iter_name}'", - call_node, - ) + info = get_expr_info(iter_val) + iter_varinfo = info.get_root_varinfo() - for reachable_t in ( - self.namespace["self"].typ.members[fn_name].reachable_internal_functions - ): - # check for indirect modification - name = reachable_t.name - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] - if _check_iterator_modification(node.iter, fn_node): - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " - f"which potentially modifies iterated storage variable '{iter_name}'", - call_node, - ) - - target_name = node.target.target.id - with self.namespace.enter_scope(): + with self.namespace.enter_scope(), self.enter_for_loop(iter_varinfo): + target_name = node.target.target.id self.namespace[target_name] = VarInfo( target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) @@ -581,7 +528,17 @@ def visit(self, node, typ): if varinfo is not None: info._reads.add(varinfo) - if self.func: + if self.function_analyzer: + for s in self.function_analyzer.loop_variables: + if s is None: + continue + + if s in info._writes: + msg = "Cannot modify loop variable" + if s.decl_node is not None: + msg += f" `{s.decl_node.target.id}`" + raise ImmutableViolation(msg, s.decl_node, node) + variable_accesses = info._writes | info._reads for s in variable_accesses: if s.is_module_variable(): @@ -641,7 +598,6 @@ def _check_call_mutability(self, call_mutability: StateMutability): def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: func_info = get_expr_info(node.func, is_callable=True) func_type = func_info.typ - self.visit(node.func, func_type) if isinstance(func_type, ContractFunctionT): # function calls @@ -650,13 +606,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: func_info._reads.update(func_type._variable_reads) if self.function_analyzer: - if func_type.is_internal: - self.func.called_functions.add(func_type) - self._check_call_mutability(func_type.mutability) - # check that if the function accesses state, the defining - # module has been `used` or `initialized`. for s in func_type._variable_accesses: if s.is_module_variable(): self.function_analyzer._check_module_use(node.func) @@ -702,6 +653,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) + self.visit(node.func, func_type) + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # membership in list literal - `x in [a, b, c]` diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index 38bac0a63d..94a26157af 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,7 +1,7 @@ from typing import Dict from vyper.semantics.analysis.base import Modifiability, VarInfo -from vyper.semantics.types import AddressT, BytesT, VyperType +from vyper.semantics.types import AddressT, BytesT, SelfT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -57,7 +57,7 @@ def get_constant_vars() -> Dict: return result -MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": AddressT} +MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": SelfT} def get_mutable_vars() -> Dict: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index a04632b96f..59a20dd99f 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -3,7 +3,7 @@ from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT -from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT +from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 07d1a21a94..6115a9e6fb 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -340,3 +340,9 @@ def validate_literal(self, node: vy_ast.Constant) -> None: f"address, the correct checksummed form is: {checksum_encode(addr)}", node, ) + + +# type for "self" +# refactoring note: it might be best for this to be a ModuleT actually +class SelfT(AddressT): + pass From aa4c5828ec89fa2636f19b6c21405c34c9a5d024 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 8 Feb 2024 17:54:10 -0500 Subject: [PATCH 02/26] impose topsort on function analysis --- vyper/semantics/analysis/local.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 758adefa49..60034dbba2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -4,6 +4,7 @@ from typing import Optional from vyper import ast as vy_ast +from vyper.utils import OrderedSet from vyper.ast.validation import validate_call_args from vyper.exceptions import ( CallViolation, @@ -60,15 +61,31 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() namespace = get_namespace() + + seen = OrderedSet() for node in vy_module.get_children(vy_ast.FunctionDef): - with namespace.enter_scope(): - try: - analyzer = FunctionAnalyzer(vy_module, node, namespace) - analyzer.analyze() - except VyperException as e: - err_list.append(e) + _validate_function_r(vy_module, node, seen, err_list) + +def _validate_function_r(vy_module: vy_ast.Module, node: vy_ast.FunctionDef, seen: OrderedSet, err_list: ExceptionList): + func_t = node._metadata["func_type"] + + if func_t in seen: + return + + for call_t in func_t.called_functions: + if call_t in seen: + continue + if isinstance(call_t, ContractFunctionT): + _validate_function_r(vy_module, call_t.ast_def, seen, err_list) + + namespace = get_namespace() - err_list.raise_if_not_empty() + try: + analyzer = FunctionAnalyzer(vy_module, node, namespace) + analyzer.analyze() + seen.add(func_t) + except VyperException as e: + err_list.append(e) # finds the terminus node for a list of nodes. From 94546875795f451caf231da27944e0b0981a1e7d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 08:51:10 -0500 Subject: [PATCH 03/26] refactor and clean up FunctionAnalyzer.visit_For --- vyper/semantics/analysis/local.py | 88 +++++++++++++++++-------------- 1 file changed, 47 insertions(+), 41 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 60034dbba2..b8775320d2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -4,7 +4,6 @@ from typing import Optional from vyper import ast as vy_ast -from vyper.utils import OrderedSet from vyper.ast.validation import validate_call_args from vyper.exceptions import ( CallViolation, @@ -55,18 +54,23 @@ ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability from vyper.semantics.types.utils import type_from_annotation +from vyper.utils import OrderedSet def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() - namespace = get_namespace() + seen: OrderedSet[ContractFunctionT] = OrderedSet() - seen = OrderedSet() for node in vy_module.get_children(vy_ast.FunctionDef): _validate_function_r(vy_module, node, seen, err_list) -def _validate_function_r(vy_module: vy_ast.Module, node: vy_ast.FunctionDef, seen: OrderedSet, err_list: ExceptionList): + err_list.raise_if_not_empty() + + +def _validate_function_r( + vy_module: vy_ast.Module, node: vy_ast.FunctionDef, seen: OrderedSet, err_list: ExceptionList +): func_t = node._metadata["func_type"] if func_t in seen: @@ -76,6 +80,7 @@ def _validate_function_r(vy_module: vy_ast.Module, node: vy_ast.FunctionDef, see if call_t in seen: continue if isinstance(call_t, ContractFunctionT): + assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy _validate_function_r(vy_module, call_t.ast_def, seen, err_list) namespace = get_namespace() @@ -401,6 +406,40 @@ def visit_Expr(self, node): ) self.expr_visitor.visit(node.value, return_value) + def _analyse_range_iter(self, iter_node, target_type): + # iteration via range() + if iter_node.get("func.id") != "range": + raise IteratorException("Cannot iterate over the result of a function call", iter_node) + _validate_range_call(iter_node) + + args = iter_node.args + kwargs = [s.value for s in iter_node.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + + def _analyse_list_iter(self, iter_node, target_type): + # iteration over a variable or literal list + iter_val = iter_node + if iter_val.has_folded_value: + iter_val = iter_val.get_folded_value() + + if isinstance(iter_val, vy_ast.List): + len_ = len(iter_val.elements) + if len_ == 0: + raise StructureException("For loop must have at least 1 iteration", iter_node) + self.expr_visitor.visit(iter_node, SArrayT(target_type, len_)) + else: + iter_type = get_exact_type_from_node(iter_node) + self.expr_visitor.visit(iter_node, iter_type) + + try: + validate_expected_type(iter_node, (DArrayT.any(), SArrayT.any())) + except (TypeMismatch, InvalidType): + raise InvalidType("Not an iterable type", iter_node) + + info = get_expr_info(iter_val) + return info.get_root_varinfo() + def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) @@ -408,55 +447,22 @@ def visit_For(self, node): target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) iter_varinfo = None - if isinstance(node.iter, vy_ast.Call): - # iteration via range() - if node.iter.get("func.id") != "range": - raise IteratorException( - "Cannot iterate over the result of a function call", node.iter - ) - _validate_range_call(node.iter) - + self._analyse_range_iter(node.iter, target_type) else: - # iteration over a variable or literal list - iter_val = node.iter - if iter_val.has_folded_value: - iter_val = node.iter.get_folded_value() - - if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: - raise StructureException("For loop must have at least 1 iteration", node.iter) - - if not any( - isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) - ): - raise InvalidType("Not an iterable type", node.iter) - - info = get_expr_info(iter_val) - iter_varinfo = info.get_root_varinfo() + iter_varinfo = self._analyse_list_iter(node.iter, target_type) with self.namespace.enter_scope(), self.enter_for_loop(iter_varinfo): target_name = node.target.target.id + # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) + self.expr_visitor.visit(node.target.target, target_type) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, target_type) - - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) - elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - args = node.iter.args - kwargs = [s.value for s in node.iter.keywords] - for arg in (*args, *kwargs): - self.expr_visitor.visit(arg, target_type) - else: - iter_type = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, iter_type) - def visit_If(self, node): self.expr_visitor.visit(node.test, BoolT()) with self.namespace.enter_scope(): From b997fe0ac7450315f185768132e7057b1eb71ceb Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 09:11:13 -0500 Subject: [PATCH 04/26] add a comment --- vyper/semantics/analysis/local.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b8775320d2..be7c248fba 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -437,6 +437,8 @@ def _analyse_list_iter(self, iter_node, target_type): except (TypeMismatch, InvalidType): raise InvalidType("Not an iterable type", iter_node) + # get the root varinfo from iter_val in case we need to peer + # through folded constants info = get_expr_info(iter_val) return info.get_root_varinfo() From a057f3976cf3a1ec0c19c9ea95e086def3f44e4c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 09:39:58 -0500 Subject: [PATCH 05/26] fix missing enter_scope() --- vyper/semantics/analysis/local.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index be7c248fba..5719179292 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -86,8 +86,10 @@ def _validate_function_r( namespace = get_namespace() try: - analyzer = FunctionAnalyzer(vy_module, node, namespace) - analyzer.analyze() + with namespace.enter_scope(): + analyzer = FunctionAnalyzer(vy_module, node, namespace) + analyzer.analyze() + seen.add(func_t) except VyperException as e: err_list.append(e) From d8fa41a7203fdf652a03716019cac4b2c7071ed6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 09:40:26 -0500 Subject: [PATCH 06/26] fix typechecker for darray, sarray --- vyper/semantics/analysis/local.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 5719179292..3a28c8db54 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -429,16 +429,17 @@ def _analyse_list_iter(self, iter_node, target_type): len_ = len(iter_val.elements) if len_ == 0: raise StructureException("For loop must have at least 1 iteration", iter_node) - self.expr_visitor.visit(iter_node, SArrayT(target_type, len_)) + iter_type = SArrayT(target_type, len_) else: iter_type = get_exact_type_from_node(iter_node) - self.expr_visitor.visit(iter_node, iter_type) - try: - validate_expected_type(iter_node, (DArrayT.any(), SArrayT.any())) - except (TypeMismatch, InvalidType): + # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # with generic length. + if not isinstance(iter_type, (DArrayT, SArrayT)): raise InvalidType("Not an iterable type", iter_node) + self.expr_visitor.visit(iter_node, iter_type) + # get the root varinfo from iter_val in case we need to peer # through folded constants info = get_expr_info(iter_val) From c1d693dc68067b5c9ab02d8531bbedd766749098 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 10:04:31 -0500 Subject: [PATCH 07/26] add tests for repros from issue --- .../unit/semantics/analysis/test_for_loop.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 607587cc28..c538ce6b7d 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -134,6 +134,45 @@ def baz(): validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_iterator_through_struct(dummy_input_bundle): + # GH issue 3429 + code = """ +struct A: + iter: DynArray[uint256, 5] + +a: A + +@external +def foo(): + self.a.iter = [1, 2, 3] + for i: uint256 in self.a.iter: + self.a = A({iter: [1, 2, 3, 4]}) + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_complex_expr(dummy_input_bundle): + # GH issue 3429 + # avoid false positive! + code = """ +a: DynArray[uint256, 5] +b: uint256[10] + +@external +def foo(): + self.a = [1, 2, 3] + for i: uint256 in self.a: + self.b[self.a[1]] = i + """ + + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + iterator_inference_codes = [ """ @external From 0021efaec3f66d49c49d489fdffc92cf6f959482 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 10:04:43 -0500 Subject: [PATCH 08/26] add tests for iterators imported from modules --- .../features/iteration/test_for_in_list.py | 56 ++++++++++++++++++- vyper/semantics/analysis/local.py | 5 +- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 36252701c4..e1bd8f313d 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -3,6 +3,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ImmutableViolation, @@ -841,6 +842,59 @@ def foo(): ] +# TODO: move these to tests/functional/syntax @pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) def test_bad_code(assert_compile_failed, get_contract, code, err): - assert_compile_failed(lambda: get_contract(code), err) + with pytest.raises(err): + compile_code(code) + + +def test_iterator_modification_module_attribute(make_input_bundle): + # test modifying iterator via attribute + lib1 = """ +queue: DynArray[uint256, 5] + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.queue.pop() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_module_function_call(make_input_bundle): + lib1 = """ +queue: DynArray[uint256, 5] + +@internal +def popqueue(): + self.queue.pop() + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.popqueue() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 3a28c8db54..a9b6a120ce 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -431,7 +431,10 @@ def _analyse_list_iter(self, iter_node, target_type): raise StructureException("For loop must have at least 1 iteration", iter_node) iter_type = SArrayT(target_type, len_) else: - iter_type = get_exact_type_from_node(iter_node) + try: + iter_type = get_exact_type_from_node(iter_node) + except (InvalidType, StructureException): + raise InvalidType("Not an iterable type", iter_node) # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays # with generic length. From 600553232dfeb5341eeb4daeef6ece1e64aa74c4 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 10:07:37 -0500 Subject: [PATCH 09/26] add test for topsort analysis --- .../unit/semantics/analysis/test_for_loop.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index c538ce6b7d..c755706d01 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -134,6 +134,31 @@ def baz(): validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): + # test the analysis works no matter the order of functions + code = """ +a: uint256[3] + +@internal +def baz(): + for i: uint256 in self.a: + self.bar() + +@internal +def bar(): + self.foo() + +@internal +def foo(): + self.a[0] = 1 + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + def test_modify_iterator_through_struct(dummy_input_bundle): # GH issue 3429 code = """ From f3f683c08a98c071c69869e28b9e569ae5716da9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 15:27:51 +0000 Subject: [PATCH 10/26] fix topsort for function calls --- vyper/semantics/analysis/local.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a9b6a120ce..fafa8331d0 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -85,12 +85,13 @@ def _validate_function_r( namespace = get_namespace() + # add to seen before analysing, if it throws an exception which gets + # caught, we don't want to analyse again. + seen.add(func_t) try: with namespace.enter_scope(): analyzer = FunctionAnalyzer(vy_module, node, namespace) analyzer.analyze() - - seen.add(func_t) except VyperException as e: err_list.append(e) From dc0908cb2f228619923f9319f13bfac76a28755b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 15:28:00 +0000 Subject: [PATCH 11/26] fix type comparison for SelfT --- vyper/semantics/types/primitives.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 6115a9e6fb..d383f72ab2 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -345,4 +345,8 @@ def validate_literal(self, node: vy_ast.Constant) -> None: # type for "self" # refactoring note: it might be best for this to be a ModuleT actually class SelfT(AddressT): - pass + _id = "self" + + def compare_type(self, other): + # compares true to AddressT + return isinstance(other, type(self)) or isinstance(self, type(other)) From 4ffd3aa5c6417d08a185de60b07810a898933662 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Feb 2024 15:34:01 +0000 Subject: [PATCH 12/26] fix topsort (again!) --- vyper/semantics/analysis/local.py | 23 ++++++++++------------- vyper/semantics/types/function.py | 2 ++ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index fafa8331d0..548d2f56dc 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -54,40 +54,30 @@ ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability from vyper.semantics.types.utils import type_from_annotation -from vyper.utils import OrderedSet def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() - seen: OrderedSet[ContractFunctionT] = OrderedSet() for node in vy_module.get_children(vy_ast.FunctionDef): - _validate_function_r(vy_module, node, seen, err_list) + _validate_function_r(vy_module, node, err_list) err_list.raise_if_not_empty() def _validate_function_r( - vy_module: vy_ast.Module, node: vy_ast.FunctionDef, seen: OrderedSet, err_list: ExceptionList + vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList ): func_t = node._metadata["func_type"] - if func_t in seen: - return - for call_t in func_t.called_functions: - if call_t in seen: - continue if isinstance(call_t, ContractFunctionT): assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy - _validate_function_r(vy_module, call_t.ast_def, seen, err_list) + _validate_function_r(vy_module, call_t.ast_def, err_list) namespace = get_namespace() - # add to seen before analysing, if it throws an exception which gets - # caught, we don't want to analyse again. - seen.add(func_t) try: with namespace.enter_scope(): analyzer = FunctionAnalyzer(vy_module, node, namespace) @@ -195,6 +185,13 @@ def __init__( self.loop_variables: list[Optional[VarInfo]] = [] def analyze(self): + if self.func._analyzed: + return + + # mark seen before analysing, if analysis throws an exception which + # gets caught, we don't want to analyse again. + self.func._analyzed = True + # allow internal function params to be mutable if self.func.is_internal: location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..fa2474e6ad 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -107,6 +107,8 @@ def __init__( self.ast_def = ast_def + self._analyzed = False + # a list of internal functions this function calls. # to be populated during analysis self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() From 7729ab753c459564876ba44f694cb6cb13b2622c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 10:14:10 -0500 Subject: [PATCH 13/26] refactor: improve the API for ContractFunctionT, protect some private members behind methods --- vyper/semantics/analysis/local.py | 17 +++++------ vyper/semantics/analysis/module.py | 2 +- vyper/semantics/types/function.py | 46 ++++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 548d2f56dc..0ee03fc9bd 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -185,12 +185,12 @@ def __init__( self.loop_variables: list[Optional[VarInfo]] = [] def analyze(self): - if self.func._analyzed: + if self.func.analysed: return # mark seen before analysing, if analysis throws an exception which # gets caught, we don't want to analyse again. - self.func._analyzed = True + self.func.mark_analysed() # allow internal function params to be mutable if self.func.is_internal: @@ -355,7 +355,7 @@ def _check_module_use(self, target: vy_ast.ExprNode): root_module_info = module_infos[0] # log the access - self.func._used_modules.add(root_module_info) + self.func.mark_used_module(root_module_info) def visit_Assign(self, node): self._assign_helper(node) @@ -573,8 +573,8 @@ def visit(self, node, typ): if s.is_module_variable(): self.function_analyzer._check_module_use(node) - self.func._variable_writes.update(info._writes) - self.func._variable_reads.update(info._reads) + self.func.mark_variable_writes(info._writes) + self.func.mark_variable_reads(info._reads) # validate and annotate folded value if node.has_folded_value: @@ -631,13 +631,14 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if isinstance(func_type, ContractFunctionT): # function calls - func_info._writes.update(func_type._variable_writes) - func_info._reads.update(func_type._variable_reads) + if not func_type.from_interface: + func_info._writes.update(func_type.get_variable_writes()) + func_info._reads.update(func_type.get_variable_reads()) if self.function_analyzer: self._check_call_mutability(func_type.mutability) - for s in func_type._variable_accesses: + for s in func_type.get_variable_accesses(): if s.is_module_variable(): self.function_analyzer._check_module_use(node.func) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..a7d8300083 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -244,7 +244,7 @@ def validate_used_modules(self): all_used_modules = OrderedSet() for f in module_t.functions.values(): - for u in f._used_modules: + for u in f.get_used_modules(): all_used_modules.add(u.module_t) for used_module in all_used_modules: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index fa2474e6ad..e4b1e7a732 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -92,6 +92,7 @@ def __init__( return_type: Optional[VyperType], function_visibility: FunctionVisibility, state_mutability: StateMutability, + from_interface: bool = False, nonreentrant: Optional[str] = None, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: @@ -104,10 +105,11 @@ def __init__( self.visibility = function_visibility self.mutability = state_mutability self.nonreentrant = nonreentrant + self.from_interface = from_interface self.ast_def = ast_def - self._analyzed = False + self._analysed = False # a list of internal functions this function calls. # to be populated during analysis @@ -129,10 +131,45 @@ def __init__( self._ir_info: Any = None self._function_id: Optional[int] = None + def _protect_analysed(self): + if self.from_interface: + return + if not self._analysed: # pragma: nocover + raise CompilerPanic(f"unreachable {self}") + + def mark_analysed(self): + assert not self._analysed + self._analysed = True + @property - def _variable_accesses(self): + def analysed(self): + return self._analysed + + def get_variable_reads(self): + self._protect_analysed() + return self._variable_reads + + def get_variable_writes(self): + self._protect_analysed() + return self._variable_writes + + def get_variable_accesses(self): + self._protect_analysed() return self._variable_reads | self._variable_writes + def get_used_modules(self): + self._protect_analysed() + return self._used_modules + + def mark_used_module(self, module_info): + self._used_modules.add(module_info) + + def mark_variable_writes(self, var_infos): + self._variable_writes.update(var_infos) + + def mark_variable_reads(self, var_infos): + self._variable_reads.update(var_infos) + @property def modifiability(self): return Modifiability.from_state_mutability(self.mutability) @@ -191,6 +228,7 @@ def from_abi(cls, abi: dict) -> "ContractFunctionT": positional_args, [], return_type, + from_interface=True, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.from_abi(abi), ) @@ -250,6 +288,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=None, ast_def=funcdef, ) @@ -302,6 +341,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -372,6 +412,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=False, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -412,6 +453,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio args, [], return_type, + from_interface=False, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, ast_def=node, From 049f6e8c8be364e9e98b7aecde773aaa909c36d3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 14:28:41 -0500 Subject: [PATCH 14/26] remove protect_analysed it turns out it was not a very good idea --- vyper/semantics/types/function.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index e4b1e7a732..1b612c9b81 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -131,12 +131,6 @@ def __init__( self._ir_info: Any = None self._function_id: Optional[int] = None - def _protect_analysed(self): - if self.from_interface: - return - if not self._analysed: # pragma: nocover - raise CompilerPanic(f"unreachable {self}") - def mark_analysed(self): assert not self._analysed self._analysed = True @@ -146,19 +140,15 @@ def analysed(self): return self._analysed def get_variable_reads(self): - self._protect_analysed() return self._variable_reads def get_variable_writes(self): - self._protect_analysed() return self._variable_writes def get_variable_accesses(self): - self._protect_analysed() return self._variable_reads | self._variable_writes def get_used_modules(self): - self._protect_analysed() return self._used_modules def mark_used_module(self, module_info): From d8353ae3e302064a98e5a938777637adac476cb3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 14:16:39 -0500 Subject: [PATCH 15/26] fix: struct touching --- vyper/codegen/expr.py | 61 +++++++++++++++++-------------- vyper/semantics/analysis/base.py | 56 +++++++++++++++++++--------- vyper/semantics/analysis/local.py | 12 +++--- vyper/semantics/analysis/utils.py | 17 ++++++--- 4 files changed, 89 insertions(+), 57 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 335cfefb87..2aed6af4b2 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -37,6 +37,7 @@ VyperException, tag_exceptions, ) +from vyper.semantics.analysis.base import VarAttributeInfo from vyper.semantics.types import ( AddressT, BoolT, @@ -263,24 +264,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif (varinfo := self.expr._expr_info.var_info) is not None: - if varinfo.is_constant: - return Expr.parse_value_expr(varinfo.decl_node.value, self.context) - - location = data_location_to_address_space( - varinfo.location, self.context.is_ctor_context - ) - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -336,17 +319,39 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None and not isinstance( + varinfo, VarAttributeInfo + ): + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e59132bd82..b1ab262b9b 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -194,6 +194,29 @@ def is_constant(self): return res +@dataclass(kw_only=True) +class VarAttributeInfo(VarInfo): + attr: str + parent: VarInfo + + def __hash__(self): + return super().__hash__() + + @classmethod + def from_varinfo(cls, varinfo: VarInfo, attr: str, typ: VyperType): + location = varinfo.location + modifiability = varinfo.modifiability + return cls( + typ=typ, location=location, modifiability=modifiability, attr=attr, parent=varinfo + ) + + +@dataclass +class AttributeInfo: + attr: str + expr_info: "ExprInfo" + + @dataclass class ExprInfo: """ @@ -205,9 +228,7 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - - # the chain of attribute parents for this expr - attribute_chain: list["ExprInfo"] = field(default_factory=list) + attribute_chain: list[AttributeInfo] = field(default_factory=list) def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -216,6 +237,8 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") + self.attribute_chain = self.attribute_chain or [] + self._writes: OrderedSet[VarInfo] = OrderedSet() self._reads: OrderedSet[VarInfo] = OrderedSet() @@ -223,41 +246,40 @@ def __post_init__(self): # e.x. `x` will return varinfo for `x` # `module.foo` will return varinfo for `module.foo` # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_root_varinfo(self) -> Optional[VarInfo]: - for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None and not isinstance(expr_info.typ, SelfT): - return expr_info.var_info + def get_closest_varinfo(self) -> Optional[VarInfo]: + for attr_info in reversed(self.attribute_chain + [self]): + var_info = getattr(attr_info, "expr_info", attr_info).var_info # type: ignore + if var_info is not None and not isinstance(var_info, SelfT): + return var_info return None @classmethod - def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, - attribute_chain=attribute_chain or [], + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": modifiability = Modifiability.RUNTIME_CONSTANT if module_info.ownership >= ModuleOwnership.USES: modifiability = Modifiability.MODIFIABLE return cls( - module_info.module_t, - module_info=module_info, - modifiability=modifiability, - attribute_chain=attribute_chain or [], + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs ) - def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - if attribute_chain is not None: - fields["attribute_chain"] = attribute_chain + for t in to_copy: + assert t not in kwargs + fields.update(kwargs) return self.__class__(typ=typ, **fields) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0ee03fc9bd..a2ff81a363 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -329,16 +329,14 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_info = info.get_root_varinfo() - assert var_info is not None - - info._writes.add(var_info) + assert (varinfo := info.get_closest_varinfo()) is not None + info._writes.add(varinfo) def _check_module_use(self, target: vy_ast.ExprNode): module_infos = [] for t in get_expr_info(target).attribute_chain: - if t.module_info is not None: - module_infos.append(t.module_info) + if t.expr_info.module_info is not None: + module_infos.append(t.expr_info.module_info) if len(module_infos) == 0: return @@ -444,7 +442,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants info = get_expr_info(iter_val) - return info.get_root_varinfo() + return info.get_closest_varinfo() def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..ea0d6bc98d 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,13 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ( + ExprInfo, + Modifiability, + ModuleInfo, + VarAttributeInfo, + VarInfo, +) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -84,12 +90,11 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr info = self.get_expr_info(node.value, is_callable=is_callable) - attribute_chain = info.attribute_chain + [info] - t = info.typ.get_member(name, node) + attr = node.attr + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): @@ -99,7 +104,9 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) # it's something else, like my_struct.foo - return info.copy_with_type(t, attribute_chain=attribute_chain) + assert (varinfo := info.var_info) is not None + child_varinfo = VarAttributeInfo.from_varinfo(varinfo=varinfo, attr=attr, typ=t) + return ExprInfo.from_varinfo(child_varinfo) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): From 35a9bca2f067a149220e2fc599a6876787b736e1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 15:13:45 -0500 Subject: [PATCH 16/26] yeet VarAttributeInfo --- vyper/codegen/expr.py | 5 +-- vyper/semantics/analysis/base.py | 56 ++++++++++++++----------------- vyper/semantics/analysis/local.py | 47 +++++++++++++++----------- vyper/semantics/analysis/utils.py | 23 ++++--------- vyper/semantics/types/function.py | 6 ++-- 5 files changed, 64 insertions(+), 73 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 2aed6af4b2..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -37,7 +37,6 @@ VyperException, tag_exceptions, ) -from vyper.semantics.analysis.base import VarAttributeInfo from vyper.semantics.types import ( AddressT, BoolT, @@ -323,9 +322,7 @@ def parse_Attribute(self): # Other variables # self.x: global attribute - if (varinfo := self.expr._expr_info.var_info) is not None and not isinstance( - varinfo, VarAttributeInfo - ): + if (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: return Expr.parse_value_expr(varinfo.decl_node.value, self.context) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index b1ab262b9b..40da185d0d 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -194,27 +194,15 @@ def is_constant(self): return res -@dataclass(kw_only=True) -class VarAttributeInfo(VarInfo): - attr: str - parent: VarInfo +@dataclass(frozen=True) +class VarAccess: + variable: VarInfo + attrs: tuple[str, ...] - def __hash__(self): - return super().__hash__() - - @classmethod - def from_varinfo(cls, varinfo: VarInfo, attr: str, typ: VyperType): - location = varinfo.location - modifiability = varinfo.modifiability - return cls( - typ=typ, location=location, modifiability=modifiability, attr=attr, parent=varinfo - ) - - -@dataclass -class AttributeInfo: - attr: str - expr_info: "ExprInfo" + def contains(self, other): + # VarAccess("v", ("a")) `contains` VarAccess("v", ("a", "b", "c")) + sub_attrs = other.attrs[: len(self.attrs)] + return self.variable == other.variable and sub_attrs == self.attrs @dataclass @@ -228,7 +216,8 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - attribute_chain: list[AttributeInfo] = field(default_factory=list) + attribute_chain: list["ExprInfo"] = field(default_factory=list) + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -239,18 +228,24 @@ def __post_init__(self): self.attribute_chain = self.attribute_chain or [] - self._writes: OrderedSet[VarInfo] = OrderedSet() - self._reads: OrderedSet[VarInfo] = OrderedSet() + self._writes: OrderedSet[VarAccess] = OrderedSet() + self._reads: OrderedSet[VarAccess] = OrderedSet() # find exprinfo in the attribute chain which has a varinfo # e.x. `x` will return varinfo for `x` # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_closest_varinfo(self) -> Optional[VarInfo]: - for attr_info in reversed(self.attribute_chain + [self]): - var_info = getattr(attr_info, "expr_info", attr_info).var_info # type: ignore - if var_info is not None and not isinstance(var_info, SelfT): - return var_info + # `self.my_struct.x.y` will return varinfo for `self.my_struct.x.y` + def get_variable_access(self) -> Optional[VarAccess]: + chain = self.attribute_chain + [self] + for i, expr_info in enumerate(chain): + varinfo = expr_info.var_info + if varinfo is not None and not isinstance(varinfo, SelfT): + attrs = [] + for expr_info in chain[i:]: + if expr_info.attr is None: + continue + attrs.append(expr_info.attr) + return VarAccess(varinfo, tuple(attrs)) return None @classmethod @@ -281,5 +276,4 @@ def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": fields = {k: getattr(self, k) for k in to_copy} for t in to_copy: assert t not in kwargs - fields.update(kwargs) - return self.__class__(typ=typ, **fields) + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a2ff81a363..bc49fdf562 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -20,7 +20,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo +from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarAccess, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -182,7 +182,7 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) - self.loop_variables: list[Optional[VarInfo]] = [] + self.loop_variables: list[Optional[VarAccess]] = [] def analyze(self): if self.func.analysed: @@ -221,8 +221,8 @@ def analyze(self): self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @contextlib.contextmanager - def enter_for_loop(self, varinfo: Optional[VarInfo]): - self.loop_variables.append(varinfo) + def enter_for_loop(self, varaccess: Optional[VarAccess]): + self.loop_variables.append(varaccess) try: yield finally: @@ -329,14 +329,19 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - assert (varinfo := info.get_closest_varinfo()) is not None - info._writes.add(varinfo) + base_var = target + while isinstance(base_var, vy_ast.Subscript): + base_var = base_var.value + + base_info = get_expr_info(base_var) + assert (var_access := base_info.get_variable_access()) is not None + info._writes.add(var_access) def _check_module_use(self, target: vy_ast.ExprNode): module_infos = [] for t in get_expr_info(target).attribute_chain: - if t.expr_info.module_info is not None: - module_infos.append(t.expr_info.module_info) + if t.module_info is not None: + module_infos.append(t.module_info) if len(module_infos) == 0: return @@ -442,7 +447,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants info = get_expr_info(iter_val) - return info.get_closest_varinfo() + return info.get_variable_access() def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): @@ -450,13 +455,13 @@ def visit_For(self, node): target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) - iter_varinfo = None + iter_var = None if isinstance(node.iter, vy_ast.Call): self._analyse_range_iter(node.iter, target_type) else: - iter_varinfo = self._analyse_list_iter(node.iter, target_type) + iter_var = self._analyse_list_iter(node.iter, target_type) - with self.namespace.enter_scope(), self.enter_for_loop(iter_varinfo): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): target_name = node.target.target.id # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( @@ -551,24 +556,28 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - varinfo = info.var_info - if varinfo is not None: - info._reads.add(varinfo) + var_access = info.get_variable_access() + if var_access is not None: + info._reads.add(var_access) if self.function_analyzer: for s in self.function_analyzer.loop_variables: if s is None: continue - if s in info._writes: + for v in info._writes: + if not v.contains(s): + continue + msg = "Cannot modify loop variable" - if s.decl_node is not None: + var = s.variable + if var.decl_node is not None: msg += f" `{s.decl_node.target.id}`" - raise ImmutableViolation(msg, s.decl_node, node) + raise ImmutableViolation(msg, var.decl_node, node) variable_accesses = info._writes | info._reads for s in variable_accesses: - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node) self.func.mark_variable_writes(info._writes) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ea0d6bc98d..64af036242 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,13 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ( - ExprInfo, - Modifiability, - ModuleInfo, - VarAttributeInfo, - VarInfo, -) +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -91,28 +85,25 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # note: Attribute(expr value, identifier attr) info = self.get_expr_info(node.value, is_callable=is_callable) + attr = node.attr + attribute_chain = info.attribute_chain + [info] - attr = node.attr t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain, attr=attr) - # it's something else, like my_struct.foo - assert (varinfo := info.var_info) is not None - child_varinfo = VarAttributeInfo.from_varinfo(varinfo=varinfo, attr=attr, typ=t) - return ExprInfo.from_varinfo(child_varinfo) + return info.copy_with_type(t, attribute_chain=attribute_chain, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - attribute_chain = info.attribute_chain + [info] - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t) return ExprInfo(t) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 1b612c9b81..705470a798 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -21,7 +21,7 @@ Modifiability, ModuleInfo, StateMutability, - VarInfo, + VarAccess, VarOffset, ) from vyper.semantics.analysis.utils import ( @@ -119,10 +119,10 @@ def __init__( self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function - self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + self._variable_writes: OrderedSet[VarAccess] = OrderedSet() # reads of variables from this function - self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + self._variable_reads: OrderedSet[VarAccess] = OrderedSet() # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() From 9c2af79de3817b80cd49f7c2914c03d94337b195 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 22:57:34 -0500 Subject: [PATCH 17/26] fix mypy --- vyper/ast/nodes.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7f863a8db9..342c84876a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -200,13 +200,13 @@ class Call(ExprNode): class keyword(VyperNode): ... -class Attribute(VyperNode): +class Attribute(ExprNode): attr: str = ... value: ExprNode = ... -class Subscript(VyperNode): - slice: VyperNode = ... - value: VyperNode = ... +class Subscript(ExprNode): + slice: ExprNode = ... + value: ExprNode = ... class Assign(VyperNode): ... From 3d32b7683c913d5d13efbea8e3ca2fda67127c52 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 23:14:50 -0500 Subject: [PATCH 18/26] fix more complicated case --- vyper/semantics/analysis/local.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index bc49fdf562..3e58cf6bfd 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -169,6 +169,15 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) +def _get_base_var(node: vy_ast.ExprNode): + info = get_expr_info(node) + while (var_access := info.get_variable_access()) is None: + assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) + node = node.value + info = get_expr_info(node) + return var_access + + class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -329,12 +338,9 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - base_var = target - while isinstance(base_var, vy_ast.Subscript): - base_var = base_var.value + var_access = _get_base_var(target) + assert var_access is not None - base_info = get_expr_info(base_var) - assert (var_access := base_info.get_variable_access()) is not None info._writes.add(var_access) def _check_module_use(self, target: vy_ast.ExprNode): @@ -446,8 +452,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants - info = get_expr_info(iter_val) - return info.get_variable_access() + return _get_base_var(iter_val) def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): @@ -561,6 +566,8 @@ def visit(self, node, typ): info._reads.add(var_access) if self.function_analyzer: + # note to self: check if moving this to _handle_modification + # breaks tests for s in self.function_analyzer.loop_variables: if s is None: continue From 0a5376cbfa8ed94e2597fd8b58deaa1487250312 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 23:27:22 -0500 Subject: [PATCH 19/26] fix bugs --- vyper/semantics/analysis/base.py | 17 ++++++++++------- vyper/semantics/analysis/local.py | 17 +++++++++++------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 40da185d0d..573da6d3cd 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -239,13 +239,16 @@ def get_variable_access(self) -> Optional[VarAccess]: chain = self.attribute_chain + [self] for i, expr_info in enumerate(chain): varinfo = expr_info.var_info - if varinfo is not None and not isinstance(varinfo, SelfT): - attrs = [] - for expr_info in chain[i:]: - if expr_info.attr is None: - continue - attrs.append(expr_info.attr) - return VarAccess(varinfo, tuple(attrs)) + if varinfo is None or isinstance(varinfo.typ, SelfT): + continue + + attrs = [] + for expr_info in chain[i:]: + if expr_info.attr is None: + continue + attrs.append(expr_info.attr) + return VarAccess(varinfo, tuple(attrs)) + return None @classmethod diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 3e58cf6bfd..24c323d98a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -172,7 +172,8 @@ def _validate_self_reference(node: vy_ast.Name) -> None: def _get_base_var(node: vy_ast.ExprNode): info = get_expr_info(node) while (var_access := info.get_variable_access()) is None: - assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + return None node = node.value info = get_expr_info(node) return var_access @@ -579,7 +580,7 @@ def visit(self, node, typ): msg = "Cannot modify loop variable" var = s.variable if var.decl_node is not None: - msg += f" `{s.decl_node.target.id}`" + msg += f" `{var.decl_node.target.id}`" raise ImmutableViolation(msg, var.decl_node, node) variable_accesses = info._writes | info._reads @@ -646,14 +647,18 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: # function calls if not func_type.from_interface: - func_info._writes.update(func_type.get_variable_writes()) - func_info._reads.update(func_type.get_variable_reads()) + for s in func_type.get_variable_writes(): + if s.variable.is_module_variable(): + func_info._writes.add(s) + for s in func_type.get_variable_reads(): + if s.variable.is_module_variable(): + func_info._reads.add(s) if self.function_analyzer: self._check_call_mutability(func_type.mutability) for s in func_type.get_variable_accesses(): - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node.func) if func_type.is_deploy and not self.func.is_deploy: @@ -684,7 +689,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: elif isinstance(func_type, MemberFunctionT): if func_type.is_modifying and self.function_analyzer is not None: # TODO refactor this - self.function_analyzer._handle_modification(node.func) + self.function_analyzer._handle_modification(node.func.value) assert len(node.args) == len(func_type.arg_types) for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) From 5e9006679b3782a35b2f61b9a63f8d9d23b69dfc Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 23:28:26 -0500 Subject: [PATCH 20/26] fix mypy --- vyper/semantics/analysis/local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 24c323d98a..384f3e519d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -689,6 +689,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: elif isinstance(func_type, MemberFunctionT): if func_type.is_modifying and self.function_analyzer is not None: # TODO refactor this + assert isinstance(node.func, vy_ast.Attribute) # help mypy self.function_analyzer._handle_modification(node.func.value) assert len(node.args) == len(func_type.arg_types) for arg, arg_type in zip(node.args, func_type.arg_types): From b647e97b64cb40f8e867500e0e10ca7fab580bf4 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 23:32:10 -0500 Subject: [PATCH 21/26] add more tests --- .../unit/semantics/analysis/test_for_loop.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index c755706d01..c97c9c095e 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -193,11 +193,52 @@ def foo(): for i: uint256 in self.a: self.b[self.a[1]] = i """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_iterator_siblings(dummy_input_bundle): + # test we can modify siblings in an access tree + code = """ +struct Foo: + a: uint256[2] + b: uint256 + +f: Foo +@external +def foo(): + for i: uint256 in self.f.a: + self.f.b += i + """ vyper_module = parse_to_ast(code) validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_subscript_barrier(dummy_input_bundle): + # test that Subscript nodes are a barrier for analysis + code = """ +struct Foo: + x: uint256[2] + y: uint256 + +struct Bar: + f: Foo[2] + +b: Bar + +@external +def foo(): + for i: uint256 in self.b.f[1].x: + self.b.f[0].y += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `b`" + + iterator_inference_codes = [ """ @external From 111333e34d9ec61d1cc7fd7f043608f5e0f55229 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 23:36:53 -0500 Subject: [PATCH 22/26] remove a comment --- vyper/semantics/analysis/local.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 384f3e519d..314b385b7e 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -567,8 +567,6 @@ def visit(self, node, typ): info._reads.add(var_access) if self.function_analyzer: - # note to self: check if moving this to _handle_modification - # breaks tests for s in self.function_analyzer.loop_variables: if s is None: continue From 274de108bf1dc5432f4315be3e0f2e0532293564 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 09:10:16 -0500 Subject: [PATCH 23/26] refactor get_variable_access --- vyper/semantics/analysis/base.py | 21 ------------------- vyper/semantics/analysis/local.py | 35 +++++++++++++++++++++++++------ 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 573da6d3cd..ff269aa2cf 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -7,7 +7,6 @@ from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType -from vyper.semantics.types.primitives import SelfT from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: @@ -231,26 +230,6 @@ def __post_init__(self): self._writes: OrderedSet[VarAccess] = OrderedSet() self._reads: OrderedSet[VarAccess] = OrderedSet() - # find exprinfo in the attribute chain which has a varinfo - # e.x. `x` will return varinfo for `x` - # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct.x.y` - def get_variable_access(self) -> Optional[VarAccess]: - chain = self.attribute_chain + [self] - for i, expr_info in enumerate(chain): - varinfo = expr_info.var_info - if varinfo is None or isinstance(varinfo.typ, SelfT): - continue - - attrs = [] - for expr_info in chain[i:]: - if expr_info.attr is None: - continue - attrs.append(expr_info.attr) - return VarAccess(varinfo, tuple(attrs)) - - return None - @classmethod def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 314b385b7e..0c58f47ea4 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -169,14 +169,37 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) -def _get_base_var(node: vy_ast.ExprNode): +# analyse the variable access for the attribute chain for a node +# e.x. `x` will return varinfo for `x` +# `module.foo` will return VarAccess for `module.foo` +# `self.my_struct.x.y` will return VarAccess for `self.my_struct.x.y` +def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: + attrs: list[str] = [] info = get_expr_info(node) - while (var_access := info.get_variable_access()) is None: + + while info.var_info is None: if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + # it's something like a literal return None + + if isinstance(node, vy_ast.Subscript): + # Subscript is an analysis barrier + # we cannot analyse if `x.y[ix1].z` overlaps with `x.y[ix2].z`. + attrs.clear() + + if (attr := info.attr) is not None: + attrs.append(attr) + + assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) # help mypy node = node.value info = get_expr_info(node) - return var_access + + # ignore `self.` as it interferes with VarAccess comparison across modules + if len(attrs) > 0 and attrs[-1] == "self": + attrs.pop() + attrs.reverse() + + return VarAccess(info.var_info, tuple(attrs)) class FunctionAnalyzer(VyperNodeVisitorBase): @@ -339,7 +362,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_access = _get_base_var(target) + var_access = _get_variable_access(target) assert var_access is not None info._writes.add(var_access) @@ -453,7 +476,7 @@ def _analyse_list_iter(self, iter_node, target_type): # get the root varinfo from iter_val in case we need to peer # through folded constants - return _get_base_var(iter_val) + return _get_variable_access(iter_val) def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): @@ -562,7 +585,7 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - var_access = info.get_variable_access() + var_access = _get_variable_access(node) if var_access is not None: info._reads.add(var_access) From b4e739011cbef30b985ff0233f1cd9a1780c8c54 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 17:54:32 -0500 Subject: [PATCH 24/26] remove attribute_chain, use explicit traversal of the attribute/subscript tree --- vyper/semantics/analysis/base.py | 5 +---- vyper/semantics/analysis/local.py | 33 ++++++++++++++++++++++++++----- vyper/semantics/analysis/utils.py | 8 +++----- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index ff269aa2cf..49b867aae5 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,5 +1,5 @@ import enum -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast @@ -215,7 +215,6 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - attribute_chain: list["ExprInfo"] = field(default_factory=list) attr: Optional[str] = None def __post_init__(self): @@ -225,8 +224,6 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") - self.attribute_chain = self.attribute_chain or [] - self._writes: OrderedSet[VarAccess] = OrderedSet() self._reads: OrderedSet[VarAccess] = OrderedSet() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0c58f47ea4..29c779a29c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -20,7 +20,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarAccess, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleInfo, + ModuleOwnership, + VarAccess, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -202,6 +208,26 @@ def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: return VarAccess(info.var_info, tuple(attrs)) +# get the chain of modules, e.g. +# mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] +# CMC 2024-02-12 note that the Attribute/Subscript traversal in this and +# _get_variable_access() are a bit gross and could probably +# be refactored into data on ExprInfo. +def _get_module_chain(node: vy_ast.ExprNode) -> list[ModuleInfo]: + ret: list[ModuleInfo] = [] + info = get_expr_info(node) + + while isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + if info.module_info is not None: + ret.append(info.module_info) + + node = node.value + info = get_expr_info(node) + + ret.reverse() + return ret + + class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -368,10 +394,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): info._writes.add(var_access) def _check_module_use(self, target: vy_ast.ExprNode): - module_infos = [] - for t in get_expr_info(target).attribute_chain: - if t.module_info is not None: - module_infos.append(t.module_info) + module_infos = _get_module_chain(target) if len(module_infos) == 0: return diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 64af036242..034cd8c46e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -87,18 +87,16 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex info = self.get_expr_info(node.value, is_callable=is_callable) attr = node.attr - attribute_chain = info.attribute_chain + [info] - t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain, attr=attr) + return ExprInfo.from_varinfo(t, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain, attr=attr) + return ExprInfo.from_moduleinfo(t, attr=attr) - return info.copy_with_type(t, attribute_chain=attribute_chain, attr=attr) + return info.copy_with_type(t, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): From 68faed2ffa6e524fc48c602c2339a5ad5dfb93a3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 18:01:05 -0500 Subject: [PATCH 25/26] fix: while -> do while --- vyper/semantics/analysis/local.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 29c779a29c..39a1c59290 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -217,10 +217,13 @@ def _get_module_chain(node: vy_ast.ExprNode) -> list[ModuleInfo]: ret: list[ModuleInfo] = [] info = get_expr_info(node) - while isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + while True: if info.module_info is not None: ret.append(info.module_info) + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + break + node = node.value info = get_expr_info(node) From 591b97b50c8f3ac0b4b828202e04a2a0ec7fe351 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 18:03:15 -0500 Subject: [PATCH 26/26] add some subscript/attribute tests for module uses (note these also pass on master) --- .../syntax/modules/test_initializers.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index a12f5f57ea..0412e83c7d 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -741,6 +741,48 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint +def test_missing_uses_subscript(make_input_bundle): + # test missing uses through nested subscript/attribute access + lib1 = """ +struct Foo: + array: uint256[5] + +foos: Foo[5] + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.foos[0].array[1] = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + def test_missing_uses_nested_attribute_function_call(make_input_bundle): # test missing uses through nested attribute access lib1 = """