diff --git a/spec/lang/code_gen/local_type_spec.lua b/spec/lang/code_gen/local_type_spec.lua index 1fa2eab6e..86036871f 100644 --- a/spec/lang/code_gen/local_type_spec.lua +++ b/spec/lang/code_gen/local_type_spec.lua @@ -239,4 +239,43 @@ describe("local type code generation", function() ]]) end) + it("always elides local type require used as a variable, even if incorrect use of interfaces or aliases", util.gen([[ + local interface IFoo + end + + local type Alias = IFoo + + local record Foo is IFoo + end + + local function register(_id:any, _value:any) + end + + local foo:Foo + + register(IFoo, foo) + + register(Alias, foo) + ]], [[ + + + + + + + + + local function register(_id, _value) + end + + local foo + + register(IFoo, foo) + + register(Alias, foo) + ]], nil, { + { y = 14, msg = "interfaces are abstract" }, + { y = 16, msg = "interfaces are abstract" }, + })) + end) diff --git a/spec/util.lua b/spec/util.lua index 95b3d1b62..3b3dc30aa 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -599,13 +599,22 @@ function util.check_types(code, types) end end -local function gen(lax, code, expected, gen_target) +local function gen(lax, code, expected, gen_target, type_errors) return function() local ast, syntax_errors = tl.parse(code, "foo.tl") assert.same({}, syntax_errors, "Code was not expected to have syntax errors") local gen_compat = gen_target == "5.4" and "off" or nil local result = tl.check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) - assert.same({}, result.type_errors) + + if type_errors then + local batch = batch_assertions() + local result_type_errors = combine_result(result, "type_errors") + batch_compare(batch, "type errors", type_errors, result_type_errors) + batch:assert() + else + assert.same({}, result.type_errors) + end + local output_code = tl.pretty_print_ast(ast, gen_target) local expected_ast, expected_errors = tl.parse(expected, "foo.tl") @@ -616,11 +625,11 @@ local function gen(lax, code, expected, gen_target) end end -function util.gen(code, expected, gen_target) +function util.gen(code, expected, gen_target, type_errors) assert(type(code) == "string") assert(type(expected) == "string") - return gen(false, code, expected, gen_target) + return gen(false, code, expected, gen_target, type_errors) end function util.run_check_type_error(...) diff --git a/tl.lua b/tl.lua index c9ce6dea2..211e87312 100644 --- a/tl.lua +++ b/tl.lua @@ -6341,6 +6341,20 @@ function Errors:add_prefixing(w, src, prefix, dst) end end +local function ensure_not_abstract(t) + if t.typename == "function" and t.macroexp then + return nil, "macroexps are abstract; consider using a concrete function" + elseif t.typename == "typedecl" then + local def = t.def + if def.typename == "interface" then + return nil, "interfaces are abstract; consider using a concrete record" + elseif not (def.typename == "record") then + return nil, "cannot use a type definition as a concrete value" + end + end + return true +end + @@ -6368,7 +6382,9 @@ local function check_for_unused_vars(scope, is_global) end elseif var.used and t.typename == "typedecl" and var.aliasing then var.aliasing.used = true - var.aliasing.declared_at.elide_type = false + if ensure_not_abstract(t) then + var.aliasing.declared_at.elide_type = false + end end end if list then @@ -7298,20 +7314,6 @@ do end end - local function ensure_not_abstract(t) - if t.typename == "function" and t.macroexp then - return nil, "macroexps are abstract; consider using a concrete function" - elseif t.typename == "typedecl" then - local def = t.def - if def.typename == "interface" then - return nil, "interfaces are abstract; consider using a concrete record" - elseif not (def.typename == "record") then - return nil, "cannot use a type definition as a concrete value" - end - end - return true - end - local function ensure_not_method(t) if t.typename == "function" and t.is_method then @@ -7768,10 +7770,9 @@ do return ret end - function TypeChecker:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - local scope = self.st[#self.st] - local var = scope.vars[name] - if narrow then + do + local function narrow_var(scope, node, name, t, attribute, narrow) + local var = scope.vars[name] if var then if var.is_narrowed then var.t = t @@ -7792,42 +7793,44 @@ do return var end - if not dont_check_redeclaration and - node and - name ~= "self" and - name ~= "..." and - name:sub(1, 1) ~= "@" then - - self:check_if_redeclaration(name, node, t) - end + function TypeChecker:add_var(node, name, t, attribute, narrow) + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) + end + if not attribute then + t = drop_constant_value(t) + end - if var and not var.used then + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) + end + local scope = self.st[#self.st] + if narrow then + return narrow_var(scope, node, name, t, attribute, narrow) + end - self.errs:unused_warning(name, var) - end + if node then + if name ~= "self" and name ~= "..." and name:sub(1, 1) ~= "@" then + self:check_if_redeclaration(name, node, t) + end + if not ensure_not_abstract(t) then + node.elide_type = true + end + end - var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope.vars[name] = var + local var = scope.vars[name] + if var and not var.used then - return var - end - function TypeChecker:add_var(node, name, t, attribute, narrow, dont_check_redeclaration) - if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - self.errs:add_unknown(node, name) - end - if not attribute then - t = drop_constant_value(t) - end + self.errs:unused_warning(name, var) + end - local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) + var = { t = t, attribute = attribute, declared_at = node } + scope.vars[name] = var - if self.collector and node then - self.collector.add_to_symbol_list(node, name, t) + return var end - - return var end diff --git a/tl.tl b/tl.tl index 46b8c5120..3af6c1ee2 100644 --- a/tl.tl +++ b/tl.tl @@ -6341,6 +6341,20 @@ function Errors:add_prefixing(w: Where, src: {Error}, prefix: string, dst?: {Err end end +local function ensure_not_abstract(t: Type): boolean, string + if t is FunctionType and t.macroexp then + return nil, "macroexps are abstract; consider using a concrete function" + elseif t is TypeDeclType then + local def = t.def + if def is InterfaceType then + return nil, "interfaces are abstract; consider using a concrete record" + elseif not def is RecordType then + return nil, "cannot use a type definition as a concrete value" + end + end + return true +end + local record Unused y: integer x: integer @@ -6368,7 +6382,9 @@ local function check_for_unused_vars(scope: Scope, is_global?: boolean): {Unused end elseif var.used and t is TypeDeclType and var.aliasing then var.aliasing.used = true - var.aliasing.declared_at.elide_type = false + if ensure_not_abstract(t) then + var.aliasing.declared_at.elide_type = false + end end end if list then @@ -7298,20 +7314,6 @@ do end end - local function ensure_not_abstract(t: Type): boolean, string - if t is FunctionType and t.macroexp then - return nil, "macroexps are abstract; consider using a concrete function" - elseif t is TypeDeclType then - local def = t.def - if def is InterfaceType then - return nil, "interfaces are abstract; consider using a concrete record" - elseif not def is RecordType then - return nil, "cannot use a type definition as a concrete value" - end - end - return true - end - local function ensure_not_method(t: Type): Type if t is FunctionType and t.is_method then @@ -7768,10 +7770,9 @@ do return ret end - function TypeChecker:add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable - local scope = self.st[#self.st] - local var = scope.vars[name] - if narrow then + do + local function narrow_var(scope: Scope, node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow): Variable + local var = scope.vars[name] if var then if var.is_narrowed then var.t = t @@ -7792,42 +7793,44 @@ do return var end - if not dont_check_redeclaration - and node - and name ~= "self" - and name ~= "..." - and name:sub(1, 1) ~= "@" - then - self:check_if_redeclaration(name, node, t) - end + function TypeChecker:add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow): Variable + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) + end + if not attribute then + t = drop_constant_value(t) + end - if var and not var.used then - -- the old var is removed from the scope and won't be checked when it closes, - -- so check it here - self.errs:unused_warning(name, var) - end + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) + end - var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope.vars[name] = var + local scope = self.st[#self.st] + if narrow then + return narrow_var(scope, node, name, t, attribute, narrow) + end - return var - end + if node then + if name ~= "self" and name ~= "..." and name:sub(1, 1) ~= "@" then + self:check_if_redeclaration(name, node, t) + end + if not ensure_not_abstract(t) then + node.elide_type = true + end + end - function TypeChecker:add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable - if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - self.errs:add_unknown(node, name) - end - if not attribute then - t = drop_constant_value(t) - end + local var = scope.vars[name] + if var and not var.used then + -- the old var is removed from the scope and won't be checked when it closes, + -- so check it here + self.errs:unused_warning(name, var) + end - local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) + var = { t = t, attribute = attribute, declared_at = node } + scope.vars[name] = var - if self.collector and node then - self.collector.add_to_symbol_list(node, name, t) + return var end - - return var end local type CompareTypes = function(TypeChecker, Type, Type): boolean, {Error}