diff --git a/lib/steep.rb b/lib/steep.rb index d05b9a201..92e828a9e 100644 --- a/lib/steep.rb +++ b/lib/steep.rb @@ -42,6 +42,8 @@ require "steep/ast/builtin" require "steep/ast/types/factory" +require "steep/interface/function" +require "steep/interface/block" require "steep/interface/method_type" require "steep/interface/substitution" require "steep/interface/interface" diff --git a/lib/steep/ast/types/factory.rb b/lib/steep/ast/types/factory.rb index e6338bfa0..375093fb9 100644 --- a/lib/steep/ast/types/factory.rb +++ b/lib/steep/ast/types/factory.rb @@ -75,9 +75,23 @@ def type(type) end Record.new(elements: elements, location: nil) when RBS::Types::Proc - params = params(type.type) - return_type = type(type.type.return_type) - Proc.new(params: params, return_type: return_type, location: nil) + func = Interface::Function.new( + params: params(type.type), + return_type: type(type.type.return_type), + location: type.location + ) + block = if type.block + Interface::Block.new( + type: Interface::Function.new( + params: params(type.block.type), + return_type: type(type.block.type.return_type), + location: type.location + ), + optional: !type.block.required + ) + end + + Proc.new(type: func, block: block) else raise "Unexpected type given: #{type}" end @@ -145,8 +159,15 @@ def type_1(type) end RBS::Types::Record.new(fields: fields, location: nil) when Proc + block = if type.block + RBS::Types::Block.new( + type: function_1(type.block.type), + required: !type.block.optional? + ) + end RBS::Types::Proc.new( - type: function_1(type.params, type.return_type), + type: function_1(type.type), + block: block, location: nil ) when Logic::Base @@ -156,7 +177,10 @@ def type_1(type) end end - def function_1(params, return_type) + def function_1(func) + params = func.params + return_type = func.return_type + RBS::Types::Function.new( required_positionals: params.required.map {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) }, optional_positionals: params.optional.map {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) }, @@ -170,7 +194,7 @@ def function_1(params, return_type) end def params(type) - Interface::Params.new( + Interface::Function::Params.new( required: type.required_positionals.map {|param| type(param.type) }, optional: type.optional_positionals.map {|param| type(param.type) }, rest: type.rest_positionals&.yield_self {|param| type(param.type) }, @@ -202,13 +226,19 @@ def method_type(method_type, self_type:, subst2: nil, method_decls:) type = Interface::MethodType.new( type_params: type_params, - return_type: type(method_type.type.return_type).subst(subst), - params: params(method_type.type).subst(subst), + type: Interface::Function.new( + params: params(method_type.type).subst(subst), + return_type: type(method_type.type.return_type).subst(subst), + location: method_type.location + ), block: method_type.block&.yield_self do |block| Interface::Block.new( optional: !block.required, - type: Proc.new(params: params(block.type).subst(subst), - return_type: type(block.type.return_type).subst(subst), location: nil) + type: Interface::Function.new( + params: params(block.type).subst(subst), + return_type: type(block.type.return_type).subst(subst), + location: nil + ) ) end, method_decls: method_decls @@ -242,12 +272,12 @@ def method_type_1(method_type, self_type:) type = RBS::MethodType.new( type_params: type_params, - type: function_1(method_type.params.subst(subst), method_type.return_type.subst(subst)), + type: function_1(method_type.type.subst(subst)), block: method_type.block&.yield_self do |block| block_type = block.type.subst(subst) - RBS::MethodType::Block.new( - type: function_1(block_type.params, block_type.return_type), + RBS::Types::Block.new( + type: function_1(block_type), required: !block.optional ) end, @@ -354,7 +384,9 @@ def setup_primitives(method_name, method_def, method_type) when :is_a?, :kind_of?, :instance_of? if defined_in == RBS::BuiltinNames::Object.name && member.instance? return method_type.with( - return_type: AST::Types::Logic::ReceiverIsArg.new(location: method_type.return_type.location) + type: method_type.type.with( + return_type: AST::Types::Logic::ReceiverIsArg.new(location: method_type.type.return_type.location) + ) ) end @@ -363,7 +395,9 @@ def setup_primitives(method_name, method_def, method_type) when RBS::BuiltinNames::Object.name, NilClassName return method_type.with( - return_type: AST::Types::Logic::ReceiverIsNil.new(location: method_type.return_type.location) + type: method_type.type.with( + return_type: AST::Types::Logic::ReceiverIsNil.new(location: method_type.type.return_type.location) + ) ) end @@ -373,7 +407,9 @@ def setup_primitives(method_name, method_def, method_type) RBS::BuiltinNames::TrueClass.name, RBS::BuiltinNames::FalseClass.name return method_type.with( - return_type: AST::Types::Logic::Not.new(location: method_type.return_type.location) + type: method_type.type.with( + return_type: AST::Types::Logic::Not.new(location: method_type.type.return_type.location) + ) ) end @@ -381,7 +417,9 @@ def setup_primitives(method_name, method_def, method_type) case defined_in when RBS::BuiltinNames::Module.name return method_type.with( - return_type: AST::Types::Logic::ArgIsReceiver.new(location: method_type.return_type.location) + type: method_type.type.with( + return_type: AST::Types::Logic::ArgIsReceiver.new(location: method_type.type.return_type.location) + ) ) end end @@ -572,14 +610,17 @@ def interface(type, private:, self_type: type) method_types: type.types.map.with_index {|elem_type, index| Interface::MethodType.new( type_params: [], - params: Interface::Params.new(required: [AST::Types::Literal.new(value: index)], - optional: [], - rest: nil, - required_keywords: {}, - optional_keywords: {}, - rest_keywords: nil), + type: Interface::Function.new( + params: Interface::Function::Params.new(required: [AST::Types::Literal.new(value: index)], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil), + return_type: elem_type, + location: nil + ), block: nil, - return_type: elem_type, method_decls: Set[] ) } + aref.method_types @@ -591,14 +632,17 @@ def interface(type, private:, self_type: type) method_types: type.types.map.with_index {|elem_type, index| Interface::MethodType.new( type_params: [], - params: Interface::Params.new(required: [AST::Types::Literal.new(value: index), elem_type], - optional: [], - rest: nil, - required_keywords: {}, - optional_keywords: {}, - rest_keywords: nil), + type: Interface::Function.new( + params: Interface::Function::Params.new(required: [AST::Types::Literal.new(value: index), elem_type], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil), + return_type: elem_type, + location: nil + ), block: nil, - return_type: elem_type, method_decls: Set[] ) } + update.method_types @@ -610,9 +654,12 @@ def interface(type, private:, self_type: type) method_types: [ Interface::MethodType.new( type_params: [], - params: Interface::Params.empty, + type: Interface::Function.new( + params: Interface::Function::Params.empty, + return_type: type.types[0] || AST::Builtin.nil_type, + location: nil + ), block: nil, - return_type: type.types[0] || AST::Builtin.nil_type, method_decls: Set[] ) ] @@ -624,9 +671,12 @@ def interface(type, private:, self_type: type) method_types: [ Interface::MethodType.new( type_params: [], - params: Interface::Params.empty, + type: Interface::Function.new( + params: Interface::Function::Params.empty, + return_type: type.types.last || AST::Builtin.nil_type, + location: nil + ), block: nil, - return_type: type.types.last || AST::Builtin.nil_type, method_decls: Set[] ) ] @@ -651,14 +701,17 @@ def interface(type, private:, self_type: type) Interface::MethodType.new( type_params: [], - params: Interface::Params.new(required: [key_type], - optional: [], - rest: nil, - required_keywords: {}, - optional_keywords: {}, - rest_keywords: nil), + type: Interface::Function.new( + params: Interface::Function::Params.new(required: [key_type], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil), + return_type: value_type, + location: nil + ), block: nil, - return_type: value_type, method_decls: Set[] ) } + ref.method_types @@ -671,14 +724,16 @@ def interface(type, private:, self_type: type) key_type = Literal.new(value: key_value, location: nil) Interface::MethodType.new( type_params: [], - params: Interface::Params.new(required: [key_type, value_type], - optional: [], - rest: nil, - required_keywords: {}, - optional_keywords: {}, - rest_keywords: nil), + type: Interface::Function.new( + params: Interface::Function::Params.new(required: [key_type, value_type], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil), + return_type: value_type, + location: nil), block: nil, - return_type: value_type, method_decls: Set[] ) } + update.method_types @@ -691,14 +746,18 @@ def interface(type, private:, self_type: type) interface(Builtin::Proc.instance_type, private: private, self_type: self_type).tap do |interface| method_type = Interface::MethodType.new( type_params: [], - params: type.params, - return_type: type.return_type, - block: nil, + type: type.type, + block: type.block, method_decls: Set[] ) - interface.methods[:[]] = Interface::Interface::Entry.new(method_types: [method_type]) interface.methods[:call] = Interface::Interface::Entry.new(method_types: [method_type]) + + if type.block_required? + interface.methods.delete(:[]) + else + interface.methods[:[]] = Interface::Interface::Entry.new(method_types: [method_type.with(block: nil)]) + end end when Logic::Base diff --git a/lib/steep/ast/types/proc.rb b/lib/steep/ast/types/proc.rb index e69d17aee..1812dae31 100644 --- a/lib/steep/ast/types/proc.rb +++ b/lib/steep/ast/types/proc.rb @@ -3,68 +3,76 @@ module AST module Types class Proc attr_reader :location - attr_reader :params - attr_reader :return_type + attr_reader :type + attr_reader :block - def initialize(params:, return_type:, location: nil) + def initialize(type:, block:, location: type.location) + @type = type + @block = block @location = location - @params = params - @return_type = return_type end def ==(other) - other.is_a?(self.class) && - other.params == params && - other.return_type == return_type + other.is_a?(self.class) && other.type == type && other.block == block end def hash - self.class.hash && params.hash && return_type.hash + self.class.hash ^ type.hash ^ block.hash end alias eql? == def subst(s) self.class.new( - params: params.subst(s), - return_type: return_type.subst(s), + type: type.subst(s), + block: block&.subst(s), location: location ) end def to_s - "^#{params} -> #{return_type}" + if block + "^#{type.params} #{block} -> #{type.return_type}" + else + "^#{type.params} -> #{type.return_type}" + end end def free_variables() - @fvs ||= Set.new.tap do |set| - set.merge(params.free_variables) - set.merge(return_type.free_variables) + @fvs ||= Set[].tap do |fvs| + fvs.merge(type.free_variables) + fvs.merge(block.free_variables) if block end end def level - children = params.each_type.to_a + [return_type] + children = type.params.each_type.to_a + [type.return_type] + if block + children.push(*block.type.params.each_type.to_a) + children.push(block.type.return_type) + end [0] + level_of_children(children) end def closed? - params.closed? && return_type.closed? + type.closed? && (block.nil? || block.closed?) end def with_location(new_location) - self.class.new(location: new_location, params: params, return_type: return_type) + self.class.new(location: new_location, block: block, type: type) end def map_type(&block) self.class.new( - params: params.map_type(&block), - return_type: yield(return_type), + type: type.map_type(&block), + block: self.block&.map_type(&block), location: location ) end def one_arg? + params = type.params + params.required.size == 1 && params.optional.empty? && !params.rest && @@ -78,6 +86,10 @@ def back_type args: [], location: location) end + + def block_required? + block && !block.optional? + end end end end diff --git a/lib/steep/interface/block.rb b/lib/steep/interface/block.rb new file mode 100644 index 000000000..7d3b29dfa --- /dev/null +++ b/lib/steep/interface/block.rb @@ -0,0 +1,79 @@ +module Steep + module Interface + class Block + attr_reader :type + attr_reader :optional + + def initialize(type:, optional:) + @type = type + @optional = optional + end + + def optional? + @optional + end + + def to_optional + self.class.new( + type: type, + optional: true + ) + end + + def ==(other) + other.is_a?(self.class) && other.type == type && other.optional == optional + end + + alias eql? == + + def hash + type.hash ^ optional.hash + end + + def closed? + type.closed? + end + + def subst(s) + ty = type.subst(s) + if ty == type + self + else + self.class.new( + type: ty, + optional: optional + ) + end + end + + def free_variables() + @fvs ||= type.free_variables + end + + def to_s + "#{optional? ? "?" : ""}{ #{type.params} -> #{type.return_type} }" + end + + def map_type(&block) + self.class.new( + type: type.map_type(&block), + optional: optional + ) + end + + def +(other) + optional = self.optional? || other.optional? + type = Function.new( + params: self.type.params + other.type.params, + return_type: AST::Types::Union.build(types: [self.type.return_type, other.type.return_type]), + location: nil + ) + + self.class.new( + type: type, + optional: optional + ) + end + end + end +end diff --git a/lib/steep/interface/function.rb b/lib/steep/interface/function.rb new file mode 100644 index 000000000..5b3b0686b --- /dev/null +++ b/lib/steep/interface/function.rb @@ -0,0 +1,770 @@ +module Steep + module Interface + class Function + class Params + attr_reader :required + attr_reader :optional + attr_reader :rest + attr_reader :required_keywords + attr_reader :optional_keywords + attr_reader :rest_keywords + + def initialize(required:, optional:, rest:, required_keywords:, optional_keywords:, rest_keywords:) + @required = required + @optional = optional + @rest = rest + @required_keywords = required_keywords + @optional_keywords = optional_keywords + @rest_keywords = rest_keywords + end + + def update(required: self.required, optional: self.optional, rest: self.rest, required_keywords: self.required_keywords, optional_keywords: self.optional_keywords, rest_keywords: self.rest_keywords) + self.class.new( + required: required, + optional: optional, + rest: rest, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest_keywords, + ) + end + + RequiredPositional = Struct.new(:type) + OptionalPositional = Struct.new(:type) + RestPositional = Struct.new(:type) + + def first_param + case + when !required.empty? + RequiredPositional.new(required[0]) + when !optional.empty? + OptionalPositional.new(optional[0]) + when rest + RestPositional.new(rest) + else + nil + end + end + + def with_first_param(param) + case param + when RequiredPositional + update(required: [param.type] + required) + when OptionalPositional + update(optional: [param.type] + required) + when RestPositional + update(rest: param.type) + else + self + end + end + + def has_positional? + first_param + end + + def self.empty + self.new( + required: [], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil + ) + end + + def ==(other) + other.is_a?(self.class) && + other.required == required && + other.optional == optional && + other.rest == rest && + other.required_keywords == required_keywords && + other.optional_keywords == optional_keywords && + other.rest_keywords == rest_keywords + end + + alias eql? == + + def hash + required.hash ^ optional.hash ^ rest.hash ^ required_keywords.hash ^ optional_keywords.hash ^ rest_keywords.hash + end + + def flat_unnamed_params + required.map {|p| [:required, p] } + optional.map {|p| [:optional, p] } + end + + def flat_keywords + required_keywords.merge optional_keywords + end + + def has_keywords? + !required_keywords.empty? || !optional_keywords.empty? || rest_keywords + end + + def without_keywords + self.class.new( + required: required, + optional: optional, + rest: rest, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil + ) + end + + def drop_first + case + when required.any? || optional.any? || rest + self.class.new( + required: required.any? ? required.drop(1) : [], + optional: required.empty? && optional.any? ? optional.drop(1) : optional, + rest: required.empty? && optional.empty? ? nil : rest, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest_keywords + ) + when has_keywords? + without_keywords + else + raise "Cannot drop from empty params" + end + end + + def each_missing_argument(args) + required.size.times do |index| + if index >= args.size + yield index + end + end + end + + def each_extra_argument(args) + return if rest + + if has_keywords? + args = args.take(args.count - 1) if args.count > 0 + end + + args.size.times do |index| + if index >= required.count + optional.count + yield index + end + end + end + + def each_missing_keyword(args) + return unless has_keywords? + + keywords, rest = extract_keywords(args) + + return unless rest.empty? + + required_keywords.each do |keyword, _| + yield keyword unless keywords.key?(keyword) + end + end + + def each_extra_keyword(args) + return unless has_keywords? + return if rest_keywords + + keywords, rest = extract_keywords(args) + + return unless rest.empty? + + all_keywords = flat_keywords + keywords.each do |keyword, _| + yield keyword unless all_keywords.key?(keyword) + end + end + + def extract_keywords(args) + last_arg = args.last + + keywords = {} + rest = [] + + if last_arg&.type == :hash + last_arg.children.each do |element| + case element.type + when :pair + if element.children[0].type == :sym + name = element.children[0].children[0] + keywords[name] = element.children[1] + end + when :kwsplat + rest << element.children[0] + end + end + end + + [keywords, rest] + end + + def each_type() + if block_given? + flat_unnamed_params.each do |(_, type)| + yield type + end + flat_keywords.each do |_, type| + yield type + end + rest and yield rest + rest_keywords and yield rest_keywords + else + enum_for :each_type + end + end + + def free_variables() + @fvs ||= Set.new.tap do |set| + each_type do |type| + set.merge(type.free_variables) + end + end + end + + def closed? + required.all?(&:closed?) && optional.all?(&:closed?) && (!rest || rest.closed?) && required_keywords.values.all?(&:closed?) && optional_keywords.values.all?(&:closed?) && (!rest_keywords || rest_keywords.closed?) + end + + def subst(s) + return self if s.empty? + return self if empty? + return self if free_variables.disjoint?(s.domain) + + rs = required.map {|t| t.subst(s) } + os = optional.map {|t| t.subst(s) } + r = rest&.subst(s) + rk = required_keywords.transform_values {|t| t.subst(s) } + ok = optional_keywords.transform_values {|t| t.subst(s) } + k = rest_keywords&.subst(s) + + if rs == required && os == optional && r == rest && rk == required_keywords && ok == optional_keywords && k == rest_keywords + self + else + self.class.new( + required: required.map {|t| t.subst(s) }, + optional: optional.map {|t| t.subst(s) }, + rest: rest&.subst(s), + required_keywords: required_keywords.transform_values {|t| t.subst(s) }, + optional_keywords: optional_keywords.transform_values {|t| t.subst(s) }, + rest_keywords: rest_keywords&.subst(s) + ) + end + end + + def size + required.size + optional.size + (rest ? 1 : 0) + required_keywords.size + optional_keywords.size + (rest_keywords ? 1 : 0) + end + + def to_s + required = self.required.map {|ty| ty.to_s } + optional = self.optional.map {|ty| "?#{ty}" } + rest = self.rest ? ["*#{self.rest}"] : [] + required_keywords = self.required_keywords.map {|name, type| "#{name}: #{type}" } + optional_keywords = self.optional_keywords.map {|name, type| "?#{name}: #{type}"} + rest_keywords = self.rest_keywords ? ["**#{self.rest_keywords}"] : [] + "(#{(required + optional + rest + required_keywords + optional_keywords + rest_keywords).join(", ")})" + end + + def map_type(&block) + self.class.new( + required: required.map(&block), + optional: optional.map(&block), + rest: rest && yield(rest), + required_keywords: required_keywords.transform_values(&block), + optional_keywords: optional_keywords.transform_values(&block), + rest_keywords: rest_keywords && yield(rest_keywords) + ) + end + + def empty? + !has_positional? && !has_keywords? + end + + # self + params returns a new params for overloading. + # + def +(other) + a = first_param + b = other.first_param + + case + when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other.drop_first).with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.nil? + (self.drop_first + other).with_first_param(OptionalPositional.new(a.type)) + when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.nil? + (self.drop_first + other).with_first_param(OptionalPositional.new(a.type)) + when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self + other.drop_first).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self + other.drop_first).with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first + other.drop_first).with_first_param(RestPositional.new(type)) + end + when a.is_a?(RestPositional) && b.nil? + (self.drop_first + other).with_first_param(RestPositional.new(a.type)) + when a.nil? && b.is_a?(RequiredPositional) + (self + other.drop_first).with_first_param(OptionalPositional.new(b.type)) + when a.nil? && b.is_a?(OptionalPositional) + (self + other.drop_first).with_first_param(OptionalPositional.new(b.type)) + when a.nil? && b.is_a?(RestPositional) + (self + other.drop_first).with_first_param(RestPositional.new(b.type)) + when a.nil? && b.nil? + required_keywords = {} + + (Set.new(self.required_keywords.keys) & Set.new(other.required_keywords.keys)).each do |keyword| + required_keywords[keyword] = AST::Types::Union.build( + types: [ + self.required_keywords[keyword], + other.required_keywords[keyword] + ] + ) + end + + optional_keywords = {} + self.required_keywords.each do |keyword, t| + unless required_keywords.key?(keyword) + case + when other.optional_keywords.key?(keyword) + optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.optional_keywords[keyword]]) + when other.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.rest_keywords]) + else + optional_keywords[keyword] = t + end + end + end + other.required_keywords.each do |keyword, t| + unless required_keywords.key?(keyword) + case + when self.optional_keywords.key?(keyword) + optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.optional_keywords[keyword]]) + when self.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.rest_keywords]) + else + optional_keywords[keyword] = t + end + end + end + self.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) + case + when other.optional_keywords.key?(keyword) + optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.optional_keywords[keyword]]) + when other.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.rest_keywords]) + else + optional_keywords[keyword] = t + end + end + end + other.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) + case + when self.optional_keywords.key?(keyword) + optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.optional_keywords[keyword]]) + when self.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.rest_keywords]) + else + optional_keywords[keyword] = t + end + end + end + + rest = case + when self.rest_keywords && other.rest_keywords + AST::Types::Union.build(types: [self.rest_keywords, other.rest_keywords]) + else + self.rest_keywords || other.rest_keywords + end + + Params.new( + required: [], + optional: [], + rest: nil, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest) + end + end + + # Returns the intersection between self and other. + # Returns nil if the intersection cannot be computed. + # + def &(other) + a = first_param + b = other.first_param + + case + when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.nil? + nil + when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.nil? + self.drop_first & other + when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self & other.drop_first)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self & other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(RestPositional) + AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first & other.drop_first)&.with_first_param(RestPositional.new(type)) + end + when a.is_a?(RestPositional) && b.nil? + self.drop_first & other + when a.nil? && b.is_a?(RequiredPositional) + nil + when a.nil? && b.is_a?(OptionalPositional) + self & other.drop_first + when a.nil? && b.is_a?(RestPositional) + self & other.drop_first + when a.nil? && b.nil? + optional_keywords = {} + + (Set.new(self.optional_keywords.keys) & Set.new(other.optional_keywords.keys)).each do |keyword| + optional_keywords[keyword] = AST::Types::Intersection.build( + types: [ + self.optional_keywords[keyword], + other.optional_keywords[keyword] + ] + ) + end + + required_keywords = {} + self.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) + case + when other.required_keywords.key?(keyword) + required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.required_keywords[keyword]]) + when other.rest_keywords + optional_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) + end + end + end + other.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) + case + when self.required_keywords.key?(keyword) + required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.required_keywords[keyword]]) + when self.rest_keywords + optional_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) + end + end + end + self.required_keywords.each do |keyword, t| + unless required_keywords.key?(keyword) + case + when other.required_keywords.key?(keyword) + required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.required_keywords[keyword]]) + when other.rest_keywords + required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) + else + return + end + end + end + other.required_keywords.each do |keyword, t| + unless required_keywords.key?(keyword) + case + when self.required_keywords.key?(keyword) + required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.required_keywords[keyword]]) + when self.rest_keywords + required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) + else + return + end + end + end + + rest = case + when self.rest_keywords && other.rest_keywords + AST::Types::Intersection.build(types: [self.rest_keywords, other.rest_keywords]) + else + nil + end + + Params.new( + required: [], + optional: [], + rest: nil, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest) + end + end + + # Returns the union between self and other. + # + def |(other) + a = first_param + b = other.first_param + + case + when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.nil? + self.drop_first&.with_first_param(OptionalPositional.new(a.type)) + when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.nil? + (self.drop_first | other)&.with_first_param(a) + when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(RestPositional.new(type)) + end + when a.is_a?(RestPositional) && b.nil? + (self.drop_first | other)&.with_first_param(a) + when a.nil? && b.is_a?(RequiredPositional) + other.drop_first&.with_first_param(OptionalPositional.new(b.type)) + when a.nil? && b.is_a?(OptionalPositional) + (self | other.drop_first)&.with_first_param(b) + when a.nil? && b.is_a?(RestPositional) + (self | other.drop_first)&.with_first_param(b) + when a.nil? && b.nil? + required_keywords = {} + optional_keywords = {} + + (Set.new(self.required_keywords.keys) & Set.new(other.required_keywords.keys)).each do |keyword| + required_keywords[keyword] = AST::Types::Union.build( + types: [ + self.required_keywords[keyword], + other.required_keywords[keyword] + ] + ) + end + + self.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = other.required_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when s = other.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = other.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + other.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = self.required_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when s = self.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = self.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + self.required_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = other.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = other.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + other.required_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = self.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = self.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + + rest = case + when self.rest_keywords && other.rest_keywords + AST::Types::Union.build(types: [self.rest_keywords, other.rest_keywords]) + when self.rest_keywords + if required_keywords.empty? && optional_keywords.empty? + self.rest_keywords + end + when other.rest_keywords + if required_keywords.empty? && optional_keywords.empty? + other.rest_keywords + end + else + nil + end + + Params.new( + required: [], + optional: [], + rest: nil, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest) + end + end + end + + attr_reader :params + attr_reader :return_type + attr_reader :location + + def initialize(params:, return_type:, location:) + @params = params + @return_type = return_type + @location = location + end + + def ==(other) + other.is_a?(Function) && other.params == params && other.return_type == return_type + end + + alias eql? == + + def hash + self.class.hash ^ params.hash ^ return_type.hash + end + + def free_variables + @fvs ||= Set[].tap do |fvs| + fvs.merge(params.free_variables) + fvs.merge(return_type.free_variables) + end + end + + def subst(s) + return self if s.empty? + + Function.new( + params: params.subst(s), + return_type: return_type.subst(s), + location: location + ) + end + + def each_type(&block) + if block_given? + params.each_type(&block) + yield return_type + else + enum_for :each_type + end + end + + def map_type(&block) + Function.new( + params: params.map_type(&block), + return_type: yield(return_type), + location: location + ) + end + + def with(params: self.params, return_type: self.return_type) + Function.new( + params: params, + return_type: return_type, + location: location + ) + end + + def to_s + "#{params} -> #{return_type}" + end + + def closed? + params.closed? && return_type.closed? + end + end + end +end diff --git a/lib/steep/interface/method_type.rb b/lib/steep/interface/method_type.rb index acd6e6868..d0faf21a2 100644 --- a/lib/steep/interface/method_type.rb +++ b/lib/steep/interface/method_type.rb @@ -1,810 +1,37 @@ module Steep module Interface - class Params - attr_reader :required - attr_reader :optional - attr_reader :rest - attr_reader :required_keywords - attr_reader :optional_keywords - attr_reader :rest_keywords - - def initialize(required:, optional:, rest:, required_keywords:, optional_keywords:, rest_keywords:) - @required = required - @optional = optional - @rest = rest - @required_keywords = required_keywords - @optional_keywords = optional_keywords - @rest_keywords = rest_keywords - end - - def update(required: self.required, optional: self.optional, rest: self.rest, required_keywords: self.required_keywords, optional_keywords: self.optional_keywords, rest_keywords: self.rest_keywords) - self.class.new( - required: required, - optional: optional, - rest: rest, - required_keywords: required_keywords, - optional_keywords: optional_keywords, - rest_keywords: rest_keywords, - ) - end - - RequiredPositional = Struct.new(:type) - OptionalPositional = Struct.new(:type) - RestPositional = Struct.new(:type) - - def first_param - case - when !required.empty? - RequiredPositional.new(required[0]) - when !optional.empty? - OptionalPositional.new(optional[0]) - when rest - RestPositional.new(rest) - else - nil - end - end - - def with_first_param(param) - case param - when RequiredPositional - update(required: [param.type] + required) - when OptionalPositional - update(optional: [param.type] + required) - when RestPositional - update(rest: param.type) - else - self - end - end - - def has_positional? - first_param - end - - def self.empty - self.new( - required: [], - optional: [], - rest: nil, - required_keywords: {}, - optional_keywords: {}, - rest_keywords: nil - ) - end - - def ==(other) - other.is_a?(self.class) && - other.required == required && - other.optional == optional && - other.rest == rest && - other.required_keywords == required_keywords && - other.optional_keywords == optional_keywords && - other.rest_keywords == rest_keywords - end - - alias eql? == - - def hash - required.hash ^ optional.hash ^ rest.hash ^ required_keywords.hash ^ optional_keywords.hash ^ rest_keywords.hash - end - - def flat_unnamed_params - required.map {|p| [:required, p] } + optional.map {|p| [:optional, p] } - end - - def flat_keywords - required_keywords.merge optional_keywords - end - - def has_keywords? - !required_keywords.empty? || !optional_keywords.empty? || rest_keywords - end - - def without_keywords - self.class.new( - required: required, - optional: optional, - rest: rest, - required_keywords: {}, - optional_keywords: {}, - rest_keywords: nil - ) - end - - def drop_first - case - when required.any? || optional.any? || rest - self.class.new( - required: required.any? ? required.drop(1) : [], - optional: required.empty? && optional.any? ? optional.drop(1) : optional, - rest: required.empty? && optional.empty? ? nil : rest, - required_keywords: required_keywords, - optional_keywords: optional_keywords, - rest_keywords: rest_keywords - ) - when has_keywords? - without_keywords - else - raise "Cannot drop from empty params" - end - end - - def each_missing_argument(args) - required.size.times do |index| - if index >= args.size - yield index - end - end - end - - def each_extra_argument(args) - return if rest - - if has_keywords? - args = args.take(args.count - 1) if args.count > 0 - end - - args.size.times do |index| - if index >= required.count + optional.count - yield index - end - end - end - - def each_missing_keyword(args) - return unless has_keywords? - - keywords, rest = extract_keywords(args) - - return unless rest.empty? - - required_keywords.each do |keyword, _| - yield keyword unless keywords.key?(keyword) - end - end - - def each_extra_keyword(args) - return unless has_keywords? - return if rest_keywords - - keywords, rest = extract_keywords(args) - - return unless rest.empty? - - all_keywords = flat_keywords - keywords.each do |keyword, _| - yield keyword unless all_keywords.key?(keyword) - end - end - - def extract_keywords(args) - last_arg = args.last - - keywords = {} - rest = [] - - if last_arg&.type == :hash - last_arg.children.each do |element| - case element.type - when :pair - if element.children[0].type == :sym - name = element.children[0].children[0] - keywords[name] = element.children[1] - end - when :kwsplat - rest << element.children[0] - end - end - end - - [keywords, rest] - end - - def each_type() - if block_given? - flat_unnamed_params.each do |(_, type)| - yield type - end - flat_keywords.each do |_, type| - yield type - end - rest and yield rest - rest_keywords and yield rest_keywords - else - enum_for :each_type - end - end - - def free_variables() - @fvs ||= Set.new.tap do |set| - each_type do |type| - set.merge(type.free_variables) - end - end - end - - def closed? - required.all?(&:closed?) && optional.all?(&:closed?) && (!rest || rest.closed?) && required_keywords.values.all?(&:closed?) && optional_keywords.values.all?(&:closed?) && (!rest_keywords || rest_keywords.closed?) - end - - def subst(s) - return self if s.empty? - return self if empty? - return self if free_variables.disjoint?(s.domain) - - rs = required.map {|t| t.subst(s) } - os = optional.map {|t| t.subst(s) } - r = rest&.subst(s) - rk = required_keywords.transform_values {|t| t.subst(s) } - ok = optional_keywords.transform_values {|t| t.subst(s) } - k = rest_keywords&.subst(s) - - if rs == required && os == optional && r == rest && rk == required_keywords && ok == optional_keywords && k == rest_keywords - self - else - self.class.new( - required: required.map {|t| t.subst(s) }, - optional: optional.map {|t| t.subst(s) }, - rest: rest&.subst(s), - required_keywords: required_keywords.transform_values {|t| t.subst(s) }, - optional_keywords: optional_keywords.transform_values {|t| t.subst(s) }, - rest_keywords: rest_keywords&.subst(s) - ) - end - end - - def size - required.size + optional.size + (rest ? 1 : 0) + required_keywords.size + optional_keywords.size + (rest_keywords ? 1 : 0) - end - - def to_s - required = self.required.map {|ty| ty.to_s } - optional = self.optional.map {|ty| "?#{ty}" } - rest = self.rest ? ["*#{self.rest}"] : [] - required_keywords = self.required_keywords.map {|name, type| "#{name}: #{type}" } - optional_keywords = self.optional_keywords.map {|name, type| "?#{name}: #{type}"} - rest_keywords = self.rest_keywords ? ["**#{self.rest_keywords}"] : [] - "(#{(required + optional + rest + required_keywords + optional_keywords + rest_keywords).join(", ")})" - end - - def map_type(&block) - self.class.new( - required: required.map(&block), - optional: optional.map(&block), - rest: rest && yield(rest), - required_keywords: required_keywords.transform_values(&block), - optional_keywords: optional_keywords.transform_values(&block), - rest_keywords: rest_keywords && yield(rest_keywords) - ) - end - - def empty? - !has_positional? && !has_keywords? - end - - # self + params returns a new params for overloading. - # - def +(other) - a = first_param - b = other.first_param - - case - when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other.drop_first).with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.nil? - (self.drop_first + other).with_first_param(OptionalPositional.new(a.type)) - when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.nil? - (self.drop_first + other).with_first_param(OptionalPositional.new(a.type)) - when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self + other.drop_first).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self + other.drop_first).with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RestPositional) && b.is_a?(RestPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first + other.drop_first).with_first_param(RestPositional.new(type)) - end - when a.is_a?(RestPositional) && b.nil? - (self.drop_first + other).with_first_param(RestPositional.new(a.type)) - when a.nil? && b.is_a?(RequiredPositional) - (self + other.drop_first).with_first_param(OptionalPositional.new(b.type)) - when a.nil? && b.is_a?(OptionalPositional) - (self + other.drop_first).with_first_param(OptionalPositional.new(b.type)) - when a.nil? && b.is_a?(RestPositional) - (self + other.drop_first).with_first_param(RestPositional.new(b.type)) - when a.nil? && b.nil? - required_keywords = {} - - (Set.new(self.required_keywords.keys) & Set.new(other.required_keywords.keys)).each do |keyword| - required_keywords[keyword] = AST::Types::Union.build( - types: [ - self.required_keywords[keyword], - other.required_keywords[keyword] - ] - ) - end - - optional_keywords = {} - self.required_keywords.each do |keyword, t| - unless required_keywords.key?(keyword) - case - when other.optional_keywords.key?(keyword) - optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.optional_keywords[keyword]]) - when other.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.rest_keywords]) - else - optional_keywords[keyword] = t - end - end - end - other.required_keywords.each do |keyword, t| - unless required_keywords.key?(keyword) - case - when self.optional_keywords.key?(keyword) - optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.optional_keywords[keyword]]) - when self.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.rest_keywords]) - else - optional_keywords[keyword] = t - end - end - end - self.optional_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) - case - when other.optional_keywords.key?(keyword) - optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.optional_keywords[keyword]]) - when other.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, other.rest_keywords]) - else - optional_keywords[keyword] = t - end - end - end - other.optional_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) - case - when self.optional_keywords.key?(keyword) - optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.optional_keywords[keyword]]) - when self.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, self.rest_keywords]) - else - optional_keywords[keyword] = t - end - end - end - - rest = case - when self.rest_keywords && other.rest_keywords - AST::Types::Union.build(types: [self.rest_keywords, other.rest_keywords]) - else - self.rest_keywords || other.rest_keywords - end - - Params.new( - required: [], - optional: [], - rest: nil, - required_keywords: required_keywords, - optional_keywords: optional_keywords, - rest_keywords: rest) - end - end - - # Returns the intersection between self and other. - # Returns nil if the intersection cannot be computed. - # - def &(other) - a = first_param - b = other.first_param - - case - when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other)&.with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.nil? - nil - when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.nil? - self.drop_first & other - when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self & other.drop_first)&.with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self & other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RestPositional) && b.is_a?(RestPositional) - AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first)&.with_first_param(RestPositional.new(type)) - end - when a.is_a?(RestPositional) && b.nil? - self.drop_first & other - when a.nil? && b.is_a?(RequiredPositional) - nil - when a.nil? && b.is_a?(OptionalPositional) - self & other.drop_first - when a.nil? && b.is_a?(RestPositional) - self & other.drop_first - when a.nil? && b.nil? - optional_keywords = {} - - (Set.new(self.optional_keywords.keys) & Set.new(other.optional_keywords.keys)).each do |keyword| - optional_keywords[keyword] = AST::Types::Intersection.build( - types: [ - self.optional_keywords[keyword], - other.optional_keywords[keyword] - ] - ) - end - - required_keywords = {} - self.optional_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) - case - when other.required_keywords.key?(keyword) - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.required_keywords[keyword]]) - when other.rest_keywords - optional_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) - end - end - end - other.optional_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) - case - when self.required_keywords.key?(keyword) - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.required_keywords[keyword]]) - when self.rest_keywords - optional_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) - end - end - end - self.required_keywords.each do |keyword, t| - unless required_keywords.key?(keyword) - case - when other.required_keywords.key?(keyword) - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.required_keywords[keyword]]) - when other.rest_keywords - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) - else - return - end - end - end - other.required_keywords.each do |keyword, t| - unless required_keywords.key?(keyword) - case - when self.required_keywords.key?(keyword) - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.required_keywords[keyword]]) - when self.rest_keywords - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) - else - return - end - end - end - - rest = case - when self.rest_keywords && other.rest_keywords - AST::Types::Intersection.build(types: [self.rest_keywords, other.rest_keywords]) - else - nil - end - - Params.new( - required: [], - optional: [], - rest: nil, - required_keywords: required_keywords, - optional_keywords: optional_keywords, - rest_keywords: rest) - end - end - - # Returns the union between self and other. - # - def |(other) - a = first_param - b = other.first_param - - case - when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(RequiredPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RequiredPositional) && b.nil? - self.drop_first&.with_first_param(OptionalPositional.new(a.type)) - when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(OptionalPositional) && b.nil? - (self.drop_first | other)&.with_first_param(a) - when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self | other.drop_first)&.with_first_param(OptionalPositional.new(type)) - end - when a.is_a?(RestPositional) && b.is_a?(RestPositional) - AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first)&.with_first_param(RestPositional.new(type)) - end - when a.is_a?(RestPositional) && b.nil? - (self.drop_first | other)&.with_first_param(a) - when a.nil? && b.is_a?(RequiredPositional) - other.drop_first&.with_first_param(OptionalPositional.new(b.type)) - when a.nil? && b.is_a?(OptionalPositional) - (self | other.drop_first)&.with_first_param(b) - when a.nil? && b.is_a?(RestPositional) - (self | other.drop_first)&.with_first_param(b) - when a.nil? && b.nil? - required_keywords = {} - optional_keywords = {} - - (Set.new(self.required_keywords.keys) & Set.new(other.required_keywords.keys)).each do |keyword| - required_keywords[keyword] = AST::Types::Union.build( - types: [ - self.required_keywords[keyword], - other.required_keywords[keyword] - ] - ) - end - - self.optional_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) - case - when s = other.required_keywords[keyword] - optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) - when s = other.optional_keywords[keyword] - optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) - when r = other.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) - else - optional_keywords[keyword] = t - end - end - end - other.optional_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) - case - when s = self.required_keywords[keyword] - optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) - when s = self.optional_keywords[keyword] - optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) - when r = self.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) - else - optional_keywords[keyword] = t - end - end - end - self.required_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) - case - when s = other.optional_keywords[keyword] - optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) - when r = other.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) - else - optional_keywords[keyword] = t - end - end - end - other.required_keywords.each do |keyword, t| - unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) - case - when s = self.optional_keywords[keyword] - optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) - when r = self.rest_keywords - optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) - else - optional_keywords[keyword] = t - end - end - end - - rest = case - when self.rest_keywords && other.rest_keywords - AST::Types::Union.build(types: [self.rest_keywords, other.rest_keywords]) - when self.rest_keywords - if required_keywords.empty? && optional_keywords.empty? - self.rest_keywords - end - when other.rest_keywords - if required_keywords.empty? && optional_keywords.empty? - other.rest_keywords - end - else - nil - end - - Params.new( - required: [], - optional: [], - rest: nil, - required_keywords: required_keywords, - optional_keywords: optional_keywords, - rest_keywords: rest) - end - end - end - - class Block - attr_reader :type - attr_reader :optional - - def initialize(type:, optional:) - @type = type - @optional = optional - end - - def optional? - @optional - end - - def to_optional - self.class.new( - type: type, - optional: true - ) - end - - def ==(other) - other.is_a?(self.class) && other.type == type && other.optional == optional - end - - alias eql? == - - def hash - type.hash ^ optional.hash - end - - def closed? - type.closed? - end - - def subst(s) - ty = type.subst(s) - if ty == type - self - else - self.class.new( - type: ty, - optional: optional - ) - end - end - - def free_variables() - @fvs ||= type.free_variables - end - - def to_s - "#{optional? ? "?" : ""}{ #{type.params} -> #{type.return_type} }" - end - - def map_type(&block) - self.class.new( - type: type.map_type(&block), - optional: optional - ) - end - - def +(other) - optional = self.optional? || other.optional? - type = AST::Types::Proc.new( - params: self.type.params + other.type.params, - return_type: AST::Types::Union.build(types: [self.type.return_type, other.type.return_type]) - ) - self.class.new( - type: type, - optional: optional - ) - end - end - class MethodType attr_reader :type_params - attr_reader :params + attr_reader :type attr_reader :block - attr_reader :return_type attr_reader :method_decls - def initialize(type_params:, params:, block:, return_type:, method_decls:) + def initialize(type_params:, type:, block:, method_decls:) @type_params = type_params - @params = params + @type = type @block = block - @return_type = return_type @method_decls = method_decls end def ==(other) other.is_a?(self.class) && other.type_params == type_params && - other.params == params && - other.block == block && - other.return_type == return_type + other.type == type && + other.block == block end alias eql? == def hash - type_params.hash ^ params.hash ^ block.hash ^ return_type.hash + type_params.hash ^ type.hash ^ block.hash end def free_variables @fvs ||= Set.new.tap do |set| - set.merge(params.free_variables) + set.merge(type.free_variables) if block set.merge(block.free_variables) end - set.merge(return_type.free_variables) set.subtract(type_params) end end @@ -817,21 +44,19 @@ def subst(s) self.class.new( type_params: type_params, - params: params.subst(s_), + type: type.subst(s_), block: block&.subst(s_), - return_type: return_type.subst(s_), method_decls: method_decls ) end def each_type(&block) if block_given? - params.each_type(&block) + type.each_type(&block) self.block&.tap do self.block.type.params.each_type(&block) yield(self.block.type.return_type) end - yield(return_type) else enum_for :each_type end @@ -839,23 +64,22 @@ def each_type(&block) def instantiate(s) self.class.new(type_params: [], - params: params.subst(s), + type: type.subst(s), block: block&.subst(s), - return_type: return_type.subst(s), method_decls: method_decls) end - def with(type_params: self.type_params, params: self.params, block: self.block, return_type: self.return_type, method_decls: self.method_decls) + def with(type_params: self.type_params, type: self.type, block: self.block, method_decls: self.method_decls) self.class.new(type_params: type_params, - params: params, + type: type, block: block, - return_type: return_type, method_decls: method_decls) end def to_s type_params = !self.type_params.empty? ? "[#{self.type_params.map{|x| "#{x}" }.join(", ")}] " : "" - params = self.params.to_s + params = type.params.to_s + return_type = type.return_type block = self.block ? " #{self.block}" : "" "#{type_params}#{params}#{block} -> #{return_type}" @@ -863,9 +87,8 @@ def to_s def map_type(&block) self.class.new(type_params: type_params, - params: params.map_type(&block), + type: type.map_type(&block), block: self.block&.yield_self {|blk| blk.map_type(&block) }, - return_type: yield(return_type), method_decls: method_decls) end @@ -889,11 +112,14 @@ def unify_overload(other) self.class.new( type_params: type_params, - params: params.subst(s1) + other.params.subst(s2), - block: block, - return_type: AST::Types::Union.build( - types: [return_type.subst(s1),other.return_type.subst(s2)] + type: Function.new( + params: type.params.subst(s1) + other.type.params.subst(s2), + return_type: AST::Types::Union.build( + types: [type.return_type.subst(s1), other.type.return_type.subst(s2)] + ), + location: nil ), + block: block, method_decls: method_decls + other.method_decls ) end @@ -921,14 +147,12 @@ def |(other) type_params = (self_type_params + other_type_params).to_a end - params = self.params & other.params or return + params = self.type.params & other.type.params or return block = case when self.block && other.block block_params = self.block.type.params | other.block.type.params block_return_type = AST::Types::Intersection.build(types: [self.block.type.return_type, other.block.type.return_type]) - block_type = AST::Types::Proc.new(params: block_params, - return_type: block_return_type, - location: nil) + block_type = Function.new(params: block_params, return_type: block_return_type, location: nil) Block.new( type: block_type, optional: self.block.optional && other.block.optional @@ -942,13 +166,12 @@ def |(other) else return end - return_type = AST::Types::Union.build(types: [self.return_type, other.return_type]) + return_type = AST::Types::Union.build(types: [self.type.return_type, other.type.return_type]) MethodType.new( - params: params, - block: block, - return_type: return_type, type_params: type_params, + type: Function.new(params: params, return_type: return_type, location: nil), + block: block, method_decls: method_decls + other.method_decls ) end @@ -972,14 +195,12 @@ def &(other) type_params = (self_type_params + other_type_params).to_a end - params = self.params | other.params + params = self.type.params | other.type.params block = case when self.block && other.block block_params = self.block.type.params & other.block.type.params or return block_return_type = AST::Types::Union.build(types: [self.block.type.return_type, other.block.type.return_type]) - block_type = AST::Types::Proc.new(params: block_params, - return_type: block_return_type, - location: nil) + block_type = Function.new(params: block_params, return_type: block_return_type, location: nil) Block.new( type: block_type, optional: self.block.optional || other.block.optional @@ -989,13 +210,12 @@ def &(other) self.block || other.block end - return_type = AST::Types::Intersection.build(types: [self.return_type, other.return_type]) + return_type = AST::Types::Intersection.build(types: [self.type.return_type, other.type.return_type]) MethodType.new( - params: params, - block: block, - return_type: return_type, type_params: type_params, + type: Function.new(params: params, return_type: return_type, location: nil), + block: block, method_decls: method_decls + other.method_decls ) end diff --git a/lib/steep/subtyping/check.rb b/lib/steep/subtyping/check.rb index 112f95103..d1def7301 100644 --- a/lib/steep/subtyping/check.rb +++ b/lib/steep/subtyping/check.rb @@ -319,17 +319,20 @@ def check0(relation, self_type:, assumption:, trace:, constraints:) end when relation.sub_type.is_a?(AST::Types::Proc) && relation.super_type.is_a?(AST::Types::Proc) - check_method_params(:__proc__, - relation.sub_type.params, relation.super_type.params, - self_type: self_type, - assumption: assumption, - trace: trace, - constraints: constraints).then do - check(Relation.new(sub_type: relation.sub_type.return_type, super_type: relation.super_type.return_type), - self_type: self_type, - assumption: assumption, - trace: trace, - constraints: constraints) + name = :__proc__ + + sub_type = relation.sub_type + super_type = relation.super_type + + check_method_params(name, sub_type.type.params, super_type.type.params, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints).then do + check_block_given(name, sub_type.block, super_type.block, trace: trace, constraints: constraints).then do + check_block_params(name, sub_type.block, super_type.block, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints).then do + check_block_return(sub_type.block, super_type.block, self_type: self_type, assumption: assumption, trace: trace, constraints:constraints).then do + relation = Relation.new(super_type: super_type.type.return_type, sub_type: sub_type.type.return_type) + check(relation, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints) + end + end + end end when relation.sub_type.is_a?(AST::Types::Tuple) && relation.super_type.is_a?(AST::Types::Tuple) @@ -700,12 +703,12 @@ def check_generic_method_type(name, sub_type, super_type, self_type:, assumption def check_method_type(name, sub_type, super_type, self_type:, assumption:, trace:, constraints:) Steep.logger.tagged("#{name}: #{sub_type} <: #{super_type}") do - check_method_params(name, sub_type.params, super_type.params, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints).then do + check_method_params(name, sub_type.type.params, super_type.type.params, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints).then do check_block_given(name, sub_type.block, super_type.block, trace: trace, constraints: constraints).then do check_block_params(name, sub_type.block, super_type.block, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints).then do check_block_return(sub_type.block, super_type.block, self_type: self_type, assumption: assumption, trace: trace, constraints:constraints).then do - relation = Relation.new(super_type: super_type.return_type, - sub_type: sub_type.return_type) + relation = Relation.new(super_type: super_type.type.return_type, + sub_type: sub_type.type.return_type) check(relation, self_type: self_type, assumption: assumption, trace: trace, constraints: constraints) end end @@ -750,10 +753,10 @@ def check_method_params(name, sub_params, super_params, self_type:, assumption:, def match_method_type(name, sub_type, super_type, trace:) [].tap do |pairs| - match_params(name, sub_type.params, super_type.params, trace: trace).yield_self do |result| + match_params(name, sub_type.type.params, super_type.type.params, trace: trace).yield_self do |result| return result unless result.is_a?(Array) pairs.push(*result) - pairs.push [sub_type.return_type, super_type.return_type] + pairs.push [sub_type.type.return_type, super_type.type.return_type] case when !super_type.block && !sub_type.block diff --git a/lib/steep/subtyping/variable_occurrence.rb b/lib/steep/subtyping/variable_occurrence.rb index b844a7d17..a9485c4ed 100644 --- a/lib/steep/subtyping/variable_occurrence.rb +++ b/lib/steep/subtyping/variable_occurrence.rb @@ -10,12 +10,12 @@ def initialize end def add_method_type(method_type) - method_type.params.each_type do |type| + method_type.type.params.each_type do |type| each_var(type) do |var| params << var end end - each_var(method_type.return_type) do |var| + each_var(method_type.type.return_type) do |var| returns << var end diff --git a/lib/steep/subtyping/variable_variance.rb b/lib/steep/subtyping/variable_variance.rb index f5b8d48aa..37b804110 100644 --- a/lib/steep/subtyping/variable_variance.rb +++ b/lib/steep/subtyping/variable_variance.rb @@ -25,8 +25,8 @@ def self.from_method_type(method_type) covariants = Set.new contravariants = Set.new - add_params(method_type.params, block: false, contravariants: contravariants, covariants: covariants) - add_type(method_type.return_type, variance: :covariant, covariants: covariants, contravariants: contravariants) + add_params(method_type.type.params, block: false, contravariants: contravariants, covariants: covariants) + add_type(method_type.type.return_type, variance: :covariant, covariants: covariants, contravariants: contravariants) method_type.block&.type&.yield_self do |proc| add_params(proc.params, block: true, contravariants: contravariants, covariants: covariants) diff --git a/lib/steep/type_construction.rb b/lib/steep/type_construction.rb index 33feda8f6..ff14c73c0 100644 --- a/lib/steep/type_construction.rb +++ b/lib/steep/type_construction.rb @@ -140,10 +140,10 @@ def for_new_method(method_name, node, args:, self_type:, definition:) method_type = annotation_method_type || definition_method_type - if annots&.return_type && method_type&.return_type - check_relation(sub_type: annots.return_type, super_type: method_type.return_type).else do |result| + if annots&.return_type && method_type&.type&.return_type + check_relation(sub_type: annots.return_type, super_type: method_type.type.return_type).else do |result| typing.add_error Errors::MethodReturnTypeAnnotationMismatch.new(node: node, - method_type: method_type.return_type, + method_type: method_type.type.return_type, annotation_type: annots.return_type, result: result) end @@ -152,19 +152,18 @@ def for_new_method(method_name, node, args:, self_type:, definition:) # constructor_method = method&.attributes&.include?(:constructor) if method_type - var_types = TypeConstruction.parameter_types(args, method_type) - unless TypeConstruction.valid_parameter_env?(var_types, args.reject {|arg| arg.type == :blockarg}, method_type.params) + var_types = TypeConstruction.parameter_types(args, method_type.type) + unless TypeConstruction.valid_parameter_env?(var_types, args.reject {|arg| arg.type == :blockarg}, method_type.type.params) typing.add_error Errors::MethodArityMismatch.new(node: node) end end if (block_arg = args.find {|arg| arg.type == :blockarg}) if method_type&.block - block_type = if method_type.block.optional? - AST::Types::Union.build(types: [method_type.block.type, AST::Builtin.nil_type]) - else - method_type.block.type - end + block_type = AST::Types::Proc.new(type: method_type.block.type, block: nil) + if method_type.block.optional? + block_type = AST::Types::Union.build(types: [block_type, AST::Builtin.nil_type]) + end var_types[block_arg.children[0]] = block_type end end @@ -183,7 +182,7 @@ def for_new_method(method_name, node, args:, self_type:, definition:) name: method_name, method: definition && definition.methods[method_name], method_type: method_type, - return_type: annots.return_type || method_type&.return_type || AST::Builtin.any_type, + return_type: annots.return_type || method_type&.type&.return_type || AST::Builtin.any_type, constructor: false, super_method: super_method ) @@ -1404,7 +1403,7 @@ def synthesize(node, hint: nil) if method_context&.method if method_context.super_method types = method_context.super_method.method_types.map {|method_type| - checker.factory.method_type(method_type, self_type: self_type, method_decls: Set[]).return_type + checker.factory.method_type(method_type, self_type: self_type, method_decls: Set[]).type.return_type } add_typing(node, type: union_type(*types)) else @@ -2022,17 +2021,25 @@ def synthesize(node, hint: nil) if hint.is_a?(AST::Types::Proc) && value.type == :sym if hint.one_arg? # Assumes Symbol#to_proc implementation - param_type = hint.params.required[0] + param_type = hint.type.params.required[0] interface = checker.factory.interface(param_type, private: true) method = interface.methods[value.children[0]] if method return_types = method.method_types.select {|method_type| - method_type.params.each_type.count == 0 - }.map(&:return_type) + method_type.type.params.empty? + }.map {|method_type| + method_type.type.return_type + } unless return_types.empty? - type = AST::Types::Proc.new(params: Interface::Params.empty.update(required: [param_type]), - return_type: AST::Types::Union.build(types: return_types)) + type = AST::Types::Proc.new( + type: Interface::Function.new( + params: Interface::Function::Params.empty.update(required: [param_type]), + return_type: AST::Types::Union.build(types: return_types), + location: nil + ), + block: nil + ) end end else @@ -2362,8 +2369,8 @@ def type_lambda(node, block_params:, block_body:, type_hint:) case type_hint when AST::Types::Proc - params_hint = type_hint.params - return_hint = type_hint.return_type + params_hint = type_hint.type.params + return_hint = type_hint.type.return_type end block_constr = for_block( @@ -2393,8 +2400,12 @@ def type_lambda(node, block_params:, block_body:, type_hint:) end block_type = AST::Types::Proc.new( - params: params_hint || params.params_type, - return_type: return_type + type: Interface::Function.new( + params: params_hint || params.params_type, + return_type: return_type, + location: nil + ), + block: nil ) add_typing node, type: block_type @@ -2600,7 +2611,7 @@ def type_method_call(node, method_name:, receiver_type:, method:, args:, block_p results = method.method_types.flat_map do |method_type| Steep.logger.tagged method_type.to_s do - zips = args.zips(method_type.params, method_type.block&.type) + zips = args.zips(method_type.type.params, method_type.block&.type) zips.map do |arg_pairs| typing.new_child(node_range) do |child_typing| @@ -2632,7 +2643,7 @@ def type_method_call(node, method_name:, receiver_type:, method:, args:, block_p context: context.method_context, method_name: method_name, receiver_type: receiver_type, - return_type: method_type.return_type, + return_type: method_type.type.return_type, errors: [error], method_decls: all_decls ) @@ -2655,7 +2666,7 @@ def type_method_call(node, method_name:, receiver_type:, method:, args:, block_p end def check_keyword_arg(receiver_type:, node:, method_type:, constraints:) - params = method_type.params + params = method_type.type.params case node.type when :hash @@ -2757,7 +2768,7 @@ def check_keyword_arg(receiver_type:, node:, method_type:, constraints:) ) else hash_elements = params.required_keywords.merge( - method_type.params.optional_keywords.transform_values do |type| + params.optional_keywords.transform_values do |type| AST::Types::Union.build(types: [type, AST::Builtin.nil_type]) end ) @@ -2840,7 +2851,7 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, args:, arg block_params: block_params_, block_param_hint: method_type.block.type.params, block_annotations: block_annotations, - node_type_hint: method_type.return_type + node_type_hint: method_type.type.return_type ) block_constr = block_constr.with_new_typing( block_constr.typing.new_child( @@ -2870,7 +2881,7 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, args:, arg checker, self_type: self_type, variance: variance, - variables: method_type.params.free_variables + method_type.block.type.params.free_variables + variables: method_type.type.params.free_variables + method_type.block.type.params.free_variables ) method_type = method_type.subst(s) block_constr = block_constr.update_lvar_env {|env| env.subst(s) } @@ -2896,22 +2907,36 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, args:, arg s = constraints.solution(checker, self_type: self_type, variance: variance, variables: fresh_vars) method_type = method_type.subst(s) - return_type = method_type.return_type + return_type = method_type.type.return_type if break_type = block_annotations.break_type return_type = union_type(break_type, return_type) end when Subtyping::Result::Failure - block_type = AST::Types::Proc.new( - params: method_type.block.type.params || block_params.params_type, - return_type: block_body_type + given_block_type = AST::Types::Proc.new( + type: Interface::Function.new( + params: method_type.block.type.params || block_params.params_type, + return_type: block_body_type, + location: nil + ), + block: nil + ) + + method_block_type = AST::Types::Proc.new( + type: Interface::Function.new( + params: method_type.block.type.params, + return_type: method_type.block.type.return_type, + location: nil + ), + block: nil ) + errors << Errors::BlockTypeMismatch.new(node: node, - expected: method_type.block.type, - actual: block_type, + expected: method_block_type, + actual: given_block_type, result: result) - return_type = method_type.return_type + return_type = method_type.type.return_type end block_constr.typing.save! @@ -2985,23 +3010,24 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, args:, arg else begin method_type = method_type.subst(constraints.solution(checker, self_type: self_type, variance: variance, variables: occurence.params)) - block_type, constr = constr.synthesize(args.block_pass_arg, hint: topdown_hint ? method_type.block.type : nil) - result = check_relation( - sub_type: block_type, - super_type: method_type.block.yield_self {|expected_block| - if expected_block.optional? - AST::Builtin.optional(expected_block.type) - else - expected_block.type - end - }, - constraints: constraints - ) + hint_type = if topdown_hint + AST::Types::Proc.new(type: method_type.block.type, block: nil) + end + given_block_type, constr = constr.synthesize(args.block_pass_arg, hint: hint_type) + method_block_type = method_type.block.yield_self {|expected_block| + proc_type = AST::Types::Proc.new(type: expected_block.type, block: nil) + if expected_block.optional? + AST::Builtin.optional(proc_type) + else + proc_type + end + } + result = check_relation(sub_type: given_block_type, super_type: method_block_type, constraints: constraints) result.else do |result| errors << Errors::BlockTypeMismatch.new(node: node, - expected: method_type.block.type, - actual: block_type, + expected: method_block_type, + actual: given_block_type, result: result) end @@ -3018,7 +3044,7 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, args:, arg receiver_type: receiver_type, method_name: method_name, actual_method_type: method_type, - return_type: return_type || method_type.return_type, + return_type: return_type || method_type.type.return_type, method_decls: method_type.method_decls ) else @@ -3027,7 +3053,7 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, args:, arg context: context.method_context, receiver_type: receiver_type, method_name: method_name, - return_type: return_type || method_type.return_type, + return_type: return_type || method_type.type.return_type, method_decls: method_type.method_decls, errors: errors ) diff --git a/lib/steep/type_inference/block_params.rb b/lib/steep/type_inference/block_params.rb index a8947732e..665220cac 100644 --- a/lib/steep/type_inference/block_params.rb +++ b/lib/steep/type_inference/block_params.rb @@ -134,7 +134,7 @@ def params_type0(hint:) rest = rest_param&.yield_self {|param| param.type.args[0] } end - Interface::Params.new( + Interface::Function::Params.new( required: leadings, optional: optionals, rest: rest, diff --git a/smoke/tsort/Steepfile b/smoke/tsort/Steepfile new file mode 100644 index 000000000..d12830583 --- /dev/null +++ b/smoke/tsort/Steepfile @@ -0,0 +1,6 @@ +target :test do + signature "." + check "." + + library "tsort" +end diff --git a/smoke/tsort/a.rb b/smoke/tsort/a.rb new file mode 100644 index 000000000..782fb51f6 --- /dev/null +++ b/smoke/tsort/a.rb @@ -0,0 +1,15 @@ +# ALLOW FAILURE + +require "tsort" + +# @type var g: Hash[Integer, Array[Integer]] +g = {1=>[2, 3], 2=>[4], 3=>[2, 4], 4=>[]} + +# @type var each_node: ^() { (Integer) -> void } -> void +each_node = -> (&b) { g.each_key(&b) } +# @type var each_child: ^(Integer) { (Integer) -> void } -> void +each_child = -> (n, &b) { g[n].each(&b) } + +# @type var xs: Array[String] +# !expects IncompatibleAssignment: lhs_type=::Array[::String], rhs_type=::Array[::Integer] +xs = TSort.tsort(each_node, each_child) diff --git a/steep.gemspec b/steep.gemspec index 3297873eb..a7d2abd71 100644 --- a/steep.gemspec +++ b/steep.gemspec @@ -34,5 +34,5 @@ Gem::Specification.new do |spec| spec.add_runtime_dependency "rainbow", ">= 2.2.2", "< 4.0" spec.add_runtime_dependency "listen", "~> 3.0" spec.add_runtime_dependency "language_server-protocol", "~> 3.15.0.1" - spec.add_runtime_dependency "rbs", "~> 0.17.0" + spec.add_runtime_dependency "rbs", ">= 0.20.0" end diff --git a/test/args_test.rb b/test/args_test.rb index 84c9192a1..3e95e125c 100644 --- a/test/args_test.rb +++ b/test/args_test.rb @@ -5,7 +5,7 @@ class ArgsTest < Minitest::Test include FactoryHelper SendArgs = Steep::TypeInference::SendArgs - Params = Steep::Interface::Params + Params = Steep::Interface::Function::Params Types = Steep::AST::Types AST = Steep::AST diff --git a/test/block_params_test.rb b/test/block_params_test.rb index 0246a7dbc..de64dc561 100644 --- a/test/block_params_test.rb +++ b/test/block_params_test.rb @@ -7,7 +7,7 @@ class BlockParamsTest < Minitest::Test BlockParams = Steep::TypeInference::BlockParams LabeledName = ASTUtils::Labeling::LabeledName - Params = Steep::Interface::Params + Params = Steep::Interface::Function::Params Types = Steep::AST::Types Namespace = RBS::Namespace @@ -349,7 +349,7 @@ def test_param_type_with_hint end def param_type(required: [], optional: [], rest: nil, required_keywords: {}, optional_keywords: {}, rest_keywords: nil) - Steep::Interface::Params.new( + Steep::Interface::Function::Params.new( required: required.map {|s| parse_type(s) }, optional: optional.map {|t| parse_type(t) }, rest: rest&.yield_self {|t| parse_type(t) }, diff --git a/test/interface_test.rb b/test/interface_test.rb index ff39318ac..95c75bae9 100644 --- a/test/interface_test.rb +++ b/test/interface_test.rb @@ -6,26 +6,26 @@ class InterfaceTest < Minitest::Test def test_method_type_params_plus with_factory do - assert_equal parse_method_type("(String | Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params + parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(String | Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params + parse_method_type("(Integer) -> untyped").type.params - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params + parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params + parse_method_type("(Integer) -> untyped").type.params - assert_equal parse_method_type("(?String) -> untyped").params, - parse_method_type("(String) -> untyped").params + parse_method_type("() -> untyped").params + assert_equal parse_method_type("(?String) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params + parse_method_type("() -> untyped").type.params - assert_equal parse_method_type("(?String | Symbol, *Symbol) -> untyped").params, - parse_method_type("(String) -> untyped").params + parse_method_type("(*Symbol) -> untyped").params + assert_equal parse_method_type("(?String | Symbol, *Symbol) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params + parse_method_type("(*Symbol) -> untyped").type.params - assert_equal parse_method_type("(?String | Symbol, *Symbol) -> void").params, - parse_method_type("(String) -> params").params + parse_method_type("(*Symbol) -> void").params + assert_equal parse_method_type("(?String | Symbol, *Symbol) -> void").type.params, + parse_method_type("(String) -> params").type.params + parse_method_type("(*Symbol) -> void").type.params - assert_equal parse_method_type("(name: String | Symbol, ?email: String | Array, ?age: Integer | Object, **Array | Object) -> void").params, - parse_method_type("(name: String, email: String, **Object) -> void").params + parse_method_type("(name: Symbol, age: Integer, **Array) -> void").params + assert_equal parse_method_type("(name: String | Symbol, ?email: String | Array, ?age: Integer | Object, **Array | Object) -> void").type.params, + parse_method_type("(name: String, email: String, **Object) -> void").type.params + parse_method_type("(name: Symbol, age: Integer, **Array) -> void").type.params - assert_equal parse_method_type("() ?{ (String | Integer) -> (Array | Hash) } -> void").params, - parse_method_type("() ?{ (String) -> Array } -> void").params + parse_method_type("() { (Integer) -> Hash } -> void").params + assert_equal parse_method_type("() ?{ (String | Integer) -> (Array | Hash) } -> void").type.params, + parse_method_type("() ?{ (String) -> Array } -> void").type.params + parse_method_type("() { (Integer) -> Hash } -> void").type.params end end @@ -34,248 +34,248 @@ def test_method_type_params_intersection # req, none, opt, rest # required:required - assert_equal parse_method_type("(String & Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params & parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(String & Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params & parse_method_type("(Integer) -> untyped").type.params # required:none - assert_nil parse_method_type("(String) -> untyped").params & parse_method_type("() -> untyped").params + assert_nil parse_method_type("(String) -> untyped").type.params & parse_method_type("() -> untyped").type.params # required:optional - assert_equal parse_method_type("(String & Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params & parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(String & Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params & parse_method_type("(?Integer) -> untyped").type.params # required:rest - assert_equal parse_method_type("(String & Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params & parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(String & Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params & parse_method_type("(*Integer) -> untyped").type.params # none:required - assert_nil parse_method_type("() -> untyped").params & parse_method_type("(String) -> void").params + assert_nil parse_method_type("() -> untyped").type.params & parse_method_type("(String) -> void").type.params # none:optional - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("() -> untyped").params & parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("() -> untyped").type.params & parse_method_type("(?Integer) -> untyped").type.params # none:rest - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("() -> untyped").params & parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("() -> untyped").type.params & parse_method_type("(*Integer) -> untyped").type.params # opt:required - assert_equal parse_method_type("(String & Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params & parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(String & Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params & parse_method_type("(Integer) -> untyped").type.params # opt:none - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("(?String) -> untyped").params & parse_method_type("() -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params & parse_method_type("() -> untyped").type.params # opt:opt - assert_equal parse_method_type("(?String & Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params & parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(?String & Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params & parse_method_type("(?Integer) -> untyped").type.params # opt:rest - assert_equal parse_method_type("(?String & Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params & parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(?String & Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params & parse_method_type("(*Integer) -> untyped").type.params # rest:required - assert_equal parse_method_type("(String & Integer) -> untyped").params, - parse_method_type("(*String) -> untyped").params & parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(String & Integer) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params & parse_method_type("(Integer) -> untyped").type.params # rest:none - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("(*String) -> untyped").params & parse_method_type("() -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params & parse_method_type("() -> untyped").type.params # rest:opt - assert_equal parse_method_type("(?String & Integer) -> untyped").params, - parse_method_type("(*String) -> untyped").params & parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(?String & Integer) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params & parse_method_type("(?Integer) -> untyped").type.params # rest:rest - assert_equal parse_method_type("(*String & Integer) -> untyped").params, - parse_method_type("(*String) -> untyped").params & parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(*String & Integer) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params & parse_method_type("(*Integer) -> untyped").type.params ## Keywords # req:req - assert_equal parse_method_type("(foo: String & Integer) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params & parse_method_type("(foo: Integer) -> untyped").params + assert_equal parse_method_type("(foo: String & Integer) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params & parse_method_type("(foo: Integer) -> untyped").type.params # req:opt - assert_equal parse_method_type("(foo: Integer & String) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(foo: Integer & String) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params & parse_method_type("(?foo: Integer) -> untyped").type.params # req:none - assert_nil parse_method_type("(foo: String) -> untyped").params & parse_method_type("() -> untyped").params + assert_nil parse_method_type("(foo: String) -> untyped").type.params & parse_method_type("() -> untyped").type.params # req:rest - assert_equal parse_method_type("(foo: String & Integer) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params & parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(foo: String & Integer) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params & parse_method_type("(**Integer) -> untyped").type.params # opt:req - assert_equal parse_method_type("(foo: String & Integer) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params & parse_method_type("(foo: Integer) -> untyped").params + assert_equal parse_method_type("(foo: String & Integer) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params & parse_method_type("(foo: Integer) -> untyped").type.params # opt:opt - assert_equal parse_method_type("(?foo: String & Integer) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String & Integer) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params & parse_method_type("(?foo: Integer) -> untyped").type.params # opt:none - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params & parse_method_type("() -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params & parse_method_type("() -> untyped").type.params # opt:rest - assert_equal parse_method_type("(?foo: String & Integer) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params & parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String & Integer) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params & parse_method_type("(**Integer) -> untyped").type.params # none:req - assert_nil parse_method_type("() -> untyped").params & parse_method_type("(foo: String) -> untyped").params + assert_nil parse_method_type("() -> untyped").type.params & parse_method_type("(foo: String) -> untyped").type.params # none:opt - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("() -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("() -> untyped").type.params & parse_method_type("(?foo: Integer) -> untyped").type.params # none:rest - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("() -> untyped").params & parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("() -> untyped").type.params & parse_method_type("(**Integer) -> untyped").type.params # rest:req - assert_equal parse_method_type("(foo: Integer & String) -> untyped").params, - parse_method_type("(**String) -> untyped").params & parse_method_type("(foo: Integer) -> untyped").params + assert_equal parse_method_type("(foo: Integer & String) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params & parse_method_type("(foo: Integer) -> untyped").type.params # rest:opt - assert_equal parse_method_type("(?foo: Integer & String) -> untyped").params, - parse_method_type("(**String) -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: Integer & String) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params & parse_method_type("(?foo: Integer) -> untyped").type.params # rest:none - assert_equal parse_method_type("() -> untyped").params, - parse_method_type("(**String) -> untyped").params & parse_method_type("() -> untyped").params + assert_equal parse_method_type("() -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params & parse_method_type("() -> untyped").type.params # rest:rest - assert_equal parse_method_type("(**String & Integer) -> untyped").params, - parse_method_type("(**String) -> untyped").params & parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(**String & Integer) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params & parse_method_type("(**Integer) -> untyped").type.params end end def test_method_type_params_union with_factory do # required:required - assert_equal parse_method_type("(String | Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(String | Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params | parse_method_type("(Integer) -> untyped").type.params # required:none - assert_equal parse_method_type("(?String) -> void").params, - parse_method_type("(String) -> untyped").params | parse_method_type("() -> untyped").params + assert_equal parse_method_type("(?String) -> void").type.params, + parse_method_type("(String) -> untyped").type.params | parse_method_type("() -> untyped").type.params # required:optional - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params | parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params | parse_method_type("(?Integer) -> untyped").type.params # required:rest - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params | parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(String) -> untyped").type.params | parse_method_type("(*Integer) -> untyped").type.params # none:required - assert_equal parse_method_type("(?String) -> untyped").params, - parse_method_type("() -> untyped").params | parse_method_type("(String) -> untyped").params + assert_equal parse_method_type("(?String) -> untyped").type.params, + parse_method_type("() -> untyped").type.params | parse_method_type("(String) -> untyped").type.params # none:optional - assert_equal parse_method_type("(?Integer) -> untyped").params, - parse_method_type("() -> untyped").params | parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(?Integer) -> untyped").type.params, + parse_method_type("() -> untyped").type.params | parse_method_type("(?Integer) -> untyped").type.params # none:rest - assert_equal parse_method_type("(*Integer) -> untyped").params, - parse_method_type("() -> untyped").params | parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(*Integer) -> untyped").type.params, + parse_method_type("() -> untyped").type.params | parse_method_type("(*Integer) -> untyped").type.params # opt:required - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params | parse_method_type("(Integer) -> untyped").type.params # opt:none - assert_equal parse_method_type("(?String) -> untyped").params, - parse_method_type("(?String) -> untyped").params | parse_method_type("() -> untyped").params + assert_equal parse_method_type("(?String) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params | parse_method_type("() -> untyped").type.params # opt:opt - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params | parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params | parse_method_type("(?Integer) -> untyped").type.params # opt:rest - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params | parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(?String) -> untyped").type.params | parse_method_type("(*Integer) -> untyped").type.params # rest:required - assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(*String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params | parse_method_type("(Integer) -> untyped").type.params # rest:none - assert_equal parse_method_type("(*String) -> untyped").params, - parse_method_type("(*String) -> untyped").params | parse_method_type("() -> untyped").params + assert_equal parse_method_type("(*String) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params | parse_method_type("() -> untyped").type.params # rest:opt - assert_equal parse_method_type("(?String | Integer, *String) -> untyped").params, - parse_method_type("(*String) -> untyped").params | parse_method_type("(?Integer) -> untyped").params + assert_equal parse_method_type("(?String | Integer, *String) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params | parse_method_type("(?Integer) -> untyped").type.params # rest:rest - assert_equal parse_method_type("(*String | Integer) -> untyped").params, - parse_method_type("(*String) -> untyped").params | parse_method_type("(*Integer) -> untyped").params + assert_equal parse_method_type("(*String | Integer) -> untyped").type.params, + parse_method_type("(*String) -> untyped").type.params | parse_method_type("(*Integer) -> untyped").type.params ## Keywords # req:req - assert_equal parse_method_type("(foo: String | Integer) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params | parse_method_type("(foo: Integer) -> untyped").params + assert_equal parse_method_type("(foo: String | Integer) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params | parse_method_type("(foo: Integer) -> untyped").type.params # req:opt - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params | parse_method_type("(?foo: Integer) -> untyped").type.params # req:none - assert_equal parse_method_type("(?foo: String) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params | parse_method_type("() -> untyped").params + assert_equal parse_method_type("(?foo: String) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params | parse_method_type("() -> untyped").type.params # req:rest - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(foo: String) -> untyped").params | parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(foo: String) -> untyped").type.params | parse_method_type("(**Integer) -> untyped").type.params # opt:req - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params | parse_method_type("(foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params | parse_method_type("(foo: Integer) -> untyped").type.params # opt:opt - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params | parse_method_type("(?foo: Integer) -> untyped").type.params # opt:none - assert_equal parse_method_type("(?foo: String) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params | parse_method_type("() -> untyped").params + assert_equal parse_method_type("(?foo: String) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params | parse_method_type("() -> untyped").type.params # opt:rest - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(?foo: String) -> untyped").params | parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(?foo: String) -> untyped").type.params | parse_method_type("(**Integer) -> untyped").type.params # none:req - assert_equal parse_method_type("(?foo: String) -> untyped").params, - parse_method_type("() -> untyped").params | parse_method_type("(foo: String) -> untyped").params + assert_equal parse_method_type("(?foo: String) -> untyped").type.params, + parse_method_type("() -> untyped").type.params | parse_method_type("(foo: String) -> untyped").type.params # none:opt - assert_equal parse_method_type("(?foo: Integer) -> untyped").params, - parse_method_type("() -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: Integer) -> untyped").type.params, + parse_method_type("() -> untyped").type.params | parse_method_type("(?foo: Integer) -> untyped").type.params # none:rest - assert_equal parse_method_type("(**Integer) -> untyped").params, - parse_method_type("() -> untyped").params | parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(**Integer) -> untyped").type.params, + parse_method_type("() -> untyped").type.params | parse_method_type("(**Integer) -> untyped").type.params # rest:req - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(**String) -> untyped").params | parse_method_type("(foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params | parse_method_type("(foo: Integer) -> untyped").type.params # rest:opt - assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, - parse_method_type("(**String) -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params | parse_method_type("(?foo: Integer) -> untyped").type.params # rest:none - assert_equal parse_method_type("(**String) -> untyped").params, - parse_method_type("(**String) -> untyped").params | parse_method_type("() -> untyped").params + assert_equal parse_method_type("(**String) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params | parse_method_type("() -> untyped").type.params # rest:rest - assert_equal parse_method_type("(**String | Integer) -> untyped").params, - parse_method_type("(**String) -> untyped").params | parse_method_type("(**Integer) -> untyped").params + assert_equal parse_method_type("(**String | Integer) -> untyped").type.params, + parse_method_type("(**String) -> untyped").type.params | parse_method_type("(**Integer) -> untyped").type.params end end diff --git a/test/subtyping_test.rb b/test/subtyping_test.rb index 13ecaedd6..82db53255 100644 --- a/test/subtyping_test.rb +++ b/test/subtyping_test.rb @@ -822,4 +822,14 @@ def test_logic_type assert_success_check checker, type, "::FalseClass" end end + + def test_proc_type + with_checker do |checker| + assert_success_check checker, "^() { () -> void } -> void", "^() { () -> void } -> void" + assert_success_check checker, "^() { (::String) -> ::Object } -> void", "^() { (::Object) -> ::String } -> void" + + assert_fail_check checker, "^() { (::Object) -> void } -> void", "^() { (::String) -> void } -> void" + assert_fail_check checker, "^() { () -> ::String } -> void", "^() { () -> ::Object } -> void" + end + end end diff --git a/test/type_construction_test.rb b/test/type_construction_test.rb index 4559318d0..30fdc3083 100644 --- a/test/type_construction_test.rb +++ b/test/type_construction_test.rb @@ -6649,4 +6649,40 @@ def test_logic_or2 end end end + + def test_self_attributes + with_checker(< void } -> Array[String] +f = -> (n, &b) { b["foo"]; ["bar"] } +RUBY + + with_standard_construction(checker, source) do |construction, typing| + construction.synthesize(source.node) + + assert_no_error typing + end + end + end end diff --git a/test/type_factory_test.rb b/test/type_factory_test.rb index 1ee793b33..7a79860a2 100644 --- a/test/type_factory_test.rb +++ b/test/type_factory_test.rb @@ -28,6 +28,7 @@ def assert_overload_including(c, *types) end Types = Steep::AST::Types + Interface = Steep::Interface include TestHelper include FactoryHelper @@ -120,8 +121,19 @@ def test_type factory.type(parse_type("^(a, ?b, *c, d, x: e, ?y: f, **g) -> void")).yield_self do |type| assert_instance_of Types::Proc, type - assert_equal "(a, ?b, *c, x: e, ?y: f, **g)", type.params.to_s - assert_instance_of Types::Void, type.return_type + assert_equal "(a, ?b, *c, x: e, ?y: f, **g)", type.type.params.to_s + assert_instance_of Types::Void, type.type.return_type + end + + factory.type(parse_type("^() ?{ (Integer) -> void } -> void")).yield_self do |type| + assert_instance_of Types::Proc, type + + assert_equal "()", type.type.params.to_s + assert_instance_of Types::Void, type.type.return_type + + assert_instance_of Interface::Block, type.block + assert_predicate type.block, :optional? + assert_equal "?{ (Integer) -> void }", type.block.to_s end factory.type(RBS::Types::Variable.new(name: :T, location: nil)) do |type| @@ -511,6 +523,32 @@ def test_proc_type end end end + + factory.type(parse_type("^(String) { (Object) -> Symbol } -> Integer")).yield_self do |type| + factory.interface(type, private: false).yield_self do |interface| + assert_instance_of Steep::Interface::Interface, interface + + interface.methods[:call].yield_self do |entry| + assert_overload_with entry, "(String) { (Object) -> Symbol } -> Integer" + end + + refute_operator interface.methods, :key?, :[] + end + end + + factory.type(parse_type("^(String) ?{ (Object) -> Symbol } -> Integer")).yield_self do |type| + factory.interface(type, private: false).yield_self do |interface| + assert_instance_of Steep::Interface::Interface, interface + + interface.methods[:call].yield_self do |entry| + assert_overload_with entry, "(String) ?{ (Object) -> Symbol } -> Integer" + end + + interface.methods[:[]].yield_self do |entry| + assert_overload_with entry, "(String) -> Integer" + end + end + end end end @@ -606,23 +644,23 @@ def !: () -> bool assert_instance_of Steep::Interface::Interface, interface interface.methods[:is_a?].tap do |is_a| - assert_instance_of Types::Logic::ReceiverIsArg, is_a.method_types[0].return_type + assert_instance_of Types::Logic::ReceiverIsArg, is_a.method_types[0].type.return_type end interface.methods[:kind_of?].tap do |kind_of| - assert_instance_of Types::Logic::ReceiverIsArg, kind_of.method_types[0].return_type + assert_instance_of Types::Logic::ReceiverIsArg, kind_of.method_types[0].type.return_type end interface.methods[:instance_of?].tap do |instance_of| - assert_instance_of Types::Logic::ReceiverIsArg, instance_of.method_types[0].return_type + assert_instance_of Types::Logic::ReceiverIsArg, instance_of.method_types[0].type.return_type end interface.methods[:nil?].tap do |nilp| - assert_instance_of Types::Logic::ReceiverIsNil, nilp.method_types[0].return_type + assert_instance_of Types::Logic::ReceiverIsNil, nilp.method_types[0].type.return_type end interface.methods[:!].tap do |unot| - assert_instance_of Types::Logic::Not, unot.method_types[0].return_type + assert_instance_of Types::Logic::Not, unot.method_types[0].type.return_type end end end @@ -632,19 +670,19 @@ def !: () -> bool assert_instance_of Steep::Interface::Interface, interface interface.methods[:is_a?].tap do |is_a| - assert_instance_of Types::Boolean, is_a.method_types[0].return_type + assert_instance_of Types::Boolean, is_a.method_types[0].type.return_type end interface.methods[:kind_of?].tap do |kind_of| - assert_instance_of Types::Boolean, kind_of.method_types[0].return_type + assert_instance_of Types::Boolean, kind_of.method_types[0].type.return_type end interface.methods[:nil?].tap do |nilp| - assert_instance_of Types::Boolean, nilp.method_types[0].return_type + assert_instance_of Types::Boolean, nilp.method_types[0].type.return_type end interface.methods[:!].tap do |unot| - assert_instance_of Types::Boolean, unot.method_types[0].return_type + assert_instance_of Types::Boolean, unot.method_types[0].type.return_type end end end diff --git a/test/typing_test.rb b/test/typing_test.rb index b2e13787a..ff775cfd7 100644 --- a/test/typing_test.rb +++ b/test/typing_test.rb @@ -33,7 +33,7 @@ def test_1 typing = Steep::Typing.new(source: source, root_context: context) - type = parse_method_type("() -> String").return_type + type = parse_type("::String") typing.add_typing(node, type, context) @@ -47,7 +47,7 @@ def test_new_child_with_save typing = Steep::Typing.new(source: source, root_context: context) - type = parse_method_type("() -> String").return_type + type = parse_type("::String") typing.add_typing(node, type, context) @@ -71,7 +71,7 @@ def test_new_child_without_save typing = Steep::Typing.new(source: source, root_context: context) - type = parse_method_type("() -> String").return_type + type = parse_type("::String") typing.add_typing(node, type, context) @@ -93,7 +93,7 @@ def test_new_child_check typing = Steep::Typing.new(source: source, root_context: context) - type = parse_method_type("() -> String").return_type + type = parse_type("::String") typing.add_typing(node, type, context) @@ -113,7 +113,7 @@ def test_new_child_check2 typing = Steep::Typing.new(source: source, root_context: context) - type = parse_method_type("() -> String").return_type + type = parse_type("::String") child1 = typing.new_child(typing.contexts.range) child1.add_typing(node.children[0], type, context)