Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type refinement on method calls #600

Merged
merged 11 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruby.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- uses: actions/checkout@v3
- name: Reset bundler
run: |
rm Gemfile.lock
rm Gemfile.lock Gemfile.steep.lock
if: contains(matrix.container_tag, '2.6')
- name: Run test
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
/lib/steep/parser.output
/lib/steep/parser.rb
/log
/.gem_rbs_collection
2 changes: 2 additions & 0 deletions bin/setup
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ set -vx

bundle install

bundle install --gemfile=Gemfile.steep
bundle exec --gemfile=Gemfile.steep rbs --collection=rbs_collection.steep.yaml collection install
2 changes: 1 addition & 1 deletion lib/steep.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

require "steep/equatable"
require "steep/method_name"
require "steep/node_helper"
require "steep/ast/types/helper"
require "steep/ast/types/any"
require "steep/ast/types/instance"
Expand All @@ -47,7 +48,6 @@
require "steep/ast/types/proc"
require "steep/ast/types/record"
require "steep/ast/types/logic"
require "steep/ast/type_params"
require "steep/ast/annotation"
require "steep/ast/annotation/collection"
require "steep/ast/builtin"
Expand Down
2 changes: 1 addition & 1 deletion lib/steep/method_name.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def to_s
end
end

module ::Kernel
class ::Object
def MethodName(string)
case string
when /#/
Expand Down
49 changes: 49 additions & 0 deletions lib/steep/node_helper.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module Steep
module NodeHelper
def each_child_node(node, &block)
if block
node.children.each do |child|
if child.is_a?(Parser::AST::Node)
yield child
end
end
else
enum_for :each_child_node, node
end
end

def each_descendant_node(node, &block)
if block
each_child_node(node) do |child|
yield child
each_descendant_node(child, &block)
end
else
enum_for :each_descendant_node, node
end
end

def value_node?(node)
case node.type
when :true, :false, :str, :sym, :int, :float, :nil
true
when :lvar
true
when :const
each_child_node(node).all? {|child| child.type == :cbase || value_node?(child) }
when :array
each_child_node(node).all? {|child| value_node?(child) }
when :hash
each_child_node(node).all? do |pair|
each_child_node(pair).all? {|child| value_node?(child) }
end
when :dstr
each_child_node(node).all? {|child| value_node?(child)}
when :begin
each_child_node(node).all? {|child| value_node?(node) }
else
false
end
end
end
end
4 changes: 4 additions & 0 deletions lib/steep/shims/symbol_start_with.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ module SymbolStartWith
def start_with?(*args)
to_s.start_with?(*args)
end

def end_with?(*args)
to_s.end_with?(*args)
end
end

unless Symbol.method_defined?(:start_with?)
Expand Down
117 changes: 58 additions & 59 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def to_ary
end
end

include NodeHelper

def inspect
s = "#<%s:%#018x " % [self.class, object_id]
s + ">"
Expand Down Expand Up @@ -200,10 +202,10 @@ def for_new_method(method_name, node, args:, self_type:, definition:)

local_variable_types = method_params.each_param.with_object({}) do |param, hash|
if param.name
hash[param.name] = context.type_env.assignment(param.name, param.var_type)
hash[param.name] = param.var_type
end
end
type_env = context.type_env.update(local_variable_types: local_variable_types)
type_env = context.type_env.assign_local_variables(local_variable_types)

type_env = TypeInference::TypeEnvBuilder.new(
TypeInference::TypeEnvBuilder::Command::ImportLocalVariableAnnotations.new(annots).merge!.on_duplicate! do |name, original, annotated|
Expand Down Expand Up @@ -716,9 +718,9 @@ def synthesize(node, hint: nil, condition: false)
rhs_type, rhs_constr, rhs_context = synthesize(rhs, hint: hint).to_ary

constr = rhs_constr.update_type_env do |type_env|
entry = type_env.assignment(name, rhs_type)
var_type = rhs_type

if enforced_type = entry[1]
if enforced_type = type_env.enforced_type(name)
if result = no_subtyping?(sub_type: rhs_type, super_type: enforced_type)
typing.add_error(
Diagnostic::Ruby::IncompatibleAssignment.new(
Expand All @@ -729,15 +731,15 @@ def synthesize(node, hint: nil, condition: false)
)
)

entry[0] = enforced_type
var_type = enforced_type
end

if rhs_type.is_a?(AST::Types::Any)
entry[0] = enforced_type
var_type = enforced_type
end
end

type_env.merge(local_variable_types: { name => entry })
type_env.assign_local_variable(name, var_type, enforced_type)
end

constr.add_typing(node, type: rhs_type)
Expand Down Expand Up @@ -1829,14 +1831,7 @@ def synthesize(node, hint: nil, condition: false)
branch_results = []

cond_type, constr = constr.synthesize(cond).to_ary
_, cond_vars = interpreter.decompose_value(cond)
unless cond_vars.empty?
first_var = cond_vars.to_a[0]
var_node = cond.updated(:lvar, [first_var])
else
first_var = nil
var_node = cond
end
cond_value_node, cond_vars = interpreter.decompose_value(cond)

when_constr = constr
whens.each do |clause|
Expand All @@ -1849,7 +1844,7 @@ def synthesize(node, hint: nil, condition: false)
test_envs = []

tests.each do |test|
test_node = test.updated(:send, [test, :===, var_node])
test_node = test.updated(:send, [test, :===, cond])
test_type, test_constr = test_constr.synthesize(test_node, condition: true).to_ary
truthy_env, falsy_env = interpreter.eval(type: test_type, node: test_node, env: test_constr.context.type_env)

Expand All @@ -1861,20 +1856,8 @@ def synthesize(node, hint: nil, condition: false)
body_constr = when_constr.update_type_env {|env| env.join(*test_envs) }

if body
# @type var assignments: Hash[Symbol, TypeInference::TypeEnv::local_variable_entry]

if first_var
var_type = body_constr.context.type_env[first_var] or raise
assignments = cond_vars.each_with_object({}) do |var, hash|
hash[var] = body_constr.context.type_env.assignment(var, var_type)
end
else
assignments = {}
end

branch_results <<
body_constr
.update_type_env {|env| env.merge(local_variable_types: assignments) }
.for_branch(body)
.tap {|constr| typing.add_context_for_node(body, context: constr.context) }
.synthesize(body, hint: hint)
Expand All @@ -1896,7 +1879,8 @@ def synthesize(node, hint: nil, condition: false)
types = branch_results.map(&:type)
constrs = branch_results.map(&:constr)

if first_var && when_constr.context.type_env[first_var].is_a?(AST::Types::Bot)
cond_type = when_constr.context.type_env[cond_value_node]
if cond_type.is_a?(AST::Types::Bot)
# Exhaustive
if els
typing.add_error Diagnostic::Ruby::ElseOnExhaustiveCase.new(node: els, type: cond_type)
Expand Down Expand Up @@ -1998,7 +1982,6 @@ def synthesize(node, hint: nil, condition: false)
end

resbody_construction = body_constr.for_branch(resbody).update_type_env do |env|
# @type var assignments: Hash[Symbol, TypeInference::TypeEnv::local_variable_entry]
assignments = {}

case
Expand All @@ -2013,12 +1996,12 @@ def synthesize(node, hint: nil, condition: false)
end
end

assignments[var_name] = env.assignment(var_name, AST::Types::Union.build(types: instance_types))
assignments[var_name] = AST::Types::Union.build(types: instance_types)
when var_name
assignments[var_name] = env.assignment(var_name, AST::Builtin.any_type)
assignments[var_name] = AST::Builtin.any_type
end

env.merge(local_variable_types: assignments)
env.assign_local_variables(assignments)
end

if body
Expand Down Expand Up @@ -2086,8 +2069,7 @@ def synthesize(node, hint: nil, condition: false)
if var_type
if body
body_constr = constr.update_type_env do |type_env|
assign = type_env.assignment(var_name, var_type)
type_env = type_env.merge(local_variable_types: { var_name => assign })
type_env = type_env.assign_local_variables({ var_name => var_type })
pins = type_env.pin_local_variables(nil)
type_env.merge(local_variable_types: pins)
end
Expand Down Expand Up @@ -2461,23 +2443,23 @@ def masgn_lhs?(lhs)

def lvasgn(node, type)
name = node.children[0]
assignment = context.type_env.assignment(name, type)
if enforced_type = assignment[1]
if result = no_subtyping?(sub_type: assignment[0], super_type: enforced_type)

if enforced_type = context.type_env.enforced_type(name)
if result = no_subtyping?(sub_type: type, super_type: enforced_type)
typing.add_error(
Diagnostic::Ruby::IncompatibleAssignment.new(
node: node,
lhs_type: assignment[1],
rhs_type: assignment[0],
lhs_type: enforced_type,
rhs_type: type,
result: result
)
)

assignment[0] = enforced_type
type = enforced_type
end
end

update_type_env {|env| env.merge(local_variable_types: { name => assignment }) }
update_type_env {|env| env.assign_local_variable(name, type, enforced_type) }
.add_typing(node, type: type)
end

Expand Down Expand Up @@ -2880,6 +2862,20 @@ def synthesize_children(node, skips: [])
constr
end

KNOWN_PURE_METHODS = Set[
MethodName("::Array#[]"),
MethodName("::Hash#[]")
]

def pure_send?(call, receiver, arguments)
return false unless call.node.type == :send || call.node.type == :csend
return false unless call.pure? || KNOWN_PURE_METHODS.intersect?(Set.new(call.method_decls.map(&:method_name)))

[receiver, *arguments].all? do |node|
!node || value_node?(node) || context.type_env[node]
end
end

def type_send_interface(node, interface:, receiver:, receiver_type:, method_name:, arguments:, block_params:, block_body:)
method = interface.methods[method_name]

Expand All @@ -2894,10 +2890,27 @@ def type_send_interface(node, interface:, receiver:, receiver_type:, method_name
topdown_hint: true)

if call && constr
if (pure_call, type = constr.context.type_env.pure_method_calls[node])
call = pure_call.with_return_type(type)
end

case method_name.to_s
when "[]=", /\w=\Z/
if typing.has_type?(arguments.last)
call = call.with_return_type(typing.type_of(node: arguments.last))
last_arg = arguments.last or raise
if typing.has_type?(last_arg)
call = call.with_return_type(typing.type_of(node: last_arg))
end
end

if call.is_a?(TypeInference::MethodCall::Typed)
if pure_send?(call, receiver, arguments)
constr = constr.update_type_env do |env|
env.add_pure_call(node, call, call.return_type)
end
else
constr = constr.update_type_env do |env|
env.invalidate_pure_node(receiver)
end
end
end

Expand Down Expand Up @@ -3428,9 +3441,7 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, arguments:
node_type_hint: method_type.type.return_type
)
block_constr = block_constr.with_new_typing(
block_constr.typing.new_child(
range: block_constr.typing.block_range(node)
)
block_constr.typing.new_child(block_constr.typing.block_range(node))
)

block_constr.typing.add_context_for_body(node, context: block_constr.context)
Expand Down Expand Up @@ -3809,18 +3820,6 @@ def synthesize_block(node:, block_type_hint:, block_body:)
end
end

def each_child_node(node)
if block_given?
node.children.each do |child|
if child.is_a?(::AST::Node)
yield child
end
end
else
enum_for :each_child_node, node
end
end

def nesting
module_context&.nesting
end
Expand Down
Loading