Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support singleton class definition #166

Merged
merged 3 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 131 additions & 52 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,39 @@ def for_new_method(method_name, node, args:, self_type:, definition:)
)
end

def implement_module(module_name:, super_name: nil, annotations:)
if (annotation = annotations.implement_module_annotation)
absolute_name(annotation.name.name).yield_self do |absolute_name|
if checker.factory.class_name?(absolute_name) || checker.factory.module_name?(absolute_name)
AST::Annotation::Implements::Module.new(
name: absolute_name,
args: annotation.name.args
)
else
Steep.logger.error "Unknown class name given to @implements: #{annotation.name.name}"
nil
end
end
else
name = nil
name ||= absolute_name(module_name).yield_self do |absolute_name|
absolute_name if checker.factory.class_name?(absolute_name) || checker.factory.module_name?(absolute_name)
end
name ||= super_name && absolute_name(super_name).yield_self do |absolute_name|
absolute_name if checker.factory.class_name?(absolute_name) || checker.factory.module_name?(absolute_name)
end

if name
absolute_name_ = checker.factory.type_name_1(name)
entry = checker.factory.env.class_decls[absolute_name_]
AST::Annotation::Implements::Module.new(
name: name,
args: entry.type_params.each.map(&:name)
)
end
end
end

def for_module(node)
new_module_name = Names::Module.from_node(node.children.first) or raise "Unexpected module name: #{node.children.first}"
new_namespace = nested_namespace_for_module(new_module_name)
Expand All @@ -240,28 +273,7 @@ def for_module(node)
annots = source.annotations(block: node, factory: checker.factory, current_module: new_namespace)
module_type = AST::Builtin::Module.instance_type

implement_module_name = yield_self do
if (annotation = annots.implement_module_annotation)
absolute_name(annotation.name.name).yield_self do |absolute_name|
if checker.factory.module_name?(absolute_name)
AST::Annotation::Implements::Module.new(name: absolute_name,
args: annotation.name.args)
else
Steep.logger.error "Unknown module name given to @implements: #{annotation.name.name}"
nil
end
end
else
absolute_name(new_module_name).yield_self do |absolute_name|
if checker.factory.module_name?(absolute_name)
absolute_name_ = checker.factory.type_name_1(absolute_name)
entry = checker.factory.env.class_decls[absolute_name_]
AST::Annotation::Implements::Module.new(name: absolute_name,
args: entry.type_params.each.map(&:name))
end
end
end
end
implement_module_name = implement_module(module_name: new_module_name, annotations: annots)

if implement_module_name
module_name = implement_module_name.name
Expand Down Expand Up @@ -352,36 +364,7 @@ def for_class(node)

annots = source.annotations(block: node, factory: checker.factory, current_module: new_namespace)

implement_module_name = yield_self do
if (annotation = annots.implement_module_annotation)
absolute_name(annotation.name.name).yield_self do |absolute_name|
if checker.factory.class_name?(absolute_name)
AST::Annotation::Implements::Module.new(name: absolute_name,
args: annotation.name.args)
else
Steep.logger.error "Unknown class name given to @implements: #{annotation.name.name}"
nil
end
end
else
name = nil
name ||= absolute_name(new_class_name).yield_self do |absolute_name|
absolute_name if checker.factory.class_name?(absolute_name)
end
name ||= super_class_name && absolute_name(super_class_name).yield_self do |absolute_name|
absolute_name if checker.factory.class_name?(absolute_name)
end

if name
absolute_name_ = checker.factory.type_name_1(name)
entry = checker.factory.env.class_decls[absolute_name_]
AST::Annotation::Implements::Module.new(
name: name,
args: entry.type_params.each.map(&:name)
)
end
end
end
implement_module_name = implement_module(module_name: new_class_name, super_name: super_class_name, annotations: annots)

if annots.implement_module_annotation
new_class_name = implement_module_name.name
Expand Down Expand Up @@ -450,6 +433,83 @@ def for_class(node)
)
end

def for_sclass(node, type)
annots = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)

instance_type = if type.is_a?(AST::Types::Self)
context.self_type
else
type
end

module_type = case instance_type
when AST::Types::Name::Class
AST::Builtin::Class.instance_type
when AST::Types::Name::Module
AST::Builtin::Module.instance_type
when AST::Types::Name::Instance
instance_type.to_class(constructor: nil)
else
raise "Unexpected type for sclass node: #{type}"
end

instance_definition = case instance_type
when AST::Types::Name::Class, AST::Types::Name::Module
type_name = checker.factory.type_name_1(instance_type.name)
checker.factory.definition_builder.build_singleton(type_name)
when AST::Types::Name::Instance
type_name = checker.factory.type_name_1(instance_type.name)
checker.factory.definition_builder.build_instance(type_name)
end

module_definition = case module_type
when AST::Types::Name::Class, AST::Types::Name::Module
type_name = checker.factory.type_name_1(instance_type.name)
checker.factory.definition_builder.build_singleton(type_name)
else
nil
end

module_context = TypeInference::Context::ModuleContext.new(
instance_type: annots.instance_type || instance_type,
module_type: annots.self_type || annots.module_type || module_type,
implement_name: nil,
current_namespace: current_namespace,
const_env: self.module_context.const_env,
class_name: self.module_context.class_name,
module_definition: module_definition,
instance_definition: instance_definition
)

type_env = TypeInference::TypeEnv.build(annotations: annots,
subtyping: checker,
const_env: self.module_context.const_env,
signatures: checker.factory.env)

lvar_env = TypeInference::LocalVariableTypeEnv.empty(
subtyping: checker,
self_type: module_context.module_type
).annotate(annots)

body_context = TypeInference::Context.new(
method_context: nil,
block_context: nil,
module_context: module_context,
break_context: nil,
self_type: module_context.module_type,
type_env: type_env,
lvar_env: lvar_env
)

self.class.new(
checker: checker,
source: source,
annotations: annots,
typing: typing,
context: body_context
)
end

def for_branch(node, truthy_vars: Set.new, type_case_override: nil, break_context: context.break_context)
annots = source.annotations(block: node, factory: checker.factory, current_module: current_namespace)

Expand Down Expand Up @@ -1130,6 +1190,25 @@ def synthesize(node, hint: nil)
add_typing(node, type: AST::Builtin.nil_type)
end

when :sclass
yield_self do
type, constr = synthesize(node.children[0])
constructor = constr.for_sclass(node, type)

constructor.typing.add_context_for_node(node, context: constructor.context)
constructor.typing.add_context_for_body(node, context: constructor.context)

constructor.synthesize(node.children[1]) if node.children[1]

if constructor.module_context.instance_definition && module_context.module_definition
if constructor.module_context.instance_definition.type_name == module_context.module_definition.type_name
module_context.defined_module_methods.merge(constructor.module_context.defined_instance_methods)
end
end

add_typing(node, type: AST::Builtin.nil_type)
end

when :self
add_typing node, type: AST::Types::Self.new

Expand Down
6 changes: 6 additions & 0 deletions lib/steep/typing.rb
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def add_context_for_body(node, context:)
end_pos = node.loc.end.begin_pos
add_context(begin_pos..end_pos, context: context)

when :sclass
name_node = node.children[0]
begin_pos = name_node.loc.expression.end_pos
end_pos = node.loc.end.begin_pos
add_context(begin_pos..end_pos, context: context)

when :def, :defs
args_node = case node.type
when :def
Expand Down
120 changes: 120 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4635,4 +4635,124 @@ class B
end
end
end

def test_singleton_class_in_class_decl
with_checker <<-RBS do |checker|
class WithSingleton
def self.open: [A] { () -> A } -> A
end
RBS
source = parse_ruby(<<'EOF')
class WithSingleton
class <<self
def open
yield new()
end
end
end
EOF

with_standard_construction(checker, source) do |construction, typing|
class_constr = construction.for_class(source.node)
type, _ = class_constr.synthesize(dig(source.node, 2, 0))
sclass_constr = class_constr.for_sclass(dig(source.node, 2), type)

module_context = sclass_constr.context.module_context

assert_equal parse_type("singleton(::WithSingleton)"), module_context.instance_type
assert_equal parse_type("::Class"), module_context.module_type
assert_equal "::WithSingleton", module_context.class_name.to_s
assert_nil module_context.implement_name
assert_nil module_context.module_definition
assert_equal "::WithSingleton", module_context.instance_definition.type_name.to_s

construction.synthesize(source.node)

assert_no_error typing
end
end
end

def test_singleton_class_in_class_decl_error
with_checker <<-RBS do |checker|
class WithSingleton
def self.open: [A] { (instance) -> A } -> A
end
RBS
source = parse_ruby(<<'EOF')
class WithSingleton
class <<self
def open
yield 30
end
end
end
EOF

with_standard_construction(checker, source) do |construction, typing|
construction.synthesize(source.node)

assert_equal 1, typing.errors.size
assert_instance_of Steep::Errors::IncompatibleAssignment, typing.errors[0]
end
end
end

def test_singleton_class_for_object_success
with_checker <<-'RBS' do |checker|
class WithSingleton
def open: [A] { () -> A } -> A
end
RBS
source = parse_ruby(<<-'RUBY')
class <<(WithSingleton.new)
def open
yield new()
end
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _ = construction.synthesize(dig(source.node, 0))
sclass_constr = construction.for_sclass(dig(source.node), type)

module_context = sclass_constr.context.module_context

assert_equal parse_type("::WithSingleton"), module_context.instance_type
assert_equal parse_type("singleton(::WithSingleton)"), module_context.module_type
assert_nil module_context.class_name
assert_nil module_context.implement_name
assert_equal "::WithSingleton", module_context.module_definition.type_name.to_s
assert_equal "::WithSingleton", module_context.instance_definition.type_name.to_s

construction.synthesize(source.node)

assert_no_error typing
end
end
end

def test_singleton_class_for_object_type_check
with_checker <<-'RBS' do |checker|
class WithSingleton
def open: [A] { () -> A } -> A
end
RBS
source = parse_ruby(<<-'RUBY')
class <<(WithSingleton.new)
def open(x)
x
end
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
construction.synthesize(source.node)

assert_equal 2, typing.errors.size
assert_instance_of Steep::Errors::MethodArityMismatch, typing.errors[0]
assert_instance_of Steep::Errors::FallbackAny, typing.errors[1]
end
end
end
end