diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 36252701c4..e1bd8f313d 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -3,6 +3,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ImmutableViolation, @@ -841,6 +842,59 @@ def foo(): ] +# TODO: move these to tests/functional/syntax @pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) def test_bad_code(assert_compile_failed, get_contract, code, err): - assert_compile_failed(lambda: get_contract(code), err) + with pytest.raises(err): + compile_code(code) + + +def test_iterator_modification_module_attribute(make_input_bundle): + # test modifying iterator via attribute + lib1 = """ +queue: DynArray[uint256, 5] + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.queue.pop() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_module_function_call(make_input_bundle): + lib1 = """ +queue: DynArray[uint256, 5] + +@internal +def popqueue(): + self.queue.pop() + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.popqueue() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index a12f5f57ea..0412e83c7d 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -741,6 +741,48 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint +def test_missing_uses_subscript(make_input_bundle): + # test missing uses through nested subscript/attribute access + lib1 = """ +struct Foo: + array: uint256[5] + +foos: Foo[5] + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.foos[0].array[1] = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + def test_missing_uses_nested_attribute_function_call(make_input_bundle): # test missing uses through nested attribute access lib1 = """ diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 607587cc28..c97c9c095e 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -134,6 +134,111 @@ def baz(): validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): + # test the analysis works no matter the order of functions + code = """ +a: uint256[3] + +@internal +def baz(): + for i: uint256 in self.a: + self.bar() + +@internal +def bar(): + self.foo() + +@internal +def foo(): + self.a[0] = 1 + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_through_struct(dummy_input_bundle): + # GH issue 3429 + code = """ +struct A: + iter: DynArray[uint256, 5] + +a: A + +@external +def foo(): + self.a.iter = [1, 2, 3] + for i: uint256 in self.a.iter: + self.a = A({iter: [1, 2, 3, 4]}) + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_complex_expr(dummy_input_bundle): + # GH issue 3429 + # avoid false positive! + code = """ +a: DynArray[uint256, 5] +b: uint256[10] + +@external +def foo(): + self.a = [1, 2, 3] + for i: uint256 in self.a: + self.b[self.a[1]] = i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_iterator_siblings(dummy_input_bundle): + # test we can modify siblings in an access tree + code = """ +struct Foo: + a: uint256[2] + b: uint256 + +f: Foo + +@external +def foo(): + for i: uint256 in self.f.a: + self.f.b += i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_subscript_barrier(dummy_input_bundle): + # test that Subscript nodes are a barrier for analysis + code = """ +struct Foo: + x: uint256[2] + y: uint256 + +struct Bar: + f: Foo[2] + +b: Bar + +@external +def foo(): + for i: uint256 in self.b.f[1].x: + self.b.f[0].y += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `b`" + + iterator_inference_codes = [ """ @external diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7f863a8db9..342c84876a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -200,13 +200,13 @@ class Call(ExprNode): class keyword(VyperNode): ... -class Attribute(VyperNode): +class Attribute(ExprNode): attr: str = ... value: ExprNode = ... -class Subscript(VyperNode): - slice: VyperNode = ... - value: VyperNode = ... +class Subscript(ExprNode): + slice: ExprNode = ... + value: ExprNode = ... class Assign(VyperNode): ... diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 335cfefb87..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -263,24 +263,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif (varinfo := self.expr._expr_info.var_info) is not None: - if varinfo.is_constant: - return Expr.parse_value_expr(varinfo.decl_node.value, self.context) - - location = data_location_to_address_space( - varinfo.location, self.context.is_ctor_context - ) - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -336,17 +318,37 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 2086e5f9da..49b867aae5 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,5 +1,5 @@ import enum -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast @@ -193,6 +193,17 @@ def is_constant(self): return res +@dataclass(frozen=True) +class VarAccess: + variable: VarInfo + attrs: tuple[str, ...] + + def contains(self, other): + # VarAccess("v", ("a")) `contains` VarAccess("v", ("a", "b", "c")) + sub_attrs = other.attrs[: len(self.attrs)] + return self.variable == other.variable and sub_attrs == self.attrs + + @dataclass class ExprInfo: """ @@ -204,9 +215,7 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - - # the chain of attribute parents for this expr - attribute_chain: list["ExprInfo"] = field(default_factory=list) + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -215,48 +224,35 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") - self._writes: OrderedSet[VarInfo] = OrderedSet() - self._reads: OrderedSet[VarInfo] = OrderedSet() - - # find exprinfo in the attribute chain which has a varinfo - # e.x. `x` will return varinfo for `x` - # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_root_varinfo(self) -> Optional[VarInfo]: - for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None: - return expr_info.var_info - return None + self._writes: OrderedSet[VarAccess] = OrderedSet() + self._reads: OrderedSet[VarAccess] = OrderedSet() @classmethod - def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, - attribute_chain=attribute_chain or [], + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": modifiability = Modifiability.RUNTIME_CONSTANT if module_info.ownership >= ModuleOwnership.USES: modifiability = Modifiability.MODIFIABLE return cls( - module_info.module_t, - module_info=module_info, - modifiability=modifiability, - attribute_chain=attribute_chain or [], + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs ) - def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - if attribute_chain is not None: - fields["attribute_chain"] = attribute_chain - return self.__class__(typ=typ, **fields) + for t in to_copy: + assert t not in kwargs + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..39a1c59290 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,5 +1,6 @@ # CMC 2024-02-03 TODO: split me into function.py and expr.py +import contextlib from typing import Optional from vyper import ast as vy_ast @@ -19,7 +20,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleInfo, + ModuleOwnership, + VarAccess, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -58,18 +65,33 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() - namespace = get_namespace() + for node in vy_module.get_children(vy_ast.FunctionDef): - with namespace.enter_scope(): - try: - analyzer = FunctionAnalyzer(vy_module, node, namespace) - analyzer.analyze() - except VyperException as e: - err_list.append(e) + _validate_function_r(vy_module, node, err_list) err_list.raise_if_not_empty() +def _validate_function_r( + vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList +): + func_t = node._metadata["func_type"] + + for call_t in func_t.called_functions: + if isinstance(call_t, ContractFunctionT): + assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy + _validate_function_r(vy_module, call_t.ast_def, err_list) + + namespace = get_namespace() + + try: + with namespace.enter_scope(): + analyzer = FunctionAnalyzer(vy_module, node, namespace) + analyzer.analyze() + except VyperException as e: + err_list.append(e) + + # finds the terminus node for a list of nodes. # raises an exception if any nodes are unreachable def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: @@ -99,36 +121,6 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: return ret -def _check_iterator_modification( - target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode -) -> Optional[vy_ast.VyperNode]: - similar_nodes = [ - n - for n in search_node.get_descendants(type(target_node)) - if vy_ast.compare_nodes(target_node, n) - ] - - for node in similar_nodes: - # raise if the node is the target of an assignment statement - assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) - # note the use of get_descendants() blocks statements like - # self.my_array[i] = x - if assign_node and node in assign_node.target.get_descendants(include_self=True): - return node - - attr_node = node.get_ancestor(vy_ast.Attribute) - # note the use of get_descendants() blocks statements like - # self.my_array[i].append(x) - if ( - attr_node is not None - and node in attr_node.value.get_descendants(include_self=True) - and attr_node.attr in ("append", "pop", "extend") - ): - return node - - return None - - # helpers def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": @@ -183,6 +175,62 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) +# analyse the variable access for the attribute chain for a node +# e.x. `x` will return varinfo for `x` +# `module.foo` will return VarAccess for `module.foo` +# `self.my_struct.x.y` will return VarAccess for `self.my_struct.x.y` +def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: + attrs: list[str] = [] + info = get_expr_info(node) + + while info.var_info is None: + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + # it's something like a literal + return None + + if isinstance(node, vy_ast.Subscript): + # Subscript is an analysis barrier + # we cannot analyse if `x.y[ix1].z` overlaps with `x.y[ix2].z`. + attrs.clear() + + if (attr := info.attr) is not None: + attrs.append(attr) + + assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) # help mypy + node = node.value + info = get_expr_info(node) + + # ignore `self.` as it interferes with VarAccess comparison across modules + if len(attrs) > 0 and attrs[-1] == "self": + attrs.pop() + attrs.reverse() + + return VarAccess(info.var_info, tuple(attrs)) + + +# get the chain of modules, e.g. +# mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] +# CMC 2024-02-12 note that the Attribute/Subscript traversal in this and +# _get_variable_access() are a bit gross and could probably +# be refactored into data on ExprInfo. +def _get_module_chain(node: vy_ast.ExprNode) -> list[ModuleInfo]: + ret: list[ModuleInfo] = [] + info = get_expr_info(node) + + while True: + if info.module_info is not None: + ret.append(info.module_info) + + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + break + + node = node.value + info = get_expr_info(node) + + ret.reverse() + return ret + + class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -196,7 +244,16 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) + self.loop_variables: list[Optional[VarAccess]] = [] + def analyze(self): + if self.func.analysed: + return + + # mark seen before analysing, if analysis throws an exception which + # gets caught, we don't want to analyse again. + self.func.mark_analysed() + # allow internal function params to be mutable if self.func.is_internal: location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) @@ -225,6 +282,14 @@ def analyze(self): for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) + @contextlib.contextmanager + def enter_for_loop(self, varaccess: Optional[VarAccess]): + self.loop_variables.append(varaccess) + try: + yield + finally: + self.loop_variables.pop() + def visit(self, node): super().visit(node) @@ -326,16 +391,13 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_info = info.get_root_varinfo() - assert var_info is not None + var_access = _get_variable_access(target) + assert var_access is not None - info._writes.add(var_info) + info._writes.add(var_access) def _check_module_use(self, target: vy_ast.ExprNode): - module_infos = [] - for t in get_expr_info(target).attribute_chain: - if t.module_info is not None: - module_infos.append(t.module_info) + module_infos = _get_module_chain(target) if len(module_infos) == 0: return @@ -352,7 +414,7 @@ def _check_module_use(self, target: vy_ast.ExprNode): root_module_info = module_infos[0] # log the access - self.func._used_modules.add(root_module_info) + self.func.mark_used_module(root_module_info) def visit_Assign(self, node): self._assign_helper(node) @@ -403,96 +465,68 @@ def visit_Expr(self, node): ) self.expr_visitor.visit(node.value, return_value) + def _analyse_range_iter(self, iter_node, target_type): + # iteration via range() + if iter_node.get("func.id") != "range": + raise IteratorException("Cannot iterate over the result of a function call", iter_node) + _validate_range_call(iter_node) + + args = iter_node.args + kwargs = [s.value for s in iter_node.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + + def _analyse_list_iter(self, iter_node, target_type): + # iteration over a variable or literal list + iter_val = iter_node + if iter_val.has_folded_value: + iter_val = iter_val.get_folded_value() + + if isinstance(iter_val, vy_ast.List): + len_ = len(iter_val.elements) + if len_ == 0: + raise StructureException("For loop must have at least 1 iteration", iter_node) + iter_type = SArrayT(target_type, len_) + else: + try: + iter_type = get_exact_type_from_node(iter_node) + except (InvalidType, StructureException): + raise InvalidType("Not an iterable type", iter_node) + + # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # with generic length. + if not isinstance(iter_type, (DArrayT, SArrayT)): + raise InvalidType("Not an iterable type", iter_node) + + self.expr_visitor.visit(iter_node, iter_type) + + # get the root varinfo from iter_val in case we need to peer + # through folded constants + return _get_variable_access(iter_val) + def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_var = None if isinstance(node.iter, vy_ast.Call): - # iteration via range() - if node.iter.get("func.id") != "range": - raise IteratorException( - "Cannot iterate over the result of a function call", node.iter - ) - _validate_range_call(node.iter) - + self._analyse_range_iter(node.iter, target_type) else: - # iteration over a variable or literal list - iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter - if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: - raise StructureException("For loop must have at least 1 iteration", node.iter) - - if not any( - isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) - ): - raise InvalidType("Not an iterable type", node.iter) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # check for references to the iterated value within the body of the loop - assign = _check_iterator_modification(node.iter, node) - if assign: - raise ImmutableViolation("Cannot modify array during iteration", assign) - - # Check if `iter` is a storage variable. get_descendants` is used to check for - # nested `self` (e.g. structs) - # NOTE: this analysis will be borked once stateful modules are allowed! - iter_is_storage_var = ( - isinstance(node.iter, vy_ast.Attribute) - and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 - ) - - if iter_is_storage_var: - # check if iterated value may be modified by function calls inside the loop - iter_name = node.iter.attr - for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): - fn_name = call_node.func.attr - - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] - if _check_iterator_modification(node.iter, fn_node): - # check for direct modification - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it potentially " - f"modifies iterated storage variable '{iter_name}'", - call_node, - ) + iter_var = self._analyse_list_iter(node.iter, target_type) - for reachable_t in ( - self.namespace["self"].typ.members[fn_name].reachable_internal_functions - ): - # check for indirect modification - name = reachable_t.name - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] - if _check_iterator_modification(node.iter, fn_node): - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " - f"which potentially modifies iterated storage variable '{iter_name}'", - call_node, - ) - - target_name = node.target.target.id - with self.namespace.enter_scope(): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): + target_name = node.target.target.id + # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) + self.expr_visitor.visit(node.target.target, target_type) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, target_type) - - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) - elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - args = node.iter.args - kwargs = [s.value for s in node.iter.keywords] - for arg in (*args, *kwargs): - self.expr_visitor.visit(arg, target_type) - else: - iter_type = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, iter_type) - def visit_If(self, node): self.expr_visitor.visit(node.test, BoolT()) with self.namespace.enter_scope(): @@ -577,18 +611,32 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - varinfo = info.var_info - if varinfo is not None: - info._reads.add(varinfo) + var_access = _get_variable_access(node) + if var_access is not None: + info._reads.add(var_access) + + if self.function_analyzer: + for s in self.function_analyzer.loop_variables: + if s is None: + continue + + for v in info._writes: + if not v.contains(s): + continue + + msg = "Cannot modify loop variable" + var = s.variable + if var.decl_node is not None: + msg += f" `{var.decl_node.target.id}`" + raise ImmutableViolation(msg, var.decl_node, node) - if self.func: variable_accesses = info._writes | info._reads for s in variable_accesses: - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node) - self.func._variable_writes.update(info._writes) - self.func._variable_reads.update(info._reads) + self.func.mark_variable_writes(info._writes) + self.func.mark_variable_reads(info._reads) # validate and annotate folded value if node.has_folded_value: @@ -641,24 +689,23 @@ def _check_call_mutability(self, call_mutability: StateMutability): def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: func_info = get_expr_info(node.func, is_callable=True) func_type = func_info.typ - self.visit(node.func, func_type) if isinstance(func_type, ContractFunctionT): # function calls - func_info._writes.update(func_type._variable_writes) - func_info._reads.update(func_type._variable_reads) + if not func_type.from_interface: + for s in func_type.get_variable_writes(): + if s.variable.is_module_variable(): + func_info._writes.add(s) + for s in func_type.get_variable_reads(): + if s.variable.is_module_variable(): + func_info._reads.add(s) if self.function_analyzer: - if func_type.is_internal: - self.func.called_functions.add(func_type) - self._check_call_mutability(func_type.mutability) - # check that if the function accesses state, the defining - # module has been `used` or `initialized`. - for s in func_type._variable_accesses: - if s.is_module_variable(): + for s in func_type.get_variable_accesses(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node.func) if func_type.is_deploy and not self.func.is_deploy: @@ -689,7 +736,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: elif isinstance(func_type, MemberFunctionT): if func_type.is_modifying and self.function_analyzer is not None: # TODO refactor this - self.function_analyzer._handle_modification(node.func) + assert isinstance(node.func, vy_ast.Attribute) # help mypy + self.function_analyzer._handle_modification(node.func.value) assert len(node.args) == len(func_type.arg_types) for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) @@ -702,6 +750,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) + self.visit(node.func, func_type) + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # membership in list literal - `x in [a, b, c]` diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..a7d8300083 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -244,7 +244,7 @@ def validate_used_modules(self): all_used_modules = OrderedSet() for f in module_t.functions.values(): - for u in f._used_modules: + for u in f.get_used_modules(): all_used_modules.add(u.module_t) for used_module in all_used_modules: diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..034cd8c46e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -84,28 +84,24 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr info = self.get_expr_info(node.value, is_callable=is_callable) + attr = node.attr - attribute_chain = info.attribute_chain + [info] - - t = info.typ.get_member(name, node) + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_varinfo(t, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_moduleinfo(t, attr=attr) - # it's something else, like my_struct.foo - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - attribute_chain = info.attribute_chain + [info] - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t) return ExprInfo(t) diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index 38bac0a63d..94a26157af 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,7 +1,7 @@ from typing import Dict from vyper.semantics.analysis.base import Modifiability, VarInfo -from vyper.semantics.types import AddressT, BytesT, VyperType +from vyper.semantics.types import AddressT, BytesT, SelfT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -57,7 +57,7 @@ def get_constant_vars() -> Dict: return result -MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": AddressT} +MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": SelfT} def get_mutable_vars() -> Dict: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index a04632b96f..59a20dd99f 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -3,7 +3,7 @@ from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT -from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT +from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..705470a798 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -21,7 +21,7 @@ Modifiability, ModuleInfo, StateMutability, - VarInfo, + VarAccess, VarOffset, ) from vyper.semantics.analysis.utils import ( @@ -92,6 +92,7 @@ def __init__( return_type: Optional[VyperType], function_visibility: FunctionVisibility, state_mutability: StateMutability, + from_interface: bool = False, nonreentrant: Optional[str] = None, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: @@ -104,9 +105,12 @@ def __init__( self.visibility = function_visibility self.mutability = state_mutability self.nonreentrant = nonreentrant + self.from_interface = from_interface self.ast_def = ast_def + self._analysed = False + # a list of internal functions this function calls. # to be populated during analysis self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() @@ -115,10 +119,10 @@ def __init__( self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function - self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + self._variable_writes: OrderedSet[VarAccess] = OrderedSet() # reads of variables from this function - self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + self._variable_reads: OrderedSet[VarAccess] = OrderedSet() # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() @@ -127,10 +131,35 @@ def __init__( self._ir_info: Any = None self._function_id: Optional[int] = None + def mark_analysed(self): + assert not self._analysed + self._analysed = True + @property - def _variable_accesses(self): + def analysed(self): + return self._analysed + + def get_variable_reads(self): + return self._variable_reads + + def get_variable_writes(self): + return self._variable_writes + + def get_variable_accesses(self): return self._variable_reads | self._variable_writes + def get_used_modules(self): + return self._used_modules + + def mark_used_module(self, module_info): + self._used_modules.add(module_info) + + def mark_variable_writes(self, var_infos): + self._variable_writes.update(var_infos) + + def mark_variable_reads(self, var_infos): + self._variable_reads.update(var_infos) + @property def modifiability(self): return Modifiability.from_state_mutability(self.mutability) @@ -189,6 +218,7 @@ def from_abi(cls, abi: dict) -> "ContractFunctionT": positional_args, [], return_type, + from_interface=True, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.from_abi(abi), ) @@ -248,6 +278,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=None, ast_def=funcdef, ) @@ -300,6 +331,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -370,6 +402,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=False, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -410,6 +443,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio args, [], return_type, + from_interface=False, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, ast_def=node, diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 07d1a21a94..d383f72ab2 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -340,3 +340,13 @@ def validate_literal(self, node: vy_ast.Constant) -> None: f"address, the correct checksummed form is: {checksum_encode(addr)}", node, ) + + +# type for "self" +# refactoring note: it might be best for this to be a ModuleT actually +class SelfT(AddressT): + _id = "self" + + def compare_type(self, other): + # compares true to AddressT + return isinstance(other, type(self)) or isinstance(self, type(other))