Skip to content

Commit

Permalink
parser: enable record types with optional fields
Browse files Browse the repository at this point in the history
this enables richer definitions of record types, which can only be defined nowadays with Hash[Symbol, untyped].
  • Loading branch information
HoneyryderChuck committed Jan 12, 2024
1 parent 3272daf commit 864611b
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 24 deletions.
27 changes: 20 additions & 7 deletions ext/rbs_extension/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ typedef struct {
VALUE rest_keywords;
} method_params;

typedef struct {
VALUE fields;
VALUE optional_fields;
} record_fields;

// /**
// * Returns RBS::Location object of `current_token` of a parser state.
// *
Expand Down Expand Up @@ -698,15 +703,22 @@ static VALUE parse_proc_type(parserstate *state) {
| {} literal_type `=>` <type>
*/
VALUE parse_record_attributes(parserstate *state) {
VALUE hash = rb_hash_new();
VALUE fields = rb_hash_new();

if (state->next_token.type == pRBRACE) {
return hash;
return fields;
}

while (true) {
VALUE key;
VALUE type;
VALUE key, type,
value = rb_ary_new(),
required = Qtrue;

if (state->next_token.type == pQUESTION) {
// { ?foo: type } syntax
required = Qfalse;
parser_advance(state);
}

if (is_keyword(state)) {
// { foo: type } syntax
Expand Down Expand Up @@ -735,7 +747,9 @@ VALUE parse_record_attributes(parserstate *state) {
parser_advance_assert(state, pFATARROW);
}
type = parse_type(state);
rb_hash_aset(hash, key, type);
rb_ary_push(value, type);
rb_ary_push(value, required);
rb_hash_aset(fields, key, value);

if (parser_advance_if(state, pCOMMA)) {
if (state->next_token.type == pRBRACE) {
Expand All @@ -745,8 +759,7 @@ VALUE parse_record_attributes(parserstate *state) {
break;
}
}

return hash;
return fields;
}

/*
Expand Down
5 changes: 2 additions & 3 deletions ext/rbs_extension/ruby_objs.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ VALUE rbs_literal(VALUE literal, VALUE location) {
);
}

VALUE rbs_record(VALUE fields, VALUE location) {
VALUE rbs_record(VALUE fields,VALUE location) {
VALUE args = rb_hash_new();
rb_hash_aset(args, ID2SYM(rb_intern("location")), location);
rb_hash_aset(args, ID2SYM(rb_intern("fields")), fields);
rb_hash_aset(args, ID2SYM(rb_intern("all_fields")), fields);

return CLASS_NEW_INSTANCE(
RBS_Types_Record,
Expand Down Expand Up @@ -588,4 +588,3 @@ VALUE rbs_ast_directives_use_wildcard_clause(VALUE namespace, VALUE location) {

return CLASS_NEW_INSTANCE(RBS_AST_Directives_Use_WildcardClause, 1, &kwargs);
}

57 changes: 44 additions & 13 deletions lib/rbs/types.rb
Original file line number Diff line number Diff line change
Expand Up @@ -515,73 +515,104 @@ def with_nonreturn_void?
end

class Record
attr_reader :fields
attr_reader :all_fields
attr_reader :location

def initialize(fields:, location:)
@fields = fields
def initialize(all_fields: nil, fields: nil, location:)
if (all_fields && fields) || (all_fields.nil? && fields.nil?)
raise ArgumentError, "only one of `:fields` or `:all_fields` is requireds"
end

if fields
@all_fields = fields.map { |k, v| [k, [v, true]] }.to_h
@fields = fields
else
@all_fields = all_fields
@fields = nil
end

@location = location
@optional_fields = nil
end

def fields
@fields ||= all_fields.filter_map { |k, (v, required)| [k, v] if required }.to_h
end

def optional_fields
return if all_fields.size == fields.size

@optional_fields ||= all_fields.filter_map { |k, (v, required)| [k, v] unless required }.to_h
end

def ==(other)
other.is_a?(Record) && other.fields == fields
other.is_a?(Record) && other.fields == fields && other.optional_fields == optional_fields
end

alias eql? ==

def hash
self.class.hash ^ fields.hash
self.class.hash ^ all_fields.hash
end

def free_variables(set = Set.new)
set.tap do
fields.each_value do |type|
type.free_variables set
end
optional_fields&.each_value do |type|
type.free_variables set
end
end
end

def to_json(state = _ = nil)
{ class: :record, fields: fields, location: location }.to_json(state)
{ class: :record, fields: fields, optional_fields: optional_fields, location: location }.to_json(state)
end

def sub(s)
self.class.new(fields: fields.transform_values {|ty| ty.sub(s) },
location: location)
self.class.new(
all_fields: all_fields.transform_values {|ty, required| [ty.sub(s), required] },
location: location
)
end

def to_s(level = 0)
return "{ }" if self.fields.empty?
return "{ }" if all_fields.empty?

fields = self.fields.map do |key, type|
if key.is_a?(Symbol) && key.match?(/\A[A-Za-z_][A-Za-z_0-9]*\z/)
fields = all_fields.map do |key, (type, required)|
field = if key.is_a?(Symbol) && key.match?(/\A[A-Za-z_][A-Za-z_0-9]*\z/)
"#{key}: #{type}"
else
"#{key.inspect} => #{type}"
end

field = "?#{field}" unless required
field
end
"{ #{fields.join(", ")} }"
end

def each_type(&block)
if block
fields.each_value(&block)
optional_fields&.each_value(&block)
else
enum_for :each_type
end
end

def map_type_name(&block)
Record.new(
fields: fields.transform_values {|ty| ty.map_type_name(&block) },
all_fields: all_fields.transform_values {|ty, required| [ty.map_type_name(&block), required] },
location: location
)
end

def map_type(&block)
if block
Record.new(
fields: fields.transform_values {|type| yield type },
all_fields: all_fields.transform_values {|type, _| yield type },
location: location
)
else
Expand Down
7 changes: 6 additions & 1 deletion sig/types.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,21 @@ module RBS
end

class Record
attr_reader fields: Hash[Symbol, t]
attr_reader all_fields: Hash[Symbol, [t, bool]]

type loc = Location[bot, bot]

def initialize: (fields: Hash[Symbol, t], location: loc?) -> void
| (all_fields: Hash[Symbol, [t, bool]], location: loc?) -> void

include _TypeBase

attr_reader location: loc?

def fields: () -> Hash[Symbol, t]

def optional_fields: () -> Hash[Symbol, t]?

def map_type: () { (t) -> t } -> Record
| () -> Enumerator[t, Record]
end
Expand Down
9 changes: 9 additions & 0 deletions test/rbs/type_parsing_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,15 @@ def test_record
end

def test_record_with_optional_key
Parser.parse_type("{ ?foo: untyped }").yield_self do |type|
assert_instance_of Types::Record, type
assert_equal({}, type.fields)
assert_equal({
foo: Types::Bases::Any.new(location: nil),
}, type.optional_fields)
assert_equal "{ ?foo: untyped }", type.location.source
end

error = assert_raises(RBS::ParsingError) do
Parser.parse_type("{ 1?: untyped }")
end
Expand Down

0 comments on commit 864611b

Please sign in to comment.