Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: builtin functions inherit from VyperType #3559

Merged
merged 12 commits into from
Nov 22, 2023
25 changes: 13 additions & 12 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Dict
from typing import Any, Optional

from vyper.ast import nodes as vy_ast
from vyper.ast.validation import validate_call_args
Expand Down Expand Up @@ -74,12 +74,14 @@ def decorator_fn(self, node, context):
return decorator_fn


class BuiltinFunction(VyperType):
class BuiltinFunctionT(VyperType):
_has_varargs = False
_kwargs: Dict[str, KwargSettings] = {}
_inputs: list[tuple[str, Any]] = []
_kwargs: dict[str, KwargSettings] = {}
_return_type: Optional[VyperType] = None

# helper function to deal with TYPE_DEFINITIONs
def _validate_single(self, arg, expected_type):
def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None:
# TODO using "TYPE_DEFINITION" is a kludge in derived classes,
# refactor me.
if expected_type == "TYPE_DEFINITION":
Expand All @@ -89,15 +91,15 @@ def _validate_single(self, arg, expected_type):
else:
validate_expected_type(arg, expected_type)

def _validate_arg_types(self, node):
def _validate_arg_types(self, node: vy_ast.Call) -> None:
num_args = len(self._inputs) # the number of args the signature indicates

expect_num_args = num_args
expect_num_args: Any = num_args
if self._has_varargs:
# note special meaning for -1 in validate_call_args API
expect_num_args = (num_args, -1)

validate_call_args(node, expect_num_args, self._kwargs)
validate_call_args(node, expect_num_args, list(self._kwargs.keys()))

for arg, (_, expected) in zip(node.args, self._inputs):
self._validate_single(arg, expected)
Expand All @@ -118,13 +120,12 @@ def _validate_arg_types(self, node):
# ensures the type can be inferred exactly.
get_exact_type_from_node(arg)

def fetch_call_return(self, node):
def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
self._validate_arg_types(node)

if self._return_type:
return self._return_type
return self._return_type

def infer_arg_types(self, node):
def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]:
self._validate_arg_types(node)
ret = [expected for (_, expected) in self._inputs]

Expand All @@ -136,7 +137,7 @@ def infer_arg_types(self, node):
ret.extend(get_exact_type_from_node(arg) for arg in varargs)
return ret

def infer_kwarg_types(self, node):
def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]:
return {i.arg: self._kwargs[i.arg].typ for i in node.keywords}

def __repr__(self):
Expand Down
Loading
Loading