Skip to content

Commit

Permalink
Support optional keys in record types
Browse files Browse the repository at this point in the history
  • Loading branch information
soutaro committed Sep 15, 2024
1 parent 805a91b commit db5a1f9
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 25 deletions.
23 changes: 14 additions & 9 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,15 @@ def type(type)
when RBS::Types::Tuple
Tuple.new(types: type.types.map {|ty| type(ty) })
when RBS::Types::Record
elements = type.fields.each.with_object({}) do |(key, value), hash|
hash[key] = type(value)
elements = {} #: Hash[Record::key, AST::Types::t]
required_keys = Set[] #: Set[Record::key]

type.all_fields.each do |key, (value, required)|
required_keys << key if required
elements[key] = type(value)
end
Record.new(elements: elements)

Record.new(elements: elements, required_keys: required_keys)
when RBS::Types::Proc
func = Interface::Function.new(
params: params(type.type),
Expand Down Expand Up @@ -204,10 +209,12 @@ def type_1(type)
location: nil
)
when Record
fields = type.elements.each.with_object({}) do |(key, value), hash|
hash[key] = type_1(value)
all_fields = {} #: Hash[Symbol, [RBS::Types::t, bool]]
type.elements.each do |key, value|
raise unless key.is_a?(Symbol)
all_fields[key] = [type_1(value), type.required?(key)]
end
RBS::Types::Record.new(fields: fields, location: nil)
RBS::Types::Record.new(all_fields: all_fields, location: nil)
when Proc
block = if type.block
RBS::Types::Block.new(
Expand Down Expand Up @@ -523,9 +530,7 @@ def normalize_type(type)
types: type.types.map {|type| normalize_type(type) }
)
when AST::Types::Record
AST::Types::Record.new(
elements: type.elements.transform_values {|type| normalize_type(type) }
)
type.map_type {|type| normalize_type(type) }
when AST::Types::Tuple
AST::Types::Tuple.new(
types: type.types.map {|type| normalize_type(type) }
Expand Down
31 changes: 24 additions & 7 deletions lib/steep/ast/types/record.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,37 @@ module Steep
module AST
module Types
class Record
attr_reader :elements
attr_reader :elements, :required_keys

def initialize(elements:)
def initialize(elements:, required_keys:)
@elements = elements
@required_keys = required_keys
end

def ==(other)
other.is_a?(Record) && other.elements == elements
other.is_a?(Record) && other.elements == elements && other.required_keys == required_keys
end

def hash
self.class.hash ^ elements.hash
self.class.hash ^ elements.hash ^ required_keys.hash
end

alias eql? ==

def subst(s)
self.class.new(elements: elements.transform_values {|type| type.subst(s) })
self.class.new(
elements: elements.transform_values {|type| type.subst(s) },
required_keys: required_keys
)
end

def to_s
strings = elements.keys.sort.map do |key|
"#{key.inspect} => #{elements[key]}"
if optional?(key)
"?#{key.inspect} => #{elements[key]}"
else
"#{key.inspect} => #{elements[key]}"
end
end
"{ #{strings.join(", ")} }"
end
Expand All @@ -49,13 +57,22 @@ def each_child(&block)

def map_type(&block)
self.class.new(
elements: elements.transform_values(&block)
elements: elements.transform_values(&block),
required_keys: required_keys
)
end

def level
[0] + level_of_children(elements.values)
end

def required?(key)
required_keys.include?(key)
end

def optional?(key)
!required_keys.include?(key)
end
end
end
end
Expand Down
4 changes: 4 additions & 0 deletions lib/steep/interface/builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ def record_shape(record)
overloads: record.elements.map do |key_value, value_type|
key_type = AST::Types::Literal.new(value: key_value)

if record.optional?(key_value)
value_type = AST::Builtin.optional(value_type)
end

Shape::MethodOverload.new(
MethodType.new(
type_params: [],
Expand Down
18 changes: 12 additions & 6 deletions lib/steep/subtyping/check.rb
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,19 @@ def check_type0(relation)
when relation.sub_type.is_a?(AST::Types::Record) && relation.super_type.is_a?(AST::Types::Record)
All(relation) do |result|
relation.super_type.elements.each_key do |key|
rel = Relation.new(
sub_type: relation.sub_type.elements[key] || AST::Builtin.nil_type,
super_type: relation.super_type.elements[key]
)
super_element_type = relation.super_type.elements[key]

result.add(rel) do
check_type(rel)
if relation.sub_type.elements.key?(key)
sub_element_type = relation.sub_type.elements[key]
else
if relation.super_type.required?(key)
sub_element_type = AST::Builtin.nil_type
end
end

if sub_element_type
rel = Relation.new(sub_type: sub_element_type, super_type: super_element_type)
result.add(rel) { check_type(rel) }
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4962,7 +4962,7 @@ def type_hash_record(hash_node, record_type)
each_child_node(hash_node) do |child|
if child.type == :pair
case child.children[0].type
when :sym, :str, :int
when :sym
key_node = child.children[0] #: Parser::AST::Node
value_node = child.children[1] #: Parser::AST::Node

Expand Down Expand Up @@ -4991,7 +4991,7 @@ def type_hash_record(hash_node, record_type)
end
end

type = AST::Types::Record.new(elements: elems)
type = AST::Types::Record.new(elements: elems, required_keys: record_type&.required_keys || Set.new(elems.keys))
constr.add_typing(hash_node, type: type)
end

Expand Down
8 changes: 7 additions & 1 deletion sig/steep/ast/types/record.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ module Steep

attr_reader elements: Hash[key, t]

def initialize: (elements: Hash[key, t]) -> void
attr_reader required_keys: Set[key]

def initialize: (elements: Hash[key, t], required_keys: Set[key]) -> void

def ==: (untyped other) -> bool

Expand All @@ -29,6 +31,10 @@ module Steep
def map_type: () { (t) -> t } -> Record

def level: () -> Array[Integer]

def required?: (key) -> bool

def optional?: (key) -> bool
end
end
end
Expand Down
71 changes: 71 additions & 0 deletions test/type_check_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2223,4 +2223,75 @@ def foo(x)
YAML
)
end

def test_record__optional_key__assignment
run_type_check_test(
signatures: {},
code: {
"a.rb" => <<~RUBY
# @type var record: { id: Integer, ?name: String }
record = { id: 123, name: "Hello" }
record = { id: 123 }
record = { id: 123, name: 123 }
RUBY
},
expectations: <<~YAML
---
- file: a.rb
diagnostics:
- range:
start:
line: 6
character: 0
end:
line: 6
character: 31
severity: ERROR
message: |-
Cannot assign a value of type `{ :id => ::Integer, ?:name => ::Integer }` to a variable of type `{ :id => ::Integer, ?:name => ::String }`
{ :id => ::Integer, ?:name => ::Integer } <: { :id => ::Integer, ?:name => ::String }
::Integer <: ::String
::Numeric <: ::String
::Object <: ::String
::BasicObject <: ::String
code: Ruby::IncompatibleAssignment
YAML
)
end

def test_record__optional_key__get
run_type_check_test(
signatures: {},
code: {
"a.rb" => <<~RUBY
# @type var record: { id: Integer, ?name: String }
record = _ = nil
record[:id] + 1
record[:name] + ""
record.fetch(:id) + 1
record.fetch(:name) + ""
RUBY
},
expectations: <<~YAML
---
- file: a.rb
diagnostics:
- range:
start:
line: 6
character: 14
end:
line: 6
character: 15
severity: ERROR
message: Type `(::String | nil)` does not have method `+`
code: Ruby::NoMethod
YAML
)
end
end

0 comments on commit db5a1f9

Please sign in to comment.