Skip to content

Commit

Permalink
fix[lang]: pure access analysis (#3895)
Browse files Browse the repository at this point in the history
this commit fixes the ability to access module variables in pure
functions. the `_validate_self_reference()` utility function was
hard-coded to check the "self" name; remove it and replace with an
analysis-based check.

this commit also fixes the pure access check for immutable variables,
and address members (e.g. `.codesize`)

misc/refactor:
* rename `VarInfo.is_module_variable()` to more fitting
  `VarInfo.is_state_variable()`.
* refactor pure decorator tests to be in line with
  recent best practices
  • Loading branch information
charles-cooper committed Apr 2, 2024
1 parent 4595938 commit 20432c5
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 76 deletions.
154 changes: 105 additions & 49 deletions tests/functional/codegen/features/decorators/test_pure.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import FunctionDeclarationException, StateAccessViolation


def test_pure_operation(get_contract_with_gas_estimation_for_constants):
c = get_contract_with_gas_estimation_for_constants(
"""
def test_pure_operation(get_contract):
code = """
@pure
@external
def foo() -> int128:
return 5
"""
)
c = get_contract(code)
assert c.foo() == 5


def test_pure_call(get_contract_with_gas_estimation_for_constants):
c = get_contract_with_gas_estimation_for_constants(
"""
def test_pure_call(get_contract):
code = """
@pure
@internal
def _foo() -> int128:
Expand All @@ -26,21 +27,18 @@ def _foo() -> int128:
def foo() -> int128:
return self._foo()
"""
)
c = get_contract(code)
assert c.foo() == 5


def test_pure_interface(get_contract_with_gas_estimation_for_constants):
c1 = get_contract_with_gas_estimation_for_constants(
"""
def test_pure_interface(get_contract):
code1 = """
@pure
@external
def foo() -> int128:
return 5
"""
)
c2 = get_contract_with_gas_estimation_for_constants(
"""
code2 = """
interface Foo:
def foo() -> int128: pure
Expand All @@ -49,58 +47,120 @@ def foo() -> int128: pure
def foo(a: address) -> int128:
return staticcall Foo(a).foo()
"""
)
c1 = get_contract(code1)
c2 = get_contract(code2)
assert c2.foo(c1.address) == 5


def test_invalid_envar_access(get_contract, assert_compile_failed):
assert_compile_failed(
lambda: get_contract(
"""
def test_invalid_envar_access(get_contract):
code = """
@pure
@external
def foo() -> uint256:
return chain.id
"""
),
StateAccessViolation,
)
with pytest.raises(StateAccessViolation):
compile_code(code)


def test_invalid_codesize_access(get_contract):
code = """
@pure
@external
def foo(s: address) -> uint256:
return s.codesize
"""
with pytest.raises(StateAccessViolation):
compile_code(code)


def test_invalid_state_access(get_contract, assert_compile_failed):
assert_compile_failed(
lambda: get_contract(
"""
code = """
x: uint256
@pure
@external
def foo() -> uint256:
return self.x
"""
),
StateAccessViolation,
)
with pytest.raises(StateAccessViolation):
compile_code(code)


def test_invalid_immutable_access():
code = """
COUNTER: immutable(uint256)
@deploy
def __init__():
COUNTER = 1234
@pure
@external
def foo() -> uint256:
return COUNTER
"""
with pytest.raises(StateAccessViolation):
compile_code(code)


def test_invalid_self_access(get_contract, assert_compile_failed):
assert_compile_failed(
lambda: get_contract(
"""
def test_invalid_self_access():
code = """
@pure
@external
def foo() -> address:
return self
"""
),
StateAccessViolation,
)
with pytest.raises(StateAccessViolation):
compile_code(code)


def test_invalid_module_variable_access(make_input_bundle):
lib1 = """
counter: uint256
"""
code = """
import lib1
initializes: lib1
@pure
@external
def foo() -> uint256:
return lib1.counter
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(StateAccessViolation):
compile_code(code, input_bundle=input_bundle)


def test_invalid_module_immutable_access(make_input_bundle):
lib1 = """
COUNTER: immutable(uint256)
@deploy
def __init__():
COUNTER = 123
"""
code = """
import lib1
initializes: lib1
@deploy
def __init__():
lib1.__init__()
@pure
@external
def foo() -> uint256:
return lib1.COUNTER
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(StateAccessViolation):
compile_code(code, input_bundle=input_bundle)


def test_invalid_call(get_contract, assert_compile_failed):
assert_compile_failed(
lambda: get_contract(
"""
def test_invalid_call():
code = """
@view
@internal
def _foo() -> uint256:
Expand All @@ -111,21 +171,17 @@ def _foo() -> uint256:
def foo() -> uint256:
return self._foo() # Fails because of calling non-pure fn
"""
),
StateAccessViolation,
)
with pytest.raises(StateAccessViolation):
compile_code(code)


def test_invalid_conflicting_decorators(get_contract, assert_compile_failed):
assert_compile_failed(
lambda: get_contract(
"""
def test_invalid_conflicting_decorators():
code = """
@pure
@external
@payable
def foo() -> uint256:
return 5
"""
),
FunctionDeclarationException,
)
with pytest.raises(FunctionDeclarationException):
compile_code(code)
2 changes: 1 addition & 1 deletion vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class Dict(VyperNode):
keys: list = ...
values: list = ...

class Name(VyperNode):
class Name(ExprNode):
id: str = ...
_type: str = ...

Expand Down
8 changes: 6 additions & 2 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vyper.exceptions import CompilerPanic, StructureException
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.primitives import SelfT
from vyper.utils import OrderedSet, StringEnum

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,8 +200,11 @@ def set_position(self, position: VarOffset) -> None:
assert isinstance(position, VarOffset) # sanity check
self.position = position

def is_module_variable(self):
return self.location not in (DataLocation.UNSET, DataLocation.MEMORY)
def is_state_variable(self):
non_state_locations = (DataLocation.UNSET, DataLocation.MEMORY, DataLocation.CALLDATA)
# `self` gets a VarInfo, but it is not considered a state
# variable (it is magic), so we ignore it here.
return self.location not in non_state_locations and not isinstance(self.typ, SelfT)

def get_size(self) -> int:
return self.typ.get_size_in(self.location)
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/data_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def _allocate_layout_r(
continue

assert isinstance(node, vy_ast.VariableDecl)
# skip non-storage variables
# skip non-state variables
varinfo = node.target._metadata["varinfo"]
if not varinfo.is_module_variable():
if not varinfo.is_state_variable():
continue
location = varinfo.location

Expand Down
52 changes: 31 additions & 21 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from vyper.semantics.data_locations import DataLocation

# TODO consolidate some of these imports
from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS
from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types import (
TYPE_T,
Expand All @@ -51,6 +51,7 @@
HashMapT,
IntegerT,
SArrayT,
SelfT,
StringT,
StructT,
TupleT,
Expand Down Expand Up @@ -164,21 +165,32 @@ def _validate_msg_value_access(node: vy_ast.Attribute) -> None:
raise NonPayableViolation("msg.value is not allowed in non-payable functions", node)


def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None:
env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys())
if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars:
if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE:
return
def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name) -> None:
info = get_expr_info(node)

raise StateAccessViolation(
"not allowed to query contract or environment variables in pure functions", node
)
env_vars = CONSTANT_ENVIRONMENT_VARS
# check env variable access like `block.number`
if isinstance(node, vy_ast.Attribute):
if node.get("value.id") in env_vars:
raise StateAccessViolation(
"not allowed to query environment variables in pure functions"
)
parent_info = get_expr_info(node.value)
if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members:
raise StateAccessViolation("not allowed to query address members in pure functions")

if (varinfo := info.var_info) is None:
return
# self is magic. we only need to check it if it is not the root of an Attribute
# node. (i.e. it is bare like `self`, not `self.foo`)
is_naked_self = isinstance(varinfo.typ, SelfT) and not isinstance(
node.get_ancestor(), vy_ast.Attribute
)
if is_naked_self:
raise StateAccessViolation("not allowed to query `self` in pure functions")

def _validate_self_reference(node: vy_ast.Name) -> None:
# CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through
if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute):
raise StateAccessViolation("not allowed to query self in pure functions", node)
if varinfo.is_state_variable() or is_naked_self:
raise StateAccessViolation("not allowed to query state variables in pure functions")


# analyse the variable access for the attribute chain for a node
Expand Down Expand Up @@ -429,7 +441,7 @@ def _handle_modification(self, target: vy_ast.ExprNode):
info._writes.add(var_access)

def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode):
if not var_access.variable.is_module_variable():
if not var_access.variable.is_state_variable():
return

root_module_info = check_module_uses(target)
Expand Down Expand Up @@ -693,7 +705,7 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None:
_validate_msg_value_access(node)

if self.func and self.func.mutability == StateMutability.PURE:
_validate_pure_access(node, typ)
_validate_pure_access(node)

value_type = get_exact_type_from_node(node.value)

Expand Down Expand Up @@ -765,10 +777,10 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:

if not func_type.from_interface:
for s in func_type.get_variable_writes():
if s.variable.is_module_variable():
if s.variable.is_state_variable():
func_info._writes.add(s)
for s in func_type.get_variable_reads():
if s.variable.is_module_variable():
if s.variable.is_state_variable():
func_info._reads.add(s)

if self.function_analyzer:
Expand Down Expand Up @@ -873,10 +885,8 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None:
self.visit(element, typ.value_type)

def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None:
if self.func:
# TODO: refactor to use expr_info mutability
if self.func.mutability == StateMutability.PURE:
_validate_self_reference(node)
if self.func and self.func.mutability == StateMutability.PURE:
_validate_pure_access(node)

def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None:
if isinstance(typ, TYPE_T):
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def visit_ExportsDecl(self, node):

# check module uses
var_accesses = func_t.get_variable_accesses()
if any(s.variable.is_module_variable() for s in var_accesses):
if any(s.variable.is_state_variable() for s in var_accesses):
module_info = check_module_uses(item)
assert module_info is not None # guaranteed by above checks
used_modules.add(module_info)
Expand Down

0 comments on commit 20432c5

Please sign in to comment.