From 58701af56f9b549868743e92387054f20cd46f16 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 21 Oct 2023 16:29:40 -0300 Subject: [PATCH 001/224] make type system more nominal, except for unions Instead of treating nominal records nominally and all other nominal types structurally, with this commit we treat all nominal types nominally except for unions, which are treated structurally. Give this branch a try in your codebase and let me know your impressions! --- spec/declaration/record_spec.lua | 2 +- tl.lua | 22 +++++++++++++--------- tl.tl | 22 +++++++++++++--------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 32f6d677a..3275fa440 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -702,7 +702,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do end function Foo.new(): Foo - return setmetatable({}, Foo) -- typing of arguments is being very permissive here, may change in the future and require a cast + return setmetatable({}, Foo as metatable) end local foo = Foo.new() diff --git a/tl.lua b/tl.lua index 19cca66af..e4d69f53e 100644 --- a/tl.lua +++ b/tl.lua @@ -1432,7 +1432,7 @@ end local function new_node(tokens, i, kind) local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind } + return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } end local function a_type(t) @@ -7219,17 +7219,13 @@ tl.type_check = function(ast, opts) end return false, terr(t1, "cannot match against any alternatives of the polymorphic type") elseif t1.typename == "nominal" and t2.typename == "nominal" then - local same, err = are_same_nominals(t1, t2) - if same then - return true - end local t1r = resolve_tuple_and_nominal(t1) local t2r = resolve_tuple_and_nominal(t2) - if is_record_type(t1r) and is_record_type(t2r) then - return same, err - else + if t1r.typename == "union" or t2r.typename == "union" then return is_a(t1r, t2r, for_equality) end + + return are_same_nominals(t1, t2) elseif t1.typename == "enum" and t2.typename == "string" then local ok if for_equality then @@ -9212,7 +9208,7 @@ tl.type_check = function(ast, opts) local infertype = infertypes[i] local rt = resolve_tuple_and_nominal(t) - if rt.typename ~= "enum" and not same_type(t, infertype) then + if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then add_var(where, var.tk, infer_at(where, infertype), "const", "narrowed_declaration") end end @@ -10204,6 +10200,14 @@ tl.type_check = function(ast, opts) end end + if orig_a.typename == "nominal" and orig_b.typename == "nominal" and not meta_on_operator then + if is_a(orig_a, orig_b) then + node.type = resolve_tuple(orig_a) + else + node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + end + end + if types_op == numeric_binop or node.op.op == ".." then node.known = FACT_TRUTHY end diff --git a/tl.tl b/tl.tl index 0ac586dff..44363090f 100644 --- a/tl.tl +++ b/tl.tl @@ -1432,7 +1432,7 @@ end local function new_node(tokens: {Token}, i: integer, kind: NodeKind): Node local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind } + return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end local function a_type(t: Type): Type @@ -7219,17 +7219,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end return false, terr(t1, "cannot match against any alternatives of the polymorphic type") elseif t1.typename == "nominal" and t2.typename == "nominal" then - local same, err = are_same_nominals(t1, t2) - if same then - return true - end local t1r = resolve_tuple_and_nominal(t1) local t2r = resolve_tuple_and_nominal(t2) - if is_record_type(t1r) and is_record_type(t2r) then - return same, err - else + if t1r.typename == "union" or t2r.typename == "union" then return is_a(t1r, t2r, for_equality) end + + return are_same_nominals(t1, t2) elseif t1.typename == "enum" and t2.typename == "string" then local ok: boolean if for_equality then @@ -9212,7 +9208,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local infertype = infertypes[i] local rt = resolve_tuple_and_nominal(t) - if rt.typename ~= "enum" and not same_type(t, infertype) then + if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then add_var(where, var.tk, infer_at(where, infertype), "const", "narrowed_declaration") end end @@ -10204,6 +10200,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + if orig_a.typename == "nominal" and orig_b.typename == "nominal" and not meta_on_operator then + if is_a(orig_a, orig_b) then + node.type = resolve_tuple(orig_a) + else + node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + end + end + if types_op == numeric_binop or node.op.op == ".." then node.known = FACT_TRUTHY end From 79c9b09ad3670d0079db5f868a494db0e909e25a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 15 Sep 2023 20:26:37 -0300 Subject: [PATCH 002/224] don't infer table literal as the first table type of a union Now that unions may contain multiple table types, only infer a table literal into the table type of a union if there's a single table type in a union. --- tl.lua | 20 +++++++++++++++++--- tl.tl | 20 +++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/tl.lua b/tl.lua index e4d69f53e..f4c55b6db 100644 --- a/tl.lua +++ b/tl.lua @@ -9597,14 +9597,28 @@ tl.type_check = function(ast, opts) local decltype = resolve_tuple_and_nominal(node.expected) if decltype.typename == "union" then + local single_table_type + local single_table_rt + for _, t in ipairs(decltype.types) do local rt = resolve_tuple_and_nominal(t) if is_lua_table_type(rt) then - node.expected = t - decltype = rt - break + if single_table_type then + + single_table_type = nil + single_table_rt = nil + break + end + + single_table_type = t + single_table_rt = rt end end + + if single_table_type then + node.expected = single_table_type + decltype = single_table_rt + end end if not is_lua_table_type(decltype) then diff --git a/tl.tl b/tl.tl index 44363090f..aa0e04ee4 100644 --- a/tl.tl +++ b/tl.tl @@ -9597,14 +9597,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local decltype = resolve_tuple_and_nominal(node.expected) if decltype.typename == "union" then + local single_table_type: Type + local single_table_rt: Type + for _, t in ipairs(decltype.types) do local rt = resolve_tuple_and_nominal(t) if is_lua_table_type(rt) then - node.expected = t - decltype = rt - break + if single_table_type then + -- multiple table types in union, give up + single_table_type = nil + single_table_rt = nil + break + end + + single_table_type = t + single_table_rt = rt end end + + if single_table_type then + node.expected = single_table_type + decltype = single_table_rt + end end if not is_lua_table_type(decltype) then From c46dee553ce35e9ee26b42dbdee2c964963d2037 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 18 Aug 2023 15:48:27 -0300 Subject: [PATCH 003/224] macroexp: initial commit --- tl.lua | 141 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- tl.tl | 141 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 280 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index f4c55b6db..85372fa3f 100644 --- a/tl.lua +++ b/tl.lua @@ -1210,6 +1210,10 @@ local table_types = { + + + + @@ -1249,6 +1253,9 @@ local attributes = { } local is_attribute = attributes +local Node = {ExpectedContext = {}, } + + @@ -1460,6 +1467,15 @@ local function shallow_copy_type(t) return copy end + +local function shallow_copy_node(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + return copy +end + local function verify_kind(ps, i, kind, node_kind) if ps.tokens[i].kind == kind then return i + 1, new_node(ps.tokens, i, node_kind) @@ -2733,6 +2749,22 @@ local metamethod_names = { ["__close"] = true, } +local function parse_macroexp(ps, i) + local istart = i - 1 + + + + + local node = new_node(ps.tokens, i, "macroexp") + i, node.args = parse_argument_list(ps, i) + i, node.rets = parse_return_types(ps, i) + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + assert(node.rets.typename == "tuple") + return i, node +end + parse_record_body = function(ps, i, def, node, name) local istart = i - 1 def.fields = {} @@ -2843,6 +2875,13 @@ parse_record_body = function(ps, i, def, node, name) end end + if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then + if t.typename ~= "function" then + fail(ps, i + 1, "macroexp must have a function type") + end + i, t.macroexp = parse_macroexp(ps, i + 2) + end + store_field_in_record(ps, iv, field_name, t, fields, field_order) elseif ps.tokens[i].tk == "=" then local next_word = ps.tokens[i + 1].tk @@ -7524,6 +7563,85 @@ tl.type_check = function(ast, opts) return func, is_method end + + + + local function traverse_macroexp(macroexp, on_arg_id, on_node) + local root = macroexp.exp + local argnames = {} + for i, a in ipairs(macroexp.args) do + argnames[a.tk] = i + end + + local visit_node = { + cbs = { + ["variable"] = { + after = function(node, _children) + local i = argnames[node.tk] + if not i then + return nil + end + + return on_arg_id(node, i) + end, + }, + }, + after = on_node, + } + + return recurse_node(root, visit_node, {}) + end + + local function expand_macroexp(orignode, args, macroexp) + local on_arg_id = function(_node, i) + return { Node, args[i] } + end + + local on_node = function(node, children, ret) + local orig = ret and ret[2] or node + + local out = shallow_copy_node(orig) + + local map = {} + for _, pair in pairs(children) do + if type(pair) == "table" then + map[pair[1]] = pair[2] + end + end + + for k, v in pairs(orig) do + if type(v) == "table" and map[v] then + (out)[k] = map[v] + end + end + + out.y = orignode.y + out.x = orignode.x + out.yend = nil + out.xend = nil + return { node, out } + end + + local p = traverse_macroexp(macroexp, on_arg_id, on_node) + orignode.expanded = p[2] + end + + local function apply_macroexp(orignode) + local expanded = orignode.expanded + local savetype = orignode.type + local saveknown = orignode.known + orignode.expanded = nil + + for k, _ in pairs(orignode) do + (orignode)[k] = nil + end + for k, v in pairs(expanded) do + (orignode)[k] = v + end + orignode.type = savetype + orignode.known = saveknown + end + local type_check_function_call do local function mark_invalid_typeargs(f) @@ -7760,6 +7878,11 @@ tl.type_check = function(ast, opts) if e1 then e1.type = f end + + if func.macroexp then + expand_macroexp(where, where_args, func.macroexp) + end + return ret end end @@ -10136,6 +10259,11 @@ tl.type_check = function(ast, opts) node.type.tk = nil elseif node.op.op == "==" or node.op.op == "~=" then node.type = BOOLEAN + + if is_lua_table_type(ra) and is_lua_table_type(rb) then + check_metamethod(node, node.op.op, ra, rb) + end + if is_a(b, a, true) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then node.known = Fact({ fact = "==", var = node.e1.tk, typ = b, where = node }) @@ -10378,6 +10506,10 @@ tl.type_check = function(ast, opts) visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] visit_node.after = function(node, _children) + if node.expanded then + apply_macroexp(node) + end + if type(node.type) ~= "table" then error(node.kind .. " did not produce a type") end @@ -10517,7 +10649,14 @@ tl.type_check = function(ast, opts) } if not opts.run_internal_compiler_checks then - visit_node.after = nil + visit_node.after = function(node, _children) + if node.expanded then + apply_macroexp(node) + end + + return node.type + end + visit_type.after = nil end diff --git a/tl.tl b/tl.tl index aa0e04ee4..bac5f168b 100644 --- a/tl.tl +++ b/tl.tl @@ -1146,6 +1146,9 @@ local record Type -- enum enumset: {string:boolean} + -- macroexp + macroexp: Node + -- unresolved items labels: {string:{Node}} nominals: {string:{Type}} @@ -1203,6 +1206,7 @@ local enum NodeKind "cast" "..." "paren" + "macroexp" "error_node" end @@ -1339,6 +1343,9 @@ local record Node -- variable is_lvalue: boolean + -- macroexp + expanded: Node + type: Type decltype: Type end @@ -1460,6 +1467,15 @@ local function shallow_copy_type(t: Type): Type return copy as Type end +-- Makes a shallow copy of the given type +local function shallow_copy_node(t: Node): Node + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + return copy as Node +end + local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind: NodeKind): integer, Node if ps.tokens[i].kind == kind then return i + 1, new_node(ps.tokens, i, node_kind) @@ -2733,6 +2749,22 @@ local metamethod_names: {string:boolean} = { ["__close"] = true, } +local function parse_macroexp(ps: ParseState, i: integer): integer, Node + local istart = i - 1 +-- TODO: generic macroexp +-- if ps.tokens[i].tk == "<" then +-- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) +-- end + local node = new_node(ps.tokens, i, "macroexp") + i, node.args = parse_argument_list(ps, i) + i, node.rets = parse_return_types(ps, i) + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + assert(node.rets.typename == "tuple") + return i, node +end + parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node local istart = i - 1 def.fields = {} @@ -2843,6 +2875,13 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end end + if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then + if t.typename ~= "function" then + fail(ps, i + 1, "macroexp must have a function type") + end + i, t.macroexp = parse_macroexp(ps, i + 2) + end + store_field_in_record(ps, iv, field_name, t, fields, field_order) elseif ps.tokens[i].tk == "=" then local next_word = ps.tokens[i + 1].tk @@ -7524,6 +7563,85 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return func, is_method end + local type OnArgId = function(node: Node, i: integer): T + local type OnNode = function(node: Node, children: {T}, ret: T): T + + local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T + local root = macroexp.exp + local argnames = {} + for i, a in ipairs(macroexp.args) do + argnames[a.tk] = i + end + + local visit_node: Visitor = { + cbs = { + ["variable"] = { + after = function(node: Node, _children: {T}): T + local i = argnames[node.tk] + if not i then + return nil + end + + return on_arg_id(node, i) + end + } + }, + after = on_node, + } + + return recurse_node(root, visit_node, {}) + end + + local function expand_macroexp(orignode: Node, args: {Node}, macroexp: Node) + local on_arg_id = function(_node: Node, i: integer): {Node, Node} + return { Node, args[i] } + end + + local on_node = function(node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} + local orig = ret and ret[2] or node + + local out = shallow_copy_node(orig) + + local map = {} + for _, pair in pairs(children as {integer:{Node, Node}}) do + if type(pair) == "table" then + map[pair[1]] = pair[2] + end + end + + for k, v in pairs(orig as {any:Node}) do + if type(v) == "table" and map[v] then + (out as {any:any})[k] = map[v] + end + end + + out.y = orignode.y + out.x = orignode.x + out.yend = nil + out.xend = nil + return { node, out } + end + + local p = traverse_macroexp(macroexp, on_arg_id, on_node) + orignode.expanded = p[2] + end + + local function apply_macroexp(orignode: Node) + local expanded = orignode.expanded + local savetype = orignode.type + local saveknown = orignode.known + orignode.expanded = nil + + for k, _ in pairs(orignode as {any:any}) do + (orignode as {any:any})[k] = nil + end + for k, v in pairs(expanded as {any:any}) do + (orignode as {any:any})[k] = v + end + orignode.type = savetype + orignode.known = saveknown + end + local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, integer): Type do local function mark_invalid_typeargs(f: Type) @@ -7760,6 +7878,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if e1 then e1.type = f end + + if func.macroexp then + expand_macroexp(where, where_args, func.macroexp) + end + return ret end end @@ -10136,6 +10259,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.type.tk = nil elseif node.op.op == "==" or node.op.op == "~=" then node.type = BOOLEAN + + if is_lua_table_type(ra) and is_lua_table_type(rb) then + check_metamethod(node, node.op.op, ra, rb) + end + if is_a(b, a, true) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then node.known = Fact { fact = "==", var = node.e1.tk, typ = b, where = node } @@ -10378,6 +10506,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] visit_node.after = function(node: Node, _children: {Type}): Type + if node.expanded then + apply_macroexp(node) + end + if type(node.type) ~= "table" then error(node.kind .. " did not produce a type") end @@ -10517,7 +10649,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } if not opts.run_internal_compiler_checks then - visit_node.after = nil + visit_node.after = function(node: Node, _children: {Type}): Type + if node.expanded then + apply_macroexp(node) + end + + return node.type + end + visit_type.after = nil end From d36733c2bfed0a13d407c8c62364f55f4a7287c6 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 18 Nov 2023 20:06:53 -0300 Subject: [PATCH 004/224] macroexp: type check expression body --- spec/declaration/macroexp_spec.lua | 24 +++++++++++++ tl.lua | 54 +++++++++++++++++++++++++++--- tl.tl | 54 +++++++++++++++++++++++++++--- 3 files changed, 122 insertions(+), 10 deletions(-) create mode 100644 spec/declaration/macroexp_spec.lua diff --git a/spec/declaration/macroexp_spec.lua b/spec/declaration/macroexp_spec.lua new file mode 100644 index 000000000..a537ed9d2 --- /dev/null +++ b/spec/declaration/macroexp_spec.lua @@ -0,0 +1,24 @@ +local util = require("spec.util") + +describe("macroexp declaration", function() + it("checks unused arguments", util.check_warnings([[ + local record R1 + metamethod __is: function(self: R1): boolean = macroexp(self: R1): boolean + true + end + end + ]], { + { y = 2, msg = "unused argument self: R1" } + })) + + it("checks argument mismatch", util.check_type_error([[ + local record R1 + metamethod __call: function(self: R1, n: number): boolean = macroexp(self: R1, s: string): boolean + self.field == s + end + field: string + end + ]], { + { y = 2, x = 70, msg = "macroexp type does not match declaration" } + })) +end) diff --git a/tl.lua b/tl.lua index 85372fa3f..402fbe7cb 100644 --- a/tl.lua +++ b/tl.lua @@ -2750,12 +2750,13 @@ local metamethod_names = { } local function parse_macroexp(ps, i) - local istart = i - 1 + local istart = i - local node = new_node(ps.tokens, i, "macroexp") + local node = new_node(ps.tokens, istart, "macroexp") + i = i + 1 i, node.args = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.exp = parse_expression(ps, i) @@ -2879,7 +2880,7 @@ parse_record_body = function(ps, i, def, node, name) if t.typename ~= "function" then fail(ps, i + 1, "macroexp must have a function type") end - i, t.macroexp = parse_macroexp(ps, i + 2) + i, t.macroexp = parse_macroexp(ps, i + 1) end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3602,6 +3603,14 @@ local function recurse_node(root, xs[2] = recurse(ast.exp) end, + ["macroexp"] = function(ast, xs) + recurse_typeargs(ast, visit_type) + xs[1] = recurse(ast.args) + xs[2] = recurse_type(ast.rets, visit_type) + extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse(ast.exp) + end, + ["function"] = function(ast, xs) recurse_typeargs(ast, visit_type) xs[1] = recurse(ast.args) @@ -10051,6 +10060,30 @@ tl.type_check = function(ast, opts) end_function_scope(node) + node.type = ensure_fresh_typeargs(a_type({ + y = node.y, + x = node.x, + typename = "function", + typeargs = node.typeargs, + args = children[1], + rets = children[2], + filename = filename, + })) + return node.type + end, + }, + ["macroexp"] = { + before = function(node) + widen_all_unions(node) + begin_scope(node) + end, + before_exp = function(node) + add_internal_function_variables(node) + end, + after = function(node, children) + end_function_scope(node) + + node.type = ensure_fresh_typeargs(a_type({ y = node.y, x = node.x, @@ -10519,7 +10552,8 @@ tl.type_check = function(ast, opts) return node.type end - local visit_type = { + local visit_type + visit_type = { cbs = { ["string"] = { after = function(typ, _children) @@ -10532,7 +10566,17 @@ tl.type_check = function(ast, opts) end, after = function(typ, _children) end_scope() - return ensure_fresh_typeargs(typ) + typ = ensure_fresh_typeargs(typ) + + if typ.macroexp then + recurse_node(typ.macroexp, visit_node, visit_type) + + if not is_a(typ.macroexp.type, typ) then + type_error(typ.macroexp.type, "macroexp type does not match declaration") + end + end + + return typ end, }, ["record"] = { diff --git a/tl.tl b/tl.tl index bac5f168b..215f8b4ea 100644 --- a/tl.tl +++ b/tl.tl @@ -2750,12 +2750,13 @@ local metamethod_names: {string:boolean} = { } local function parse_macroexp(ps: ParseState, i: integer): integer, Node - local istart = i - 1 + local istart = i -- TODO: generic macroexp -- if ps.tokens[i].tk == "<" then -- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) -- end - local node = new_node(ps.tokens, i, "macroexp") + local node = new_node(ps.tokens, istart, "macroexp") + i = i + 1 -- skip 'macroexp' i, node.args = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.exp = parse_expression(ps, i) @@ -2879,7 +2880,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, if t.typename ~= "function" then fail(ps, i + 1, "macroexp must have a function type") end - i, t.macroexp = parse_macroexp(ps, i + 2) + i, t.macroexp = parse_macroexp(ps, i + 1) end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3602,6 +3603,14 @@ local function recurse_node(root: Node, xs[2] = recurse(ast.exp) end, + ["macroexp"] = function(ast: Node, xs: {T}) + recurse_typeargs(ast, visit_type) + xs[1] = recurse(ast.args) + xs[2] = recurse_type(ast.rets, visit_type) + extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse(ast.exp) + end, + ["function"] = function(ast: Node, xs: {T}) recurse_typeargs(ast, visit_type) xs[1] = recurse(ast.args) @@ -10063,6 +10072,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return node.type end, }, + ["macroexp"] = { + before = function(node: Node) + widen_all_unions(node) + begin_scope(node) + end, + before_exp = function(node: Node) + add_internal_function_variables(node) + end, + after = function(node: Node, children: {Type}): Type + end_function_scope(node) + -- children[1] args + -- children[2] body + node.type = ensure_fresh_typeargs(a_type { + y = node.y, + x = node.x, + typename = "function", + typeargs = node.typeargs, + args = children[1], + rets = children[2], + filename = filename, + }) + return node.type + end, + }, ["cast"] = { after = function(node: Node, _children: {Type}): Type node.type = node.casttype @@ -10519,7 +10552,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return node.type end - local visit_type: Visitor = { + local visit_type: Visitor + visit_type = { cbs = { ["string"] = { after = function(typ: Type, _children: {Type}): Type @@ -10532,7 +10566,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, after = function(typ: Type, _children: {Type}): Type end_scope() - return ensure_fresh_typeargs(typ) + typ = ensure_fresh_typeargs(typ) + + if typ.macroexp then + recurse_node(typ.macroexp, visit_node, visit_type) + + if not is_a(typ.macroexp.type, typ) then + type_error(typ.macroexp.type, "macroexp type does not match declaration") + end + end + + return typ end, }, ["record"] = { From 800cc79be3fd3857462d5ec95bb894c0d74b7065 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 15 Sep 2023 21:53:20 -0300 Subject: [PATCH 005/224] macroexp: check that arguments are used only once --- spec/declaration/macroexp_spec.lua | 14 ++++++++++++++ tl.lua | 16 ++++++++++++++++ tl.tl | 16 ++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/spec/declaration/macroexp_spec.lua b/spec/declaration/macroexp_spec.lua index a537ed9d2..ba5b517af 100644 --- a/spec/declaration/macroexp_spec.lua +++ b/spec/declaration/macroexp_spec.lua @@ -21,4 +21,18 @@ describe("macroexp declaration", function() ]], { { y = 2, x = 70, msg = "macroexp type does not match declaration" } })) + + it("checks multiple use of arguments", util.check_type_error([[ + global function f(a: string, b:string) + print(a, b) + end + + local record R1 + metamethod __call: function(self: R1, s: string): boolean = macroexp(self: R1, s: string): boolean + print(s, s) + end + end + ]], { + { y = 7, x = 22, msg = "cannot use argument 's' multiple times in macroexp" } + })) end) diff --git a/tl.lua b/tl.lua index 402fbe7cb..0f5e05cf8 100644 --- a/tl.lua +++ b/tl.lua @@ -7635,6 +7635,20 @@ tl.type_check = function(ast, opts) orignode.expanded = p[2] end + local function check_macroexp_arg_use(macroexp) + local used = {} + + local on_arg_id = function(node, _i) + if used[node.tk] then + node_error(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + else + used[node.tk] = true + end + end + + traverse_macroexp(macroexp, on_arg_id, nil) + end + local function apply_macroexp(orignode) local expanded = orignode.expanded local savetype = orignode.type @@ -10571,6 +10585,8 @@ tl.type_check = function(ast, opts) if typ.macroexp then recurse_node(typ.macroexp, visit_node, visit_type) + check_macroexp_arg_use(typ.macroexp) + if not is_a(typ.macroexp.type, typ) then type_error(typ.macroexp.type, "macroexp type does not match declaration") end diff --git a/tl.tl b/tl.tl index 215f8b4ea..77d26b9b8 100644 --- a/tl.tl +++ b/tl.tl @@ -7635,6 +7635,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.expanded = p[2] end + local function check_macroexp_arg_use(macroexp: Node) + local used: {string:boolean} = {} + + local on_arg_id = function(node: Node, _i: integer): {Node, Node} + if used[node.tk] then + node_error(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + else + used[node.tk] = true + end + end + + traverse_macroexp(macroexp, on_arg_id, nil) + end + local function apply_macroexp(orignode: Node) local expanded = orignode.expanded local savetype = orignode.type @@ -10571,6 +10585,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if typ.macroexp then recurse_node(typ.macroexp, visit_node, visit_type) + check_macroexp_arg_use(typ.macroexp) + if not is_a(typ.macroexp.type, typ) then type_error(typ.macroexp.type, "macroexp type does not match declaration") end From 8d11028d34451c286fae6b59b0b73b54c09ca13b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 22 Aug 2023 08:51:43 -0300 Subject: [PATCH 006/224] __is: mark as a valid metamethod --- tl.lua | 2 ++ tl.tl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tl.lua b/tl.lua index 0f5e05cf8..6376e2b73 100644 --- a/tl.lua +++ b/tl.lua @@ -2747,6 +2747,7 @@ local metamethod_names = { ["__pairs"] = true, ["__gc"] = true, ["__close"] = true, + ["__is"] = true, } local function parse_macroexp(ps, i) @@ -4766,6 +4767,7 @@ local binop_to_metamethod = { ["<"] = "__lt", ["<="] = "__le", ["@index"] = "__index", + ["is"] = "__is", } local function is_unknown(t) diff --git a/tl.tl b/tl.tl index 77d26b9b8..d7520cb80 100644 --- a/tl.tl +++ b/tl.tl @@ -2747,6 +2747,7 @@ local metamethod_names: {string:boolean} = { ["__pairs"] = true, ["__gc"] = true, ["__close"] = true, + ["__is"] = true, } local function parse_macroexp(ps: ParseState, i: integer): integer, Node @@ -4766,6 +4767,7 @@ local binop_to_metamethod: {string:string} = { ["<"] = "__lt", ["<="] = "__le", ["@index"] = "__index", + ["is"] = "__is", } local function is_unknown(t: Type): boolean From 11acc34ddd7edd91c293d6b4fd57cb4468f2b61c Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 22 Aug 2023 09:00:12 -0300 Subject: [PATCH 007/224] __is: accept unions over multiple tables that declare `__is` metamethod Accordingly, accept `is` checks on table types that declare `__is`. We disallow unions mixing tables with and without `__is`. This is a limitation because to lift it we would have to implement code that transforms this ``` local u = R1 | R2 | {number} -- R1 and R2 have __is, {number} doesn't if u is {number} then -- ... end ``` into (effectively) ``` local u = R1 | R2 | {number} -- R1 and R2 have __is, {number} doesn't if not (u is R1 or u is R2) then -- ... end ``` In other words, "is" testing for the one table/userdata item without __is would have to be defined in terms of the negation of the disjunction of all the other cases. --- spec/metamethods/index_spec.lua | 2 +- spec/operator/is_spec.lua | 28 +++++++++ tl.lua | 103 +++++++++++++++++++------------ tl.tl | 105 ++++++++++++++++++++------------ 4 files changed, 158 insertions(+), 80 deletions(-) diff --git a/spec/metamethods/index_spec.lua b/spec/metamethods/index_spec.lua index de924ea0b..6e58ca63e 100644 --- a/spec/metamethods/index_spec.lua +++ b/spec/metamethods/index_spec.lua @@ -52,7 +52,7 @@ describe("metamethod __index", function() r.x = 12 print(r[r]) ]], { - { msg = "argument 2: got Rec, expected number" }, + { msg = "argument 1: got Rec, expected number" }, })) it("cannot be typechecked if the metamethod is not defined in the record", util.check_type_error([[ diff --git a/spec/operator/is_spec.lua b/spec/operator/is_spec.lua index 95a2b9ba4..00d634513 100644 --- a/spec/operator/is_spec.lua +++ b/spec/operator/is_spec.lua @@ -402,6 +402,34 @@ describe("flow analysis with is", function() { msg = "cannot index" }, })) + it("produces no errors or warnings for checks on unions of records", util.check_warnings([[ + local record R1 + metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean + true + end + end + + local record R2 + metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean + false + end + end + + local type RS = R1 | R2 + + local rs1 : RS + + if rs1 is R1 then + print("yes") + end + + local rs2 : R1 | R2 + + if rs2 is R2 then + print("yes") + end + ]], {}, {})) + it("gen cleaner checking codes for nil", util.gen([[ local record R f: function() diff --git a/tl.lua b/tl.lua index 6376e2b73..867fb7dbf 100644 --- a/tl.lua +++ b/tl.lua @@ -6039,9 +6039,9 @@ tl.type_check = function(ast, opts) local function union_type(t) if is_typetype(t) then - return union_type(t.def) + return union_type(t.def), t.def elseif t.typename == "tuple" then - return union_type(t[1]) + return union_type(t[1]), t[1] elseif t.typename == "nominal" then local typetype = t.found or find_type(t.names) if not typetype then @@ -6050,13 +6050,13 @@ tl.type_check = function(ast, opts) return union_type(typetype) elseif t.typename == "record" then if t.is_userdata then - return "userdata" + return "userdata", t end - return "table" + return "table", t elseif table_types[t.typename] then - return "table" + return "table", t else - return t.typename + return t.typename, t end end @@ -6068,21 +6068,43 @@ tl.type_check = function(ast, opts) local n_table_types = 0 + local n_table_is_types = 0 local n_function_types = 0 local n_userdata_types = 0 + local n_userdata_is_types = 0 local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut = union_type(t) + local ut, rt = union_type(t) if ut == "userdata" then - n_userdata_types = n_userdata_types + 1 - if n_userdata_types > 1 then - return false, "cannot discriminate a union between multiple userdata types: %s" + if rt.meta_fields and rt.meta_fields["__is"] then + n_userdata_is_types = n_userdata_is_types + 1 + if n_userdata_types > 0 then + return false, "cannot mix userdata types with and without __is metamethod: %s" + end + else + n_userdata_types = n_userdata_types + 1 + if n_userdata_types > 1 then + return false, "cannot discriminate a union between multiple userdata types: %s" + end + if n_userdata_is_types > 0 then + return false, "cannot mix userdata types with and without __is metamethod: %s" + end end elseif ut == "table" then - n_table_types = n_table_types + 1 - if n_table_types > 1 then - return false, "cannot discriminate a union between multiple table types: %s" + if rt.meta_fields and rt.meta_fields["__is"] then + n_table_is_types = n_table_is_types + 1 + if n_table_types > 0 then + return false, "cannot mix table types with and without __is metamethod: %s" + end + else + n_table_types = n_table_types + 1 + if n_table_types > 1 then + return false, "cannot discriminate a union between multiple table types: %s" + end + if n_table_is_types > 0 then + return false, "cannot mix table types with and without __is metamethod: %s" + end end elseif ut == "function" then n_function_types = n_function_types + 1 @@ -7748,7 +7770,7 @@ tl.type_check = function(ast, opts) rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, rets, 1, 0, "return") end - args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, 1, argdelta, "argument") + args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, from, argdelta, "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -7912,35 +7934,31 @@ tl.type_check = function(ast, opts) end end - local function check_metamethod(node, op, a, b, orig_a, orig_b) - local method_name - local where_args - local args - local meta_on_operator = 1 - + local function check_metamethod(node, method_name, a, b, orig_a, orig_b) if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return UNKNOWN, nil elseif not a.meta_fields and not (b and b.meta_fields) then return nil, nil end - if a and b then - method_name = binop_to_metamethod[op] - where_args = { node.e1, node.e2 } - args = { typename = "tuple", orig_a, orig_b } - else - method_name = unop_to_metamethod[op] - where_args = { node.e1 } - args = { typename = "tuple", orig_a } + local meta_on_operator = 1 + local metamethod + if method_name ~= "__is" then + metamethod = a.meta_fields and a.meta_fields[method_name or ""] end - - local metamethod = a.meta_fields and a.meta_fields[method_name or ""] - if (not metamethod) and b and op ~= "@index" then + if (not metamethod) and b and method_name ~= "__index" then metamethod = b.meta_fields and b.meta_fields[method_name or ""] meta_on_operator = 2 end + if metamethod then - return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, false, 0)), meta_on_operator + local where_args = { node.e1 } + local args = { typename = "tuple", orig_a } + if b and method_name ~= "__is" then + where_args[2] = node.e2 + args[2] = orig_b + end + return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, true)), meta_on_operator else return nil, nil end @@ -7970,7 +7988,7 @@ tl.type_check = function(ast, opts) return tbl.fields[key] end - local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING) + local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) if meta_t then return meta_t end @@ -8341,7 +8359,7 @@ tl.type_check = function(ast, opts) errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b end - local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b) + local meta_t = check_metamethod(anode, "__index", a, orig_b, orig_a, orig_b) if meta_t then return meta_t end @@ -10244,6 +10262,7 @@ tl.type_check = function(ast, opts) if ra.typename == "typetype" then node_error(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then + check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) node.known = Fact({ fact = "is", var = node.e1.tk, typ = b, where = node }) else node_error(node, "can only use 'is' on variables") @@ -10309,9 +10328,9 @@ tl.type_check = function(ast, opts) elseif node.op.op == "==" or node.op.op == "~=" then node.type = BOOLEAN - if is_lua_table_type(ra) and is_lua_table_type(rb) then - check_metamethod(node, node.op.op, ra, rb) - end + + + if is_a(b, a, true) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then @@ -10336,7 +10355,10 @@ tl.type_check = function(ast, opts) node.type = types_op[a.typename] local meta_on_operator if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil) + local mt_name = unop_to_metamethod[node.op.op] + if mt_name then + node.type, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) + end if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) end @@ -10382,7 +10404,10 @@ tl.type_check = function(ast, opts) node.type = types_op[a.typename] and types_op[a.typename][b.typename] local meta_on_operator if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b) + local mt_name = binop_to_metamethod[node.op.op] + if mt_name then + node.type, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) + end if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then diff --git a/tl.tl b/tl.tl index d7520cb80..8f77bec7e 100644 --- a/tl.tl +++ b/tl.tl @@ -6037,11 +6037,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function union_type(t: Type): string + local function union_type(t: Type): string, Type if is_typetype(t) then - return union_type(t.def) + return union_type(t.def), t.def elseif t.typename == "tuple" then - return union_type(t[1]) + return union_type(t[1]), t[1] elseif t.typename == "nominal" then local typetype = t.found or find_type(t.names) if not typetype then @@ -6050,13 +6050,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return union_type(typetype) elseif t.typename == "record" then if t.is_userdata then - return "userdata" + return "userdata", t end - return "table" + return "table", t elseif table_types[t.typename] then - return "table" + return "table", t else - return t.typename + return t.typename, t end end @@ -6068,21 +6068,43 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- check for limitations in our union support -- due to codegen limitations (we only check with type() so far) local n_table_types = 0 + local n_table_is_types = 0 local n_function_types = 0 local n_userdata_types = 0 + local n_userdata_is_types = 0 local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut = union_type(t) + local ut, rt = union_type(t) if ut == "userdata" then -- must be tested before table_types - n_userdata_types = n_userdata_types + 1 - if n_userdata_types > 1 then - return false, "cannot discriminate a union between multiple userdata types: %s" + if rt.meta_fields and rt.meta_fields["__is"] then + n_userdata_is_types = n_userdata_is_types + 1 + if n_userdata_types > 0 then + return false, "cannot mix userdata types with and without __is metamethod: %s" + end + else + n_userdata_types = n_userdata_types + 1 + if n_userdata_types > 1 then + return false, "cannot discriminate a union between multiple userdata types: %s" + end + if n_userdata_is_types > 0 then + return false, "cannot mix userdata types with and without __is metamethod: %s" + end end elseif ut == "table" then - n_table_types = n_table_types + 1 - if n_table_types > 1 then - return false, "cannot discriminate a union between multiple table types: %s" + if rt.meta_fields and rt.meta_fields["__is"] then + n_table_is_types = n_table_is_types + 1 + if n_table_types > 0 then + return false, "cannot mix table types with and without __is metamethod: %s" + end + else + n_table_types = n_table_types + 1 + if n_table_types > 1 then + return false, "cannot discriminate a union between multiple table types: %s" + end + if n_table_is_types > 0 then + return false, "cannot mix table types with and without __is metamethod: %s" + end end elseif ut == "function" then n_function_types = n_function_types + 1 @@ -7748,7 +7770,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, rets, 1, 0, "return") end - args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, 1, argdelta, "argument") + args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, from, argdelta, "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -7912,35 +7934,31 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function check_metamethod(node: Node, op: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer - local method_name: string - local where_args: {Node} - local args: Type - local meta_on_operator = 1 - + local function check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return UNKNOWN, nil elseif not a.meta_fields and not (b and b.meta_fields) then return nil, nil end - if a and b then - method_name = binop_to_metamethod[op] - where_args = { node.e1, node.e2 } - args = { typename = "tuple", orig_a, orig_b } - else - method_name = unop_to_metamethod[op] - where_args = { node.e1 } - args = { typename = "tuple", orig_a } + local meta_on_operator = 1 + local metamethod: Type + if method_name ~= "__is" then + metamethod = a.meta_fields and a.meta_fields[method_name or ""] end - - local metamethod = a.meta_fields and a.meta_fields[method_name or ""] - if (not metamethod) and b and op ~= "@index" then + if (not metamethod) and b and method_name ~= "__index" then metamethod = b.meta_fields and b.meta_fields[method_name or ""] meta_on_operator = 2 end + if metamethod then - return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, false, 0)), meta_on_operator + local where_args = { node.e1 } + local args = { typename = "tuple", orig_a } + if b and method_name ~= "__is" then + where_args[2] = node.e2 + args[2] = orig_b + end + return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, true)), meta_on_operator else return nil, nil end @@ -7970,7 +7988,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return tbl.fields[key] end - local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING) + local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) if meta_t then return meta_t end @@ -8341,7 +8359,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b end - local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b) + local meta_t = check_metamethod(anode, "__index", a, orig_b, orig_a, orig_b) if meta_t then return meta_t end @@ -10244,6 +10262,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ra.typename == "typetype" then node_error(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then + check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) node.known = Fact { fact = "is", var = node.e1.tk, typ = b, where = node } else node_error(node, "can only use 'is' on variables") @@ -10309,9 +10328,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif node.op.op == "==" or node.op.op == "~=" then node.type = BOOLEAN - if is_lua_table_type(ra) and is_lua_table_type(rb) then - check_metamethod(node, node.op.op, ra, rb) - end +-- if is_lua_table_type(ra) and is_lua_table_type(rb) then +-- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) +-- end if is_a(b, a, true) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then @@ -10336,7 +10355,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.type = types_op[a.typename] local meta_on_operator: integer if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil) + local mt_name = unop_to_metamethod[node.op.op] + if mt_name then + node.type, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) + end if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) end @@ -10382,7 +10404,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.type = types_op[a.typename] and types_op[a.typename][b.typename] local meta_on_operator: integer if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b) + local mt_name = binop_to_metamethod[node.op.op] + if mt_name then + node.type, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) + end if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then From 7decf46c67cd2e2d0824b2a8c0878d6b0520676a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 21 Oct 2023 14:04:57 -0300 Subject: [PATCH 008/224] __is: add macroexp tests for is operator --- spec/macroexp/is_spec.lua | 119 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 spec/macroexp/is_spec.lua diff --git a/spec/macroexp/is_spec.lua b/spec/macroexp/is_spec.lua new file mode 100644 index 000000000..b359a332b --- /dev/null +++ b/spec/macroexp/is_spec.lua @@ -0,0 +1,119 @@ +local util = require("spec.util") + +describe("__is with macroexp", function() + it("can expand a constant expression", util.gen([[ + local record R1 + metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean + true + end + end + + local record R2 + metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean + false + end + end + + local type RS = R1 | R2 + + local rs1 : RS + + if rs1 is R1 then + print("yes") + end + + local rs2 : R1 | R2 + + if rs2 is R2 then + print("yes") + end + ]], [[ + + + + + + + + + + + + + + + local rs1 + + if true then + print("yes") + end + + local rs2 + + if false then + print("yes") + end + ]])) + + it("can expand self in an expression", util.gen([[ + local record R1 + metamethod __is: function(self: R1|R2): boolean = macroexp(self: R1|R2): boolean + self.kind == "r1" + end + + kind: string + end + + local record R2 + metamethod __is: function(self: R1|R2): boolean = macroexp(self: R1|R2): boolean + self.kind == "r2" + end + + kind: string + end + + local type RS = R1 | R2 + + local rs1 : RS = { kind = "r1" } + + if rs1 is R1 then + print("yes") + end + + local rs2 : R1 | R2 = { kind = "r2" } + + if rs2 is R2 then + print("yes") + end + ]], [[ + + + + + + + + + + + + + + + + + + + local rs1 = { kind = "r1" } + + if rs1.kind == "r1" then + print("yes") + end + + local rs2 = { kind = "r2" } + + if rs2.kind == "r2" then + print("yes") + end + ]])) +end) From e3bf9fda2ec7b3c9953ef43827892d60879fcc8a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 22 Aug 2023 09:46:28 -0300 Subject: [PATCH 009/224] accept indexing a union when index is the same is all entries --- tl.lua | 46 ++++++++++++++++++++++++++++++++++++++++++++++ tl.tl | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/tl.lua b/tl.lua index 867fb7dbf..e0af83795 100644 --- a/tl.lua +++ b/tl.lua @@ -7573,6 +7573,34 @@ tl.type_check = function(ast, opts) end end + local function same_in_all_union_entries(u, check) + local t1, f = check(u.types[1]) + if not t1 then + return nil + end + for i = 2, #u.types do + local t2 = check(u.types[i]) + if not t2 or not same_type(t1, t2) then + return nil + end + end + return f or t1 + end + + local function same_call_mt_in_all_union_entries(tbl) + return same_in_all_union_entries(tbl, function(t) + t = resolve_tuple_and_nominal(t) + local call_mt = t.meta_fields and t.meta_fields["__call"] + if call_mt then + local args_tuple = a_type({ typename = "tuple" }) + for i = 2, #call_mt.args do + table.insert(args_tuple, call_mt.args[i]) + end + return args_tuple, call_mt + end + end) + end + local function resolve_for_call(func, args, is_method) if lax and is_unknown(func) then @@ -7582,6 +7610,14 @@ tl.type_check = function(ast, opts) func = resolve_tuple_and_nominal(func) if func.typename ~= "function" and func.typename ~= "poly" then + if func.typename == "union" then + local r = same_call_mt_in_all_union_entries(func) + if r then + table.insert(args, 1, func.types[1]) + return resolve_tuple_and_nominal(r), true + end + end + if is_typetype(func) and func.def.typename == "record" then func = func.def end @@ -7981,6 +8017,16 @@ tl.type_check = function(ast, opts) tbl = resolve_typetype(tbl) + if tbl.typename == "union" then + local t = same_in_all_union_entries(tbl, function(t) + return (match_record_key(t, rec, key)) + end) + + if t then + return t + end + end + if is_record_type(tbl) then assert(tbl.fields, "record has no fields!?") diff --git a/tl.tl b/tl.tl index 8f77bec7e..7f3536df6 100644 --- a/tl.tl +++ b/tl.tl @@ -7573,6 +7573,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function same_in_all_union_entries(u: Type, check: function(Type): (Type, Type)): Type + local t1, f = check(u.types[1]) + if not t1 then + return nil + end + for i = 2, #u.types do + local t2 = check(u.types[i]) + if not t2 or not same_type(t1, t2) then + return nil + end + end + return f or t1 + end + + local function same_call_mt_in_all_union_entries(tbl: Type): Type + return same_in_all_union_entries(tbl, function(t: Type): (Type, Type) + t = resolve_tuple_and_nominal(t) + local call_mt = t.meta_fields and t.meta_fields["__call"] + if call_mt then + local args_tuple = a_type { typename = "tuple" } + for i = 2, #call_mt.args do + table.insert(args_tuple, call_mt.args[i]) + end + return args_tuple, call_mt + end + end) + end + local function resolve_for_call(func: Type, args: {Type}, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function if lax and is_unknown(func) then @@ -7581,6 +7609,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- unwrap if tuple, resolve if nominal func = resolve_tuple_and_nominal(func) if func.typename ~= "function" and func.typename ~= "poly" then + -- resolve if union + if func.typename == "union" then + local r = same_call_mt_in_all_union_entries(func) + if r then + table.insert(args, 1, func.types[1]) -- FIXME: is this right? + return resolve_tuple_and_nominal(r), true + end + end -- resolve if prototype if is_typetype(func) and func.def.typename == "record" then func = func.def @@ -7981,6 +8017,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tbl = resolve_typetype(tbl) + if tbl.typename == "union" then + local t = same_in_all_union_entries(tbl, function(t: Type): (Type, Type) + return (match_record_key(t, rec, key)) + end) + + if t then + return t + end + end + if is_record_type(tbl) then assert(tbl.fields, "record has no fields!?") From bf791d1c6d90fd12939f89e8cb8dc01d60492a4f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 27 Aug 2023 01:30:43 -0300 Subject: [PATCH 010/224] refactor: use __is and macroexps in Fact --- tl.lua | 261 +++++++++++++++++++++++++++++++++++---------- tl.tl | 327 +++++++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 441 insertions(+), 147 deletions(-) diff --git a/tl.lua b/tl.lua index e0af83795..96036a8c2 100644 --- a/tl.lua +++ b/tl.lua @@ -1219,12 +1219,100 @@ local table_types = { -local Fact = {} + + + + + + + + + + + + + + + + + + +local TruthyFact = {} + + + + + + + + + + +local NotFact = {} + + + + + + + + + + + + +local AndFact = {} + + + + + + + + + + + + + +local OrFact = {} + + + + + + + + + + + + + +local EqFact = {} + + + + + + + + + + + + + +local IsFact = {} + + + + @@ -8000,7 +8088,9 @@ tl.type_check = function(ast, opts) end end - local function match_record_key(tbl, rec, key) + local match_record_key + + match_record_key = function(tbl, rec, key) assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") @@ -8539,37 +8629,93 @@ tl.type_check = function(ast, opts) local apply_facts local FACT_TRUTHY do - setmetatable(Fact, { + local IsFact_mt = { + __tostring = function(f) + return ("(%s is %s)"):format(f.var, show_type(f.typ)) + end, + } + + setmetatable(IsFact, { __call = function(_, fact) - return setmetatable(fact, { - __tostring = function(f) - if f.fact == "is" then - return ("(%s is %s)"):format(f.var, show_type(f.typ)) - elseif f.fact == "==" then - return ("(%s == %s)"):format(f.var, show_type(f.typ)) - elseif f.fact == "truthy" then - return "*" - elseif f.fact == "not" then - return ("(not %s)"):format(tostring(f.f1)) - elseif f.fact == "or" then - return ("(%s or %s)"):format(tostring(f.f1), tostring(f.f2)) - elseif f.fact == "and" then - return ("(%s and %s)"):format(tostring(f.f1), tostring(f.f2)) - end - end, - }) + fact.fact = "is" + return setmetatable(fact, IsFact_mt) + end, + }) + + local EqFact_mt = { + __tostring = function(f) + return ("(%s == %s)"):format(f.var, show_type(f.typ)) + end, + } + + setmetatable(EqFact, { + __call = function(_, fact) + fact.fact = "==" + return setmetatable(fact, EqFact_mt) + end, + }) + + local TruthyFact_mt = { + __tostring = function(_f) + return "*" + end, + } + + setmetatable(TruthyFact, { + __call = function(_, fact) + fact.fact = "truthy" + return setmetatable(fact, TruthyFact_mt) + end, + }) + + local NotFact_mt = { + __tostring = function(f) + return ("(not %s)"):format(tostring(f.f1)) + end, + } + + setmetatable(NotFact, { + __call = function(_, fact) + fact.fact = "not" + return setmetatable(fact, NotFact_mt) + end, + }) + + local AndFact_mt = { + __tostring = function(f) + return ("(%s and %s)"):format(tostring(f.f1), tostring(f.f2)) + end, + } + + setmetatable(AndFact, { + __call = function(_, fact) + fact.fact = "and" + return setmetatable(fact, AndFact_mt) + end, + }) + + local OrFact_mt = { + __tostring = function(f) + return ("(%s or %s)"):format(tostring(f.f1), tostring(f.f2)) + end, + } + + setmetatable(OrFact, { + __call = function(_, fact) + fact.fact = "or" + return setmetatable(fact, OrFact_mt) end, }) - FACT_TRUTHY = Fact({ fact = "truthy" }) + FACT_TRUTHY = TruthyFact({}) facts_and = function(where, f1, f2) - return Fact({ fact = "and", f1 = f1, f2 = f2, where = where }) + return AndFact({ f1 = f1, f2 = f2, where = where }) end facts_or = function(where, f1, f2) if f1 and f2 then - return Fact({ fact = "or", f1 = f1, f2 = f2, where = where }) + return OrFact({ f1 = f1, f2 = f2, where = where }) else return nil end @@ -8577,7 +8723,7 @@ tl.type_check = function(ast, opts) facts_not = function(where, f1) if f1 then - return Fact({ fact = "not", f1 = f1, where = where }) + return NotFact({ f1 = f1, where = where }) else return nil end @@ -8661,35 +8807,31 @@ tl.type_check = function(ast, opts) local eval_fact local function invalid_from(f) - return Fact({ fact = "is", var = f.var, typ = INVALID, where = f.where }) + return IsFact({ fact = "is", var = f.var, typ = INVALID, where = f.where }) end not_facts = function(fs) local ret = {} for var, f in pairs(fs) do local typ = find_var_type(f.var, "check_only") - local fact = "==" - local where = f.where + if not typ then - typ = INVALID + ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + elseif f.fact == "==" then + + ret[var] = EqFact({ var = var, typ = typ }) + elseif typ.typename == "typevar" then + assert(f.fact == "is") + + ret[var] = EqFact({ var = var, typ = typ }) + elseif not is_a(f.typ, typ) then + assert(f.fact == "is") + node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) else - if f.fact == "is" then - if typ.typename == "typevar" then - - where = nil - elseif not is_a(f.typ, typ) then - node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - typ = INVALID - else - fact = "is" - typ = subtract_types(typ, f.typ) - end - elseif f.fact == "==" then - - where = nil - end + assert(f.fact == "is") + ret[var] = IsFact({ var = var, typ = subtract_types(typ, f.typ), where = f.where }) end - ret[var] = Fact({ fact = fact, var = var, typ = typ, where = where }) end return ret end @@ -8719,9 +8861,12 @@ tl.type_check = function(ast, opts) for var, f in pairs(fs2) do if fs1[var] then - local fact = (fs1[var].fact == "is" and f.fact == "is") and - "is" or "==" - ret[var] = Fact({ fact = fact, var = var, typ = unite_types(f.typ, fs1[var].typ), where = f.where }) + local united = unite_types(f.typ, fs1[var].typ) + if fs1[var].fact == "is" and f.fact == "is" then + ret[var] = IsFact({ var = var, typ = united, where = f.where }) + else + ret[var] = EqFact({ var = var, typ = united, where = f.where }) + end end end @@ -8734,21 +8879,23 @@ tl.type_check = function(ast, opts) for var, f in pairs(fs1) do local rt - local fact + local ctor = EqFact if fs2[var] then - fact = (fs2[var].fact == "is" and f.fact == "is") and "is" or "==" + if fs2[var].fact == "is" and f.fact == "is" then + ctor = IsFact + end rt = intersect_types(f.typ, fs2[var].typ) else - fact = "==" rt = f.typ end - ret[var] = Fact({ fact = fact, var = var, typ = rt, where = f.where }) - has[fact] = true + local ff = ctor({ var = var, typ = rt, where = f.where }) + ret[var] = ff + has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = Fact({ fact = "==", var = var, typ = f.typ, where = f.where }) + ret[var] = EqFact({ var = var, typ = f.typ, where = f.where }) has["=="] = true end end @@ -8793,7 +8940,7 @@ tl.type_check = function(ast, opts) return eval_not(f.f1) elseif f.fact == "and" then return and_facts(eval_fact(f.f1), eval_fact(f.f2)) - elseif f.fact == "or" then + else return or_facts(eval_fact(f.f1), eval_fact(f.f2)) end end @@ -10309,7 +10456,7 @@ tl.type_check = function(ast, opts) node_error(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) - node.known = Fact({ fact = "is", var = node.e1.tk, typ = b, where = node }) + node.known = IsFact({ var = node.e1.tk, typ = b, where = node }) else node_error(node, "can only use 'is' on variables") end @@ -10380,11 +10527,11 @@ tl.type_check = function(ast, opts) if is_a(b, a, true) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = Fact({ fact = "==", var = node.e1.tk, typ = b, where = node }) + node.known = EqFact({ var = node.e1.tk, typ = b, where = node }) end elseif is_a(a, b, true) or b.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = Fact({ fact = "==", var = node.e2.tk, typ = a, where = node }) + node.known = EqFact({ var = node.e2.tk, typ = a, where = node }) end elseif lax and (is_unknown(a) or is_unknown(b)) then node.type = UNKNOWN diff --git a/tl.tl b/tl.tl index 7f3536df6..1bfcc8a51 100644 --- a/tl.tl +++ b/tl.tl @@ -1219,19 +1219,107 @@ local enum FactType "truthy" -- expression that is either truthy or a runtime error end -local record Fact +--local record Fact +-- fact: FactType +-- where: Node +-- +-- -- is +-- var: string +-- typ: Type +-- +-- -- not, and, or +-- f1: Fact +-- f2: Fact +-- +-- metamethod __call: function(Fact, Fact): Fact +--end + +local type Fact + = TruthyFact + | NotFact + | AndFact + | OrFact + | IsFact + | EqFact + +local record TruthyFact + metamethod __is: function(self: Fact): boolean = macroexp(self: TruthyFact): boolean + self.fact == "truthy" + end + fact: FactType where: Node - -- is - var: string - typ: Type + metamethod __call: function(Fact, Fact): TruthyFact +end + +local record NotFact + metamethod __is: function(self: Fact): boolean = macroexp(self: NotFact): boolean + self.fact == "not" + end + + fact: FactType + where: Node + + f1: Fact + + metamethod __call: function(Fact, Fact): NotFact +end + +local record AndFact + metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean + self.fact == "and" + end + + fact: FactType + where: Node - -- not, and, or f1: Fact f2: Fact - metamethod __call: function(Fact, Fact): Fact + metamethod __call: function(Fact, Fact): AndFact +end + +local record OrFact + metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean + self.fact == "or" + end + + fact: FactType + where: Node + + f1: Fact + f2: Fact + + metamethod __call: function(Fact, Fact): OrFact +end + +local record EqFact + metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean + self.fact == "==" + end + + fact: FactType + where: Node + + var: string + typ: Type + + metamethod __call: function(Fact, Fact): EqFact +end + +local record IsFact + metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean + self.fact == "is" + end + + fact: FactType + where: Node + + var: string + typ: Type + + metamethod __call: function(Fact, Fact): IsFact end local enum KeyParsed @@ -8000,7 +8088,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function match_record_key(tbl: Type, rec: Node, key: string): Type, string + local match_record_key: function(tbl: Type, rec: Node, key: string): Type, string + + match_record_key = function(tbl: Type, rec: Node, key: string): Type, string assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") @@ -8539,37 +8629,93 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local apply_facts: function(where: Node, known: Fact) local FACT_TRUTHY: Fact do - setmetatable(Fact, { - __call = function(_: Fact, fact: Fact): Fact - return setmetatable(fact, { - __tostring = function(f: Fact): string - if f.fact == "is" then - return ("(%s is %s)"):format(f.var, show_type(f.typ)) - elseif f.fact == "==" then - return ("(%s == %s)"):format(f.var, show_type(f.typ)) - elseif f.fact == "truthy" then - return "*" - elseif f.fact == "not" then - return ("(not %s)"):format(tostring(f.f1)) - elseif f.fact == "or" then - return ("(%s or %s)"):format(tostring(f.f1), tostring(f.f2)) - elseif f.fact == "and" then - return ("(%s and %s)"):format(tostring(f.f1), tostring(f.f2)) - end - end - }) + local IsFact_mt: metatable = { + __tostring = function(f: IsFact): string + return ("(%s is %s)"):format(f.var, show_type(f.typ)) + end + } + + setmetatable(IsFact, { + __call = function(_: IsFact, fact: Fact): IsFact + fact.fact = "is" + return setmetatable(fact as IsFact, IsFact_mt) + end, + }) + + local EqFact_mt: metatable = { + __tostring = function(f: EqFact): string + return ("(%s == %s)"):format(f.var, show_type(f.typ)) + end + } + + setmetatable(EqFact, { + __call = function(_: EqFact, fact: Fact): EqFact + fact.fact = "==" + return setmetatable(fact as EqFact, EqFact_mt) + end, + }) + + local TruthyFact_mt: metatable = { + __tostring = function(_f: TruthyFact): string + return "*" + end + } + + setmetatable(TruthyFact, { + __call = function(_: TruthyFact, fact: Fact): TruthyFact + fact.fact = "truthy" + return setmetatable(fact as TruthyFact, TruthyFact_mt) + end, + }) + + local NotFact_mt: metatable = { + __tostring = function(f: NotFact): string + return ("(not %s)"):format(tostring(f.f1)) + end + } + + setmetatable(NotFact, { + __call = function(_: NotFact, fact: Fact): NotFact + fact.fact = "not" + return setmetatable(fact as NotFact, NotFact_mt) + end, + }) + + local AndFact_mt: metatable = { + __tostring = function(f: AndFact): string + return ("(%s and %s)"):format(tostring(f.f1), tostring(f.f2)) + end + } + + setmetatable(AndFact, { + __call = function(_: AndFact, fact: Fact): AndFact + fact.fact = "and" + return setmetatable(fact as AndFact, AndFact_mt) end, }) - FACT_TRUTHY = Fact { fact = "truthy" } + local OrFact_mt: metatable = { + __tostring = function(f: OrFact): string + return ("(%s or %s)"):format(tostring(f.f1), tostring(f.f2)) + end + } + + setmetatable(OrFact, { + __call = function(_: OrFact, fact: Fact): OrFact + fact.fact = "or" + return setmetatable(fact as OrFact, OrFact_mt) + end, + }) + + FACT_TRUTHY = TruthyFact {} facts_and = function(where: Node, f1: Fact, f2: Fact): Fact - return Fact({ fact = "and", f1 = f1, f2 = f2, where = where }) + return AndFact { f1 = f1, f2 = f2, where = where } end facts_or = function(where: Node, f1: Fact, f2: Fact): Fact if f1 and f2 then - return Fact { fact = "or", f1 = f1, f2 = f2, where = where } + return OrFact { f1 = f1, f2 = f2, where = where } else return nil end @@ -8577,7 +8723,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string facts_not = function(where: Node, f1: Fact): Fact if f1 then - return Fact { fact = "not", f1 = f1, where = where } + return NotFact { f1 = f1, where = where } else return nil end @@ -8654,101 +8800,102 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return unite(types) end - local eval_not: function(f: Fact): {string:Fact} - local not_facts: function(fs: {string:Fact}): {string:Fact} - local or_facts: function(fs1: {string:Fact}, fs2: {string:Fact}): {string:Fact} - local and_facts: function(fs1: {string:Fact}, fs2: {string:Fact}): {string:Fact} - local eval_fact: function(f: Fact): {string:Fact} + local eval_not: function(f: Fact): {string:IsFact|EqFact} + local not_facts: function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local or_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local and_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local eval_fact: function(f: Fact): {string:IsFact|EqFact} - local function invalid_from(f: Fact): Fact - return Fact { fact = "is", var = f.var, typ = INVALID, where = f.where } + local function invalid_from(f: IsFact): IsFact + return IsFact { fact = "is", var = f.var, typ = INVALID, where = f.where } end - not_facts = function(fs: {string:Fact}): {string:Fact} - local ret: {string:Fact} = {} + not_facts = function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs) do local typ = find_var_type(f.var, "check_only") - local fact: FactType = "==" - local where = f.where + if not typ then - typ = INVALID + ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + elseif f is EqFact then + -- nothing is known from negation of equality; widen back + ret[var] = EqFact { var = var, typ = typ } + elseif typ.typename == "typevar" then + assert(f.fact == "is") + -- nothing is known from negation on typeargs; widen back (no 'where') + ret[var] = EqFact { var = var, typ = typ } + elseif not is_a(f.typ, typ) then + assert(f.fact == "is") + node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact { var = var, typ = INVALID, where = f.where } else - if f.fact == "is" then - if typ.typename == "typevar" then - -- nothing is known from negation on typeargs; widen back - where = nil - elseif not is_a(f.typ, typ) then - node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - typ = INVALID - else - fact = "is" - typ = subtract_types(typ, f.typ) - end - elseif f.fact == "==" then - -- nothing is known from negation of equality; widen back - where = nil - end + assert(f.fact == "is") + ret[var] = IsFact { var = var, typ = subtract_types(typ, f.typ), where = f.where } end - ret[var] = Fact { fact = fact, var = var, typ = typ, where = where } end return ret end - eval_not = function(f: Fact): {string:Fact} + eval_not = function(f: Fact): {string:IsFact|EqFact} if not f then return {} - elseif f.fact == "is" then + elseif f is IsFact then return not_facts({[f.var] = f}) - elseif f.fact == "not" then + elseif f is NotFact then return eval_fact(f.f1) - elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then + elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then return eval_not(f.f1) - elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then + elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then return eval_fact(f.f1) - elseif f.fact == "and" then + elseif f is AndFact then return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) - elseif f.fact == "or" then + elseif f is OrFact then return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) else return not_facts(eval_fact(f)) end end - or_facts = function(fs1: {string:Fact}, fs2: {string:Fact}): {string:Fact} - local ret: {string:Fact} = {} + or_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs2) do if fs1[var] then - local fact: FactType = (fs1[var].fact == "is" and f.fact == "is") - and "is" or "==" - ret[var] = Fact { fact = fact, var = var, typ = unite_types(f.typ, fs1[var].typ), where = f.where } + local united = unite_types(f.typ, fs1[var].typ) + if fs1[var].fact == "is" and f.fact == "is" then + ret[var] = IsFact { var = var, typ = united, where = f.where } + else + ret[var] = EqFact { var = var, typ = united, where = f.where } + end end end return ret end - and_facts = function(fs1: {string:Fact}, fs2: {string:Fact}): {string:Fact} - local ret: {string:Fact} = {} + and_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local ret: {string:IsFact|EqFact} = {} local has: {FactType:boolean} = {} for var, f in pairs(fs1) do local rt: Type - local fact: FactType + local ctor: IsFact | EqFact = EqFact if fs2[var] then - fact = (fs2[var].fact == "is" and f.fact == "is") and "is" or "==" + if fs2[var].fact == "is" and f.fact == "is" then + ctor = IsFact + end rt = intersect_types(f.typ, fs2[var].typ) else - fact = "==" rt = f.typ end - ret[var] = Fact { fact = fact, var = var, typ = rt, where = f.where } - has[fact] = true + local ff = ctor { var = var, typ = rt, where = f.where } + ret[var] = ff + has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = Fact { fact = "==", var = var, typ = f.typ, where = f.where } + ret[var] = EqFact { var = var, typ = f.typ, where = f.where } has["=="] = true end end @@ -8762,10 +8909,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - eval_fact = function(f: Fact): {string:Fact} + eval_fact = function(f: Fact): {string:IsFact|EqFact} if not f then return {} - elseif f.fact == "is" then + elseif f is IsFact then local typ = find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } @@ -8781,19 +8928,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end return { [f.var] = f } - elseif f.fact == "==" then + elseif f is EqFact then return { [f.var] = f } - elseif f.fact == "not" then + elseif f is NotFact then return eval_not(f.f1) - elseif f.fact == "truthy" then + elseif f is TruthyFact then return {} - elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then + elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then return eval_fact(f.f1) - elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then + elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then return eval_not(f.f1) - elseif f.fact == "and" then + elseif f is AndFact then return and_facts(eval_fact(f.f1), eval_fact(f.f2)) - elseif f.fact == "or" then + else -- f is OrFact return or_facts(eval_fact(f.f1), eval_fact(f.f2)) end end @@ -10309,7 +10456,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node_error(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) - node.known = Fact { fact = "is", var = node.e1.tk, typ = b, where = node } + node.known = IsFact { var = node.e1.tk, typ = b, where = node } else node_error(node, "can only use 'is' on variables") end @@ -10380,11 +10527,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if is_a(b, a, true) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = Fact { fact = "==", var = node.e1.tk, typ = b, where = node } + node.known = EqFact { var = node.e1.tk, typ = b, where = node } end elseif is_a(a, b, true) or b.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = Fact { fact = "==", var = node.e2.tk, typ = a, where = node } + node.known = EqFact { var = node.e2.tk, typ = a, where = node } end elseif lax and (is_unknown(a) or is_unknown(b)) then node.type = UNKNOWN From a40d062c700349e56f792fbae64ce6e5bba64c27 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 1 Sep 2023 22:22:36 -0300 Subject: [PATCH 011/224] refactor: use Where union, rename error functions --- tl.lua | 200 ++++++++++++++++++++-------------------- tl.tl | 282 +++++++++++++++++++++++++++++---------------------------- 2 files changed, 240 insertions(+), 242 deletions(-) diff --git a/tl.lua b/tl.lua index 96036a8c2..e8fa8aee0 100644 --- a/tl.lua +++ b/tl.lua @@ -1239,6 +1239,8 @@ local table_types = { + + @@ -5023,7 +5025,7 @@ local function show_type_base(t, short, seen) end local function inferred_msg(t) - return " (inferred at " .. t.inferred_at_file .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" + return " (inferred at " .. t.inferred_at.filename .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" end show_type = function(t, short, seen) @@ -6063,7 +6065,7 @@ tl.type_check = function(ast, opts) end end - local function error_in_type(where, msg, ...) + local function Err(where, msg, ...) local n = select("#", ...) if n > 0 then local showt = {} @@ -6087,8 +6089,8 @@ tl.type_check = function(ast, opts) } end - local function type_error(t, msg, ...) - local e = error_in_type(t, msg, ...) + local function error_at(w, msg, ...) + local e = Err(w, msg, ...) if e then table.insert(errors, e) return true @@ -6222,7 +6224,7 @@ tl.type_check = function(ast, opts) else errs = errors end - table.insert(errs, error_in_type(where, err, u)) + table.insert(errs, Err(where, err, u)) end if not valid then u = INVALID @@ -6430,28 +6432,24 @@ tl.type_check = function(ast, opts) return t end - local function node_warning(tag, node, fmt, ...) + local function add_warning(tag, where, fmt, ...) table.insert(warnings, { - y = node.y, - x = node.x, + y = where.y, + x = where.x, msg = fmt:format(...), - filename = filename, + filename = where.filename or filename, tag = tag, }) end local function node_error(node, msg, ...) - type_error(node, msg, ...) + error_at(node, msg, ...) node.type = INVALID return node.type end - local function terr(t, s, ...) - return { error_in_type(t, s, ...) } - end - local function add_unknown(node, name) - node_warning("unknown", node, "unknown variable: %s", name) + add_warning("unknown", node, "unknown variable: %s", name) end local function redeclaration_warning(node, old_var) @@ -6466,9 +6464,9 @@ tl.type_check = function(ast, opts) local short_error = "redeclaration of " .. var_kind .. " '%s'" if old_var and old_var.declared_at then - node_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) else - node_warning("redeclaration", node, short_error, var_name) + add_warning("redeclaration", node, short_error, var_name) end end @@ -6487,9 +6485,9 @@ tl.type_check = function(ast, opts) prefix ~= "@" then if name:sub(1, 2) == "::" then - node_warning("unused", var.declared_at, "unused label %s", name) + add_warning("unused", var.declared_at, "unused label %s", name) else - node_warning( + add_warning( "unused", var.declared_at, "unused %s %s: %s", @@ -6543,7 +6541,7 @@ tl.type_check = function(ast, opts) end ret = (ret ~= t) and ret or shallow_copy_type(t) ret.inferred_at = where - ret.inferred_at_file = filename + ret.inferred_at.filename = filename return ret end @@ -6680,7 +6678,7 @@ tl.type_check = function(ast, opts) local t2k = t2(k) if t2k == nil then if (not lax) and invariant then - table.insert(fielderrs, error_in_type(f, "unknown field " .. k)) + table.insert(fielderrs, Err(f, "unknown field " .. k)) end else local ok, errs @@ -6702,7 +6700,7 @@ tl.type_check = function(ast, opts) local function match_fields_to_record(rec1, rec2, invariant) if rec1.is_userdata ~= rec2.is_userdata then - return false, { error_in_type(rec1, "userdata record doesn't match: %s", rec2) } + return false, { Err(rec1, "userdata record doesn't match: %s", rec2) } end local ok, fielderrs = match_record_fields(rec1, function(k) return rec2.fields[k] end, invariant) if not ok then @@ -6715,7 +6713,7 @@ tl.type_check = function(ast, opts) local function match_fields_to_map(rec1, map) if not match_record_fields(rec1, function(_) return map.values end) then - return false, { error_in_type(rec1, "record is not a valid map; not all fields have the same type") } + return false, { Err(rec1, "record is not a valid map; not all fields have the same type") } end return true end @@ -6895,7 +6893,7 @@ tl.type_check = function(ast, opts) local function match_typevals(t, def) if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - type_error(t, "mismatch in number of type arguments") + error_at(t, "mismatch in number of type arguments") return nil end @@ -6907,10 +6905,10 @@ tl.type_check = function(ast, opts) end_scope() return ret elseif t.typevals then - type_error(t, "spurious type arguments") + error_at(t, "spurious type arguments") return nil elseif def.typeargs then - type_error(t, "missing type arguments in %s", def) + error_at(t, "missing type arguments in %s", def) return nil else return def @@ -6926,7 +6924,7 @@ tl.type_check = function(ast, opts) local typetype = t.found or find_type(t.names) if not typetype then - type_error(t, "unknown type %s", t) + error_at(t, "unknown type %s", t) return INVALID elseif is_typetype(typetype) then if typetype.is_alias then @@ -6946,7 +6944,7 @@ tl.type_check = function(ast, opts) assert(typetype.def.typename ~= "nominal") resolved = match_typevals(t, typetype.def) else - type_error(t, table.concat(t.names, ".") .. " is not a type") + error_at(t, table.concat(t.names, ".") .. " is not a type") end if not resolved then @@ -6993,10 +6991,10 @@ tl.type_check = function(ast, opts) end if not ft1 then - type_error(t1, "unknown type %s", t1) + error_at(t1, "unknown type %s", t1) end if not ft2 then - type_error(t2, "unknown type %s", t2) + error_at(t2, "unknown type %s", t2) end return false, {} end @@ -7030,7 +7028,7 @@ tl.type_check = function(ast, opts) t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" end end - return false, terr(t1, t1name .. " is not a " .. t2name) + return false, { Err(t1, t1name .. " is not a " .. t2name) } end end @@ -7075,7 +7073,7 @@ tl.type_check = function(ast, opts) end if t1.typename ~= t2.typename then - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end if t1.typename == "array" then return same_type(t1.elements, t2.elements) @@ -7104,7 +7102,7 @@ tl.type_check = function(ast, opts) has_all_types_of(t2.types, t1.types, same_type) then return true else - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end elseif t1.typename == "nominal" then return are_same_nominals(t1, t2) @@ -7114,12 +7112,12 @@ tl.type_check = function(ast, opts) local argdelta = t1.is_method and 1 or 0 if #t1.args ~= #t2.args then if t1.is_method ~= t2.is_method then - return false, terr(t1, "different number of input arguments: method and non-method are not the same type") + return false, { Err(t1, "different number of input arguments: method and non-method are not the same type") } end - return false, terr(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta) + return false, { Err(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta) } end if #t1.rets ~= #t2.rets then - return false, terr(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets) + return false, { Err(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets) } end local all_errs = {} for i = 1, #t1.args do @@ -7253,7 +7251,7 @@ tl.type_check = function(ast, opts) for i = 2, #tupletype.types do arr_type = expand_type(where, arr_type, a_type({ elements = tupletype.types[i], typename = "array" })) if not arr_type or not arr_type.elements then - return nil, terr(tupletype, "unable to convert tuple %s to array", tupletype) + return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end end return arr_type @@ -7326,7 +7324,7 @@ tl.type_check = function(ast, opts) end end_scope() if not ok then - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end end @@ -7341,7 +7339,7 @@ tl.type_check = function(ast, opts) else for _, t in ipairs(t1.types) do if not is_a(t, t2, for_equality) then - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end end return true @@ -7363,7 +7361,7 @@ tl.type_check = function(ast, opts) elseif t2.typename == "poly" then for _, t in ipairs(t2.types) do if not is_a(t1, t, for_equality) then - return false, terr(t1, "cannot match against all alternatives of the polymorphic type") + return false, { Err(t1, "cannot match against all alternatives of the polymorphic type") } end end return true @@ -7377,7 +7375,7 @@ tl.type_check = function(ast, opts) return true end end - return false, terr(t1, "cannot match against any alternatives of the polymorphic type") + return false, { Err(t1, "cannot match against any alternatives of the polymorphic type") } elseif t1.typename == "nominal" and t2.typename == "nominal" then local t1r = resolve_tuple_and_nominal(t1) local t2r = resolve_tuple_and_nominal(t2) @@ -7396,7 +7394,7 @@ tl.type_check = function(ast, opts) if ok then return true else - return false, terr(t1, "enum is incompatible with %s", t2) + return false, { Err(t1, "enum is incompatible with %s", t2) } end elseif t1.typename == "integer" and t2.typename == "number" then return true @@ -7406,9 +7404,9 @@ tl.type_check = function(ast, opts) return true else if t1.tk then - return false, terr(t1, "%s is not a member of %s", t1, t2) + return false, { Err(t1, "%s is not a member of %s", t1, t2) } else - return false, terr(t1, "string is not a %s", t2) + return false, { Err(t1, "string is not a %s", t2) } end end elseif t1.typename == "nominal" or t2.typename == "nominal" then @@ -7419,7 +7417,7 @@ tl.type_check = function(ast, opts) if errs[1].msg:match("^got ") then - errs = terr(t1, "got %s, expected %s", t1, t2) + errs = { Err(t1, "got %s, expected %s", t1, t2) } end end return ok, errs @@ -7434,7 +7432,7 @@ tl.type_check = function(ast, opts) for i = 2, #t1.types do local t = t1.types[i] if not is_a(t, t2e) then - return false, terr(t, "%s is not a member of %s", t, t2e) + return false, { Err(t, "%s is not a member of %s", t, t2e) } end end end @@ -7442,14 +7440,14 @@ tl.type_check = function(ast, opts) end elseif t1.typename == "tupletable" then if t2.inferred_len and t2.inferred_len > #t1.types then - return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) + return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } end local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) if not t1a then return false, err end if not is_a(t1a, t2) then - return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2) + return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } end return true elseif t1.typename == "map" then @@ -7469,21 +7467,21 @@ tl.type_check = function(ast, opts) return is_a(t1.elements, t2.elements) elseif t1.typename == "tupletable" then if t2.inferred_len and t2.inferred_len > #t1.types then - return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) + return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } end local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) if not t1a then return false, err end if not is_a(t1a, t2) then - return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2) + return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } end return true elseif t1.typename == "record" then return match_fields_to_record(t1, t2) elseif t1.typename == "arrayrecord" then if not is_a(t1.elements, t2.elements) then - return false, terr(t1, "array parts have incompatible element types") + return false, { Err(t1, "array parts have incompatible element types") } end return match_fields_to_record(t1, t2) elseif is_typetype(t1) and is_record_type(t1.def) then @@ -7504,7 +7502,7 @@ tl.type_check = function(ast, opts) if t1.typename == "tupletable" then local arr_type = arraytype_from_tuple(t1.inferred_at, t1) if not arr_type then - return false, terr(t1, "Unable to convert tuple %s to map", t1) + return false, { Err(t1, "Unable to convert tuple %s to map", t1) } end elements = arr_type.elements else @@ -7516,12 +7514,12 @@ tl.type_check = function(ast, opts) return combine_map_errs(errs_keys, errs_values) elseif is_record_type(t1) then if not is_a(t2.keys, STRING) then - return false, terr(t1, "can't match a record to a map with non-string keys") + return false, { Err(t1, "can't match a record to a map with non-string keys") } end if t2.keys.typename == "enum" then for _, k in ipairs(t1.field_order) do if not t2.keys.enumset[k] then - return false, terr(t1, "key is not an enum value: " .. k) + return false, { Err(t1, "key is not an enum value: " .. k) } end end end @@ -7531,19 +7529,19 @@ tl.type_check = function(ast, opts) if t1.typename == "tupletable" then for i = 1, math.min(#t1.types, #t2.types) do if not is_a(t1.types[i], t2.types[i], for_equality) then - return false, terr(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) + return false, { Err(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) } end end if for_equality and #t1.types ~= #t2.types then - return false, terr(t1, "tuples are not the same size") + return false, { Err(t1, "tuples are not the same size") } end if #t1.types > #t2.types then - return false, terr(t1, "tuple %s is too big for tuple %s", t1, t2) + return false, { Err(t1, "tuple %s is too big for tuple %s", t1, t2) } end return true elseif is_array_type(t1) then if t1.inferred_len and t1.inferred_len > #t2.types then - return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) + return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) } end @@ -7554,7 +7552,7 @@ tl.type_check = function(ast, opts) for i = 1, len do if not is_a(t1.elements, t2.types[i], for_equality) then - return false, terr(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) + return false, { Err(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) } end end return true @@ -7562,7 +7560,7 @@ tl.type_check = function(ast, opts) elseif t1.typename == "function" and t2.typename == "function" then local all_errs = {} if (not t2.args.is_va) and #t1.args > #t2.args then - table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) + table.insert(all_errs, Err(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) else for i = ((t1.is_method or t2.is_method) and 2 or 1), #t1.args do arg_check(nil, is_a, t1.args[i], t2.args[i] or ANY, i, all_errs, "argument") @@ -7570,7 +7568,7 @@ tl.type_check = function(ast, opts) end local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets.is_va if #t1.rets < #t2.rets and not diff_by_va then - table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) + table.insert(all_errs, Err(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) else local nrets = #t2.rets if diff_by_va then @@ -7593,7 +7591,7 @@ tl.type_check = function(ast, opts) return true end - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end local function assert_is_a(node, t1, t2, context, name) @@ -7958,7 +7956,7 @@ tl.type_check = function(ast, opts) return resolve_typevars_at(node, f.rets) end - local function check_call(where, where_args, func, args, is_method, argdelta) + local function check_call(node, where_args, func, args, is_method, argdelta) assert(type(func) == "table") assert(type(args) == "table") @@ -7971,7 +7969,7 @@ tl.type_check = function(ast, opts) local is_func = func.typename == "function" local is_poly = func.typename == "poly" if not (is_func or is_poly) then - return node_error(where, "not a function: %s", func) + return node_error(node, "not a function: %s", func) end local passes, n = 1, 1 @@ -7989,14 +7987,14 @@ tl.type_check = function(ast, opts) if f.is_method and not is_method then if args[1] and is_a(args[1], f.args[1]) then - if where.kind == "op" and where.op.op == "@funcall" then - local receiver_is_typetype = where.e1.e1 and where.e1.e1.type and where.e1.e1.type.resolved and where.e1.e1.type.resolved.typename == "typetype" + if node.kind == "op" and node.op.op == "@funcall" then + local receiver_is_typetype = node.e1.e1 and node.e1.e1.type and node.e1.e1.type.resolved and node.e1.e1.type.resolved.typename == "typetype" if not receiver_is_typetype then - node_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + add_warning("hint", node, "invoked method as a regular function: consider using ':' instead of '.'") end end else - return node_error(where, "invoked method as a regular function: use ':' instead of '.'") + return node_error(node, "invoked method as a regular function: use ':' instead of '.'") end end local expected = #f.args @@ -8012,16 +8010,16 @@ tl.type_check = function(ast, opts) push_typeargs(f) - local matched, errs = check_args_rets(where, where_args, f, args, where.expected, argdelta) + local matched, errs = check_args_rets(node, where_args, f, args, node.expected, argdelta) if matched then return matched, f end first_errs = first_errs or errs - if where.expected then + if node.expected then - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(node, where_args, f.rets, f.rets, argdelta) end if is_poly then @@ -8034,24 +8032,24 @@ tl.type_check = function(ast, opts) end end - return fail_call(where, func, given, first_errs) + return fail_call(node, func, given, first_errs) end - type_check_function_call = function(where, where_args, func, args, e1, is_method, argdelta) - if where.expected and where.expected.typename ~= "tuple" then - where.expected = a_type({ typename = "tuple", where.expected }) + type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) + if node.expected and node.expected.typename ~= "tuple" then + node.expected = a_type({ typename = "tuple", node.expected }) end begin_scope() - local ret, f = check_call(where, where_args, func, args, is_method, argdelta) - ret = resolve_typevars_at(where, ret) + local ret, f = check_call(node, where_args, func, args, is_method, argdelta) + ret = resolve_typevars_at(node, ret) end_scope() if e1 then e1.type = f end if func.macroexp then - expand_macroexp(where, where_args, func.macroexp) + expand_macroexp(node, where_args, func.macroexp) end return ret @@ -8288,7 +8286,7 @@ tl.type_check = function(ast, opts) for _, t in ipairs(node.typeargs) do local v = find_var(t.typearg, "check_only") if not v or not v.used_as_type then - type_error(t, "type argument '%s' is not used in function signature", t) + error_at(t, "type argument '%s' is not used in function signature", t) end end end @@ -8321,7 +8319,7 @@ tl.type_check = function(ast, opts) for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) - type_error(typ, "unknown type %s", typ) + error_at(typ, "unknown type %s", typ) end end end @@ -8447,9 +8445,7 @@ tl.type_check = function(ast, opts) return a.elements elseif a.typename == "emptytable" then if a.keys == nil then - a.keys = resolve_tuple(orig_b) - a.keys_inferred_at = assert(anode) - a.keys_inferred_at_file = filename + a.keys = infer_at(anode, resolve_tuple(orig_b)) end if is_a(orig_b, a.keys) then @@ -8457,9 +8453,9 @@ tl.type_check = function(ast, opts) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. - a.keys_inferred_at_file .. ":" .. - a.keys_inferred_at.y .. ":" .. - a.keys_inferred_at.x .. ": )", orig_b, a.keys + a.keys.inferred_at.filename .. ":" .. + a.keys.inferred_at.y .. ":" .. + a.keys.inferred_at.x .. ": )", orig_b, a.keys elseif a.typename == "map" then if is_a(orig_b, a.keys) then return a.values @@ -8514,7 +8510,7 @@ tl.type_check = function(ast, opts) old.values = expand_type(where, old.values, ftype) end else - node_error(where, "cannot determine table literal type") + error_at(where, "cannot determine table literal type") end elseif is_record_type(old) and is_record_type(new) then old.typename = "map" @@ -8826,7 +8822,7 @@ tl.type_check = function(ast, opts) ret[var] = EqFact({ var = var, typ = typ }) elseif not is_a(f.typ, typ) then assert(f.fact == "is") - node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) else assert(f.fact == "is") @@ -8923,7 +8919,7 @@ tl.type_check = function(ast, opts) return { [f.var] = f } elseif not is_a(f.typ, typ) then - node_error(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -8954,7 +8950,7 @@ tl.type_check = function(ast, opts) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - node_error(where, "cannot resolve a type for " .. v .. " here") + error_at(where, "cannot resolve a type for " .. v .. " here") end local t = infer_at(where, f.typ) if not f.where then @@ -9100,7 +9096,7 @@ tl.type_check = function(ast, opts) else resolved = find_type(names) if (not resolved) or (not is_typetype(resolved)) then - type_error(typetype, "%s is not a type", typetype) + error_at(typetype, "%s is not a type", typetype) resolved = a_type({ typename = "bad_nominal", names = names }) end end @@ -9639,7 +9635,7 @@ tl.type_check = function(ast, opts) local msg = #rets == 1 and "only 1 value is returned by the function" or ("only " .. #rets .. " values are returned by the function") - node_warning("hint", varnode, msg) + add_warning("hint", varnode, msg) end end end @@ -9766,7 +9762,7 @@ tl.type_check = function(ast, opts) if exp1.op and exp1.op.op == "@funcall" then local t = resolve_tuple_and_nominal(exp1.e2.type) if exp1.e1.tk == "pairs" and is_array_type(t) then - node_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") + add_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") end if exp1.e1.tk == "pairs" and t.typename ~= "map" then @@ -9775,7 +9771,7 @@ tl.type_check = function(ast, opts) match_all_record_field_names(exp1.e2, t, t.field_order, "attempting pairs loop on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - node_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + add_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else node_error(exp1.e2, "cannot apply pairs on values of type: %s", exp1.e2.type) end @@ -9878,7 +9874,7 @@ tl.type_check = function(ast, opts) node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and #node.exps[1].e2.type > 1 then - node_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, #children[1] do @@ -10045,12 +10041,10 @@ tl.type_check = function(ast, opts) end if force_array then - node.type = a_type({ - inferred_at = node, - inferred_at_file = filename, + node.type = infer_at(node, a_type({ typename = "array", elements = force_array, - }) + })) else node.type = resolve_typevars_at(node, node.expected) if node.expected == node.type and node.type.typename == "nominal" then @@ -10558,7 +10552,7 @@ tl.type_check = function(ast, opts) end if a.typename == "map" then if a.keys.typename == "number" or a.keys.typename == "integer" then - node_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else node_error(node, "using the '#' operator on this map will always return 0") end @@ -10604,7 +10598,7 @@ tl.type_check = function(ast, opts) if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then - node_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end @@ -10808,7 +10802,7 @@ tl.type_check = function(ast, opts) check_macroexp_arg_use(typ.macroexp) if not is_a(typ.macroexp.type, typ) then - type_error(typ.macroexp.type, "macroexp type does not match declaration") + error_at(typ.macroexp.type, "macroexp type does not match declaration") end end @@ -10876,7 +10870,7 @@ tl.type_check = function(ast, opts) ["typevar"] = { after = function(typ, _children) if not find_var_type(typ.typevar) then - type_error(typ, "undefined type variable " .. typ.typevar) + error_at(typ, "undefined type variable " .. typ.typevar) end return typ end, diff --git a/tl.tl b/tl.tl index 1bfcc8a51..92579616b 100644 --- a/tl.tl +++ b/tl.tl @@ -1065,6 +1065,11 @@ local table_types : {TypeName:boolean} = { local record Type {Type} + + metamethod __is: function(self: Type) = macroexp(self: Type) + self.typename ~= nil + end + y: integer x: integer filename: string @@ -1137,10 +1142,7 @@ local record Type -- emptytable declared_at: Node assigned_to: string - keys_inferred_at: Node - keys_inferred_at_file: string - inferred_at: Node - inferred_at_file: string + inferred_at: Where emptytable_type: Type -- enum @@ -1248,7 +1250,7 @@ local record TruthyFact end fact: FactType - where: Node + where: Where metamethod __call: function(Fact, Fact): TruthyFact end @@ -1259,7 +1261,7 @@ local record NotFact end fact: FactType - where: Node + where: Where f1: Fact @@ -1272,7 +1274,7 @@ local record AndFact end fact: FactType - where: Node + where: Where f1: Fact f2: Fact @@ -1286,7 +1288,7 @@ local record OrFact end fact: FactType - where: Node + where: Where f1: Fact f2: Fact @@ -1300,7 +1302,7 @@ local record EqFact end fact: FactType - where: Node + where: Where var: string typ: Type @@ -1314,7 +1316,7 @@ local record IsFact end fact: FactType - where: Node + where: Where var: string typ: Type @@ -1344,6 +1346,10 @@ local is_attribute : {string:boolean} = attributes as {string:boolean} local record Node {Node} + metamethod __is: function(self: Node) = macroexp(self: Node) + self.kind ~= nil + end + record ExpectedContext kind: NodeKind name: string @@ -1351,6 +1357,8 @@ local record Node y: integer x: integer + filename: string + tk: string kind: NodeKind symbol_list_slot: integer @@ -1438,6 +1446,10 @@ local record Node decltype: Type end +local type Where + = Node + | Type + local function is_array_type(t:Type): boolean return t.typename == "array" or t.typename == "arrayrecord" end @@ -5023,7 +5035,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end local function inferred_msg(t: Type): string - return " (inferred at "..t.inferred_at_file..":"..t.inferred_at.y..":"..t.inferred_at.x..")" + return " (inferred at "..t.inferred_at.filename..":"..t.inferred_at.y..":"..t.inferred_at.x..")" end show_type = function(t: Type, short: boolean, seen: {Type:string}): string @@ -6063,7 +6075,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function error_in_type(where: Type, msg: string, ...: Type): Error + local function Err(where: Where, msg: string, ...: Type): Error local n = select("#", ...) if n > 0 then local showt = {} @@ -6087,8 +6099,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } end - local function type_error(t: Type, msg: string, ...:Type): boolean - local e = error_in_type(t, msg, ...) + local function error_at(w: Where, msg: string, ...:Type): boolean + local e = Err(w, msg, ...) if e then table.insert(errors, e) return true @@ -6214,7 +6226,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function validate_union(where: Node, u: Type, store_errs: boolean, errs: {Error}): Type, {Error} + local function validate_union(where: Where, u: Type, store_errs: boolean, errs: {Error}): Type, {Error} local valid, err = is_valid_union(u) if err then if store_errs then @@ -6222,7 +6234,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else errs = errors end - table.insert(errs, error_in_type(where as Type, err, u)) + table.insert(errs, Err(where, err, u)) end if not valid then u = INVALID @@ -6378,7 +6390,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.types[i], same = resolve(tf, same) end - copy, errs = validate_union(t as Node, copy, true, errs) + copy, errs = validate_union(t, copy, true, errs) elseif t.typename == "poly" or t.typename == "tupletable" then copy.types = {} for i, tf in ipairs(t.types) do @@ -6430,28 +6442,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function node_warning(tag: tl.WarningKind, node: Node, fmt: string, ...: any) + local function add_warning(tag: tl.WarningKind, where: Where, fmt: string, ...: any) table.insert(warnings, { - y = node.y, - x = node.x, + y = where.y, + x = where.x, msg = fmt:format(...), - filename = filename, + filename = where.filename or filename, tag = tag, }) end local function node_error(node: Node, msg: string, ...:Type): Type - type_error(node as Type, msg, ...) + error_at(node, msg, ...) node.type = INVALID return node.type end - local function terr(t: Type, s: string, ...: Type): {Error} - return { error_in_type(t, s, ...) } - end - local function add_unknown(node: Node, name: string) - node_warning("unknown", node, "unknown variable: %s", name) + add_warning("unknown", node, "unknown variable: %s", name) end local function redeclaration_warning(node: Node, old_var: Variable) @@ -6466,9 +6474,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local short_error = "redeclaration of " .. var_kind .. " '%s'" if old_var and old_var.declared_at then - node_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) else - node_warning("redeclaration", node, short_error, var_name) + add_warning("redeclaration", node, short_error, var_name) end end @@ -6487,9 +6495,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and prefix ~= "@" then if name:sub(1, 2) == "::" then - node_warning("unused", var.declared_at, "unused label %s", name) + add_warning("unused", var.declared_at, "unused label %s", name) else - node_warning( + add_warning( "unused", var.declared_at, "unused %s %s: %s", @@ -6504,14 +6512,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function add_errs_prefixing(where: Node, src: {Error}, dst: {Error}, prefix: string) + local function add_errs_prefixing(where: Where, src: {Error}, dst: {Error}, prefix: string) if not src then return end for _, err in ipairs(src) do err.msg = prefix .. err.msg - -- where.y may be nil because of `typ as Node` casts and not all types have .y set + -- where.y may be nil because not all types have .y set if where and where.y and ( (err.filename ~= filename) or (not err.y) @@ -6526,7 +6534,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function resolve_typevars_at(where: Node, t: Type): Type + local function resolve_typevars_at(where: Where, t: Type): Type assert(where) local ok, typ, errs = resolve_typevars(t) if not ok then @@ -6536,14 +6544,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return typ end - local function infer_at(where: Node, t: Type): Type + local function infer_at(where: Where, t: Type): Type local ret = resolve_typevars_at(where, t) if ret.typename == "invalid" then ret = t -- errors are produced by resolve_typevars_at end ret = (ret ~= t) and ret or shallow_copy_type(t) ret.inferred_at = where - ret.inferred_at_file = filename + ret.inferred_at.filename = filename return ret end @@ -6680,7 +6688,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t2k = t2(k) if t2k == nil then if (not lax) and invariant then - table.insert(fielderrs, error_in_type(f, "unknown field " .. k)) + table.insert(fielderrs, Err(f, "unknown field " .. k)) end else local ok, errs: boolean, {Error} @@ -6702,7 +6710,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function match_fields_to_record(rec1: Type, rec2: Type, invariant: boolean): boolean, {Error} if rec1.is_userdata ~= rec2.is_userdata then - return false, { error_in_type(rec1, "userdata record doesn't match: %s", rec2) } + return false, { Err(rec1, "userdata record doesn't match: %s", rec2) } end local ok, fielderrs = match_record_fields(rec1, function(k: string): Type return rec2.fields[k] end, invariant) if not ok then @@ -6715,12 +6723,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function match_fields_to_map(rec1: Type, map: Type): boolean, {Error} if not match_record_fields(rec1, function(_: string): Type return map.values end) then - return false, { error_in_type(rec1, "record is not a valid map; not all fields have the same type") } + return false, { Err(rec1, "record is not a valid map; not all fields have the same type") } end return true end - local function arg_check(where: Node, cmp: CompareTypes, a: Type, b: Type, n: integer, errs: {Error}, ctx: string): boolean + local function arg_check(where: Where, cmp: CompareTypes, a: Type, b: Type, n: integer, errs: {Error}, ctx: string): boolean local matches, match_errs = cmp(a, b) if not matches then add_errs_prefixing(where, match_errs, errs, ctx .. (n and " " .. n or "") .. ": ") @@ -6895,7 +6903,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function match_typevals(t: Type, def: Type): Type if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - type_error(t, "mismatch in number of type arguments") + error_at(t, "mismatch in number of type arguments") return nil end @@ -6907,10 +6915,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_scope() return ret elseif t.typevals then - type_error(t, "spurious type arguments") + error_at(t, "spurious type arguments") return nil elseif def.typeargs then - type_error(t, "missing type arguments in %s", def) + error_at(t, "missing type arguments in %s", def) return nil else return def @@ -6926,7 +6934,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local typetype = t.found or find_type(t.names) if not typetype then - type_error(t, "unknown type %s", t) + error_at(t, "unknown type %s", t) return INVALID elseif is_typetype(typetype) then if typetype.is_alias then @@ -6946,7 +6954,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(typetype.def.typename ~= "nominal") resolved = match_typevals(t, typetype.def) else - type_error(t, table.concat(t.names, ".") .. " is not a type") + error_at(t, table.concat(t.names, ".") .. " is not a type") end if not resolved then @@ -6993,10 +7001,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not ft1 then - type_error(t1, "unknown type %s", t1) + error_at(t1, "unknown type %s", t1) end if not ft2 then - type_error(t2, "unknown type %s", t2) + error_at(t2, "unknown type %s", t2) end return false, {} -- errors were already produced end @@ -7030,7 +7038,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" end end - return false, terr(t1, t1name .. " is not a " .. t2name) + return false, { Err(t1, t1name .. " is not a " .. t2name) } end end @@ -7075,7 +7083,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if t1.typename ~= t2.typename then - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end if t1.typename == "array" then return same_type(t1.elements, t2.elements) @@ -7104,7 +7112,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and has_all_types_of(t2.types, t1.types, same_type) then return true else - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end elseif t1.typename == "nominal" then return are_same_nominals(t1, t2) @@ -7114,12 +7122,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local argdelta = t1.is_method and 1 or 0 if #t1.args ~= #t2.args then if t1.is_method ~= t2.is_method then - return false, terr(t1, "different number of input arguments: method and non-method are not the same type") + return false, { Err(t1, "different number of input arguments: method and non-method are not the same type") } end - return false, terr(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta) + return false, { Err(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta) } end if #t1.rets ~= #t2.rets then - return false, terr(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets) + return false, { Err(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets) } end local all_errs = {} for i = 1, #t1.args do @@ -7233,8 +7241,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return known_table_types[t.typename] and not t.is_userdata end - local expand_type: function(where: Node, old: Type, new: Type): Type - local function arraytype_from_tuple(where: Node, tupletype: Type): Type, {Error} + local expand_type: function(where: Where, old: Type, new: Type): Type + local function arraytype_from_tuple(where: Where, tupletype: Type): Type, {Error} -- first just try a basic union local element_type = unite(tupletype.types, true) local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) @@ -7253,7 +7261,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = 2, #tupletype.types do arr_type = expand_type(where, arr_type, a_type { elements = tupletype.types[i], typename = "array" }) if not arr_type or not arr_type.elements then - return nil, terr(tupletype, "unable to convert tuple %s to array", tupletype) + return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end end return arr_type @@ -7326,7 +7334,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end_scope() -- don't preserve failed inferences if not ok then - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end end -- preserve all valid inferences @@ -7341,7 +7349,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else for _, t in ipairs(t1.types) do if not is_a(t, t2, for_equality) then - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end end return true @@ -7363,7 +7371,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t2.typename == "poly" then for _, t in ipairs(t2.types) do if not is_a(t1, t, for_equality) then - return false, terr(t1, "cannot match against all alternatives of the polymorphic type") + return false, { Err(t1, "cannot match against all alternatives of the polymorphic type") } end end return true @@ -7377,7 +7385,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end end - return false, terr(t1, "cannot match against any alternatives of the polymorphic type") + return false, { Err(t1, "cannot match against any alternatives of the polymorphic type") } elseif t1.typename == "nominal" and t2.typename == "nominal" then local t1r = resolve_tuple_and_nominal(t1) local t2r = resolve_tuple_and_nominal(t2) @@ -7396,7 +7404,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ok then return true else - return false, terr(t1, "enum is incompatible with %s", t2) + return false, { Err(t1, "enum is incompatible with %s", t2) } end elseif t1.typename == "integer" and t2.typename == "number" then return true @@ -7406,9 +7414,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true else if t1.tk then - return false, terr(t1, "%s is not a member of %s", t1, t2) + return false, { Err(t1, "%s is not a member of %s", t1, t2) } else - return false, terr(t1, "string is not a %s", t2) + return false, { Err(t1, "string is not a %s", t2) } end end elseif t1.typename == "nominal" or t2.typename == "nominal" then @@ -7419,7 +7427,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if errs[1].msg:match("^got ") then --local got = t1.typename == "nominal" and t1.name or show_type(t1) --local expected = t2.typename == "nominal" and t2.name or show_type(t2) - errs = terr(t1, "got %s, expected %s", t1, t2) + errs = { Err(t1, "got %s, expected %s", t1, t2) } end end return ok, errs @@ -7434,7 +7442,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = 2, #t1.types do local t = t1.types[i] if not is_a(t, t2e) then - return false, terr(t, "%s is not a member of %s", t, t2e) + return false, { Err(t, "%s is not a member of %s", t, t2e) } end end end @@ -7442,14 +7450,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t1.typename == "tupletable" then if t2.inferred_len and t2.inferred_len > #t1.types then - return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) + return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } end local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) if not t1a then return false, err end if not is_a(t1a, t2) then - return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2) + return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } end return true elseif t1.typename == "map" then @@ -7469,21 +7477,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_a(t1.elements, t2.elements) elseif t1.typename == "tupletable" then if t2.inferred_len and t2.inferred_len > #t1.types then - return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) + return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } end local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) if not t1a then return false, err end if not is_a(t1a, t2) then - return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2) + return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } end return true elseif t1.typename == "record" then return match_fields_to_record(t1, t2) elseif t1.typename == "arrayrecord" then if not is_a(t1.elements, t2.elements) then - return false, terr(t1, "array parts have incompatible element types") + return false, { Err(t1, "array parts have incompatible element types") } end return match_fields_to_record(t1, t2) elseif is_typetype(t1) and is_record_type(t1.def) then -- record as prototype @@ -7504,7 +7512,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t1.typename == "tupletable" then local arr_type = arraytype_from_tuple(t1.inferred_at, t1) if not arr_type then - return false, terr(t1, "Unable to convert tuple %s to map", t1) + return false, { Err(t1, "Unable to convert tuple %s to map", t1) } end elements = arr_type.elements else @@ -7516,12 +7524,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return combine_map_errs(errs_keys, errs_values) elseif is_record_type(t1) then -- FIXME if not is_a(t2.keys, STRING) then - return false, terr(t1, "can't match a record to a map with non-string keys") + return false, { Err(t1, "can't match a record to a map with non-string keys") } end if t2.keys.typename == "enum" then for _, k in ipairs(t1.field_order) do if not t2.keys.enumset[k] then - return false, terr(t1, "key is not an enum value: " .. k) + return false, { Err(t1, "key is not an enum value: " .. k) } end end end @@ -7531,19 +7539,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t1.typename == "tupletable" then for i = 1, math.min(#t1.types, #t2.types) do if not is_a(t1.types[i], t2.types[i], for_equality) then - return false, terr(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) + return false, { Err(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) } end end if for_equality and #t1.types ~= #t2.types then - return false, terr(t1, "tuples are not the same size") + return false, { Err(t1, "tuples are not the same size") } end if #t1.types > #t2.types then - return false, terr(t1, "tuple %s is too big for tuple %s", t1, t2) + return false, { Err(t1, "tuple %s is too big for tuple %s", t1, t2) } end return true elseif is_array_type(t1) then if t1.inferred_len and t1.inferred_len > #t2.types then - return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) + return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) } end -- for array literals (which is the only case where inferred_len is defined), @@ -7554,7 +7562,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = 1, len do if not is_a(t1.elements, t2.types[i], for_equality) then - return false, terr(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) + return false, { Err(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) } end end return true @@ -7562,7 +7570,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t1.typename == "function" and t2.typename == "function" then local all_errs = {} if (not t2.args.is_va) and #t1.args > #t2.args then - table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) + table.insert(all_errs, Err(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) else for i = ((t1.is_method or t2.is_method) and 2 or 1), #t1.args do arg_check(nil, is_a, t1.args[i], t2.args[i] or ANY, i, all_errs, "argument") @@ -7570,7 +7578,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets.is_va if #t1.rets < #t2.rets and not diff_by_va then - table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) + table.insert(all_errs, Err(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) else local nrets = #t2.rets if diff_by_va then @@ -7593,7 +7601,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - return false, terr(t1, "got %s, expected %s", t1, t2) + return false, { Err(t1, "got %s, expected %s", t1, t2) } end local function assert_is_a(node: Node, t1: Type, t2: Type, context: string, name: string): boolean @@ -7825,7 +7833,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function infer_emptytables(where: Node, wheres: {Node}, xs: Type, ys: Type, delta: integer) + local function infer_emptytables(where: Where, wheres: {Where}, xs: Type, ys: Type, delta: integer) assert(xs.typename == "tuple") assert(ys.typename == "tuple") @@ -7846,10 +7854,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local check_args_rets: function(where: Node, where_args: {Node}, f: Type, args: {Type}, rets: {Type}, argdelta: integer): Type, {Error} + local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: {Type}, rets: {Type}, argdelta: integer): Type, {Error} do -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(where: Node, wheres: {Node}, xs: Type, ys: Type, from: integer, delta: integer, mode: string): boolean, {Error} + local function check_func_type_list(where: Where, wheres: {Where}, xs: Type, ys: Type, from: integer, delta: integer, mode: string): boolean, {Error} assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -7872,7 +7880,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Node, where_args: {Node}, f: Type, args: {Type}, rets: {Type}, argdelta: integer): Type, {Error} + check_args_rets = function(where: Where, where_args: {Node}, f: Type, args: {Type}, rets: {Type}, argdelta: integer): Type, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean @@ -7958,7 +7966,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_typevars_at(node, f.rets) end - local function check_call(where: Node, where_args: {Node}, func: Type, args: {Type}, is_method: boolean, argdelta: integer): Type, Type + local function check_call(node: Node, where_args: {Node}, func: Type, args: {Type}, is_method: boolean, argdelta: integer): Type, Type assert(type(func) == "table") assert(type(args) == "table") @@ -7971,7 +7979,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local is_func = func.typename == "function" local is_poly = func.typename == "poly" if not (is_func or is_poly) then - return node_error(where, "not a function: %s", func) + return node_error(node, "not a function: %s", func) end local passes, n = 1, 1 @@ -7988,15 +7996,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local f = is_func and func or func.types[i] if f.is_method and not is_method then if args[1] and is_a(args[1], f.args[1]) then - -- a non-"@funcall" `where` means a synthesized call, e.g. from a metamethod - if where.kind == "op" and where.op.op == "@funcall" then - local receiver_is_typetype = where.e1.e1 and where.e1.e1.type and where.e1.e1.type.resolved and where.e1.e1.type.resolved.typename == "typetype" + -- a non-"@funcall" `node` means a synthesized call, e.g. from a metamethod + if node.kind == "op" and node.op.op == "@funcall" then + local receiver_is_typetype = node.e1.e1 and node.e1.e1.type and node.e1.e1.type.resolved and node.e1.e1.type.resolved.typename == "typetype" if not receiver_is_typetype then - node_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + add_warning("hint", node, "invoked method as a regular function: consider using ':' instead of '.'") end end else - return node_error(where, "invoked method as a regular function: use ':' instead of '.'") + return node_error(node, "invoked method as a regular function: use ':' instead of '.'") end end local expected = #f.args @@ -8012,16 +8020,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string then push_typeargs(f) - local matched, errs = check_args_rets(where, where_args, f, args, where.expected, argdelta) + local matched, errs = check_args_rets(node, where_args, f, args, node.expected, argdelta) if matched then -- success! return matched, f end first_errs = first_errs or errs - if where.expected then + if node.expected then -- revert inferred returns - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(node, where_args, f.rets, f.rets, argdelta) end if is_poly then @@ -8034,24 +8042,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - return fail_call(where, func, given, first_errs) + return fail_call(node, func, given, first_errs) end - type_check_function_call = function(where: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta: integer): Type - if where.expected and where.expected.typename ~= "tuple" then - where.expected = a_type { typename = "tuple", where.expected } + type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta: integer): Type + if node.expected and node.expected.typename ~= "tuple" then + node.expected = a_type { typename = "tuple", node.expected } end begin_scope() - local ret, f = check_call(where, where_args, func, args, is_method, argdelta) - ret = resolve_typevars_at(where, ret) + local ret, f = check_call(node, where_args, func, args, is_method, argdelta) + ret = resolve_typevars_at(node, ret) end_scope() if e1 then e1.type = f end if func.macroexp then - expand_macroexp(where, where_args, func.macroexp) + expand_macroexp(node, where_args, func.macroexp) end return ret @@ -8288,7 +8296,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for _, t in ipairs(node.typeargs) do local v = find_var(t.typearg, "check_only") if not v or not v.used_as_type then - type_error(t, "type argument '%s' is not used in function signature", t) + error_at(t, "type argument '%s' is not used in function signature", t) end end end @@ -8321,7 +8329,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) - type_error(typ, "unknown type %s", typ) + error_at(typ, "unknown type %s", typ) end end end @@ -8447,9 +8455,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return a.elements elseif a.typename == "emptytable" then if a.keys == nil then - a.keys = resolve_tuple(orig_b) - a.keys_inferred_at = assert(anode) - a.keys_inferred_at_file = filename + a.keys = infer_at(anode, resolve_tuple(orig_b)) end if is_a(orig_b, a.keys) then @@ -8457,9 +8463,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " - .. a.keys_inferred_at_file .. ":" - .. a.keys_inferred_at.y .. ":" - .. a.keys_inferred_at.x .. ": )", orig_b, a.keys + .. a.keys.inferred_at.filename .. ":" + .. a.keys.inferred_at.y .. ":" + .. a.keys.inferred_at.x .. ": )", orig_b, a.keys elseif a.typename == "map" then if is_a(orig_b, a.keys) then return a.values @@ -8503,7 +8509,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return node_error(bnode, errm, erra, errb) end - expand_type = function(where: Node, old: Type, new: Type): Type + expand_type = function(where: Where, old: Type, new: Type): Type if not old or old.typename == "nil" then return new else @@ -8514,7 +8520,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.values = expand_type(where, old.values, ftype) end else - node_error(where, "cannot determine table literal type") + error_at(where, "cannot determine table literal type") end elseif is_record_type(old) and is_record_type(new) then old.typename = "map" @@ -8623,10 +8629,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- Inference engine for 'is' operator - local facts_and: function(where: Node, f1: Fact, f2: Fact): Fact - local facts_or: function(where: Node, f1: Fact, f2: Fact): Fact - local facts_not: function(where: Node, f1: Fact): Fact - local apply_facts: function(where: Node, known: Fact) + local facts_and: function(where: Where, f1: Fact, f2: Fact): Fact + local facts_or: function(where: Where, f1: Fact, f2: Fact): Fact + local facts_not: function(where: Where, f1: Fact): Fact + local apply_facts: function(where: Where, known: Fact) local FACT_TRUTHY: Fact do local IsFact_mt: metatable = { @@ -8709,11 +8715,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string FACT_TRUTHY = TruthyFact {} - facts_and = function(where: Node, f1: Fact, f2: Fact): Fact + facts_and = function(where: Where, f1: Fact, f2: Fact): Fact return AndFact { f1 = f1, f2 = f2, where = where } end - facts_or = function(where: Node, f1: Fact, f2: Fact): Fact + facts_or = function(where: Where, f1: Fact, f2: Fact): Fact if f1 and f2 then return OrFact { f1 = f1, f2 = f2, where = where } else @@ -8721,7 +8727,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - facts_not = function(where: Node, f1: Fact): Fact + facts_not = function(where: Where, f1: Fact): Fact if f1 then return NotFact { f1 = f1, where = where } else @@ -8826,7 +8832,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ret[var] = EqFact { var = var, typ = typ } elseif not is_a(f.typ, typ) then assert(f.fact == "is") - node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) ret[var] = EqFact { var = var, typ = INVALID, where = f.where } else assert(f.fact == "is") @@ -8920,10 +8926,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if typ.typename ~= "typevar" then if is_a(typ, f.typ) then -- drop this warning because of implicit nil in all unions - -- node_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) + -- add_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) return { [f.var] = f } elseif not is_a(f.typ, typ) then - node_error(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -8945,7 +8951,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - apply_facts = function(where: Node, known: Fact) + apply_facts = function(where: Where, known: Fact) if not known then return end @@ -8954,7 +8960,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for v, f in pairs(facts) do if f.typ.typename == "invalid" then - node_error(where, "cannot resolve a type for " .. v .. " here") + error_at(where, "cannot resolve a type for " .. v .. " here") end local t = infer_at(where, f.typ) if not f.where then @@ -9100,7 +9106,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else resolved = find_type(names) if (not resolved) or (not is_typetype(resolved)) then - type_error(typetype, "%s is not a type", typetype) + error_at(typetype, "%s is not a type", typetype) resolved = a_type { typename = "bad_nominal", names = names } end end @@ -9327,7 +9333,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return typ end - local function infer_negation_of_if_blocks(where: Node, ifnode: Node, n: integer) + local function infer_negation_of_if_blocks(where: Where, ifnode: Node, n: integer) local f = facts_not(where, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] @@ -9639,7 +9645,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local msg = #rets == 1 and "only 1 value is returned by the function" or ("only " .. #rets .. " values are returned by the function") - node_warning("hint", varnode, msg) + add_warning("hint", varnode, msg) end end end @@ -9766,7 +9772,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if exp1.op and exp1.op.op == "@funcall" then local t = resolve_tuple_and_nominal(exp1.e2.type) if exp1.e1.tk == "pairs" and is_array_type(t) then - node_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") + add_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") end if exp1.e1.tk == "pairs" and t.typename ~= "map" then @@ -9775,7 +9781,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string match_all_record_field_names(exp1.e2, t, t.field_order, "attempting pairs loop on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - node_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + add_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else node_error(exp1.e2, "cannot apply pairs on values of type: %s", exp1.e2.type) end @@ -9878,7 +9884,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and #node.exps[1].e2.type > 1 then - node_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, #children[1] do @@ -10045,12 +10051,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if force_array then - node.type = a_type { - inferred_at = node, - inferred_at_file = filename, + node.type = infer_at(node, a_type { typename = "array", elements = force_array, - } + }) else node.type = resolve_typevars_at(node, node.expected) if node.expected == node.type and node.type.typename == "nominal" then @@ -10558,7 +10562,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if a.typename == "map" then if a.keys.typename == "number" or a.keys.typename == "integer" then - node_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else node_error(node, "using the '#' operator on this map will always return 0") end @@ -10604,7 +10608,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then - node_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end @@ -10808,7 +10812,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string check_macroexp_arg_use(typ.macroexp) if not is_a(typ.macroexp.type, typ) then - type_error(typ.macroexp.type, "macroexp type does not match declaration") + error_at(typ.macroexp.type, "macroexp type does not match declaration") end end @@ -10876,7 +10880,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["typevar"] = { after = function(typ: Type, _children: {Type}): Type if not find_var_type(typ.typevar) then - type_error(typ, "undefined type variable " .. typ.typevar) + error_at(typ, "undefined type variable " .. typ.typevar) end return typ end, @@ -10913,7 +10917,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["union"] = { after = function(typ: Type, _children: {Type}): Type - return (validate_union(typ as Node, typ)) + return (validate_union(typ, typ)) end }, }, From 621a3f0dfd43c3b3d7e1820f88efd80624ccdfd6 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 18 Nov 2023 22:01:58 -0300 Subject: [PATCH 012/224] add magic type @self --- tl.lua | 21 ++++++++++++++++++++- tl.tl | 21 ++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index e8fa8aee0..5c06197e8 100644 --- a/tl.lua +++ b/tl.lua @@ -4909,6 +4909,10 @@ local function show_type_base(t, short, seen) end if t.typename == "nominal" then + if #t.names == 1 and t.names[1] == "@self" then + return "self" + end + if t.typevals then local out = { table.concat(t.names, "."), "<" } local vals = {} @@ -7257,6 +7261,10 @@ tl.type_check = function(ast, opts) return arr_type end + local function is_self(t) + return t.typename == "nominal" and t.names[1] == "@self" + end + is_a = function(t1, t2, for_equality) assert(type(t1) == "table") @@ -7302,6 +7310,11 @@ tl.type_check = function(ast, opts) if t2.typename == "any" then return true + elseif is_self(t1) then + return is_a(resolve_tuple_and_nominal(t1), t2, for_equality) + + elseif is_self(t2) then + return is_a(t1, resolve_tuple_and_nominal(t2), for_equality) elseif t1.typename == "union" then @@ -7880,7 +7893,7 @@ tl.type_check = function(ast, opts) if argdelta == -1 then from = 2 local errs = {} - if not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then + if (not is_self(f.args[1])) and not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then return nil, errs end end @@ -7966,6 +7979,10 @@ tl.type_check = function(ast, opts) argdelta = is_method and -1 or argdelta or 0 + if is_method and args[1] then + add_var(nil, "@self", a_type({ typename = "typetype", y = node.y, x = node.x, def = args[1] })) + end + local is_func = func.typename == "function" local is_poly = func.typename == "poly" if not (is_func or is_poly) then @@ -10812,6 +10829,8 @@ tl.type_check = function(ast, opts) ["record"] = { before = function(typ, _children) begin_scope() + add_var(nil, "@self", a_type({ typename = "typetype", y = typ.y, x = typ.x, def = typ })) + for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then typ2.typename = "nestedtype" diff --git a/tl.tl b/tl.tl index 92579616b..316593544 100644 --- a/tl.tl +++ b/tl.tl @@ -4919,6 +4919,10 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end if t.typename == "nominal" then + if #t.names == 1 and t.names[1] == "@self" then + return "self" + end + if t.typevals then local out = { table.concat(t.names, "."), "<" } local vals: {string} = {} @@ -7267,6 +7271,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return arr_type end + local function is_self(t: Type): boolean + return t.typename == "nominal" and t.names[1] == "@self" + end + -- subtyping comparison is_a = function(t1: Type, t2: Type, for_equality: boolean): boolean, {Error} assert(type(t1) == "table") @@ -7312,6 +7320,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t2.typename == "any" then return true + elseif is_self(t1) then + return is_a(resolve_tuple_and_nominal(t1), t2, for_equality) + + elseif is_self(t2) then + return is_a(t1, resolve_tuple_and_nominal(t2), for_equality) elseif t1.typename == "union" then @@ -7890,7 +7903,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if argdelta == -1 then from = 2 local errs = {} - if not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then + if (not is_self(f.args[1])) and not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then return nil, errs end end @@ -7976,6 +7989,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string argdelta = is_method and -1 or argdelta or 0 + if is_method and args[1] then + add_var(nil, "@self", a_type({ typename = "typetype", y = node.y, x = node.x, def = args[1] })) + end + local is_func = func.typename == "function" local is_poly = func.typename == "poly" if not (is_func or is_poly) then @@ -10822,6 +10839,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = { before = function(typ: Type, _children: {Type}) begin_scope() + add_var(nil, "@self", a_type { typename = "typetype", y = typ.y, x = typ.x, def = typ }) + for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then typ2.typename = "nestedtype" From db02a918fd8758d3e8aac961a3505c8b7ca1bf17 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 24 Sep 2023 00:08:35 -0300 Subject: [PATCH 013/224] interfaces: add "interface" syntax --- docs/grammar.md | 2 ++ tl.lua | 6 ++++++ tl.tl | 6 ++++++ 3 files changed, 14 insertions(+) diff --git a/docs/grammar.md b/docs/grammar.md index d2e5d0500..79c539651 100644 --- a/docs/grammar.md +++ b/docs/grammar.md @@ -31,12 +31,14 @@ precedence, see below. + ‘local’ attnamelist [‘:’ typelist] [‘=’ explist] | ‘local’ ‘function’ Name funcbody | * ‘local’ ‘record’ Name recordbody | +* ‘local’ ‘interface’ Name recordbody | * ‘local’ ‘enum’ Name enumbody | * ‘local’ ‘type’ Name ‘=’ newtype | * ‘global’ attnamelist ‘:’ typelist [‘=’ explist] | * ‘global’ attnamelist ‘=’ explist | * ‘global’ ‘function’ Name funcbody | * ‘global’ ‘record’ Name recordbody | +* ‘global’ ‘interface’ Name recordbody | * ‘global’ ‘enum’ Name enumbody | * ‘global’ ‘type’ Name [‘=’ newtype] diff --git a/tl.lua b/tl.lua index 5c06197e8..f04d205e3 100644 --- a/tl.lua +++ b/tl.lua @@ -191,6 +191,7 @@ tl.typecodes = { TUPLE = 0x00080008, EMPTY_TABLE = 0x00000008, ENUM = 0x00010004, + INTERFACE = 0x00100008, IS_ARRAY = 0x00010008, IS_RECORD = 0x00020008, @@ -1024,6 +1025,7 @@ end + local table_types = { @@ -1031,6 +1033,7 @@ local table_types = { ["map"] = true, ["arrayrecord"] = true, ["record"] = true, + ["interface"] = true, ["emptytable"] = true, ["tupletable"] = true, @@ -1242,6 +1245,7 @@ local table_types = { + local TruthyFact = {} @@ -2994,6 +2998,7 @@ parse_record_body = function(ps, i, def, node, name) end parse_type_body_fns = { + ["interface"] = parse_record_body, ["record"] = parse_record_body, ["enum"] = parse_enum_body, } @@ -11022,6 +11027,7 @@ local typename_to_typecode = { ["map"] = tl.typecodes.MAP, ["tupletable"] = tl.typecodes.TUPLE, ["arrayrecord"] = tl.typecodes.ARRAYRECORD, + ["interface"] = tl.typecodes.INTERFACE, ["record"] = tl.typecodes.RECORD, ["enum"] = tl.typecodes.ENUM, ["boolean"] = tl.typecodes.BOOLEAN, diff --git a/tl.tl b/tl.tl index 316593544..c2ca700ef 100644 --- a/tl.tl +++ b/tl.tl @@ -191,6 +191,7 @@ tl.typecodes = { TUPLE = 0x00080008, EMPTY_TABLE = 0x00000008, ENUM = 0x00010004, + INTERFACE = 0x00100008, -- Teal type masks IS_ARRAY = 0x00010008, IS_RECORD = 0x00020008, @@ -1001,6 +1002,7 @@ local enum TypeName "tupletable" "arrayrecord" "record" + "interface" "enum" "boolean" "string" @@ -1031,6 +1033,7 @@ local table_types : {TypeName:boolean} = { ["map"] = true, ["arrayrecord"] = true, ["record"] = true, + ["interface"] = true, ["emptytable"] = true, ["tupletable"] = true, @@ -1209,6 +1212,7 @@ local enum NodeKind "..." "paren" "macroexp" + "interface" "error_node" end @@ -3004,6 +3008,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end parse_type_body_fns = { + ["interface"] = parse_record_body, ["record"] = parse_record_body, ["enum"] = parse_enum_body, } @@ -11032,6 +11037,7 @@ local typename_to_typecode : {TypeName:integer} = { ["map"] = tl.typecodes.MAP, ["tupletable"] = tl.typecodes.TUPLE, ["arrayrecord"] = tl.typecodes.ARRAYRECORD, + ["interface"] = tl.typecodes.INTERFACE, ["record"] = tl.typecodes.RECORD, ["enum"] = tl.typecodes.ENUM, ["boolean"] = tl.typecodes.BOOLEAN, From d34d144ef3e1bf36ebbda3c45023a1cf506b603f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 24 Sep 2023 00:27:43 -0300 Subject: [PATCH 014/224] interfaces: add "is " to record syntax Also check that types referenced in `is` exist. --- docs/grammar.md | 8 +++- spec/declaration/record_spec.lua | 21 +++++++++++ tl.lua | 65 ++++++++++++++++++++++---------- tl.tl | 65 ++++++++++++++++++++++---------- 4 files changed, 119 insertions(+), 40 deletions(-) diff --git a/docs/grammar.md b/docs/grammar.md index 79c539651..835610027 100644 --- a/docs/grammar.md +++ b/docs/grammar.md @@ -95,9 +95,11 @@ precedence, see below. * type ::= ‘(’ type ‘)’ | basetype {‘|’ basetype} +* nominal ::= Name {{‘.’ Name }} [typeargs] + * basetype ::= ‘string’ | ‘boolean’ | ‘nil’ | ‘number’ | * ‘{’ type {',' type} ‘}’ | ‘{’ type ‘:’ type ‘}’ | functiontype -* | Name {{‘.’ Name }} [typeargs] +* | nominal * typelist ::= type {‘,’ type} @@ -107,7 +109,9 @@ precedence, see below. * newtype ::= ‘record’ recordbody | ‘enum’ enumbody | type -* recordbody ::= [typeargs] {recordentry} ‘end’ +* interfacelist ::= nominal {‘,’ nominal} + +* recordbody ::= [typeargs] [‘is’ interfacelist] {recordentry} ‘end’ * recordentry ::= ‘userdata’ | ‘{’ type ‘}’ | * ‘type’ Name ‘=’ newtype | [‘metamethod’] recordkey ‘:’ type | diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 3275fa440..c78939baf 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -728,10 +728,31 @@ for i, name in ipairs({"records", "arrayrecords"}) do { y = 9, msg = "argument 1: userdata record doesn't match: Foo" }, nil })) + + it("reports error on unknown interfaces", util.check_type_error([[ + local record Foo ]]..pick(i, "is Bongo, Bingo", "is {number}, Bongo, Bingo")..[[ + userdata + a: number + end + ]], { + { y = 1, msg = "unknown type Bongo" }, + { y = 1, msg = "unknown type Bingo" }, + })) end) end describe("arrayrecord", function() + it("can be declared with is", util.check([[ + local record R1 + is {string} + + x: number + end + + local v: R1 = { x = 10 } + v[1] = "hello" + ]])) + it("assigning to array produces no warnings", util.check_warnings([[ local record R1 {string} diff --git a/tl.lua b/tl.lua index f04d205e3..88aaa5cf8 100644 --- a/tl.lua +++ b/tl.lua @@ -1246,6 +1246,7 @@ local table_types = { + local TruthyFact = {} @@ -1854,30 +1855,35 @@ local simple_types = { ["integer"] = INTEGER, } -local function parse_base_type(ps, i) +local function parse_simple_type_or_nominal(ps, i) local tk = ps.tokens[i].tk - if ps.tokens[i].kind == "identifier" then - local st = simple_types[tk] - if st then - return i + 1, st - end - local typ = new_type(ps, i, "nominal") - typ.names = { tk } + local st = simple_types[tk] + if st then + return i + 1, st + end + local typ = new_type(ps, i, "nominal") + typ.names = { tk } + i = i + 1 + while ps.tokens[i].tk == "." do i = i + 1 - while ps.tokens[i].tk == "." do + if ps.tokens[i].kind == "identifier" then + table.insert(typ.names, ps.tokens[i].tk) i = i + 1 - if ps.tokens[i].kind == "identifier" then - table.insert(typ.names, ps.tokens[i].tk) - i = i + 1 - else - return fail(ps, i, "syntax error, expected identifier") - end + else + return fail(ps, i, "syntax error, expected identifier") end + end - if ps.tokens[i].tk == "<" then - i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) - end - return i, typ + if ps.tokens[i].tk == "<" then + i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) + end + return i, typ +end + +local function parse_base_type(ps, i) + local tk = ps.tokens[i].tk + if ps.tokens[i].kind == "identifier" then + return parse_simple_type_or_nominal(ps, i) elseif tk == "{" then i = i + 1 local decl = new_type(ps, i, "array") @@ -2861,10 +2867,26 @@ local function parse_macroexp(ps, i) return i, node end +local function parse_interface_name(ps, i) + local istart = i + local typ + i, typ = parse_simple_type_or_nominal(ps, i) + if typ.typename ~= "nominal" then + return fail(ps, istart, "expected an interface") + end + return i, typ +end + parse_record_body = function(ps, i, def, node, name) local istart = i - 1 def.fields = {} def.field_order = {} + + if ps.tokens[i].tk == "is" then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + end + if ps.tokens[i].tk == "<" then i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end @@ -3487,6 +3509,11 @@ local function recurse_type(ast, visit) table.insert(xs, recurse_type(child, visit)) end end + if ast.interface_list then + for _, child in ipairs(ast.interface_list) do + recurse_type(child, visit) + end + end if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end diff --git a/tl.tl b/tl.tl index c2ca700ef..edc59ef20 100644 --- a/tl.tl +++ b/tl.tl @@ -1103,6 +1103,7 @@ local record Type missing: {string} -- records + interface_list: {Type} typeargs: {Type} fields: {string: Type} field_order: {string} @@ -1864,30 +1865,35 @@ local simple_types: {string:Type} = { ["integer"] = INTEGER, } -local function parse_base_type(ps: ParseState, i: integer): integer, Type, integer +local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer, Type local tk = ps.tokens[i].tk - if ps.tokens[i].kind == "identifier" then - local st = simple_types[tk] - if st then - return i + 1, st - end - local typ = new_type(ps, i, "nominal") - typ.names = { tk } + local st = simple_types[tk] + if st then + return i + 1, st + end + local typ = new_type(ps, i, "nominal") + typ.names = { tk } + i = i + 1 + while ps.tokens[i].tk == "." do i = i + 1 - while ps.tokens[i].tk == "." do + if ps.tokens[i].kind == "identifier" then + table.insert(typ.names, ps.tokens[i].tk) i = i + 1 - if ps.tokens[i].kind == "identifier" then - table.insert(typ.names, ps.tokens[i].tk) - i = i + 1 - else - return fail(ps, i, "syntax error, expected identifier") - end + else + return fail(ps, i, "syntax error, expected identifier") end + end - if ps.tokens[i].tk == "<" then - i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) - end - return i, typ + if ps.tokens[i].tk == "<" then + i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) + end + return i, typ +end + +local function parse_base_type(ps: ParseState, i: integer): integer, Type, integer + local tk = ps.tokens[i].tk + if ps.tokens[i].kind == "identifier" then + return parse_simple_type_or_nominal(ps, i) elseif tk == "{" then i = i + 1 local decl = new_type(ps, i, "array") @@ -2871,10 +2877,26 @@ local function parse_macroexp(ps: ParseState, i: integer): integer, Node return i, node end +local function parse_interface_name(ps: ParseState, i: integer): integer, Type, integer + local istart = i + local typ: Type + i, typ = parse_simple_type_or_nominal(ps, i) + if typ.typename ~= "nominal" then + return fail(ps, istart, "expected an interface") + end + return i, typ +end + parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node local istart = i - 1 def.fields = {} def.field_order = {} + + if ps.tokens[i].tk == "is" then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + end + if ps.tokens[i].tk == "<" then i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end @@ -3497,6 +3519,11 @@ local function recurse_type(ast: Type, visit: Visitor): T table.insert(xs, recurse_type(child, visit)) end end + if ast.interface_list then + for _, child in ipairs(ast.interface_list) do + recurse_type(child, visit) + end + end if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end From 0b494bfdcb47589969bece7bc5e2dfad155954a1 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 24 Sep 2023 00:09:11 -0300 Subject: [PATCH 015/224] interfaces: is_record_type --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 88aaa5cf8..bb26dc433 100644 --- a/tl.lua +++ b/tl.lua @@ -1450,7 +1450,7 @@ local function is_array_type(t) end local function is_record_type(t) - return t.typename == "record" or t.typename == "arrayrecord" + return t.typename == "record" or t.typename == "arrayrecord" or t.typename == "interface" end local function is_number_type(t) diff --git a/tl.tl b/tl.tl index edc59ef20..9f7886899 100644 --- a/tl.tl +++ b/tl.tl @@ -1460,7 +1460,7 @@ local function is_array_type(t:Type): boolean end local function is_record_type(t:Type): boolean - return t.typename == "record" or t.typename == "arrayrecord" + return t.typename == "record" or t.typename == "arrayrecord" or t.typename == "interface" end local function is_number_type(t:Type): boolean From 8f20b6e7e56427889dbcbb20485e42432c05250e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 24 Sep 2023 00:33:31 -0300 Subject: [PATCH 016/224] interfaces: implement show_type --- tl.lua | 2 ++ tl.tl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tl.lua b/tl.lua index bb26dc433..aa3525a4b 100644 --- a/tl.lua +++ b/tl.lua @@ -4989,6 +4989,8 @@ local function show_type_base(t, short, seen) return "{" .. show(t.elements) .. "}" elseif t.typename == "enum" then return t.names and table.concat(t.names, ".") or "enum" + elseif t.typename == "interface" then + return show_record_type("interface") elseif is_record_type(t) then return show_record_type("record") elseif t.typename == "function" then diff --git a/tl.tl b/tl.tl index 9f7886899..dbf334df8 100644 --- a/tl.tl +++ b/tl.tl @@ -4999,6 +4999,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "{" .. show(t.elements) .. "}" elseif t.typename == "enum" then return t.names and table.concat(t.names, ".") or "enum" + elseif t.typename == "interface" then + return show_record_type("interface") elseif is_record_type(t) then return show_record_type("record") elseif t.typename == "function" then From 559d6c251569416a234cafd04a05274b12c099a1 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 24 Sep 2023 00:54:59 -0300 Subject: [PATCH 017/224] interfaces: add "where " syntax --- docs/grammar.md | 8 +++++--- tl.lua | 32 ++++++++++++++++++++++++++++++++ tl.tl | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/docs/grammar.md b/docs/grammar.md index 835610027..3e1a9c9cf 100644 --- a/docs/grammar.md +++ b/docs/grammar.md @@ -109,11 +109,13 @@ precedence, see below. * newtype ::= ‘record’ recordbody | ‘enum’ enumbody | type -* interfacelist ::= nominal {‘,’ nominal} +* interfacelist ::= nominal {‘,’ nominal} | +* ‘{’ type ‘}’ {‘,’ nominal} -* recordbody ::= [typeargs] [‘is’ interfacelist] {recordentry} ‘end’ +* recordbody ::= [typeargs] [‘is’ interfacelist] +* [‘where’ exp] {recordentry} ‘end’ -* recordentry ::= ‘userdata’ | ‘{’ type ‘}’ | +* recordentry ::= ‘userdata’ | * ‘type’ Name ‘=’ newtype | [‘metamethod’] recordkey ‘:’ type | * ‘record’ recordbody | ‘enum’ enumbody diff --git a/tl.lua b/tl.lua index aa3525a4b..cf496e749 100644 --- a/tl.lua +++ b/tl.lua @@ -2867,6 +2867,20 @@ local function parse_macroexp(ps, i) return i, node end +local function parse_where_clause(ps, i) + local node = new_node(ps.tokens, i, "macroexp") + node.args = new_node(ps.tokens, i, "argument_list") + node.args[1] = new_node(ps.tokens, i, "argument") + node.args[1].tk = "self" + node.args[1].decltype = new_type(ps, i, "nominal") + node.args[1].decltype.names = { "@self" } + node.rets = new_type(ps, i, "tuple") + node.rets[1] = BOOLEAN + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node +end + local function parse_interface_name(ps, i) local istart = i local typ @@ -2887,6 +2901,24 @@ parse_record_body = function(ps, i, def, node, name) i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + if ps.tokens[i].tk == "where" then + local wstart = i + i = i + 1 + local where_macroexp + i, where_macroexp = parse_where_clause(ps, i) + + def.meta_fields = {} + def.meta_field_order = {} + + local typ = new_type(ps, wstart, "function") + typ.is_method = true + typ.args = a_type({ typename = "tuple", a_type({ typename = "nominal", y = typ.y, x = typ.x, names = { "@self" } }) }) + typ.rets = a_type({ typename = "tuple", a_type({ typename = "boolean" }) }) + typ.macroexp = where_macroexp + + store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) + end + if ps.tokens[i].tk == "<" then i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end diff --git a/tl.tl b/tl.tl index dbf334df8..325a343a2 100644 --- a/tl.tl +++ b/tl.tl @@ -2877,6 +2877,20 @@ local function parse_macroexp(ps: ParseState, i: integer): integer, Node return i, node end +local function parse_where_clause(ps: ParseState, i: integer): integer, Node + local node = new_node(ps.tokens, i, "macroexp") + node.args = new_node(ps.tokens, i, "argument_list") + node.args[1] = new_node(ps.tokens, i, "argument") + node.args[1].tk = "self" + node.args[1].decltype = new_type(ps, i, "nominal") + node.args[1].decltype.names = { "@self" } + node.rets = new_type(ps, i, "tuple") + node.rets[1] = BOOLEAN + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node +end + local function parse_interface_name(ps: ParseState, i: integer): integer, Type, integer local istart = i local typ: Type @@ -2897,6 +2911,24 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + if ps.tokens[i].tk == "where" then + local wstart = i + i = i + 1 + local where_macroexp: Node + i, where_macroexp = parse_where_clause(ps, i) + + def.meta_fields = {} + def.meta_field_order = {} + + local typ = new_type(ps, wstart, "function") + typ.is_method = true + typ.args = a_type { typename = "tuple", a_type { typename = "nominal", y = typ.y, x = typ.x, names = { "@self" } } } + typ.rets = a_type { typename = "tuple", a_type { typename = "boolean" } } + typ.macroexp = where_macroexp + + store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) + end + if ps.tokens[i].tk == "<" then i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end From 9726088282464d479fabea6c0ce752113616c2dc Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 24 Sep 2023 01:10:12 -0300 Subject: [PATCH 018/224] interfaces/records: move arrayrecord to `is {...}` This avoids a grammar ambiguity in the `where` clause, where a following `{T}` would be mistaken as being part of the `where` expression. We still support a `{T}` at the top of the record declaration if there is no `is` clause, but not in the middle of the field list anymore. --- tl.lua | 55 +++++++++++++++++++++++++++++++++++++------------------ tl.tl | 55 +++++++++++++++++++++++++++++++++++++------------------ 2 files changed, 74 insertions(+), 36 deletions(-) diff --git a/tl.lua b/tl.lua index cf496e749..ba849d0ae 100644 --- a/tl.lua +++ b/tl.lua @@ -2891,14 +2891,49 @@ local function parse_interface_name(ps, i) return i, typ end +local function parse_arrayrecord_declaration(ps, i, def) + if def.typename == "arrayrecord" then + i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type) + else + i = i + 1 + local t + i, t = parse_type(ps, i) + if ps.tokens[i].tk == "}" then + i = verify_tk(ps, i, "}") + else + return fail(ps, i, "expected an array declaration") + end + def.typename = "arrayrecord" + def.elements = t + end + return i +end + parse_record_body = function(ps, i, def, node, name) local istart = i - 1 def.fields = {} def.field_order = {} + if ps.tokens[i].tk == "<" then + i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + + if ps.tokens[i].tk == "{" then + i = parse_arrayrecord_declaration(ps, i, def) + end + if ps.tokens[i].tk == "is" then i = i + 1 - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + + if ps.tokens[i].tk == "{" then + i = parse_arrayrecord_declaration(ps, i, def) + if ps.tokens[i].tk == "," then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + end + else + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + end end if ps.tokens[i].tk == "where" then @@ -2919,9 +2954,6 @@ parse_record_body = function(ps, i, def, node, name) store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) end - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do local tn = ps.tokens[i].tk if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then @@ -2932,20 +2964,7 @@ parse_record_body = function(ps, i, def, node, name) end i = i + 1 elseif ps.tokens[i].tk == "{" then - if def.typename == "arrayrecord" then - i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type) - else - i = i + 1 - local t - i, t = parse_type(ps, i) - if ps.tokens[i].tk == "}" then - i = verify_tk(ps, i, "}") - else - return fail(ps, i, "expected an array declaration") - end - def.typename = "arrayrecord" - def.elements = t - end + return fail(ps, i, "syntax error: this syntax is no longer valid; declare arrayrecord at the top with 'is {...}'") elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i diff --git a/tl.tl b/tl.tl index 325a343a2..62656af38 100644 --- a/tl.tl +++ b/tl.tl @@ -2901,14 +2901,49 @@ local function parse_interface_name(ps: ParseState, i: integer): integer, Type, return i, typ end +local function parse_arrayrecord_declaration(ps: ParseState, i: integer, def: Type): integer + if def.typename == "arrayrecord" then + i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type as SkipFunction) + else + i = i + 1 + local t: Type + i, t = parse_type(ps, i) + if ps.tokens[i].tk == "}" then + i = verify_tk(ps, i, "}") + else + return fail(ps, i, "expected an array declaration") + end + def.typename = "arrayrecord" + def.elements = t + end + return i +end + parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node local istart = i - 1 def.fields = {} def.field_order = {} + if ps.tokens[i].tk == "<" then + i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + + if ps.tokens[i].tk == "{" then + i = parse_arrayrecord_declaration(ps, i, def) + end + if ps.tokens[i].tk == "is" then i = i + 1 - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + + if ps.tokens[i].tk == "{" then + i = parse_arrayrecord_declaration(ps, i, def) + if ps.tokens[i].tk == "," then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + end + else + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + end end if ps.tokens[i].tk == "where" then @@ -2929,9 +2964,6 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) end - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do local tn = ps.tokens[i].tk as TypeName if ps.tokens[i].tk == "userdata" and ps.tokens[i+1].tk ~= ":" then @@ -2942,20 +2974,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end i = i + 1 elseif ps.tokens[i].tk == "{" then - if def.typename == "arrayrecord" then - i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type as SkipFunction) - else - i = i + 1 - local t: Type - i, t = parse_type(ps, i) - if ps.tokens[i].tk == "}" then - i = verify_tk(ps, i, "}") - else - return fail(ps, i, "expected an array declaration") - end - def.typename = "arrayrecord" - def.elements = t - end + return fail(ps, i, "syntax error: this syntax is no longer valid; declare arrayrecord at the top with 'is {...}'") elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i From 27124cf45013576169386a3cec0d6b9db77eb3b8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 1 Oct 2023 01:45:53 -0300 Subject: [PATCH 019/224] refactor: use `where` clause in fact interfaces --- tl.lua | 22 +++++++--------------- tl.tl | 38 ++++++++++---------------------------- 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/tl.lua b/tl.lua index ba849d0ae..1f248d086 100644 --- a/tl.lua +++ b/tl.lua @@ -1242,9 +1242,6 @@ local table_types = { - - - @@ -1258,8 +1255,6 @@ local TruthyFact = {} - - local NotFact = {} @@ -1271,8 +1266,6 @@ local NotFact = {} - - local AndFact = {} @@ -1285,8 +1278,6 @@ local AndFact = {} - - local OrFact = {} @@ -1299,8 +1290,6 @@ local OrFact = {} - - local EqFact = {} @@ -1313,8 +1302,6 @@ local EqFact = {} - - local IsFact = {} @@ -1336,8 +1323,6 @@ local IsFact = {} - - @@ -1437,6 +1422,13 @@ local Node = {ExpectedContext = {}, } + + + + + + + diff --git a/tl.tl b/tl.tl index 62656af38..3f447e4a1 100644 --- a/tl.tl +++ b/tl.tl @@ -1067,11 +1067,8 @@ local table_types : {TypeName:boolean} = { } local record Type - {Type} - - metamethod __is: function(self: Type) = macroexp(self: Type) - self.typename ~= nil - end + is {Type} + where self.typename ~= nil y: integer x: integer @@ -1250,9 +1247,7 @@ local type Fact | EqFact local record TruthyFact - metamethod __is: function(self: Fact): boolean = macroexp(self: TruthyFact): boolean - self.fact == "truthy" - end + where self.fact == "truthy" fact: FactType where: Where @@ -1261,9 +1256,7 @@ local record TruthyFact end local record NotFact - metamethod __is: function(self: Fact): boolean = macroexp(self: NotFact): boolean - self.fact == "not" - end + where self.fact == "not" fact: FactType where: Where @@ -1274,9 +1267,7 @@ local record NotFact end local record AndFact - metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean - self.fact == "and" - end + where self.fact == "and" fact: FactType where: Where @@ -1288,9 +1279,7 @@ local record AndFact end local record OrFact - metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean - self.fact == "or" - end + where self.fact == "or" fact: FactType where: Where @@ -1302,9 +1291,7 @@ local record OrFact end local record EqFact - metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean - self.fact == "==" - end + where self.fact == "==" fact: FactType where: Where @@ -1316,9 +1303,7 @@ local record EqFact end local record IsFact - metamethod __is: function(self: Fact): boolean = macroexp(self: AndFact): boolean - self.fact == "is" - end + where self.fact == "is" fact: FactType where: Where @@ -1349,11 +1334,8 @@ local attributes : {Attribute: boolean} = { local is_attribute : {string:boolean} = attributes as {string:boolean} local record Node - {Node} - - metamethod __is: function(self: Node) = macroexp(self: Node) - self.kind ~= nil - end + is {Node} + where self.kind ~= nil record ExpectedContext kind: NodeKind From bdc937f939910724dd69114fea746b211c1b8264 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 20 Nov 2023 01:13:12 -0300 Subject: [PATCH 020/224] remove arrayrecord: now defined as array interface in records --- docs/tutorial.md | 15 ++-- tl.lua | 173 ++++++++++++++++++++++++++------------------- tl.tl | 179 +++++++++++++++++++++++++++-------------------- 3 files changed, 209 insertions(+), 158 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index 57b836411..68863444b 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -155,7 +155,6 @@ Finally, there are types that must be declared and referred to using names: * enum * record * userdata - * arrayrecord Here is an example declaration of each. Again, we'll go into more detail below, but this should give you an overview: @@ -180,9 +179,8 @@ local record File close: function(File): boolean, string end --- an arrayrecord: a record which doubles as a record and an array -local record TreeNode - {TreeNode} +-- a record can doubles as a record and an array, by declaring an array interface +local record TreeNode is {TreeNode} item: T end ``` @@ -489,13 +487,12 @@ local record Obj end ``` -A record can also have an array part, making it an "arrayrecord". The -following is an arrayrecord. You can use it both as a record, accessing its -fields by name, and as an array, accessing its entries by number. +A record can also store array data, by declaring an array interface. You can +use it both as a record, accessing its fields by name, and as an array, +accessing its entries by number. A record can have only one array interface. ```lua -local record Node - {Node} +local record Node is {Node} weight: number name: string end diff --git a/tl.lua b/tl.lua index 1f248d086..b1f81730b 100644 --- a/tl.lua +++ b/tl.lua @@ -167,6 +167,7 @@ tl.warning_kinds = wk + tl.typecodes = { NIL = 0x00000001, @@ -186,7 +187,6 @@ tl.typecodes = { INTEGER = 0x00010002, ARRAY = 0x00010008, RECORD = 0x00020008, - ARRAYRECORD = 0x00030008, MAP = 0x00040008, TUPLE = 0x00080008, EMPTY_TABLE = 0x00000008, @@ -1025,13 +1025,11 @@ end - local table_types = { ["array"] = true, ["map"] = true, - ["arrayrecord"] = true, ["record"] = true, ["interface"] = true, ["emptytable"] = true, @@ -1438,11 +1436,12 @@ local Node = {ExpectedContext = {}, } local function is_array_type(t) - return t.typename == "array" or t.typename == "arrayrecord" + + return t.typename == "array" or t.elements ~= nil end local function is_record_type(t) - return t.typename == "record" or t.typename == "arrayrecord" or t.typename == "interface" + return t.typename == "record" or t.typename == "interface" end local function is_number_type(t) @@ -2883,22 +2882,21 @@ local function parse_interface_name(ps, i) return i, typ end -local function parse_arrayrecord_declaration(ps, i, def) - if def.typename == "arrayrecord" then - i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type) - else - i = i + 1 - local t - i, t = parse_type(ps, i) - if ps.tokens[i].tk == "}" then - i = verify_tk(ps, i, "}") - else - return fail(ps, i, "expected an array declaration") - end - def.typename = "arrayrecord" - def.elements = t +local function parse_array_interface_type(ps, i, def) + if def.interface_list and def.interface_list[1].typename == "array" then + return failskip(ps, i, "duplicated declaration of array element type", parse_type) end - return i + local t + i, t = parse_base_type(ps, i) + if not t then + return i + end + if t.typename ~= "array" then + fail(ps, i, "expected an array declaration") + return i + end + def.elements = t.elements + return i, t end parse_record_body = function(ps, i, def, node, name) @@ -2911,17 +2909,27 @@ parse_record_body = function(ps, i, def, node, name) end if ps.tokens[i].tk == "{" then - i = parse_arrayrecord_declaration(ps, i, def) + local atype + i, atype = parse_array_interface_type(ps, i, def) + if atype then + def.interface_list = { atype } + end end if ps.tokens[i].tk == "is" then i = i + 1 if ps.tokens[i].tk == "{" then - i = parse_arrayrecord_declaration(ps, i, def) + local atype + i, atype = parse_array_interface_type(ps, i, def) if ps.tokens[i].tk == "," then i = i + 1 i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + else + def.interface_list = {} + end + if atype then + table.insert(def.interface_list, 1, atype) end else i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) @@ -2956,7 +2964,7 @@ parse_record_body = function(ps, i, def, node, name) end i = i + 1 elseif ps.tokens[i].tk == "{" then - return fail(ps, i, "syntax error: this syntax is no longer valid; declare arrayrecord at the top with 'is {...}'") + return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i @@ -4616,7 +4624,6 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] visit_type.cbs["tupletable"] = visit_type.cbs["string"] - visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] visit_type.cbs["record"] = visit_type.cbs["string"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] @@ -4759,20 +4766,11 @@ local equality_binop = { }, ["record"] = { ["emptytable"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, ["record"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["array"] = { ["emptytable"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, - }, - ["arrayrecord"] = { - ["emptytable"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, - ["record"] = BOOLEAN, ["array"] = BOOLEAN, ["nil"] = BOOLEAN, }, @@ -4789,7 +4787,6 @@ local equality_binop = { local unop_types = { ["#"] = { - ["arrayrecord"] = INTEGER, ["string"] = INTEGER, ["array"] = INTEGER, ["tupletable"] = INTEGER, @@ -4810,7 +4807,6 @@ local unop_types = { ["integer"] = BOOLEAN, ["boolean"] = BOOLEAN, ["record"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, ["array"] = BOOLEAN, ["tupletable"] = BOOLEAN, ["map"] = BOOLEAN, @@ -4873,9 +4869,6 @@ local binop_types = { ["record"] = { ["boolean"] = BOOLEAN, }, - ["arrayrecord"] = { - ["boolean"] = BOOLEAN, - }, ["map"] = { ["boolean"] = BOOLEAN, }, @@ -6425,7 +6418,7 @@ tl.type_check = function(ast, opts) copy.is_method = t.is_method copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) - elseif t.typename == "record" or t.typename == "arrayrecord" then + elseif is_record_type(t) then if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6433,6 +6426,7 @@ tl.type_check = function(ast, opts) end end + if t.elements then copy.elements, same = resolve(t.elements, same) end @@ -7188,6 +7182,17 @@ tl.type_check = function(ast, opts) elseif t1.typename == "nominal" then return are_same_nominals(t1, t2) elseif t1.typename == "record" then + + if (t1.elements ~= nil) ~= (t2.elements ~= nil) then + return false, { Err(t1, "types do not have the same array interface") } + end + if t1.elements and t2.elements then + local ok, errs = same_type(t1.elements, t2.elements) + if not ok then + return ok, errs + end + end + return invariant_match_fields_to_record(t1, t2) elseif t1.typename == "function" then local argdelta = t1.is_method and 1 or 0 @@ -7209,12 +7214,6 @@ tl.type_check = function(ast, opts) add_errs_prefixing(t1, errs, all_errs, "return " .. i) end return any_errors(all_errs) - elseif t1.typename == "arrayrecord" then - local ok, errs = same_type(t1.elements, t2.elements) - if not ok then - return ok, errs - end - return invariant_match_fields_to_record(t1, t2) end return true end @@ -7303,7 +7302,6 @@ tl.type_check = function(ast, opts) array = true, map = true, record = true, - arrayrecord = true, tupletable = true, } @@ -7547,15 +7545,9 @@ tl.type_check = function(ast, opts) return combine_map_errs(errs_keys, errs_values) end elseif t2.typename == "record" then - if is_record_type(t1) then - return match_fields_to_record(t1, t2) - elseif is_typetype(t1) and is_record_type(t1.def) then - return is_a(t1.def, t2, for_equality) - end - elseif t2.typename == "arrayrecord" then - if t1.typename == "array" then - return is_a(t1.elements, t2.elements) - elseif t1.typename == "tupletable" then + + + if t1.typename == "tupletable" and t2.elements then if t2.inferred_len and t2.inferred_len > #t1.types then return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } end @@ -7567,12 +7559,17 @@ tl.type_check = function(ast, opts) return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } end return true - elseif t1.typename == "record" then - return match_fields_to_record(t1, t2) - elseif t1.typename == "arrayrecord" then + end + if t1.elements and t2.elements then if not is_a(t1.elements, t2.elements) then return false, { Err(t1, "array parts have incompatible element types") } end + if t1.typename == "array" then + return true + end + end + + if is_record_type(t1) then return match_fields_to_record(t1, t2) elseif is_typetype(t1) and is_record_type(t1.def) then return is_a(t1.def, t2, for_equality) @@ -9366,7 +9363,17 @@ tl.type_check = function(ast, opts) typ.elements = nil node_error(node, "cannot determine type of table literal") elseif is_record and is_array then - typ.typename = "arrayrecord" + typ.typename = "record" + typ.interface_list = { + a_type({ + filename = filename, + y = node.y, + x = node.x, + typename = "array", + elements = typ.elements, + }), + } + elseif is_record and is_map then if typ.keys.typename == "string" then typ.typename = "map" @@ -9379,12 +9386,8 @@ tl.type_check = function(ast, opts) node_error(node, "cannot determine type of table literal") end elseif is_array then - if is_not_tuple then - typ.typename = "array" - typ.inferred_len = largest_array_idx - 1 - else - local pure_array = true - + local pure_array = true + if not is_not_tuple then local last_t for _, current_t in pairs(typ.types) do if last_t then @@ -9395,13 +9398,16 @@ tl.type_check = function(ast, opts) end last_t = current_t end + end + if pure_array then + typ.typename = "array" - if not pure_array then - typ.typename = "tupletable" - else - typ.typename = "array" - typ.inferred_len = largest_array_idx - 1 - end + assert(typ.elements) + typ.inferred_len = largest_array_idx - 1 + else + typ.typename = "tupletable" + typ.elements = nil + assert(typ.types) end elseif is_record then typ.typename = "record" @@ -9550,6 +9556,21 @@ tl.type_check = function(ast, opts) return is_total, missing end + local function find_in_interface_list(a, f) + if not a.interface_list then + return nil + end + + for _, t in ipairs(a.interface_list) do + local ret = f(t) + if ret then + return ret + end + end + + return nil + end + local visit_node = {} visit_node.cbs = { @@ -10634,6 +10655,13 @@ tl.type_check = function(ast, opts) local types_op = unop_types[node.op.op] node.type = types_op[a.typename] + + if not node.type then + node.type = find_in_interface_list(a, function(t) + return types_op[t.typename] + end) + end + local meta_on_operator if not node.type then local mt_name = unop_to_metamethod[node.op.op] @@ -10644,6 +10672,7 @@ tl.type_check = function(ast, opts) node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) end end + if a.typename == "map" then if a.keys.typename == "number" or a.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") @@ -11035,7 +11064,6 @@ tl.type_check = function(ast, opts) visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] - visit_type.cbs["arrayrecord"] = visit_type.cbs["record"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] visit_type.cbs["nil"] = visit_type.cbs["string"] @@ -11098,7 +11126,6 @@ local typename_to_typecode = { ["array"] = tl.typecodes.ARRAY, ["map"] = tl.typecodes.MAP, ["tupletable"] = tl.typecodes.TUPLE, - ["arrayrecord"] = tl.typecodes.ARRAYRECORD, ["interface"] = tl.typecodes.INTERFACE, ["record"] = tl.typecodes.RECORD, ["enum"] = tl.typecodes.ENUM, diff --git a/tl.tl b/tl.tl index 3f447e4a1..eeaf749d3 100644 --- a/tl.tl +++ b/tl.tl @@ -104,7 +104,7 @@ local record tl x: integer y: integer ref: integer -- NOMINAL - fields: {string: integer} -- RECORD, ARRAYRECORD + fields: {string: integer} -- RECORD enums: {string} -- ENUM args: {{integer, string}} -- FUNCTION rets: {{integer, string}} -- FUNCTION @@ -162,7 +162,8 @@ tl.warning_kinds = wk -- * "any" satisfies all Lua masks -- * bits 30-27: if valid: other Teal types ("nominal", "poly", "union", "typevar") -- * bits 24-26: reserved --- * bits 16-19: if valid: Teal types ("array", "record", "arrayrecord", "map", "tuple", "enum") that map to a Lua type ("table", "string") +-- * bits 20-23: abstract types ("interface") +-- * bits 16-19: if valid: Teal types ("array", "record", "map", "tuple", "enum") that map to a Lua type ("table", "string") -- * bit 15: if not valid: value is unknown -- * bits 8-14: reserved -- * bits 0-7: (LSB) Lua types, one bit for each ("nil", "number", "boolean", "string", table, "function", "userdata", "thread") @@ -186,7 +187,6 @@ tl.typecodes = { INTEGER = 0x00010002, ARRAY = 0x00010008, RECORD = 0x00020008, - ARRAYRECORD = 0x00030008, MAP = 0x00040008, TUPLE = 0x00080008, EMPTY_TABLE = 0x00000008, @@ -1000,7 +1000,6 @@ local enum TypeName "array" "map" "tupletable" - "arrayrecord" "record" "interface" "enum" @@ -1031,7 +1030,6 @@ end local table_types : {TypeName:boolean} = { ["array"] = true, ["map"] = true, - ["arrayrecord"] = true, ["record"] = true, ["interface"] = true, ["emptytable"] = true, @@ -1438,11 +1436,12 @@ local type Where | Type local function is_array_type(t:Type): boolean - return t.typename == "array" or t.typename == "arrayrecord" + -- checking array interface + return t.typename == "array" or t.elements ~= nil end local function is_record_type(t:Type): boolean - return t.typename == "record" or t.typename == "arrayrecord" or t.typename == "interface" + return t.typename == "record" or t.typename == "interface" end local function is_number_type(t:Type): boolean @@ -2883,22 +2882,21 @@ local function parse_interface_name(ps: ParseState, i: integer): integer, Type, return i, typ end -local function parse_arrayrecord_declaration(ps: ParseState, i: integer, def: Type): integer - if def.typename == "arrayrecord" then - i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type as SkipFunction) - else - i = i + 1 - local t: Type - i, t = parse_type(ps, i) - if ps.tokens[i].tk == "}" then - i = verify_tk(ps, i, "}") - else - return fail(ps, i, "expected an array declaration") - end - def.typename = "arrayrecord" - def.elements = t +local function parse_array_interface_type(ps: ParseState, i: integer, def: Type): integer, Type + if def.interface_list and def.interface_list[1].typename == "array" then + return failskip(ps, i, "duplicated declaration of array element type", parse_type as SkipFunction) end - return i + local t: Type + i, t = parse_base_type(ps, i) + if not t then + return i + end + if t.typename ~= "array" then + fail(ps, i, "expected an array declaration") + return i + end + def.elements = t.elements + return i, t end parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node @@ -2911,17 +2909,27 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end if ps.tokens[i].tk == "{" then - i = parse_arrayrecord_declaration(ps, i, def) + local atype: Type + i, atype = parse_array_interface_type(ps, i, def) + if atype then + def.interface_list = { atype } + end end if ps.tokens[i].tk == "is" then i = i + 1 if ps.tokens[i].tk == "{" then - i = parse_arrayrecord_declaration(ps, i, def) + local atype: Type + i, atype = parse_array_interface_type(ps, i, def) if ps.tokens[i].tk == "," then i = i + 1 i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + else + def.interface_list = {} + end + if atype then + table.insert(def.interface_list, 1, atype) end else i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) @@ -2956,7 +2964,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end i = i + 1 elseif ps.tokens[i].tk == "{" then - return fail(ps, i, "syntax error: this syntax is no longer valid; declare arrayrecord at the top with 'is {...}'") + return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i @@ -4616,7 +4624,6 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] visit_type.cbs["tupletable"] = visit_type.cbs["string"] - visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] visit_type.cbs["record"] = visit_type.cbs["string"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] @@ -4759,20 +4766,11 @@ local equality_binop = { }, ["record"] = { ["emptytable"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, ["record"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["array"] = { ["emptytable"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, - }, - ["arrayrecord"] = { - ["emptytable"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, - ["record"] = BOOLEAN, ["array"] = BOOLEAN, ["nil"] = BOOLEAN, }, @@ -4789,7 +4787,6 @@ local equality_binop = { local unop_types: {string:{string:Type}} = { ["#"] = { - ["arrayrecord"] = INTEGER, ["string"] = INTEGER, ["array"] = INTEGER, ["tupletable"] = INTEGER, @@ -4810,7 +4807,6 @@ local unop_types: {string:{string:Type}} = { ["integer"] = BOOLEAN, ["boolean"] = BOOLEAN, ["record"] = BOOLEAN, - ["arrayrecord"] = BOOLEAN, ["array"] = BOOLEAN, ["tupletable"] = BOOLEAN, ["map"] = BOOLEAN, @@ -4873,9 +4869,6 @@ local binop_types: {string:{TypeName:{TypeName:Type}}} = { ["record"] = { ["boolean"] = BOOLEAN, }, - ["arrayrecord"] = { - ["boolean"] = BOOLEAN, - }, ["map"] = { ["boolean"] = BOOLEAN, }, @@ -6425,7 +6418,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.is_method = t.is_method copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) - elseif t.typename == "record" or t.typename == "arrayrecord" then + elseif is_record_type(t) then if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6433,6 +6426,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + -- checking array interface if t.elements then copy.elements, same = resolve(t.elements, same) end @@ -7188,6 +7182,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t1.typename == "nominal" then return are_same_nominals(t1, t2) elseif t1.typename == "record" then + -- checking array interface + if (t1.elements ~= nil) ~= (t2.elements ~= nil) then + return false, { Err(t1, "types do not have the same array interface") } + end + if t1.elements and t2.elements then + local ok, errs = same_type(t1.elements, t2.elements) + if not ok then + return ok, errs + end + end + return invariant_match_fields_to_record(t1, t2) elseif t1.typename == "function" then local argdelta = t1.is_method and 1 or 0 @@ -7209,12 +7214,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_errs_prefixing(t1 as Node, errs, all_errs, "return " .. i) end return any_errors(all_errs) - elseif t1.typename == "arrayrecord" then - local ok, errs = same_type(t1.elements, t2.elements) - if not ok then - return ok, errs - end - return invariant_match_fields_to_record(t1, t2) end return true end @@ -7303,7 +7302,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string array = true, map = true, record = true, - arrayrecord = true, tupletable = true, } @@ -7547,15 +7545,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return combine_map_errs(errs_keys, errs_values) end elseif t2.typename == "record" then - if is_record_type(t1) then - return match_fields_to_record(t1, t2) - elseif is_typetype(t1) and is_record_type(t1.def) then -- record as prototype - return is_a(t1.def, t2, for_equality) - end - elseif t2.typename == "arrayrecord" then - if t1.typename == "array" then - return is_a(t1.elements, t2.elements) - elseif t1.typename == "tupletable" then + + -- checking array interface + if t1.typename == "tupletable" and t2.elements then if t2.inferred_len and t2.inferred_len > #t1.types then return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } end @@ -7567,12 +7559,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } end return true - elseif t1.typename == "record" then - return match_fields_to_record(t1, t2) - elseif t1.typename == "arrayrecord" then + end + if t1.elements and t2.elements then if not is_a(t1.elements, t2.elements) then return false, { Err(t1, "array parts have incompatible element types") } end + if t1.typename == "array" then + return true + end + end + + if is_record_type(t1) then return match_fields_to_record(t1, t2) elseif is_typetype(t1) and is_record_type(t1.def) then -- record as prototype return is_a(t1.def, t2, for_equality) @@ -9366,7 +9363,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ.elements = nil node_error(node, "cannot determine type of table literal") elseif is_record and is_array then - typ.typename = "arrayrecord" + typ.typename = "record" + typ.interface_list = { + a_type { + filename = filename, + y = node.y, + x = node.x, + typename = "array", + elements = typ.elements, + } + } + -- TODO adopt logic from is_array below when we accept tupletable as an interface elseif is_record and is_map then if typ.keys.typename == "string" then typ.typename = "map" @@ -9379,12 +9386,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node_error(node, "cannot determine type of table literal") end elseif is_array then - if is_not_tuple then - typ.typename = "array" - typ.inferred_len = largest_array_idx - 1 - else - local pure_array = true - + local pure_array = true + if not is_not_tuple then local last_t: Type for _, current_t in pairs(typ.types as {integer:Type}) do if last_t then @@ -9395,13 +9398,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end last_t = current_t end - - if not pure_array then - typ.typename = "tupletable" - else - typ.typename = "array" - typ.inferred_len = largest_array_idx - 1 - end + end + if pure_array then + typ.typename = "array" + -- typ.types = nil + assert(typ.elements) + typ.inferred_len = largest_array_idx - 1 + else + typ.typename = "tupletable" + typ.elements = nil + assert(typ.types) end elseif is_record then typ.typename = "record" @@ -9550,6 +9556,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end + local function find_in_interface_list(a: Type, f: function(Type): T): T + if not a.interface_list then + return nil + end + + for _, t in ipairs(a.interface_list) do + local ret = f(t) + if ret then + return ret + end + end + + return nil + end + local visit_node: Visitor = {} visit_node.cbs = { @@ -10634,6 +10655,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local types_op = unop_types[node.op.op] node.type = types_op[a.typename] + + if not node.type then + node.type = find_in_interface_list(a, function(t: Type): Type + return types_op[t.typename] + end) + end + local meta_on_operator: integer if not node.type then local mt_name = unop_to_metamethod[node.op.op] @@ -10644,6 +10672,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) end end + if a.typename == "map" then if a.keys.typename == "number" or a.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") @@ -11035,7 +11064,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] - visit_type.cbs["arrayrecord"] = visit_type.cbs["record"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] visit_type.cbs["nil"] = visit_type.cbs["string"] @@ -11098,7 +11126,6 @@ local typename_to_typecode : {TypeName:integer} = { ["array"] = tl.typecodes.ARRAY, ["map"] = tl.typecodes.MAP, ["tupletable"] = tl.typecodes.TUPLE, - ["arrayrecord"] = tl.typecodes.ARRAYRECORD, ["interface"] = tl.typecodes.INTERFACE, ["record"] = tl.typecodes.RECORD, ["enum"] = tl.typecodes.ENUM, From bf754030bf39529d79b333fd2422a0a27c87c4bb Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 20 Nov 2023 01:23:33 -0300 Subject: [PATCH 021/224] interfaces: prevent assignment to interfaces and types Interfaces are abstract, so they cannot be assigned to, unlike record fields, which are concrete values when records are used as prototypes. We already prevented assignment overwriting top-level records. This commit also extends the restriction to nested records, for consistency. --- spec/assignment/to_interface_spec.lua | 59 ++++ spec/declaration/record_spec.lua | 415 ++++++++++++++++---------- tl.lua | 55 ++-- tl.tl | 57 ++-- 4 files changed, 390 insertions(+), 196 deletions(-) create mode 100644 spec/assignment/to_interface_spec.lua diff --git a/spec/assignment/to_interface_spec.lua b/spec/assignment/to_interface_spec.lua new file mode 100644 index 000000000..9f384276d --- /dev/null +++ b/spec/assignment/to_interface_spec.lua @@ -0,0 +1,59 @@ +local util = require("spec.util") + +local scopes = {} +for _, fst in ipairs({"def", "var"}) do + for _, snd in ipairs({"def", "var"}) do + table.insert(scopes, "to outer " .. fst .. " with inner " .. snd) + table.insert(scopes, "to inner " .. fst .. " with outer " .. snd) + end +end + +local assignments = { + ["to outer def with inner var"] = "Outer = { field = { x = 42 } }", -- always fails (1) + ["to outer def with inner def"] = "Outer = { Inner = { x = 42 } }", -- always fails (1) + ["to outer var with inner var"] = "local v: Outer = { field = { x = 42 } }", -- always succeeds + ["to outer var with inner def"] = "local v: Outer = { Inner = { x = 42 } }", -- always fails (2) + ["to inner def with outer def"] = "Outer.Inner = { x = 42 }", -- always fails (3) + ["to inner def with outer var"] = "local v: Outer = {}; v.Inner = { x = 42 }", -- always fails (3) + ["to inner var with outer def"] = "Outer.field = { x = 42 }", -- succeeds in record only (4) + ["to inner var with outer var"] = "local v: Outer = {}; v.field = { x = 42 }", -- always succeeds +} + +describe("assignment", function() + for _, outer in ipairs({"record", "interface"}) do + for _, inner in ipairs({"record", "interface"}) do + for _, scope in ipairs(scopes) do + assert(assignments[scope]) + + local err + if scope:match("to outer def") then -- 1 + err = { { y = 6, msg = "cannot reassign a type" } } + elseif scope:match("with inner def") then -- 2 + err = { { y = 6, msg = "cannot reassign a type" } } + elseif scope:match("to inner def") then -- 3 + if outer == "interface" and scope:match("with outer def") then + err = { { y = 6, msg = "interfaces are abstract; consider using a concrete record" } } + else + err = { { y = 6, msg = "cannot reassign a type" } } + end + elseif outer == "interface" and scope == "to inner var with outer def" then -- 4 + err = { { y = 6, msg = "interfaces are abstract; consider using a concrete record" } } + else + err = nil + end + + it((err and "fails" or "succeeds") .. " with outer " .. outer .. " and inner " .. inner .. ", assignment " .. scope, + (err and util.check_type_error or util.check)([[ + local type Outer = ]] .. outer .. [[ + type Inner = ]] .. inner .. [[ + x: number + end + field: Inner + end + + ]] .. assignments[scope] .. [[ + ]], err)) + end + end + end +end) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index c78939baf..067d94d9e 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -1,13 +1,18 @@ local util = require("spec.util") -local function pick(...) - return select(...) .. "\n" +local function array(i, arr, not_arr) + if i == 2 or i == 4 then + return arr .. "\n" + else + return (not_arr or "") .. "\n" + end end -for i, name in ipairs({"records", "arrayrecords"}) do - describe(name, function() +for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces"}) do + local statement = select(i, "record", "record", "interface", "interface") + describe("#" .. name, function() it("can be declared with 'local type'", util.check([[ - local type Point = record ]]..pick(i, "", "{Point}")..[[ + local type Point = ]]..statement..[[ ]]..array(i, "{Point}")..[[ x: number y: number end @@ -17,8 +22,8 @@ for i, name in ipairs({"records", "arrayrecords"}) do p.y = 12 ]])) - it("can be declared with 'local record'", util.check([[ - local record Point ]]..pick(i, "", "{Point}")..[[ + it("can be declared with 'local "..statement.."'", util.check([[ + local ]]..statement..[[ Point ]]..array(i, "{Point}")..[[ x: number y: number end @@ -29,7 +34,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) it("produces a nice error when declared with bare 'local'", util.check_syntax_error([[ - local Point = record ]]..pick(i, "", "{Point}")..[[ + local Point = ]]..statement..[[ ]]..array(i, "{Point}")..[[ x: number y: number end @@ -38,29 +43,31 @@ for i, name in ipairs({"records", "arrayrecords"}) do p.x = 12 p.y = 12 ]], { - { y = 1, msg = "syntax error: this syntax is no longer valid; use 'local record Point'" }, + { y = 1, msg = "syntax error: this syntax is no longer valid; use 'local "..statement.." Point'" }, })) it("produces a nice error when attempting to nest in a table", util.check_syntax_error([[ local t = { - Point = record ]]..pick(i, "", "{Point}")..[[ + Point = ]]..statement..[[ ]]..array(i, "{Point}")..[[ x: number y: number end } ]], { - { y = 2, msg = "syntax error: this syntax is no longer valid; declare nested record inside a record" }, + { y = 2, msg = (statement == "interface" + and "syntax error: cannot declare interface inside a table; use a statement" + or "syntax error: this syntax is no longer valid; declare nested record inside a record") }, })) - it("accepts record as soft keyword", util.check([[ - local record = 2 + it("accepts "..statement.." as soft keyword", util.check([[ + local ]]..statement..[[ = 2 local t = { - record = record, + ]]..statement..[[ = ]]..statement..[[, } ]])) it("can be declared with 'global type'", util.check([[ - global type Point = record ]]..pick(i, "", "{Point}")..[[ + global type Point = ]]..statement..[[ ]]..array(i, "{Point}")..[[ x: number y: number end @@ -70,8 +77,8 @@ for i, name in ipairs({"records", "arrayrecords"}) do p.y = 12 ]])) - it("can be declared with 'global record'", util.check([[ - global record Point ]]..pick(i, "", "{Point}")..[[ + it("can be declared with 'global "..statement.."'", util.check([[ + global ]]..statement..[[ Point ]]..array(i, "{Point}")..[[ x: number y: number end @@ -81,21 +88,33 @@ for i, name in ipairs({"records", "arrayrecords"}) do p.y = 12 ]])) - it("can have self-references", util.check([[ - local record SLAXML ]]..pick(i, "", "{SLAXML}")..[[ - parse: function(self: SLAXML, xml: string, anotherself: SLAXML) - end + if statement == "interface" then + it("can have self-references", util.check([[ + local interface SLAXML + parse: function(self: SLAXML, xml: string, anotherself: SLAXML) + end - local myxml = io.open('my.xml'):read('*all') - SLAXML:parse(myxml, SLAXML) - ]])) + local myxml = io.open('my.xml'):read('*all') + local slaxml: SLAXML = {} + slaxml:parse(myxml, slaxml) + ]])) + else + it("can have self-references", util.check([[ + local ]]..statement..[[ SLAXML ]]..array(i, "{SLAXML}")..[[ + parse: function(self: SLAXML, xml: string, anotherself: SLAXML) + end + + local myxml = io.open('my.xml'):read('*all') + SLAXML:parse(myxml, SLAXML) + ]])) + end it("can have circular type dependencies", util.check([[ - local type R = record ]]..pick(i, "", "{S}")..[[ + local type R = ]]..statement..[[ ]]..array(i, "{S}")..[[ foo: S end - local type S = record ]]..pick(i, "", "{R}")..[[ + local type S = ]]..statement..[[ ]]..array(i, "{R}")..[[ foo: R end @@ -106,29 +125,29 @@ for i, name in ipairs({"records", "arrayrecords"}) do it("recursive types don't trip up the resolver", util.check([[ local type EmptyString = enum "" end - local record ltn12 ]]..pick(i, "", "{ltn12}")..[[ + local ]]..statement..[[ ltn12 ]]..array(i, "{ltn12}")..[[ type FancySource = function(): T|EmptyString, string|FancySource end return ltn12 ]])) it("can overload functions", util.check([[ - global type love_graphics = record ]]..pick(i, "", "{love_graphics}")..[[ + global type love_graphics = ]]..statement..[[ ]]..array(i, "{love_graphics}")..[[ print: function(text: string, x: number, y: number, r: number, sx: number, sy: number, ox: number, oy: number, kx: number, ky:number) print: function(coloredtext: {any}, x: number, y: number, r: number, sx: number, sy: number, ox: number, oy: number, kx: number, ky:number) end - global type love = record ]]..pick(i, "", "{love}")..[[ + global type love = ]]..statement..[[ ]]..array(i, "{love}")..[[ graphics: love_graphics end - + ]] .. (statement ~= "interface" and [[ global function main() love.graphics.print("Hello world", 100, 100) end - ]])) + ]] or ""))) it("cannot overload other things", util.check_syntax_error([[ - global type love_graphics = record ]]..pick(i, "", "{love_graphics}")..[[ + global type love_graphics = ]]..statement..[[ ]]..array(i, "{love_graphics}")..[[ print: number print: string end @@ -142,7 +161,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do "b" "c" end - local type R = record ]]..pick(i, "", "{number}")..[[ + local type R = ]]..statement..[[ ]]..array(i, "{number}")..[[ f: function(enums: {E}) f: function(tuple: {string, number}) end @@ -154,7 +173,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do it("can report an error on unknown types in polymorphic definitions", util.check_type_error([[ -- this reports an error - local type R = record ]]..pick(i, "", "{R}")..[[ + local type R = ]]..statement..[[ ]]..array(i, "{R}")..[[ u: function(): UnknownType u: function(): string end @@ -168,7 +187,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do it("can report an error on unknown types in polymorphic definitions in any order", util.check_type_error([[ -- this reports an error - local type R = record ]]..pick(i, "", "{R}")..[[ + local type R = ]]..statement..[[ ]]..array(i, "{R}")..[[ u: function(): string u: function(): UnknownType end @@ -181,14 +200,14 @@ for i, name in ipairs({"records", "arrayrecords"}) do })) it("can produce an intersection type for polymorphic functions", util.check([[ - local type requests = record ]]..pick(i, "", "{requests}")..[[ + local type requests = ]]..statement..[[ ]]..array(i, "{requests}")..[[ - type RequestOpts = record + type RequestOpts = ]]..statement..[[ {string} url: string end - type Response = record ]]..pick(i, "", "{Response}")..[[ + type Response = ]]..statement..[[ ]]..array(i, "{Response}")..[[ status_code: number end @@ -202,14 +221,14 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) it("can check the arity of polymorphic functions", util.check_type_error([[ - local type requests = record ]]..pick(i, "", "{requests}")..[[ + local type requests = ]]..statement..[[ ]]..array(i, "{requests}")..[[ - type RequestOpts = record + type RequestOpts = ]]..statement..[[ -- {string} url: string end - type Response = record ]]..pick(i, "", "{Response}")..[[ + type Response = ]]..statement..[[ ]]..array(i, "{Response}")..[[ status_code: number end @@ -221,7 +240,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do local r: requests = {} local resp = r.get("hello", 123, 123) ]], { - { y = 18, msg = "wrong number of arguments (given 3, expects 1 or 2)" } + { y = 18, x = 28, msg = "wrong number of arguments (given 3, expects 1 or 2)" } })) it("can be nested", function() @@ -229,12 +248,12 @@ for i, name in ipairs({"records", "arrayrecords"}) do ["req.d.tl"] = [[ local type requests = record - type RequestOpts = record + type RequestOpts = ]]..statement..[[ -- {string} url: string end - type Response = record ]]..pick(i, "", "{Response}")..[[ + type Response = ]]..statement..[[ ]]..array(i, "{Response}")..[[ status_code: number end @@ -253,7 +272,10 @@ for i, name in ipairs({"records", "arrayrecords"}) do print(r.status_code) print(r.status_coda) ]], { - { msg = "invalid key 'status_coda' in record 'r' of type Response" } + { msg = (statement == "interface") + and "invalid key 'status_coda' in 'r' of interface type Response" + or "invalid key 'status_coda' in record 'r' of type Response" + } }) end) @@ -262,12 +284,12 @@ for i, name in ipairs({"records", "arrayrecords"}) do ["req.d.tl"] = [[ local type requests = record - record RequestOpts + ]]..statement..[[ RequestOpts {string} url: string end - record Response ]]..pick(i, "", "{Response}")..[[ + ]]..statement..[[ Response ]]..array(i, "{Response}")..[[ status_code: number end @@ -286,37 +308,40 @@ for i, name in ipairs({"records", "arrayrecords"}) do print(r.status_code) print(r.status_coda) ]], { - { msg = "invalid key 'status_coda' in record 'r' of type Response" } + { msg = (statement == "interface") + and "invalid key 'status_coda' in 'r' of interface type Response" + or "invalid key 'status_coda' in record 'r' of type Response" + } }) end) - it("record and enum and not reserved words", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ - record: string + it(statement.." and enum and not reserved words", util.check([[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + ]]..statement..[[: string enum: number end local f: foo = {} - foo.record = "hello" - foo.enum = 123 + f.]]..statement..[[ = "hello" + f.enum = 123 ]])) it("can have nested generic " .. name, util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ - type bar = record ]]..pick(i, "", "{bar}")..[[ + local type Foo = ]]..statement..[[ ]]..array(i, "{Foo}")..[[ + type Bar = ]]..statement..[[ ]]..array(i, "{Bar}")..[[ x: T end - example: bar + example: Bar end - local f: foo = {} + local f: Foo = {} - foo.example = { x = "hello" } + f.example = { x = "hello" } ]])) it("can have nested enums", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ enum Direction "north" "south" @@ -330,12 +355,12 @@ for i, name in ipairs({"records", "arrayrecords"}) do local f: foo = {} local dir: foo.Direction = "north" - foo.d = dir + f.d = dir ]])) it("can have nested generic " .. name .. " with shorthand syntax", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ - record bar ]]..pick(i, "", "{bar}")..[[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + ]]..statement..[[ bar ]]..array(i, "{bar}")..[[ x: T end example: bar @@ -343,13 +368,13 @@ for i, name in ipairs({"records", "arrayrecords"}) do local f: foo = {} - foo.example = { x = "hello" } + f.example = { x = "hello" } ]])) - it("can mix nested record syntax", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ - record mid ]]..pick(i, "", "{mid}")..[[ - type bar = record ]]..pick(i, "", "{bar}")..[[ + it("can mix nested "..statement.." syntax", util.check([[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + ]]..statement..[[ mid ]]..array(i, "{mid}")..[[ + type bar = ]]..statement..[[ ]]..array(i, "{bar}")..[[ x: T end z: bar @@ -359,47 +384,51 @@ for i, name in ipairs({"records", "arrayrecords"}) do local f: foo = {} - foo.example = { z = { x = "hello" } } + f.example = { z = { x = "hello" } } ]])) it("can have " .. name .. " in arrayrecords", util.check([[ - local record bar ]]..pick(i, "", "{bar}")..[[ + local ]]..statement..[[ bar ]]..array(i, "{bar}")..[[ end - local record foo + local ]]..statement..[[ foo { bar } end local f : foo = { { } } ]])) it("nested " .. name .. " in " .. name, util.check_type_error([[ - local record foo ]]..pick(i, "", "{foo}")..[[ - record bar ]]..pick(i, "", "{bar}")..[[ + local ]]..statement..[[ foo ]]..array(i, "{foo}")..[[ + ]]..statement..[[ bar ]]..array(i, "{bar}")..[[ end end local f : foo = { { } } ]], { - select(i, + ({ { msg = "in local declaration: f: got {{}} (inferred at foo.tl:5:26), expected foo" }, -- records - nil -- arrayrecords - ) + nil, -- arrayrecords + { msg = "in local declaration: f: got {{}} (inferred at foo.tl:5:26), expected foo" }, -- interfaces + nil, -- interfaces with arrays + })[i] })) it("can have nested enums in " .. name, util.check_type_error([[ - local record foo ]]..pick(i, "", "{bar}")..[[ + local ]]..statement..[[ foo ]]..array(i, "{bar}")..[[ enum bar "baz" end end local f : foo = { "baz" } ]], { - select(i, + ({ { msg = "in local declaration: f: got {string \"baz\"} (inferred at foo.tl:6:26), expected foo" }, -- records - nil -- arrayrecords - ) + nil, -- arrayrecords + { msg = "in local declaration: f: got {string \"baz\"} (inferred at foo.tl:6:26), expected foo" }, -- interfaces + nil, -- interfaces with arrays + })[i] })) it("can extend generic functions", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ type bar = function(T) example: bar end @@ -409,27 +438,45 @@ for i, name in ipairs({"records", "arrayrecords"}) do end ]])) - it("does not produce an esoteric type error (#167)", util.check_type_error([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ - type bar = function(T) - example: bar - end + if statement == "record" then + it("does not produce an esoteric type error (#167)", util.check_type_error([[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + type bar = function(T) + example: bar + end - foo.example = function(data: string) - print(data) - end as bar - ]], { - -- this is expected, because bar is local to foo - { y = 8, x = 17, msg = "unknown type bar" }, - })) + foo.example = function(data: string) + print(data) + end as bar + ]], { + -- this is expected, because bar is local to foo + { y = 8, x = 20, msg = "unknown type bar" }, + })) + else + it("does not produce an esoteric type error (#167)", util.check_type_error([[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + type bar = function(T) + example: bar + end + + local f: foo = {} + f.example = function(data: string) + print(data) + end as bar + ]], { + -- this is expected, because bar is local to foo + { y = 9, x = 20, msg = "unknown type bar" }, + })) + end it("can cast generic member using full path of type name", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ type bar = function(T) example: bar end - foo.example = function(data: string) + local f: foo = {} + f.example = function(data: string) print(data) end as foo.bar ]])) @@ -439,12 +486,12 @@ for i, name in ipairs({"records", "arrayrecords"}) do ["req.d.tl"] = [[ local record requests - type RequestOpts = record + type RequestOpts = ]]..statement..[[ {string} url: string end - type Response = record ]]..pick(i, "", "{Response}")..[[ + type Response = ]]..statement..[[ ]]..array(i, "{Response}")..[[ status_code: number end @@ -468,8 +515,8 @@ for i, name in ipairs({"records", "arrayrecords"}) do end) it("resolves aliasing of nested " .. name .. " (see #400)", util.check([[ - local record Foo ]]..pick(i, "", "{Foo}")..[[ - record Bar ]]..pick(i, "", "{Bar}")..[[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ + ]]..statement..[[ Bar ]]..array(i, "{Bar}")..[[ end end local function func(_f: Foo.Bar) end @@ -482,7 +529,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do it("resolves nested type aliases (see #416)", util.check([[ local type A = number - local record Foo ]]..pick(i, "", "{Foo}")..[[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ type B = A end @@ -491,7 +538,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) it("resolves nested type aliases to other aliases (see #527)", util.check([[ - local record M + local ]]..statement..[[ M type Type1 = number type Type2 = Type1 end @@ -504,14 +551,14 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) it("can use nested type aliases as types (see #416)", util.check_type_error([[ - local record F1 ]]..pick(i, "", "{F1}")..[[ - record A ]]..pick(i, "", "{A}")..[[ + local ]]..statement..[[ F1 ]]..array(i, "{F1}")..[[ + ]]..statement..[[ A ]]..array(i, "{A}")..[[ x: number end type C1 = A - record F2 ]]..pick(i, "", "{F2}")..[[ + ]]..statement..[[ F2 ]]..array(i, "{F2}")..[[ type C2 = C1 - record F3 ]]..pick(i, "", "{F3}")..[[ + ]]..statement..[[ F3 ]]..array(i, "{F3}")..[[ type C3 = C2 end end @@ -526,60 +573,94 @@ for i, name in ipairs({"records", "arrayrecords"}) do { y = 18, msg = 'got string "hello", expected number' }, })) - it("cannot use nested type aliases as values (see #416)", util.check_type_error([[ - local record F1 ]]..pick(i, "", "{F1}")..[[ - record A ]]..pick(i, "", "{C1}")..[[ - x: number + if statement == "record" then + it("cannot use nested type aliases as values (see #416)", util.check_type_error([[ + local ]]..statement..[[ F1 ]]..array(i, "{F1}")..[[ + ]]..statement..[[ A ]]..array(i, "{C1}")..[[ + x: number + end + type C1 = A + ]]..statement..[[ F2 ]]..array(i, "{C2}")..[[ + type C2 = C1 + ]]..statement..[[ F3 ]]..array(i, "{C3}")..[[ + type C3 = C2 + end + end end - type C1 = A - record F2 ]]..pick(i, "", "{C2}")..[[ - type C2 = C1 - record F3 ]]..pick(i, "", "{C3}")..[[ - type C3 = C2 + + -- Let's use nested type aliases as prototypes + + F1.C1.x = 2 + + local proto = F1.F2.F3.C3 + + proto.x = 2 + ]], { + { y = 16, msg = "cannot use a nested type alias as a concrete value" }, + { y = 20, msg = "cannot use a nested type alias as a concrete value" }, + })) + else + it("cannot use nested type aliases as values (see #416)", util.check_type_error([[ + local ]]..statement..[[ F1 ]]..array(i, "{F1}")..[[ + ]]..statement..[[ A ]]..array(i, "{C1}")..[[ + x: number + end + type C1 = A + ]]..statement..[[ F2 ]]..array(i, "{C2}")..[[ + type C2 = C1 + ]]..statement..[[ F3 ]]..array(i, "{C3}")..[[ + type C3 = C2 + end end end - end - -- Let's use nested type aliases as prototypes + -- Let's use nested type aliases as prototypes - F1.C1.x = 2 + F1.C1.x = 2 - local proto = F1.F2.F3.C3 + local proto = F1.F2.F3.C3 - proto.x = 2 - ]], { - { y = 16, msg = "cannot use a nested type alias as a concrete value" }, - { y = 20, msg = "cannot use a nested type alias as a concrete value" }, - })) + proto.x = 2 + ]], { + { y = 16, msg = "interfaces are abstract" }, + { y = 16, msg = "cannot use a nested type alias as a concrete value" }, + { y = 18, msg = "interfaces are abstract" }, + { y = 18, msg = "interfaces are abstract" }, + { y = 18, msg = "interfaces are abstract" }, + { y = 20, msg = "cannot use a nested type alias as a concrete value" }, + })) + end it("can resolve generics partially (see #417)", function() local _, ast = util.check([[ - local record fun ]]..pick(i, "", "{fun}")..[[ - record iterator ]]..pick(i, "", "{iterator}")..[[ + local ]]..statement..[[ fun ]]..array(i, "{fun}")..[[ + ]]..statement..[[ iterator ]]..array(i, "{iterator}")..[[ reduce: function(iterator, (function(R, T): R), R): R end iter: function({T}): iterator end - local sum = fun.iter({ 1, 2, 3, 4 }):reduce(function(a:integer,x:integer): integer + local f: fun + + local sum = f.iter({ 1, 2, 3, 4 }):reduce(function(a:integer,x:integer): integer return a + x end, 0) ]])() - assert.same("integer", ast[2].exps[1].type[1].typename) + assert.same("integer", ast[3].exps[1].type[1].typename) end) it("can have circular type dependencies on nested types", util.check([[ - local type R = record ]]..pick(i, "", "{S}")..[[ - type R2 = record ]]..pick(i, "", "{S.S2}")..[[ + local type R = ]]..statement..[[ ]]..array(i, "{S}")..[[ + type R2 = ]]..statement..[[ ]]..array(i, "{S.S2}")..[[ foo: S.S2 end foo: S end - local type S = record ]]..pick(i, "", "{R}")..[[ - type S2 = record ]]..pick(i, "", "{R.R2}")..[[ + local type S = ]]..statement..[[ ]]..array(i, "{R}")..[[ + type S2 = ]]..statement..[[ ]]..array(i, "{R.R2}")..[[ foo: R.R2 end @@ -592,16 +673,16 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) it("can detect errors in type dependencies on nested types", util.check_type_error([[ - local record R ]]..pick(i, "", "{R}")..[[ - record R2 ]]..pick(i, "", "{R2}")..[[ + local ]]..statement..[[ R ]]..array(i, "{R}")..[[ + ]]..statement..[[ R2 ]]..array(i, "{R2}")..[[ foo: S.S3 end foo: S end - local record S ]]..pick(i, "", "{S}")..[[ - record S2 ]]..pick(i, "", "{S2}")..[[ + local ]]..statement..[[ S ]]..array(i, "{S}")..[[ + ]]..statement..[[ S2 ]]..array(i, "{S2}")..[[ foo: R.R2 end @@ -616,7 +697,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do })) it("can contain reserved words/arbitrary strings with ['table key syntax']", util.check([[ - local record A ]]..pick(i, "", "{A}")..[=[ + local ]]..statement..[[ A ]]..array(i, "{A}")..[=[ start: number ["end"]: number [" "]: string @@ -626,7 +707,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]=])) it("can be declared as userdata", util.check([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ userdata x: number y: number @@ -634,18 +715,18 @@ for i, name in ipairs({"records", "arrayrecords"}) do ]])) it("cannot be declared as userdata twice", util.check_syntax_error([[ - local type foo = record ]]..pick(i, "", "{foo}")..[[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ userdata userdata x: number y: number end ]], { - { msg = "duplicated 'userdata' declaration in record" }, + { msg = "duplicated 'userdata' declaration" }, })) it("untyped attributes are not accepted (#381)", util.check_syntax_error([[ - local record kons ]]..pick(i, "", "{kons}")..[[ + local ]]..statement..[[ kons ]]..array(i, "{kons}")..[[ any_identifier other_sequence aaa bbb end @@ -659,7 +740,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do })) it("catches redeclaration of literal keys", util.check_type_error([[ - local record Foo ]]..pick(i, "", "{Foo}")..[[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ foo: string bar: boolean end @@ -673,7 +754,7 @@ for i, name in ipairs({"records", "arrayrecords"}) do })) it("catches redeclaration of literal keys, bracket syntax", util.check_type_error([[ - local record Foo ]]..pick(i, "", "{Foo}")..[[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ foo: string bar: boolean end @@ -686,30 +767,32 @@ for i, name in ipairs({"records", "arrayrecords"}) do { y = 8, msg = "redeclared key foo" } })) - it("can use itself in a constructor (regression test for #422)", util.check([[ - local record Foo ]]..pick(i, "", "{number}")..[[ - end + if statement ~= "interface" then + it("can use itself in a constructor (regression test for #422)", util.check([[ + local ]]..statement..[[ Foo ]]..array(i, "{number}")..[[ + end - function Foo:new(): Foo - return setmetatable({} as Foo, self as metatable) - end + function Foo:new(): Foo + return setmetatable({} as Foo, self as metatable) + end - local foo = Foo:new() - ]])) + local foo = Foo:new() + ]])) - it("can use itself in a constructor with dot notation (regression test for #422)", util.check([[ - local record Foo ]]..pick(i, "", "{number}")..[[ - end + it("can use itself in a constructor with dot notation (regression test for #422)", util.check([[ + local ]]..statement..[[ Foo ]]..array(i, "{number}")..[[ + end - function Foo.new(): Foo - return setmetatable({}, Foo as metatable) - end + function Foo.new(): Foo + return setmetatable({}, Foo as metatable) + end - local foo = Foo.new() - ]])) + local foo = Foo.new() + ]])) + end it("creation of userdata records should be disallowed (#460)", util.check_type_error([[ - local record Foo ]]..pick(i, "", "{number}")..[[ + local ]]..statement..[[ Foo ]]..array(i, "{number}")..[[ userdata a: number end @@ -723,14 +806,24 @@ for i, name in ipairs({"records", "arrayrecords"}) do f(bar) ]], { { y = 5, msg = "in local declaration: foo: got {}, expected Foo" }, - { y = 6, msg = "in assignment: userdata record doesn't match: Foo" }, + select(i, + { y = 6, msg = "in assignment: userdata "..statement.." doesn't match: Foo" }, + { y = 6, msg = "in assignment: userdata "..statement.." doesn't match: Foo" }, + { y = 6, msg = "in assignment: got record (a: integer), expected Foo" }, + { y = 6, msg = "in assignment: got record (a: integer), expected Foo" } + ), { y = 8, msg = "argument 1: got {}, expected Foo" }, - { y = 9, msg = "argument 1: userdata record doesn't match: Foo" }, + select(i, + { y = 9, msg = "argument 1: userdata "..statement.." doesn't match: Foo" }, + { y = 9, msg = "argument 1: userdata "..statement.." doesn't match: Foo" }, + { y = 9, msg = "argument 1: got record (a: integer), expected Foo" }, + { y = 9, msg = "argument 1: got record (a: integer), expected Foo" } + ), nil })) it("reports error on unknown interfaces", util.check_type_error([[ - local record Foo ]]..pick(i, "is Bongo, Bingo", "is {number}, Bongo, Bingo")..[[ + local ]]..statement..[[ Foo ]]..array(i, "is {number}, Bongo, Bingo", "is Bongo, Bingo")..[[ userdata a: number end diff --git a/tl.lua b/tl.lua index b1f81730b..f7c10f2e8 100644 --- a/tl.lua +++ b/tl.lua @@ -1596,10 +1596,12 @@ end local function parse_table_value(ps, i) local next_word = ps.tokens[i].tk - if next_word == "record" then + if next_word == "record" or next_word == "interface" then local skip_i, e = skip(ps, i, skip_type_body) if e then - fail(ps, i, "syntax error: this syntax is no longer valid; declare nested record inside a record") + fail(ps, i, next_word == "record" and + "syntax error: this syntax is no longer valid; declare nested record inside a record" or + "syntax error: cannot declare interface inside a table; use a statement") return skip_i, new_node(ps.tokens, i, "error_node") end elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then @@ -2958,7 +2960,7 @@ parse_record_body = function(ps, i, def, node, name) local tn = ps.tokens[i].tk if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then if def.is_userdata then - fail(ps, i, "duplicated 'userdata' declaration in record") + fail(ps, i, "duplicated 'userdata' declaration") else def.is_userdata = true end @@ -6057,7 +6059,6 @@ tl.type_check = function(ast, opts) local function find_var(name, use) - for i = #st, 1, -1 do local scope = st[i] local var = scope[name] @@ -7298,16 +7299,19 @@ tl.type_check = function(ast, opts) return false, errs end - local known_table_types = { - array = true, - map = true, - record = true, - tupletable = true, - } + do + local known_table_types = { + array = true, + map = true, + record = true, + tupletable = true, + interface = true, + } - is_lua_table_type = function(t) - return known_table_types[t.typename] and not t.is_userdata + is_lua_table_type = function(t) + return known_table_types[t.typename] and not t.is_userdata + end end local expand_type @@ -8219,7 +8223,11 @@ tl.type_check = function(ast, opts) end if rec.kind == "variable" then - return nil, "invalid key '" .. key .. "' in record '" .. rec.tk .. "' of type %s" + if tbl.typename == "interface" then + return nil, "invalid key '" .. key .. "' in '" .. rec.tk .. "' of interface type %s" + else + return nil, "invalid key '" .. key .. "' in record '" .. rec.tk .. "' of type %s" + end else return nil, "invalid key '" .. key .. "' in type %s" end @@ -10120,7 +10128,11 @@ tl.type_check = function(ast, opts) if not df then node_error(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - assert_is_a(node[i], cvtype, df, "in record field", ck) + if is_typetype(df) then + node_error(node[i], in_context(node.expected_context, "cannot reassign a type")) + else + assert_is_a(node[i], cvtype, df, "in record field", ck) + end end elseif is_tupletable and is_number_type(child.ktype) then local dt = decltype.types[n] @@ -10513,11 +10525,19 @@ tl.type_check = function(ast, opts) return node.type end - if is_typetype(ra) and ra.def.typename == "record" then - ra = ra.def + if is_typetype(ra) then + if ra.def.typename == "record" then + ra = ra.def + elseif ra.def.typename == "interface" then + node_error(node, "interfaces are abstract; consider using a concrete record") + end end if rb and is_typetype(rb) and rb.def.typename == "record" then - rb = rb.def + if rb.def.typename == "record" then + rb = rb.def + elseif rb.def.typename == "interface" then + node_error(node, "interfaces are abstract; consider using a concrete record") + end end if node.op.op == "@funcall" then @@ -11064,6 +11084,7 @@ tl.type_check = function(ast, opts) visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] + visit_type.cbs["interface"] = visit_type.cbs["record"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] visit_type.cbs["nil"] = visit_type.cbs["string"] diff --git a/tl.tl b/tl.tl index eeaf749d3..b082ad544 100644 --- a/tl.tl +++ b/tl.tl @@ -1596,10 +1596,12 @@ end local function parse_table_value(ps: ParseState, i: integer): integer, Node, integer local next_word = ps.tokens[i].tk - if next_word == "record" then + if next_word == "record" or next_word == "interface" then local skip_i, e = skip(ps, i, skip_type_body) if e then - fail(ps, i, "syntax error: this syntax is no longer valid; declare nested record inside a record") + fail(ps, i, next_word == "record" + and "syntax error: this syntax is no longer valid; declare nested record inside a record" + or "syntax error: cannot declare interface inside a table; use a statement") return skip_i, new_node(ps.tokens, i, "error_node") end elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then @@ -2958,7 +2960,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, local tn = ps.tokens[i].tk as TypeName if ps.tokens[i].tk == "userdata" and ps.tokens[i+1].tk ~= ":" then if def.is_userdata then - fail(ps, i, "duplicated 'userdata' declaration in record") + fail(ps, i, "duplicated 'userdata' declaration") else def.is_userdata = true end @@ -6057,7 +6059,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function find_var(name: string, use: VarUse): Variable, integer, Attribute - for i = #st, 1, -1 do local scope = st[i] local var = scope[name] @@ -7298,16 +7299,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, errs end - local known_table_types: {TypeName:boolean} = { - array = true, - map = true, - record = true, - tupletable = true, - } + do + local known_table_types: {TypeName:boolean} = { + array = true, + map = true, + record = true, + tupletable = true, + interface = true, + } - -- Is the type represented concretely as a Lua table? - is_lua_table_type = function(t: Type): boolean - return known_table_types[t.typename] and not t.is_userdata + -- Is the type represented concretely as a Lua table? + is_lua_table_type = function(t: Type): boolean + return known_table_types[t.typename] and not t.is_userdata + end end local expand_type: function(where: Where, old: Type, new: Type): Type @@ -8219,7 +8223,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if rec.kind == "variable" then - return nil, "invalid key '" .. key .. "' in record '" .. rec.tk .. "' of type %s" + if tbl.typename == "interface" then + return nil, "invalid key '" .. key .. "' in '" .. rec.tk .. "' of interface type %s" + else + return nil, "invalid key '" .. key .. "' in record '" .. rec.tk .. "' of type %s" + end else return nil, "invalid key '" .. key .. "' in type %s" end @@ -10120,7 +10128,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not df then node_error(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - assert_is_a(node[i], cvtype, df, "in record field", ck) + if is_typetype(df) then + node_error(node[i], in_context(node.expected_context, "cannot reassign a type")) + else + assert_is_a(node[i], cvtype, df, "in record field", ck) + end end elseif is_tupletable and is_number_type(child.ktype) then local dt = decltype.types[n as integer] @@ -10513,11 +10525,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return node.type end - if is_typetype(ra) and ra.def.typename == "record" then - ra = ra.def + if is_typetype(ra) then + if ra.def.typename == "record" then + ra = ra.def + elseif ra.def.typename == "interface" then + node_error(node, "interfaces are abstract; consider using a concrete record") + end end if rb and is_typetype(rb) and rb.def.typename == "record" then - rb = rb.def + if rb.def.typename == "record" then + rb = rb.def + elseif rb.def.typename == "interface" then + node_error(node, "interfaces are abstract; consider using a concrete record") + end end if node.op.op == "@funcall" then @@ -11064,6 +11084,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] + visit_type.cbs["interface"] = visit_type.cbs["record"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] visit_type.cbs["nil"] = visit_type.cbs["string"] From 9a5aeb405c0257d13393eee6d3c567d057ad1552 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 26 Jan 2021 17:43:51 -0300 Subject: [PATCH 022/224] optional arity: add syntax for optional arity in function arguments --- docs/grammar.md | 4 ++-- tl.lua | 40 ++++++++++++++++++++++++++++++++++++---- tl.tl | 42 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 75 insertions(+), 11 deletions(-) diff --git a/docs/grammar.md b/docs/grammar.md index 3e1a9c9cf..49b0682e9 100644 --- a/docs/grammar.md +++ b/docs/grammar.md @@ -127,11 +127,11 @@ precedence, see below. * partypelist ::= partype {‘,’ partype} -* partype ::= [Name ‘:’] type +* partype ::= Name [‘?’] ‘:’ type | [‘?’] type * parnamelist ::= parname {‘,’ parname} -* parname ::= Name [‘:’ type] +* parname ::= Name [‘?’] [‘:’ type] ``` ## Operator precedence diff --git a/tl.lua b/tl.lua index f7c10f2e8..0cba1b7bc 100644 --- a/tl.lua +++ b/tl.lua @@ -443,7 +443,7 @@ do end local lex_any_char_kinds = {} - local single_char_kinds = { "[", "]", "(", ")", "{", "}", ",", "#", ";" } + local single_char_kinds = { "[", "]", "(", ")", "{", "}", ",", "#", ";", "?" } for _, c in ipairs(single_char_kinds) do lex_any_char_kinds[c] = c end @@ -1239,6 +1239,10 @@ local table_types = { + + + + @@ -1433,6 +1437,7 @@ local Node = {ExpectedContext = {}, } + local function is_array_type(t) @@ -2436,6 +2441,10 @@ local function parse_argument(ps, i) if ps.tokens[i].tk == "..." then fail(ps, i, "'...' needs to be declared as a typed argument") end + if ps.tokens[i].tk == "?" then + i = i + 1 + node.opt = true + end if ps.tokens[i].tk == ":" then i = i + 1 local decltype @@ -2452,10 +2461,16 @@ end parse_argument_list = function(ps, i) local node = new_node(ps.tokens, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) + local opts = false for a, fnarg in ipairs(node) do if fnarg.tk == "..." and a ~= #node then fail(ps, i, "'...' can only be last argument") end + if fnarg.opt then + opts = true + elseif opts then + return fail(ps, i, "non-optional arguments cannot follow optional arguments") + end end return i, node end @@ -2469,9 +2484,21 @@ end local function parse_argument_type(ps, i) local is_va = false local argument_name = nil - if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then + + local opt = false + if ps.tokens[i].kind == "identifier" then argument_name = ps.tokens[i].tk - i = i + 2 + if ps.tokens[i + 1].tk == "?" then + opt = true + if ps.tokens[i + 2].tk == ":" then + i = i + 3 + end + elseif ps.tokens[i + 1].tk == ":" then + i = i + 2 + end + elseif ps.tokens[i].kind == "?" then + opt = true + i = i + 1 elseif ps.tokens[i].tk == "..." then if ps.tokens[i + 1].tk == ":" then i = i + 2 @@ -2483,6 +2510,7 @@ local function parse_argument_type(ps, i) local typ; i, typ = parse_type(ps, i) if typ then + typ.opt = opt if not is_va and ps.tokens[i].tk == "..." then i = i + 1 is_va = true @@ -5049,7 +5077,9 @@ local function show_type_base(t, short, seen) end for i, v in ipairs(t.args) do if not t.is_method or i > 1 then - table.insert(args, (i == #t.args and t.args.is_va and "...: " or "") .. show(v)) + table.insert(args, ((i == #t.args and t.args.is_va) and "...: " or + v.opt and "? " or + "") .. show(v)) end end table.insert(out, table.concat(args, ", ")) @@ -6417,6 +6447,7 @@ tl.type_check = function(ast, opts) end copy.is_method = t.is_method + copy.min_arity = t.min_arity copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) elseif is_record_type(t) then @@ -10839,6 +10870,7 @@ tl.type_check = function(ast, opts) if node.tk == "..." then t = a_type({ typename = "tuple", is_va = true, t }) end + t.opt = node.opt add_var(node, node.tk, t).is_func_arg = true return node.type end, diff --git a/tl.tl b/tl.tl index b082ad544..8d4060090 100644 --- a/tl.tl +++ b/tl.tl @@ -273,7 +273,7 @@ local enum TokenKind "keyword" "op" "string" - "[" "]" "(" ")" "{" "}" "," ":" "#" "." ";" + "[" "]" "(" ")" "{" "}" "," ":" "#" "." ";" "?" "::" "..." "identifier" @@ -443,7 +443,7 @@ do end local lex_any_char_kinds: {string:TokenKind} = {} - local single_char_kinds: {TokenKind} = {"[", "]", "(", ")", "{", "}", ",", "#", ";"} + local single_char_kinds: {TokenKind} = {"[", "]", "(", ")", "{", "}", ",", "#", ";", "?"} for _, c in ipairs(single_char_kinds) do lex_any_char_kinds[c] = c end @@ -1080,6 +1080,9 @@ local record Type -- Lua compatibilty needs_compat: boolean + -- arguments: optional arity + opt: boolean + -- tuple is_va: boolean @@ -1113,6 +1116,7 @@ local record Type -- function is_method: boolean + min_arity: number args: Type rets: Type @@ -1429,6 +1433,7 @@ local record Node type: Type decltype: Type + opt: boolean end local type Where @@ -2436,6 +2441,10 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege if ps.tokens[i].tk == "..." then fail(ps, i, "'...' needs to be declared as a typed argument") end + if ps.tokens[i].tk == "?" then + i = i + 1 + node.opt = true + end if ps.tokens[i].tk == ":" then i = i + 1 local decltype: Type @@ -2452,10 +2461,16 @@ end parse_argument_list = function(ps: ParseState, i: integer): integer, Node local node = new_node(ps.tokens, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) + local opts = false for a, fnarg in ipairs(node) do if fnarg.tk == "..." and a ~= #node then fail(ps, i, "'...' can only be last argument") end + if fnarg.opt then + opts = true + elseif opts then + return fail(ps, i, "non-optional arguments cannot follow optional arguments") + end end return i, node end @@ -2469,9 +2484,21 @@ end local function parse_argument_type(ps: ParseState, i: integer): integer, TypeAndVararg, integer local is_va = false local argument_name: string = nil - if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then + + local opt = false + if ps.tokens[i].kind == "identifier" then argument_name = ps.tokens[i].tk - i = i + 2 + if ps.tokens[i + 1].tk == "?" then + opt = true + if ps.tokens[i + 2].tk == ":" then + i = i + 3 + end + elseif ps.tokens[i + 1].tk == ":" then + i = i + 2 + end + elseif ps.tokens[i].kind == "?" then + opt = true + i = i + 1 elseif ps.tokens[i].tk == "..." then if ps.tokens[i + 1].tk == ":" then i = i + 2 @@ -2483,6 +2510,7 @@ local function parse_argument_type(ps: ParseState, i: integer): integer, TypeAnd local typ: Type; i, typ = parse_type(ps, i) if typ then + typ.opt = opt if not is_va and ps.tokens[i].tk == "..." then i = i + 1 is_va = true @@ -5049,7 +5077,9 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end for i, v in ipairs(t.args) do if not t.is_method or i > 1 then - table.insert(args, (i == #t.args and t.args.is_va and "...: " or "") .. show(v)) + table.insert(args, ((i == #t.args and t.args.is_va) and "...: " + or v.opt and "? " + or "") .. show(v)) end end table.insert(out, table.concat(args, ", ")) @@ -6417,6 +6447,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end copy.is_method = t.is_method + copy.min_arity = t.min_arity copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) elseif is_record_type(t) then @@ -10839,6 +10870,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.tk == "..." then t = a_type { typename = "tuple", is_va = true, t } end + t.opt = node.opt add_var(node, node.tk, t).is_func_arg = true return node.type end, From f034cf18c2ec8509227377213c3f8daa9c17e95f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 26 Jan 2021 17:55:10 -0300 Subject: [PATCH 023/224] optional arity: calculate minimum arity for a function --- tl.lua | 18 ++++++++++++++++++ tl.tl | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tl.lua b/tl.lua index 0cba1b7bc..e7e14965e 100644 --- a/tl.lua +++ b/tl.lua @@ -7935,6 +7935,23 @@ tl.type_check = function(ast, opts) local type_check_function_call do + local function set_min_arity(t) + if not t.args then + return + end + local min_arity = 0 + for i, fnarg in ipairs(t.args) do + if not fnarg.opt then + min_arity = i + end + end + if t.args.is_va then + min_arity = min_arity - 1 + end + t.min_arity = min_arity + return min_arity + end + local function mark_invalid_typeargs(f) if f.typeargs then for _, a in ipairs(f.typeargs) do @@ -8124,6 +8141,7 @@ tl.type_check = function(ast, opts) end end local expected = #f.args + local min_arity = f.min_arity or set_min_arity(f) if (is_func and (given <= expected or (f.args.is_va and given > expected))) or diff --git a/tl.tl b/tl.tl index 8d4060090..54ee043b4 100644 --- a/tl.tl +++ b/tl.tl @@ -7935,6 +7935,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, integer): Type do + local function set_min_arity(t: Type): integer + if not t.args then + return + end + local min_arity = 0 + for i, fnarg in ipairs(t.args) do + if not fnarg.opt then + min_arity = i + end + end + if t.args.is_va then + min_arity = min_arity - 1 + end + t.min_arity = min_arity + return min_arity + end + local function mark_invalid_typeargs(f: Type) if f.typeargs then for _, a in ipairs(f.typeargs) do @@ -8124,6 +8141,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end local expected = #f.args + local min_arity = f.min_arity or set_min_arity(f) -- simple functions: if (is_func and (given <= expected or (f.args.is_va and given > expected))) From d62c0e1522361f169100fbb7cc2b1c3dc8a45a35 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 26 Jan 2021 18:04:28 -0300 Subject: [PATCH 024/224] optional arity: prepare error messages for optional arity --- tl.lua | 27 ++++++++++++++++++--------- tl.tl | 27 ++++++++++++++++++--------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/tl.lua b/tl.lua index e7e14965e..1fe6c9cb4 100644 --- a/tl.lua +++ b/tl.lua @@ -7935,21 +7935,24 @@ tl.type_check = function(ast, opts) local type_check_function_call do - local function set_min_arity(t) - if not t.args then + local function set_min_arity(f) + if f.min_arity then + return + end + if not f.args then + f.min_arity = 0 return end local min_arity = 0 - for i, fnarg in ipairs(t.args) do + for i, fnarg in ipairs(f.args) do if not fnarg.opt then min_arity = i end end - if t.args.is_va then + if f.args.is_va then min_arity = min_arity - 1 end - t.min_arity = min_arity - return min_arity + f.min_arity = min_arity end local function mark_invalid_typeargs(f) @@ -8065,6 +8068,12 @@ tl.type_check = function(ast, opts) end end + local function show_arity(f) + return f.min_arity < #f.args and + "at least " .. f.min_arity or + tostring(#f.args or 0) + end + local function fail_call(node, func, nargs, errs) if errs then @@ -8076,7 +8085,7 @@ tl.type_check = function(ast, opts) local expects = {} if func.typename == "poly" then for _, f in ipairs(func.types) do - table.insert(expects, tostring(#f.args or 0)) + table.insert(expects, show_arity(f)) end table.sort(expects) for i = #expects, 1, -1 do @@ -8085,7 +8094,7 @@ tl.type_check = function(ast, opts) end end else - table.insert(expects, tostring(#func.args or 0)) + table.insert(expects, show_arity(func)) end node_error(node, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end @@ -8141,7 +8150,7 @@ tl.type_check = function(ast, opts) end end local expected = #f.args - local min_arity = f.min_arity or set_min_arity(f) + set_min_arity(f) if (is_func and (given <= expected or (f.args.is_va and given > expected))) or diff --git a/tl.tl b/tl.tl index 54ee043b4..e180b5995 100644 --- a/tl.tl +++ b/tl.tl @@ -7935,21 +7935,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, integer): Type do - local function set_min_arity(t: Type): integer - if not t.args then + local function set_min_arity(f: Type) + if f.min_arity then + return + end + if not f.args then + f.min_arity = 0 return end local min_arity = 0 - for i, fnarg in ipairs(t.args) do + for i, fnarg in ipairs(f.args) do if not fnarg.opt then min_arity = i end end - if t.args.is_va then + if f.args.is_va then min_arity = min_arity - 1 end - t.min_arity = min_arity - return min_arity + f.min_arity = min_arity end local function mark_invalid_typeargs(f: Type) @@ -8065,6 +8068,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function show_arity(f: Type): string + return f.min_arity < #f.args + and "at least " .. f.min_arity + or tostring(#f.args or 0) + end + local function fail_call(node: Node, func: Type, nargs: integer, errs: {Error}): Type if errs then -- report the errors from the first match @@ -8076,7 +8085,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local expects: {string} = {} if func.typename == "poly" then for _, f in ipairs(func.types) do - table.insert(expects, tostring(#f.args or 0)) + table.insert(expects, show_arity(f)) end table.sort(expects) for i = #expects, 1, -1 do @@ -8085,7 +8094,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end else - table.insert(expects, tostring(#func.args or 0)) + table.insert(expects, show_arity(func)) end node_error(node, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end @@ -8141,7 +8150,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end local expected = #f.args - local min_arity = f.min_arity or set_min_arity(f) + set_min_arity(f) -- simple functions: if (is_func and (given <= expected or (f.args.is_va and given > expected))) From 997d28917b523be1cd09482459bed082660d1c54 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Jan 2021 02:29:38 -0300 Subject: [PATCH 025/224] optional arity: adapt the compiler to specify optional arities --- spec/call/record_method_spec.lua | 12 +- spec/operator/is_spec.lua | 4 +- spec/subtyping/integer_spec.lua | 2 +- spec/subtyping/number_spec.lua | 2 +- tl.lua | 203 +++++++++++---------- tl.tl | 297 +++++++++++++++---------------- 6 files changed, 259 insertions(+), 261 deletions(-) diff --git a/spec/call/record_method_spec.lua b/spec/call/record_method_spec.lua index dcfbb9b47..0e87405b7 100644 --- a/spec/call/record_method_spec.lua +++ b/spec/call/record_method_spec.lua @@ -213,7 +213,7 @@ describe("record method call", function() end local record Foo x: integer - add: function(self: Bar, other: Bar) + add: function(self: Bar, other?: Bar) end local first: Foo = {} local second: Bar = {} @@ -223,7 +223,7 @@ describe("record method call", function() it("for function declared in method body with self as different generic type from receiver", util.check_warnings([[ local record Foo x: T - add: function(self: Foo, other: Foo) + add: function(self: Foo, other?: Foo) end local first: Foo = {} local second: Foo = {} @@ -234,7 +234,7 @@ describe("record method call", function() local record Foo x: integer end - function Foo:add(other: Foo) + function Foo:add(other?: Foo) self.x = other and (self.x + other.x) or self.x end local first: Foo = {} @@ -257,7 +257,7 @@ describe("record method call", function() local record Foo x: integer end - function Foo:add(other: integer) + function Foo:add(other?: integer) self.x = other and (self.x + other) or self.x end local first: Foo = {} @@ -282,7 +282,7 @@ describe("record method call", function() end local record Foo x: integer - add: function(self: Bar, other: Bar) + add: function(self: Bar, other?: Bar) end local first: Foo = {} first.add(first) @@ -294,7 +294,7 @@ describe("record method call", function() it("for function declared in record body with self as different generic type from receiver", util.check_type_error([[ local record Foo x: T - add: function(self: Foo, other: Foo) + add: function(self: Foo, other?: Foo) end local first: Foo = {} first.add(first) diff --git a/spec/operator/is_spec.lua b/spec/operator/is_spec.lua index 00d634513..433eff152 100644 --- a/spec/operator/is_spec.lua +++ b/spec/operator/is_spec.lua @@ -254,7 +254,7 @@ describe("flow analysis with is", function() local function f(d: any): any if d is string and func(d, "a") then - local d = func(d) + local d = func(d, d) elseif d is string and func(d, "b") then return d .. "???" else @@ -271,7 +271,7 @@ describe("flow analysis with is", function() local function f(d: any): any local d = d as (number | string | function) if d is string and func(d, "a") then - local d = func(d) + local d = func(d, d) elseif d is string and func(d, "b") then return d .. "???" else diff --git a/spec/subtyping/integer_spec.lua b/spec/subtyping/integer_spec.lua index 5e7b13108..5ce5226eb 100644 --- a/spec/subtyping/integer_spec.lua +++ b/spec/subtyping/integer_spec.lua @@ -47,7 +47,7 @@ describe("subtyping of integer:", function() })) it("integer <╱: thread", util.check_type_error([[ - local c = coroutine.create() + local c = coroutine.create(function() end) c = 42 ]], { { msg = "got integer, expected thread" } diff --git a/spec/subtyping/number_spec.lua b/spec/subtyping/number_spec.lua index 8cde26a80..4093a6508 100644 --- a/spec/subtyping/number_spec.lua +++ b/spec/subtyping/number_spec.lua @@ -49,7 +49,7 @@ describe("subtyping of number:", function() })) it("number <╱: thread", util.check_type_error([[ - local c = coroutine.create() + local c = coroutine.create(function() end) c = 1.5 ]], { { msg = "got number, expected thread" } diff --git a/tl.lua b/tl.lua index 1fe6c9cb4..40b1afc51 100644 --- a/tl.lua +++ b/tl.lua @@ -1853,6 +1853,19 @@ local simple_types = { ["integer"] = INTEGER, } +local memoize_opt_types = {} + +local function OPT(t) + if memoize_opt_types[t] then + return memoize_opt_types[t] + end + + local ot = shallow_copy_type(t) + ot.opt = true + memoize_opt_types[t] = ot + return ot +end + local function parse_simple_type_or_nominal(ps, i) local tk = ps.tokens[i].tk local st = simple_types[tk] @@ -2138,7 +2151,6 @@ do } local function new_operator(tk, arity, op) - op = op or tk.tk return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } end @@ -2178,8 +2190,8 @@ do end local e1 local t1 = ps.tokens[i] - if precedences[1][ps.tokens[i].tk] ~= nil then - local op = new_operator(ps.tokens[i], 1) + if precedences[1][t1.tk] ~= nil then + local op = new_operator(t1, 1, t1.tk) i = i + 1 local prev_i = i i, e1 = P(ps, i) @@ -2211,7 +2223,7 @@ do break end if tkop.tk == "." or tkop.tk == ":" then - local op = new_operator(tkop, 2) + local op = new_operator(tkop, 2, tkop.tk) local prev_i = i @@ -2337,7 +2349,7 @@ do local lookahead = ps.tokens[i].tk while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do local t1 = ps.tokens[i] - local op = new_operator(t1, 2) + local op = new_operator(t1, 2, t1.tk) i = i + 1 local rhs i, rhs = P(ps, i) @@ -2435,6 +2447,7 @@ local function parse_argument(ps, i) local node if ps.tokens[i].tk == "..." then i, node = verify_kind(ps, i, "...", "argument") + node.opt = true else i, node = verify_kind(ps, i, "identifier", "argument") end @@ -2463,10 +2476,12 @@ parse_argument_list = function(ps, i) i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false for a, fnarg in ipairs(node) do - if fnarg.tk == "..." and a ~= #node then - fail(ps, i, "'...' can only be last argument") - end - if fnarg.opt then + if fnarg.tk == "..." then + if a ~= #node then + fail(ps, i, "'...' can only be last argument") + break + end + elseif fnarg.opt then opts = true elseif opts then return fail(ps, i, "non-optional arguments cannot follow optional arguments") @@ -2510,15 +2525,19 @@ local function parse_argument_type(ps, i) local typ; i, typ = parse_type(ps, i) if typ then - typ.opt = opt if not is_va and ps.tokens[i].tk == "..." then i = i + 1 is_va = true end - end - if argument_name == "self" then - typ.is_self = true + if opt then + typ = OPT(typ) + end + + if argument_name == "self" then + typ = shallow_copy_type(typ) + typ.is_self = true + end end return i, { i = i, type = typ, is_va = is_va }, 0 @@ -5232,12 +5251,6 @@ local function sorted_keys(m) return keys end -local function fill_field_order(t) - if t.typename == "record" then - t.field_order = sorted_keys(t.fields) - end -end - local function require_module(module_name, lax, env) local mod = env.modules[module_name] if mod then @@ -5392,6 +5405,13 @@ local function init_globals(lax) last_typeid = globals_typeid end + local function a_record(t) + t = a_type(t) + t.typename = "record" + t.field_order = sorted_keys(t.fields) + return t + end + local function a_gfunction(n, f) local typevars = {} local typeargs = {} @@ -5411,6 +5431,7 @@ local function init_globals(lax) local function a_grecord(n, f) local t = a_gfunction(n, f) t.typename = "record" + t.field_order = sorted_keys(t.fields) return t end @@ -5438,6 +5459,7 @@ local function init_globals(lax) { ctor = TUPLE, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, { ctor = VARARG, args = { UNION({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) }) }, rets = { UNION({ STRING, NUMBER }) } }, { ctor = VARARG, args = { UNION({ NUMBER, STRING }) }, rets = { STRING } }, + { ctor = VARARG, args = {}, rets = { STRING } }, } local function a_file_reader(fn) @@ -5455,8 +5477,7 @@ local function init_globals(lax) local LOAD_FUNCTION = a_type({ typename = "function", args = {}, rets = TUPLE({ STRING }) }) - local OS_DATE_TABLE = a_type({ - typename = "record", + local OS_DATE_TABLE = a_record({ fields = { ["year"] = INTEGER, ["month"] = INTEGER, @@ -5470,8 +5491,7 @@ local function init_globals(lax) }, }) - local DEBUG_GETINFO_TABLE = a_type({ - typename = "record", + local DEBUG_GETINFO_TABLE = a_record({ fields = { ["name"] = STRING, ["namewhat"] = STRING, @@ -5523,11 +5543,6 @@ local function init_globals(lax) }) end - - local function OPT(x) - return x - end - local standard_library = { ["..."] = VARARG({ STRING }), ["any"] = a_type({ typename = "typetype", def = ANY }), @@ -5543,7 +5558,7 @@ local function init_globals(lax) }, }), ["dofile"] = a_type({ typename = "function", args = TUPLE({ OPT(STRING) }), rets = VARARG({ ANY }) }), - ["error"] = a_type({ typename = "function", args = TUPLE({ ANY, NUMBER }), rets = TUPLE({}) }), + ["error"] = a_type({ typename = "function", args = TUPLE({ ANY, OPT(NUMBER) }), rets = TUPLE({}) }), ["getmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a }), rets = TUPLE({ METATABLE(a) }) } end), ["ipairs"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a) }), rets = TUPLE({ a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ INTEGER, a }) }), @@ -5595,8 +5610,7 @@ local function init_globals(lax) ["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), ["FILE"] = a_type({ typename = "typetype", - def = a_type({ - typename = "record", + def = a_record({ is_userdata = true, fields = { ["close"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({ BOOLEAN, STRING, INTEGER }) }), @@ -5660,8 +5674,7 @@ local function init_globals(lax) }, } end), }), - ["coroutine"] = a_type({ - typename = "record", + ["coroutine"] = a_record({ fields = { ["create"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ THREAD }) }), ["close"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ BOOLEAN, STRING }) }), @@ -5673,8 +5686,7 @@ local function init_globals(lax) ["yield"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) }), }, }), - ["debug"] = a_type({ - typename = "record", + ["debug"] = a_record({ fields = { ["Info"] = a_type({ typename = "typetype", @@ -5724,8 +5736,8 @@ local function init_globals(lax) ["traceback"] = a_type({ typename = "poly", types = { - a_type({ typename = "function", args = TUPLE({ THREAD, STRING, NUMBER }), rets = TUPLE({ STRING }) }), - a_type({ typename = "function", args = TUPLE({ STRING, NUMBER }), rets = TUPLE({ STRING }) }), + a_type({ typename = "function", args = TUPLE({ OPT(THREAD), OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), + a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), }, }), ["upvalueid"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ USERDATA }) }), @@ -5740,8 +5752,7 @@ local function init_globals(lax) }), }, }), - ["io"] = a_type({ - typename = "record", + ["io"] = a_record({ fields = { ["close"] = a_type({ typename = "function", args = TUPLE({ OPT(NOMINAL_FILE) }), rets = TUPLE({ BOOLEAN, STRING }) }), ["flush"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }), @@ -5751,9 +5762,9 @@ local function init_globals(lax) a_type({ typename = "function", args = TUPLE({}), rets = ctor(rets) }), }), }) end), - ["open"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), + ["open"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), ["output"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ STRING, NOMINAL_FILE })) }), rets = TUPLE({ NOMINAL_FILE }) }), - ["popen"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), + ["popen"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), ["read"] = a_file_reader(function(ctor, args, rets) return a_type({ typename = "function", args = ctor(args), rets = ctor(rets) }) end), @@ -5765,8 +5776,7 @@ local function init_globals(lax) ["write"] = a_type({ typename = "function", args = VARARG({ UNION({ STRING, NUMBER }) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), }, }), - ["math"] = a_type({ - typename = "record", + ["math"] = a_record({ fields = { ["abs"] = a_type({ typename = "poly", @@ -5795,7 +5805,7 @@ local function init_globals(lax) ["frexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER, NUMBER }) }), ["huge"] = NUMBER, ["ldexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), - ["log"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), + ["log"] = a_type({ typename = "function", args = TUPLE({ NUMBER, OPT(NUMBER) }), rets = TUPLE({ NUMBER }) }), ["log10"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), ["max"] = a_type({ typename = "poly", @@ -5824,7 +5834,7 @@ local function init_globals(lax) ["random"] = a_type({ typename = "poly", types = { - a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ NUMBER, OPT(NUMBER) }), rets = TUPLE({ INTEGER }) }), a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }), }, }), @@ -5839,21 +5849,20 @@ local function init_globals(lax) ["ult"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ BOOLEAN }) }), }, }), - ["os"] = a_type({ - typename = "record", + ["os"] = a_record({ fields = { ["clock"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }), ["date"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }), - a_type({ typename = "function", args = TUPLE({ an_enum({ "!*t", "*t" }), NUMBER }), rets = TUPLE({ OS_DATE_TABLE }) }), + a_type({ typename = "function", args = TUPLE({ an_enum({ "!*t", "*t" }), OPT(NUMBER) }), rets = TUPLE({ OS_DATE_TABLE }) }), a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), }, }), ["difftime"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), ["execute"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING, INTEGER }) }), - ["exit"] = a_type({ typename = "function", args = TUPLE({ UNION({ NUMBER, BOOLEAN }), BOOLEAN }), rets = TUPLE({}) }), + ["exit"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ NUMBER, BOOLEAN })), OPT(BOOLEAN) }), rets = TUPLE({}) }), ["getenv"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), ["remove"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING }) }), ["rename"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ BOOLEAN, STRING }) }), @@ -5862,8 +5871,7 @@ local function init_globals(lax) ["tmpname"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }), }, }), - ["package"] = a_type({ - typename = "record", + ["package"] = a_record({ fields = { ["config"] = STRING, ["cpath"] = STRING, @@ -5886,8 +5894,7 @@ local function init_globals(lax) ["searchpath"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(STRING), OPT(STRING) }), rets = TUPLE({ STRING, STRING }) }), }, }), - ["string"] = a_type({ - typename = "record", + ["string"] = a_record({ fields = { ["byte"] = a_type({ typename = "poly", @@ -5906,29 +5913,28 @@ local function init_globals(lax) ["gsub"] = a_type({ typename = "poly", types = { - a_type({ typename = "function", args = TUPLE({ STRING, STRING, STRING, NUMBER }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), NUMBER }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ STRING }) }) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NUMBER }) }) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ BOOLEAN }) }) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({}) }) }), rets = TUPLE({ STRING, INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ STRING }) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NUMBER }) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ BOOLEAN }) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({}) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), + a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), }, }), ["len"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }), ["lower"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), - ["match"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, NUMBER }), rets = VARARG({ STRING }) }), + ["match"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING), OPT(NUMBER) }), rets = VARARG({ STRING }) }), ["pack"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }), ["packsize"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }), ["rep"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, OPT(STRING) }), rets = TUPLE({ STRING }) }), ["reverse"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), - ["sub"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ STRING }) }), + ["sub"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, OPT(NUMBER) }), rets = TUPLE({ STRING }) }), ["unpack"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(NUMBER) }), rets = VARARG({ ANY }) }), ["upper"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), }, }), - ["table"] = a_type({ - typename = "record", + ["table"] = a_record({ fields = { ["concat"] = a_type({ typename = "function", args = TUPLE({ ARRAY(UNION({ STRING, NUMBER })), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), ["insert"] = a_type({ @@ -5948,11 +5954,10 @@ local function init_globals(lax) ["pack"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ TABLE }) }), ["remove"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(NUMBER) }), rets = TUPLE({ a }) } end), ["sort"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(TABLE_SORT_FUNCTION) }), rets = TUPLE({}) } end), - ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = TUPLE({ ARRAY(a), NUMBER, NUMBER }), rets = VARARG({ a }) } end), + ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = TUPLE({ ARRAY(a), OPT(NUMBER), OPT(NUMBER) }), rets = VARARG({ a }) } end), }, }), - ["utf8"] = a_type({ - typename = "record", + ["utf8"] = a_record({ fields = { ["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }), ["charpattern"] = STRING, @@ -5967,15 +5972,6 @@ local function init_globals(lax) ["_VERSION"] = STRING, } - for _, t in pairs(standard_library) do - fill_field_order(t) - if is_typetype(t) then - fill_field_order(t.def) - end - end - fill_field_order(OS_DATE_TABLE) - fill_field_order(DEBUG_GETINFO_TABLE) - NOMINAL_FILE.found = standard_library["FILE"] for _, m in ipairs(metatable_nominals) do m.found = standard_library["metatable"] @@ -6337,6 +6333,25 @@ tl.type_check = function(ast, opts) return u, store_errs and errs end + local function set_min_arity(f) + if f.min_arity then + return + end + local tuple = f.args.tuple + local n = #tuple + if f.args.is_va then + n = n - 1 + end + for i = n, 1, -1 do + if tuple[i].opt then + n = n - 1 + else + break + end + end + f.min_arity = n + end + local function resolve_typetype(t) if is_typetype(t) then return t.def @@ -6404,6 +6419,7 @@ tl.type_check = function(ast, opts) local copy = {} seen[orig_t] = copy + copy.opt = t.opt copy.is_userdata = t.is_userdata copy.typename = t.typename copy.filename = t.filename @@ -6446,8 +6462,9 @@ tl.type_check = function(ast, opts) end end - copy.is_method = t.is_method + set_min_arity(t) copy.min_arity = t.min_arity + copy.is_method = t.is_method copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) elseif is_record_type(t) then @@ -6819,7 +6836,7 @@ tl.type_check = function(ast, opts) end local function match_fields_to_map(rec1, map) - if not match_record_fields(rec1, function(_) return map.values end) then + if not match_record_fields(rec1, function(_) return map.values end, false) then return false, { Err(rec1, "record is not a valid map; not all fields have the same type") } end return true @@ -7605,7 +7622,7 @@ tl.type_check = function(ast, opts) end if is_record_type(t1) then - return match_fields_to_record(t1, t2) + return match_fields_to_record(t1, t2, false) elseif is_typetype(t1) and is_record_type(t1.def) then return is_a(t1.def, t2, for_equality) end @@ -7935,26 +7952,6 @@ tl.type_check = function(ast, opts) local type_check_function_call do - local function set_min_arity(f) - if f.min_arity then - return - end - if not f.args then - f.min_arity = 0 - return - end - local min_arity = 0 - for i, fnarg in ipairs(f.args) do - if not fnarg.opt then - min_arity = i - end - end - if f.args.is_va then - min_arity = min_arity - 1 - end - f.min_arity = min_arity - end - local function mark_invalid_typeargs(f) if f.typeargs then for _, a in ipairs(f.typeargs) do @@ -9931,7 +9928,7 @@ tl.type_check = function(ast, opts) node.exps[2] and node.exps[2].type, node.exps[3] and node.exps[3].type, } - local exp1type = resolve_for_call(exp1.type, args) + local exp1type = resolve_for_call(exp1.type, args, false) if exp1type.typename == "poly" then type_check_function_call(exp1, { node.exps[2], node.exps[3] }, exp1type, args, exp1, false, 0) @@ -10897,7 +10894,9 @@ tl.type_check = function(ast, opts) if node.tk == "..." then t = a_type({ typename = "tuple", is_va = true, t }) end - t.opt = node.opt + if node.opt then + t = OPT(t) + end add_var(node, node.tk, t).is_func_arg = true return node.type end, @@ -10991,7 +10990,7 @@ tl.type_check = function(ast, opts) end, }, ["function"] = { - before = function(_typ, _children) + before = function(_typ) begin_scope() end, after = function(typ, _children) @@ -11012,7 +11011,7 @@ tl.type_check = function(ast, opts) end, }, ["record"] = { - before = function(typ, _children) + before = function(typ) begin_scope() add_var(nil, "@self", a_type({ typename = "typetype", y = typ.y, x = typ.x, def = typ })) diff --git a/tl.tl b/tl.tl index e180b5995..804aa0f3d 100644 --- a/tl.tl +++ b/tl.tl @@ -133,10 +133,10 @@ local record tl load: function(string, string, LoadMode, {any:any}): LoadFunction, string process: function(string, Env, string, FILE): (Result, string) - process_string: function(string, boolean, Env, string, string): Result + process_string: function(string, boolean, Env, ? string, ? string): Result gen: function(string, Env, PrettyPrintOptions): string, Result type_check: function(Node, TypeCheckOptions): Result, string - init_env: function(boolean, boolean | CompatMode, TargetMode, {string}): Env, string + init_env: function(? boolean, ? boolean | CompatMode, ? TargetMode, ? {string}): Env, string version: function(): string package_loader_env: Env @@ -1475,15 +1475,15 @@ end local parse_type_list: function(ParseState, integer, ParseTypeListMode): integer, Type local parse_expression: function(ParseState, integer): integer, Node, integer local parse_expression_and_tk: function(ps: ParseState, i: integer, tk: string): integer, Node -local parse_statements: function(ParseState, integer, boolean): integer, Node +local parse_statements: function(ParseState, integer, ? boolean): integer, Node local parse_argument_list: function(ParseState, integer): integer, Node local parse_argument_type_list: function(ParseState, integer): integer, Type local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer, name: string): integer, Node -local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node +local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node, name?: string): integer, Node local parse_enum_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node -local parse_record_body: function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node +local parse_record_body: function(ps: ParseState, i: integer, def: Type, node: Node, name?: string): integer, Node local parse_type_body_fns: {TypeName:ParseBody} local function fail(ps: ParseState, i: integer, msg: string): integer @@ -1528,7 +1528,7 @@ local function verify_end(ps: ParseState, i: integer, istart: integer, node: Nod return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end -local function new_node(tokens: {Token}, i: integer, kind: NodeKind): Node +local function new_node(tokens: {Token}, i: integer, kind?: NodeKind): Node local t = tokens[i] return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end @@ -1558,7 +1558,7 @@ local function shallow_copy_type(t: Type): Type return copy as Type end --- Makes a shallow copy of the given type +-- Makes a shallow copy of the given node local function shallow_copy_node(t: Node): Node local copy: {any:any} = {} for k, v in pairs(t as {any:any}) do @@ -1567,7 +1567,7 @@ local function shallow_copy_node(t: Node): Node return copy as Node end -local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind: NodeKind): integer, Node +local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind?: NodeKind): integer, Node if ps.tokens[i].kind == kind then return i + 1, new_node(ps.tokens, i, node_kind) end @@ -1586,7 +1586,7 @@ local function skip(ps: ParseState, i: integer, skip_fn: SkipFunction): integer, return skip_fn(err_ps, i) end -local function failskip(ps: ParseState, i: integer, msg: string, skip_fn: SkipFunction, starti: integer): integer +local function failskip(ps: ParseState, i: integer, msg: string, skip_fn: SkipFunction, starti?: integer): integer local skip_i = skip(ps, starti or i, skip_fn) fail(ps, i, msg) return skip_i @@ -1685,7 +1685,7 @@ local function parse_table_item(ps: ParseState, i: integer, n: integer): integer return i, node, n + 1 end -local type ParseItem = function(ParseState, integer, integer): integer, T, integer +local type ParseItem = function(ParseState, integer, ? integer): integer, T, integer local enum SeparatorMode "sep" @@ -1853,6 +1853,19 @@ local simple_types: {string:Type} = { ["integer"] = INTEGER, } +local memoize_opt_types: {Type:Type} = {} + +local function OPT(t: Type): Type + if memoize_opt_types[t] then + return memoize_opt_types[t] + end + + local ot = shallow_copy_type(t) + ot.opt = true + memoize_opt_types[t] = ot + return ot +end + local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer, Type local tk = ps.tokens[i].tk local st = simple_types[tk] @@ -2138,7 +2151,6 @@ do } local function new_operator(tk: Token, arity: integer, op: string): Operator - op = op or tk.tk return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } end @@ -2178,8 +2190,8 @@ do end local e1: Node local t1 = ps.tokens[i] - if precedences[1][ps.tokens[i].tk] ~= nil then - local op: Operator = new_operator(ps.tokens[i], 1) + if precedences[1][t1.tk] ~= nil then + local op: Operator = new_operator(t1, 1, t1.tk) i = i + 1 local prev_i = i i, e1 = P(ps, i) @@ -2211,7 +2223,7 @@ do break end if tkop.tk == "." or tkop.tk == ":" then - local op: Operator = new_operator(tkop, 2) + local op: Operator = new_operator(tkop, 2, tkop.tk) local prev_i = i @@ -2337,7 +2349,7 @@ do local lookahead = ps.tokens[i].tk while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do local t1 = ps.tokens[i] - local op: Operator = new_operator(t1, 2) + local op: Operator = new_operator(t1, 2, t1.tk) i = i + 1 local rhs: Node i, rhs = P(ps, i) @@ -2435,6 +2447,7 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege local node: Node if ps.tokens[i].tk == "..." then i, node = verify_kind(ps, i, "...", "argument") + node.opt = true else i, node = verify_kind(ps, i, "identifier", "argument") end @@ -2463,10 +2476,12 @@ parse_argument_list = function(ps: ParseState, i: integer): integer, Node i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false for a, fnarg in ipairs(node) do - if fnarg.tk == "..." and a ~= #node then - fail(ps, i, "'...' can only be last argument") - end - if fnarg.opt then + if fnarg.tk == "..." then + if a ~= #node then + fail(ps, i, "'...' can only be last argument") + break + end + elseif fnarg.opt then opts = true elseif opts then return fail(ps, i, "non-optional arguments cannot follow optional arguments") @@ -2510,15 +2525,19 @@ local function parse_argument_type(ps: ParseState, i: integer): integer, TypeAnd local typ: Type; i, typ = parse_type(ps, i) if typ then - typ.opt = opt if not is_va and ps.tokens[i].tk == "..." then i = i + 1 is_va = true end - end - if argument_name == "self" then - typ.is_self = true + if opt then + typ = OPT(typ) + end + + if argument_name == "self" then + typ = shallow_copy_type(typ) + typ.is_self = true + end end return i, { i = i, type = typ, is_va = is_va }, 0 @@ -2610,7 +2629,7 @@ local function parse_function(ps: ParseState, i: integer, ft: FunctionType): int return i, fn end -local function parse_if_block(ps: ParseState, i: integer, n: integer, node: Node, is_else: boolean): integer, Node +local function parse_if_block(ps: ParseState, i: integer, n: integer, node: Node, is_else?: boolean): integer, Node local block = new_node(ps.tokens, i, "if_block") i = i + 1 block.if_parent = node @@ -3345,7 +3364,7 @@ local needs_local_or_global: {string : function(ParseState, integer):(integer, N ["enum"] = type_needs_local_or_global, } -parse_statements = function(ps: ParseState, i: integer, toplevel: boolean): integer, Node +parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): integer, Node local node = new_node(ps.tokens, i, "statements") local item: Node while true do @@ -3451,12 +3470,12 @@ end -------------------------------------------------------------------------------- local record VisitorCallbacks - before: function(N, {T}) + before: function(N) before_exp: function({N}, {T}) before_arguments: function({N}, {T}) before_statements: function({N}, {T}) before_e2: function({N}, {T}) - after: function(N, {T}, T): T + after: function(N, {T}): T end local enum VisitorExtraCallback @@ -3476,7 +3495,7 @@ local enum MetaMode "meta" end -local function fields_of(t: Type, meta: MetaMode): (function(): string, Type) +local function fields_of(t: Type, meta?: MetaMode): (function(): string, Type) local i = 1 local field_order, fields: {string}, {string:Type} if meta then @@ -3498,7 +3517,7 @@ local function fields_of(t: Type, meta: MetaMode): (function(): string, Type) end end -local show_type: function(Type, boolean, {Type:string}): string +local show_type: function(Type, ? boolean, ? {Type:string}): string local tl_debug_indent = 0 local record DebugEntry @@ -3535,7 +3554,7 @@ local function tl_debug_indent_push(mark: string, y: integer, x: integer, fmt: s } end -local function tl_debug_indent_pop(mark: string, single: string, y: integer, x: integer, fmt: string, ...: any) +local function tl_debug_indent_pop(mark: string, single: string, y: integer, x: integer, fmt?: string, ...: any) if tl_debug_entry then local msg = tl_debug_entry.msg if fmt then @@ -4057,7 +4076,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end end - local function add_child(out: Output, child: Output, space: string, current_indent: integer): integer + local function add_child(out: Output, child: Output, space?: string, current_indent?: integer): integer if #child == 0 then return end @@ -5134,7 +5153,7 @@ local function inferred_msg(t: Type): string return " (inferred at "..t.inferred_at.filename..":"..t.inferred_at.y..":"..t.inferred_at.x..")" end -show_type = function(t: Type, short: boolean, seen: {Type:string}): string +show_type = function(t: Type, short?: boolean, seen?: {Type:string}): string seen = seen or {} if seen[t] then return seen[t] @@ -5232,12 +5251,6 @@ local function sorted_keys(m: {A:B}):{A} return keys end -local function fill_field_order(t: Type) - if t.typename == "record" then - t.field_order = sorted_keys(t.fields) - end -end - local function require_module(module_name: string, lax: boolean, env: Env): Type, boolean local mod = env.modules[module_name] if mod then @@ -5351,7 +5364,7 @@ local bit_operators: {string:string} = { ["<<"] = "lshift", } -local function convert_node_to_compat_call(node: Node, mod_name: string, fn_name: string, e1: Node, e2: Node) +local function convert_node_to_compat_call(node: Node, mod_name: string, fn_name: string, e1: Node, e2?: Node) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 @@ -5363,7 +5376,7 @@ local function convert_node_to_compat_call(node: Node, mod_name: string, fn_name node.e2[2] = e2 end -local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which_self: integer, e1: Node, e2: Node) +local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which_self: integer, e1: Node, e2?: Node) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 @@ -5392,6 +5405,13 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} last_typeid = globals_typeid end + local function a_record(t: Type): Type + t = a_type(t) + t.typename = "record" + t.field_order = sorted_keys(t.fields) + return t + end + local function a_gfunction(n: integer, f: function(...: Type): Type): Type local typevars = {} local typeargs = {} @@ -5411,6 +5431,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local function a_grecord(n: integer, f: function(...: Type): Type): Type local t = a_gfunction(n, f) t.typename = "record" + t.field_order = sorted_keys(t.fields) return t end @@ -5434,10 +5455,11 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} end local file_reader_poly_types: {ArgsRets} = { - { ctor = VARARG, args = { UNION { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, + { ctor = VARARG, args = {UNION { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, { ctor = TUPLE, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, { ctor = VARARG, args = { UNION { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { UNION { STRING, NUMBER } } }, { ctor = VARARG, args = { UNION { NUMBER, STRING } }, rets = { STRING } }, + { ctor = VARARG, args = { }, rets = { STRING } }, } local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type)): Type @@ -5455,8 +5477,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local LOAD_FUNCTION = a_type { typename = "function", args = {}, rets = TUPLE { STRING } } - local OS_DATE_TABLE = a_type { - typename = "record", + local OS_DATE_TABLE = a_record { fields = { ["year"] = INTEGER, ["month"] = INTEGER, @@ -5470,8 +5491,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} } } - local DEBUG_GETINFO_TABLE = a_type { - typename = "record", + local DEBUG_GETINFO_TABLE = a_record { fields = { ["name"] = STRING, ["namewhat"] = STRING, @@ -5523,11 +5543,6 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} } end - -- placeholders for when we have optional arity annotations - local function OPT(x: Type): Type - return x - end - local standard_library: {string:Type} = { ["..."] = VARARG { STRING }, ["any"] = a_type { typename = "typetype", def = ANY }, @@ -5543,7 +5558,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} } }, ["dofile"] = a_type { typename = "function", args = TUPLE { OPT(STRING) }, rets = VARARG { ANY } }, - ["error"] = a_type { typename = "function", args = TUPLE { ANY, NUMBER }, rets = TUPLE {} }, + ["error"] = a_type { typename = "function", args = TUPLE { ANY, OPT(NUMBER) }, rets = TUPLE {} }, ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { a }, rets = TUPLE { METATABLE(a) } } end), ["ipairs"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a) }, rets = TUPLE { a_type { typename = "function", args = TUPLE {}, rets = TUPLE { INTEGER, a } }, @@ -5595,8 +5610,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["type"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { STRING } }, ["FILE"] = a_type { typename = "typetype", - def = a_type { - typename = "record", + def = a_record { is_userdata = true, fields = { ["close"] = a_type { typename = "function", args = TUPLE { NOMINAL_FILE }, rets = TUPLE { BOOLEAN, STRING, INTEGER } }, @@ -5660,8 +5674,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, } end), }, - ["coroutine"] = a_type { - typename = "record", + ["coroutine"] = a_record { fields = { ["create"] = a_type { typename = "function", args = TUPLE { FUNCTION }, rets = TUPLE { THREAD } }, ["close"] = a_type { typename = "function", args = TUPLE { THREAD }, rets = TUPLE { BOOLEAN, STRING } }, @@ -5673,8 +5686,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["yield"] = a_type { typename = "function", args = VARARG { ANY }, rets = VARARG { ANY } }, } }, - ["debug"] = a_type { - typename = "record", + ["debug"] = a_record { fields = { ["Info"] = a_type { typename = "typetype", @@ -5724,8 +5736,8 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["traceback"] = a_type { typename = "poly", types = { - a_type { typename = "function", args = TUPLE { THREAD, STRING, NUMBER }, rets = TUPLE { STRING } }, - a_type { typename = "function", args = TUPLE { STRING, NUMBER }, rets = TUPLE { STRING } }, + a_type { typename = "function", args = TUPLE { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING } }, + a_type { typename = "function", args = TUPLE { OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING } }, }, }, ["upvalueid"] = a_type { typename = "function", args = TUPLE { FUNCTION, NUMBER }, rets = TUPLE { USERDATA } }, @@ -5740,8 +5752,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, }, }, - ["io"] = a_type { - typename = "record", + ["io"] = a_record { fields = { ["close"] = a_type { typename = "function", args = TUPLE { OPT(NOMINAL_FILE) }, rets = TUPLE { BOOLEAN, STRING } }, ["flush"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE {} }, @@ -5751,9 +5762,9 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} a_type { typename = "function", args = TUPLE {}, rets = ctor(rets) }, } } end), - ["open"] = a_type { typename = "function", args = TUPLE { STRING, STRING }, rets = TUPLE { NOMINAL_FILE, STRING } }, + ["open"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING) }, rets = TUPLE { NOMINAL_FILE, STRING } }, ["output"] = a_type { typename = "function", args = TUPLE { OPT(UNION { STRING, NOMINAL_FILE }) }, rets = TUPLE { NOMINAL_FILE } }, - ["popen"] = a_type { typename = "function", args = TUPLE { STRING, STRING }, rets = TUPLE { NOMINAL_FILE, STRING } }, + ["popen"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING) }, rets = TUPLE { NOMINAL_FILE, STRING } }, ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type return a_type { typename = "function", args = ctor(args), rets = ctor(rets) } end), @@ -5765,8 +5776,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["write"] = a_type { typename = "function", args = VARARG { UNION { STRING, NUMBER } }, rets = TUPLE { NOMINAL_FILE, STRING } }, }, }, - ["math"] = a_type { - typename = "record", + ["math"] = a_record { fields = { ["abs"] = a_type { typename = "poly", @@ -5795,7 +5805,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["frexp"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER, NUMBER } }, ["huge"] = NUMBER, ["ldexp"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, - ["log"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, + ["log"] = a_type { typename = "function", args = TUPLE { NUMBER, OPT(NUMBER) }, rets = TUPLE { NUMBER } }, ["log10"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, ["max"] = a_type { typename = "poly", @@ -5824,7 +5834,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["random"] = a_type { typename = "poly", types = { - a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { INTEGER } }, + a_type { typename = "function", args = TUPLE { NUMBER, OPT(NUMBER) }, rets = TUPLE { INTEGER } }, a_type { typename = "function", args = TUPLE {}, rets = TUPLE { NUMBER } }, } }, @@ -5839,21 +5849,20 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["ult"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { BOOLEAN } }, }, }, - ["os"] = a_type { - typename = "record", + ["os"] = a_record { fields = { ["clock"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { NUMBER } }, ["date"] = a_type { typename = "poly", types = { - a_type { typename = "function", args = TUPLE {}, rets = TUPLE { STRING } }, - a_type { typename = "function", args = TUPLE { an_enum { "!*t", "*t" }, NUMBER }, rets = TUPLE { OS_DATE_TABLE } }, + a_type { typename = "function", args = TUPLE { }, rets = TUPLE { STRING } }, + a_type { typename = "function", args = TUPLE { an_enum { "!*t", "*t" }, OPT(NUMBER) }, rets = TUPLE { OS_DATE_TABLE } }, a_type { typename = "function", args = TUPLE { OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING } }, } }, ["difftime"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, ["execute"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { BOOLEAN, STRING, INTEGER } }, - ["exit"] = a_type { typename = "function", args = TUPLE { UNION { NUMBER, BOOLEAN }, BOOLEAN }, rets = TUPLE {} }, + ["exit"] = a_type { typename = "function", args = TUPLE { OPT(UNION { NUMBER, BOOLEAN }), OPT(BOOLEAN) }, rets = TUPLE {} }, ["getenv"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, ["remove"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { BOOLEAN, STRING } }, ["rename"] = a_type { typename = "function", args = TUPLE { STRING, STRING}, rets = TUPLE { BOOLEAN, STRING } }, @@ -5862,8 +5871,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["tmpname"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { STRING } }, }, }, - ["package"] = a_type { - typename = "record", + ["package"] = a_record { fields = { ["config"] = STRING, ["cpath"] = STRING, @@ -5886,8 +5894,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["searchpath"] = a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = TUPLE { STRING, STRING } }, }, }, - ["string"] = a_type { - typename = "record", + ["string"] = a_record { fields = { ["byte"] = a_type { typename = "poly", @@ -5906,29 +5913,28 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["gsub"] = a_type { typename = "poly", types = { - a_type { typename = "function", args = TUPLE { STRING, STRING, STRING, NUMBER }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "map", keys = STRING, values = STRING }, NUMBER }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { STRING } } }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { NUMBER } } }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { BOOLEAN } } }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE {} } }, rets = TUPLE { STRING, INTEGER } }, + a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "map", keys = STRING, values = STRING }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, + a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { STRING } }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, + a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { NUMBER } }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, + a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { BOOLEAN } }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, + a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE {} }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, + a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, -- FIXME any other modes } }, ["len"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { INTEGER } }, ["lower"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, - ["match"] = a_type { typename = "function", args = TUPLE { STRING, STRING, NUMBER }, rets = VARARG { STRING } }, + ["match"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING), OPT(NUMBER) }, rets = VARARG { STRING } }, ["pack"] = a_type { typename = "function", args = VARARG { STRING, ANY }, rets = TUPLE { STRING } }, ["packsize"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { INTEGER } }, ["rep"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, OPT(STRING) }, rets = TUPLE { STRING } }, ["reverse"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, - ["sub"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, NUMBER }, rets = TUPLE { STRING } }, + ["sub"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, OPT(NUMBER) }, rets = TUPLE { STRING } }, ["unpack"] = a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(NUMBER) }, rets = VARARG { ANY } }, ["upper"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, }, }, - ["table"] = a_type { - typename = "record", + ["table"] = a_record { fields = { ["concat"] = a_type { typename = "function", args = TUPLE { ARRAY(UNION {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = TUPLE { STRING } }, ["insert"] = a_type { @@ -5948,11 +5954,10 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["pack"] = a_type { typename = "function", args = VARARG { ANY }, rets = TUPLE { TABLE } }, ["remove"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), OPT(NUMBER) }, rets = TUPLE { a } } end), ["sort"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), OPT(TABLE_SORT_FUNCTION) }, rets = TUPLE {} } end), - ["unpack"] = a_gfunction(1, function(a: Type): Type return { needs_compat = true, args = TUPLE { ARRAY(a), NUMBER, NUMBER }, rets = VARARG { a } } end), + ["unpack"] = a_gfunction(1, function(a: Type): Type return { needs_compat = true, args = TUPLE { ARRAY(a), OPT(NUMBER), OPT(NUMBER) }, rets = VARARG { a } } end), }, }, - ["utf8"] = a_type { - typename = "record", + ["utf8"] = a_record { fields = { ["char"] = a_type { typename = "function", args = VARARG { NUMBER }, rets = TUPLE { STRING } }, ["charpattern"] = STRING, @@ -5967,15 +5972,6 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["_VERSION"] = STRING, } - for _, t in pairs(standard_library) do - fill_field_order(t) - if is_typetype(t) then - fill_field_order(t.def) - end - end - fill_field_order(OS_DATE_TABLE) - fill_field_order(DEBUG_GETINFO_TABLE) - NOMINAL_FILE.found = standard_library["FILE"] for _, m in ipairs(metatable_nominals) do m.found = standard_library["metatable"] @@ -5997,7 +5993,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return globals, standard_library end -tl.init_env = function(lax: boolean, gen_compat: boolean | CompatMode, gen_target: TargetMode, predefined: {string}): Env, string +tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_target?: TargetMode, predefined?: {string}): Env, string if gen_compat == true or gen_compat == nil then gen_compat = "optional" elseif gen_compat == false then @@ -6088,7 +6084,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "check_only" end - local function find_var(name: string, use: VarUse): Variable, integer, Attribute + local function find_var(name: string, use?: VarUse): Variable, integer, Attribute for i = #st, 1, -1 do local scope = st[i] local var = scope[name] @@ -6130,7 +6126,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local type ResolveType = function(Type): Type - local resolve_typevars: function (typ: Type, fn_var: ResolveType, fn_arg: ResolveType): boolean, Type, {Error} + local resolve_typevars: function (typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local function fresh_typevar(t: Type): Type, Type, boolean return a_type { @@ -6158,7 +6154,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function find_var_type(name: string, use: VarUse): Type, Attribute + local function find_var_type(name: string, use?: VarUse): Type, Attribute local var = find_var(name, use) if var then local t = var.t @@ -6204,7 +6200,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function find_type(names: {string}, accept_typearg: boolean): Type + local function find_type(names: {string}, accept_typearg?: boolean): Type local typ = find_var_type(names[1], "use_type") if not typ then return nil @@ -6321,7 +6317,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function validate_union(where: Where, u: Type, store_errs: boolean, errs: {Error}): Type, {Error} + local function validate_union(where: Where, u: Type, store_errs?: boolean, errs?: {Error}): Type, {Error} local valid, err = is_valid_union(u) if err then if store_errs then @@ -6337,6 +6333,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return u, store_errs and errs end + local function set_min_arity(f: Type) + if f.min_arity then + return + end + local tuple = f.args.tuple + local n = #tuple + if f.args.is_va then + n = n - 1 + end + for i = n, 1, -1 do + if tuple[i].opt then + n = n - 1 + else + break + end + end + f.min_arity = n + end + local function resolve_typetype(t: Type): Type if is_typetype(t) then return t.def @@ -6404,6 +6419,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local copy: Type = {} seen[orig_t] = copy + copy.opt = t.opt copy.is_userdata = t.is_userdata copy.typename = t.typename copy.filename = t.filename @@ -6446,8 +6462,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - copy.is_method = t.is_method + set_min_arity(t) copy.min_arity = t.min_arity + copy.is_method = t.is_method copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) elseif is_record_type(t) then @@ -6559,7 +6576,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_warning("unknown", node, "unknown variable: %s", name) end - local function redeclaration_warning(node: Node, old_var: Variable) + local function redeclaration_warning(node: Node, old_var?: Variable) if node.tk:sub(1, 1) == "_" then return end local var_kind = "variable" @@ -6666,7 +6683,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.symbol_list_slot = symbol_list_n end - local get_unresolved: function(scope: Scope): Type + local get_unresolved: function(scope?: Scope): Type local function add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable local scope = st[#st] @@ -6713,7 +6730,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return var end - local function add_var(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable + local function add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then add_unknown(node, name) end @@ -6738,7 +6755,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return var end - local type CompareTypes = function(Type, Type, boolean): boolean, {Error} + local type CompareTypes = function(Type, Type, ? boolean): boolean, {Error} local function compare_and_infer_typevars(t1: Type, t2: Type, comp: CompareTypes): boolean, {Error} -- if both are typevars and they are the same variable, nothing to do here @@ -6774,7 +6791,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local same_type: function(t1: Type, t2: Type): boolean, {Error} - local is_a: function(Type, Type, boolean): boolean, {Error} + local is_a: function(Type, Type, ? boolean): boolean, {Error} local type TypeGetter = function(string): Type @@ -6819,7 +6836,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function match_fields_to_map(rec1: Type, map: Type): boolean, {Error} - if not match_record_fields(rec1, function(_: string): Type return map.values end) then + if not match_record_fields(rec1, function(_: string): Type return map.values end, false) then return false, { Err(rec1, "record is not a valid map; not all fields have the same type") } end return true @@ -6887,7 +6904,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string var: Variable end - local function check_for_unused_vars(vars: {string:Variable}, is_global: boolean) + local function check_for_unused_vars(vars: {string:Variable}, is_global?: boolean) if not next(vars) then return end @@ -6917,7 +6934,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - get_unresolved = function(scope: Scope): Type + get_unresolved = function(scope?: Scope): Type local unresolved: Type if scope then local unr = scope["@unresolved"] @@ -6938,7 +6955,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return unresolved end - local function begin_scope(node: Node) + local function begin_scope(node?: Node) table.insert(st, {}) if node then @@ -6947,7 +6964,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function end_scope(node: Node) + local function end_scope(node?: Node) local scope = st[#st] local unresolved = scope["@unresolved"] if unresolved then @@ -7250,7 +7267,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function unite(types: {Type}, flatten_constants: boolean): Type + local function unite(types: {Type}, flatten_constants?: boolean): Type if #types == 1 then return types[1] end @@ -7605,7 +7622,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if is_record_type(t1) then - return match_fields_to_record(t1, t2) + return match_fields_to_record(t1, t2, false) elseif is_typetype(t1) and is_record_type(t1.def) then -- record as prototype return is_a(t1.def, t2, for_equality) end @@ -7716,7 +7733,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, { Err(t1, "got %s, expected %s", t1, t2) } end - local function assert_is_a(node: Node, t1: Type, t2: Type, context: string, name: string): boolean + local function assert_is_a(node: Node, t1: Type, t2: Type, context: string, name?: string): boolean t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) if lax and (is_unknown(t1) or is_unknown(t2)) then @@ -7933,28 +7950,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, integer): Type + local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, ? integer): Type do - local function set_min_arity(f: Type) - if f.min_arity then - return - end - if not f.args then - f.min_arity = 0 - return - end - local min_arity = 0 - for i, fnarg in ipairs(f.args) do - if not fnarg.opt then - min_arity = i - end - end - if f.args.is_va then - min_arity = min_arity - 1 - end - f.min_arity = min_arity - end - local function mark_invalid_typeargs(f: Type) if f.typeargs then for _, a in ipairs(f.typeargs) do @@ -8188,7 +8185,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return fail_call(node, func, given, first_errs) end - type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta: integer): Type + type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta?: integer): Type if node.expected and node.expected.typename ~= "tuple" then node.expected = a_type { typename = "tuple", node.expected } end @@ -8371,7 +8368,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return recurse_node(root, visit_node, visit_type) end - local function widen_all_unions(node: Node) + local function widen_all_unions(node?: Node) for i = #st, 1, -1 do local scope = st[i] local unr = scope["@unresolved"] @@ -8385,7 +8382,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function add_global(node: Node, var: string, valtype: Type, is_assigning: boolean): Variable + local function add_global(node: Node, var: string, valtype: Type, is_assigning?: boolean): Variable if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then add_unknown(node, var) end @@ -9133,7 +9130,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local type_check_funcall: function(node: Node, a: Type, b: {Type}, argdelta: integer): Type + local type_check_funcall: function(node: Node, a: Type, b: {Type}, argdelta?: integer): Type local function special_pcall_xpcall(node: Node, _a: Type, b: {Type}, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 @@ -9218,7 +9215,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, } - type_check_funcall = function(node: Node, a: Type, b: {Type}, argdelta: integer): Type + type_check_funcall = function(node: Node, a: Type, b: {Type}, argdelta?: integer): Type argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] @@ -9931,7 +9928,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.exps[2] and node.exps[2].type, node.exps[3] and node.exps[3].type } - local exp1type = resolve_for_call(exp1.type, args) + local exp1type = resolve_for_call(exp1.type, args, false) if exp1type.typename == "poly" then type_check_function_call(exp1, {node.exps[2], node.exps[3]}, exp1type, args, exp1, false, 0) @@ -10897,7 +10894,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.tk == "..." then t = a_type { typename = "tuple", is_va = true, t } end - t.opt = node.opt + if node.opt then + t = OPT(t) + end add_var(node, node.tk, t).is_func_arg = true return node.type end, @@ -10991,7 +10990,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["function"] = { - before = function(_typ: Type, _children: {Type}) + before = function(_typ: Type) begin_scope() end, after = function(typ: Type, _children: {Type}): Type @@ -11012,7 +11011,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["record"] = { - before = function(typ: Type, _children: {Type}) + before = function(typ: Type) begin_scope() add_var(nil, "@self", a_type { typename = "typetype", y = typ.y, x = typ.x, def = typ }) @@ -11533,7 +11532,7 @@ tl.process = function(filename: string, env: Env, module_name: string, fd: FILE) return tl.process_string(input, is_lua, env, filename, module_name) end -function tl.process_string(input: string, is_lua: boolean, env: Env, filename: string, module_name: string): Result +function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: string, module_name?: string): Result if filename and not module_name then module_name = filename_to_module_name(filename) end From 51a0be2040c61e8f0fcdc85260e445b1f7f29d5e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 27 Jan 2021 00:38:34 -0300 Subject: [PATCH 026/224] optional arity: fix arguments in some tests --- spec/call/record_method_spec.lua | 2 +- spec/operator/index_spec.lua | 28 ++++++++++++++-------------- spec/subtyping/any_spec.lua | 2 +- spec/subtyping/nil_spec.lua | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/spec/call/record_method_spec.lua b/spec/call/record_method_spec.lua index 0e87405b7..3e8713b04 100644 --- a/spec/call/record_method_spec.lua +++ b/spec/call/record_method_spec.lua @@ -192,7 +192,7 @@ describe("record method call", function() local record Foo x: integer end - function Foo:add(other: Foo) + function Foo:add(other?: Foo) self.x = other and (self.x + other.x) or self.x end local first: Foo = {} diff --git a/spec/operator/index_spec.lua b/spec/operator/index_spec.lua index 2b2ca792e..40d9009df 100644 --- a/spec/operator/index_spec.lua +++ b/spec/operator/index_spec.lua @@ -77,20 +77,20 @@ describe("[]", function() describe("on strings", function() it("works with relevant stdlib string functions", util.check([[ local s: string - s:byte() - s:find() + s:byte(1) + s:find("hello") s:format() - s:gmatch() - s:gsub() + s:gmatch("hello") + s:gsub("hello", "world") s:len() s:lower() s:match() s:pack() s:packsize() - s:rep() + s:rep(2) s:reverse() - s:sub() - s:unpack() + s:sub(2) + s:unpack("b") s:upper() ]])) end) @@ -101,20 +101,20 @@ describe("[]", function() "bar" end local s: foo - s:byte() - s:find() + s:byte(1) + s:find("hello") s:format() - s:gmatch() - s:gsub() + s:gmatch("hello") + s:gsub("hello", "world") s:len() s:lower() s:match() s:pack() s:packsize() - s:rep() + s:rep(2) s:reverse() - s:sub() - s:unpack() + s:sub(2) + s:unpack("b") s:upper() ]])) end) diff --git a/spec/subtyping/any_spec.lua b/spec/subtyping/any_spec.lua index 06f1be10a..1467e362b 100644 --- a/spec/subtyping/any_spec.lua +++ b/spec/subtyping/any_spec.lua @@ -44,7 +44,7 @@ describe("subtyping of any:", function() it("thread <: any", util.check([[ local a: any - a = coroutine.create() + a = coroutine.create(function() end) ]])) it("poly <: any", util.check([[ diff --git a/spec/subtyping/nil_spec.lua b/spec/subtyping/nil_spec.lua index bc8e368de..d79123f08 100644 --- a/spec/subtyping/nil_spec.lua +++ b/spec/subtyping/nil_spec.lua @@ -41,7 +41,7 @@ describe("subtyping of nil:", function() ]])) it("nil <: thread", util.check([[ - local c = coroutine.create() + local c = coroutine.create(function() end) c = nil ]])) From 44bdf19f7593dde0bf2ce94db3b657681d14ace5 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Jan 2021 02:33:13 -0300 Subject: [PATCH 027/224] optional arity: add some optionals --- spec/call/record_method_spec.lua | 2 +- spec/declaration/record_spec.lua | 4 ++-- tl.tl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spec/call/record_method_spec.lua b/spec/call/record_method_spec.lua index 3e8713b04..0e91056c1 100644 --- a/spec/call/record_method_spec.lua +++ b/spec/call/record_method_spec.lua @@ -17,7 +17,7 @@ describe("record method call", function() ]])) it("method call with different call forms", util.check([[ - local foo = {bar = function(x: any, t: any) end} + local foo = {bar = function(x: any, t?: any) end} print(foo:bar()) print(foo:bar{}) print(foo:bar"hello") diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 067d94d9e..7083f95d7 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -133,8 +133,8 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces it("can overload functions", util.check([[ global type love_graphics = ]]..statement..[[ ]]..array(i, "{love_graphics}")..[[ - print: function(text: string, x: number, y: number, r: number, sx: number, sy: number, ox: number, oy: number, kx: number, ky:number) - print: function(coloredtext: {any}, x: number, y: number, r: number, sx: number, sy: number, ox: number, oy: number, kx: number, ky:number) + print: function(text: string, x: number, y: number, r?: number, sx?: number, sy?: number, ox?: number, oy?: number, kx?: number, ky?: number) + print: function(coloredtext: {any}, x: number, y: number, r?: number, sx?: number, sy?: number, ox?: number, oy?: number, kx?: number, ky?: number) end global type love = ]]..statement..[[ ]]..array(i, "{love}")..[[ diff --git a/tl.tl b/tl.tl index 804aa0f3d..53f99886a 100644 --- a/tl.tl +++ b/tl.tl @@ -1754,7 +1754,7 @@ local function parse_table_literal(ps: ParseState, i: integer): integer, Node return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end -local function parse_trying_list(ps: ParseState, i: integer, list: {T}, parse_item: ParseItem, ret_lookahead: boolean): integer, {T} +local function parse_trying_list(ps: ParseState, i: integer, list: {T}, parse_item: ParseItem, ret_lookahead?: boolean): integer, {T} local try_ps: ParseState = { filename = ps.filename, tokens = ps.tokens, From 7f3fd80686b4db45f4100abaaef67c3481af0af4 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 27 Jan 2021 00:38:02 -0300 Subject: [PATCH 028/224] optional arity: enable optional arity checks! --- spec/call/function_spec.lua | 66 +++++++++++++++++++++++++++++++++++++ tl.lua | 30 ++++++++--------- tl.tl | 32 +++++++++--------- 3 files changed, 97 insertions(+), 31 deletions(-) diff --git a/spec/call/function_spec.lua b/spec/call/function_spec.lua index 31f26b9e5..1aaab0d0f 100644 --- a/spec/call/function_spec.lua +++ b/spec/call/function_spec.lua @@ -32,4 +32,70 @@ describe("function calls", function() return foo("hi there", select(2, ...)) end ]])) + + describe("check the arity of functions:", function() + it("when excessive", util.check_type_error([[ + local function f(n: number, m: number): number + return n + m + end + + local x = f(1, 2, 3) + ]], { + { y = 5, msg = "wrong number of arguments (given 3, expects 2)" }, + })) + + it("when insufficient", util.check_type_error([[ + local function f(n: number, m: number): number + return n + m + end + + local x = f(1) + ]], { + { y = 5, msg = "wrong number of arguments (given 1, expects 2)" }, + })) + + it("when using optional", util.check([[ + local function f(n: number, m?: number): number + return n + (m or 0) + end + + local x = f(1) + ]])) + + it("when insufficient with optionals", util.check_type_error([[ + local function f(n: number, m?: number): number + return n + (m or 0) + end + + local x = f() + ]], { + { y = 5, msg = "wrong number of arguments (given 0, expects at least 1 and at most 2)" }, + })) + + it("when using all optionals", util.check([[ + local function f(n?: number, m?: number): number + return (n or 0) + (m or 0) + end + + local x = f() + ]])) + + it("when using all optionals", util.check([[ + local function f(n?: number, m?: number): number + return (n or 0) + (m or 0) + end + + local x = f(1, 2) + ]])) + + it("when excessive with optionals", util.check_type_error([[ + local function f(n: number, m?: number): number + return (n or 0) + (m or 0) + end + + local x = f(1, 2, 3) + ]], { + { y = 5, msg = "wrong number of arguments (given 3, expects at least 1 and at most 2)" }, + })) + end) end) diff --git a/tl.lua b/tl.lua index 40b1afc51..452a7cc51 100644 --- a/tl.lua +++ b/tl.lua @@ -2835,7 +2835,7 @@ local function parse_nested_type(ps, i, def, typename, parse_body) local nt = new_node(ps.tokens, i - 2, "newtype") nt.newtype = new_type(ps, i, "typetype") local rdef = new_type(ps, i, typename) - local iok = parse_body(ps, i, rdef, nt, v.tk) + local iok = parse_body(ps, i, rdef, nt) if iok then i = iok nt.newtype.def = rdef @@ -3024,7 +3024,7 @@ parse_record_body = function(ps, i, def, node, name) end i = verify_tk(ps, i, "=") local nt - i, nt = parse_newtype(ps, i, v.tk) + i, nt = parse_newtype(ps, i) if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end @@ -3125,14 +3125,14 @@ parse_type_body_fns = { ["enum"] = parse_enum_body, } -parse_newtype = function(ps, i, name) +parse_newtype = function(ps, i) local node = new_node(ps.tokens, i, "newtype") node.newtype = new_type(ps, i, "typetype") local tn = ps.tokens[i].tk if parse_type_body_fns[tn] then local def = new_type(ps, i, tn) i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node, name) + i = parse_type_body_fns[tn](ps, i, def, node) node.newtype.def = def return i, node else @@ -3269,7 +3269,7 @@ local function parse_type_declaration(ps, i, node_name) return i, asgn end - i, asgn.value = parse_newtype(ps, i, asgn.var.tk) + i, asgn.value = parse_newtype(ps, i) if not asgn.value then return i end @@ -3296,7 +3296,7 @@ local function parse_type_constructor(ps, i, node_name, type_name, parse_body) end nt.newtype.def.names = { asgn.var.tk } - i = parse_body(ps, i, def, nt, asgn.var.tk) + i = parse_body(ps, i, def, nt) return i, asgn end @@ -6352,6 +6352,12 @@ tl.type_check = function(ast, opts) f.min_arity = n end + local function show_arity(f) + return f.min_arity < #f.args and + "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. #f.args) or + tostring(#f.args or 0) + end + local function resolve_typetype(t) if is_typetype(t) then return t.def @@ -8065,12 +8071,6 @@ tl.type_check = function(ast, opts) end end - local function show_arity(f) - return f.min_arity < #f.args and - "at least " .. f.min_arity or - tostring(#f.args or 0) - end - local function fail_call(node, func, nargs, errs) if errs then @@ -8150,11 +8150,11 @@ tl.type_check = function(ast, opts) set_min_arity(f) - if (is_func and (given <= expected or (f.args.is_va and given > expected))) or + if (is_func and ((given <= expected and given >= f.min_arity) or (f.args.is_va and given > expected) or (lax and given <= expected))) or (is_poly and ((pass == 1 and given == expected) or - (pass == 2 and given < expected) or + (pass == 2 and given < expected and (lax or given >= f.min_arity)) or (pass == 3 and f.args.is_va and given > expected))) then @@ -8230,7 +8230,7 @@ tl.type_check = function(ast, opts) where_args[2] = node.e2 args[2] = orig_b end - return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, true)), meta_on_operator + return resolve_tuple_and_nominal((type_check_function_call(node, where_args, metamethod, args, nil, true))), meta_on_operator else return nil, nil end diff --git a/tl.tl b/tl.tl index 53f99886a..2163ce089 100644 --- a/tl.tl +++ b/tl.tl @@ -1479,7 +1479,7 @@ local parse_statements: function(ParseState, integer, ? boolean): integer, Node local parse_argument_list: function(ParseState, integer): integer, Node local parse_argument_type_list: function(ParseState, integer): integer, Type local parse_type: function(ParseState, integer): integer, Type, integer -local parse_newtype: function(ps: ParseState, i: integer, name: string): integer, Node +local parse_newtype: function(ps: ParseState, i: integer): integer, Node local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node, name?: string): integer, Node local parse_enum_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node @@ -2835,7 +2835,7 @@ local function parse_nested_type(ps: ParseState, i: integer, def: Type, typename local nt: Node = new_node(ps.tokens, i - 2, "newtype") nt.newtype = new_type(ps, i, "typetype") local rdef = new_type(ps, i, typename) - local iok = parse_body(ps, i, rdef, nt, v.tk) + local iok = parse_body(ps, i, rdef, nt) if iok then i = iok nt.newtype.def = rdef @@ -3024,7 +3024,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end i = verify_tk(ps, i, "=") local nt: Node - i, nt = parse_newtype(ps, i, v.tk) + i, nt = parse_newtype(ps, i) if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end @@ -3125,14 +3125,14 @@ parse_type_body_fns = { ["enum"] = parse_enum_body, } -parse_newtype = function(ps: ParseState, i: integer, name: string): integer, Node +parse_newtype = function(ps: ParseState, i: integer): integer, Node local node: Node = new_node(ps.tokens, i, "newtype") node.newtype = new_type(ps, i, "typetype") local tn = ps.tokens[i].tk as TypeName if parse_type_body_fns[tn] then local def = new_type(ps, i, tn) i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node, name) + i = parse_type_body_fns[tn](ps, i, def, node) node.newtype.def = def return i, node else @@ -3269,7 +3269,7 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod return i, asgn end - i, asgn.value = parse_newtype(ps, i, asgn.var.tk) + i, asgn.value = parse_newtype(ps, i) if not asgn.value then return i end @@ -3296,7 +3296,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod end nt.newtype.def.names = { asgn.var.tk } - i = parse_body(ps, i, def, nt, asgn.var.tk) + i = parse_body(ps, i, def, nt) return i, asgn end @@ -6352,6 +6352,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string f.min_arity = n end + local function show_arity(f: Type): string + return f.min_arity < #f.args + and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. #f.args) + or tostring(#f.args or 0) + end + local function resolve_typetype(t: Type): Type if is_typetype(t) then return t.def @@ -8065,12 +8071,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function show_arity(f: Type): string - return f.min_arity < #f.args - and "at least " .. f.min_arity - or tostring(#f.args or 0) - end - local function fail_call(node: Node, func: Type, nargs: integer, errs: {Error}): Type if errs then -- report the errors from the first match @@ -8150,11 +8150,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string set_min_arity(f) -- simple functions: - if (is_func and (given <= expected or (f.args.is_va and given > expected))) + if (is_func and ((given <= expected and given >= f.min_arity) or (f.args.is_va and given > expected) or (lax and given <= expected))) -- poly, pass 1: try exact arity matches first or (is_poly and ((pass == 1 and given == expected) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < expected) + or (pass == 2 and given < expected and (lax or given >= f.min_arity)) -- poly, pass 3: then finally try vararg functions or (pass == 3 and f.args.is_va and given > expected))) then @@ -8230,7 +8230,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string where_args[2] = node.e2 args[2] = orig_b end - return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, true)), meta_on_operator + return resolve_tuple_and_nominal((type_check_function_call(node, where_args, metamethod, args, nil, true))), meta_on_operator else return nil, nil end From 3583627a61f24eaeaba50ced9fda7fd1eb99957a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 23 Nov 2023 18:48:28 -0300 Subject: [PATCH 029/224] refactor: move method downgrade to type-checking time --- tl.lua | 38 +++++++++++++++++++++----------------- tl.tl | 42 +++++++++++++++++++++++------------------- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/tl.lua b/tl.lua index 452a7cc51..13014d197 100644 --- a/tl.lua +++ b/tl.lua @@ -2948,7 +2948,7 @@ local function parse_array_interface_type(ps, i, def) return i, t end -parse_record_body = function(ps, i, def, node, name) +parse_record_body = function(ps, i, def, node) local istart = i - 1 def.fields = {} def.field_order = {} @@ -3077,22 +3077,6 @@ parse_record_body = function(ps, i, def, node, name) end end - if t.is_method and t.args and t.args[1] and t.args[1].is_self then - local selfarg = t.args[1] - if selfarg.tk ~= name or (def.typeargs and not selfarg.typevals) then - t.is_method = false - selfarg.is_self = false - elseif def.typeargs then - for j = 1, #def.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= def.typeargs[j].typearg then - t.is_method = false - selfarg.is_self = false - break - end - end - end - end - if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then if t.typename ~= "function" then fail(ps, i + 1, "macroexp must have a function type") @@ -11045,6 +11029,26 @@ tl.type_check = function(ast, opts) if ftype.typename == "nestedtype" then ftype.typename = "typetype" end + + if ftype.is_method and ftype.args and ftype.args[1] and ftype.args[1].is_self then + local record_name = typ.names and typ.names[1] + if record_name then + local selfarg = ftype.args[1] + if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + selfarg.is_self = false + elseif typ.typeargs then + for j = 1, #typ.typeargs do + if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + ftype.is_method = false + selfarg.is_self = false + break + end + end + end + end + end + typ.fields[name] = ftype i = i + 1 end diff --git a/tl.tl b/tl.tl index 2163ce089..1144b8251 100644 --- a/tl.tl +++ b/tl.tl @@ -1481,9 +1481,9 @@ local parse_argument_type_list: function(ParseState, integer): integer, Type local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer): integer, Node -local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node, name?: string): integer, Node +local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node local parse_enum_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node -local parse_record_body: function(ps: ParseState, i: integer, def: Type, node: Node, name?: string): integer, Node +local parse_record_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node local parse_type_body_fns: {TypeName:ParseBody} local function fail(ps: ParseState, i: integer, msg: string): integer @@ -2948,7 +2948,7 @@ local function parse_array_interface_type(ps: ParseState, i: integer, def: Type) return i, t end -parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, name: string): integer, Node +parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node local istart = i - 1 def.fields = {} def.field_order = {} @@ -3077,22 +3077,6 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node, end end - if t.is_method and t.args and t.args[1] and t.args[1].is_self then - local selfarg = t.args[1] - if selfarg.tk ~= name or (def.typeargs and not selfarg.typevals) then - t.is_method = false - selfarg.is_self = false - elseif def.typeargs then - for j=1,#def.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= def.typeargs[j].typearg then - t.is_method = false - selfarg.is_self = false - break - end - end - end - end - if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then if t.typename ~= "function" then fail(ps, i + 1, "macroexp must have a function type") @@ -11045,6 +11029,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ftype.typename == "nestedtype" then ftype.typename = "typetype" end + + if ftype.is_method and ftype.args and ftype.args[1] and ftype.args[1].is_self then + local record_name = typ.names and typ.names[1] + if record_name then + local selfarg = ftype.args[1] + if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + selfarg.is_self = false + elseif typ.typeargs then + for j=1,#typ.typeargs do + if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + ftype.is_method = false + selfarg.is_self = false + break + end + end + end + end + end + typ.fields[name] = ftype i = i + 1 end From 3b25cdbb3b649b53f8c17f641c3975c44003b664 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 27 Nov 2023 17:37:20 -0300 Subject: [PATCH 030/224] refactor: assert_is_a uses Where, not Node --- tl.lua | 12 ++++++------ tl.tl | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tl.lua b/tl.lua index 13014d197..0500bb921 100644 --- a/tl.lua +++ b/tl.lua @@ -7723,7 +7723,7 @@ tl.type_check = function(ast, opts) return false, { Err(t1, "got %s, expected %s", t1, t2) } end - local function assert_is_a(node, t1, t2, context, name) + local function assert_is_a(where, t1, t2, context, name) t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) if lax and (is_unknown(t1) or is_unknown(t2)) then @@ -7735,23 +7735,23 @@ tl.type_check = function(ast, opts) return true elseif t2.typename == "unresolved_emptytable_value" then if is_number_type(t2.emptytable_type.keys) then - infer_emptytable(t2.emptytable_type, infer_at(node, a_type({ typename = "array", elements = t1 }))) + infer_emptytable(t2.emptytable_type, infer_at(where, a_type({ typename = "array", elements = t1 }))) else - infer_emptytable(t2.emptytable_type, infer_at(node, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }))) + infer_emptytable(t2.emptytable_type, infer_at(where, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }))) end return true elseif t2.typename == "emptytable" then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(node, t1)) + infer_emptytable(t2, infer_at(where, t1)) elseif t1.typename ~= "emptytable" then - node_error(node, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) return false end return true end local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(node, match_errs, errors, context .. ": " .. (name and (name .. ": ") or "")) + add_errs_prefixing(where, match_errs, errors, context .. ": " .. (name and (name .. ": ") or "")) return ok end diff --git a/tl.tl b/tl.tl index 1144b8251..05c5f7ccc 100644 --- a/tl.tl +++ b/tl.tl @@ -7723,7 +7723,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, { Err(t1, "got %s, expected %s", t1, t2) } end - local function assert_is_a(node: Node, t1: Type, t2: Type, context: string, name?: string): boolean + local function assert_is_a(where: Where, t1: Type, t2: Type, context: string, name?: string): boolean t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) if lax and (is_unknown(t1) or is_unknown(t2)) then @@ -7735,23 +7735,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true elseif t2.typename == "unresolved_emptytable_value" then if is_number_type(t2.emptytable_type.keys) then -- ideally integer only - infer_emptytable(t2.emptytable_type, infer_at(node, a_type { typename = "array", elements = t1 })) + infer_emptytable(t2.emptytable_type, infer_at(where, a_type { typename = "array", elements = t1 })) else - infer_emptytable(t2.emptytable_type, infer_at(node, a_type { typename = "map", keys = t2.emptytable_type.keys, values = t1 })) + infer_emptytable(t2.emptytable_type, infer_at(where, a_type { typename = "map", keys = t2.emptytable_type.keys, values = t1 })) end return true elseif t2.typename == "emptytable" then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(node, t1)) + infer_emptytable(t2, infer_at(where, t1)) elseif t1.typename ~= "emptytable" then - node_error(node, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) return false end return true end local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(node, match_errs, errors, context .. ": ".. (name and (name .. ": ") or "")) + add_errs_prefixing(where, match_errs, errors, context .. ": ".. (name and (name .. ": ") or "")) return ok end From 3a81bd0cd2395410e8184cf056412aed7b03a114 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 30 Nov 2023 03:13:33 -0300 Subject: [PATCH 031/224] remove "tl build" --- docs/build_file.md | 52 --- docs/compiler_options.md | 103 +----- spec/cli/build_dir_spec.lua | 66 ---- spec/cli/build_script_spec.lua | 312 ------------------ spec/cli/build_spec.lua | 39 --- spec/cli/source_dir_spec.lua | 87 ----- spec/config/files_spec.lua | 19 -- spec/config/glob_spec.lua | 300 ----------------- spec/config/interactions_spec.lua | 161 ---------- spec/config/type_check_spec.lua | 28 -- tl | 517 +----------------------------- tl-dev-1.rockspec | 4 - 12 files changed, 20 insertions(+), 1668 deletions(-) delete mode 100644 docs/build_file.md delete mode 100644 spec/cli/build_dir_spec.lua delete mode 100644 spec/cli/build_script_spec.lua delete mode 100644 spec/cli/build_spec.lua delete mode 100644 spec/cli/source_dir_spec.lua delete mode 100644 spec/config/files_spec.lua delete mode 100644 spec/config/glob_spec.lua delete mode 100644 spec/config/interactions_spec.lua delete mode 100644 spec/config/type_check_spec.lua diff --git a/docs/build_file.md b/docs/build_file.md deleted file mode 100644 index 9a93ded57..000000000 --- a/docs/build_file.md +++ /dev/null @@ -1,52 +0,0 @@ -# build.tl file -Teal has the ability to run a custom function at compile time. The file containing this function is by default called `build.tl` but the name can be changed by setting `build_file` in `tlconfig.lua`. - -This function can be used to automatically generate code/types without the need for external build tools like `make`. - -The results of executing this function are cached, and the cache is automatically made invalid if the build script changed since last execution. -## Layout - -A build.tl file needs at least the following layout: -```lua -return { - gen_code = function(path:string) - - end -} -``` -`gen_code` is the function that will get executed and `path` is the base path where it should store generated teal files. - -More keys are planned in the future, which is why the file returns a table rather than it being executed directly. - -## Output location - -The teal files get stored in a temporary directory cleaned up after compilation (`/tmp` on Unix, `%TEMP%` on Windows). The generated teal files will get compiled to lua as normal and will be part of the build output. - -You can configure where the lua files will be saved by setting `build_file_output_dir` in `tlconfig.lua`. This uses the directory set by `build_dir` as a base. The default value is `generated_code`. - -## Use case - -As mentioned earlier, this file can be used to generate types without the need for `make` or other external build tools. A reason why you might is if your teal code consumes an API that has schemas available. - -Then you could simply add these schemas to your repo and have the `build.tl` file create types based on these schemas for you. That way you only need to grab a new version of the schemas if they change. - -Another use case could be when using a teal version of a library like [pgtyped](https://github.com/adelsz/pgtyped) where you normally need to run a command manually to generate the types and code. Now you can just stick that in the `build.tl` file and forget about it. - -## Limitations - -Right now the `build.tl` file is mostly useful for programs and less for libraries. This is because `teal` does not have its own package manager able to run the `build.tl` files from required dependencies. - -## Example - -```lua -return { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write([[ -function add(a : number, b : number): number - return a + b -end -]]) - end -} -``` diff --git a/docs/compiler_options.md b/docs/compiler_options.md index 15b92698c..0f94e28fa 100644 --- a/docs/compiler_options.md +++ b/docs/compiler_options.md @@ -18,25 +18,17 @@ return { ## List of compiler options -| Command line option | Config key | Type | Relevant Commands | Description | -| -------------------- | -------------------------- | ---------- | --------------------------- | ----------- | -| `-l --require` | | `{string}` | `run` | Require a module prior to executing the script. This is similar in behavior to the `-l` flag in the Lua interpreter. | -| `-I --include-dir` | `include_dir` | `{string}` | `build` `check` `gen` `run` | Prepend this directory to the module search path. -| `--gen-compat` | `gen_compat` | `string` | `build` `gen` `run` | Generate compatibility code for targeting different Lua VM versions. See [below](#generated-code) for details. -| `--gen-target` | `gen_target` | `string` | `build` `gen` `run` | Minimum targeted Lua version for generated code. Options are `5.1`, `5.3` and `5.4`. See [below](#generated-code) for details. -| | `include` | `{string}` | `build` | The set of files to compile/check. See below for details on patterns. -| | `exclude` | `{string}` | `build` | The set of files to exclude. See below for details on patterns. -| `--keep-hashbang` | | | `gen` | Preserve hashbang line (`#!`) at the top of file if present. -| `-s --source-dir` | `source_dir` | `string` | `build` | Set the directory to be searched for files. `build` will compile every .tl file in every subdirectory by default. -| `-b --build-dir` | `build_dir` | `string` | `build` | Set the directory for generated files, mimicking the file structure of the source files. -| | `files` | `{string}` | `build` | The names of files to be compiled. Does not accept patterns like `include`. -| `-p --pretend` | | | `build` `gen` | Don't compile/write to any files, but type check and log what files would be written to. -| `--wdisable` | `disable_warnings` | `{string}` | `build` `check` `run` | Disable the given warnings. -| `--werror` | `warning_error` | `{string}` | `build` `check` `run` | Promote the given warnings to errors. -| `--run-build-script` | `run_build_script` | `boolean` | `run` `check` `gen` | Runs the build script as if `tl build` was being run -| | `build_file_output_dir` | `string` | `run` `check` `gen` `build` | Folder where the generated files from the build script will be accessible in -| | `internal_compiler_output` | `string` | `run` `check` `gen` `build` | Folder to store cache files for use by the compiler -| `--global-env-def` | `global_env_def` | `string` | `build` `check` `gen` `run` | Specify a definition module declaring any custom globals predefined in your Lua environment. See the [declaration files](declaration_files.md#global-environment-definition) page for details. | +| Command line option | Config key | Type | Relevant Commands | Description | +| -------------------- | -------------------------- | ---------- | -------------------- | ----------- | +| `-l --require` | | `{string}` | `run` | Require a module prior to executing the script. This is similar in behavior to the `-l` flag in the Lua interpreter. | +| `-I --include-dir` | `include_dir` | `{string}` | `check` `gen` `run` | Prepend this directory to the module search path. +| `--gen-compat` | `gen_compat` | `string` | `gen` `run` | Generate compatibility code for targeting different Lua VM versions. See [below](#generated-code) for details. +| `--gen-target` | `gen_target` | `string` | `gen` `run` | Minimum targeted Lua version for generated code. Options are `5.1`, `5.3` and `5.4`. See [below](#generated-code) for details. +| `--keep-hashbang` | | | `gen` | Preserve hashbang line (`#!`) at the top of file if present. +| `-p --pretend` | | | `gen` | Don't compile/write to any files, but type check and log what files would be written to. +| `--wdisable` | `disable_warnings` | `{string}` | `check` `run` | Disable the given warnings. +| `--werror` | `warning_error` | `{string}` | `check` `run` | Promote the given warnings to errors. +| `--global-env-def` | `global_env_def` | `string` | `check` `gen` `run` | Specify a definition module declaring any custom globals predefined in your Lua environment. See the [declaration files](declaration_files.md#global-environment-definition) page for details. | ### Generated code @@ -138,76 +130,3 @@ you may pass a declaration module to the compiler using the `--global-env-def` f in the CLI or the `global_env_def` string in `tlconfig.lua`. For more information, see the [declaration files](declaration_files.md#global-environment-definition) page. - -### Include/Exclude patterns - -The `include` and `exclude` fields can have glob-like patterns in them: -- `*`: Matches any number of characters (excluding directory separators) -- `**/`: Matches any number subdirectories - -In addition -- setting the `source_dir` has the effect of prepending `source_dir` to all patterns. -- currently, `include` will only include `.tl` files even if the extension isn't specified - -For example: -If our project was laid out as such: -``` -tlconfig.lua -src/ -| foo/ -| | bar.tl -| | baz.tl -| bar/ -| | a/ -| | | foo.tl -| | b/ -| | | foo.tl -``` - -and our tlconfig.lua contained the following: -```lua -return { - source_dir = "src", - build_dir = "build", - include = { - "foo/*.tl", - "bar/**/*.tl" - }, - exclude = { - "foo/bar.tl" - } -} -``` - -Running `tl build -p` will type check the `include`d files and show what would be written to. -Running `tl build` will produce the following files. -``` -tlconfig.lua -src/ -| foo/ -| | bar.tl -| | baz.tl -| bar/ -| | a/ -| | | foo.tl -| | b/ -| | | foo.tl -build/ -| foo/ -| | baz.lua -| bar/ -| | a/ -| | | foo.lua -| | b/ -| | | foo.lua -``` - -Additionally, complex patterns can be used for whatever convoluted file structure we need. -```lua -return { - include = { - "foo/**/bar/**/baz/**/*.tl" - } -} -``` -This will compile any `.tl` file with a sequential `foo`, `bar`, and `baz` directory in its path. diff --git a/spec/cli/build_dir_spec.lua b/spec/cli/build_dir_spec.lua deleted file mode 100644 index 8e356f553..000000000 --- a/spec/cli/build_dir_spec.lua +++ /dev/null @@ -1,66 +0,0 @@ -local util = require("spec.util") - -describe("-b --build-dir argument", function() - it("generates files in the given directory", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - include = { - "foo.tl", "bar.tl" - }, - }]], - ["foo.tl"] = [[print "foo"]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - generated_files = { - ["build"] = { - "foo.lua", - "bar.lua", - } - }, - }) - end) - it("replicates the directory structure of the source", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - include = {]] .. util.os_path('"**/*.tl"') .. [[} - }]], - ["foo.tl"] = [[print "foo"]], - ["bar.tl"] = [[print "bar"]], - ["baz"] = { - ["foo.tl"] = [[print "foo"]], - ["bar"] = { - ["foo.tl"] = [[print "foo"]], - } - } - }, - cmd = "build", - generated_files = { - ["build"] = { - "foo.lua", - "bar.lua", - ["baz"] = { - "foo.lua", - ["bar"] = { - "foo.lua", - } - } - } - }, - }) - end) - it("dies when no config is found", function() - util.run_mock_project(finally, { - dir_structure = {}, - cmd = "build", - generated_files = {}, - exit_code = 1, - cmd_output = "Build error: tlconfig.lua not found\n" - }) - end) - -end) diff --git a/spec/cli/build_script_spec.lua b/spec/cli/build_script_spec.lua deleted file mode 100644 index eb0bcc3b3..000000000 --- a/spec/cli/build_script_spec.lua +++ /dev/null @@ -1,312 +0,0 @@ -local util = require("spec.util") - -describe("build.tl", function() - it("defaults to the generated_code folder", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - }]], - ["foo.tl"] = [[print(require("generated_code/generated"))]], - ["build.tl"] = [[ - return { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - } - ]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - generated_files = { - ["build"] = { - "build.lua", - "foo.lua", - "bar.lua", - ["generated_code"] = { - "generated.lua" - } - }, - ["internal_compiler_output"] = { - ["build_script_output"] = { - ["generated_code"] = { - "generated.tl" - } - }, - "last_build_script_time" - } - }, - }) - end) - it("can have the location it stores altered by setting build_file_output_dir", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file_output_dir = "other_generated_code" - }]], - ["foo.tl"] = [[print(require("other_generated_code/generated"))]], - ["build.tl"] = [[ - return { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - }]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - generated_files = { - ["build"] = { - "build.lua", - "foo.lua", - "bar.lua", - ["other_generated_code"] = { - "generated.lua" - } - }, - ["internal_compiler_output"] = { - ["build_script_output"] = { - ["other_generated_code"] = { - "generated.tl" - } - }, - "last_build_script_time" - } - }, - }) - end) - it("can have a diffrent name by setting build_file", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file = "other_name.tl" - }]], - ["foo.tl"] = [[print(require("generated_code/generated"))]], - ["other_name.tl"] = [[ - return { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - } - ]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - generated_files = { - ["build"] = { - "other_name.lua", - "foo.lua", - "bar.lua", - ["generated_code"] = { - "generated.lua" - } - }, - ["internal_compiler_output"] = { - ["build_script_output"] = { - ["generated_code"] = { - "generated.tl" - } - }, - "last_build_script_time" - } - } - }) - end) - it("Can have the location for cached output files changed", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - internal_compiler_output = "this_other_folder_to_store_cached_items" - }]], - ["foo.tl"] = [[print(require("generated_code/generated"))]], - ["build.tl"] = [[ - return { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - - } - ]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - generated_files = { - ["build"] = { - "build.lua", - "foo.lua", - "bar.lua", - ["generated_code"] = { - "generated.lua" - } - }, - ["this_other_folder_to_store_cached_items"] = { - ["build_script_output"] = { - ["generated_code"] = { - "generated.tl" - } - }, - "last_build_script_time" - } - }, - }) - - end) - it("Should not run when running something else than build", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file_output_dir = "generated_code", - internal_compiler_output = "this_other_folder_to_store_cached_items" - }]], - ["foo.tl"] = [[print("build.tl did not run")]], - ["build.tl"] = [[ - { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - - } - ]], - }, - cmd = "run", - args = { - "foo.tl" - }, - cmd_output = "build.tl did not run\n" - }) - - end) - it("Should run when running something else than build and --run-build-script is passed", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file_output_dir = "generated_code", - internal_compiler_output = "this_other_folder_to_store_cached_items" - }]], - ["foo.tl"] = [[local x =require("generated_code/generated") print(x)]], - ["build.tl"] = [[ - return { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - - } - ]], - }, - cmd = "run", - pre_args = {"--run-build-script"}, - args = { - "foo.tl", - }, - cmd_output = "Hello from script generated by build.tl\n" - }) - - end) - - it("It should only run the build script if it changed since last time", function() - local path = util.write_tmp_dir(finally, { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file_output_dir = "generated_code", - internal_compiler_output = "test" - }]], - ["build.tl"] = [[ - return { - gen_code = function(_:string) - - print("This text should appear only once") - end - - } - ]], - }) - util.run_mock_project(finally, { - cmd = "build", - cmd_output = [[ -This text should appear only once -Wrote: build/build.lua -]] - }, - path - ) - util.run_mock_project(finally, { - cmd = "build", - cmd_output = [[ -Wrote: build/build.lua -]] - }, - path - ) - - end) - - - it("Should give an error if the build script contains invalid teal", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file_output_dir = "generated_code", - internal_compiler_output = "this_other_folder_to_store_cached_items" - }]], - ["foo.tl"] = [[print(require("generated_code/generated"))]], - ["build.tl"] = [[ - { - gen_code = function(path:string) - local file = io.open(path .. "/generated.tl", "w") - file:write('return "Hello from script generated by build.tl"') - file:close() - end - - } - ]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - cmd_output = -[[======================================== -1 syntax error: -./build.tl:8:17: syntax error -]] - }) - end) - it("Should give an error if the key gen_code exists, but it is not a function", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - build_dir = "build", - build_file_output_dir = "generated_code", - internal_compiler_output = "this_other_folder_to_store_cached_items" - }]], - ["foo.tl"] = [[print(require("generated_code/generated"))]], - ["build.tl"] = [[ - return { - gen_code = "I am a string" - - } - ]], - ["bar.tl"] = [[print "bar"]], - }, - cmd = "build", - cmd_output = -[[the key "gen_code" exists in the build file, but it is not a function. Value: I am a string -]] - }) - end) -end) diff --git a/spec/cli/build_spec.lua b/spec/cli/build_spec.lua deleted file mode 100644 index fac2f8d7c..000000000 --- a/spec/cli/build_spec.lua +++ /dev/null @@ -1,39 +0,0 @@ -local util = require("spec.util") - -describe("build command", function() - it("should exit with non zero exit code when there is an error", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return {}]], - ["foo.tl"] = [[print "a"]], - ["bar.tl"] = [[local x: string = 10]], - }, - cmd = "build", - exit_code = 1, - }) - end) - - it("should not error when tlconfig returns nil/nothing", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[]], - }, - cmd = "build", - exit_code = 0, - }) - end) - - it("should find tlconfig.lua in a parent directory", function() - util.do_in(util.write_tmp_dir(finally, { - ["tlconfig.lua"] = [[ - return { - source_dir = "src" - } - ]], - src = {}, - }), function() - local ph = io.popen("cd src && " .. util.tl_cmd("build"), "r") - util.assert_popen_close(0, ph:close()) - end) - end) -end) diff --git a/spec/cli/source_dir_spec.lua b/spec/cli/source_dir_spec.lua deleted file mode 100644 index 01536e833..000000000 --- a/spec/cli/source_dir_spec.lua +++ /dev/null @@ -1,87 +0,0 @@ -local util = require("spec.util") - -describe("-s --source-dir argument", function() - it("recursively traverses the directory by default", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { source_dir = "src" }]], - ["src"] = { - ["foo.tl"] = [[print "foo"]], - ["bar.tl"] = [[print "bar"]], - foo = { - ["bar.tl"] = [[print "bar"]], - baz = { - ["foo.tl"] = [[print "baz"]], - } - } - } - }, - cmd = "build", - generated_files = { - ["src"] = { - "foo.lua", - "bar.lua", - foo = { - "bar.lua", - baz = { - "foo.lua" - } - } - } - }, - }) - end) - it("should die when the given directory doesn't exist", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return {source_dir="src"}]], - ["foo.tl"] = [[print 'hi']], - }, - cmd = "build", - generated_files = {}, - cmd_output = "Build error: source_dir 'src' doesn't exist\n", - }) - end) - it("should not include files from other directories", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - source_dir = "foo", - }]], - ["foo"] = { - ["a.tl"] = [[return "hey"]], - }, - ["bar"] = { - ["b.tl"] = [[return "hi"]], - }, - }, - cmd = "build", - generated_files = { - ["foo"] = { - "a.lua" - }, - }, - }) - end) - it("should correctly match directory names", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - source_dir = "foo", - }]], - ["foo"] = { - ["a.tl"] = [[return "hey"]], - }, - ["foobar"] = { - ["b.tl"] = [[return "hi"]], - }, - }, - cmd = "build", - generated_files = { - ["foo"] = { - "a.lua" - }, - }, - }) - end) -end) diff --git a/spec/config/files_spec.lua b/spec/config/files_spec.lua deleted file mode 100644 index c23e14bdc..000000000 --- a/spec/config/files_spec.lua +++ /dev/null @@ -1,19 +0,0 @@ -local util = require("spec.util") - -describe("files config option", function() - it("should compile the given list of files", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { files = { "foo.tl", "bar.tl" } }]], - ["foo.tl"] = [[print "a"]], - ["bar.tl"] = [[print "b"]], - ["baz.tl"] = [[print "c"]], - }, - cmd = "build", - generated_files = { - "foo.lua", - "bar.lua", - } - }) - end) -end) diff --git a/spec/config/glob_spec.lua b/spec/config/glob_spec.lua deleted file mode 100644 index dcf6ac2e5..000000000 --- a/spec/config/glob_spec.lua +++ /dev/null @@ -1,300 +0,0 @@ -local util = require("spec.util") - -describe("globs", function() - describe("*", function() - it("should match non directory separators", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = {"*"} }]], - ["a.tl"] = [[print "a"]], - ["b.tl"] = [[print "b"]], - ["c.tl"] = [[print "c"]], - }, - cmd = "build", - generated_files = { - "a.lua", - "b.lua", - "c.lua", - }, - cmd_output = "Wrote: a.lua\nWrote: b.lua\nWrote: c.lua\n" - }) - end) - it("should match when other characters are present in the pattern", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { "ab*cd.tl" } }]], - ["abzcd.tl"] = [[print "a"]], - ["abcd.tl"] = [[print "b"]], - ["abfoocd.tl"] = [[print "c"]], - ["abbarcd.tl"] = [[print "d"]], - ["abbar.tl"] = [[print "e"]], - ["barcd.tl"] = [[print "f"]], - }, - cmd = "build", - generated_files = { - "abbarcd.lua", - "abcd.lua", - "abfoocd.lua", - "abzcd.lua", - }, - cmd_output = "Wrote: abbarcd.lua\nWrote: abcd.lua\nWrote: abfoocd.lua\nWrote: abzcd.lua\n" - }) - end) - it("should only match .tl by default", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { "*" } }]], - ["foo.tl"] = [[print "a"]], - ["foo.py"] = [[print("b")]], - ["foo.hs"] = [[main = print "c"]], - ["foo.sh"] = [[echo "d"]], - }, - cmd = "build", - generated_files = { - "foo.lua" - }, - }) - end) - it("should not match .d.tl files", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { "*" } }]], - ["foo.tl"] = [[print "a"]], - ["bar.d.tl"] = [[local Point = record x: number y: number end return Point]], - }, - cmd = "build", - generated_files = { - "foo.lua" - }, - }) - end) - it("should match directories in the middle of a path", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { ]] .. util.os_path('"foo/*/baz.tl"') .. [[ } }]], - ["foo"] = { - ["bar"] = { - ["foo.tl"] = [[print "a"]], - ["baz.tl"] = [[print "b"]], - }, - ["bingo"] = { - ["foo.tl"] = [[print "c"]], - ["baz.tl"] = [[print "d"]], - }, - ["bongo"] = { - ["foo.tl"] = [[print "e"]], - }, - } - }, - cmd = "build", - generated_files = { - ["foo"] = { - ["bar"] = { - "baz.lua" - }, - ["bingo"] = { - "baz.lua" - }, - }, - }, - }) - end) - end) - describe("**/", function() - it("should match the current directory", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { ]] .. util.os_path('"**/*"') .. [[ } }]], - ["foo.tl"] = [[print "a"]], - ["bar.tl"] = [[print "b"]], - ["baz.tl"] = [[print "c"]], - }, - cmd = "build", - generated_files = { - "foo.lua", - "bar.lua", - "baz.lua", - }, - }) - end) - it("should match any subdirectory", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { ]] .. util.os_path('"**/*"') .. [[ } }]], - ["foo"] = { - ["foo.tl"] = [[print "a"]], - ["bar.tl"] = [[print "b"]], - ["baz.tl"] = [[print "c"]], - }, - ["bar"] = { - ["foo.tl"] = [[print "a"]], - ["baz"] = { - ["bar.tl"] = [[print "b"]], - ["baz.tl"] = [[print "c"]], - } - }, - ["a"] = {a={a={a={a={a={["a.tl"]=[[global a = "a"]]}}}}}} - }, - cmd = "build", - generated_files = { - ["foo"] = { - "foo.lua", - "bar.lua", - "baz.lua", - }, - ["bar"] = { - "foo.lua", - ["baz"] = { - "bar.lua", - "baz.lua", - } - }, - ["a"] = {a={a={a={a={a={"a.lua"}}}}}}, - }, - }) - end) - it("should not get the order of directories confused", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { ]] .. util.os_path('"foo/**/bar/**/baz/a.tl"') .. [[ } }]], - ["foo"] = { - ["bar"] = { - ["baz"] = { - ["a.tl"] = [[print "a"]], - }, - }, - }, - ["baz"] = { - ["bar"] = { - ["foo"] = { - ["a.tl"] = [[print "a"]], - }, - }, - }, - ["bar"] = { - ["baz"] = { - ["foo"] = { - ["a.tl"] = [[print "a"]], - }, - }, - }, - }, - cmd = "build", - generated_files = { - ["foo"] = { - ["bar"] = { - ["baz"] = { - "a.lua", - } - } - }, - }, - }) - end) - end) - describe("* and **/", function() - it("should work together", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = { ]] .. util.os_path('"**/foo/*/bar/**/*"') .. [[ } }]], - ["foo"] = { - ["a"] = { - ["bar"] = { - ["baz"] = { - ["a"] = {["b"] = {["c.tl"] = [[print "c"]]}}, - }, - ["bat"] = { - ["a"] = {["b.tl"] = [[print "b"]]}, - }, - }, - }, - ["b"] = { - ["bar"] = { - ["a"] = {["b.tl"] = [[print "b"]]}, - }, - }, - ["c"] = { - ["d"] = { - ["bar"] = { - ["a.tl"] = [[print "not included"]] - }, - }, - }, - }, - ["a"] = { - ["b"] = { - ["foo"] = { - ["a"] = { - ["bar"] = { - ["baz"] = { - ["a"] = {["b"] = {["c.tl"] = [[print "c"]]}}, - }, - ["bat"] = { - ["a"] = {["b.tl"] = [[print "b"]]}, - }, - }, - }, - ["b"] = { - ["bar"] = { - ["baz"] = { - ["a"] = {["b"] = {["c.tl"] = [[print "c"]]}}, - }, - ["bat"] = { - ["a"] = {["b.tl"] = [[print "b"]]}, - }, - }, - }, - }, - }, - }, - }, - cmd = "build", - generated_files = { - ["foo"] = { - ["a"] = { - ["bar"] = { - ["baz"] = { - ["a"] = {["b"] = {"c.lua"}}, - }, - ["bat"] = { - ["a"] = {"b.lua"}, - }, - }, - }, - ["b"] = { - ["bar"] = { - ["a"] = {"b.lua"}, - }, - }, - }, - ["a"] = { - ["b"] = { - ["foo"] = { - ["a"] = { - ["bar"] = { - ["baz"] = { - ["a"] = {["b"] = {"c.lua"}}, - }, - ["bat"] = { - ["a"] = {"b.lua"}, - }, - }, - }, - ["b"] = { - ["bar"] = { - ["baz"] = { - ["a"] = {["b"] = {"c.lua"}}, - }, - ["bat"] = { - ["a"] = {"b.lua"}, - }, - }, - }, - }, - }, - }, - }, - }) - end) - end) -end) diff --git a/spec/config/interactions_spec.lua b/spec/config/interactions_spec.lua deleted file mode 100644 index b1e72d742..000000000 --- a/spec/config/interactions_spec.lua +++ /dev/null @@ -1,161 +0,0 @@ -local util = require("spec.util") - -describe("config option interactions", function() - describe("include+exclude", function() - it("exclude should have precedence over include", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - include = { - ]] .. util.os_path('"**/*"') .. [[, - }, - exclude = { - "*", - }, - }]], - -- should include any .tl file not in the top directory - ["foo.tl"] = [[print "hey"]], - ["bar.tl"] = [[print "hi"]], - baz = { - foo = { - ["bar.tl"] = [[print "h"]], - }, - bar = { - ["baz.tl"] = [[print "hello"]], - }, - }, - }, - cmd = "build", - generated_files = { - baz = { - foo = { "bar.lua" }, - bar = { "baz.lua" }, - }, - }, - }) - end) - end) - describe("source_dir+build_dir", function() - it("Having source_dir inside of build_dir works", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - source_dir = ]] .. util.os_path('"foo/bar"') .. [[, - build_dir = "foo", - }]], - foo = { - bar = { - ["a.tl"] = [[print "a"]], - ["b.tl"] = [[print "b"]], - } - } - }, - cmd = "build", - generated_files = { - foo = { - "a.lua", - "b.lua", - } - }, - }) - end) - it("Having build_dir inside of source_dir works", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - source_dir = "foo", - build_dir = "foo/bar", - }]], - foo = { - ["a.tl"] = [[print "a"]], - ["b.tl"] = [[print "b"]], - } - }, - cmd = "build", - generated_files = { - foo = { - bar = { - "a.lua", - "b.lua", - } - } - }, - }) - end) - end) - describe("source_dir+include+exclude", function() - it("nothing outside of source_dir is included", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - source_dir = "src", - include = { - ]] .. util.os_path('"**/*"') .. [[ - }, - }]], - ["src"] = { - ["foo"] = { - ["bar"] = { - ["a.tl"] = [[print "a"]], - ["b.tl"] = [[print "b"]], - }, - ["a.tl"] = [[print "a"]], - ["b.tl"] = [[print "b"]], - }, - }, - ["a.tl"] = [[print "a"]], - ["b.tl"] = [[print "b"]], - }, - cmd = "build", - generated_files = { - ["src"] = { - ["foo"] = { - ["bar"] = { - "a.lua", - "b.lua", - }, - "a.lua", - "b.lua", - }, - }, - }, - }) - end) - it("include and exclude work as expected", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { - source_dir = ".", - include = { - ]] .. util.os_path('"foo/*.tl"') .. [[, - }, - exclude = { - ]] .. util.os_path('"foo/a*.tl"') .. [[, - }, - }]], - foo = { - ["a.tl"] = [[print 'a']], - ["ab.tl"] = [[print 'a']], - ["ac.tl"] = [[print 'a']], - ["b.tl"] = [[print 'b']], - ["bc.tl"] = [[print 'b']], - ["bd.tl"] = [[print 'b']], - }, - bar = { - ["c.tl"] = [[print 'c']], - ["cd.tl"] = [[print 'c']], - ["ce.tl"] = [[print 'c']], - }, - }, - cmd = "build", - generated_files = { - foo = { - "b.lua", - "bc.lua", - "bd.lua", - }, - }, - }) - end) - end) -end) diff --git a/spec/config/type_check_spec.lua b/spec/config/type_check_spec.lua deleted file mode 100644 index dd8f844a1..000000000 --- a/spec/config/type_check_spec.lua +++ /dev/null @@ -1,28 +0,0 @@ -local util = require("spec.util") - -describe("config type checking", function() - it("should error out when config.include is not a {string}", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { include = "*.tl" }]], - ["foo.tl"] = [[print "a"]], - }, - cmd = "build", - generated_files = {}, - exit_code = 1, - cmd_output = "Error loading tlconfig.lua:\n* in key \"include\": expected a {string}, got string\n", - }) - end) - it("should error out when config.source_dir is not a string", function() - util.run_mock_project(finally, { - dir_structure = { - ["tlconfig.lua"] = [[return { source_dir = true }]], - ["foo.tl"] = [[print "a"]], - }, - cmd = "build", - generated_files = {}, - exit_code = 1, - cmd_output = "Error loading tlconfig.lua:\n* in key \"source_dir\": expected a string, got boolean\n", - }) - end) -end) diff --git a/tl b/tl index 093ac9cfa..5858f1698 100755 --- a/tl +++ b/tl @@ -255,483 +255,6 @@ local function write_out(tlconfig, result, output_file, pp_opts) end end --------------------------------------------------------------------------------- --- Build system (deprecated, use Cyan) --------------------------------------------------------------------------------- - -local build = {} -do - local lfs = require("lfs") - - local function str_split(str, delimiter) - local idx = 0 - return function() - if not idx then return end - idx = idx + 1 - local prev_idx = idx - local s_idx - s_idx, idx = str:find(delimiter, idx, true) - return str:sub(prev_idx, (s_idx or 0) - 1) - end - end - - --remove trailing and extra path separators, substitute './' for 'current_dir/' - local function cleanup_file_name(name) - return (name - :gsub("^(%.)(.?)", function(a, b) - assert(a == ".") - if b == "." then - die("Config error: .." .. PATH_SEPARATOR .. " not allowed, please use direct paths") - elseif b == PATH_SEPARATOR then - return "" - else - return b - end - end) - :gsub(PATH_SEPARATOR .. "+", PATH_SEPARATOR)) - :gsub(PATH_SEPARATOR .. "+$", "") - end - - local function path_concat(...) - local path = {} - for i = 1, select("#", ...) do - local fname = cleanup_file_name((select(i, ...))) - if #fname > 0 then - table.insert(path, fname) - end - end - return table.concat(path, PATH_SEPARATOR) - end - - function build.arg_parser(parser) - parser:flag("--run-build-script", "Run the build script if needed, even when not running the build comamnd.") - - local build_command = parser:command("build", - "Build your project according to tlconfig.lua by type checking and compiling each specified file.") - build_command:option("-b --build-dir", "Put all generated files in .") - :argname("") - build_command:option("-s --source-dir", "Compile all *.tl files in (and all subdirectories).") - :argname("") - end - - function build.tlconfig_not_found(cmd) - if cmd == "build" then - die("Build error: tlconfig.lua not found") - end - end - - function build.config_dir(cmd, config_path, config) - local config_dir = config_path:match("^(.+)" .. PATH_SEPARATOR .. "tlconfig.lua$") - if cmd == "build" and config_dir then - assert(lfs.chdir(config_dir)) - end - - if not config.build_file then - if lfs.attributes("./build.tl", "mode") == "file" then - config.build_file = "./build.tl" - end - else - if lfs.attributes(config.build_file, "mode") ~= "file" then - die("The configured build script is not a file") - end - end - end - - function build.merge_config(tlconfig, args) - local cmd = args["command"] - if cmd == "build" then - tlconfig["source_dir"] = args["source_dir"] or tlconfig["source_dir"] - tlconfig["build_dir"] = args["build_dir"] or tlconfig["build_dir"] - end - tlconfig["run_build_script"] = tlconfig["run_build_script"] or args["run_build_script"] or cmd == "build" - - if tlconfig["run_build_script"] then - if tlconfig["build_file"] and not tlconfig["build_file_output_dir"] then - print("A build file is detected, but build_file_output_dir is not set. Defaulting to ./generated_code") - tlconfig["build_file_output_dir"] = "generated_code" - end - - if tlconfig["build_file"] and not tlconfig["internal_compiler_output"] then - print("A build file is detected, but there is no place configured " .. - "to store temporary compiler output. Defaulting to ./internal_compiler_output") - tlconfig["internal_compiler_output"] = "internal_compiler_output" - end - end - end - - local internal_output - local build_path - - function build.run_build_script(tlconfig) - if not tlconfig["run_build_script"] then - return - end - - if tlconfig["internal_compiler_output"] then - - internal_output = path_concat(lfs.currentdir(), tlconfig["internal_compiler_output"]) - local mode = lfs.attributes(internal_output, "mode") - if not mode then - local parts = "" - local prefix = PATH_SEPARATOR == "\\" and "" or "/" - for v in string.gmatch(internal_output, "[^/]+") do - parts = parts .. prefix .. v - mode = lfs.attributes(parts, "mode") - if mode == nil then - local res, message = lfs.mkdir(parts) - if not res then - die("Could not create directory to store internal output. Error: " .. message) - end - elseif mode ~= "directory" then - die("Could not create directory to store the internal output. " .. - "Path: " .. parts .. " is not a directory") - end - end - end - end - - if tlconfig["build_file"] then - build_path = path_concat(internal_output, "build_script_output") - lfs.mkdir(build_path) - prepend_to_lua_paths(build_path) - - local script = {} - local chunk = type_check_and_load(tlconfig, tlconfig.build_file) - local success, res = pcall(chunk) - if success then - script = res - else - die("The build file could not be executed.") - end - - - local time_keeper_path = path_concat(internal_output, "last_build_script_time") - -- No need to read the file if we can just look up when it was last modified. - -- Should have about the same effect and is easier. - local last_run_time = lfs.attributes(time_keeper_path, "modification") - local last_edit_time = lfs.attributes(tlconfig.build_file, "modification") - local should_rerun = last_run_time == nil or last_run_time < last_edit_time - - local gen_code = script["gen_code"] - if should_rerun and gen_code then - if type(gen_code) == "function" then - local full_path = path_concat(build_path, tlconfig["build_file_output_dir"] ) - lfs.rmdir(full_path) - lfs.mkdir(full_path) - local pok, message = pcall(gen_code, full_path) - if not pok then - die("Something has gone wrong while executing the " .. - "\"gen_code\" part of the build file. Error : ".. tostring(message)) - end - local file = io.open(time_keeper_path, "wb") - file:write(last_edit_time) - file:flush() - file:close() - else - die("the key \"gen_code\" exists in the build file, " .. - "but it is not a function. Value: ".. tostring(gen_code)) - end - end - end - end - - function build.run(tlconfig) - local function remove_leading_path(leading_part, path) - local s, e = path:find("^" .. leading_part .. PATH_SEPARATOR .. "?") - if s then - return path:sub(e+1, -1) - end - return path - end - - local function traverse(dirname, emptyref, is_generated, generated_ref) - local files = {} - local paths = {} --lookup table for string paths to help - -- with pattern matching while iterating over a project - -- paths[files.foo.bar] -> "foo/bar" - emptyref = emptyref or {} - generated_ref = generated_ref or {} - for file in lfs.dir(dirname) do - if file ~= "." and file ~= ".." then - if lfs.attributes(path_concat(dirname, file), "mode") == "directory" then - local p - files[file], p = traverse(path_concat(dirname, file), emptyref, is_generated, generated_ref) - paths[files[file]] = file - for k, v in pairs(p) do - paths[k] = path_concat(file, v) - end - else - -- storing a special entry in this table to it mark as empty could - -- interfere with convoluted or maliciously constructed directory - -- names so we use a table with specific metatable to mark - -- something as the end of a traversal to have a property attached - -- to the table, without creating an entry in the table - local meta_table = {empty = emptyref} - if is_generated then - meta_table["generated"] = generated_ref - end - files[file] = setmetatable({}, meta_table) - paths[files[file]] = file - end - end - end - return files, paths, emptyref - end - - local function match(patt_arr, str) - for i, v in ipairs(patt_arr) do - if v(str) then - return i - end - end - return nil - end - local inc_patterns = {} - local exc_patterns = {} - - local function patt_match(patt, str) - local matches = true - local idx = 1 - local s_idx - for _, v in ipairs(patt) do - s_idx, idx = str:find(v, idx) - if not s_idx then - matches = false - break - end - end - return matches - end - local function matcher(str) - local chunks = {} - for piece in str_split(str, "**" .. PATH_SEPARATOR) do - table.insert(chunks, (piece:gsub("%*", "[^" .. PATH_SEPARATOR .. "]-"))) - end - chunks[1] = "^" .. chunks[1] - chunks[#chunks] = chunks[#chunks] .. "$" - return function(s) - return patt_match(chunks, s) - end - end - if internal_output then - table.insert(exc_patterns, matcher(path_concat(tlconfig["internal_compiler_output"], "**" .. PATH_SEPARATOR .. "*.*"))) - end - - -- prepare build and source dirs - - local project = {} - -- This will probably get exposed in the api if that happens - function project:file_with_is_build(inc_patt_arr, exc_patt_arr, dirname) - local iter_dir - if dirname then - iter_dir = project:find(dirname) - else - iter_dir = self.dir - end - if not iter_dir then - return function() end - end - inc_patt_arr = inc_patt_arr or {} - exc_patt_arr = exc_patt_arr or {} - local function iter(dirs) - for fname, file in pairs(dirs) do - local path = self.paths[file] - if dirname then - path = remove_leading_path(dirname, path) - end - local meta_table = getmetatable(file) - if meta_table and meta_table.empty == self.emptyref then - - local include = true - - if tlconfig["files"] then - include = false - end - if build_path and meta_table.generated and meta_table.generated == self.generatedref then - coroutine.yield(build_path .. PATH_SEPARATOR .. path, true) - else - - -- TODO: print out patterns that include/exclude paths to help - -- users debug tlconfig.lua (this is why match returns the array index) - if #inc_patt_arr > 0 then - local idx = match(inc_patt_arr, path) - if not idx then - include = false - end - end - if #exc_patt_arr > 0 then - local idx = match(exc_patt_arr, path) - if include and idx then - include = false - end - end - if include then - coroutine.yield(self.paths[file], false) - end - end - else - iter(file, fname) - end - end - end - return coroutine.wrap(iter), iter_dir - end - function project:find(path) -- allow for indexing with paths project:find("foo/bar") -> project.dir.foo.bar - if not path then return nil end - if path == "" then return self.dir end -- empty string is the current dir - local current_dir = self.dir - for dirname in str_split(path, PATH_SEPARATOR) do - current_dir = current_dir[dirname] - if not current_dir then - return nil - end - end - return current_dir - end - - project.dir, project.paths, project.emptyref = traverse(lfs.currentdir()) - local build_ref = {} - project.generatedref = build_ref - if build_path then - local build_dir, build_paths = traverse(build_path, project.emptyref, true, build_ref) - for k, v in pairs(build_dir) do - project.dir[k] = v - end - for k, v in pairs(build_paths) do - project.paths[k] = v - end - end - - project.source_file_map = {} - - if tlconfig["source_dir"] then - tlconfig["source_dir"] = cleanup_file_name(tlconfig["source_dir"]) - local project_source = project:find(tlconfig["source_dir"]) - local meta_table = getmetatable(project_source) - if not project_source then - die("Build error: source_dir '" .. tlconfig["source_dir"] .. "' doesn't exist") - elseif meta_table and meta_table.empty == project.emptyref then - die("Build error: source_dir '" .. tlconfig["source_dir"] .. "' is not a directory") - end - end - if tlconfig["build_dir"] then - tlconfig["build_dir"] = cleanup_file_name(tlconfig["build_dir"]) - end - - -- include/exclude pattern matching - -- create matchers for each pattern - if tlconfig["include"] then - for _, patt in ipairs(tlconfig["include"]) do - patt = cleanup_file_name(patt) - table.insert(inc_patterns, matcher(patt)) - end - end - if tlconfig["exclude"] then - for _, patt in ipairs(tlconfig["exclude"]) do - patt = cleanup_file_name(patt) - table.insert(exc_patterns, matcher(patt)) - end - end - - local dirs_to_be_mked = {} - local function check_parent_dirs(path) - local parent_dirs = {} - for dir in str_split(path, PATH_SEPARATOR) do - parent_dirs[#parent_dirs + 1] = #parent_dirs > 0 and path_concat(parent_dirs[#parent_dirs], dir) or dir - end - for i, v in ipairs(parent_dirs) do - if i < #parent_dirs then - local mode = lfs.attributes(v, "mode") - if not mode and not dirs_to_be_mked[v] then - table.insert(dirs_to_be_mked, v) - dirs_to_be_mked[v] = true - elseif mode and mode ~= "directory" then - die("Build error: expected " .. v .. " to be a directory") - end - end - end - end - - if tlconfig["files"] then - -- TODO: check if files are not relative - for _, fname in ipairs(tlconfig["files"]) do - if not project:find(fname) then - die("Build error: file \"" .. fname .. "\" not found") - end - project.source_file_map[fname] = fname:gsub("%.tl$", ".lua") - if tlconfig["build_dir"] then - project.source_file_map[fname] = path_concat(tlconfig["build_dir"], project.source_file_map[fname]) - end - check_parent_dirs(project.source_file_map[fname]) - end - end - local source_dir = tlconfig["source_dir"] - for path, is_build in project:file_with_is_build(inc_patterns, exc_patterns, source_dir) do - --TODO: make this better - local valid = true - if not (path:match("%.tl$") and not path:match("%.d%.tl$")) then - valid = false - end - if valid then - local work_on = path:gsub("%.tl$", ".lua") - if is_build then - work_on = remove_leading_path(build_path, work_on) - end - project.source_file_map[path] = work_on - if tlconfig["build_dir"] then - if source_dir then - project.source_file_map[path] = remove_leading_path(source_dir, project.source_file_map[path]) - end - project.source_file_map[path] = path_concat(tlconfig["build_dir"], project.source_file_map[path]) - end - - check_parent_dirs(project.source_file_map[path]) - end - end - for _, v in ipairs(dirs_to_be_mked) do - if not lfs.mkdir(v) then - die("Build error: unable to mkdir \"" .. v .. "\"") - end - end - - -- sort source map so that order is deterministic (helps for testing output) - local sorted_source_file_arr = {} - for input_file, output_file in pairs(project.source_file_map) do - table.insert(sorted_source_file_arr, {input_file, output_file}) - end - table.sort(sorted_source_file_arr, function(a, b) return a[1] < b[1] end) - - if #sorted_source_file_arr == 0 then - os.exit(0) - end - - turbo(true) - local env - for i, files in ipairs(sorted_source_file_arr) do - local input_file, output_file = files[1], files[2] - if not env then - env = setup_env(tlconfig, input_file) - end - - local result, err = tl.process(input_file, env) - if err then - die(err) - end - - filter_warnings(tlconfig, result) - if #result.syntax_errors == 0 and #result.type_errors == 0 then - write_out(tlconfig, result, output_file) - end - - check_collect(i) - end - - local ok = report_all_errors(tlconfig, env) - - os.exit(ok and 0 or 1) - end -end - -------------------------------------------------------------------------------- -- Driver utilities -------------------------------------------------------------------------------- @@ -769,17 +292,6 @@ local function validate_config(config) gen_target = { ["5.1"] = true, ["5.3"] = true, ["5.4"] = true }, disable_warnings = "{string}", warning_error = "{string}", - - -- build related keys - exclude = "{string}", - files = "{string}", - include = "{string}", - source_dir = "string", - build_dir = "string", - build_file = "string", - build_file_output_dir = "string", - internal_compiler_output = "string", - run_build_script = "boolean" } for k, v in pairs(config) do @@ -874,8 +386,6 @@ local function get_args_parser() :argname("") :count("*") - build.arg_parser(parser) - parser:command("warnings", "List each kind of warning the compiler can produce.") local types_command = parser:command("types", "Report all types found in one or more Teal files") @@ -896,12 +406,15 @@ local function get_config(cmd) local config_path = find_file_in_parent_dirs("tlconfig.lua") or "tlconfig.lua" - local conf, err = loadfile(config_path) - if not conf then - if err:match("No such file or directory$") then - build.tlconfig_not_found(cmd) - else - die("Error loading tlconfig.lua:\n" .. err) + local conf, err + local conf_fd = io.open(config_path, "r") + if conf_fd then + local conf_text = conf_fd:read("*a") + if conf_text then + conf, err = (loadstring or load)(conf_text) + if not conf then + die("Error loading tlconfig.lua:\n" .. err) + end end end @@ -920,8 +433,6 @@ local function get_config(cmd) end end - build.config_dir(cmd, config_path, config) - local errs, warnings = validate_config(config) if #errs > 0 then @@ -983,8 +494,6 @@ local function merge_config_and_args(tlconfig, args) for _, include in ipairs(tlconfig["include_dir"]) do prepend_to_lua_paths(include) end - - build.merge_config(tlconfig, args) end local function get_output_filename(file_name) @@ -1377,12 +886,6 @@ do end end --------------------------------------------------------------------------------- --- tl build --------------------------------------------------------------------------------- - -commands["build"] = build.run - -------------------------------------------------------------------------------- -- Main program -------------------------------------------------------------------------------- @@ -1412,6 +915,4 @@ if not args["quiet"] then end end -build.run_build_script(tlconfig) - commands[cmd](tlconfig, args) diff --git a/tl-dev-1.rockspec b/tl-dev-1.rockspec index 9e2516c76..4fd0b2ff1 100644 --- a/tl-dev-1.rockspec +++ b/tl-dev-1.rockspec @@ -17,10 +17,6 @@ dependencies = { -- needed for the cli tool "argparse", - - -- needed for build options - -- --build-dir, --source-dir, etc. - "luafilesystem", } test_dependencies = { "dkjson", From 4896aee708e084dcac7423f8cf9ca5b9b5714660 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 28 Nov 2023 11:43:10 -0300 Subject: [PATCH 032/224] remove node.type * Node objects no longer carry a `.type` attribute. This was redundant with the return values produced by the node visitors, which are then available in the traversal as the `children` array. The idea is that the logic of resolving types should follow the single-pass flow of traversal, instead of having to track the code of visitors that peek into the tree structure (or worse, modify it after the fact as they go). * Because of that, the collection of types for `tl types` has changed: instead of always collecting a symbol table during type checking, one needs to enable `report_types = true` in the `env`, which then performs the collection and summarization of the data reported by `tl types` in place. This makes the overall process more efficient for both `tl check` and `tl types`. --- spec/api/get_types_spec.lua | 8 +- spec/declaration/record_method_spec.lua | 10 +- spec/declaration/record_spec.lua | 30 +- spec/parser/parser_spec.lua | 3 - spec/statement/forin_spec.lua | 11 +- spec/stdlib/ipairs_spec.lua | 3 +- spec/stdlib/require_spec.lua | 1 + spec/util.lua | 47 + tl | 1 + tl.lua | 1571 ++++++++++++---------- tl.tl | 1611 ++++++++++++----------- 11 files changed, 1802 insertions(+), 1494 deletions(-) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index f80878874..26bbf1d05 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -2,11 +2,13 @@ local tl = require("tl") describe("tl.get_types", function() it("skips over label nodes (#393)", function() + local env = tl.init_env() + env.report_types = true local result = assert(tl.process_string([[ local function a() ::continue:: end - ]])) + ]], false, env)) local tr, trenv = tl.get_types(result) assert(tr) @@ -14,6 +16,8 @@ describe("tl.get_types", function() end) it("reports resolved type on poly function calls", function() + local env = tl.init_env() + env.report_types = true local result = assert(tl.process_string([[ local record R f: function(string) @@ -21,7 +25,7 @@ describe("tl.get_types", function() end R.f("hello") - ]])) + ]], false, env)) local tr, trenv = tl.get_types(result) local y = 6 diff --git a/spec/declaration/record_method_spec.lua b/spec/declaration/record_method_spec.lua index b920c17a2..20cbde3dc 100644 --- a/spec/declaration/record_method_spec.lua +++ b/spec/declaration/record_method_spec.lua @@ -349,15 +349,15 @@ describe("record method", function() })) it("does not fail when declaring methods on untyped self (regression test for #427)", util.check([[ - local record T - method: function(T, function(A)): T + local record Rec + my_method: function(Rec, function(A)): Rec end local t = { } - function t.new(): T + function t.new(): Rec local self = { } - function self:method(callback: function(A)): T + function self:my_method(callback: function(A)): Rec return self end return self @@ -397,7 +397,7 @@ describe("record method", function() local record Point end - function Point:new(x: number, y: number): Point + function Point:new(x?: number, y?: number): Point end function Point:move(dx: number, dy: number) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 7083f95d7..da1aaa71b 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -631,24 +631,22 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces })) end - it("can resolve generics partially (see #417)", function() - local _, ast = util.check([[ - local ]]..statement..[[ fun ]]..array(i, "{fun}")..[[ - ]]..statement..[[ iterator ]]..array(i, "{iterator}")..[[ - reduce: function(iterator, (function(R, T): R), R): R - end - iter: function({T}): iterator - end - - local f: fun + it("can resolve generics partially (see #417)", util.check_types([[ + local ]]..statement..[[ fun ]]..array(i, "{fun}")..[[ + ]]..statement..[[ iterator ]]..array(i, "{iterator}")..[[ + reduce: function(iterator, (function(R, T): R), R): R + end + iter: function({T}): iterator + end - local sum = f.iter({ 1, 2, 3, 4 }):reduce(function(a:integer,x:integer): integer - return a + x - end, 0) - ]])() + local f: fun - assert.same("integer", ast[3].exps[1].type[1].typename) - end) + local sum = f.iter({ 1, 2, 3, 4 }):reduce(function(a:integer,x:integer): integer + return a + x + end, 0) + ]], { + { y = 10, x = 16, type = "integer" }, + })) it("can have circular type dependencies on nested types", util.check([[ local type R = ]]..statement..[[ ]]..array(i, "{S}")..[[ diff --git a/spec/parser/parser_spec.lua b/spec/parser/parser_spec.lua index 07ae6cba5..8cbf40495 100644 --- a/spec/parser/parser_spec.lua +++ b/spec/parser/parser_spec.lua @@ -19,9 +19,6 @@ describe("parser", function() assert.same({ kind = "statements", tk = "$EOF$", - type = { - typename = "none", - }, x = 1, y = 1, xend = 5, diff --git a/spec/statement/forin_spec.lua b/spec/statement/forin_spec.lua index 11eb26dc9..9f93f131b 100644 --- a/spec/statement/forin_spec.lua +++ b/spec/statement/forin_spec.lua @@ -45,9 +45,12 @@ describe("forin", function() end end ]], { - { msg = "attempting ipairs loop" }, - { y = 3, msg = "argument 1: got A (unresolved generic), expected {A}" }, - { y = 4, msg = "cannot use operator '..' for types string \"value: \" and A (unresolved generic)" }, + { msg = "attempting ipairs" }, + { y = 3, msg = "expression in for loop does not return an iterator" }, + { y = 3, msg = "unknown variable: a" }, + { y = 4, msg = "unknown variable: i" }, + { y = 4, msg = "unknown variable: j" }, + { y = 4, msg = "unknown variable: b" }, })) end) @@ -66,7 +69,7 @@ describe("forin", function() end end ]], { - { msg = "attempting pairs loop" }, + { msg = "attempting pairs" }, { msg = "not all fields have the same type" }, { msg = "cannot index object of type Rec" }, })) diff --git a/spec/stdlib/ipairs_spec.lua b/spec/stdlib/ipairs_spec.lua index 4e0f75451..f92710666 100644 --- a/spec/stdlib/ipairs_spec.lua +++ b/spec/stdlib/ipairs_spec.lua @@ -13,7 +13,6 @@ describe("ipairs", function() for i, v in ipairs(my_tuple) do end ]], { - { msg = [[attempting ipairs loop on tuple that's not a valid array: ({{integer}, {string "a"}})]] }, - { msg = [[argument 1: unable to convert tuple {{integer}, {string "a"}} to array]] }, + { msg = [[attempting ipairs on tuple that's not a valid array: {{integer}, {string "a"}}]] }, })) end) diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 6d523963c..2703b6d2d 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -547,6 +547,7 @@ describe("require", function() assert.same(nil, err) assert.same({}, result.syntax_errors) + assert.same(1, #result.type_errors) assert.match("cannot add undeclared function 'draws' outside of the scope where 'love' was originally declared", result.type_errors[1].msg) end) diff --git a/spec/util.lua b/spec/util.lua index e62565c0e..fab01dcbf 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -545,6 +545,53 @@ function util.check_warnings(code, warnings, type_errors) end end +local function show_keys(arr) + local out = {} + for k, _ in pairs(arr) do + table.insert(out, k) + end + table.sort(out) + return table.concat(out, ", ") +end + +function util.check_types(code, types) + assert(type(code) == "string") + assert(type(types) == "table") + + return function() + local ast, syntax_errors = tl.parse(code, "foo.tl") + assert.same({}, syntax_errors, "Code was not expected to have syntax errors") + local batch = batch_assertions() + local env = tl.init_env() + env.report_types = true + local result = tl.type_check(ast, { filename = "foo.tl", env = env, lax = false }) + batch:add(assert.same, {}, result.type_errors, "Code was not expected to have type errors") + + local tr = tl.get_types(result, env.trenv) + for i, e in ipairs(types) do + assert(e.x, "[" .. i .. "] missing 'x' key in test specification") + assert(e.y, "[" .. i .. "] missing 'y' key in test specification") + assert(e.type, "[" .. i .. "] missing 'type' key in test specification") + local info = tr.by_pos["foo.tl"] + if not info[e.y] then + batch:add(assert.True, false, "[" .. i .. "] No type info for line " .. e.y .. " (has lines " .. show_keys(info) .. ")") + end + info = info[e.y] + if not info[e.x] then + batch:add(assert.True, false, "[" .. i .. "] No type info for position " .. e.x .. " in line " .. e.y .. " (has positions " .. show_keys(info) .. ")") + end + info = info[e.x] + if info then + info = tr.types[info] + batch:add(assert.same, e.type, info.str, "[" .. i .. "] Evaluated type at position " .. e.y .. ":" .. e.x .. " does not match:") + end + end + + batch:assert() + return true + end +end + local function gen(lax, code, expected, gen_target) return function() local ast, syntax_errors = tl.parse(code, "foo.tl") diff --git a/tl b/tl index 5858f1698..318f93b66 100755 --- a/tl +++ b/tl @@ -851,6 +851,7 @@ do local filename = args["file"][1] local env = setup_env(tlconfig, filename) env.keep_going = true + env.report_types = true local tr, trenv for i, input_file in ipairs(args["file"]) do diff --git a/tl.lua b/tl.lua index 0500bb921..b221f7e7d 100644 --- a/tl.lua +++ b/tl.lua @@ -140,6 +140,8 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = { + + @@ -252,7 +254,8 @@ if TL_DEBUG then return end - io.stderr:write(info.name or "", info.currentline > 0 and "@" .. info.currentline or "", " :: ", event, "\n") + local name = info.name or "", info.currentline > 0 and "@" .. info.currentline or "" + io.stderr:write(name, " :: ", event, "\n") io.stderr:flush() else count = count + 100 @@ -1438,6 +1441,7 @@ local Node = {ExpectedContext = {}, } + local function is_array_type(t) @@ -3922,7 +3926,7 @@ local function recurse_node(root, end if TL_DEBUG then - tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s] = %s", kprint, ast.type and show_type(ast.type)) + tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s]", kprint) end return ret @@ -4699,6 +4703,198 @@ end +local typename_to_typecode = { + ["typevar"] = tl.typecodes.TYPE_VARIABLE, + ["typearg"] = tl.typecodes.TYPE_VARIABLE, + ["unresolved_typearg"] = tl.typecodes.TYPE_VARIABLE, + ["unresolvable_typearg"] = tl.typecodes.TYPE_VARIABLE, + ["function"] = tl.typecodes.FUNCTION, + ["array"] = tl.typecodes.ARRAY, + ["map"] = tl.typecodes.MAP, + ["tupletable"] = tl.typecodes.TUPLE, + ["interface"] = tl.typecodes.INTERFACE, + ["record"] = tl.typecodes.RECORD, + ["enum"] = tl.typecodes.ENUM, + ["boolean"] = tl.typecodes.BOOLEAN, + ["string"] = tl.typecodes.STRING, + ["nil"] = tl.typecodes.NIL, + ["thread"] = tl.typecodes.THREAD, + ["number"] = tl.typecodes.NUMBER, + ["integer"] = tl.typecodes.INTEGER, + ["union"] = tl.typecodes.IS_UNION, + ["nominal"] = tl.typecodes.NOMINAL, + ["bad_nominal"] = tl.typecodes.NOMINAL, + ["circular_require"] = tl.typecodes.NOMINAL, + ["emptytable"] = tl.typecodes.EMPTY_TABLE, + ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, + ["poly"] = tl.typecodes.IS_POLY, + ["any"] = tl.typecodes.ANY, + ["unknown"] = tl.typecodes.UNKNOWN, + ["invalid"] = tl.typecodes.INVALID, + + ["none"] = tl.typecodes.UNKNOWN, + ["tuple"] = tl.typecodes.UNKNOWN, + ["table_item"] = tl.typecodes.UNKNOWN, + ["unresolved"] = tl.typecodes.UNKNOWN, + ["typetype"] = tl.typecodes.UNKNOWN, + ["nestedtype"] = tl.typecodes.UNKNOWN, +} + +local skip_types = { + ["none"] = true, + ["tuple"] = true, + ["table_item"] = true, + ["unresolved"] = true, + ["typetype"] = true, + ["nestedtype"] = true, +} + +local get_typenum + + +local function sorted_keys(m) + local keys = {} + for k, _ in pairs(m) do + table.insert(keys, k) + end + table.sort(keys) + return keys +end + + +local function mark_array(x) + local arr = x + arr[0] = false + return x +end + +function tl.init_type_report() + return { + next_num = 1, + typeid_to_num = {}, + tr = { + by_pos = {}, + types = {}, + symbols_by_file = {}, + globals = {}, + }, + } +end + +local function store_function(trenv, ti, rt) + local args = {} + for _, fnarg in ipairs(rt.args) do + table.insert(args, mark_array({ get_typenum(trenv, fnarg), nil })) + end + ti.args = mark_array(args) + local rets = {} + for _, fnarg in ipairs(rt.rets) do + table.insert(rets, mark_array({ get_typenum(trenv, fnarg), nil })) + end + ti.rets = mark_array(rets) + ti.vararg = not not rt.is_va +end + +get_typenum = function(trenv, t) + assert(t.typeid) + + local n = trenv.typeid_to_num[t.typeid] + if n then + return n + end + + local tr = trenv.tr + + + n = trenv.next_num + + local rt = t + if is_typetype(rt) then + rt = rt.def + elseif rt.typename == "tuple" and #rt == 1 then + rt = rt[1] + end + + local ti = { + t = assert(typename_to_typecode[rt.typename]), + str = show_type(t, true), + file = t.filename, + y = t.y, + x = t.x, + } + tr.types[n] = ti + trenv.typeid_to_num[t.typeid] = n + trenv.next_num = trenv.next_num + 1 + + if t.found then + ti.ref = get_typenum(trenv, t.found) + end + if t.resolved then + rt = t + end + assert(not is_typetype(rt)) + + if is_record_type(rt) then + + local r = {} + for _, k in ipairs(rt.field_order) do + local v = rt.fields[k] + r[k] = get_typenum(trenv, v) + end + ti.fields = r + end + + if is_array_type(rt) then + ti.elements = get_typenum(trenv, rt.elements) + end + + if rt.typename == "map" then + ti.keys = get_typenum(trenv, rt.keys) + ti.values = get_typenum(trenv, rt.values) + elseif rt.typename == "enum" then + ti.enums = mark_array(sorted_keys(rt.enumset)) + elseif rt.typename == "function" then + store_function(trenv, ti, rt) + elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then + local tis = {} + + for _, pt in ipairs(rt.types) do + table.insert(tis, get_typenum(trenv, pt)) + end + + ti.types = mark_array(tis) + end + + return n +end + +local function make_type_reporter(filename, trenv) + + + local ft = {} + trenv.tr.by_pos[filename] = ft + + local function store_type(y, x, typ) + if not typ or skip_types[typ.typename] then + return + end + + local yt = ft[y] + if not yt then + yt = {} + ft[y] = yt + end + + yt[x] = get_typenum(trenv, typ) + end + + return store_type +end + + + + + local function VARARG(t) local tuple = t tuple.typename = "tuple" @@ -5225,15 +5421,6 @@ end - -local function sorted_keys(m) - local keys = {} - for k, _ in pairs(m) do - table.insert(keys, k) - end - table.sort(keys) - return keys -end local function require_module(module_name, lax, env) local mod = env.modules[module_name] @@ -6050,9 +6237,6 @@ tl.type_check = function(ast, opts) local st = { env.globals } - local symbol_list = {} - local symbol_list_n = 0 - local all_needs_compat = {} local dependencies = {} @@ -6061,6 +6245,15 @@ tl.type_check = function(ast, opts) local module_type + local symbol_list + local symbol_list_n = 0 + local store_type + if env.report_types then + symbol_list = {} + env.trenv = env.trenv or tl.init_type_report() + store_type = make_type_reporter(filename or "?", env.trenv) + end + @@ -6165,16 +6358,19 @@ tl.type_check = function(ast, opts) end msg = msg:format(_tl_table_unpack(showt)) end + local name = where.filename or filename return { y = where.y, x = where.x, msg = msg, - filename = where.filename or filename, + filename = name, } end local function error_at(w, msg, ...) + assert(w.y) + local e = Err(w, msg, ...) if e then table.insert(errors, e) @@ -6321,7 +6517,7 @@ tl.type_check = function(ast, opts) if f.min_arity then return end - local tuple = f.args.tuple + local tuple = f.args local n = #tuple if f.args.is_va then n = n - 1 @@ -6556,10 +6752,9 @@ tl.type_check = function(ast, opts) }) end - local function node_error(node, msg, ...) - error_at(node, msg, ...) - node.type = INVALID - return node.type + local function invalid_at(where, msg, ...) + error_at(where, msg, ...) + return INVALID end local function add_unknown(node, name) @@ -6730,8 +6925,7 @@ tl.type_check = function(ast, opts) local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if node and t.typename ~= "unresolved" and t.typename ~= "none" then - node.type = node.type or t + if symbol_list and node and t.typename ~= "unresolved" and t.typename ~= "none" then local slot if node.symbol_list_slot then slot = node.symbol_list_slot @@ -6948,7 +7142,7 @@ tl.type_check = function(ast, opts) local function begin_scope(node) table.insert(st, {}) - if node then + if symbol_list and node then symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end @@ -6985,7 +7179,7 @@ tl.type_check = function(ast, opts) check_for_unused_vars(scope) table.remove(st) - if node then + if symbol_list and node then if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -6998,8 +7192,7 @@ tl.type_check = function(ast, opts) local end_scope_and_none_type = function(node, _children) end_scope(node) - node.type = NONE - return node.type + return NONE end local resolve_nominal @@ -7428,6 +7621,10 @@ tl.type_check = function(ast, opts) return true elseif is_self(t1) then + if is_self(t2) then + return true + end + return is_a(resolve_tuple_and_nominal(t1), t2, for_equality) elseif is_self(t2) then @@ -7915,7 +8112,7 @@ tl.type_check = function(ast, opts) local on_arg_id = function(node, _i) if used[node.tk] then - node_error(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -7926,7 +8123,6 @@ tl.type_check = function(ast, opts) local function apply_macroexp(orignode) local expanded = orignode.expanded - local savetype = orignode.type local saveknown = orignode.known orignode.expanded = nil @@ -7936,7 +8132,6 @@ tl.type_check = function(ast, opts) for k, v in pairs(expanded) do (orignode)[k] = v end - orignode.type = savetype orignode.known = saveknown end @@ -8055,7 +8250,7 @@ tl.type_check = function(ast, opts) end end - local function fail_call(node, func, nargs, errs) + local function fail_call(where, func, nargs, errs) if errs then for _, err in ipairs(errs) do @@ -8077,15 +8272,15 @@ tl.type_check = function(ast, opts) else table.insert(expects, show_arity(func)) end - node_error(node, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = func.typename == "poly" and func.types[1] or func mark_invalid_typeargs(f) - return resolve_typevars_at(node, f.rets) + return resolve_typevars_at(where, f.rets) end - local function check_call(node, where_args, func, args, is_method, argdelta) + local function check_call(where, where_args, func, args, expected, typetype_funcall, is_method, argdelta) assert(type(func) == "table") assert(type(args) == "table") @@ -8096,13 +8291,13 @@ tl.type_check = function(ast, opts) argdelta = is_method and -1 or argdelta or 0 if is_method and args[1] then - add_var(nil, "@self", a_type({ typename = "typetype", y = node.y, x = node.x, def = args[1] })) + add_var(nil, "@self", a_type({ typename = "typetype", y = where.y, x = where.x, def = args[1] })) end local is_func = func.typename == "function" local is_poly = func.typename == "poly" if not (is_func or is_poly) then - return node_error(node, "not a function: %s", func) + return invalid_at(where, "not a function: %s", func) end local passes, n = 1, 1 @@ -8120,40 +8315,37 @@ tl.type_check = function(ast, opts) if f.is_method and not is_method then if args[1] and is_a(args[1], f.args[1]) then - if node.kind == "op" and node.op.op == "@funcall" then - local receiver_is_typetype = node.e1.e1 and node.e1.e1.type and node.e1.e1.type.resolved and node.e1.e1.type.resolved.typename == "typetype" - if not receiver_is_typetype then - add_warning("hint", node, "invoked method as a regular function: consider using ':' instead of '.'") - end + if not typetype_funcall then + add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") end else - return node_error(node, "invoked method as a regular function: use ':' instead of '.'") + return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") end end - local expected = #f.args + local wanted = #f.args set_min_arity(f) - if (is_func and ((given <= expected and given >= f.min_arity) or (f.args.is_va and given > expected) or (lax and given <= expected))) or + if (is_func and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or - (is_poly and ((pass == 1 and given == expected) or + (is_poly and ((pass == 1 and given == wanted) or - (pass == 2 and given < expected and (lax or given >= f.min_arity)) or + (pass == 2 and given < wanted and (lax or given >= f.min_arity)) or - (pass == 3 and f.args.is_va and given > expected))) then + (pass == 3 and f.args.is_va and given > wanted))) then push_typeargs(f) - local matched, errs = check_args_rets(node, where_args, f, args, node.expected, argdelta) + local matched, errs = check_args_rets(where, where_args, f, args, expected, argdelta) if matched then return matched, f end first_errs = first_errs or errs - if node.expected then + if expected then - infer_emptytables(node, where_args, f.rets, f.rets, argdelta) + infer_emptytables(where, where_args, f.rets, f.rets, argdelta) end if is_poly then @@ -8166,7 +8358,7 @@ tl.type_check = function(ast, opts) end end - return fail_call(node, func, given, first_errs) + return fail_call(where, func, given, first_errs) end type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) @@ -8175,18 +8367,29 @@ tl.type_check = function(ast, opts) end begin_scope() - local ret, f = check_call(node, where_args, func, args, is_method, argdelta) + + local typetype_funcall = not not ( + node.kind == "op" and + node.op.op == "@funcall" and + node.e1 and + node.e1.receiver and + node.e1.receiver.resolved and + node.e1.receiver.resolved.typename == "typetype") + + + local ret, f = check_call(node, where_args, func, args, node.expected, typetype_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() - if e1 then - e1.type = f + + if store_type and e1 then + store_type(e1.y, e1.x, f) end if func.macroexp then expand_macroexp(node, where_args, func.macroexp) end - return ret + return ret, f end end @@ -8371,36 +8574,25 @@ tl.type_check = function(ast, opts) add_unknown(node, var) end - local existing, scope, existing_attr = find_var(var) - if existing and scope > 1 then - node_error(node, "cannot define a global when a local with the same name is in scope") - return nil - end - local is_const = node.attribute ~= nil - + local existing, scope, existing_attr = find_var(var) if existing then - if is_assigning and existing_attr then - node_error(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) - end - if existing_attr and not is_const then - node_error(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) - end - if (not existing_attr) and is_const then - node_error(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - end - if valtype and not same_type(existing.t, valtype) then - node_error(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + if scope > 1 then + error_at(node, "cannot define a global when a local with the same name is in scope") + elseif is_assigning and existing_attr then + error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + elseif existing_attr and not is_const then + error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + elseif (not existing_attr) and is_const then + error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) + elseif valtype and not same_type(existing.t, valtype) then + error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) end return nil end st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - if node then - node.type = node.type or valtype - end - return st[1][var] end @@ -8416,8 +8608,10 @@ tl.type_check = function(ast, opts) return t end - local function add_internal_function_variables(node) - add_var(nil, "@is_va", node.args.type.is_va and ANY or NIL) + local function add_internal_function_variables(node, args) + assert(args.typename == "tuple") + + add_var(nil, "@is_va", args.is_va and ANY or NIL) add_var(nil, "@return", node.rets or a_type({ typename = "tuple" })) if node.typeargs then @@ -8430,10 +8624,13 @@ tl.type_check = function(ast, opts) end end - local function add_function_definition_for_recursion(node) - local args = a_type({ typename = "tuple", is_va = node.args.type.is_va }) - for _, fnarg in ipairs(node.args) do - table.insert(args, fnarg.type) + local function add_function_definition_for_recursion(node, fnargs) + assert(fnargs.typename == "tuple") + + local args = TUPLE({}) + args.is_va = fnargs.is_va + for _, fnarg in ipairs(fnargs) do + table.insert(args, fnarg) end add_var(nil, node.name.tk, a_type({ @@ -8449,7 +8646,7 @@ tl.type_check = function(ast, opts) st[#st]["@unresolved"] = nil for name, nodes in pairs(unresolved.t.labels) do for _, node in ipairs(nodes) do - node_error(node, "no visible label '" .. name .. "' for goto") + error_at(node, "no visible label '" .. name .. "' for goto") end end for name, types in pairs(unresolved.t.nominals) do @@ -8509,15 +8706,17 @@ tl.type_check = function(ast, opts) end local last = vals[#vals] - if last.typename == "tuple" then + if last then + if last.typename == "tuple" then - is_va = last.is_va - for _, v in ipairs(last) do - table.insert(ret, v) - end - else + is_va = last.is_va + for _, v in ipairs(last) do + table.insert(ret, v) + end + else - table.insert(ret, last) + table.insert(ret, last) + end end @@ -8546,7 +8745,7 @@ tl.type_check = function(ast, opts) if t then return t else - return node_error(node, errmsg) + return invalid_at(node, errmsg) end end @@ -8634,7 +8833,7 @@ tl.type_check = function(ast, opts) return meta_t end - return node_error(bnode, errm, erra, errb) + return invalid_at(bnode, errm, erra, errb) end expand_type = function(where, old, new) @@ -9119,7 +9318,7 @@ tl.type_check = function(ast, opts) local function special_pcall_xpcall(node, _a, b, argdelta) local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then - node_error(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") return TUPLE({ BOOLEAN }) end @@ -9154,34 +9353,79 @@ tl.type_check = function(ast, opts) end local special_functions = { + ["pairs"] = function(node, a, b, argdelta) + if not b[1] then + return invalid_at(node, "pairs requires an argument") + end + local t = resolve_tuple_and_nominal(b[1]) + if is_array_type(t) then + add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + end + + if t.typename ~= "map" then + if not (lax and is_unknown(t)) then + if is_record_type(t) then + match_all_record_field_names(node.e2, t, t.field_order, + "attempting pairs on a record with attributes of different types") + local ct = t.typename == "record" and "{string:any}" or "{any:any}" + add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + else + error_at(node.e2, "cannot apply pairs on values of type: %s", t) + end + end + end + + return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + end, + + ["ipairs"] = function(node, a, b, argdelta) + if not b[1] then + return invalid_at(node, "ipairs requires an argument") + end + local t = resolve_tuple_and_nominal(b[1]) + + if t.typename == "tupletable" then + local arr_type = arraytype_from_tuple(node.e2, t) + if not arr_type then + return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", t) + end + elseif not is_array_type(t) then + if not (lax and (is_unknown(t) or t.typename == "emptytable")) then + return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", t) + end + end + + return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + end, + ["rawget"] = function(node, _a, b, _argdelta) if #b == 2 then return type_check_index(node.e2[1], node.e2[2], b[1], b[2]) else - return node_error(node, "rawget expects two arguments") + return invalid_at(node, "rawget expects two arguments") end end, ["require"] = function(node, _a, b, _argdelta) if #b ~= 1 then - return node_error(node, "require expects one literal argument") + return invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return ANY + return a_type({ typename = "any" }) end local module_name = assert(node.e2[1].conststr) local t, found = require_module(module_name, lax, env) if not found then - return node_error(node, "module not found: '" .. module_name .. "'") + return invalid_at(node, "module not found: '" .. module_name .. "'") end if t.typename == "invalid" then if lax then return UNKNOWN end - return node_error(node, "no type information for required module: '" .. module_name .. "'") + return invalid_at(node, "no type information for required module: '" .. module_name .. "'") end dependencies[module_name] = t.filename @@ -9206,13 +9450,13 @@ tl.type_check = function(ast, opts) if special then return special(node, a, b, argdelta) else - return type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta) + return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then - table.insert(b, 1, node.e1.e1.type) - return type_check_function_call(node, node.e2, a, b, node.e1, true) + table.insert(b, 1, node.e1.receiver) + return (type_check_function_call(node, node.e2, a, b, node.e1, true)) else - return type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta) + return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) end end @@ -9248,9 +9492,9 @@ tl.type_check = function(ast, opts) return UNKNOWN else if node.exps then - return node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return node_error(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end @@ -9301,13 +9545,13 @@ tl.type_check = function(ast, opts) - local function check_redeclared_key(node, ctx, seen_keys, key) + local function check_redeclared_key(where, ctx, seen_keys, key) if key ~= nil then local s = seen_keys[key] if s then - node_error(node, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) + error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) else - seen_keys[key] = node + seen_keys[key] = where end end end @@ -9408,7 +9652,7 @@ tl.type_check = function(ast, opts) typ.keys = expand_type(node, typ.keys, INTEGER) typ.values = expand_type(node, typ.values, typ.elements) typ.elements = nil - node_error(node, "cannot determine type of table literal") + error_at(node, "cannot determine type of table literal") elseif is_record and is_array then typ.typename = "record" typ.interface_list = { @@ -9430,7 +9674,7 @@ tl.type_check = function(ast, opts) typ.fields = nil typ.field_order = nil else - node_error(node, "cannot determine type of table literal") + error_at(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -9463,7 +9707,7 @@ tl.type_check = function(ast, opts) elseif is_tuple then typ.typename = "tupletable" if not typ.types or #typ.types == 0 then - node_error(node, "cannot determine type of tuple elements") + error_at(node, "cannot determine type of tuple elements") end end @@ -9500,7 +9744,7 @@ tl.type_check = function(ast, opts) end else if infertype and infertype.typename == "unresolvable_typearg" then - node_error(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") + error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false infertype = INVALID elseif infertype and infertype.is_method then @@ -9514,15 +9758,15 @@ tl.type_check = function(ast, opts) if var.attribute == "total" then local rd = decltype and resolve_tuple_and_nominal(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - node_error(var, "attribute only applies to maps and records") + error_at(var, "attribute only applies to maps and records") ok = false elseif not infertype then - node_error(var, "variable declared does not declare an initialization value") + error_at(var, "variable declared does not declare an initialization value") ok = false elseif not (node.exps[i] and node.exps[i].attribute == "total") then local ri = resolve_tuple_and_nominal(infertype) if ri.typename ~= "map" and ri.typename ~= "record" then - node_error(var, "attribute only applies to maps and records") + error_at(var, "attribute only applies to maps and records") ok = false elseif not infertype.is_total then local missing = "" @@ -9530,10 +9774,10 @@ tl.type_check = function(ast, opts) missing = " (missing: " .. table.concat(infertype.missing, ", ") .. ")" end if ri.typename == "map" then - node_error(var, "map variable declared does not declare values for all possible keys" .. missing) + error_at(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri.typename == "record" then - node_error(var, "record variable declared does not declare values for all fields" .. missing) + error_at(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -9618,6 +9862,50 @@ tl.type_check = function(ast, opts) return nil end + + + + + local function check_assignment(where, vartype, valtype, varname, attr) + if varname then + if widen_back_var(varname) then + vartype, attr = find_var_type(varname) + if not vartype then + error_at(where, "unknown variable") + return nil + end + end + end + if attr == "close" or attr == "const" or attr == "total" then + error_at(where, "cannot assign to <" .. attr .. "> variable") + return nil + end + + local var = resolve_tuple_and_nominal(vartype) + if is_typetype(var) then + error_at(where, "cannot reassign a type") + return nil + end + + if not valtype then + error_at(where, "variable is not being assigned a value") + return nil, nil, "missing" + end + + assert_is_a(where, valtype, vartype, "in assignment") + + local val = resolve_tuple_and_nominal(valtype) + + return var, val + end + + local function discard_tuple(node, t, b) + if b.typename == "tuple" then + node.discarded_tuple = true + end + return resolve_tuple(t) + end + local visit_node = {} visit_node.cbs = { @@ -9635,8 +9923,7 @@ tl.type_check = function(ast, opts) end_scope(node) end - node.type = NONE - return node.type + return NONE end, }, ["local_type"] = { @@ -9644,7 +9931,7 @@ tl.type_check = function(ast, opts) local name = node.var.tk local resolved, aliasing = get_type_declaration(node) local var = add_var(node.var, name, resolved, node.var.attribute) - node.value.type = resolved + if aliasing then var.aliasing = aliasing node.value.is_alias = true @@ -9652,8 +9939,7 @@ tl.type_check = function(ast, opts) end, after = function(node, _children) dismiss_unresolved(node.var.tk) - node.type = NONE - return node.type + return NONE end, }, ["global_type"] = { @@ -9680,14 +9966,15 @@ tl.type_check = function(ast, opts) end, after = function(node, _children) dismiss_unresolved(node.var.tk) - node.type = NONE - return node.type + return NONE end, }, ["local_declaration"] = { before = function(node) - for _, var in ipairs(node.vars) do - reserve_symbol_list_slot(var) + if symbol_list then + for _, var in ipairs(node.vars) do + reserve_symbol_list_slot(var) + end end end, before_exp = set_expected_types_to_decltypes, @@ -9698,12 +9985,12 @@ tl.type_check = function(ast, opts) if var.attribute == "close" then if opts.gen_target == "5.4" then if encountered_close then - node_error(var, "only one per declaration is allowed") + error_at(var, "only one per declaration is allowed") else encountered_close = true end else - node_error(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") end end @@ -9711,9 +9998,9 @@ tl.type_check = function(ast, opts) if var.attribute == "close" then if not type_is_closable(t) then - node_error(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - node_error(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end @@ -9726,14 +10013,18 @@ tl.type_check = function(ast, opts) local rt = resolve_tuple_and_nominal(t) if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then - add_var(where, var.tk, infer_at(where, infertype), "const", "narrowed_declaration") + t = infer_at(where, infertype) + add_var(where, var.tk, t, "const", "narrowed_declaration") end end + if store_type then + store_type(var.y, var.x, t) + end + dismiss_unresolved(var.tk) end - node.type = NONE - return node.type + return NONE end, }, ["global_declaration"] = { @@ -9744,70 +10035,55 @@ tl.type_check = function(ast, opts) local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - node_error(var, "globals may not be ") + error_at(var, "globals may not be ") end add_global(var, var.tk, t, is_inferred) - var.type = t dismiss_unresolved(var.tk) end - node.type = NONE - return node.type + return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltypes, after = function(node, children) local valtypes = get_assignment_values(children[3], #children[1]) - local exps = flatten_list(valtypes) + valtypes = flatten_list(valtypes) for i, vartype in ipairs(children[1]) do local varnode = node.vars[i] - local attr = varnode.attribute - if varnode.kind == "variable" then - if widen_back_var(varnode.tk) then - vartype, attr = find_var_type(varnode.tk) + local varname = varnode.tk + local rvar, rval, err = check_assignment(varnode, vartype, valtypes[i], varname, varnode.attribute) + if err == "missing" then + if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then + local rets = children[3] + if rets.typename == "tuple" then + local msg = #rets == 1 and + "only 1 value is returned by the function" or + ("only " .. #rets .. " values are returned by the function") + add_warning("hint", varnode, msg) + end end end - if attr then - node_error(varnode, "cannot assign to <" .. attr .. "> variable") - end - if vartype then - local val = exps[i] - if is_typetype(resolve_tuple_and_nominal(vartype)) then - node_error(varnode, "cannot reassign a type") - elseif val then - assert_is_a(varnode, val, vartype, "in assignment") + if rval and rvar then - local rval = resolve_tuple_and_nominal(val) - if rval.typename == "function" then - widen_all_unions() - end + if rval.typename == "function" then + widen_all_unions() + end - if varnode.kind == "variable" and vartype.typename == "union" then + if varname and rvar.typename == "union" then - add_var(varnode, varnode.tk, val, nil, "narrow") - end - else - node_error(varnode, "variable is not being assigned a value") - if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then - local rets = node.exps[1].type - if rets.typename == "tuple" then - local msg = #rets == 1 and - "only 1 value is returned by the function" or - ("only " .. #rets .. " values are returned by the function") - add_warning("hint", varnode, msg) - end - end + add_var(varnode, varname, rval, nil, "narrow") + end + + if store_type then + store_type(varnode.y, varnode.x, valtypes[i]) end - else - node_error(varnode, "unknown variable") end end - node.type = NONE - return node.type + return NONE end, }, ["if"] = { @@ -9824,8 +10100,7 @@ tl.type_check = function(ast, opts) infer_negation_of_if_blocks(node, node, #node.if_blocks) end - node.type = NONE - return node.type + return NONE end, }, ["if_block"] = { @@ -9847,8 +10122,7 @@ tl.type_check = function(ast, opts) node.block_returns = true end - node.type = NONE - return node.type + return NONE end, }, ["while"] = { @@ -9868,11 +10142,10 @@ tl.type_check = function(ast, opts) widen_all_unions() local label_id = "::" .. node.label .. "::" if st[#st][label_id] then - node_error(node, "label '" .. node.label .. "' already defined at " .. filename) + error_at(node, "label '" .. node.label .. "' already defined at " .. filename) end local unresolved = st[#st]["@unresolved"] - node.type = a_type({ y = node.y, x = node.x, typename = "none" }) - local var = add_var(node, label_id, node.type) + local var = add_var(node, label_id, a_type({ y = node.y, x = node.x, typename = "none" })) if unresolved then if unresolved.t.labels[node.label] then var.used = true @@ -9880,6 +10153,9 @@ tl.type_check = function(ast, opts) unresolved.t.labels[node.label] = nil end end, + after = function() + return NONE + end, }, ["goto"] = { after = function(node, _children) @@ -9888,8 +10164,8 @@ tl.type_check = function(ast, opts) unresolved.labels[node.label] = unresolved.labels[node.label] or {} table.insert(unresolved.labels[node.label], node) end - node.type = NONE - return node.type + + return NONE end, }, ["repeat"] = { @@ -9904,55 +10180,25 @@ tl.type_check = function(ast, opts) before = function(node) begin_scope(node) end, - before_statements = function(node) + before_statements = function(node, children) + local exptypes = children[2] + widen_all_unions(node) local exp1 = node.exps[1] local args = { typename = "tuple", - node.exps[2] and node.exps[2].type, - node.exps[3] and node.exps[3].type, + node.exps[2] and exptypes[2], + node.exps[3] and exptypes[3], } - local exp1type = resolve_for_call(exp1.type, args, false) + local exp1type = resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then - type_check_function_call(exp1, { node.exps[2], node.exps[3] }, exp1type, args, exp1, false, 0) - exp1type = exp1.type or exp1type + local _ + _, exp1type = type_check_function_call(exp1, { node.exps[2], node.exps[3] }, exp1type, args, exp1, false, 0) end if exp1type.typename == "function" then - if exp1.op and exp1.op.op == "@funcall" then - local t = resolve_tuple_and_nominal(exp1.e2.type) - if exp1.e1.tk == "pairs" and is_array_type(t) then - add_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") - end - - if exp1.e1.tk == "pairs" and t.typename ~= "map" then - if not (lax and is_unknown(t)) then - if is_record_type(t) then - match_all_record_field_names(exp1.e2, t, t.field_order, - "attempting pairs loop on a record with attributes of different types") - local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) - else - node_error(exp1.e2, "cannot apply pairs on values of type: %s", exp1.e2.type) - end - end - elseif exp1.e1.tk == "ipairs" then - if t.typename == "tupletable" then - local arr_type = arraytype_from_tuple(exp1.e2, t) - if not arr_type then - node_error(exp1.e2, "attempting ipairs loop on tuple that's not a valid array: %s", exp1.e2.type) - end - elseif not is_array_type(t) then - if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - node_error(exp1.e2, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type) - end - end - end - end - - local last local rets = exp1type.rets for i, v in ipairs(node.vars) do @@ -9971,11 +10217,11 @@ tl.type_check = function(ast, opts) local nrets = #rets local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - node_error(at, "too many variables for this iterator; it produces " .. n_values) + error_at(at, "too many variables for this iterator; it produces " .. n_values) end else if not (lax and is_unknown(exp1type)) then - node_error(exp1, "expression in for loop does not return an iterator") + error_at(exp1, "expression in for loop does not return an iterator") end end end, @@ -10028,14 +10274,14 @@ tl.type_check = function(ast, opts) end if #children[1] > nrets and (not lax) and not vatype then - node_error(node, what .. ": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) + error_at(node, what .. ": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) end if nrets > 1 and #node.exps == 1 and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and - #node.exps[1].e2.type > 1 then + node.exps[1].discarded_tuple then add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end @@ -10051,27 +10297,26 @@ tl.type_check = function(ast, opts) end end - node.type = NONE - return node.type + return NONE end, }, ["variable_list"] = { - after = function(node, children) - node.type = TUPLE(children) + after = function(_node, children) + local tuple = TUPLE(children) - local n = #children - if n > 0 and children[n].typename == "tuple" then - if children[n].is_va then - node.type.is_va = true + local n = #tuple + if n > 0 and tuple[n].typename == "tuple" then + local final_tuple = tuple[n] + if final_tuple.is_va then + tuple.is_va = true end - local tuple = children[n] - for i, c in ipairs(tuple) do - children[n + i - 1] = c + for i, c in ipairs(final_tuple) do + tuple[n + i - 1] = c end end - return node.type + return tuple end, }, ["table_literal"] = { @@ -10110,130 +10355,130 @@ tl.type_check = function(ast, opts) after = function(node, children) node.known = FACT_TRUTHY - if node.expected then - local decltype = resolve_tuple_and_nominal(node.expected) + if not node.expected then + return infer_table_literal(node, children) + end - if decltype.typename == "union" then - local single_table_type - local single_table_rt + local decltype = resolve_tuple_and_nominal(node.expected) - for _, t in ipairs(decltype.types) do - local rt = resolve_tuple_and_nominal(t) - if is_lua_table_type(rt) then - if single_table_type then + if decltype.typename == "union" then + local single_table_type + local single_table_rt - single_table_type = nil - single_table_rt = nil - break - end + for _, t in ipairs(decltype.types) do + local rt = resolve_tuple_and_nominal(t) + if is_lua_table_type(rt) then + if single_table_type then - single_table_type = t - single_table_rt = rt + single_table_type = nil + single_table_rt = nil + break end - end - if single_table_type then - node.expected = single_table_type - decltype = single_table_rt + single_table_type = t + single_table_rt = rt end end - if not is_lua_table_type(decltype) then - node.type = infer_table_literal(node, children) - return node.type + if single_table_type then + node.expected = single_table_type + decltype = single_table_rt end + end - local is_record = is_record_type(decltype) - local is_array = is_array_type(decltype) - local is_tupletable = decltype.typename == "tupletable" - local is_map = decltype.typename == "map" + if not is_lua_table_type(decltype) then + return infer_table_literal(node, children) + end - local force_array = nil + local is_record = is_record_type(decltype) + local is_array = is_array_type(decltype) + local is_tupletable = decltype.typename == "tupletable" + local is_map = decltype.typename == "map" - local seen_keys = {} + local force_array = nil - for i, child in ipairs(children) do - assert(child.typename == "table_item") - local cvtype = resolve_tuple(child.vtype) - local ck = child.kname - local n = node[i].key.constnum - local b = nil - if child.ktype.typename == "boolean" then - b = (node[i].key.tk == "true") - end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) - if is_record and ck then - local df = decltype.fields[ck] - if not df then - node_error(node[i], in_context(node.expected_context, "unknown field " .. ck)) + local seen_keys = {} + + for i, child in ipairs(children) do + assert(child.typename == "table_item") + local cvtype = resolve_tuple(child.vtype) + local ck = child.kname + local n = node[i].key.constnum + local b = nil + if child.ktype.typename == "boolean" then + b = (node[i].key.tk == "true") + end + check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + if is_record and ck then + local df = decltype.fields[ck] + if not df then + error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + else + if is_typetype(df) then + error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else - if is_typetype(df) then - node_error(node[i], in_context(node.expected_context, "cannot reassign a type")) - else - assert_is_a(node[i], cvtype, df, "in record field", ck) - end + assert_is_a(node[i], cvtype, df, "in record field", ck) end - elseif is_tupletable and is_number_type(child.ktype) then - local dt = decltype.types[n] - if not n then - node_error(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) - elseif not dt then - node_error(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) - else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) - end - elseif is_array and is_number_type(child.ktype) then - if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then + end + elseif is_tupletable and is_number_type(child.ktype) then + local dt = decltype.types[n] + if not n then + error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + elseif not dt then + error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + else + assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + end + elseif is_array and is_number_type(child.ktype) then + if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then - for ti, tt in ipairs(child.vtype) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) - end - else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + for ti, tt in ipairs(child.vtype) do + assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) end - elseif node[i].key_parsed == "implicit" then - if is_map then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) - end - force_array = expand_type(node[i], force_array, child.vtype) - elseif is_map then - force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) else - node_error(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) end - end - - if force_array then - node.type = infer_at(node, a_type({ - typename = "array", - elements = force_array, - })) - else - node.type = resolve_typevars_at(node, node.expected) - if node.expected == node.type and node.type.typename == "nominal" then - node.type = { - typeid = node.type.typeid, - typename = "nominal", - names = node.type.names, - found = node.type.found, - resolved = node.type.resolved, - } + elseif node[i].key_parsed == "implicit" then + if is_map then + assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) + assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) end + force_array = expand_type(node[i], force_array, child.vtype) + elseif is_map then + force_array = nil + assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) + assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + else + error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) end + end - if decltype.typename == "record" then - node.type.is_total, node.type.missing = total_record_check(decltype, seen_keys) - elseif decltype.typename == "map" then - node.type.is_total, node.type.missing = total_map_check(decltype, seen_keys) - end + local t + if force_array then + t = infer_at(node, a_type({ + typename = "array", + elements = force_array, + })) else - node.type = infer_table_literal(node, children) + t = resolve_typevars_at(node, node.expected) + if node.expected == t and t.typename == "nominal" then + t = { + typeid = t.typeid, + typename = "nominal", + names = t.names, + found = t.found, + resolved = t.resolved, + } + end + end + + if decltype.typename == "record" then + t.is_total, t.missing = total_record_check(decltype, seen_keys) + elseif decltype.typename == "map" then + t.is_total, t.missing = total_map_check(decltype, seen_keys) end - return node.type + return t end, }, ["table_item"] = { @@ -10251,7 +10496,7 @@ tl.type_check = function(ast, opts) vtype.typeid = new_typeid() vtype.is_method = false end - node.type = a_type({ + return a_type({ y = node.y, x = node.x, typename = "table_item", @@ -10259,24 +10504,26 @@ tl.type_check = function(ast, opts) ktype = ktype, vtype = vtype, }) - return node.type end, }, ["local_function"] = { before = function(node) widen_all_unions() - reserve_symbol_list_slot(node) + if symbol_list then + reserve_symbol_list_slot(node) + end begin_scope(node) end, - before_statements = function(node) - add_internal_function_variables(node) - add_function_definition_for_recursion(node) + before_statements = function(node, children) + local args = children[2] + add_internal_function_variables(node, args) + add_function_definition_for_recursion(node, args) end, after = function(node, children) end_function_scope(node) local rets = get_rets(children[3]) - add_var(node, node.name.tk, ensure_fresh_typeargs(a_type({ + local t = ensure_fresh_typeargs(a_type({ y = node.y, x = node.x, typename = "function", @@ -10284,8 +10531,10 @@ tl.type_check = function(ast, opts) args = children[2], rets = rets, filename = filename, - }))) - return node.type + })) + + add_var(node, node.name.tk, t) + return t end, }, ["global_function"] = { @@ -10298,22 +10547,24 @@ tl.type_check = function(ast, opts) if typ.typename == "function" then node.is_predeclared_local_function = true elseif not lax then - node_error(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end elseif not lax then - node_error(node, "functions need an explicit 'local' or 'global' annotation") + error_at(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node) - add_internal_function_variables(node) - add_function_definition_for_recursion(node) + before_statements = function(node, children) + local args = children[2] + add_internal_function_variables(node, args) + add_function_definition_for_recursion(node, args) end, after = function(node, children) end_function_scope(node) if node.is_predeclared_local_function then - return node.type + return NONE end + add_global(node, node.name.tk, ensure_fresh_typeargs(a_type({ y = node.y, x = node.x, @@ -10323,7 +10574,8 @@ tl.type_check = function(ast, opts) rets = get_rets(children[3]), filename = filename, }))) - return node.type + + return NONE end, }, ["record_function"] = { @@ -10347,7 +10599,7 @@ tl.type_check = function(ast, opts) end end, before_statements = function(node, children) - add_internal_function_variables(node) + local args = children[3] local rtype = node.rtype if rtype.typename == "emptytable" then @@ -10361,17 +10613,17 @@ tl.type_check = function(ast, opts) end if not is_record_type(rtype) then - node_error(node, "not a module: %s", rtype) + error_at(node, "not a module: %s", rtype) return end + local selftype = get_self_type(node.fn_owner) if node.is_method then - local selftype = get_self_type(node.fn_owner) if not selftype then - node_error(node, "could not resolve type of self") + error_at(node, "could not resolve type of self") return end - children[3][1] = selftype + args[1] = selftype add_var(nil, "self", selftype) end @@ -10381,7 +10633,7 @@ tl.type_check = function(ast, opts) typename = "function", is_method = node.is_method, typeargs = node.typeargs, - args = children[3], + args = args, rets = get_rets(children[4]), filename = filename, })) @@ -10403,9 +10655,7 @@ tl.type_check = function(ast, opts) return end - local shortname = node.fn_owner.type.typename == "nominal" and - show_type(node.fn_owner.type) or - owner_name + local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " add_errs_prefixing(node, err, errors, msg) return @@ -10415,7 +10665,7 @@ tl.type_check = function(ast, opts) rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -10427,12 +10677,12 @@ tl.type_check = function(ast, opts) end open_v.implemented[open_k] = true end - node.name.type = fn_type + + add_internal_function_variables(node, args) end, after = function(node, _children) end_function_scope(node) - node.type = NONE - return node.type + return NONE end, }, ["function"] = { @@ -10440,14 +10690,15 @@ tl.type_check = function(ast, opts) widen_all_unions(node) begin_scope(node) end, - before_statements = function(node) - add_internal_function_variables(node) + before_statements = function(node, children) + local args = children[1] + add_internal_function_variables(node, args) end, after = function(node, children) end_function_scope(node) - node.type = ensure_fresh_typeargs(a_type({ + return ensure_fresh_typeargs(a_type({ y = node.y, x = node.x, typename = "function", @@ -10456,7 +10707,6 @@ tl.type_check = function(ast, opts) rets = children[2], filename = filename, })) - return node.type end, }, ["macroexp"] = { @@ -10464,14 +10714,15 @@ tl.type_check = function(ast, opts) widen_all_unions(node) begin_scope(node) end, - before_exp = function(node) - add_internal_function_variables(node) + before_exp = function(node, children) + local args = children[1] + add_internal_function_variables(node, args) end, after = function(node, children) end_function_scope(node) - node.type = ensure_fresh_typeargs(a_type({ + return ensure_fresh_typeargs(a_type({ y = node.y, x = node.x, typename = "function", @@ -10480,13 +10731,11 @@ tl.type_check = function(ast, opts) rets = children[2], filename = filename, })) - return node.type end, }, ["cast"] = { after = function(node, _children) - node.type = node.casttype - return node.type + return node.casttype end, }, ["paren"] = { @@ -10495,8 +10744,7 @@ tl.type_check = function(ast, opts) end, after = function(node, children) node.known = node.e1 and node.e1.known - node.type = resolve_tuple(children[1]) - return node.type + return resolve_tuple(children[1]) end, }, ["op"] = { @@ -10513,18 +10761,20 @@ tl.type_check = function(ast, opts) end end end, - before_e2 = function(node) + before_e2 = function(node, children) + local e1type = children[1] + if node.op.op == "and" then apply_facts(node, node.e1.known) elseif node.op.op == "or" then apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then - if node.e1.type.typename == "function" then + if e1type.typename == "function" then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(node.e1.type.rets, node.expected) + is_a(e1type.rets, node.expected) end - local e1args = node.e1.type.args + local e1args = e1type.args local at = argdelta for _, typ in ipairs(e1args) do at = at + 1 @@ -10540,8 +10790,8 @@ tl.type_check = function(ast, opts) end end elseif node.op.op == "@index" then - if node.e1.type.typename == "map" then - node.e2.expected = node.e1.type.keys + if e1type.typename == "map" then + node.e2.expected = e1type.keys end end end, @@ -10559,23 +10809,21 @@ tl.type_check = function(ast, opts) local expected = node.expected and resolve_tuple_and_nominal(node.expected) if ra.typename == "circular_require" or (ra.def and ra.def.typename == "circular_require") then - node_error(node, "cannot dereference a type from a circular require") - node.type = INVALID - return node.type + return invalid_at(node, "cannot dereference a type from a circular require") end if is_typetype(ra) then if ra.def.typename == "record" then ra = ra.def elseif ra.def.typename == "interface" then - node_error(node, "interfaces are abstract; consider using a concrete record") + error_at(node, "interfaces are abstract; consider using a concrete record") end end if rb and is_typetype(rb) and rb.def.typename == "record" then if rb.def.typename == "record" then rb = rb.def elseif rb.def.typename == "interface" then - node_error(node, "interfaces are abstract; consider using a concrete record") + error_at(node, "interfaces are abstract; consider using a concrete record") end end @@ -10585,8 +10833,12 @@ tl.type_check = function(ast, opts) add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - node.type = type_check_funcall(node, a, b) - elseif node.op.op == "." then + return type_check_funcall(node, a, b) + end + + if node.op.op == "." then + node.receiver = a + assert(node.e2.kind == "identifier") local bnode = { y = node.e2.y, @@ -10601,9 +10853,9 @@ tl.type_check = function(ast, opts) tk = '"' .. node.e2.tk .. '"', typename = "string", }) - node.type = type_check_index(node.e1, bnode, orig_a, btype) + local t = type_check_index(node.e1, bnode, orig_a, btype) - if node.type.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and opts.gen_compat ~= "off" then if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk @@ -10612,83 +10864,112 @@ tl.type_check = function(ast, opts) all_needs_compat[key] = true end end - elseif node.op.op == "@index" then - node.type = type_check_index(node.e1, node.e2, a, b) - elseif node.op.op == "as" then - node.type = b - elseif node.op.op == "is" then + + return t + end + + if node.op.op == "@index" then + return type_check_index(node.e1, node.e2, a, b) + end + + if node.op.op == "as" then + return b + end + + if node.op.op == "is" then if rb.typename == "integer" then all_needs_compat["math"] = true end if ra.typename == "typetype" then - node_error(node, "can only use 'is' on variables, not types") + error_at(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) node.known = IsFact({ var = node.e1.tk, typ = b, where = node }) else - node_error(node, "can only use 'is' on variables") + error_at(node, "can only use 'is' on variables") end - node.type = BOOLEAN - elseif node.op.op == ":" then + return BOOLEAN + end + + if node.op.op == ":" then + node.receiver = a + if lax and (is_unknown(a) or a.typename == "typevar") then if node.e1.kind == "variable" then add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - node.type = UNKNOWN - else - local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) - if not t then - node.type = INVALID - return node_error(node.e2, e, a == INVALID and a or resolve_tuple(orig_a)) - end - node.type = t + return UNKNOWN + end + + local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) + if not t then + return invalid_at(node.e2, e, a == INVALID and a or resolve_tuple(orig_a)) end - elseif node.op.op == "not" then + + return t + end + + if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - node.type = BOOLEAN - elseif node.op.op == "and" then + return BOOLEAN + end + + if node.op.op == "and" then node.known = facts_and(node, node.e1.known, node.e2.known) - node.type = resolve_tuple(b) - elseif node.op.op == "or" and b.typename == "nil" then - node.known = nil - node.type = resolve_tuple(a) - elseif node.op.op == "or" and is_lua_table_type(ra) and b.typename == "emptytable" then - node.known = nil - node.type = resolve_tuple(a) - elseif node.op.op == "or" and - ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or - (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then - node.known = nil - node.type = (ra.typename == "enum" and ra or rb) - elseif node.op.op == "or" and expected and expected.typename == "union" then - - node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ ra, rb }, true) - if u.typename == "union" then - u = validate_union(node, u) - end - node.type = u - elseif node.op.op == "or" and is_a(rb, ra) then - node.known = facts_or(node, node.e1.known, node.e2.known) - if expected then - local a_is = is_a(a, node.expected) - local b_is = is_a(b, node.expected) - if a_is and b_is then - node.type = resolve_typevars_at(node, node.expected) - elseif a_is then - node.type = resolve_tuple(b) + return discard_tuple(node, b, b) + end + + if node.op.op == "or" then + local t + if b.typename == "nil" then + node.known = nil + t = a + + elseif is_lua_table_type(ra) and b.typename == "emptytable" then + node.known = nil + t = a + + elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or + (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then + node.known = nil + t = (ra.typename == "enum" and ra or rb) + + elseif expected and expected.typename == "union" then + + node.known = facts_or(node, node.e1.known, node.e2.known) + local u = unite({ ra, rb }, true) + if u.typename == "union" then + u = validate_union(node, u) + end + t = u + + elseif is_a(rb, ra) then + node.known = facts_or(node, node.e1.known, node.e2.known) + if expected then + local a_is = is_a(a, node.expected) + local b_is = is_a(b, node.expected) + if a_is and b_is then + t = resolve_typevars_at(node, node.expected) + elseif a_is then + t = resolve_tuple(b) + else + t = resolve_tuple(a) + end else - node.type = resolve_tuple(a) + t = resolve_tuple(a) end - else - node.type = resolve_tuple(a) + t.tk = nil end - node.type.tk = nil - elseif node.op.op == "==" or node.op.op == "~=" then - node.type = BOOLEAN + if t then + return discard_tuple(node, t, b) + end + + end + + if node.op.op == "==" or node.op.op == "~=" then @@ -10702,33 +10983,39 @@ tl.type_check = function(ast, opts) node.known = EqFact({ var = node.e2.tk, typ = a, where = node }) end elseif lax and (is_unknown(a) or is_unknown(b)) then - node.type = UNKNOWN + return UNKNOWN else - return node_error(node, "types are not comparable for equality: %s and %s", a, b) + return invalid_at(node, "types are not comparable for equality: %s and %s", a, b) end - elseif node.op.arity == 1 and unop_types[node.op.op] then + + return BOOLEAN + end + + if node.op.arity == 1 and unop_types[node.op.op] then a = ra if a.typename == "union" then a = unite(a.types, true) end local types_op = unop_types[node.op.op] - node.type = types_op[a.typename] - if not node.type then - node.type = find_in_interface_list(a, function(t) - return types_op[t.typename] + local t = types_op[a.typename] + + if not t then + t = find_in_interface_list(a, function(ty) + return types_op[ty.typename] end) end local meta_on_operator - if not node.type then + if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - node.type, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) + t, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) end - if not node.type then - node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) + if not t then + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) + t = INVALID end end @@ -10736,11 +11023,11 @@ tl.type_check = function(ast, opts) if a.keys.typename == "number" or a.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - node_error(node, "using the '#' operator on this map will always return 0") + error_at(node, "using the '#' operator on this map will always return 0") end end - if node.type.typename ~= "boolean" and not is_unknown(node.type) then + if t.typename ~= "boolean" and not is_unknown(t) then node.known = FACT_TRUTHY end @@ -10754,7 +11041,10 @@ tl.type_check = function(ast, opts) end end - elseif node.op.arity == 2 and binop_types[node.op.op] then + return t + end + + if node.op.arity == 2 and binop_types[node.op.op] then if node.op.op == "or" then node.known = facts_or(node, node.e1.known, node.e2.known) end @@ -10770,15 +11060,18 @@ tl.type_check = function(ast, opts) end local types_op = binop_types[node.op.op] - node.type = types_op[a.typename] and types_op[a.typename][b.typename] + + local t = types_op[a.typename] and types_op[a.typename][b.typename] + local meta_on_operator - if not node.type then + if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - node.type, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) + t, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) end - if not node.type then - node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + if not t then + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + t = INVALID if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end @@ -10787,9 +11080,9 @@ tl.type_check = function(ast, opts) if orig_a.typename == "nominal" and orig_b.typename == "nominal" and not meta_on_operator then if is_a(orig_a, orig_b) then - node.type = resolve_tuple(orig_a) + t = resolve_tuple(orig_a) else - node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) end end @@ -10814,10 +11107,11 @@ tl.type_check = function(ast, opts) convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end - else - error("unknown node op " .. node.op.op) + + return t end - return node.type + + error("unknown node op " .. node.op.op) end, }, ["variable"] = { @@ -10825,48 +11119,54 @@ tl.type_check = function(ast, opts) if node.tk == "..." then local va_sentinel = find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return node_error(node, "cannot use '...' outside a vararg function") + return invalid_at(node, "cannot use '...' outside a vararg function") end end + + local t if node.tk == "_G" then - node.type, node.attribute = simulate_g() + t, node.attribute = simulate_g() else local use = node.is_lvalue and "lvalue" or "use" - node.type, node.attribute = find_var_type(node.tk, use) + t, node.attribute = find_var_type(node.tk, use) + end + if not t then + if lax then + add_unknown(node, node.tk) + return UNKNOWN + end + + return invalid_at(node, "unknown variable: " .. node.tk) end - if node.type and is_typetype(node.type) then - node.type = a_type({ + + if is_typetype(t) then + t = a_type({ y = node.y, x = node.x, typename = "nominal", names = { node.tk }, - found = node.type, - resolved = node.type, + found = t, + resolved = t, }) end - if node.type == nil then - node.type = a_type({ typename = "unknown" }) - if lax then - add_unknown(node, node.tk) - else - return node_error(node, "unknown variable: " .. node.tk) - end - end - return node.type + + return t end, }, ["type_identifier"] = { after = function(node, _children) - node.type, node.attribute = find_var_type(node.tk) - if node.type == nil then - if lax then - node.type = UNKNOWN - add_unknown(node, node.tk) - else - return node_error(node, "unknown variable: " .. node.tk) - end + local typ, attr = find_var_type(node.tk) + node.attribute = attr + if typ then + return typ + end + + if lax then + add_unknown(node, node.tk) + return UNKNOWN end - return node.type + + return invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { @@ -10882,57 +11182,52 @@ tl.type_check = function(ast, opts) t = OPT(t) end add_var(node, node.tk, t).is_func_arg = true - return node.type + return t end, }, ["identifier"] = { - after = function(node, _children) - node.type = node.type or NONE - return node.type + after = function(_node, _children) + return NONE end, }, ["newtype"] = { after = function(node, _children) - node.type = node.type or node.newtype - return node.type + return node.newtype end, }, ["error_node"] = { - after = function(node, _children) - node.type = INVALID - return node.type + after = function(_node, _children) + return INVALID end, }, } visit_node.cbs["break"] = { - after = function(node, _children) - node.type = NONE - return node.type + after = function(_node, _children) + return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] local function after_literal(node) - node.type = a_type({ + node.known = FACT_TRUTHY + return a_type({ y = node.y, x = node.x, typename = node.kind, tk = node.tk, }) - node.known = FACT_TRUTHY - return node.type end visit_node.cbs["string"] = { after = function(node, _children) - after_literal(node) + local t = after_literal(node) if node.expected then - if node.expected.typename == "enum" and is_a(node.type, node.expected) then - node.type = node.expected + if node.expected.typename == "enum" and is_a(t, node.expected) then + t = node.expected end end - return node.type + return t end, } visit_node.cbs["number"] = { after = after_literal } @@ -10940,9 +11235,9 @@ tl.type_check = function(ast, opts) visit_node.cbs["boolean"] = { after = function(node, _children) - after_literal(node) + local t = after_literal(node) node.known = (node.tk == "true") and FACT_TRUTHY or nil - return node.type + return t end, } visit_node.cbs["nil"] = visit_node.cbs["boolean"] @@ -10951,18 +11246,12 @@ tl.type_check = function(ast, opts) visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node, _children) + visit_node.after = function(node, _children, t) if node.expanded then apply_macroexp(node) end - if type(node.type) ~= "table" then - error(node.kind .. " did not produce a type") - end - if type(node.type.typename) ~= "string" then - error(node.kind .. " type does not have a typename") - end - return node.type + return t end local visit_type @@ -10982,12 +11271,12 @@ tl.type_check = function(ast, opts) typ = ensure_fresh_typeargs(typ) if typ.macroexp then - recurse_node(typ.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(typ.macroexp, visit_node, visit_type) check_macroexp_arg_use(typ.macroexp) - if not is_a(typ.macroexp.type, typ) then - error_at(typ.macroexp.type, "macroexp type does not match declaration") + if not is_a(macroexp_type, typ) then + error_at(macroexp_type, "macroexp type does not match declaration") end end @@ -11118,27 +11407,45 @@ tl.type_check = function(ast, opts) end, }, }, - after = function(typ, _children, ret) - if type(ret) ~= "table" then - error(typ.typename .. " did not produce a type") + } + + local function internal_compiler_check(fn) + return function(w, children, t) + t = fn and fn(w, children, t) or t + + if type(t) ~= "table" then + error(((w).kind or (w).typename) .. " did not produce a type") end - if type(ret.typename) ~= "string" then - error("type node does not have a typename") + if type(t.typename) ~= "string" then + error(((w).kind or (w).typename) .. " type does not have a typename") end - return ret - end, - } - if not opts.run_internal_compiler_checks then - visit_node.after = function(node, _children) - if node.expanded then - apply_macroexp(node) + return t + end + end + + local function store_type_after(fn) + return function(w, children, t) + t = fn and fn(w, children, t) or t + + local where = w + + if where.y then + store_type(where.y, where.x, t) end - return node.type + return t end + end - visit_type.after = nil + if opts.run_internal_compiler_checks then + visit_node.after = internal_compiler_check(visit_node.after) + visit_type.after = internal_compiler_check(visit_type.after) + end + + if store_type then + visit_node.after = store_type_after(visit_node.after) + visit_type.after = store_type_after(visit_type.after) end visit_type.cbs["tupletable"] = visit_type.cbs["string"] @@ -11200,193 +11507,15 @@ end -local typename_to_typecode = { - ["typevar"] = tl.typecodes.TYPE_VARIABLE, - ["typearg"] = tl.typecodes.TYPE_VARIABLE, - ["unresolved_typearg"] = tl.typecodes.TYPE_VARIABLE, - ["unresolvable_typearg"] = tl.typecodes.TYPE_VARIABLE, - ["function"] = tl.typecodes.FUNCTION, - ["array"] = tl.typecodes.ARRAY, - ["map"] = tl.typecodes.MAP, - ["tupletable"] = tl.typecodes.TUPLE, - ["interface"] = tl.typecodes.INTERFACE, - ["record"] = tl.typecodes.RECORD, - ["enum"] = tl.typecodes.ENUM, - ["boolean"] = tl.typecodes.BOOLEAN, - ["string"] = tl.typecodes.STRING, - ["nil"] = tl.typecodes.NIL, - ["thread"] = tl.typecodes.THREAD, - ["number"] = tl.typecodes.NUMBER, - ["integer"] = tl.typecodes.INTEGER, - ["union"] = tl.typecodes.IS_UNION, - ["nominal"] = tl.typecodes.NOMINAL, - ["bad_nominal"] = tl.typecodes.NOMINAL, - ["circular_require"] = tl.typecodes.NOMINAL, - ["emptytable"] = tl.typecodes.EMPTY_TABLE, - ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, - ["poly"] = tl.typecodes.IS_POLY, - ["any"] = tl.typecodes.ANY, - ["unknown"] = tl.typecodes.UNKNOWN, - ["invalid"] = tl.typecodes.INVALID, - - ["none"] = tl.typecodes.UNKNOWN, - ["tuple"] = tl.typecodes.UNKNOWN, - ["table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, - ["typetype"] = tl.typecodes.UNKNOWN, - ["nestedtype"] = tl.typecodes.UNKNOWN, -} - -local skip_types = { - ["none"] = true, - ["tuple"] = true, - ["table_item"] = true, - ["unresolved"] = true, - ["typetype"] = true, - ["nestedtype"] = true, -} - function tl.get_types(result, trenv) local filename = result.filename or "?" - - local function mark_array(x) - local arr = x - arr[0] = false - return x - end + trenv = trenv or result.env.trenv if not trenv then - trenv = { - next_num = 1, - typeid_to_num = {}, - tr = { - by_pos = {}, - types = {}, - symbols_by_file = {}, - globals = {}, - }, - } + error("result must have been generated with env.report_types = true", 2) end local tr = trenv.tr - local typeid_to_num = trenv.typeid_to_num - - local get_typenum - - local function store_function(ti, rt) - local args = {} - for _, fnarg in ipairs(rt.args) do - table.insert(args, mark_array({ get_typenum(fnarg), nil })) - end - ti.args = mark_array(args) - local rets = {} - for _, fnarg in ipairs(rt.rets) do - table.insert(rets, mark_array({ get_typenum(fnarg), nil })) - end - ti.rets = mark_array(rets) - ti.vararg = not not rt.is_va - end - - get_typenum = function(t) - assert(t.typeid) - - local n = typeid_to_num[t.typeid] - if n then - return n - end - - - n = trenv.next_num - - local rt = t - if is_typetype(rt) then - rt = rt.def - elseif rt.typename == "tuple" and #rt == 1 then - rt = rt[1] - end - - local ti = { - t = assert(typename_to_typecode[rt.typename]), - str = show_type(t, true), - file = t.filename, - y = t.y, - x = t.x, - } - tr.types[n] = ti - typeid_to_num[t.typeid] = n - trenv.next_num = trenv.next_num + 1 - - if t.found then - ti.ref = get_typenum(t.found) - end - if t.resolved then - rt = t - end - assert(not is_typetype(rt)) - - if is_record_type(rt) then - - local r = {} - for _, k in ipairs(rt.field_order) do - local v = rt.fields[k] - r[k] = get_typenum(v) - end - ti.fields = r - end - - if is_array_type(rt) then - ti.elements = get_typenum(rt.elements) - end - - if rt.typename == "map" then - ti.keys = get_typenum(rt.keys) - ti.values = get_typenum(rt.values) - elseif rt.typename == "enum" then - ti.enums = mark_array(sorted_keys(rt.enumset)) - elseif rt.typename == "function" then - store_function(ti, rt) - elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then - local tis = {} - - for _, pt in ipairs(rt.types) do - table.insert(tis, get_typenum(pt)) - end - - ti.types = mark_array(tis) - end - - return n - end - - local visit_node = { allow_missing_cbs = true } - local visit_type = { allow_missing_cbs = true } - - local ft = {} - tr.by_pos[filename] = ft - - local function store(y, x, typ) - if not typ or skip_types[typ.typename] then - return - end - - local yt = ft[y] - if not yt then - yt = {} - ft[y] = yt - end - - yt[x] = get_typenum(typ) - end - - visit_node.after = function(node) - store(node.y, node.x, node.type) - end - - visit_type.after = function(typ) - store(typ.y or 0, typ.x or 0, typ) - end - - recurse_node(result.ast, visit_node, visit_type) tr.by_pos[filename][0] = nil @@ -11427,7 +11556,7 @@ function tl.get_types(result, trenv) i = i + 1 local id if s.typ then - id = get_typenum(s.typ) + id = get_typenum(trenv, s.typ) elseif s.name == "@{" then level = level + 1 stack[level] = i @@ -11448,7 +11577,7 @@ function tl.get_types(result, trenv) for _, name in ipairs(gkeys) do if name:sub(1, 1) ~= "@" then local var = result.env.globals[name] - tr.globals[name] = get_typenum(var.t) + tr.globals[name] = get_typenum(trenv, var.t) end end diff --git a/tl.tl b/tl.tl index 05c5f7ccc..ab516b4e5 100644 --- a/tl.tl +++ b/tl.tl @@ -45,9 +45,11 @@ local record tl modules: {string:Type} loaded: {string:Result} loaded_order: {string} + trenv: TypeReportEnv gen_compat: CompatMode gen_target: TargetMode keep_going: boolean + report_types: boolean end record Symbol @@ -252,7 +254,8 @@ if TL_DEBUG then return end - io.stderr:write(info.name or "", info.currentline > 0 and "@" .. info.currentline or "", " :: ", event as string, "\n") + local name = info.name or "", info.currentline > 0 and "@" .. info.currentline or "" + io.stderr:write(name, " :: ", event as string, "\n") io.stderr:flush() else count = count + 100 @@ -1416,6 +1419,8 @@ local record Node constnum: number conststr: string failstore: boolean + discarded_tuple: boolean + receiver: Type -- table literal array_len: integer @@ -1431,7 +1436,6 @@ local record Node -- macroexp expanded: Node - type: Type decltype: Type opt: boolean end @@ -3922,7 +3926,7 @@ local function recurse_node(root: Node, end if TL_DEBUG then - tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s] = %s", kprint, ast.type and show_type(ast.type)) + tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s]", kprint) end return ret @@ -4695,6 +4699,198 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return (concat_output(code):gsub(" *\n", "\n")) end +-------------------------------------------------------------------------------- +-- Type collection for report +-------------------------------------------------------------------------------- + +local typename_to_typecode : {TypeName:integer} = { + ["typevar"] = tl.typecodes.TYPE_VARIABLE, + ["typearg"] = tl.typecodes.TYPE_VARIABLE, + ["unresolved_typearg"] = tl.typecodes.TYPE_VARIABLE, + ["unresolvable_typearg"] = tl.typecodes.TYPE_VARIABLE, + ["function"] = tl.typecodes.FUNCTION, + ["array"] = tl.typecodes.ARRAY, + ["map"] = tl.typecodes.MAP, + ["tupletable"] = tl.typecodes.TUPLE, + ["interface"] = tl.typecodes.INTERFACE, + ["record"] = tl.typecodes.RECORD, + ["enum"] = tl.typecodes.ENUM, + ["boolean"] = tl.typecodes.BOOLEAN, + ["string"] = tl.typecodes.STRING, + ["nil"] = tl.typecodes.NIL, + ["thread"] = tl.typecodes.THREAD, + ["number"] = tl.typecodes.NUMBER, + ["integer"] = tl.typecodes.INTEGER, + ["union"] = tl.typecodes.IS_UNION, + ["nominal"] = tl.typecodes.NOMINAL, + ["bad_nominal"] = tl.typecodes.NOMINAL, + ["circular_require"] = tl.typecodes.NOMINAL, + ["emptytable"] = tl.typecodes.EMPTY_TABLE, + ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, + ["poly"] = tl.typecodes.IS_POLY, + ["any"] = tl.typecodes.ANY, + ["unknown"] = tl.typecodes.UNKNOWN, + ["invalid"] = tl.typecodes.INVALID, + -- types that should be skipped or not present: + ["none"] = tl.typecodes.UNKNOWN, + ["tuple"] = tl.typecodes.UNKNOWN, + ["table_item"] = tl.typecodes.UNKNOWN, + ["unresolved"] = tl.typecodes.UNKNOWN, + ["typetype"] = tl.typecodes.UNKNOWN, + ["nestedtype"] = tl.typecodes.UNKNOWN, +} + +local skip_types: {TypeName: boolean} = { + ["none"] = true, + ["tuple"] = true, + ["table_item"] = true, + ["unresolved"] = true, + ["typetype"] = true, + ["nestedtype"] = true, +} + +local get_typenum: function(trenv: TypeReportEnv, t: Type): integer +local type StoreType = function(y: integer, x: integer, typ: Type) + +local function sorted_keys(m: {A:B}):{A} + local keys = {} + for k, _ in pairs(m) do + table.insert(keys, k) + end + table.sort(keys) + return keys +end + +-- mark array for JSON-encoded reports in `tl types` +local function mark_array(x: T): T + local arr = x as {boolean} + arr[0] = false + return x +end + +function tl.init_type_report(): TypeReportEnv + return { + next_num = 1, + typeid_to_num = {}, + tr = { + by_pos = {}, + types = {}, + symbols_by_file = {}, + globals = {}, + }, + } +end + +local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: Type) + local args: {{integer, string}} = {} + for _, fnarg in ipairs(rt.args) do + table.insert(args, mark_array { get_typenum(trenv, fnarg), nil }) + end + ti.args = mark_array(args) + local rets: {{integer, string}} = {} + for _, fnarg in ipairs(rt.rets) do + table.insert(rets, mark_array { get_typenum(trenv, fnarg), nil }) + end + ti.rets = mark_array(rets) + ti.vararg = not not rt.is_va +end + +get_typenum = function(trenv:TypeReportEnv, t: Type): integer + assert(t.typeid) + -- try by typeid + local n = trenv.typeid_to_num[t.typeid] + if n then + return n + end + + local tr = trenv.tr + + -- it's a new entry: store and increment + n = trenv.next_num + + local rt = t + if is_typetype(rt) then + rt = rt.def + elseif rt.typename == "tuple" and #rt == 1 then + rt = rt[1] + end + + local ti: TypeInfo = { + t = assert(typename_to_typecode[rt.typename]), + str = show_type(t, true), + file = t.filename, + y = t.y, + x = t.x, + } + tr.types[n] = ti + trenv.typeid_to_num[t.typeid] = n + trenv.next_num = trenv.next_num + 1 + + if t.found then + ti.ref = get_typenum(trenv, t.found) + end + if t.resolved then + rt = t + end + assert(not is_typetype(rt)) + + if is_record_type(rt) then + -- store record field info + local r = {} + for _, k in ipairs(rt.field_order) do + local v = rt.fields[k] + r[k] = get_typenum(trenv, v) + end + ti.fields = r + end + + if is_array_type(rt) then + ti.elements = get_typenum(trenv, rt.elements) + end + + if rt.typename == "map" then + ti.keys = get_typenum(trenv, rt.keys) + ti.values = get_typenum(trenv, rt.values) + elseif rt.typename == "enum" then + ti.enums = mark_array(sorted_keys(rt.enumset)) + elseif rt.typename == "function" then + store_function(trenv, ti, rt) + elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then + local tis = {} + + for _, pt in ipairs(rt.types) do + table.insert(tis, get_typenum(trenv, pt)) + end + + ti.types = mark_array(tis) + end + + return n +end + +local function make_type_reporter(filename: string, trenv: TypeReportEnv): StoreType +-- local filename = result.filename or "?" + + local ft: {integer:{integer:integer}} = {} + trenv.tr.by_pos[filename] = ft + + local function store_type(y: integer, x: integer, typ: Type) + if not typ or skip_types[typ.typename] then + return + end + + local yt = ft[y] + if not yt then + yt = {} + ft[y] = yt + end + + yt[x] = get_typenum(trenv, typ) + end + + return store_type +end + -------------------------------------------------------------------------------- -- Type check -------------------------------------------------------------------------------- @@ -5226,15 +5422,6 @@ local record Variable implemented: {string:boolean} end -local function sorted_keys(m: {A:B}):{A} - local keys = {} - for k, _ in pairs(m) do - table.insert(keys, k) - end - table.sort(keys) - return keys -end - local function require_module(module_name: string, lax: boolean, env: Env): Type, boolean local mod = env.modules[module_name] if mod then @@ -6050,9 +6237,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local type Scope = {string:Variable} local st: {Scope} = { env.globals } - local symbol_list: {Symbol} = {} - local symbol_list_n = 0 - local all_needs_compat = {} local dependencies: {string:string} = {} @@ -6061,6 +6245,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local module_type: Type + local symbol_list: {Symbol} + local symbol_list_n = 0 + local store_type: StoreType + if env.report_types then + symbol_list = {} + env.trenv = env.trenv or tl.init_type_report() + store_type = make_type_reporter(filename or "?", env.trenv) + end + local enum VarUse "use" "lvalue" @@ -6165,16 +6358,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end msg = msg:format(table.unpack(showt)) end + local name = where.filename or filename return { y = where.y, x = where.x, msg = msg, - filename = where.filename or filename, + filename = name, } end local function error_at(w: Where, msg: string, ...:Type): boolean + assert(w.y) + local e = Err(w, msg, ...) if e then table.insert(errors, e) @@ -6321,7 +6517,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if f.min_arity then return end - local tuple = f.args.tuple + local tuple = f.args local n = #tuple if f.args.is_va then n = n - 1 @@ -6556,10 +6752,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }) end - local function node_error(node: Node, msg: string, ...:Type): Type - error_at(node, msg, ...) - node.type = INVALID - return node.type + local function invalid_at(where: Where, msg: string, ...:Type): Type + error_at(where, msg, ...) + return INVALID end local function add_unknown(node: Node, name: string) @@ -6730,8 +6925,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if node and t.typename ~= "unresolved" and t.typename ~= "none" then - node.type = node.type or t + if symbol_list and node and t.typename ~= "unresolved" and t.typename ~= "none" then local slot: integer if node.symbol_list_slot then slot = node.symbol_list_slot @@ -6948,7 +7142,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function begin_scope(node?: Node) table.insert(st, {}) - if node then + if symbol_list and node then symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end @@ -6985,7 +7179,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string check_for_unused_vars(scope) table.remove(st) - if node then + if symbol_list and node then if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -6998,8 +7192,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local end_scope_and_none_type = function(node: Node, _children: {Type}): Type end_scope(node) - node.type = NONE - return node.type + return NONE end local resolve_nominal: function(t: Type): Type @@ -7015,7 +7208,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, tt in ipairs(t.typevals) do add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t as Node, def) + local ret = resolve_typevars_at(t, def) end_scope() return ret elseif t.typevals then @@ -7121,7 +7314,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local all_errs = {} for i = 1, #t1.typevals do local _, errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1 as Node, errs, all_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + add_errs_prefixing(t1, errs, all_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") end if #all_errs == 0 then return true @@ -7196,7 +7389,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = 1, math.min(#t1.types, #t2.types) do local ok, err = same_type(t1.types[i], t2.types[i]) if not ok then - add_errs_prefixing(t1 as Node, err, all_errs, "values") + add_errs_prefixing(t1, err, all_errs, "values") end end return any_errors(all_errs) @@ -7204,11 +7397,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local all_errs = {} local k_ok, k_errs = same_type(t1.keys, t2.keys) if not k_ok then - add_errs_prefixing(t1 as Node, k_errs, all_errs, "keys") + add_errs_prefixing(t1, k_errs, all_errs, "keys") end local v_ok, v_errs = same_type(t1.values, t2.values) if not v_ok then - add_errs_prefixing(t1 as Node, v_errs, all_errs, "values") + add_errs_prefixing(t1, v_errs, all_errs, "values") end return any_errors(all_errs) elseif t1.typename == "union" then @@ -7250,7 +7443,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for i = 1, #t1.rets do local _, errs = same_type(t1.rets[i], t2.rets[i]) - add_errs_prefixing(t1 as Node, errs, all_errs, "return " .. i) + add_errs_prefixing(t1, errs, all_errs, "return " .. i) end return any_errors(all_errs) end @@ -7428,6 +7621,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true elseif is_self(t1) then + if is_self(t2) then + return true + end + return is_a(resolve_tuple_and_nominal(t1), t2, for_equality) elseif is_self(t2) then @@ -7915,7 +8112,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local on_arg_id = function(node: Node, _i: integer): {Node, Node} if used[node.tk] then - node_error(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -7926,7 +8123,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function apply_macroexp(orignode: Node) local expanded = orignode.expanded - local savetype = orignode.type local saveknown = orignode.known orignode.expanded = nil @@ -7936,11 +8132,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for k, v in pairs(expanded as {any:any}) do (orignode as {any:any})[k] = v end - orignode.type = savetype orignode.known = saveknown end - local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, ? integer): Type + local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, ? integer): Type, Type do local function mark_invalid_typeargs(f: Type) if f.typeargs then @@ -8055,7 +8250,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function fail_call(node: Node, func: Type, nargs: integer, errs: {Error}): Type + local function fail_call(where: Where, func: Type, nargs: integer, errs: {Error}): Type if errs then -- report the errors from the first match for _, err in ipairs(errs) do @@ -8077,15 +8272,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else table.insert(expects, show_arity(func)) end - node_error(node, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = func.typename == "poly" and func.types[1] or func mark_invalid_typeargs(f) - return resolve_typevars_at(node, f.rets) + return resolve_typevars_at(where, f.rets) end - local function check_call(node: Node, where_args: {Node}, func: Type, args: {Type}, is_method: boolean, argdelta: integer): Type, Type + local function check_call(where: Where, where_args: {Node}, func: Type, args: {Type}, expected: Type, typetype_funcall: boolean, is_method: boolean, argdelta: integer): Type, Type assert(type(func) == "table") assert(type(args) == "table") @@ -8096,13 +8291,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string argdelta = is_method and -1 or argdelta or 0 if is_method and args[1] then - add_var(nil, "@self", a_type({ typename = "typetype", y = node.y, x = node.x, def = args[1] })) + add_var(nil, "@self", a_type({ typename = "typetype", y = where.y, x = where.x, def = args[1] })) end local is_func = func.typename == "function" local is_poly = func.typename == "poly" if not (is_func or is_poly) then - return node_error(node, "not a function: %s", func) + return invalid_at(where, "not a function: %s", func) end local passes, n = 1, 1 @@ -8119,41 +8314,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local f = is_func and func or func.types[i] if f.is_method and not is_method then if args[1] and is_a(args[1], f.args[1]) then - -- a non-"@funcall" `node` means a synthesized call, e.g. from a metamethod - if node.kind == "op" and node.op.op == "@funcall" then - local receiver_is_typetype = node.e1.e1 and node.e1.e1.type and node.e1.e1.type.resolved and node.e1.e1.type.resolved.typename == "typetype" - if not receiver_is_typetype then - add_warning("hint", node, "invoked method as a regular function: consider using ':' instead of '.'") - end + -- a non-"@funcall" means a synthesized call, e.g. from a metamethod + if not typetype_funcall then + add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") end else - return node_error(node, "invoked method as a regular function: use ':' instead of '.'") + return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") end end - local expected = #f.args + local wanted = #f.args set_min_arity(f) -- simple functions: - if (is_func and ((given <= expected and given >= f.min_arity) or (f.args.is_va and given > expected) or (lax and given <= expected))) + if (is_func and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) -- poly, pass 1: try exact arity matches first - or (is_poly and ((pass == 1 and given == expected) + or (is_poly and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < expected and (lax or given >= f.min_arity)) + or (pass == 2 and given < wanted and (lax or given >= f.min_arity)) -- poly, pass 3: then finally try vararg functions - or (pass == 3 and f.args.is_va and given > expected))) + or (pass == 3 and f.args.is_va and given > wanted))) then push_typeargs(f) - local matched, errs = check_args_rets(node, where_args, f, args, node.expected, argdelta) + local matched, errs = check_args_rets(where, where_args, f, args, expected, argdelta) if matched then -- success! return matched, f end first_errs = first_errs or errs - if node.expected then + if expected then -- revert inferred returns - infer_emptytables(node, where_args, f.rets, f.rets, argdelta) + infer_emptytables(where, where_args, f.rets, f.rets, argdelta) end if is_poly then @@ -8166,27 +8358,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - return fail_call(node, func, given, first_errs) + return fail_call(where, func, given, first_errs) end - type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta?: integer): Type + type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta?: integer): Type, Type if node.expected and node.expected.typename ~= "tuple" then node.expected = a_type { typename = "tuple", node.expected } end begin_scope() - local ret, f = check_call(node, where_args, func, args, is_method, argdelta) + + local typetype_funcall = not not ( + node.kind == "op" + and node.op.op == "@funcall" + and node.e1 + and node.e1.receiver + and node.e1.receiver.resolved + and node.e1.receiver.resolved.typename == "typetype" + ) + + local ret, f = check_call(node, where_args, func, args, node.expected, typetype_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() - if e1 then - e1.type = f + + if store_type and e1 then + store_type(e1.y, e1.x, f) end if func.macroexp then expand_macroexp(node, where_args, func.macroexp) end - return ret + return ret, f end end @@ -8371,36 +8574,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_unknown(node, var) end - local existing, scope, existing_attr = find_var(var) - if existing and scope > 1 then - node_error(node, "cannot define a global when a local with the same name is in scope") - return nil - end - local is_const = node.attribute ~= nil - + local existing, scope, existing_attr = find_var(var) if existing then - if is_assigning and existing_attr then - node_error(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) - end - if existing_attr and not is_const then - node_error(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) - end - if (not existing_attr) and is_const then - node_error(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - end - if valtype and not same_type(existing.t, valtype) then - node_error(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + if scope > 1 then + error_at(node, "cannot define a global when a local with the same name is in scope") + elseif is_assigning and existing_attr then + error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + elseif existing_attr and not is_const then + error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + elseif (not existing_attr) and is_const then + error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) + elseif valtype and not same_type(existing.t, valtype) then + error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) end return nil end st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - if node then - node.type = node.type or valtype - end - return st[1][var] end @@ -8416,8 +8608,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function add_internal_function_variables(node: Node) - add_var(nil, "@is_va", node.args.type.is_va and ANY or NIL) + local function add_internal_function_variables(node: Node, args: Type) + assert(args.typename == "tuple") + + add_var(nil, "@is_va", args.is_va and ANY or NIL) add_var(nil, "@return", node.rets or a_type { typename = "tuple" }) if node.typeargs then @@ -8430,10 +8624,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function add_function_definition_for_recursion(node: Node) - local args: Type = a_type { typename = "tuple", is_va = node.args.type.is_va } - for _, fnarg in ipairs(node.args) do - table.insert(args, fnarg.type) + local function add_function_definition_for_recursion(node: Node, fnargs: Type) + assert(fnargs.typename == "tuple") + + local args: Type = TUPLE({}) + args.is_va = fnargs.is_va + for _, fnarg in ipairs(fnargs) do + table.insert(args, fnarg) end add_var(nil, node.name.tk, a_type { @@ -8449,7 +8646,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string st[#st]["@unresolved"] = nil for name, nodes in pairs(unresolved.t.labels) do for _, node in ipairs(nodes) do - node_error(node, "no visible label '" .. name .. "' for goto") + error_at(node, "no visible label '" .. name .. "' for goto") end end for name, types in pairs(unresolved.t.nominals) do @@ -8509,15 +8706,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local last = vals[#vals] - if last.typename == "tuple" then - -- ...if the last is a tuple, unpack it - is_va = last.is_va - for _, v in ipairs(last) do - table.insert(ret, v) + if last then + if last.typename == "tuple" then + -- ...if the last is a tuple, unpack it + is_va = last.is_va + for _, v in ipairs(last) do + table.insert(ret, v) + end + else + -- ...otherwise simply get it + table.insert(ret, last) end - else - -- ...otherwise simply get it - table.insert(ret, last) end -- ...if the last is vararg, repeat its type until it matches the number of wanted args @@ -8546,7 +8745,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t then return t else - return node_error(node, errmsg) + return invalid_at(node, errmsg) end end @@ -8634,7 +8833,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return meta_t end - return node_error(bnode, errm, erra, errb) + return invalid_at(bnode, errm, erra, errb) end expand_type = function(where: Where, old: Type, new: Type): Type @@ -9119,7 +9318,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function special_pcall_xpcall(node: Node, _a: Type, b: {Type}, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then - node_error(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") return TUPLE { BOOLEAN } end @@ -9154,34 +9353,79 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local special_functions: {string : function(Node,Type,{Type},integer):Type } = { + ["pairs"] = function(node: Node, a: Type, b: {Type}, argdelta: integer): Type + if not b[1] then + return invalid_at(node, "pairs requires an argument") + end + local t = resolve_tuple_and_nominal(b[1]) + if is_array_type(t) then + add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + end + + if t.typename ~= "map" then + if not (lax and is_unknown(t)) then + if is_record_type(t) then + match_all_record_field_names(node.e2, t, t.field_order, + "attempting pairs on a record with attributes of different types") + local ct = t.typename == "record" and "{string:any}" or "{any:any}" + add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + else + error_at(node.e2, "cannot apply pairs on values of type: %s", t) + end + end + end + + return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + end, + + ["ipairs"] = function(node: Node, a: Type, b: {Type}, argdelta: integer): Type + if not b[1] then + return invalid_at(node, "ipairs requires an argument") + end + local t = resolve_tuple_and_nominal(b[1]) + + if t.typename == "tupletable" then + local arr_type = arraytype_from_tuple(node.e2, t) + if not arr_type then + return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", t) + end + elseif not is_array_type(t) then + if not (lax and (is_unknown(t) or t.typename == "emptytable")) then + return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", t) + end + end + + return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + end, + ["rawget"] = function(node: Node, _a: Type, b: {Type}, _argdelta: integer): Type -- TODO should those offsets be fixed by _argdelta? if #b == 2 then return type_check_index(node.e2[1], node.e2[2], b[1], b[2]) else - return node_error(node, "rawget expects two arguments") + return invalid_at(node, "rawget expects two arguments") end end, ["require"] = function(node: Node, _a: Type, b: {Type}, _argdelta: integer): Type if #b ~= 1 then - return node_error(node, "require expects one literal argument") + return invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return ANY + return a_type { typename = "any" } end local module_name = assert(node.e2[1].conststr) local t, found = require_module(module_name, lax, env) if not found then - return node_error(node, "module not found: '" .. module_name .. "'") + return invalid_at(node, "module not found: '" .. module_name .. "'") end if t.typename == "invalid" then if lax then return UNKNOWN end - return node_error(node, "no type information for required module: '" .. module_name .. "'") + return invalid_at(node, "no type information for required module: '" .. module_name .. "'") end dependencies[module_name] = t.filename @@ -9206,13 +9450,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if special then return special(node, a, b, argdelta) else - return type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta) + return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then - table.insert(b, 1, node.e1.e1.type) - return type_check_function_call(node, node.e2, a, b, node.e1, true) + table.insert(b, 1, node.e1.receiver) + return (type_check_function_call(node, node.e2, a, b, node.e1, true)) else - return type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta) + return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) end end @@ -9248,9 +9492,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return UNKNOWN else if node.exps then - return node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return node_error(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end @@ -9301,13 +9545,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local type CheckableKey = string | number | boolean - local function check_redeclared_key(node: Node, ctx: Node.ExpectedContext, seen_keys: {CheckableKey:Node}, key: CheckableKey) + local function check_redeclared_key(where: Where, ctx: Node.ExpectedContext, seen_keys: {CheckableKey:Where}, key: CheckableKey) if key ~= nil then local s = seen_keys[key] if s then - node_error(node, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) + error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) else - seen_keys[key] = node + seen_keys[key] = where end end end @@ -9330,7 +9574,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local last_array_idx = 1 local largest_array_idx = -1 - local seen_keys: {CheckableKey:Node} = {} + local seen_keys: {CheckableKey:Where} = {} for i, child in ipairs(children) do assert(child.typename == "table_item") @@ -9408,7 +9652,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ.keys = expand_type(node, typ.keys, INTEGER) typ.values = expand_type(node, typ.values, typ.elements) typ.elements = nil - node_error(node, "cannot determine type of table literal") + error_at(node, "cannot determine type of table literal") elseif is_record and is_array then typ.typename = "record" typ.interface_list = { @@ -9430,7 +9674,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ.fields = nil typ.field_order = nil else - node_error(node, "cannot determine type of table literal") + error_at(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -9463,7 +9707,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif is_tuple then typ.typename = "tupletable" if not typ.types or #typ.types == 0 then - node_error(node, "cannot determine type of tuple elements") + error_at(node, "cannot determine type of tuple elements") end end @@ -9500,7 +9744,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else if infertype and infertype.typename == "unresolvable_typearg" then - node_error(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") + error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false infertype = INVALID elseif infertype and infertype.is_method then @@ -9514,15 +9758,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if var.attribute == "total" then local rd = decltype and resolve_tuple_and_nominal(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - node_error(var, "attribute only applies to maps and records") + error_at(var, "attribute only applies to maps and records") ok = false elseif not infertype then - node_error(var, "variable declared does not declare an initialization value") + error_at(var, "variable declared does not declare an initialization value") ok = false elseif not (node.exps[i] and node.exps[i].attribute == "total") then local ri = resolve_tuple_and_nominal(infertype) if ri.typename ~= "map" and ri.typename ~= "record" then - node_error(var, "attribute only applies to maps and records") + error_at(var, "attribute only applies to maps and records") ok = false elseif not infertype.is_total then local missing = "" @@ -9530,10 +9774,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string missing = " (missing: " .. table.concat(infertype.missing, ", ") .. ")" end if ri.typename == "map" then - node_error(var, "map variable declared does not declare values for all possible keys" .. missing) + error_at(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri.typename == "record" then - node_error(var, "record variable declared does not declare values for all fields" .. missing) + error_at(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -9561,7 +9805,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function total_check_key(key: CheckableKey, seen_keys: {CheckableKey:Node}, is_total: boolean, missing: {string}): boolean, {string} + local function total_check_key(key: CheckableKey, seen_keys: {CheckableKey:Where}, is_total: boolean, missing: {string}): boolean, {string} if not seen_keys[key] then missing = missing or {} table.insert(missing, tostring(key)) @@ -9570,7 +9814,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_record_check(t: Type, seen_keys: {CheckableKey:Node}): boolean, {string} + local function total_record_check(t: Type, seen_keys: {CheckableKey:Where}): boolean, {string} if t.meta_field_order then return false end @@ -9585,7 +9829,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_map_check(t: Type, seen_keys: {CheckableKey:Node}): boolean, {string} + local function total_map_check(t: Type, seen_keys: {CheckableKey:Where}): boolean, {string} local k = resolve_tuple_and_nominal(t.keys) local is_total = true local missing: {string} @@ -9618,6 +9862,50 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end + local enum MissingError + "missing" + end + + local function check_assignment(where: Where, vartype: Type, valtype: Type, varname: string, attr: Attribute): Type, Type, MissingError + if varname then + if widen_back_var(varname) then + vartype, attr = find_var_type(varname) + if not vartype then + error_at(where, "unknown variable") + return nil + end + end + end + if attr == "close" or attr == "const" or attr == "total" then + error_at(where, "cannot assign to <" .. attr .. "> variable") + return nil + end + + local var = resolve_tuple_and_nominal(vartype) + if is_typetype(var) then + error_at(where, "cannot reassign a type") + return nil + end + + if not valtype then + error_at(where, "variable is not being assigned a value") + return nil, nil, "missing" + end + + assert_is_a(where, valtype, vartype, "in assignment") + + local val = resolve_tuple_and_nominal(valtype) + + return var, val + end + + local function discard_tuple(node: Node, t: Type, b: Type): Type + if b.typename == "tuple" then + node.discarded_tuple = true + end + return resolve_tuple(t) + end + local visit_node: Visitor = {} visit_node.cbs = { @@ -9635,8 +9923,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_scope(node) end -- TODO extract node type from `return` - node.type = NONE - return node.type + return NONE end }, ["local_type"] = { @@ -9644,7 +9931,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local name = node.var.tk local resolved, aliasing = get_type_declaration(node) local var = add_var(node.var, name, resolved, node.var.attribute) - node.value.type = resolved +--@-- node.value.type = resolved if aliasing then var.aliasing = aliasing node.value.is_alias = true @@ -9652,8 +9939,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, after = function(node: Node, _children: {Type}): Type dismiss_unresolved(node.var.tk) - node.type = NONE - return node.type + return NONE end, }, ["global_type"] = { @@ -9680,14 +9966,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, after = function(node: Node, _children: {Type}): Type dismiss_unresolved(node.var.tk) - node.type = NONE - return node.type + return NONE end, }, ["local_declaration"] = { before = function(node: Node) - for _, var in ipairs(node.vars) do - reserve_symbol_list_slot(var) + if symbol_list then + for _, var in ipairs(node.vars) do + reserve_symbol_list_slot(var) + end end end, before_exp = set_expected_types_to_decltypes, @@ -9698,12 +9985,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if var.attribute == "close" then if opts.gen_target == "5.4" then if encountered_close then - node_error(var, "only one per declaration is allowed") + error_at(var, "only one per declaration is allowed") else encountered_close = true end else - node_error(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") end end @@ -9711,9 +9998,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if var.attribute == "close" then if not type_is_closable(t) then - node_error(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - node_error(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end @@ -9726,14 +10013,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rt = resolve_tuple_and_nominal(t) if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then - add_var(where, var.tk, infer_at(where, infertype), "const", "narrowed_declaration") + t = infer_at(where, infertype) + add_var(where, var.tk, t, "const", "narrowed_declaration") end end + if store_type then + store_type(var.y, var.x, t) + end + dismiss_unresolved(var.tk) end - node.type = NONE - return node.type + return NONE end, }, ["global_declaration"] = { @@ -9744,70 +10035,55 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - node_error(var, "globals may not be ") + error_at(var, "globals may not be ") end add_global(var, var.tk, t, is_inferred) - var.type = t dismiss_unresolved(var.tk) end - node.type = NONE - return node.type + return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltypes, after = function(node: Node, children: {Type}): Type local valtypes: {Type} = get_assignment_values(children[3], #children[1]) - local exps = flatten_list(valtypes) + valtypes = flatten_list(valtypes) for i, vartype in ipairs(children[1]) do local varnode = node.vars[i] - local attr = varnode.attribute - if varnode.kind == "variable" then - if widen_back_var(varnode.tk) then - vartype, attr = find_var_type(varnode.tk) + local varname = varnode.tk + local rvar, rval, err = check_assignment(varnode, vartype, valtypes[i], varname, varnode.attribute) + if err == "missing" then + if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then + local rets = children[3] + if rets.typename == "tuple" then + local msg = #rets == 1 + and "only 1 value is returned by the function" + or ("only " .. #rets .. " values are returned by the function") + add_warning("hint", varnode, msg) + end end end - if attr then - node_error(varnode, "cannot assign to <" .. attr .. "> variable") - end - if vartype then - local val = exps[i] - if is_typetype(resolve_tuple_and_nominal(vartype)) then - node_error(varnode, "cannot reassign a type") - elseif val then - assert_is_a(varnode, val, vartype, "in assignment") - -- assigning a function - local rval = resolve_tuple_and_nominal(val) - if rval.typename == "function" then - widen_all_unions() - end + if rval and rvar then + -- assigning a function + if rval.typename == "function" then + widen_all_unions() + end - if varnode.kind == "variable" and vartype.typename == "union" then - -- narrow union - add_var(varnode, varnode.tk, val, nil, "narrow") - end - else - node_error(varnode, "variable is not being assigned a value") - if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then - local rets = node.exps[1].type - if rets.typename == "tuple" then - local msg = #rets == 1 - and "only 1 value is returned by the function" - or ("only " .. #rets .. " values are returned by the function") - add_warning("hint", varnode, msg) - end - end + if varname and rvar.typename == "union" then + -- narrow union + add_var(varnode, varname, rval, nil, "narrow") + end + + if store_type then + store_type(varnode.y, varnode.x, valtypes[i]) end - else - node_error(varnode, "unknown variable") end end - node.type = NONE - return node.type + return NONE end, }, ["if"] = { @@ -9824,8 +10100,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string infer_negation_of_if_blocks(node, node, #node.if_blocks) end - node.type = NONE - return node.type + return NONE end, }, ["if_block"] = { @@ -9847,8 +10122,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.block_returns = true end - node.type = NONE - return node.type + return NONE end }, ["while"] = { @@ -9868,11 +10142,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions() local label_id = "::" .. node.label .. "::" if st[#st][label_id] then - node_error(node, "label '" .. node.label .. "' already defined at " .. filename ) + error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) end local unresolved = st[#st]["@unresolved"] - node.type = a_type { y = node.y, x = node.x, typename = "none" } - local var = add_var(node, label_id, node.type) + local var = add_var(node, label_id, a_type { y = node.y, x = node.x, typename = "none" }) if unresolved then if unresolved.t.labels[node.label] then var.used = true @@ -9880,6 +10153,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string unresolved.t.labels[node.label] = nil end end, + after = function(): Type + return NONE + end }, ["goto"] = { after = function(node: Node, _children: {Type}): Type @@ -9888,8 +10164,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string unresolved.labels[node.label] = unresolved.labels[node.label] or {} table.insert(unresolved.labels[node.label], node) end - node.type = NONE - return node.type + + return NONE end, }, ["repeat"] = { @@ -9904,54 +10180,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string before = function(node: Node) begin_scope(node) end, - before_statements = function(node: Node) + before_statements = function(node: Node, children: {Type}) + local exptypes = children[2] + widen_all_unions(node) local exp1 = node.exps[1] local args = { typename = "tuple", - node.exps[2] and node.exps[2].type, - node.exps[3] and node.exps[3].type + node.exps[2] and exptypes[2], + node.exps[3] and exptypes[3] } - local exp1type = resolve_for_call(exp1.type, args, false) + local exp1type = resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then - type_check_function_call(exp1, {node.exps[2], node.exps[3]}, exp1type, args, exp1, false, 0) - exp1type = exp1.type or exp1type + local _: Type + _, exp1type = type_check_function_call(exp1, {node.exps[2], node.exps[3]}, exp1type, args, exp1, false, 0) end if exp1type.typename == "function" then - -- check common errors: - if exp1.op and exp1.op.op == "@funcall" then - local t = resolve_tuple_and_nominal(exp1.e2.type) - if exp1.e1.tk == "pairs" and is_array_type(t) then - add_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") - end - - if exp1.e1.tk == "pairs" and t.typename ~= "map" then - if not (lax and is_unknown(t)) then - if is_record_type(t) then - match_all_record_field_names(exp1.e2, t, t.field_order, - "attempting pairs loop on a record with attributes of different types") - local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) - else - node_error(exp1.e2, "cannot apply pairs on values of type: %s", exp1.e2.type) - end - end - elseif exp1.e1.tk == "ipairs" then - if t.typename == "tupletable" then - local arr_type = arraytype_from_tuple(exp1.e2, t) - if not arr_type then - node_error(exp1.e2, "attempting ipairs loop on tuple that's not a valid array: %s", exp1.e2.type) - end - elseif not is_array_type(t) then - if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - node_error(exp1.e2, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type) - end - end - end - end - -- TODO: check that exp1's arguments match with (optional self, explicit iterator, state) local last: Type local rets = exp1type.rets @@ -9971,11 +10217,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local nrets = #rets local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - node_error(at, "too many variables for this iterator; it produces " .. n_values) + error_at(at, "too many variables for this iterator; it produces " .. n_values) end else if not (lax and is_unknown(exp1type)) then - node_error(exp1, "expression in for loop does not return an iterator") + error_at(exp1, "expression in for loop does not return an iterator") end end end, @@ -10028,14 +10274,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if #children[1] > nrets and (not lax) and not vatype then - node_error(node, what ..": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) + error_at(node, what ..": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) end if nrets > 1 and #node.exps == 1 and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") - and #node.exps[1].e2.type > 1 then + and node.exps[1].discarded_tuple then add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end @@ -10051,27 +10297,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - node.type = NONE - return node.type + return NONE end, }, ["variable_list"] = { - after = function(node: Node, children: {Type}): Type - node.type = TUPLE(children) + after = function(_node: Node, children: {Type}): Type + local tuple = TUPLE(children) -- explode last tuple: (1, 2, (3, 4)) becomes (1, 2, 3, 4) - local n = #children - if n > 0 and children[n].typename == "tuple" then - if children[n].is_va then - node.type.is_va = true + local n = #tuple + if n > 0 and tuple[n].typename == "tuple" then + local final_tuple = tuple[n] + if final_tuple.is_va then + tuple.is_va = true end - local tuple = children[n] - for i, c in ipairs(tuple) do - children[n + i - 1] = c + for i, c in ipairs(final_tuple) do + tuple[n + i - 1] = c end end - return node.type + return tuple end, }, ["table_literal"] = { @@ -10110,130 +10355,130 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string after = function(node: Node, children: {Type}): Type node.known = FACT_TRUTHY - if node.expected then - local decltype = resolve_tuple_and_nominal(node.expected) + if not node.expected then + return infer_table_literal(node, children) + end - if decltype.typename == "union" then - local single_table_type: Type - local single_table_rt: Type - - for _, t in ipairs(decltype.types) do - local rt = resolve_tuple_and_nominal(t) - if is_lua_table_type(rt) then - if single_table_type then - -- multiple table types in union, give up - single_table_type = nil - single_table_rt = nil - break - end + local decltype = resolve_tuple_and_nominal(node.expected) + + if decltype.typename == "union" then + local single_table_type: Type + local single_table_rt: Type - single_table_type = t - single_table_rt = rt + for _, t in ipairs(decltype.types) do + local rt = resolve_tuple_and_nominal(t) + if is_lua_table_type(rt) then + if single_table_type then + -- multiple table types in union, give up + single_table_type = nil + single_table_rt = nil + break end - end - if single_table_type then - node.expected = single_table_type - decltype = single_table_rt + single_table_type = t + single_table_rt = rt end end - if not is_lua_table_type(decltype) then - node.type = infer_table_literal(node, children) - return node.type + if single_table_type then + node.expected = single_table_type + decltype = single_table_rt end + end - local is_record = is_record_type(decltype) - local is_array = is_array_type(decltype) - local is_tupletable = decltype.typename == "tupletable" - local is_map = decltype.typename == "map" + if not is_lua_table_type(decltype) then + return infer_table_literal(node, children) + end - local force_array: Type = nil + local is_record = is_record_type(decltype) + local is_array = is_array_type(decltype) + local is_tupletable = decltype.typename == "tupletable" + local is_map = decltype.typename == "map" - local seen_keys: {CheckableKey:Node} = {} + local force_array: Type = nil - for i, child in ipairs(children) do - assert(child.typename == "table_item") - local cvtype = resolve_tuple(child.vtype) - local ck = child.kname - local n = node[i].key.constnum - local b: boolean = nil - if child.ktype.typename == "boolean" then - b = (node[i].key.tk == "true") - end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) - if is_record and ck then - local df = decltype.fields[ck] - if not df then - node_error(node[i], in_context(node.expected_context, "unknown field " .. ck)) - else - if is_typetype(df) then - node_error(node[i], in_context(node.expected_context, "cannot reassign a type")) - else - assert_is_a(node[i], cvtype, df, "in record field", ck) - end - end - elseif is_tupletable and is_number_type(child.ktype) then - local dt = decltype.types[n as integer] - if not n then - node_error(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) - elseif not dt then - node_error(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) - else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) - end - elseif is_array and is_number_type(child.ktype) then - if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then - -- need to expand last item in an array (e.g { 1, 2, 3, f() }) - for ti, tt in ipairs(child.vtype) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) - end + local seen_keys: {CheckableKey:Where} = {} + + for i, child in ipairs(children) do + assert(child.typename == "table_item") + local cvtype = resolve_tuple(child.vtype) + local ck = child.kname + local n = node[i].key.constnum + local b: boolean = nil + if child.ktype.typename == "boolean" then + b = (node[i].key.tk == "true") + end + check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + if is_record and ck then + local df = decltype.fields[ck] + if not df then + error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + else + if is_typetype(df) then + error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + assert_is_a(node[i], cvtype, df, "in record field", ck) end - elseif node[i].key_parsed == "implicit" then - if is_map then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + end + elseif is_tupletable and is_number_type(child.ktype) then + local dt = decltype.types[n as integer] + if not n then + error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + elseif not dt then + error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + else + assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + end + elseif is_array and is_number_type(child.ktype) then + if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then + -- need to expand last item in an array (e.g { 1, 2, 3, f() }) + for ti, tt in ipairs(child.vtype) do + assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) end - force_array = expand_type(node[i], force_array, child.vtype) - elseif is_map then - force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) else - node_error(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) end - end - - if force_array then - node.type = infer_at(node, a_type { - typename = "array", - elements = force_array, - }) - else - node.type = resolve_typevars_at(node, node.expected) - if node.expected == node.type and node.type.typename == "nominal" then - node.type = { - typeid = node.type.typeid, - typename = "nominal", - names = node.type.names, - found = node.type.found, - resolved = node.type.resolved, - } + elseif node[i].key_parsed == "implicit" then + if is_map then + assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) + assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) end + force_array = expand_type(node[i], force_array, child.vtype) + elseif is_map then + force_array = nil + assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) + assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + else + error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) end + end - if decltype.typename == "record" then - node.type.is_total, node.type.missing = total_record_check(decltype, seen_keys) - elseif decltype.typename == "map" then - node.type.is_total, node.type.missing = total_map_check(decltype, seen_keys) - end + local t: Type + if force_array then + t = infer_at(node, a_type { + typename = "array", + elements = force_array, + }) else - node.type = infer_table_literal(node, children) + t = resolve_typevars_at(node, node.expected) + if node.expected == t and t.typename == "nominal" then + t = { + typeid = t.typeid, + typename = "nominal", + names = t.names, + found = t.found, + resolved = t.resolved, + } + end end - return node.type + if decltype.typename == "record" then + t.is_total, t.missing = total_record_check(decltype, seen_keys) + elseif decltype.typename == "map" then + t.is_total, t.missing = total_map_check(decltype, seen_keys) + end + + return t end, }, ["table_item"] = { @@ -10251,7 +10496,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype.typeid = new_typeid() vtype.is_method = false end - node.type = a_type { + return a_type { y = node.y, x = node.x, typename = "table_item", @@ -10259,24 +10504,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ktype = ktype, vtype = vtype, } - return node.type end, }, ["local_function"] = { before = function(node: Node) widen_all_unions() - reserve_symbol_list_slot(node) + if symbol_list then + reserve_symbol_list_slot(node) + end begin_scope(node) end, - before_statements = function(node: Node) - add_internal_function_variables(node) - add_function_definition_for_recursion(node) + before_statements = function(node: Node, children: {Type}) + local args = children[2] + add_internal_function_variables(node, args) + add_function_definition_for_recursion(node, args) end, after = function(node: Node, children: {Type}): Type end_function_scope(node) local rets = get_rets(children[3]) - add_var(node, node.name.tk, ensure_fresh_typeargs(a_type { + local t = ensure_fresh_typeargs(a_type { y = node.y, x = node.x, typename = "function", @@ -10284,8 +10531,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string args = children[2], rets = rets, filename = filename, - })) - return node.type + }) + + add_var(node, node.name.tk, t) + return t end, }, ["global_function"] = { @@ -10298,22 +10547,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if typ.typename == "function" then node.is_predeclared_local_function = true elseif not lax then - node_error(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end elseif not lax then - node_error(node, "functions need an explicit 'local' or 'global' annotation") + error_at(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node: Node) - add_internal_function_variables(node) - add_function_definition_for_recursion(node) + before_statements = function(node: Node, children: {Type}) + local args = children[2] + add_internal_function_variables(node, args) + add_function_definition_for_recursion(node, args) end, after = function(node: Node, children: {Type}): Type end_function_scope(node) if node.is_predeclared_local_function then - return node.type + return NONE end + add_global(node, node.name.tk, ensure_fresh_typeargs(a_type { y = node.y, x = node.x, @@ -10323,7 +10574,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string rets = get_rets(children[3]), filename = filename, })) - return node.type + + return NONE end, }, ["record_function"] = { @@ -10347,7 +10599,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end, before_statements = function(node: Node, children: {Type}) - add_internal_function_variables(node) + local args = children[3] local rtype = node.rtype if rtype.typename == "emptytable" then @@ -10361,17 +10613,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not is_record_type(rtype) then - node_error(node, "not a module: %s", rtype) + error_at(node, "not a module: %s", rtype) return end + local selftype = get_self_type(node.fn_owner) if node.is_method then - local selftype = get_self_type(node.fn_owner) if not selftype then - node_error(node, "could not resolve type of self") + error_at(node, "could not resolve type of self") return end - children[3][1] = selftype + args[1] = selftype add_var(nil, "self", selftype) end @@ -10381,7 +10633,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typename = "function", is_method = node.is_method, typeargs = node.typeargs, - args = children[3], + args = args, rets = get_rets(children[4]), filename = filename, }) @@ -10403,9 +10655,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return end - local shortname = node.fn_owner.type.typename == "nominal" - and show_type(node.fn_owner.type) - or owner_name + local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " add_errs_prefixing(node, err, errors, msg) return @@ -10415,7 +10665,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -10427,12 +10677,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end open_v.implemented[open_k] = true end - node.name.type = fn_type + + add_internal_function_variables(node, args) end, after = function(node: Node, _children: {Type}): Type end_function_scope(node) - node.type = NONE - return node.type + return NONE end, }, ["function"] = { @@ -10440,14 +10690,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions(node) begin_scope(node) end, - before_statements = function(node: Node) - add_internal_function_variables(node) + before_statements = function(node: Node, children: {Type}) + local args = children[1] + add_internal_function_variables(node, args) end, after = function(node: Node, children: {Type}): Type end_function_scope(node) -- children[1] args -- children[2] body - node.type = ensure_fresh_typeargs(a_type { + return ensure_fresh_typeargs(a_type { y = node.y, x = node.x, typename = "function", @@ -10456,7 +10707,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string rets = children[2], filename = filename, }) - return node.type end, }, ["macroexp"] = { @@ -10464,14 +10714,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions(node) begin_scope(node) end, - before_exp = function(node: Node) - add_internal_function_variables(node) + before_exp = function(node: Node, children: {Type}) + local args = children[1] + add_internal_function_variables(node, args) end, after = function(node: Node, children: {Type}): Type end_function_scope(node) -- children[1] args -- children[2] body - node.type = ensure_fresh_typeargs(a_type { + return ensure_fresh_typeargs(a_type { y = node.y, x = node.x, typename = "function", @@ -10480,13 +10731,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string rets = children[2], filename = filename, }) - return node.type end, }, ["cast"] = { after = function(node: Node, _children: {Type}): Type - node.type = node.casttype - return node.type + return node.casttype end }, ["paren"] = { @@ -10495,8 +10744,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, after = function(node: Node, children: {Type}): Type node.known = node.e1 and node.e1.known - node.type = resolve_tuple(children[1]) - return node.type + return resolve_tuple(children[1]) end, }, ["op"] = { @@ -10513,18 +10761,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - before_e2 = function(node: Node) + before_e2 = function(node: Node, children: {Type}) + local e1type = children[1] + if node.op.op == "and" then apply_facts(node, node.e1.known) elseif node.op.op == "or" then apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then - if node.e1.type.typename == "function" then + if e1type.typename == "function" then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(node.e1.type.rets, node.expected) + is_a(e1type.rets, node.expected) end - local e1args = node.e1.type.args + local e1args = e1type.args local at = argdelta for _, typ in ipairs(e1args) do at = at + 1 @@ -10540,8 +10790,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif node.op.op == "@index" then - if node.e1.type.typename == "map" then - node.e2.expected = node.e1.type.keys + if e1type.typename == "map" then + node.e2.expected = e1type.keys end end end, @@ -10559,23 +10809,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local expected = node.expected and resolve_tuple_and_nominal(node.expected) if ra.typename == "circular_require" or (ra.def and ra.def.typename == "circular_require") then - node_error(node, "cannot dereference a type from a circular require") - node.type = INVALID - return node.type + return invalid_at(node, "cannot dereference a type from a circular require") end if is_typetype(ra) then if ra.def.typename == "record" then ra = ra.def elseif ra.def.typename == "interface" then - node_error(node, "interfaces are abstract; consider using a concrete record") + error_at(node, "interfaces are abstract; consider using a concrete record") end end if rb and is_typetype(rb) and rb.def.typename == "record" then if rb.def.typename == "record" then rb = rb.def elseif rb.def.typename == "interface" then - node_error(node, "interfaces are abstract; consider using a concrete record") + error_at(node, "interfaces are abstract; consider using a concrete record") end end @@ -10585,8 +10833,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - node.type = type_check_funcall(node, a, b) - elseif node.op.op == "." then + return type_check_funcall(node, a, b) + end + + if node.op.op == "." then + node.receiver = a + assert(node.e2.kind == "identifier") local bnode: Node = { y = node.e2.y, @@ -10601,9 +10853,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tk = '"' ..node.e2.tk .. '"', typename = "string", } - node.type = type_check_index(node.e1, bnode, orig_a, btype) + local t = type_check_index(node.e1, bnode, orig_a, btype) - if node.type.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and opts.gen_compat ~= "off" then -- only apply to a literal use, not a propagated type if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk @@ -10612,83 +10864,112 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string all_needs_compat[key] = true end end - elseif node.op.op == "@index" then - node.type = type_check_index(node.e1, node.e2, a, b) - elseif node.op.op == "as" then - node.type = b - elseif node.op.op == "is" then + + return t + end + + if node.op.op == "@index" then + return type_check_index(node.e1, node.e2, a, b) + end + + if node.op.op == "as" then + return b + end + + if node.op.op == "is" then if rb.typename == "integer" then all_needs_compat["math"] = true end if ra.typename == "typetype" then - node_error(node, "can only use 'is' on variables, not types") + error_at(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) node.known = IsFact { var = node.e1.tk, typ = b, where = node } else - node_error(node, "can only use 'is' on variables") + error_at(node, "can only use 'is' on variables") end - node.type = BOOLEAN - elseif node.op.op == ":" then + return BOOLEAN + end + + if node.op.op == ":" then + node.receiver = a + -- we handle ':' separately from '.' because ':' is specific to records, -- so we produce different error messages if lax and (is_unknown(a) or a.typename == "typevar") then if node.e1.kind == "variable" then add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - node.type = UNKNOWN - else - local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) - if not t then - node.type = INVALID - return node_error(node.e2, e, a == INVALID and a or resolve_tuple(orig_a)) - end - node.type = t + return UNKNOWN end - elseif node.op.op == "not" then + + local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) + if not t then + return invalid_at(node.e2, e, a == INVALID and a or resolve_tuple(orig_a)) + end + + return t + end + + if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - node.type = BOOLEAN - elseif node.op.op == "and" then + return BOOLEAN + end + + if node.op.op == "and" then node.known = facts_and(node, node.e1.known, node.e2.known) - node.type = resolve_tuple(b) - elseif node.op.op == "or" and b.typename == "nil" then - node.known = nil - node.type = resolve_tuple(a) - elseif node.op.op == "or" and is_lua_table_type(ra) and b.typename == "emptytable" then - node.known = nil - node.type = resolve_tuple(a) - elseif node.op.op == "or" - and ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) - or (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then - node.known = nil - node.type = (ra.typename == "enum" and ra or rb) - elseif node.op.op == "or" and expected and expected.typename == "union" then - -- must be checked after string/enum above - node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ra, rb}, true) - if u.typename == "union" then - u = validate_union(node, u) - end - node.type = u - elseif node.op.op == "or" and is_a(rb, ra) then - node.known = facts_or(node, node.e1.known, node.e2.known) - if expected then - local a_is = is_a(a, node.expected) - local b_is = is_a(b, node.expected) - if a_is and b_is then - node.type = resolve_typevars_at(node, node.expected) - elseif a_is then - node.type = resolve_tuple(b) + return discard_tuple(node, b, b) + end + + if node.op.op == "or" then + local t: Type + if b.typename == "nil" then + node.known = nil + t = a + + elseif is_lua_table_type(ra) and b.typename == "emptytable" then + node.known = nil + t = a + + elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) + or (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then + node.known = nil + t = (ra.typename == "enum" and ra or rb) + + elseif expected and expected.typename == "union" then + -- must be checked after string/enum above + node.known = facts_or(node, node.e1.known, node.e2.known) + local u = unite({ra, rb}, true) + if u.typename == "union" then + u = validate_union(node, u) + end + t = u + + elseif is_a(rb, ra) then + node.known = facts_or(node, node.e1.known, node.e2.known) + if expected then + local a_is = is_a(a, node.expected) + local b_is = is_a(b, node.expected) + if a_is and b_is then + t = resolve_typevars_at(node, node.expected) + elseif a_is then + t = resolve_tuple(b) + else + t = resolve_tuple(a) + end else - node.type = resolve_tuple(a) + t = resolve_tuple(a) end - else - node.type = resolve_tuple(a) + t.tk = nil end - node.type.tk = nil - elseif node.op.op == "==" or node.op.op == "~=" then - node.type = BOOLEAN + if t then + return discard_tuple(node, t, b) + end + -- else fallthrough to general binop handler + end + + if node.op.op == "==" or node.op.op == "~=" then -- if is_lua_table_type(ra) and is_lua_table_type(rb) then -- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end @@ -10702,33 +10983,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = EqFact { var = node.e2.tk, typ = a, where = node } end elseif lax and (is_unknown(a) or is_unknown(b)) then - node.type = UNKNOWN + return UNKNOWN else - return node_error(node, "types are not comparable for equality: %s and %s", a, b) + return invalid_at(node, "types are not comparable for equality: %s and %s", a, b) end - elseif node.op.arity == 1 and unop_types[node.op.op] then + + return BOOLEAN + end + + if node.op.arity == 1 and unop_types[node.op.op] then a = ra if a.typename == "union" then a = unite(a.types, true) -- squash unions of string constants end local types_op = unop_types[node.op.op] - node.type = types_op[a.typename] - if not node.type then - node.type = find_in_interface_list(a, function(t: Type): Type - return types_op[t.typename] + local t = types_op[a.typename] + + if not t then + t = find_in_interface_list(a, function(ty: Type): Type + return types_op[ty.typename] end) end local meta_on_operator: integer - if not node.type then + if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - node.type, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) + t, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) end - if not node.type then - node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) + if not t then + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) + t = INVALID end end @@ -10736,11 +11023,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if a.keys.typename == "number" or a.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - node_error(node, "using the '#' operator on this map will always return 0") + error_at(node, "using the '#' operator on this map will always return 0") end end - if node.type.typename ~= "boolean" and not is_unknown(node.type) then + if t.typename ~= "boolean" and not is_unknown(t) then node.known = FACT_TRUTHY end @@ -10754,7 +11041,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - elseif node.op.arity == 2 and binop_types[node.op.op] then + return t + end + + if node.op.arity == 2 and binop_types[node.op.op] then if node.op.op == "or" then node.known = facts_or(node, node.e1.known, node.e2.known) end @@ -10770,15 +11060,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local types_op = binop_types[node.op.op] - node.type = types_op[a.typename] and types_op[a.typename][b.typename] + + local t = types_op[a.typename] and types_op[a.typename][b.typename] + local meta_on_operator: integer - if not node.type then + if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - node.type, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) + t, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) end - if not node.type then - node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + if not t then + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + t = INVALID if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end @@ -10787,9 +11080,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if orig_a.typename == "nominal" and orig_b.typename == "nominal" and not meta_on_operator then if is_a(orig_a, orig_b) then - node.type = resolve_tuple(orig_a) + t = resolve_tuple(orig_a) else - node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) end end @@ -10814,10 +11107,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end - else - error("unknown node op " .. node.op.op) + + return t end - return node.type + + error("unknown node op " .. node.op.op) end, }, ["variable"] = { @@ -10825,48 +11119,54 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.tk == "..." then local va_sentinel = find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return node_error(node, "cannot use '...' outside a vararg function") + return invalid_at(node, "cannot use '...' outside a vararg function") end end + + local t: Type if node.tk == "_G" then - node.type, node.attribute = simulate_g() + t, node.attribute = simulate_g() else local use: VarUse = node.is_lvalue and "lvalue" or "use" - node.type, node.attribute = find_var_type(node.tk, use) + t, node.attribute = find_var_type(node.tk, use) + end + if not t then + if lax then + add_unknown(node, node.tk) + return UNKNOWN + end + + return invalid_at(node, "unknown variable: " .. node.tk) end - if node.type and is_typetype(node.type) then - node.type = a_type { + + if is_typetype(t) then + t = a_type { y = node.y, x = node.x, typename = "nominal", names = { node.tk }, - found = node.type, - resolved = node.type, + found = t, + resolved = t, } end - if node.type == nil then - node.type = a_type { typename = "unknown" } - if lax then - add_unknown(node, node.tk) - else - return node_error(node, "unknown variable: " .. node.tk) - end - end - return node.type + + return t end, }, ["type_identifier"] = { after = function(node: Node, _children: {Type}): Type - node.type, node.attribute = find_var_type(node.tk) - if node.type == nil then - if lax then - node.type = UNKNOWN - add_unknown(node, node.tk) - else - return node_error(node, "unknown variable: " .. node.tk) - end + local typ, attr = find_var_type(node.tk) + node.attribute = attr + if typ then + return typ end - return node.type + + if lax then + add_unknown(node, node.tk) + return UNKNOWN + end + + return invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { @@ -10882,57 +11182,52 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t = OPT(t) end add_var(node, node.tk, t).is_func_arg = true - return node.type + return t end, }, ["identifier"] = { - after = function(node: Node, _children: {Type}): Type - node.type = node.type or NONE -- type is resolved elsewhere - return node.type + after = function(_node: Node, _children: {Type}): Type + return NONE -- type is resolved elsewhere end, }, ["newtype"] = { after = function(node: Node, _children: {Type}): Type - node.type = node.type or node.newtype - return node.type + return node.newtype end, }, ["error_node"] = { - after = function(node: Node, _children: {Type}): Type - node.type = INVALID - return node.type + after = function(_node: Node, _children: {Type}): Type + return INVALID end, } } visit_node.cbs["break"] = { - after = function(node: Node, _children: {Type}): Type - node.type = NONE - return node.type + after = function(_node: Node, _children: {Type}): Type + return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] local function after_literal(node: Node): Type - node.type = a_type { + node.known = FACT_TRUTHY + return a_type { y = node.y, x = node.x, typename = node.kind as TypeName, tk = node.tk, } - node.known = FACT_TRUTHY - return node.type end visit_node.cbs["string"] = { after = function(node: Node, _children: {Type}): Type - after_literal(node) + local t = after_literal(node) if node.expected then - if node.expected.typename == "enum" and is_a(node.type, node.expected) then - node.type = node.expected + if node.expected.typename == "enum" and is_a(t, node.expected) then + t = node.expected end end - return node.type + return t end, } visit_node.cbs["number"] = { after = after_literal } @@ -10940,9 +11235,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["boolean"] = { after = function(node: Node, _children: {Type}): Type - after_literal(node) + local t = after_literal(node) node.known = (node.tk == "true") and FACT_TRUTHY or nil - return node.type + return t end, } visit_node.cbs["nil"] = visit_node.cbs["boolean"] @@ -10951,18 +11246,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node: Node, _children: {Type}): Type + visit_node.after = function(node: Node, _children: {Type}, t: Type): Type if node.expanded then apply_macroexp(node) end - if type(node.type) ~= "table" then - error(node.kind .. " did not produce a type") - end - if type(node.type.typename) ~= "string" then - error(node.kind .. " type does not have a typename") - end - return node.type + return t end local visit_type: Visitor @@ -10982,12 +11271,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = ensure_fresh_typeargs(typ) if typ.macroexp then - recurse_node(typ.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(typ.macroexp, visit_node, visit_type) check_macroexp_arg_use(typ.macroexp) - if not is_a(typ.macroexp.type, typ) then - error_at(typ.macroexp.type, "macroexp type does not match declaration") + if not is_a(macroexp_type, typ) then + error_at(macroexp_type, "macroexp type does not match declaration") end end @@ -10997,7 +11286,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = { before = function(typ: Type) begin_scope() - add_var(nil, "@self", a_type { typename = "typetype", y = typ.y, x = typ.x, def = typ }) + add_var(nil, "@self", a_type({ typename = "typetype", y = typ.y, x = typ.x, def = typ })) for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then @@ -11118,27 +11407,45 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end }, }, - after = function(typ: Type, _children: {Type}, ret: Type): Type - if type(ret) ~= "table" then - error(typ.typename .. " did not produce a type") + } + + local function internal_compiler_check(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) + return function(w: W, children: {Type}, t: Type): Type + t = fn and fn(w, children, t) or t + + if type(t) ~= "table" then + error(((w as Node).kind or (w as Type).typename) .. " did not produce a type") end - if type(ret.typename) ~= "string" then - error("type node does not have a typename") + if type(t.typename) ~= "string" then + error(((w as Node).kind or (w as Type).typename) .. " type does not have a typename") end - return ret + + return t end - } + end + + local function store_type_after(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) + return function(w: W, children: {Type}, t: Type): Type + t = fn and fn(w, children, t) or t + + local where = w as Where - if not opts.run_internal_compiler_checks then - visit_node.after = function(node: Node, _children: {Type}): Type - if node.expanded then - apply_macroexp(node) + if where.y then + store_type(where.y, where.x, t) end - return node.type + return t end + end + + if opts.run_internal_compiler_checks then + visit_node.after = internal_compiler_check(visit_node.after) + visit_type.after = internal_compiler_check(visit_type.after) + end - visit_type.after = nil + if store_type then + visit_node.after = store_type_after(visit_node.after) + visit_type.after = store_type_after(visit_type.after) end visit_type.cbs["tupletable"] = visit_type.cbs["string"] @@ -11200,193 +11507,15 @@ end -- Report types -------------------------------------------------------------------------------- -local typename_to_typecode : {TypeName:integer} = { - ["typevar"] = tl.typecodes.TYPE_VARIABLE, - ["typearg"] = tl.typecodes.TYPE_VARIABLE, - ["unresolved_typearg"] = tl.typecodes.TYPE_VARIABLE, - ["unresolvable_typearg"] = tl.typecodes.TYPE_VARIABLE, - ["function"] = tl.typecodes.FUNCTION, - ["array"] = tl.typecodes.ARRAY, - ["map"] = tl.typecodes.MAP, - ["tupletable"] = tl.typecodes.TUPLE, - ["interface"] = tl.typecodes.INTERFACE, - ["record"] = tl.typecodes.RECORD, - ["enum"] = tl.typecodes.ENUM, - ["boolean"] = tl.typecodes.BOOLEAN, - ["string"] = tl.typecodes.STRING, - ["nil"] = tl.typecodes.NIL, - ["thread"] = tl.typecodes.THREAD, - ["number"] = tl.typecodes.NUMBER, - ["integer"] = tl.typecodes.INTEGER, - ["union"] = tl.typecodes.IS_UNION, - ["nominal"] = tl.typecodes.NOMINAL, - ["bad_nominal"] = tl.typecodes.NOMINAL, - ["circular_require"] = tl.typecodes.NOMINAL, - ["emptytable"] = tl.typecodes.EMPTY_TABLE, - ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, - ["poly"] = tl.typecodes.IS_POLY, - ["any"] = tl.typecodes.ANY, - ["unknown"] = tl.typecodes.UNKNOWN, - ["invalid"] = tl.typecodes.INVALID, - -- types that should be skipped or not present: - ["none"] = tl.typecodes.UNKNOWN, - ["tuple"] = tl.typecodes.UNKNOWN, - ["table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, - ["typetype"] = tl.typecodes.UNKNOWN, - ["nestedtype"] = tl.typecodes.UNKNOWN, -} - -local skip_types: {TypeName: boolean} = { - ["none"] = true, - ["tuple"] = true, - ["table_item"] = true, - ["unresolved"] = true, - ["typetype"] = true, - ["nestedtype"] = true, -} - function tl.get_types(result: Result, trenv: TypeReportEnv): TypeReport, TypeReportEnv local filename = result.filename or "?" - - local function mark_array(x: T): T - local arr = x as {boolean} - arr[0] = false - return x - end + trenv = trenv or result.env.trenv if not trenv then - trenv = { - next_num = 1, - typeid_to_num = {}, - tr = { - by_pos = {}, - types = {}, - symbols_by_file = {}, - globals = {}, - }, - } + error("result must have been generated with env.report_types = true", 2) end local tr = trenv.tr - local typeid_to_num = trenv.typeid_to_num - - local get_typenum: function(t: Type): integer - - local function store_function(ti: TypeInfo, rt: Type) - local args: {{integer, string}} = {} - for _, fnarg in ipairs(rt.args) do - table.insert(args, mark_array { get_typenum(fnarg), nil }) - end - ti.args = mark_array(args) - local rets: {{integer, string}} = {} - for _, fnarg in ipairs(rt.rets) do - table.insert(rets, mark_array { get_typenum(fnarg), nil }) - end - ti.rets = mark_array(rets) - ti.vararg = not not rt.is_va - end - - get_typenum = function(t: Type): integer - assert(t.typeid) - -- try by typeid - local n = typeid_to_num[t.typeid] - if n then - return n - end - - -- it's a new entry: store and increment - n = trenv.next_num - - local rt = t - if is_typetype(rt) then - rt = rt.def - elseif rt.typename == "tuple" and #rt == 1 then - rt = rt[1] - end - - local ti: TypeInfo = { - t = assert(typename_to_typecode[rt.typename]), - str = show_type(t, true), - file = t.filename, - y = t.y, - x = t.x, - } - tr.types[n] = ti - typeid_to_num[t.typeid] = n - trenv.next_num = trenv.next_num + 1 - - if t.found then - ti.ref = get_typenum(t.found) - end - if t.resolved then - rt = t - end - assert(not is_typetype(rt)) - - if is_record_type(rt) then - -- store record field info - local r = {} - for _, k in ipairs(rt.field_order) do - local v = rt.fields[k] - r[k] = get_typenum(v) - end - ti.fields = r - end - - if is_array_type(rt) then - ti.elements = get_typenum(rt.elements) - end - - if rt.typename == "map" then - ti.keys = get_typenum(rt.keys) - ti.values = get_typenum(rt.values) - elseif rt.typename == "enum" then - ti.enums = mark_array(sorted_keys(rt.enumset)) - elseif rt.typename == "function" then - store_function(ti, rt) - elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then - local tis = {} - - for _, pt in ipairs(rt.types) do - table.insert(tis, get_typenum(pt)) - end - - ti.types = mark_array(tis) - end - - return n - end - - local visit_node: Visitor = { allow_missing_cbs = true } - local visit_type: Visitor = { allow_missing_cbs = true } - - local ft: {integer:{integer:integer}} = {} - tr.by_pos[filename] = ft - - local function store(y: integer, x: integer, typ: Type) - if not typ or skip_types[typ.typename] then - return - end - - local yt = ft[y] - if not yt then - yt = {} - ft[y] = yt - end - - yt[x] = get_typenum(typ) - end - - visit_node.after = function(node: Node): nil - store(node.y, node.x, node.type) - end - - visit_type.after = function(typ: Type): nil - store(typ.y or 0, typ.x or 0, typ) - end - - recurse_node(result.ast, visit_node, visit_type) tr.by_pos[filename][0] = nil @@ -11427,7 +11556,7 @@ function tl.get_types(result: Result, trenv: TypeReportEnv): TypeReport, TypeRep i = i + 1 local id: integer if s.typ then - id = get_typenum(s.typ) + id = get_typenum(trenv, s.typ) elseif s.name == "@{" then level = level + 1 stack[level] = i @@ -11448,7 +11577,7 @@ function tl.get_types(result: Result, trenv: TypeReportEnv): TypeReport, TypeRep for _, name in ipairs(gkeys) do if name:sub(1, 1) ~= "@" then local var = result.env.globals[name] - tr.globals[name] = get_typenum(var.t) + tr.globals[name] = get_typenum(trenv, var.t) end end From 5658b9a5b6bc7694c829cc0ad0ba9b5880aa1030 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 6 Dec 2023 18:22:17 -0300 Subject: [PATCH 033/224] add node.debug_type for debugging only --- tl.lua | 21 ++++++++++++++++++++- tl.tl | 21 ++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index b221f7e7d..a9efc37eb 100644 --- a/tl.lua +++ b/tl.lua @@ -1441,6 +1441,8 @@ local Node = {ExpectedContext = {}, } + + @@ -3926,7 +3928,8 @@ local function recurse_node(root, end if TL_DEBUG then - tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s]", kprint) + local typ = ast.debug_type and " = " .. show_type(ast.debug_type) or "" + tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s]%s", kprint, typ) end return ret @@ -6360,6 +6363,10 @@ tl.type_check = function(ast, opts) end local name = where.filename or filename + if TL_DEBUG then + io.stderr:write("ERROR:" .. where.y .. ":" .. where.x .. ": " .. msg .. "\n") + end + return { y = where.y, x = where.x, @@ -11438,6 +11445,14 @@ tl.type_check = function(ast, opts) end end + local function debug_type_after(fn) + return function(node, children, t) + t = fn and fn(node, children, t) or t + node.debug_type = t + return t + end + end + if opts.run_internal_compiler_checks then visit_node.after = internal_compiler_check(visit_node.after) visit_type.after = internal_compiler_check(visit_type.after) @@ -11448,6 +11463,10 @@ tl.type_check = function(ast, opts) visit_type.after = store_type_after(visit_type.after) end + if TL_DEBUG then + visit_node.after = debug_type_after(visit_node.after) + end + visit_type.cbs["tupletable"] = visit_type.cbs["string"] visit_type.cbs["typetype"] = visit_type.cbs["string"] visit_type.cbs["nestedtype"] = visit_type.cbs["string"] diff --git a/tl.tl b/tl.tl index ab516b4e5..0530721b8 100644 --- a/tl.tl +++ b/tl.tl @@ -1438,6 +1438,8 @@ local record Node decltype: Type opt: boolean + + debug_type: Type end local type Where @@ -3926,7 +3928,8 @@ local function recurse_node(root: Node, end if TL_DEBUG then - tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s]", kprint) + local typ = ast.debug_type and " = " .. show_type(ast.debug_type) or "" + tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s]%s", kprint, typ) end return ret @@ -6360,6 +6363,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local name = where.filename or filename + if TL_DEBUG then + io.stderr:write("ERROR:" .. where.y .. ":" .. where.x .. ": " .. msg .. "\n") + end + return { y = where.y, x = where.x, @@ -11438,6 +11445,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function debug_type_after(fn: function(Node, {Type}, Type): (Type)): (function(Node, {Type}, Type): (Type)) + return function(node: Node, children: {Type}, t: Type): Type + t = fn and fn(node, children, t) or t + node.debug_type = t + return t + end + end + if opts.run_internal_compiler_checks then visit_node.after = internal_compiler_check(visit_node.after) visit_type.after = internal_compiler_check(visit_type.after) @@ -11448,6 +11463,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.after = store_type_after(visit_type.after) end + if TL_DEBUG then + visit_node.after = debug_type_after(visit_node.after) + end + visit_type.cbs["tupletable"] = visit_type.cbs["string"] visit_type.cbs["typetype"] = visit_type.cbs["string"] visit_type.cbs["nestedtype"] = visit_type.cbs["string"] From c54e4f5369aefd39f97c5e54bfd9720371d9e3e8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 8 Dec 2023 02:18:16 -0300 Subject: [PATCH 034/224] refactor type relation functions is_a and same_type --- spec/assignment/to_array_spec.lua | 7 + spec/declaration/record_spec.lua | 8 +- spec/operator/eq_spec.lua | 2 +- spec/stdlib/require_spec.lua | 69 +- tl.lua | 1127 +++++++++++++++------------- tl.tl | 1163 +++++++++++++++-------------- 6 files changed, 1305 insertions(+), 1071 deletions(-) diff --git a/spec/assignment/to_array_spec.lua b/spec/assignment/to_array_spec.lua index 540fddba7..7c6fbaa3e 100644 --- a/spec/assignment/to_array_spec.lua +++ b/spec/assignment/to_array_spec.lua @@ -64,4 +64,11 @@ describe("assignment to array", function() { msg = "unused variable b: Alias" }, })) + it("catches an incompatible tupletable", util.check_type_error([[ + local a: {string} + local t: {string, number} = { "hello", 123 } + a = t + ]], { + { y = 3, msg = "in assignment: got {string | number} (from {string, number}), expected {string}" }, + })) end) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index da1aaa71b..5e1b34954 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -805,15 +805,15 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces ]], { { y = 5, msg = "in local declaration: foo: got {}, expected Foo" }, select(i, - { y = 6, msg = "in assignment: userdata "..statement.." doesn't match: Foo" }, - { y = 6, msg = "in assignment: userdata "..statement.." doesn't match: Foo" }, + { y = 6, msg = "in assignment: record is not a userdata" }, + { y = 6, msg = "in assignment: record is not a userdata" }, { y = 6, msg = "in assignment: got record (a: integer), expected Foo" }, { y = 6, msg = "in assignment: got record (a: integer), expected Foo" } ), { y = 8, msg = "argument 1: got {}, expected Foo" }, select(i, - { y = 9, msg = "argument 1: userdata "..statement.." doesn't match: Foo" }, - { y = 9, msg = "argument 1: userdata "..statement.." doesn't match: Foo" }, + { y = 9, msg = "argument 1: record is not a userdata" }, + { y = 9, msg = "argument 1: record is not a userdata" }, { y = 9, msg = "argument 1: got record (a: integer), expected Foo" }, { y = 9, msg = "argument 1: got record (a: integer), expected Foo" } ), diff --git a/spec/operator/eq_spec.lua b/spec/operator/eq_spec.lua index af0ddf2db..7406d2501 100644 --- a/spec/operator/eq_spec.lua +++ b/spec/operator/eq_spec.lua @@ -27,7 +27,7 @@ describe("==", function() print("unreachable") end ]], { - { msg = "not comparable for equality" } + { msg = "string \"hello\" is not a member of MyEnum" } })) end) diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 2703b6d2d..5fc70d40f 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -140,7 +140,7 @@ describe("require", function() Point = Point, } - function Point:move(x: number, y: number) + function Point:move(x: number, y: number): Point self.x = self.x + x self.y = self.y + y end @@ -203,7 +203,7 @@ describe("require", function() Point = Point, } - function Point:move(x: number, y: number) + function Point:move(x: number, y: number): Point self.x = self.x + x self.y = self.y + y end @@ -248,6 +248,71 @@ describe("require", function() assert.same({}, result.type_errors) end) + it("return types of exported functions are checked", function () + -- ok + util.mock_io(finally, { + ["point.tl"] = [[ + local type Point = record + x: number + y: number + end + + local point = { + Point = Point, + } + + function Point:move(x: number, y: number) + self.x = self.x + x + self.y = self.y + y + end + + return point + ]], + ["bar.tl"] = [[ + local mypoint = require "point" + + local type rec = record + xx: number + yy: number + end + + local function get_point(): mypoint.Point + return { x = 100, y = 100 } + end + + return { + get_point = get_point, + rec = rec, + } + ]], + ["foo.tl"] = [[ + local pnt = require "point" + local bar = require "bar" + + global function use_point(p: pnt.Point) + print(p.x, p.y) + end + + use_point(bar.get_point():move(5, 5)) + local r: bar.rec = { + xx = 10, + yy = 20, + } + ]], + }) + local result, err = tl.process("foo.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { + filename = "foo.tl", + msg = "wrong number of arguments (given 0, expects 1)", + x = 22, + y = 8, + }, + }, result.type_errors) + end) + it("equality of nominal types does not depend on module names", function () -- ok util.mock_io(finally, { diff --git a/tl.lua b/tl.lua index a9efc37eb..492412f75 100644 --- a/tl.lua +++ b/tl.lua @@ -1028,6 +1028,7 @@ end + local table_types = { @@ -1065,6 +1066,7 @@ local table_types = { ["invalid"] = false, ["unresolved"] = false, ["none"] = false, + ["*"] = false, } @@ -4741,6 +4743,7 @@ local typename_to_typecode = { ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, ["nestedtype"] = tl.typecodes.UNKNOWN, + ["*"] = tl.typecodes.UNKNOWN, } local skip_types = { @@ -6364,7 +6367,7 @@ tl.type_check = function(ast, opts) local name = where.filename or filename if TL_DEBUG then - io.stderr:write("ERROR:" .. where.y .. ":" .. where.x .. ": " .. msg .. "\n") + io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") end return { @@ -6948,91 +6951,9 @@ tl.type_check = function(ast, opts) - local function compare_and_infer_typevars(t1, t2, comp) - - if t1.typevar == t2.typevar then - return true - end - - - local typevar = t2.typevar or t1.typevar - - - local vt = find_var_type(typevar) - if vt then - - if t2.typevar then - return comp(t1, vt) - else - return comp(vt, t2) - end - else - - local other = t2.typevar and t1 or t2 - local ok, resolved, errs = resolve_typevars(other) - if not ok then - return false, errs - end - if resolved.typename ~= "unknown" then - resolved = resolve_typetype(resolved) - add_var(nil, typevar, resolved) - end - return true - end - end - local same_type local is_a - - - local function match_record_fields(rec1, t2, invariant) - local fielderrs = {} - for _, k in ipairs(rec1.field_order) do - local f = rec1.fields[k] - local t2k = t2(k) - if t2k == nil then - if (not lax) and invariant then - table.insert(fielderrs, Err(f, "unknown field " .. k)) - end - else - local ok, errs - if invariant then - ok, errs = same_type(f, t2k) - else - ok, errs = is_a(f, t2k) - end - if not ok then - add_errs_prefixing(nil, errs, fielderrs, "record field doesn't match: " .. k .. ": ") - end - end - end - if #fielderrs > 0 then - return false, fielderrs - end - return true - end - - local function match_fields_to_record(rec1, rec2, invariant) - if rec1.is_userdata ~= rec2.is_userdata then - return false, { Err(rec1, "userdata record doesn't match: %s", rec2) } - end - local ok, fielderrs = match_record_fields(rec1, function(k) return rec2.fields[k] end, invariant) - if not ok then - local errs = {} - add_errs_prefixing(nil, fielderrs, errs, show_type(rec1) .. " is not a " .. show_type(rec2) .. ": ") - return false, errs - end - return true - end - - local function match_fields_to_map(rec1, map) - if not match_record_fields(rec1, function(_) return map.values end, false) then - return false, { Err(rec1, "record is not a valid map; not all fields have the same type") } - end - return true - end - local function arg_check(where, cmp, a, b, n, errs, ctx) local matches, match_errs = cmp(a, b) if not matches then @@ -7042,11 +6963,11 @@ tl.type_check = function(ast, opts) return true end - local function has_all_types_of(t1s, t2s, cmp) + local function has_all_types_of(t1s, t2s) for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if cmp(t2, t1) then + if same_type(t2, t1) then found = true break end @@ -7290,6 +7211,22 @@ tl.type_check = function(ast, opts) return false end + local function fail_nominals(t1, t2) + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + local t1r = resolve_nominal(t1) + if t1r.filename then + t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" + end + local t2r = resolve_nominal(t2) + if t2r.filename then + t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + end + end + return false, { Err(t1, t1name .. " is not a " .. t2name) } + end + local function are_same_nominals(t1, t2) local same_names if t1.found and t2.found then @@ -7314,149 +7251,24 @@ tl.type_check = function(ast, opts) end end - if same_names then - if t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local all_errs = {} - for i = 1, #t1.typevals do - local _, errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, errs, all_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") - end - if #all_errs == 0 then - return true - else - return false, all_errs - end - end - else - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" - end - end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - end - - local is_lua_table_type - local resolve_tuple_and_nominal = nil - - local function invariant_match_fields_to_record(t1, t2) - local ok, errs = match_fields_to_record(t1, t2, true) - if not ok then - return ok, errs - end - ok, errs = match_fields_to_record(t2, t1, true) - if not ok then - return ok, errs - end - return true - end - - - same_type = function(t1, t2) - assert(type(t1) == "table") - assert(type(t2) == "table") - - if t1.typeid == t2.typeid then - if TL_DEBUG then - local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {}) - assert(st1 == st2, st1 .. " ~= " .. st2) - end + if not same_names then + return fail_nominals(t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then return true - end - - if t1.typename == "typevar" or t2.typename == "typevar" then - return compare_and_infer_typevars(t1, t2, same_type) - end - - if t1.typename == "emptytable" and is_lua_table_type(resolve_tuple_and_nominal(t2)) then - return true - end - - if t2.typename == "emptytable" and is_lua_table_type(resolve_tuple_and_nominal(t1)) then - return true - end - - if t1.typename ~= t2.typename then - return false, { Err(t1, "got %s, expected %s", t1, t2) } - end - if t1.typename == "array" then - return same_type(t1.elements, t2.elements) - elseif t1.typename == "tupletable" then - local all_errs = {} - for i = 1, math.min(#t1.types, #t2.types) do - local ok, err = same_type(t1.types[i], t2.types[i]) - if not ok then - add_errs_prefixing(t1, err, all_errs, "values") - end - end - return any_errors(all_errs) - elseif t1.typename == "map" then - local all_errs = {} - local k_ok, k_errs = same_type(t1.keys, t2.keys) - if not k_ok then - add_errs_prefixing(t1, k_errs, all_errs, "keys") - end - local v_ok, v_errs = same_type(t1.values, t2.values) - if not v_ok then - add_errs_prefixing(t1, v_errs, all_errs, "values") - end - return any_errors(all_errs) - elseif t1.typename == "union" then - if has_all_types_of(t1.types, t2.types, same_type) and - has_all_types_of(t2.types, t1.types, same_type) then - return true - else - return false, { Err(t1, "got %s, expected %s", t1, t2) } - end - elseif t1.typename == "nominal" then - return are_same_nominals(t1, t2) - elseif t1.typename == "record" then - - if (t1.elements ~= nil) ~= (t2.elements ~= nil) then - return false, { Err(t1, "types do not have the same array interface") } - end - if t1.elements and t2.elements then - local ok, errs = same_type(t1.elements, t2.elements) - if not ok then - return ok, errs - end - end - - return invariant_match_fields_to_record(t1, t2) - elseif t1.typename == "function" then - local argdelta = t1.is_method and 1 or 0 - if #t1.args ~= #t2.args then - if t1.is_method ~= t2.is_method then - return false, { Err(t1, "different number of input arguments: method and non-method are not the same type") } - end - return false, { Err(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta) } - end - if #t1.rets ~= #t2.rets then - return false, { Err(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets) } - end - local all_errs = {} - for i = 1, #t1.args do - arg_check(t1, same_type, t1.args[i], t2.args[i], i - argdelta, all_errs, "argument") - end - for i = 1, #t1.rets do - local _, errs = same_type(t1.rets[i], t2.rets[i]) - add_errs_prefixing(t1, errs, all_errs, "return " .. i) + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) + add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") end - return any_errors(all_errs) + return any_errors(errs) end return true end + local is_lua_table_type + local resolve_tuple_and_nominal + local function unite(types, flatten_constants) if #types == 1 then return types[1] @@ -7518,25 +7330,6 @@ tl.type_check = function(ast, opts) end end - local function add_map_errors(errs, ctx, ctx_errs) - if ctx_errs then - for _, err in ipairs(ctx_errs) do - err.msg = ctx .. err.msg - table.insert(errs, err) - end - end - end - - local function combine_map_errs(key_errs, value_errs) - if not key_errs and not value_errs then - return true - end - local errs = {} - add_map_errors(errs, "in map key: ", key_errs) - add_map_errors(errs, "in map value: ", value_errs) - return false, errs - end - do local known_table_types = { array = true, @@ -7571,7 +7364,7 @@ tl.type_check = function(ast, opts) }) for i = 2, #tupletype.types do arr_type = expand_type(where, arr_type, a_type({ elements = tupletype.types[i], typename = "array" })) - if not arr_type or not arr_type.elements then + if not arr_type.elements then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end end @@ -7582,349 +7375,630 @@ tl.type_check = function(ast, opts) return t.typename == "nominal" and t.names[1] == "@self" end + local function compare_false(_, _) + return false + end - is_a = function(t1, t2, for_equality) - assert(type(t1) == "table") - assert(type(t2) == "table") + local function compare_true(_, _) + return true + end - if lax and (is_unknown(t1) or is_unknown(t2)) then + local function subtype_nominal(a, b) + if is_self(a) and is_self(b) then return true end - if t1.typeid == t2.typeid then - if TL_DEBUG then - local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {}) - assert(st1 == st2, st1 .. " ~= " .. st2) - end - return true + local ra = a.typename == "nominal" and resolve_nominal(a) or a + local rb = b.typename == "nominal" and resolve_nominal(b) or b + local ok, errs = is_a(ra, rb) + if errs and #errs == 1 and errs[1].msg:match("^got ") then + return false end + return ok, errs + end + + local function subtype_array(a, b) - if t1.typename == "bad_nominal" or t2.typename == "bad_nominal" then + if (not a.elements) or (not is_a(a.elements, b.elements)) then return false end + if a.types and #a.types > 1 then - if t2.typename ~= "tuple" then - t1 = resolve_tuple(t1) + for i = 1, #a.types do + local e = a.types[i] + if not is_a(e, b.elements) then + return false, { Err(a, "%s is not a member of %s", e, b.elements) } + end + end end + return true + end - if t2.typename == "tuple" and t1.typename ~= "tuple" then - t1 = a_type({ - typename = "tuple", - [1] = t1, - }) - end + local function subtype_record(a, b) - if t1.typename == "typevar" or t2.typename == "typevar" then - return compare_and_infer_typevars(t1, t2, is_a) + if a.elements and b.elements then + if not is_a(a.elements, b.elements) then + return false, { Err(a, "array parts have incompatible element types") } + end end + if a.is_userdata ~= b.is_userdata then + return false, { Err(a, a.is_userdata and "userdata is not a record" or +"record is not a userdata"), } + end - if t1.typename == "nil" then - return true + local errs = {} + for _, k in ipairs(a.field_order) do + local ak = a.fields[k] + local bk = b.fields[k] + if bk then + local ok, fielderrs = is_a(ak, bk) + if not ok then + add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + end + end + end + if #errs > 0 then + for _, err in ipairs(errs) do + err.msg = show_type(a) .. " is not a " .. show_type(b) .. ": " .. err.msg + end + return false, errs end + return true + end - if t2.typename == "any" then - return true + local eqtype_record = function(a, b) - elseif is_self(t1) then - if is_self(t2) then - return true + if (a.elements ~= nil) ~= (b.elements ~= nil) then + return false, { Err(a, "types do not have the same array interface") } + end + if a.elements then + local ok, errs = same_type(a.elements, b.elements) + if not ok then + return ok, errs end + end + + local ok, errs = subtype_record(a, b) + if not ok then + return ok, errs + end + ok, errs = subtype_record(b, a) + if not ok then + return ok, errs + end + return true + end - return is_a(resolve_tuple_and_nominal(t1), t2, for_equality) + local function compare_map(ak, bk, av, bv, no_hack) + local ok1, errs_k = same_type(ak, bk) + local ok2, errs_v = same_type(av, bv) - elseif is_self(t2) then - return is_a(t1, resolve_tuple_and_nominal(t2), for_equality) - elseif t1.typename == "union" then + if bk.typename == "any" and not no_hack then + ok1, errs_k = true, nil + end + if bv.typename == "any" and not no_hack then + ok2, errs_v = true, nil + end + if ok1 and ok2 then + return true + end + for i = 1, errs_k and #errs_k or 0 do + errs_k[i].msg = "in map key: " .. errs_k[i].msg + end + for i = 1, errs_v and #errs_v or 0 do + errs_v[i].msg = "in map value: " .. errs_v[i].msg + end + if errs_k and errs_v then + for i = 1, #errs_v do + table.insert(errs_k, errs_v[i]) + end + return false, errs_k + end + return false, errs_k or errs_v + end - if t2.typename == "union" then - local used = {} - for _, t in ipairs(t1.types) do - local ok = false - begin_scope() - for _, u in ipairs(t2.types) do - if not used[u] then - if is_a(t, u, for_equality) then - used[u] = t - ok = true - break - end - end - end - end_scope() - if not ok then - return false, { Err(t1, "got %s, expected %s", t1, t2) } - end - end + local function compare_or_infer_typevar(typevar, a, b, cmp) - for u, t in pairs(used) do - is_a(t, u, for_equality) - end - return true + local vt = find_var_type(typevar) + if vt then + return cmp(a or vt, b or vt) + else - else - for _, t in ipairs(t1.types) do - if not is_a(t, t2, for_equality) then - return false, { Err(t1, "got %s, expected %s", t1, t2) } - end - end + local ok, r, errs = resolve_typevars(a or b) + if not ok then + return false, errs + end + if r.typevar == typevar then return true end + add_var(nil, typevar, r) + return true + end + end - - - elseif t2.typename == "union" then - for _, t in ipairs(t2.types) do - if is_a(t1, t, for_equality) then - return true - end + local function exists_supertype_in(t, xs) + for _, x in ipairs(xs.types) do + if is_a(t, x) then + return x end + end + end - - - elseif t2.typename == "poly" then - for _, t in ipairs(t2.types) do - if not is_a(t1, t, for_equality) then - return false, { Err(t1, "cannot match against all alternatives of the polymorphic type") } - end - end - return true - + local emptytable_relations = { + ["array"] = compare_true, + ["map"] = compare_true, + ["tupletable"] = compare_true, + ["interface"] = function(_a, b) + return not b.is_userdata + end, + ["record"] = function(_a, b) + return not b.is_userdata + end, + } - elseif t1.typename == "poly" then - for _, t in ipairs(t1.types) do - if is_a(t, t2, for_equality) then + local eqtype_relations + eqtype_relations = { + ["bad_nominal"] = { + ["*"] = compare_false, + }, + ["typevar"] = { + ["typevar"] = function(a, b) + if a.typevar == b.typevar then return true end - end - return false, { Err(t1, "cannot match against any alternatives of the polymorphic type") } - elseif t1.typename == "nominal" and t2.typename == "nominal" then - local t1r = resolve_tuple_and_nominal(t1) - local t2r = resolve_tuple_and_nominal(t2) - if t1r.typename == "union" or t2r.typename == "union" then - return is_a(t1r, t2r, for_equality) - end - return are_same_nominals(t1, t2) - elseif t1.typename == "enum" and t2.typename == "string" then - local ok - if for_equality then - ok = t2.tk and t1.enumset[unquote(t2.tk)] - else - ok = true - end - if ok then - return true - else - return false, { Err(t1, "enum is incompatible with %s", t2) } - end - elseif t1.typename == "integer" and t2.typename == "number" then - return true - elseif t1.typename == "string" and t2.typename == "enum" then - local ok = t1.tk and t2.enumset[unquote(t1.tk)] - if ok then + return compare_or_infer_typevar(b.typevar, a, nil, same_type) + end, + ["*"] = function(a, b) + return compare_or_infer_typevar(a.typevar, nil, b, same_type) + end, + }, + ["emptytable"] = emptytable_relations, + ["tupletable"] = { + ["tupletable"] = function(a, b) + for i = 1, math.min(#a.types, #b.types) do + if not same_type(a.types[i], b.types[i]) then + return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + end + end + if #a.types ~= #b.types then + return false, { Err(a, "tuples have different size", a, b) } + end return true - else - if t1.tk then - return false, { Err(t1, "%s is not a member of %s", t1, t2) } - else - return false, { Err(t1, "string is not a %s", t2) } + end, + }, + ["array"] = { + ["array"] = function(a, b) + return same_type(a.elements, b.elements) + end, + }, + ["map"] = { + ["map"] = function(a, b) + return compare_map(a.keys, b.keys, a.values, b.values, true) + end, + }, + ["union"] = { + ["union"] = function(a, b) + return (has_all_types_of(a.types, b.types) and + has_all_types_of(b.types, a.types)) + end, + }, + ["nominal"] = { + ["nominal"] = are_same_nominals, + }, + ["record"] = { + ["record"] = eqtype_record, + }, + ["function"] = { + ["function"] = function(a, b) + local argdelta = a.is_method and 1 or 0 + if #a.args ~= #b.args then + if a.is_method ~= b.is_method then + return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + end + return false, { Err(a, "different number of input arguments: got " .. #a.args - argdelta .. ", expected " .. #b.args - argdelta) } end - end - elseif t1.typename == "nominal" or t2.typename == "nominal" then - local t1r = resolve_tuple_and_nominal(t1) - local t2r = resolve_tuple_and_nominal(t2) - local ok, errs = is_a(t1r, t2r, for_equality) - if errs and #errs == 1 then - if errs[1].msg:match("^got ") then - + if #a.rets ~= #b.rets then + return false, { Err(a, "different number of return values: got " .. #a.rets .. ", expected " .. #b.rets) } + end + local errs = {} + for i = 1, #a.args do + arg_check(a, same_type, a.args[i], b.args[i], i - argdelta, errs, "argument") + end + for i = 1, #a.rets do + arg_check(a, same_type, a.rets[i], b.rets[i], i, errs, "return") + end + return any_errors(errs) + end, + }, + ["*"] = { + ["bad_nominal"] = compare_false, + ["typevar"] = function(a, b) + return compare_or_infer_typevar(b.typevar, a, nil, same_type) + end, + }, + } - errs = { Err(t1, "got %s, expected %s", t1, t2) } + local subtype_relations + subtype_relations = { + ["bad_nominal"] = { + ["*"] = compare_false, + }, + ["tuple"] = { + ["tuple"] = function(a, b) + if #a ~= #b then + return false end - end - return ok, errs - elseif t1.typename == "emptytable" and is_lua_table_type(t2) then - return true - elseif t2.typename == "array" then - if is_array_type(t1) then - if is_a(t1.elements, t2.elements) then - local t1e = resolve_tuple_and_nominal(t1.elements) - local t2e = resolve_tuple_and_nominal(t2.elements) - if t2e.typename == "enum" and t1e.typename == "string" and #t1.types > 1 then - for i = 2, #t1.types do - local t = t1.types[i] - if not is_a(t, t2e) then - return false, { Err(t, "%s is not a member of %s", t, t2e) } - end - end + for i = 1, #a do + if not is_a(a[i], b[i]) then + return false end + end + return true + end, + ["*"] = function(a, b) + return is_a(resolve_tuple(a), b) + end, + }, + ["typevar"] = { + ["typevar"] = function(a, b) + if a.typevar == b.typevar then return true end - elseif t1.typename == "tupletable" then - if t2.inferred_len and t2.inferred_len > #t1.types then - return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } + + return compare_or_infer_typevar(b.typevar, a, nil, is_a) + end, + ["*"] = function(a, b) + return compare_or_infer_typevar(a.typevar, nil, b, is_a) + end, + }, + ["nil"] = { + ["*"] = compare_true, + }, + ["union"] = { + ["union"] = function(a, b) + local used = {} + for _, t in ipairs(a.types) do + begin_scope() + local u = exists_supertype_in(t, b) + end_scope() + if not u then + return false + end + if not used[u] then + used[u] = t + end end - local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) - if not t1a then - return false, err + for u, t in pairs(used) do + is_a(t, u) end - if not is_a(t1a, t2) then - return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } + return true + end, + ["*"] = function(a, b) + for _, t in ipairs(a.types) do + if not is_a(t, b) then + return false + end end return true - elseif t1.typename == "map" then - local _, errs_keys, errs_values - _, errs_keys = is_a(t1.keys, INTEGER) - _, errs_values = is_a(t1.values, t2.elements) - return combine_map_errs(errs_keys, errs_values) - end - elseif t2.typename == "record" then + end, + }, + ["poly"] = { + ["*"] = function(a, b) + if exists_supertype_in(b, a) then + return true + end + return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + end, + }, + ["nominal"] = { + ["nominal"] = function(a, b) + local ra = resolve_nominal(a) + local rb = resolve_nominal(b) + if ra.typename == "union" or rb.typename == "union" then + return is_a(ra, rb) + end + + return are_same_nominals(a, b) + end, + ["*"] = subtype_nominal, + }, + ["enum"] = { + ["string"] = compare_true, + }, + ["string"] = { + ["enum"] = function(a, b) + if not a.tk then + return false, { Err(a, "string is not a %s", b) } + end - if t1.typename == "tupletable" and t2.elements then - if t2.inferred_len and t2.inferred_len > #t1.types then - return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } + if b.enumset[unquote(a.tk)] then + return true end - local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) - if not t1a then - return false, err + + return false, { Err(a, "%s is not a member of %s", a, b) } + end, + }, + ["integer"] = { + ["number"] = compare_true, + }, + ["interface"] = { + ["array"] = subtype_array, + ["record"] = subtype_record, + ["tupletable"] = function(a, b) + return subtype_relations["record"]["tupletable"](a, b) + end, + }, + ["emptytable"] = emptytable_relations, + ["tupletable"] = { + ["tupletable"] = function(a, b) + for i = 1, math.min(#a.types, #b.types) do + if not is_a(a.types[i], b.types[i]) then + return false, { Err(a, "in tuple entry " .. +tostring(i) .. ": got %s, expected %s", +a.types[i], b.types[i]), } + end end - if not is_a(t1a, t2) then - return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } + if #a.types > #b.types then + return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } end return true - end - if t1.elements and t2.elements then - if not is_a(t1.elements, t2.elements) then - return false, { Err(t1, "array parts have incompatible element types") } + end, + ["record"] = function(a, b) + if b.elements then + return subtype_relations["tupletable"]["array"](a, b) end - if t1.typename == "array" then - return true + end, + ["array"] = function(a, b) + if b.inferred_len and b.inferred_len > #a.types then + return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - end - - if is_record_type(t1) then - return match_fields_to_record(t1, t2, false) - elseif is_typetype(t1) and is_record_type(t1.def) then - return is_a(t1.def, t2, for_equality) - end - elseif t2.typename == "map" then - if t1.typename == "map" then - local _, errs_keys, errs_values - if t2.keys.typename ~= "any" then - _, errs_keys = same_type(t2.keys, t1.keys) + local aa, err = arraytype_from_tuple(a.inferred_at, a) + if not aa then + return false, err end - if t2.values.typename ~= "any" then - _, errs_values = same_type(t1.values, t2.values) + if not is_a(aa, b) then + return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } end - return combine_map_errs(errs_keys, errs_values) - elseif t1.typename == "array" or t1.typename == "tupletable" then - local elements - if t1.typename == "tupletable" then - local arr_type = arraytype_from_tuple(t1.inferred_at, t1) - if not arr_type then - return false, { Err(t1, "Unable to convert tuple %s to map", t1) } - end - elements = arr_type.elements - else - elements = t1.elements - end - local _, errs_keys, errs_values - _, errs_keys = is_a(INTEGER, t2.keys) - _, errs_values = is_a(elements, t2.values) - return combine_map_errs(errs_keys, errs_values) - elseif is_record_type(t1) then - if not is_a(t2.keys, STRING) then - return false, { Err(t1, "can't match a record to a map with non-string keys") } - end - if t2.keys.typename == "enum" then - for _, k in ipairs(t1.field_order) do - if not t2.keys.enumset[k] then - return false, { Err(t1, "key is not an enum value: " .. k) } - end - end + return true + end, + ["map"] = function(a, b) + local aa = arraytype_from_tuple(a.inferred_at, a) + if not aa then + return false, { Err(a, "Unable to convert tuple %s to map", a) } end - return match_fields_to_map(t1, t2) - end - elseif t2.typename == "tupletable" then - if t1.typename == "tupletable" then - for i = 1, math.min(#t1.types, #t2.types) do - if not is_a(t1.types[i], t2.types[i], for_equality) then - return false, { Err(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) } + + return compare_map(INTEGER, b.keys, aa.elements, b.values) + end, + }, + ["record"] = { + ["record"] = subtype_record, + ["array"] = subtype_array, + ["map"] = function(a, b) + if not is_a(b.keys, STRING) then + return false, { Err(a, "can't match a record to a map with non-string keys") } + end + + for _, k in ipairs(a.field_order) do + if b.keys.typename == "enum" and not b.keys.enumset[k] then + return false, { Err(a, "key is not an enum value: " .. k) } + end + if not is_a(a.fields[k], b.values) then + return false, { Err(a, "record is not a valid map; not all fields have the same type") } end end - if for_equality and #t1.types ~= #t2.types then - return false, { Err(t1, "tuples are not the same size") } + + return true + end, + ["tupletable"] = function(a, b) + if a.elements then + return subtype_relations["array"]["tupletable"](a, b) end - if #t1.types > #t2.types then - return false, { Err(t1, "tuple %s is too big for tuple %s", t1, t2) } + end, + }, + ["array"] = { + ["array"] = subtype_array, + ["record"] = function(a, b) + if b.elements then + return subtype_array(a, b) end - return true - elseif is_array_type(t1) then - if t1.inferred_len and t1.inferred_len > #t2.types then - return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) } + end, + ["map"] = function(a, b) + return compare_map(INTEGER, b.keys, a.elements, b.values) + end, + ["tupletable"] = function(a, b) + local alen = a.inferred_len or 0 + if alen > #b.types then + return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end - local len = (t1.inferred_len and t1.inferred_len > 0) and - t1.inferred_len or - #t2.types - - for i = 1, len do - if not is_a(t1.elements, t2.types[i], for_equality) then - return false, { Err(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) } + for i = 1, (alen > 0) and alen or #b.types do + if not is_a(a.elements, b.types[i]) then + return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true - end - elseif t1.typename == "function" and t2.typename == "function" then - local all_errs = {} - if (not t2.args.is_va) and #t1.args > #t2.args then - table.insert(all_errs, Err(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) - else - for i = ((t1.is_method or t2.is_method) and 2 or 1), #t1.args do - arg_check(nil, is_a, t1.args[i], t2.args[i] or ANY, i, all_errs, "argument") + end, + }, + ["map"] = { + ["map"] = function(a, b) + return compare_map(a.keys, b.keys, a.values, b.values) + end, + ["array"] = function(a, b) + return compare_map(a.keys, INTEGER, a.values, b.elements) + end, + }, + ["typetype"] = { + ["record"] = function(a, b) + return subtype_record(a.def, b) + end, + }, + ["function"] = { + ["function"] = function(a, b) + local errs = {} + + local aa, ba = a.args, b.args + set_min_arity(a) + set_min_arity(b) + if (not ba.is_va) and a.min_arity > b.min_arity then + table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", aa, ba)) + else + for i = ((a.is_method or b.is_method) and 2 or 1), #aa do + arg_check(nil, is_a, aa[i], ba[i] or ANY, i, errs, "argument") + end end - end - local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets.is_va - if #t1.rets < #t2.rets and not diff_by_va then - table.insert(all_errs, Err(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) - else - local nrets = #t2.rets - if diff_by_va then - nrets = nrets - 1 + + local ar, br = a.rets, b.rets + local diff_by_va = #br - #ar == 1 and br.is_va + if #ar < #br and not diff_by_va then + table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", ar, br)) + else + local nrets = #br + if diff_by_va then + nrets = nrets - 1 + end + for i = 1, nrets do + arg_check(nil, is_a, ar[i], br[i], i, errs, "return") + end end - for i = 1, nrets do - local _, errs = is_a(t1.rets[i], t2.rets[i]) - add_errs_prefixing(nil, errs, all_errs, "return " .. i .. ": ") + + return any_errors(errs) + end, + }, + ["*"] = { + ["bad_nominal"] = compare_false, + ["any"] = compare_true, + ["tuple"] = function(a, b) + return is_a(TUPLE({ a }), b) + end, + ["typevar"] = function(a, b) + return compare_or_infer_typevar(b.typevar, a, nil, is_a) + end, + ["union"] = exists_supertype_in, + + + ["nominal"] = subtype_nominal, + ["poly"] = function(a, b) + for _, t in ipairs(b.types) do + if not is_a(a, t) then + return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + end end - end - if #all_errs == 0 then return true - else - return false, all_errs - end - elseif lax and ((not for_equality) and t2.typename == "boolean") then + end, + }, + } + + local type_priorities = { + + ["bad_nominal"] = 1, + ["tuple"] = 2, + ["typevar"] = 3, + ["nil"] = 4, + ["any"] = 5, + ["union"] = 6, + ["poly"] = 7, + ["nominal"] = 8, + + ["enum"] = 9, + ["string"] = 9, + ["integer"] = 9, + ["boolean"] = 9, + + ["interface"] = 10, + + ["emptytable"] = 11, + ["tupletable"] = 12, + + ["record"] = 13, + ["array"] = 13, + ["map"] = 13, + ["function"] = 13, + } + + if lax then + type_priorities["unknown"] = 0 + + subtype_relations["unknown"] = {} + subtype_relations["unknown"]["*"] = compare_true + subtype_relations["*"]["unknown"] = compare_true + + subtype_relations["boolean"] = {} + subtype_relations["boolean"]["boolean"] = compare_true + subtype_relations["*"]["boolean"] = compare_true + end + + local function compare_types(relations, t1, t2) + if t1.typeid == t2.typeid then return true - elseif t1.typename == t2.typename then - return true end - return false, { Err(t1, "got %s, expected %s", t1, t2) } + local s1 = relations[t1.typename] + local fn = s1 and s1[t2.typename] + if not fn then + local p1 = type_priorities[t1.typename] or 999 + local p2 = type_priorities[t2.typename] or 999 + fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) + end + + local ok, err + if fn then + if fn == compare_true then + return true + end + ok, err = fn(t1, t2) + else + ok = t1.typename == t2.typename + end + if (not ok) and not err then + return false, { Err(t1, "got %s, expected %s", t1, t2) } + end + return ok, err + end + + + is_a = function(t1, t2) + return compare_types(subtype_relations, t1, t2) + end + + + same_type = function(t1, t2) + + + return compare_types(eqtype_relations, t1, t2) + end + + if TL_DEBUG then + local orig_is_a = is_a + is_a = function(t1, t2) + assert(type(t1) == "table") + assert(type(t2) == "table") + + if t1.typeid == t2.typeid then + local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {}) + assert(st1 == st2, st1 .. " ~= " .. st2) + return true + end + + return orig_is_a(t1, t2) + end end local function assert_is_a(where, t1, t2, context, name) @@ -8916,6 +8990,27 @@ tl.type_check = function(ast, opts) end end + local function typetype_to_nominal(where, name, t, resolved) + assert(t.typename == "typetype") + + local typevals + if t.def.typeargs then + typevals = {} + for _, a in ipairs(t.def.typeargs) do + table.insert(typevals, a_type({ typename = "typevar", typevar = a.typearg })) + end + end + return a_type({ + y = where.y, + x = where.x, + typename = "nominal", + typevals = typevals, + names = { name }, + found = t, + resolved = resolved, + }) + end + local function get_self_type(exp) if exp.kind == "type_identifier" then @@ -8925,21 +9020,7 @@ tl.type_check = function(ast, opts) end if t.typename == "typetype" then - local typevals - if t.def.typeargs then - typevals = {} - for _, a in ipairs(t.def.typeargs) do - table.insert(typevals, a_type({ typename = "typevar", typevar = a.typearg })) - end - end - return a_type({ - y = exp.y, - x = exp.x, - typename = "nominal", - typevals = typevals, - names = { exp.tk }, - found = t, - }) + return typetype_to_nominal(exp, exp.tk, t) else return t end @@ -10318,6 +10399,7 @@ tl.type_check = function(ast, opts) if final_tuple.is_va then tuple.is_va = true end + tuple[n] = nil for i, c in ipairs(final_tuple) do tuple[n + i - 1] = c end @@ -10981,11 +11063,17 @@ tl.type_check = function(ast, opts) - if is_a(b, a, true) or a.typename == "typevar" then + if ra.typename == "enum" and rb.typename == "string" then + if not (rb.tk and ra.enumset[unquote(rb.tk)]) then + return invalid_at(node, "%s is not a member of %s", b, a) + end + elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then + return invalid_at(node, "tuples are not the same size") + elseif is_a(b, a) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then node.known = EqFact({ var = node.e1.tk, typ = b, where = node }) end - elseif is_a(a, b, true) or b.typename == "typevar" then + elseif is_a(a, b) or b.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then node.known = EqFact({ var = node.e2.tk, typ = a, where = node }) end @@ -11147,14 +11235,7 @@ tl.type_check = function(ast, opts) end if is_typetype(t) then - t = a_type({ - y = node.y, - x = node.x, - typename = "nominal", - names = { node.tk }, - found = t, - resolved = t, - }) + t = typetype_to_nominal(node, node.tk, t, t) end return t diff --git a/tl.tl b/tl.tl index 0530721b8..18825c20c 100644 --- a/tl.tl +++ b/tl.tl @@ -1028,6 +1028,7 @@ local enum TypeName "invalid" -- producing a new value of this type (not propagating) must always produce a type error "unresolved" "none" + "*" end local table_types : {TypeName:boolean} = { @@ -1065,6 +1066,7 @@ local table_types : {TypeName:boolean} = { ["invalid"] = false, ["unresolved"] = false, ["none"] = false, + ["*"] = false, } local record Type @@ -1628,7 +1630,7 @@ local function parse_table_value(ps: ParseState, i: integer): integer, Node, int return i, e end -local function parse_table_item(ps: ParseState, i: integer, n: integer): integer, Node, integer +local function parse_table_item(ps: ParseState, i: integer, n?: integer): integer, Node, integer local node = new_node(ps.tokens, i, "table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i, "unexpected eof") @@ -4741,6 +4743,7 @@ local typename_to_typecode : {TypeName:integer} = { ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, ["nestedtype"] = tl.typecodes.UNKNOWN, + ["*"] = tl.typecodes.UNKNOWN, } local skip_types: {TypeName: boolean} = { @@ -6364,7 +6367,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local name = where.filename or filename if TL_DEBUG then - io.stderr:write("ERROR:" .. where.y .. ":" .. where.x .. ": " .. msg .. "\n") + io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") end return { @@ -6576,7 +6579,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return rt end - resolve_typevars = function(typ: Type, fn_var: ResolveType, fn_arg: ResolveType): boolean, Type, {Error} + resolve_typevars = function(typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:boolean} = {} @@ -6946,92 +6949,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return var end - local type CompareTypes = function(Type, Type, ? boolean): boolean, {Error} - - local function compare_and_infer_typevars(t1: Type, t2: Type, comp: CompareTypes): boolean, {Error} - -- if both are typevars and they are the same variable, nothing to do here - if t1.typevar == t2.typevar then - return true - end - - -- we have one typevar to compare to or infer to the other term - local typevar = t2.typevar or t1.typevar - - -- does the typevar currently match to a type? - local vt = find_var_type(typevar) - if vt then - -- If so, compare it to the other type - if t2.typevar then - return comp(t1, vt) - else - return comp(vt, t2) - end - else - -- otherwise, infer it to the other type - local other = t2.typevar and t1 or t2 - local ok, resolved, errs = resolve_typevars(other) - if not ok then - return false, errs - end - if resolved.typename ~= "unknown" then - resolved = resolve_typetype(resolved) - add_var(nil, typevar, resolved) - end - return true - end - end + local type CompareTypes = function(Type, Type): boolean, {Error} local same_type: function(t1: Type, t2: Type): boolean, {Error} - local is_a: function(Type, Type, ? boolean): boolean, {Error} - - local type TypeGetter = function(string): Type - - local function match_record_fields(rec1: Type, t2: TypeGetter, invariant: boolean): boolean, {Error} - local fielderrs: {Error} = {} - for _, k in ipairs(rec1.field_order) do - local f = rec1.fields[k] - local t2k = t2(k) - if t2k == nil then - if (not lax) and invariant then - table.insert(fielderrs, Err(f, "unknown field " .. k)) - end - else - local ok, errs: boolean, {Error} - if invariant then - ok, errs = same_type(f, t2k) - else - ok, errs = is_a(f, t2k) - end - if not ok then - add_errs_prefixing(nil, errs, fielderrs, "record field doesn't match: " .. k .. ": ") - end - end - end - if #fielderrs > 0 then - return false, fielderrs - end - return true - end - - local function match_fields_to_record(rec1: Type, rec2: Type, invariant: boolean): boolean, {Error} - if rec1.is_userdata ~= rec2.is_userdata then - return false, { Err(rec1, "userdata record doesn't match: %s", rec2) } - end - local ok, fielderrs = match_record_fields(rec1, function(k: string): Type return rec2.fields[k] end, invariant) - if not ok then - local errs = {} - add_errs_prefixing(nil, fielderrs, errs, show_type(rec1) .. " is not a " .. show_type(rec2) .. ": ") - return false, errs - end - return true - end - - local function match_fields_to_map(rec1: Type, map: Type): boolean, {Error} - if not match_record_fields(rec1, function(_: string): Type return map.values end, false) then - return false, { Err(rec1, "record is not a valid map; not all fields have the same type") } - end - return true - end + local is_a: function(Type, Type): boolean, {Error} local function arg_check(where: Where, cmp: CompareTypes, a: Type, b: Type, n: integer, errs: {Error}, ctx: string): boolean local matches, match_errs = cmp(a, b) @@ -7042,11 +6963,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function has_all_types_of(t1s: {Type}, t2s: {Type}, cmp: CompareTypes): boolean + local function has_all_types_of(t1s: {Type}, t2s: {Type}): boolean for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if cmp(t2, t1) then + if same_type(t2, t1) then found = true break end @@ -7290,6 +7211,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false end + local function fail_nominals(t1: Type, t2: Type): boolean, {Error} + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + local t1r = resolve_nominal(t1) + if t1r.filename then + t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" + end + local t2r = resolve_nominal(t2) + if t2r.filename then + t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + end + end + return false, { Err(t1, t1name .. " is not a " .. t2name) } + end + local function are_same_nominals(t1: Type, t2: Type): boolean, {Error} local same_names: boolean if t1.found and t2.found then @@ -7314,149 +7251,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if same_names then - if t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local all_errs = {} - for i = 1, #t1.typevals do - local _, errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, errs, all_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") - end - if #all_errs == 0 then - return true - else - return false, all_errs - end - end - else - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" - end - end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - end - - local is_lua_table_type: function(t: Type): boolean - local resolve_tuple_and_nominal: function(t: Type): Type = nil - - local function invariant_match_fields_to_record(t1: Type, t2: Type): boolean, {Error} - local ok, errs = match_fields_to_record(t1, t2, true) - if not ok then - return ok, errs - end - ok, errs = match_fields_to_record(t2, t1, true) - if not ok then - return ok, errs - end - return true - end - - -- invariant type comparison - same_type = function(t1: Type, t2: Type): boolean, {Error} - assert(type(t1) == "table") - assert(type(t2) == "table") - - if t1.typeid == t2.typeid then - if TL_DEBUG then - local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {}) - assert(st1 == st2, st1 .. " ~= " .. st2) - end - return true - end - - if t1.typename == "typevar" or t2.typename == "typevar" then - return compare_and_infer_typevars(t1, t2, same_type) - end - - if t1.typename == "emptytable" and is_lua_table_type(resolve_tuple_and_nominal(t2)) then + if not same_names then + return fail_nominals(t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then return true - end - - if t2.typename == "emptytable" and is_lua_table_type(resolve_tuple_and_nominal(t1)) then - return true - end - - if t1.typename ~= t2.typename then - return false, { Err(t1, "got %s, expected %s", t1, t2) } - end - if t1.typename == "array" then - return same_type(t1.elements, t2.elements) - elseif t1.typename == "tupletable" then - local all_errs = {} - for i = 1, math.min(#t1.types, #t2.types) do - local ok, err = same_type(t1.types[i], t2.types[i]) - if not ok then - add_errs_prefixing(t1, err, all_errs, "values") - end - end - return any_errors(all_errs) - elseif t1.typename == "map" then - local all_errs = {} - local k_ok, k_errs = same_type(t1.keys, t2.keys) - if not k_ok then - add_errs_prefixing(t1, k_errs, all_errs, "keys") - end - local v_ok, v_errs = same_type(t1.values, t2.values) - if not v_ok then - add_errs_prefixing(t1, v_errs, all_errs, "values") - end - return any_errors(all_errs) - elseif t1.typename == "union" then - if has_all_types_of(t1.types, t2.types, same_type) - and has_all_types_of(t2.types, t1.types, same_type) then - return true - else - return false, { Err(t1, "got %s, expected %s", t1, t2) } - end - elseif t1.typename == "nominal" then - return are_same_nominals(t1, t2) - elseif t1.typename == "record" then - -- checking array interface - if (t1.elements ~= nil) ~= (t2.elements ~= nil) then - return false, { Err(t1, "types do not have the same array interface") } - end - if t1.elements and t2.elements then - local ok, errs = same_type(t1.elements, t2.elements) - if not ok then - return ok, errs - end - end - - return invariant_match_fields_to_record(t1, t2) - elseif t1.typename == "function" then - local argdelta = t1.is_method and 1 or 0 - if #t1.args ~= #t2.args then - if t1.is_method ~= t2.is_method then - return false, { Err(t1, "different number of input arguments: method and non-method are not the same type") } - end - return false, { Err(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta) } - end - if #t1.rets ~= #t2.rets then - return false, { Err(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets) } - end - local all_errs = {} - for i = 1, #t1.args do - arg_check(t1 as Node, same_type, t1.args[i], t2.args[i], i - argdelta, all_errs, "argument") - end - for i = 1, #t1.rets do - local _, errs = same_type(t1.rets[i], t2.rets[i]) - add_errs_prefixing(t1, errs, all_errs, "return " .. i) + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) + add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") end - return any_errors(all_errs) + return any_errors(errs) end return true end + local is_lua_table_type: function(t: Type): boolean + local resolve_tuple_and_nominal: function(t: Type): Type + local function unite(types: {Type}, flatten_constants?: boolean): Type if #types == 1 then return types[1] @@ -7518,25 +7330,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function add_map_errors(errs: {Error}, ctx: string, ctx_errs: {Error}) - if ctx_errs then - for _, err in ipairs(ctx_errs) do - err.msg = ctx .. err.msg - table.insert(errs, err) - end - end - end - - local function combine_map_errs(key_errs: {Error}, value_errs: {Error}): boolean, {Error} - if not key_errs and not value_errs then - return true - end - local errs: {Error} = {} - add_map_errors(errs, "in map key: ", key_errs) - add_map_errors(errs, "in map value: ", value_errs) - return false, errs - end - do local known_table_types: {TypeName:boolean} = { array = true, @@ -7571,7 +7364,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } for i = 2, #tupletype.types do arr_type = expand_type(where, arr_type, a_type { elements = tupletype.types[i], typename = "array" }) - if not arr_type or not arr_type.elements then + if not arr_type.elements then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end end @@ -7582,349 +7375,630 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t.typename == "nominal" and t.names[1] == "@self" end - -- subtyping comparison - is_a = function(t1: Type, t2: Type, for_equality: boolean): boolean, {Error} - assert(type(t1) == "table") - assert(type(t2) == "table") + local function compare_false(_: Type, _: Type): boolean, {Error} + return false + end - if lax and (is_unknown(t1) or is_unknown(t2)) then + local function compare_true(_: Type, _: Type): boolean, {Error} + return true + end + + local function subtype_nominal(a: Type, b: Type): boolean, {Error} + if is_self(a) and is_self(b) then return true end - if t1.typeid == t2.typeid then - if TL_DEBUG then - local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {}) - assert(st1 == st2, st1 .. " ~= " .. st2) + local ra = a.typename == "nominal" and resolve_nominal(a) or a + local rb = b.typename == "nominal" and resolve_nominal(b) or b + local ok, errs = is_a(ra, rb) + if errs and #errs == 1 and errs[1].msg:match("^got ") then + return false -- translate to got-expected error with unresolved types + end + return ok, errs + end + + local function subtype_array(a: Type, b: Type): boolean, {Error} + -- assert(b.typename == "array") + if (not a.elements) or (not is_a(a.elements, b.elements)) then + return false + end + if a.types and #a.types > 1 then + -- constant array, check elements (useful for array of enums) + for i = 1, #a.types do + local e = a.types[i] + if not is_a(e, b.elements) then + return false, { Err(a, "%s is not a member of %s", e, b.elements) } + end end - return true end + return true + end - if t1.typename == "bad_nominal" or t2.typename == "bad_nominal" then - return false -- an error has been generated elsewhere + local function subtype_record(a: Type, b: Type): boolean, {Error} + -- assert(b.typename == "record") + if a.elements and b.elements then + if not is_a(a.elements, b.elements) then + return false, { Err(a, "array parts have incompatible element types") } + end end - if t2.typename ~= "tuple" then - t1 = resolve_tuple(t1) + if a.is_userdata ~= b.is_userdata then + return false, { Err(a, a.is_userdata and "userdata is not a record" + or "record is not a userdata") } end - if t2.typename == "tuple" and t1.typename ~= "tuple" then - t1 = a_type { - typename = "tuple", - [1] = t1, - } + local errs: {Error} = {} + for _, k in ipairs(a.field_order) do + local ak = a.fields[k] + local bk = b.fields[k] + if bk then + local ok, fielderrs = is_a(ak, bk) + if not ok then + add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + end + end + end + if #errs > 0 then + for _, err in ipairs(errs) do + err.msg = show_type(a) .. " is not a " .. show_type(b) .. ": " .. err.msg + end + return false, errs + end + + return true + end + + local eqtype_record = function(a: Type, b: Type): boolean, {Error} + -- checking array interface + if (a.elements ~= nil) ~= (b.elements ~= nil) then + return false, { Err(a, "types do not have the same array interface") } + end + if a.elements then + local ok, errs = same_type(a.elements, b.elements) + if not ok then + return ok, errs + end end - if t1.typename == "typevar" or t2.typename == "typevar" then - return compare_and_infer_typevars(t1, t2, is_a) + local ok, errs = subtype_record(a, b) + if not ok then + return ok, errs + end + ok, errs = subtype_record(b, a) + if not ok then + return ok, errs end + return true + end - -- ∀ t, nil <: t - if t1.typename == "nil" then -- TODO nilable - return true + local function compare_map(ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} + local ok1, errs_k = same_type(ak, bk) + local ok2, errs_v = same_type(av, bv) + + -- FIXME hack for {any:any} + if bk.typename == "any" and not no_hack then + ok1, errs_k = true, nil + end + if bv.typename == "any" and not no_hack then + ok2, errs_v = true, nil end - -- ∀ t, t <: any - if t2.typename == "any" then + if ok1 and ok2 then return true + end - elseif is_self(t1) then - if is_self(t2) then + -- combine errs_k and errs_v, prefixing errors + for i = 1, errs_k and #errs_k or 0 do + errs_k[i].msg = "in map key: " .. errs_k[i].msg + end + for i = 1, errs_v and #errs_v or 0 do + errs_v[i].msg = "in map value: " .. errs_v[i].msg + end + if errs_k and errs_v then + for i = 1, #errs_v do + table.insert(errs_k, errs_v[i]) + end + return false, errs_k + end + return false, errs_k or errs_v + end + + local function compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} + -- assert((a == nil and b ~= nil) or (a ~= nil and b == nil)) + + -- does the typevar currently match to a type? + local vt = find_var_type(typevar) + if vt then + -- If so, compare it to the other type + return cmp(a or vt, b or vt) + else + -- otherwise, infer it to the other type + local ok, r, errs = resolve_typevars(a or b) + if not ok then + return false, errs + end + if r.typevar == typevar then return true end + add_var(nil, typevar, r) + return true + end + end - return is_a(resolve_tuple_and_nominal(t1), t2, for_equality) + -- ∃ x ∈ xs. t <: x + local function exists_supertype_in(t: Type, xs: Type): Type + for _, x in ipairs(xs.types) do + if is_a(t, x) then + return x + end + end + end - elseif is_self(t2) then - return is_a(t1, resolve_tuple_and_nominal(t2), for_equality) + -- emptytable rules are the same in eqtype_relations and subtype_relations + local emptytable_relations: {TypeName:CompareTypes} = { + ["array"] = compare_true, + ["map"] = compare_true, + ["tupletable"] = compare_true, + ["interface"] = function(_a: Type, b: Type): boolean, {Error} + return not b.is_userdata + end, + ["record"] = function(_a: Type, b: Type): boolean, {Error} + return not b.is_userdata + end, + } - elseif t1.typename == "union" then + local type TypeRelations = {TypeName:{TypeName:CompareTypes}} - -- ∀ t in t1, ∃ u in t2 t <: u - -- ─────────────────────────── -- a union t1 is a union t2 - -- t1 union <: t2 union -- if all of t1's types can be satisfied in t2 - if t2.typename == "union" then - local used = {} - for _, t in ipairs(t1.types) do - local ok = false - begin_scope() - for _, u in ipairs(t2.types) do - if not used[u] then - if is_a(t, u, for_equality) then - used[u] = t - ok = true - break - end - end - end - end_scope() -- don't preserve failed inferences - if not ok then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + local eqtype_relations: TypeRelations + eqtype_relations = { + ["bad_nominal"] = { + ["*"] = compare_false, + }, + ["typevar"] = { + ["typevar"] = function(a: Type, b: Type): boolean, {Error} + if a.typevar == b.typevar then + return true + end + + return compare_or_infer_typevar(b.typevar, a, nil, same_type) + end, + ["*"] = function(a: Type, b: Type): boolean, {Error} + return compare_or_infer_typevar(a.typevar, nil, b, same_type) + end, + }, + ["emptytable"] = emptytable_relations, + ["tupletable"] = { + ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + for i = 1, math.min(#a.types, #b.types) do + if not same_type(a.types[i], b.types[i]) then + return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end - -- preserve all valid inferences - for u, t in pairs(used) do - is_a(t, u, for_equality) + if #a.types ~= #b.types then + return false, { Err(a, "tuples have different size", a, b) } end return true + end, + }, + ["array"] = { + ["array"] = function(a: Type, b: Type): boolean, {Error} + return same_type(a.elements, b.elements) + end, + }, + ["map"] = { + ["map"] = function(a: Type, b: Type): boolean, {Error} + return compare_map(a.keys, b.keys, a.values, b.values, true) + end, + }, + ["union"] = { + ["union"] = function(a: Type, b: Type): boolean, {Error} + return (has_all_types_of(a.types, b.types) + and has_all_types_of(b.types, a.types)) + end, + }, + ["nominal"] = { + ["nominal"] = are_same_nominals, + }, + ["record"] = { + ["record"] = eqtype_record, + }, + ["function"] = { + ["function"] = function(a: Type, b: Type): boolean, {Error} + local argdelta = a.is_method and 1 or 0 + if #a.args ~= #b.args then + if a.is_method ~= b.is_method then + return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + end + return false, { Err(a, "different number of input arguments: got " .. #a.args - argdelta .. ", expected " .. #b.args - argdelta) } + end + if #a.rets ~= #b.rets then + return false, { Err(a, "different number of return values: got " .. #a.rets .. ", expected " .. #b.rets) } + end + local errs = {} + for i = 1, #a.args do + arg_check(a, same_type, a.args[i], b.args[i], i - argdelta, errs, "argument") + end + for i = 1, #a.rets do + arg_check(a, same_type, a.rets[i], b.rets[i], i, errs, "return") + end + return any_errors(errs) + end, + }, + ["*"] = { + ["bad_nominal"] = compare_false, + ["typevar"] = function(a: Type, b: Type): boolean, {Error} + return compare_or_infer_typevar(b.typevar, a, nil, same_type) + end, + }, + } - -- ∀ t in t1, t <: t2 - -- ────────────────── -- a union type t1 is a t2 - -- t1 union <: t2 -- if all of t1's types satisfy t2 - else - for _, t in ipairs(t1.types) do - if not is_a(t, t2, for_equality) then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + local subtype_relations: TypeRelations + subtype_relations = { + ["bad_nominal"] = { + ["*"] = compare_false, + }, + ["tuple"] = { + ["tuple"] = function(a: Type, b: Type): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] + if #a ~= #b then -- ────────────────────────────────── + return false -- a tuple <: b tuple + end + for i = 1, #a do + if not is_a(a[i], b[i]) then + return false end end return true - end - - -- ∃ t in t2, t1 <: t - -- ────────────────── -- a value of type t1 is a member of union type t2 - -- t1 <: t2 union -- if it is a member of some of t2's types - elseif t2.typename == "union" then - for _, t in ipairs(t2.types) do - if is_a(t1, t, for_equality) then + end, + ["*"] = function(a: Type, b: Type): boolean, {Error} + return is_a(resolve_tuple(a), b) + end, + }, + ["typevar"] = { + ["typevar"] = function(a: Type, b: Type): boolean, {Error} + if a.typevar == b.typevar then return true end - end - -- ∀ t in t2, t1 <: t - -- ────────────────── -- a type t1 is a poly type t2 - -- t1 <: t2 poly -- if all of t2's poly types are satisfied by t1 - elseif t2.typename == "poly" then - for _, t in ipairs(t2.types) do - if not is_a(t1, t, for_equality) then - return false, { Err(t1, "cannot match against all alternatives of the polymorphic type") } + return compare_or_infer_typevar(b.typevar, a, nil, is_a) + end, + ["*"] = function(a: Type, b: Type): boolean, {Error} + return compare_or_infer_typevar(a.typevar, nil, b, is_a) + end, + }, + ["nil"] = { + ["*"] = compare_true, + }, + ["union"] = { + ["union"] = function(a: Type, b: Type): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u + local used = {} -- ──────────────────────── + for _, t in ipairs(a.types) do -- a union <: b union + begin_scope() + local u = exists_supertype_in(t, b) + end_scope() -- don't preserve failed inferences + if not u then + return false + end + if not used[u] then -- FIXME the order of declared union items affects inference behavior + used[u] = t + end end - end - return true - - -- ∃ t in t1, t <: t2 - -- ────────────────── -- a poly type t1 is a t2 if - -- t1 poly <: t2 -- t2 is some of of the poly's types - elseif t1.typename == "poly" then - for _, t in ipairs(t1.types) do - if is_a(t, t2, for_equality) then - return true + for u, t in pairs(used) do + is_a(t, u) -- preserve valid inferences end - end - return false, { Err(t1, "cannot match against any alternatives of the polymorphic type") } - elseif t1.typename == "nominal" and t2.typename == "nominal" then - local t1r = resolve_tuple_and_nominal(t1) - local t2r = resolve_tuple_and_nominal(t2) - if t1r.typename == "union" or t2r.typename == "union" then - return is_a(t1r, t2r, for_equality) - end - - return are_same_nominals(t1, t2) - elseif t1.typename == "enum" and t2.typename == "string" then - local ok: boolean - if for_equality then - ok = t2.tk and t1.enumset[unquote(t2.tk)] - else - ok = true - end - if ok then return true - else - return false, { Err(t1, "enum is incompatible with %s", t2) } - end - elseif t1.typename == "integer" and t2.typename == "number" then - return true - elseif t1.typename == "string" and t2.typename == "enum" then - local ok = t1.tk and t2.enumset[unquote(t1.tk)] - if ok then + end, + ["*"] = function(a: Type, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b + for _, t in ipairs(a.types) do -- ──────────────── + if not is_a(t, b) then -- a union <: b + return false + end + end return true - else - if t1.tk then - return false, { Err(t1, "%s is not a member of %s", t1, t2) } - else - return false, { Err(t1, "string is not a %s", t2) } + end, + }, + ["poly"] = { + ["*"] = function(a: Type, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b + if exists_supertype_in(b, a) then -- ─────────────── + return true -- a poly <: b end - end - elseif t1.typename == "nominal" or t2.typename == "nominal" then - local t1r = resolve_tuple_and_nominal(t1) - local t2r = resolve_tuple_and_nominal(t2) - local ok, errs = is_a(t1r, t2r, for_equality) - if errs and #errs == 1 then - if errs[1].msg:match("^got ") then - --local got = t1.typename == "nominal" and t1.name or show_type(t1) - --local expected = t2.typename == "nominal" and t2.name or show_type(t2) - errs = { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + end, + }, + ["nominal"] = { + ["nominal"] = function(a: Type, b: Type): boolean, {Error} + local ra = resolve_nominal(a) + local rb = resolve_nominal(b) + -- match unions structurally + if ra.typename == "union" or rb.typename == "union" then + return is_a(ra, rb) + end + -- all other types nominally + return are_same_nominals(a, b) + end, + ["*"] = subtype_nominal, + }, + ["enum"] = { + ["string"] = compare_true, + }, + ["string"] = { + ["enum"] = function(a: Type, b: Type): boolean, {Error} + if not a.tk then + return false, { Err(a, "string is not a %s", b) } end - end - return ok, errs - elseif t1.typename == "emptytable" and is_lua_table_type(t2) then - return true - elseif t2.typename == "array" then - if is_array_type(t1) then - if is_a(t1.elements, t2.elements) then - local t1e = resolve_tuple_and_nominal(t1.elements) - local t2e = resolve_tuple_and_nominal(t2.elements) - if t2e.typename == "enum" and t1e.typename == "string" and #t1.types > 1 then - for i = 2, #t1.types do - local t = t1.types[i] - if not is_a(t, t2e) then - return false, { Err(t, "%s is not a member of %s", t, t2e) } - end - end - end + + if b.enumset[unquote(a.tk)] then return true end - elseif t1.typename == "tupletable" then - if t2.inferred_len and t2.inferred_len > #t1.types then - return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } - end - local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) - if not t1a then - return false, err + + return false, { Err(a, "%s is not a member of %s", a, b) } + end, + }, + ["integer"] = { + ["number"] = compare_true, + }, + ["interface"] = { + ["array"] = subtype_array, + ["record"] = subtype_record, + ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + return subtype_relations["record"]["tupletable"](a, b) + end, + }, + ["emptytable"] = emptytable_relations, + ["tupletable"] = { + ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + for i = 1, math.min(#a.types, #b.types) do + if not is_a(a.types[i], b.types[i]) then + return false, { Err(a, "in tuple entry " + .. tostring(i) .. ": got %s, expected %s", + a.types[i], b.types[i]) } + end end - if not is_a(t1a, t2) then - return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } + if #a.types > #b.types then + return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } end return true - elseif t1.typename == "map" then - local _, errs_keys, errs_values: any, {Error}, {Error} - _, errs_keys = is_a(t1.keys, INTEGER) - _, errs_values = is_a(t1.values, t2.elements) - return combine_map_errs(errs_keys, errs_values) - end - elseif t2.typename == "record" then - - -- checking array interface - if t1.typename == "tupletable" and t2.elements then - if t2.inferred_len and t2.inferred_len > #t1.types then - return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) } + end, + ["record"] = function(a: Type, b: Type): boolean, {Error} + if b.elements then + return subtype_relations["tupletable"]["array"](a, b) + end + end, + ["array"] = function(a: Type, b: Type): boolean, {Error} + if b.inferred_len and b.inferred_len > #a.types then + return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) - if not t1a then + local aa, err = arraytype_from_tuple(a.inferred_at, a) + if not aa then return false, err end - if not is_a(t1a, t2) then - return false, { Err(t2, "got %s (from %s), expected %s", t1a, t1, t2) } + if not is_a(aa, b) then + return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } end return true - end - if t1.elements and t2.elements then - if not is_a(t1.elements, t2.elements) then - return false, { Err(t1, "array parts have incompatible element types") } - end - if t1.typename == "array" then - return true + end, + ["map"] = function(a: Type, b: Type): boolean, {Error} + local aa = arraytype_from_tuple(a.inferred_at, a) + if not aa then + return false, { Err(a, "Unable to convert tuple %s to map", a) } end - end - if is_record_type(t1) then - return match_fields_to_record(t1, t2, false) - elseif is_typetype(t1) and is_record_type(t1.def) then -- record as prototype - return is_a(t1.def, t2, for_equality) - end - elseif t2.typename == "map" then - if t1.typename == "map" then - local _, errs_keys, errs_values: any, {Error}, {Error} - if t2.keys.typename ~= "any" then -- FIXME hack for {any:any} - _, errs_keys = same_type(t2.keys, t1.keys) - end - if t2.values.typename ~= "any" then -- FIXME hack for {any:any} - _, errs_values = same_type(t1.values, t2.values) + return compare_map(INTEGER, b.keys, aa.elements, b.values) + end, + }, + ["record"] = { + ["record"] = subtype_record, + ["array"] = subtype_array, + ["map"] = function(a: Type, b: Type): boolean, {Error} + if not is_a(b.keys, STRING) then + return false, { Err(a, "can't match a record to a map with non-string keys") } end - return combine_map_errs(errs_keys, errs_values) - elseif t1.typename == "array" or t1.typename == "tupletable" then - local elements: Type - if t1.typename == "tupletable" then - local arr_type = arraytype_from_tuple(t1.inferred_at, t1) - if not arr_type then - return false, { Err(t1, "Unable to convert tuple %s to map", t1) } - end - elements = arr_type.elements - else - elements = t1.elements - end - local _, errs_keys, errs_values: any, {Error}, {Error} - _, errs_keys = is_a(INTEGER, t2.keys) - _, errs_values = is_a(elements, t2.values) - return combine_map_errs(errs_keys, errs_values) - elseif is_record_type(t1) then -- FIXME - if not is_a(t2.keys, STRING) then - return false, { Err(t1, "can't match a record to a map with non-string keys") } - end - if t2.keys.typename == "enum" then - for _, k in ipairs(t1.field_order) do - if not t2.keys.enumset[k] then - return false, { Err(t1, "key is not an enum value: " .. k) } - end + + for _, k in ipairs(a.field_order) do + if b.keys.typename == "enum" and not b.keys.enumset[k] then + return false, { Err(a, "key is not an enum value: " .. k) } end - end - return match_fields_to_map(t1, t2) - end - elseif t2.typename == "tupletable" then - if t1.typename == "tupletable" then - for i = 1, math.min(#t1.types, #t2.types) do - if not is_a(t1.types[i], t2.types[i], for_equality) then - return false, { Err(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) } + if not is_a(a.fields[k], b.values) then + return false, { Err(a, "record is not a valid map; not all fields have the same type") } end end - if for_equality and #t1.types ~= #t2.types then - return false, { Err(t1, "tuples are not the same size") } + + return true + end, + ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + if a.elements then + return subtype_relations["array"]["tupletable"](a, b) end - if #t1.types > #t2.types then - return false, { Err(t1, "tuple %s is too big for tuple %s", t1, t2) } + end, + }, + ["array"] = { + ["array"] = subtype_array, + ["record"] = function(a: Type, b: Type): boolean, {Error} + if b.elements then + return subtype_array(a, b) end - return true - elseif is_array_type(t1) then - if t1.inferred_len and t1.inferred_len > #t2.types then - return false, { Err(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) } + end, + ["map"] = function(a: Type, b: Type): boolean, {Error} + return compare_map(INTEGER, b.keys, a.elements, b.values) + end, + ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + local alen = a.inferred_len or 0 + if alen > #b.types then + return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end -- for array literals (which is the only case where inferred_len is defined), - -- only check the entries present - local len = (t1.inferred_len and t1.inferred_len > 0) - and t1.inferred_len - or #t2.types - - for i = 1, len do - if not is_a(t1.elements, t2.types[i], for_equality) then - return false, { Err(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) } + -- only check the entries that are present + for i = 1, (alen > 0) and alen or #b.types do + if not is_a(a.elements, b.types[i]) then + return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true - end - elseif t1.typename == "function" and t2.typename == "function" then - local all_errs = {} - if (not t2.args.is_va) and #t1.args > #t2.args then - table.insert(all_errs, Err(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) - else - for i = ((t1.is_method or t2.is_method) and 2 or 1), #t1.args do - arg_check(nil, is_a, t1.args[i], t2.args[i] or ANY, i, all_errs, "argument") + end, + }, + ["map"] = { + ["map"] = function(a: Type, b: Type): boolean, {Error} + return compare_map(a.keys, b.keys, a.values, b.values) + end, + ["array"] = function(a: Type, b: Type): boolean, {Error} + return compare_map(a.keys, INTEGER, a.values, b.elements) + end, + }, + ["typetype"] = { + ["record"] = function(a: Type, b: Type): boolean, {Error} + return subtype_record(a.def, b) -- record as prototype + end, + }, + ["function"] = { + ["function"] = function(a: Type, b: Type): boolean, {Error} + local errs = {} + + local aa, ba = a.args, b.args + set_min_arity(a) + set_min_arity(b) + if (not ba.is_va) and a.min_arity > b.min_arity then + table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", aa, ba)) + else + for i = ((a.is_method or b.is_method) and 2 or 1), #aa do + arg_check(nil, is_a, aa[i], ba[i] or ANY, i, errs, "argument") + end end - end - local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets.is_va - if #t1.rets < #t2.rets and not diff_by_va then - table.insert(all_errs, Err(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) - else - local nrets = #t2.rets - if diff_by_va then - nrets = nrets - 1 + + local ar, br = a.rets, b.rets + local diff_by_va = #br - #ar == 1 and br.is_va + if #ar < #br and not diff_by_va then + table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", ar, br)) + else + local nrets = #br + if diff_by_va then + nrets = nrets - 1 + end + for i = 1, nrets do + arg_check(nil, is_a, ar[i], br[i], i, errs, "return") + end end - for i = 1, nrets do - local _, errs = is_a(t1.rets[i], t2.rets[i]) - add_errs_prefixing(nil, errs, all_errs, "return " .. i .. ": ") + + return any_errors(errs) + end, + }, + ["*"] = { + ["bad_nominal"] = compare_false, + ["any"] = compare_true, + ["tuple"] = function(a: Type, b: Type): boolean, {Error} + return is_a(TUPLE({a}), b) + end, + ["typevar"] = function(a: Type, b: Type): boolean, {Error} + return compare_or_infer_typevar(b.typevar, a, nil, is_a) + end, + ["union"] = exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t + -- ─────────────── + -- a <: b union + ["nominal"] = subtype_nominal, + ["poly"] = function(a: Type, b: Type): boolean, {Error} -- ∀ t ∈ b, a <: t + for _, t in ipairs(b.types) do -- ─────────────── + if not is_a(a, t) then -- a <: b poly + return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + end end - end - if #all_errs == 0 then return true - else - return false, all_errs - end - elseif lax and ((not for_equality) and t2.typename == "boolean") then - -- in .lua files, all values can be used in a boolean context (but not in == or ~=) - return true - elseif t1.typename == t2.typename then + end, + }, + } + + -- evaluation strategy + local type_priorities: {TypeName:integer} = { + -- types that have catch-all rules evaluate first + ["bad_nominal"] = 1, + ["tuple"] = 2, + ["typevar"] = 3, + ["nil"] = 4, + ["any"] = 5, + ["union"] = 6, + ["poly"] = 7, + ["nominal"] = 8, + -- then base types + ["enum"] = 9, + ["string"] = 9, + ["integer"] = 9, + ["boolean"] = 9, + -- then interfaces + ["interface"] = 10, + -- then special cases of tables + ["emptytable"] = 11, + ["tupletable"] = 12, + -- then other recursive types + ["record"] = 13, + ["array"] = 13, + ["map"] = 13, + ["function"] = 13, + } + + if lax then + type_priorities["unknown"] = 0 + + subtype_relations["unknown"] = {} + subtype_relations["unknown"]["*"] = compare_true + subtype_relations["*"]["unknown"] = compare_true + -- in .lua files, all values can be used in a boolean context + subtype_relations["boolean"] = {} + subtype_relations["boolean"]["boolean"] = compare_true + subtype_relations["*"]["boolean"] = compare_true + end + + local function compare_types(relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} + if t1.typeid == t2.typeid then return true end - return false, { Err(t1, "got %s, expected %s", t1, t2) } + local s1 = relations[t1.typename] + local fn = s1 and s1[t2.typename] + if not fn then + local p1 = type_priorities[t1.typename] or 999 + local p2 = type_priorities[t2.typename] or 999 + fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) + end + + local ok, err: boolean, {Error} + if fn then + if fn == compare_true then + return true + end + ok, err = fn(t1, t2) + else + ok = t1.typename == t2.typename + end + if (not ok) and not err then + return false, { Err(t1, "got %s, expected %s", t1, t2) } + end + return ok, err + end + + -- subtyping comparison + is_a = function(t1: Type, t2: Type): boolean, {Error} + return compare_types(subtype_relations, t1, t2) + end + + -- invariant type comparison + same_type = function(t1: Type, t2: Type): boolean, {Error} + -- except for error messages, behavior is the same as + -- `return (is_a(t1, t2) and is_a(t2, t1))` + return compare_types(eqtype_relations, t1, t2) + end + + if TL_DEBUG then + local orig_is_a = is_a + is_a = function(t1: Type, t2: Type): boolean, {Error} + assert(type(t1) == "table") + assert(type(t2) == "table") + + if t1.typeid == t2.typeid then + local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {}) + assert(st1 == st2, st1 .. " ~= " .. st2) + return true + end + + return orig_is_a(t1, t2) + end end local function assert_is_a(where: Where, t1: Type, t2: Type, context: string, name?: string): boolean @@ -8916,6 +8990,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function typetype_to_nominal(where: Where, name: string, t: Type, resolved?: Type): Type + assert(t.typename == "typetype") + + local typevals: Type + if t.def.typeargs then + typevals = {} + for _, a in ipairs(t.def.typeargs) do + table.insert(typevals, a_type { typename = "typevar", typevar = a.typearg }) + end + end + return a_type { + y = where.y, + x = where.x, + typename = "nominal", + typevals = typevals, + names = { name }, + found = t, + resolved = resolved, + } + end + local function get_self_type(exp: Node): Type -- base if exp.kind == "type_identifier" then @@ -8925,21 +9020,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if t.typename == "typetype" then - local typevals: Type - if t.def.typeargs then - typevals = {} - for _, a in ipairs(t.def.typeargs) do - table.insert(typevals, a_type { typename = "typevar", typevar = a.typearg }) - end - end - return a_type { - y = exp.y, - x = exp.x, - typename = "nominal", - typevals = typevals, - names = { exp.tk }, - found = t, - } + return typetype_to_nominal(exp, exp.tk, t) else return t end @@ -10318,6 +10399,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if final_tuple.is_va then tuple.is_va = true end + tuple[n] = nil for i, c in ipairs(final_tuple) do tuple[n + i - 1] = c end @@ -10981,11 +11063,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end - if is_a(b, a, true) or a.typename == "typevar" then + if ra.typename == "enum" and rb.typename == "string" then + if not (rb.tk and ra.enumset[unquote(rb.tk)]) then + return invalid_at(node, "%s is not a member of %s", b, a) + end + elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then + return invalid_at(node, "tuples are not the same size") + elseif is_a(b, a) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then node.known = EqFact { var = node.e1.tk, typ = b, where = node } end - elseif is_a(a, b, true) or b.typename == "typevar" then + elseif is_a(a, b) or b.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then node.known = EqFact { var = node.e2.tk, typ = a, where = node } end @@ -11147,14 +11235,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if is_typetype(t) then - t = a_type { - y = node.y, - x = node.x, - typename = "nominal", - names = { node.tk }, - found = t, - resolved = t, - } + t = typetype_to_nominal(node, node.tk, t, t) end return t From 8a49055f6afe80520266103b3aeb64b2608b7575 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 9 Dec 2023 18:01:17 -0300 Subject: [PATCH 035/224] remove "nestedtype" Apparently we don't need it anymore! --- tl.lua | 14 +------------- tl.tl | 14 +------------- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/tl.lua b/tl.lua index 492412f75..f49ca92e2 100644 --- a/tl.lua +++ b/tl.lua @@ -1028,7 +1028,6 @@ end - local table_types = { @@ -1040,7 +1039,6 @@ local table_types = { ["tupletable"] = true, ["typetype"] = false, - ["nestedtype"] = false, ["typevar"] = false, ["typearg"] = false, ["function"] = false, @@ -1462,7 +1460,7 @@ local function is_number_type(t) end local function is_typetype(t) - return t.typename == "typetype" or t.typename == "nestedtype" + return t.typename == "typetype" end @@ -4742,7 +4740,6 @@ local typename_to_typecode = { ["table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, - ["nestedtype"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -4752,7 +4749,6 @@ local skip_types = { ["table_item"] = true, ["unresolved"] = true, ["typetype"] = true, - ["nestedtype"] = true, } local get_typenum @@ -11378,7 +11374,6 @@ a.types[i], b.types[i]), } for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then - typ2.typename = "nestedtype" local resolved, is_alias = resolve_nominal_typetype(typ2) if is_alias then typ2.is_alias = true @@ -11403,9 +11398,6 @@ a.types[i], b.types[i]), } end for name, _ in fields_of(typ) do local ftype = children[i] - if ftype.typename == "nestedtype" then - ftype.typename = "typetype" - end if ftype.is_method and ftype.args and ftype.args[1] and ftype.args[1].is_self then local record_name = typ.names and typ.names[1] @@ -11431,9 +11423,6 @@ a.types[i], b.types[i]), } end for name, _ in fields_of(typ, "meta") do local ftype = children[i] - if ftype.typename == "nestedtype" then - ftype.typename = "typetype" - end typ.meta_fields[name] = ftype i = i + 1 end @@ -11550,7 +11539,6 @@ a.types[i], b.types[i]), } visit_type.cbs["tupletable"] = visit_type.cbs["string"] visit_type.cbs["typetype"] = visit_type.cbs["string"] - visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] visit_type.cbs["interface"] = visit_type.cbs["record"] diff --git a/tl.tl b/tl.tl index 18825c20c..cd2e0bf18 100644 --- a/tl.tl +++ b/tl.tl @@ -996,7 +996,6 @@ end local enum TypeName "typetype" - "nestedtype" "typevar" "typearg" "function" @@ -1040,7 +1039,6 @@ local table_types : {TypeName:boolean} = { ["tupletable"] = true, ["typetype"] = false, - ["nestedtype"] = false, ["typevar"] = false, ["typearg"] = false, ["function"] = false, @@ -1462,7 +1460,7 @@ local function is_number_type(t:Type): boolean end local function is_typetype(t:Type): boolean - return t.typename == "typetype" or t.typename == "nestedtype" + return t.typename == "typetype" end local record ParseState @@ -4742,7 +4740,6 @@ local typename_to_typecode : {TypeName:integer} = { ["table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, - ["nestedtype"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -4752,7 +4749,6 @@ local skip_types: {TypeName: boolean} = { ["table_item"] = true, ["unresolved"] = true, ["typetype"] = true, - ["nestedtype"] = true, } local get_typenum: function(trenv: TypeReportEnv, t: Type): integer @@ -11378,7 +11374,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then - typ2.typename = "nestedtype" local resolved, is_alias = resolve_nominal_typetype(typ2) if is_alias then typ2.is_alias = true @@ -11403,9 +11398,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for name, _ in fields_of(typ) do local ftype = children[i] - if ftype.typename == "nestedtype" then - ftype.typename = "typetype" - end if ftype.is_method and ftype.args and ftype.args[1] and ftype.args[1].is_self then local record_name = typ.names and typ.names[1] @@ -11431,9 +11423,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for name, _ in fields_of(typ, "meta") do local ftype = children[i] - if ftype.typename == "nestedtype" then - ftype.typename = "typetype" - end typ.meta_fields[name] = ftype i = i + 1 end @@ -11550,7 +11539,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["tupletable"] = visit_type.cbs["string"] visit_type.cbs["typetype"] = visit_type.cbs["string"] - visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] visit_type.cbs["interface"] = visit_type.cbs["record"] From 988e45c1bfe3a686ac392070e358dd4c5995ebca Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 10 Dec 2023 02:53:15 -0300 Subject: [PATCH 036/224] no need to clear tk when uniting flattening constants --- tl.lua | 4 +--- tl.tl | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tl.lua b/tl.lua index f49ca92e2..46cc40c87 100644 --- a/tl.lua +++ b/tl.lua @@ -8949,9 +8949,7 @@ a.types[i], b.types[i]), } new.tk = nil table.insert(old.types, new) else - old.tk = nil - new.tk = nil - return unite({ old, new }) + return unite({ old, new }, true) end end end diff --git a/tl.tl b/tl.tl index cd2e0bf18..700241d5d 100644 --- a/tl.tl +++ b/tl.tl @@ -8949,9 +8949,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string new.tk = nil table.insert(old.types, new) else - old.tk = nil - new.tk = nil - return unite({ old, new }) + return unite({ old, new }, true) end end end From fe33c75ae66f49f0fd0010497d7403897f63cc6b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 11 Dec 2023 01:40:09 -0300 Subject: [PATCH 037/224] avoid setting (or updating!) typename directly This is to ensure that there is always a consistent correspondence between a typeid and its typename (that is, if the typename changes, we update the typeid as well). With consistency between typeids and typenames, we would _almost_ be able to cache type comparison judgements based on typeids alone. We're not able to do it just because type comparisons such as is_a and same_type have one possible side-effect: inferring typevars. From my experiments, in the current state of the code there's no real performance gain by caching type comparisons, but it would be nice to be able to do it for better understandability -- intuitively, type comparisons should be a pure operation. In any case, the changes in this commit should make it easier to reason about and debug the code in the long term. --- tl.lua | 910 +++++++++++++++++++++++++------------------------------ tl.tl | 932 ++++++++++++++++++++++++++------------------------------- 2 files changed, 823 insertions(+), 1019 deletions(-) diff --git a/tl.lua b/tl.lua index 46cc40c87..7a06b85b2 100644 --- a/tl.lua +++ b/tl.lua @@ -1539,15 +1539,21 @@ local function new_node(tokens, i, kind) return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } end -local function a_type(t) +local function a_type(typename, t) t.typeid = new_typeid() + t.typename = typename + return t +end + +local function edit_type(t, typename) + t.typeid = new_typeid() + t.typename = typename return t end local function new_type(ps, i, typename) local token = ps.tokens[i] - return a_type({ - typename = assert(typename), + return a_type(typename, { filename = ps.filename, y = token.y, x = token.x, @@ -1555,6 +1561,54 @@ local function new_type(ps, i, typename) }) end +local function a_tuple(t) + return a_type("tuple", t) +end + +local function a_union(t) + return a_type("union", { types = t }) +end + +local function a_poly(t) + return a_type("poly", { types = t }) +end + +local function a_function(t) + return a_type("function", t) +end + +local function a_typetype(t) + return a_type("typetype", t) +end + +local function a_vararg(t) + local tuple = t + tuple.is_va = true + return a_tuple(t) +end + +local function an_array(t) + return a_type("array", { + elements = t, + }) +end + +local function a_map(k, v) + return a_type("map", { + keys = k, + values = v, + }) +end + +local NIL = a_type("nil", {}) +local ANY = a_type("any", {}) +local TABLE = a_type("map", { keys = ANY, values = ANY }) +local NUMBER = a_type("number", {}) +local STRING = a_type("string", {}) +local THREAD = a_type("thread", {}) +local BOOLEAN = a_type("boolean", {}) +local INTEGER = a_type("integer", {}) + local function shallow_copy_type(t) local copy = {} @@ -1808,10 +1862,9 @@ end local function parse_typearg(ps, i) i = verify_kind(ps, i, "identifier") - return i, a_type({ + return i, a_type("typearg", { y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, - typename = "typearg", typearg = ps.tokens[i - 1].tk, }) end @@ -1830,8 +1883,8 @@ local function parse_function_type(ps, i) i, typ.args = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else - typ.args = a_type({ typename = "tuple", is_va = true, a_type({ typename = "any" }) }) - typ.rets = a_type({ typename = "tuple", is_va = true, a_type({ typename = "any" }) }) + typ.args = a_vararg({ ANY }) + typ.rets = a_vararg({ ANY }) end if typ.args[1] and typ.args[1].is_self then typ.is_method = true @@ -1839,15 +1892,6 @@ local function parse_function_type(ps, i) return i, typ end -local NIL = a_type({ typename = "nil" }) -local ANY = a_type({ typename = "any" }) -local TABLE = a_type({ typename = "map", keys = ANY, values = ANY }) -local NUMBER = a_type({ typename = "number" }) -local STRING = a_type({ typename = "string" }) -local THREAD = a_type({ typename = "thread" }) -local BOOLEAN = a_type({ typename = "boolean" }) -local INTEGER = a_type({ typename = "integer" }) - local simple_types = { ["nil"] = NIL, ["any"] = ANY, @@ -1902,18 +1946,21 @@ local function parse_base_type(ps, i) if ps.tokens[i].kind == "identifier" then return parse_simple_type_or_nominal(ps, i) elseif tk == "{" then + local istart = i i = i + 1 - local decl = new_type(ps, i, "array") local t i, t = parse_type(ps, i) if not t then return i end if ps.tokens[i].tk == "}" then + local decl = new_type(ps, istart, "array") decl.elements = t end_at(decl, ps.tokens[i]) i = verify_tk(ps, i, "}") + return i, decl elseif ps.tokens[i].tk == "," then + local decl = new_type(ps, istart, "tupletable") decl.typename = "tupletable" decl.types = { t } local n = 2 @@ -1927,8 +1974,9 @@ local function parse_base_type(ps, i) until ps.tokens[i].tk ~= "," end_at(decl, ps.tokens[i]) i = verify_tk(ps, i, "}") + return i, decl elseif ps.tokens[i].tk == ":" then - decl.typename = "map" + local decl = new_type(ps, istart, "map") i = i + 1 decl.keys = t i, decl.values = parse_type(ps, i) @@ -1937,18 +1985,17 @@ local function parse_base_type(ps, i) end end_at(decl, ps.tokens[i]) i = verify_tk(ps, i, "}") - else - return fail(ps, i, "syntax error; did you forget a '}'?") + return i, decl end - return i, decl + return fail(ps, i, "syntax error; did you forget a '}'?") elseif tk == "function" then return parse_function_type(ps, i) elseif tk == "nil" then return i + 1, simple_types["nil"] elseif tk == "table" then local typ = new_type(ps, i, "map") - typ.keys = a_type({ typename = "any" }) - typ.values = a_type({ typename = "any" }) + typ.keys = ANY + typ.values = ANY return i + 1, typ end return fail(ps, i, "expected a type") @@ -3002,8 +3049,8 @@ parse_record_body = function(ps, i, def, node) local typ = new_type(ps, wstart, "function") typ.is_method = true - typ.args = a_type({ typename = "tuple", a_type({ typename = "nominal", y = typ.y, x = typ.x, names = { "@self" } }) }) - typ.rets = a_type({ typename = "tuple", a_type({ typename = "boolean" }) }) + typ.args = a_tuple({ a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) }) + typ.rets = a_tuple({ BOOLEAN }) typ.macroexp = where_macroexp store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) @@ -4897,32 +4944,15 @@ end -local function VARARG(t) - local tuple = t - tuple.typename = "tuple" - tuple.is_va = true - return a_type(t) -end - -local function TUPLE(t) - local tuple = t - tuple.typename = "tuple" - return a_type(t) -end - -local function UNION(t) - return a_type({ typename = "union", types = t }) -end - -local NONE = a_type({ typename = "none" }) -local INVALID = a_type({ typename = "invalid" }) -local UNKNOWN = a_type({ typename = "unknown" }) -local CIRCULAR_REQUIRE = a_type({ typename = "circular_require" }) +local NONE = a_type("none", {}) +local INVALID = a_type("invalid", {}) +local UNKNOWN = a_type("unknown", {}) +local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) }) +local FUNCTION = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }) -local NOMINAL_FILE = a_type({ typename = "nominal", names = { "FILE" } }) -local XPCALL_MSGH_FUNCTION = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({}) }) +local NOMINAL_FILE = a_type("nominal", { names = { "FILE" } }) +local XPCALL_MSGH_FUNCTION = a_function({ args = a_tuple({ ANY }), rets = a_tuple({}) }) local USERDATA = ANY @@ -5579,40 +5609,34 @@ local function init_globals(lax) end local function a_record(t) - t = a_type(t) - t.typename = "record" + t = a_type("record", t) t.field_order = sorted_keys(t.fields) return t end - local function a_gfunction(n, f) + local function a_gfunction(n, f, typename) local typevars = {} local typeargs = {} local c = string.byte("A") - 1 fresh_typevar_ctr = fresh_typevar_ctr + 1 for i = 1, n do local name = string.char(c + i) .. "@" .. fresh_typevar_ctr - typevars[i] = a_type({ typename = "typevar", typevar = name }) - typeargs[i] = a_type({ typename = "typearg", typearg = name }) + typevars[i] = a_type("typevar", { typevar = name }) + typeargs[i] = a_type("typearg", { typearg = name }) end local t = f(_tl_table_unpack(typevars)) - t.typename = "function" t.typeargs = typeargs - return a_type(t) + return a_type(typename or "function", t) end local function a_grecord(n, f) - local t = a_gfunction(n, f) - t.typename = "record" + local t = a_gfunction(n, f, "record") t.field_order = sorted_keys(t.fields) return t end local function an_enum(keys) - local t = a_type({ - typename = "enum", - enumset = {}, - }) + local t = a_type("enum", { enumset = {} }) for _, k in ipairs(keys) do t.enumset[k] = true end @@ -5628,18 +5652,15 @@ local function init_globals(lax) local file_reader_poly_types = { - { ctor = VARARG, args = { UNION({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) }) }, rets = { STRING } }, - { ctor = TUPLE, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, - { ctor = VARARG, args = { UNION({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) }) }, rets = { UNION({ STRING, NUMBER }) } }, - { ctor = VARARG, args = { UNION({ NUMBER, STRING }) }, rets = { STRING } }, - { ctor = VARARG, args = {}, rets = { STRING } }, + { ctor = a_vararg, args = { a_union({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) }) }, rets = { STRING } }, + { ctor = a_tuple, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, + { ctor = a_vararg, args = { a_union({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) }) }, rets = { a_union({ STRING, NUMBER }) } }, + { ctor = a_vararg, args = { a_union({ NUMBER, STRING }) }, rets = { STRING } }, + { ctor = a_vararg, args = {}, rets = { STRING } }, } local function a_file_reader(fn) - local t = a_type({ - typename = "poly", - types = {}, - }) + local t = a_poly({}) for _, entry in ipairs(file_reader_poly_types) do local args = shallow_copy_type(entry.args) local rets = shallow_copy_type(entry.rets) @@ -5648,7 +5669,7 @@ local function init_globals(lax) return t end - local LOAD_FUNCTION = a_type({ typename = "function", args = {}, rets = TUPLE({ STRING }) }) + local LOAD_FUNCTION = a_function({ args = {}, rets = a_tuple({ STRING }) }) local OS_DATE_TABLE = a_record({ fields = { @@ -5679,467 +5700,370 @@ local function init_globals(lax) ["nparams"] = INTEGER, ["isvararg"] = BOOLEAN, ["func"] = ANY, - ["activelines"] = a_type({ typename = "map", keys = INTEGER, values = BOOLEAN }), + ["activelines"] = a_type("map", { keys = INTEGER, values = BOOLEAN }), }, }) local DEBUG_HOOK_EVENT = an_enum({ "call", "tail call", "return", "line", "count" }) - local DEBUG_HOOK_FUNCTION = a_type({ - typename = "function", - args = TUPLE({ DEBUG_HOOK_EVENT, INTEGER }), - rets = TUPLE({}), + local DEBUG_HOOK_FUNCTION = a_function({ + args = a_tuple({ DEBUG_HOOK_EVENT, INTEGER }), + rets = a_tuple({}), }) - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = TUPLE({ a, a }), rets = TUPLE({ BOOLEAN }) } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = a_tuple({ a, a }), rets = a_tuple({ BOOLEAN }) } end) local metatable_nominals = {} local function METATABLE(a) - local t = a_type({ typename = "nominal", names = { "metatable" }, typevals = { a } }) + local t = a_type("nominal", { names = { "metatable" }, typevals = { a } }) table.insert(metatable_nominals, t) return t end - local function ARRAY(t) - return a_type({ - typename = "array", - elements = t, - }) - end - - local function MAP(k, v) - return a_type({ - typename = "map", - keys = k, - values = v, - }) - end - local standard_library = { - ["..."] = VARARG({ STRING }), - ["any"] = a_type({ typename = "typetype", def = ANY }), - ["arg"] = ARRAY(STRING), - ["assert"] = a_gfunction(2, function(a, b) return { args = TUPLE({ a, OPT(b) }), rets = TUPLE({ a }) } end), - ["collectgarbage"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ an_enum({ "collect", "count", "stop", "restart" }) }), rets = TUPLE({ NUMBER }) }), - a_type({ typename = "function", args = TUPLE({ an_enum({ "step", "setpause", "setstepmul" }), NUMBER }), rets = TUPLE({ NUMBER }) }), - a_type({ typename = "function", args = TUPLE({ an_enum({ "isrunning" }) }), rets = TUPLE({ BOOLEAN }) }), - a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER) }), rets = TUPLE({ a_type({ typename = "union", types = { BOOLEAN, NUMBER } }) }) }), - }, + ["..."] = a_vararg({ STRING }), + ["any"] = a_type("typetype", { def = ANY }), + ["arg"] = an_array(STRING), + ["assert"] = a_gfunction(2, function(a, b) return { args = a_tuple({ a, OPT(b) }), rets = a_tuple({ a }) } end), + ["collectgarbage"] = a_poly({ + a_function({ args = a_tuple({ an_enum({ "collect", "count", "stop", "restart" }) }), rets = a_tuple({ NUMBER }) }), + a_function({ args = a_tuple({ an_enum({ "step", "setpause", "setstepmul" }), NUMBER }), rets = a_tuple({ NUMBER }) }), + a_function({ args = a_tuple({ an_enum({ "isrunning" }) }), rets = a_tuple({ BOOLEAN }) }), + a_function({ args = a_tuple({ STRING, OPT(NUMBER) }), rets = a_tuple({ a_union({ BOOLEAN, NUMBER }) }) }), }), - ["dofile"] = a_type({ typename = "function", args = TUPLE({ OPT(STRING) }), rets = VARARG({ ANY }) }), - ["error"] = a_type({ typename = "function", args = TUPLE({ ANY, OPT(NUMBER) }), rets = TUPLE({}) }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a }), rets = TUPLE({ METATABLE(a) }) } end), - ["ipairs"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a) }), rets = TUPLE({ - a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ INTEGER, a }) }), + ["dofile"] = a_function({ args = a_tuple({ OPT(STRING) }), rets = a_vararg({ ANY }) }), + ["error"] = a_function({ args = a_tuple({ ANY, OPT(NUMBER) }), rets = a_tuple({}) }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a }), rets = a_tuple({ METATABLE(a) }) } end), + ["ipairs"] = a_gfunction(1, function(a) return { args = a_tuple({ an_array(a) }), rets = a_tuple({ + a_function({ args = a_tuple({}), rets = a_tuple({ INTEGER, a }) }), }), } end), - ["load"] = a_type({ typename = "function", args = TUPLE({ UNION({ STRING, LOAD_FUNCTION }), OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = TUPLE({ FUNCTION, STRING }) }), - ["loadfile"] = a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = TUPLE({ FUNCTION, STRING }) }), - ["next"] = a_type({ - typename = "poly", - types = { - a_gfunction(2, function(a, b) return { args = TUPLE({ MAP(a, b), OPT(a) }), rets = TUPLE({ a, b }) } end), - a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(a) }), rets = TUPLE({ INTEGER, a }) } end), - }, + ["load"] = a_function({ args = a_tuple({ a_union({ STRING, LOAD_FUNCTION }), OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_tuple({ FUNCTION, STRING }) }), + ["loadfile"] = a_function({ args = a_tuple({ OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_tuple({ FUNCTION, STRING }) }), + ["next"] = a_poly({ + a_gfunction(2, function(a, b) return { args = a_tuple({ a_map(a, b), OPT(a) }), rets = a_tuple({ a, b }) } end), + a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), OPT(a) }), rets = a_tuple({ INTEGER, a }) } end), }), - ["pairs"] = a_gfunction(2, function(a, b) return { args = TUPLE({ a_type({ typename = "map", keys = a, values = b }) }), rets = TUPLE({ - a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ a, b }) }), + ["pairs"] = a_gfunction(2, function(a, b) return { args = a_tuple({ a_map(a, b) }), rets = a_tuple({ + a_function({ args = a_tuple({}), rets = a_tuple({ a, b }) }), }), } end), - ["pcall"] = a_type({ typename = "function", args = VARARG({ FUNCTION, ANY }), rets = VARARG({ BOOLEAN, ANY }) }), - ["xpcall"] = a_type({ typename = "function", args = VARARG({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = VARARG({ BOOLEAN, ANY }) }), - ["print"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({}) }), - ["rawequal"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), - ["rawget"] = a_type({ typename = "function", args = TUPLE({ TABLE, ANY }), rets = TUPLE({ ANY }) }), - ["rawlen"] = a_type({ typename = "function", args = TUPLE({ UNION({ TABLE, STRING }) }), rets = TUPLE({ INTEGER }) }), - ["rawset"] = a_type({ - typename = "poly", - types = { - a_gfunction(2, function(a, b) return { args = TUPLE({ MAP(a, b), a, b }), rets = TUPLE({}) } end), - a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, a }), rets = TUPLE({}) } end), - a_type({ typename = "function", args = TUPLE({ TABLE, ANY, ANY }), rets = TUPLE({}) }), - }, + ["pcall"] = a_function({ args = a_vararg({ FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["xpcall"] = a_function({ args = a_vararg({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["print"] = a_function({ args = a_vararg({ ANY }), rets = a_tuple({}) }), + ["rawequal"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), + ["rawget"] = a_function({ args = a_tuple({ TABLE, ANY }), rets = a_tuple({ ANY }) }), + ["rawlen"] = a_function({ args = a_tuple({ a_union({ TABLE, STRING }) }), rets = a_tuple({ INTEGER }) }), + ["rawset"] = a_poly({ + a_gfunction(2, function(a, b) return { args = a_tuple({ a_map(a, b), a, b }), rets = a_tuple({}) } end), + a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, a }), rets = a_tuple({}) } end), + a_function({ args = a_tuple({ TABLE, ANY, ANY }), rets = a_tuple({}) }), }), - ["require"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({}) }), - ["select"] = a_type({ - typename = "poly", - types = { - a_gfunction(1, function(a) return { args = VARARG({ NUMBER, a }), rets = TUPLE({ a }) } end), - a_type({ typename = "function", args = VARARG({ NUMBER, ANY }), rets = TUPLE({ ANY }) }), - a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ INTEGER }) }), - }, + ["require"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({}) }), + ["select"] = a_poly({ + a_gfunction(1, function(a) return { args = a_vararg({ NUMBER, a }), rets = a_tuple({ a }) } end), + a_function({ args = a_vararg({ NUMBER, ANY }), rets = a_tuple({ ANY }) }), + a_function({ args = a_vararg({ STRING, ANY }), rets = a_tuple({ INTEGER }) }), }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a, METATABLE(a) }), rets = TUPLE({ a }) } end), - ["tonumber"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ NUMBER }) }), - a_type({ typename = "function", args = TUPLE({ ANY, NUMBER }), rets = TUPLE({ INTEGER }) }), - }, + ["setmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a, METATABLE(a) }), rets = a_tuple({ a }) } end), + ["tonumber"] = a_poly({ + a_function({ args = a_tuple({ ANY }), rets = a_tuple({ NUMBER }) }), + a_function({ args = a_tuple({ ANY, NUMBER }), rets = a_tuple({ INTEGER }) }), }), - ["tostring"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), - ["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), - ["FILE"] = a_type({ - typename = "typetype", + ["tostring"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), + ["type"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), + ["FILE"] = a_typetype({ def = a_record({ is_userdata = true, fields = { - ["close"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({ BOOLEAN, STRING, INTEGER }) }), - ["flush"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({}) }), + ["close"] = a_function({ args = a_tuple({ NOMINAL_FILE }), rets = a_tuple({ BOOLEAN, STRING, INTEGER }) }), + ["flush"] = a_function({ args = a_tuple({ NOMINAL_FILE }), rets = a_tuple({}) }), ["lines"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_type({ typename = "function", args = ctor(args), rets = TUPLE({ - a_type({ typename = "function", args = TUPLE({}), rets = ctor(rets) }), + return a_function({ args = ctor(args), rets = a_tuple({ + a_function({ args = a_tuple({}), rets = ctor(rets) }), }), }) end), ["read"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_type({ typename = "function", args = ctor(args), rets = ctor(rets) }) + return a_function({ args = ctor(args), rets = ctor(rets) }) end), - ["seek"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ INTEGER, STRING }) }), - ["setvbuf"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, STRING, OPT(NUMBER) }), rets = TUPLE({}) }), - ["write"] = a_type({ typename = "function", args = VARARG({ NOMINAL_FILE, UNION({ STRING, NUMBER }) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), + ["seek"] = a_function({ args = a_tuple({ NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ INTEGER, STRING }) }), + ["setvbuf"] = a_function({ args = a_tuple({ NOMINAL_FILE, STRING, OPT(NUMBER) }), rets = a_tuple({}) }), + ["write"] = a_function({ args = a_vararg({ NOMINAL_FILE, a_union({ STRING, NUMBER }) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), }, meta_fields = { ["__close"] = FUNCTION }, meta_field_order = { "__close" }, }), }), - ["metatable"] = a_type({ - typename = "typetype", + ["metatable"] = a_typetype({ def = a_grecord(1, function(a) return { fields = { - ["__call"] = a_type({ typename = "function", args = VARARG({ a, ANY }), rets = VARARG({ ANY }) }), - ["__gc"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({}) }), + ["__call"] = a_function({ args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), + ["__gc"] = a_function({ args = a_tuple({ a }), rets = a_tuple({}) }), ["__index"] = ANY, - ["__len"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({ ANY }) }), + ["__len"] = a_function({ args = a_tuple({ a }), rets = a_tuple({ ANY }) }), ["__mode"] = an_enum({ "k", "v", "kv" }), ["__newindex"] = ANY, ["__pairs"] = a_gfunction(2, function(k, v) return { - args = TUPLE({ a }), - rets = TUPLE({ a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ k, v }) }) }), + args = a_tuple({ a }), + rets = a_tuple({ a_function({ args = a_tuple({}), rets = a_tuple({ k, v }) }) }), } end), - ["__tostring"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({ STRING }) }), + ["__tostring"] = a_function({ args = a_tuple({ a }), rets = a_tuple({ STRING }) }), ["__name"] = STRING, - ["__add"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__sub"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__mul"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__div"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__idiv"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__mod"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__pow"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__unm"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ ANY }) }), - ["__band"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__bor"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__bxor"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__bnot"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ ANY }) }), - ["__shl"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__shr"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__concat"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), - ["__eq"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), - ["__lt"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), - ["__le"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), - ["__close"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({}) }), + ["__add"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__sub"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__mul"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__div"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__idiv"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__mod"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__pow"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__unm"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ ANY }) }), + ["__band"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__bor"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__bxor"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__bnot"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ ANY }) }), + ["__shl"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__shr"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__concat"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), + ["__eq"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), + ["__lt"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), + ["__le"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), + ["__close"] = a_function({ args = a_tuple({ a }), rets = a_tuple({}) }), }, } end), }), ["coroutine"] = a_record({ fields = { - ["create"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ THREAD }) }), - ["close"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ BOOLEAN, STRING }) }), - ["isyieldable"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ BOOLEAN }) }), - ["resume"] = a_type({ typename = "function", args = VARARG({ THREAD, ANY }), rets = VARARG({ BOOLEAN, ANY }) }), - ["running"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ THREAD, BOOLEAN }) }), - ["status"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ STRING }) }), - ["wrap"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ FUNCTION }) }), - ["yield"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) }), + ["create"] = a_function({ args = a_tuple({ FUNCTION }), rets = a_tuple({ THREAD }) }), + ["close"] = a_function({ args = a_tuple({ THREAD }), rets = a_tuple({ BOOLEAN, STRING }) }), + ["isyieldable"] = a_function({ args = a_tuple({}), rets = a_tuple({ BOOLEAN }) }), + ["resume"] = a_function({ args = a_vararg({ THREAD, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["running"] = a_function({ args = a_tuple({}), rets = a_tuple({ THREAD, BOOLEAN }) }), + ["status"] = a_function({ args = a_tuple({ THREAD }), rets = a_tuple({ STRING }) }), + ["wrap"] = a_function({ args = a_tuple({ FUNCTION }), rets = a_tuple({ FUNCTION }) }), + ["yield"] = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }), }, }), ["debug"] = a_record({ fields = { - ["Info"] = a_type({ - typename = "typetype", - def = DEBUG_GETINFO_TABLE, + ["Info"] = a_typetype({ def = DEBUG_GETINFO_TABLE }), + ["Hook"] = a_typetype({ def = DEBUG_HOOK_FUNCTION }), + ["HookEvent"] = a_typetype({ def = DEBUG_HOOK_EVENT }), + + ["debug"] = a_function({ args = a_tuple({}), rets = a_tuple({}) }), + ["gethook"] = a_function({ args = a_tuple({ OPT(THREAD) }), rets = a_tuple({ DEBUG_HOOK_FUNCTION, INTEGER }) }), + ["getlocal"] = a_poly({ + a_function({ args = a_tuple({ THREAD, FUNCTION, NUMBER }), rets = STRING }), + a_function({ args = a_tuple({ THREAD, NUMBER, NUMBER }), rets = a_tuple({ STRING, ANY }) }), + a_function({ args = a_tuple({ FUNCTION, NUMBER }), rets = STRING }), + a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ STRING, ANY }) }), }), - ["Hook"] = a_type({ - typename = "typetype", - def = DEBUG_HOOK_FUNCTION, + ["getmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a }), rets = a_tuple({ METATABLE(a) }) } end), + ["getregistry"] = a_function({ args = a_tuple({}), rets = a_tuple({ TABLE }) }), + ["getupvalue"] = a_function({ args = a_tuple({ FUNCTION, NUMBER }), rets = a_tuple({ ANY }) }), + ["getuservalue"] = a_function({ args = a_tuple({ USERDATA, NUMBER }), rets = a_tuple({ ANY }) }), + ["sethook"] = a_poly({ + a_function({ args = a_tuple({ THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_tuple({}) }), + a_function({ args = a_tuple({ DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_tuple({}) }), }), - ["HookEvent"] = a_type({ - typename = "typetype", - def = DEBUG_HOOK_EVENT, + ["setlocal"] = a_poly({ + a_function({ args = a_tuple({ THREAD, NUMBER, NUMBER, ANY }), rets = a_tuple({ STRING }) }), + a_function({ args = a_tuple({ NUMBER, NUMBER, ANY }), rets = a_tuple({ STRING }) }), }), - - ["debug"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }), - ["gethook"] = a_type({ typename = "function", args = TUPLE({ OPT(THREAD) }), rets = TUPLE({ DEBUG_HOOK_FUNCTION, INTEGER }) }), - ["getlocal"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ THREAD, FUNCTION, NUMBER }), rets = STRING }), - a_type({ typename = "function", args = TUPLE({ THREAD, NUMBER, NUMBER }), rets = TUPLE({ STRING, ANY }) }), - a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = STRING }), - a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ STRING, ANY }) }), - }, + ["setmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a, METATABLE(a) }), rets = a_tuple({ a }) } end), + ["setupvalue"] = a_function({ args = a_tuple({ FUNCTION, NUMBER, ANY }), rets = a_tuple({ STRING }) }), + ["setuservalue"] = a_function({ args = a_tuple({ USERDATA, ANY, NUMBER }), rets = a_tuple({ USERDATA }) }), + ["traceback"] = a_poly({ + a_function({ args = a_tuple({ OPT(THREAD), OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), + a_function({ args = a_tuple({ OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a }), rets = TUPLE({ METATABLE(a) }) } end), - ["getregistry"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ TABLE }) }), - ["getupvalue"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ ANY }) }), - ["getuservalue"] = a_type({ typename = "function", args = TUPLE({ USERDATA, NUMBER }), rets = TUPLE({ ANY }) }), - ["sethook"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = TUPLE({}) }), - a_type({ typename = "function", args = TUPLE({ DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = TUPLE({}) }), - }, - }), - ["setlocal"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ THREAD, NUMBER, NUMBER, ANY }), rets = TUPLE({ STRING }) }), - a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER, ANY }), rets = TUPLE({ STRING }) }), - }, - }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a, METATABLE(a) }), rets = TUPLE({ a }) } end), - ["setupvalue"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER, ANY }), rets = TUPLE({ STRING }) }), - ["setuservalue"] = a_type({ typename = "function", args = TUPLE({ USERDATA, ANY, NUMBER }), rets = TUPLE({ USERDATA }) }), - ["traceback"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ OPT(THREAD), OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), - a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), - }, - }), - ["upvalueid"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ USERDATA }) }), - ["upvaluejoin"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = TUPLE({}) }), - ["getinfo"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }), - a_type({ typename = "function", args = TUPLE({ ANY, STRING }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }), - a_type({ typename = "function", args = TUPLE({ ANY, ANY, STRING }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }), - }, + ["upvalueid"] = a_function({ args = a_tuple({ FUNCTION, NUMBER }), rets = a_tuple({ USERDATA }) }), + ["upvaluejoin"] = a_function({ args = a_tuple({ FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = a_tuple({}) }), + ["getinfo"] = a_poly({ + a_function({ args = a_tuple({ ANY }), rets = a_tuple({ DEBUG_GETINFO_TABLE }) }), + a_function({ args = a_tuple({ ANY, STRING }), rets = a_tuple({ DEBUG_GETINFO_TABLE }) }), + a_function({ args = a_tuple({ ANY, ANY, STRING }), rets = a_tuple({ DEBUG_GETINFO_TABLE }) }), }), }, }), ["io"] = a_record({ fields = { - ["close"] = a_type({ typename = "function", args = TUPLE({ OPT(NOMINAL_FILE) }), rets = TUPLE({ BOOLEAN, STRING }) }), - ["flush"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }), - ["input"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ STRING, NOMINAL_FILE })) }), rets = TUPLE({ NOMINAL_FILE }) }), + ["close"] = a_function({ args = a_tuple({ OPT(NOMINAL_FILE) }), rets = a_tuple({ BOOLEAN, STRING }) }), + ["flush"] = a_function({ args = a_tuple({}), rets = a_tuple({}) }), + ["input"] = a_function({ args = a_tuple({ OPT(a_union({ STRING, NOMINAL_FILE })) }), rets = a_tuple({ NOMINAL_FILE }) }), ["lines"] = a_file_reader(function(ctor, args, rets) - return a_type({ typename = "function", args = ctor(args), rets = TUPLE({ - a_type({ typename = "function", args = TUPLE({}), rets = ctor(rets) }), + return a_function({ args = ctor(args), rets = a_tuple({ + a_function({ args = a_tuple({}), rets = ctor(rets) }), }), }) end), - ["open"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), - ["output"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ STRING, NOMINAL_FILE })) }), rets = TUPLE({ NOMINAL_FILE }) }), - ["popen"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), + ["open"] = a_function({ args = a_tuple({ STRING, OPT(STRING) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), + ["output"] = a_function({ args = a_tuple({ OPT(a_union({ STRING, NOMINAL_FILE })) }), rets = a_tuple({ NOMINAL_FILE }) }), + ["popen"] = a_function({ args = a_tuple({ STRING, OPT(STRING) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), ["read"] = a_file_reader(function(ctor, args, rets) - return a_type({ typename = "function", args = ctor(args), rets = ctor(rets) }) + return a_function({ args = ctor(args), rets = ctor(rets) }) end), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NOMINAL_FILE }) }), - ["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), - ["write"] = a_type({ typename = "function", args = VARARG({ UNION({ STRING, NUMBER }) }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), + ["tmpfile"] = a_function({ args = a_tuple({}), rets = a_tuple({ NOMINAL_FILE }) }), + ["type"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), + ["write"] = a_function({ args = a_vararg({ a_union({ STRING, NUMBER }) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), }, }), ["math"] = a_record({ fields = { - ["abs"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ INTEGER }), rets = TUPLE({ INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - }, + ["abs"] = a_poly({ + a_function({ args = a_tuple({ INTEGER }), rets = a_tuple({ INTEGER }) }), + a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), }), - ["acos"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["asin"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["atan"] = a_type({ typename = "function", args = TUPLE({ NUMBER, OPT(NUMBER) }), rets = TUPLE({ NUMBER }) }), - ["atan2"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), - ["ceil"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER }) }), - ["cos"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["cosh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["deg"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["exp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["floor"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER }) }), - ["fmod"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ INTEGER, INTEGER }), rets = TUPLE({ INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), - }, + ["acos"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["asin"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["atan"] = a_function({ args = a_tuple({ NUMBER, OPT(NUMBER) }), rets = a_tuple({ NUMBER }) }), + ["atan2"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), + ["ceil"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ INTEGER }) }), + ["cos"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["cosh"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["deg"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["exp"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["floor"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ INTEGER }) }), + ["fmod"] = a_poly({ + a_function({ args = a_tuple({ INTEGER, INTEGER }), rets = a_tuple({ INTEGER }) }), + a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), }), - ["frexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER, NUMBER }) }), + ["frexp"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER, NUMBER }) }), ["huge"] = NUMBER, - ["ldexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), - ["log"] = a_type({ typename = "function", args = TUPLE({ NUMBER, OPT(NUMBER) }), rets = TUPLE({ NUMBER }) }), - ["log10"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["max"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = VARARG({ INTEGER }), rets = TUPLE({ INTEGER }) }), - a_gfunction(1, function(a) return { args = VARARG({ a }), rets = TUPLE({ a }) } end), - a_type({ typename = "function", args = VARARG({ a_type({ typename = "union", types = { NUMBER, INTEGER } }) }), rets = TUPLE({ NUMBER }) }), - a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ ANY }) }), - }, + ["ldexp"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), + ["log"] = a_function({ args = a_tuple({ NUMBER, OPT(NUMBER) }), rets = a_tuple({ NUMBER }) }), + ["log10"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["max"] = a_poly({ + a_function({ args = a_vararg({ INTEGER }), rets = a_tuple({ INTEGER }) }), + a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_tuple({ a }) } end), + a_function({ args = a_vararg({ a_union({ NUMBER, INTEGER }) }), rets = a_tuple({ NUMBER }) }), + a_function({ args = a_vararg({ ANY }), rets = a_tuple({ ANY }) }), }), - ["maxinteger"] = a_type({ typename = "integer", needs_compat = true }), - ["min"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = VARARG({ INTEGER }), rets = TUPLE({ INTEGER }) }), - a_gfunction(1, function(a) return { args = VARARG({ a }), rets = TUPLE({ a }) } end), - a_type({ typename = "function", args = VARARG({ a_type({ typename = "union", types = { NUMBER, INTEGER } }) }), rets = TUPLE({ NUMBER }) }), - a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ ANY }) }), - }, + ["maxinteger"] = a_type("integer", { needs_compat = true }), + ["min"] = a_poly({ + a_function({ args = a_vararg({ INTEGER }), rets = a_tuple({ INTEGER }) }), + a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_tuple({ a }) } end), + a_function({ args = a_vararg({ a_union({ NUMBER, INTEGER }) }), rets = a_tuple({ NUMBER }) }), + a_function({ args = a_vararg({ ANY }), rets = a_tuple({ ANY }) }), }), - ["mininteger"] = a_type({ typename = "integer", needs_compat = true }), - ["modf"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER, NUMBER }) }), + ["mininteger"] = a_type("integer", { needs_compat = true }), + ["modf"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ INTEGER, NUMBER }) }), ["pi"] = NUMBER, - ["pow"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), - ["rad"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["random"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ NUMBER, OPT(NUMBER) }), rets = TUPLE({ INTEGER }) }), - a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }), - }, + ["pow"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), + ["rad"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["random"] = a_poly({ + a_function({ args = a_tuple({ NUMBER, OPT(NUMBER) }), rets = a_tuple({ INTEGER }) }), + a_function({ args = a_tuple({}), rets = a_tuple({ NUMBER }) }), }), - ["randomseed"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ INTEGER, INTEGER }) }), - ["sin"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["sinh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["sqrt"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["tan"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["tanh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), - ["tointeger"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ INTEGER }) }), - ["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), - ["ult"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ BOOLEAN }) }), + ["randomseed"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ INTEGER, INTEGER }) }), + ["sin"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["sinh"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["sqrt"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["tan"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["tanh"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), + ["tointeger"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ INTEGER }) }), + ["type"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), + ["ult"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ BOOLEAN }) }), }, }), ["os"] = a_record({ fields = { - ["clock"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }), - ["date"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }), - a_type({ typename = "function", args = TUPLE({ an_enum({ "!*t", "*t" }), OPT(NUMBER) }), rets = TUPLE({ OS_DATE_TABLE }) }), - a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), - }, + ["clock"] = a_function({ args = a_tuple({}), rets = a_tuple({ NUMBER }) }), + ["date"] = a_poly({ + a_function({ args = a_tuple({}), rets = a_tuple({ STRING }) }), + a_function({ args = a_tuple({ an_enum({ "!*t", "*t" }), OPT(NUMBER) }), rets = a_tuple({ OS_DATE_TABLE }) }), + a_function({ args = a_tuple({ OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), }), - ["difftime"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), - ["execute"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING, INTEGER }) }), - ["exit"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ NUMBER, BOOLEAN })), OPT(BOOLEAN) }), rets = TUPLE({}) }), - ["getenv"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), - ["remove"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING }) }), - ["rename"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ BOOLEAN, STRING }) }), - ["setlocale"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING) }), rets = TUPLE({ STRING }) }), - ["time"] = a_type({ typename = "function", args = TUPLE({ OPT(OS_DATE_TABLE) }), rets = TUPLE({ INTEGER }) }), - ["tmpname"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }), + ["difftime"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), + ["execute"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ BOOLEAN, STRING, INTEGER }) }), + ["exit"] = a_function({ args = a_tuple({ OPT(a_union({ NUMBER, BOOLEAN })), OPT(BOOLEAN) }), rets = a_tuple({}) }), + ["getenv"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), + ["remove"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ BOOLEAN, STRING }) }), + ["rename"] = a_function({ args = a_tuple({ STRING, STRING }), rets = a_tuple({ BOOLEAN, STRING }) }), + ["setlocale"] = a_function({ args = a_tuple({ STRING, OPT(STRING) }), rets = a_tuple({ STRING }) }), + ["time"] = a_function({ args = a_tuple({ OPT(OS_DATE_TABLE) }), rets = a_tuple({ INTEGER }) }), + ["tmpname"] = a_function({ args = a_tuple({}), rets = a_tuple({ STRING }) }), }, }), ["package"] = a_record({ fields = { ["config"] = STRING, ["cpath"] = STRING, - ["loaded"] = a_type({ - typename = "map", - keys = STRING, - values = ANY, - }), - ["loaders"] = a_type({ - typename = "array", - elements = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ ANY, ANY }) }), - }), - ["loadlib"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ FUNCTION }) }), + ["loaded"] = a_map(STRING, ANY), + ["loaders"] = an_array(a_function({ args = a_tuple({ STRING }), rets = a_tuple({ ANY, ANY }) })), + ["loadlib"] = a_function({ args = a_tuple({ STRING, STRING }), rets = a_tuple({ FUNCTION }) }), ["path"] = STRING, ["preload"] = TABLE, - ["searchers"] = a_type({ - typename = "array", - elements = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ ANY, ANY }) }), - }), - ["searchpath"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(STRING), OPT(STRING) }), rets = TUPLE({ STRING, STRING }) }), + ["searchers"] = an_array(a_function({ args = a_tuple({ STRING }), rets = a_tuple({ ANY, ANY }) })), + ["searchpath"] = a_function({ args = a_tuple({ STRING, STRING, OPT(STRING), OPT(STRING) }), rets = a_tuple({ STRING, STRING }) }), }, }), ["string"] = a_record({ fields = { - ["byte"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER) }), rets = TUPLE({ INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = VARARG({ INTEGER }) }), - }, + ["byte"] = a_poly({ + a_function({ args = a_tuple({ STRING, OPT(NUMBER) }), rets = a_tuple({ INTEGER }) }), + a_function({ args = a_tuple({ STRING, NUMBER, NUMBER }), rets = a_vararg({ INTEGER }) }), }), - ["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }), - ["dump"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, OPT(BOOLEAN) }), rets = TUPLE({ STRING }) }), - ["find"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }), rets = VARARG({ INTEGER, INTEGER, STRING }) }), - ["format"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }), - ["gmatch"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ - a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }), + ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_tuple({ STRING }) }), + ["dump"] = a_function({ args = a_tuple({ FUNCTION, OPT(BOOLEAN) }), rets = a_tuple({ STRING }) }), + ["find"] = a_function({ args = a_tuple({ STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }), rets = a_vararg({ INTEGER, INTEGER, STRING }) }), + ["format"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_tuple({ STRING }) }), + ["gmatch"] = a_function({ args = a_tuple({ STRING, STRING }), rets = a_tuple({ + a_function({ args = a_tuple({}), rets = a_vararg({ STRING }) }), }), }), - ["gsub"] = a_type({ - typename = "poly", - types = { - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ STRING }) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NUMBER }) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ BOOLEAN }) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({}) }), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), - a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING, INTEGER }) }), - - }, + ["gsub"] = a_poly({ + a_function({ args = a_tuple({ STRING, STRING, a_map(STRING, STRING), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), + a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({ STRING }) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), + a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({ NUMBER }) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), + a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({ BOOLEAN }) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), + a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({}) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), + a_function({ args = a_tuple({ STRING, STRING, OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), + }), - ["len"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }), - ["lower"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), - ["match"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING), OPT(NUMBER) }), rets = VARARG({ STRING }) }), - ["pack"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }), - ["packsize"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }), - ["rep"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, OPT(STRING) }), rets = TUPLE({ STRING }) }), - ["reverse"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), - ["sub"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, OPT(NUMBER) }), rets = TUPLE({ STRING }) }), - ["unpack"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(NUMBER) }), rets = VARARG({ ANY }) }), - ["upper"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), + ["len"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ INTEGER }) }), + ["lower"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), + ["match"] = a_function({ args = a_tuple({ STRING, OPT(STRING), OPT(NUMBER) }), rets = a_vararg({ STRING }) }), + ["pack"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_tuple({ STRING }) }), + ["packsize"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ INTEGER }) }), + ["rep"] = a_function({ args = a_tuple({ STRING, NUMBER, OPT(STRING) }), rets = a_tuple({ STRING }) }), + ["reverse"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), + ["sub"] = a_function({ args = a_tuple({ STRING, NUMBER, OPT(NUMBER) }), rets = a_tuple({ STRING }) }), + ["unpack"] = a_function({ args = a_tuple({ STRING, STRING, OPT(NUMBER) }), rets = a_vararg({ ANY }) }), + ["upper"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), }, }), ["table"] = a_record({ fields = { - ["concat"] = a_type({ typename = "function", args = TUPLE({ ARRAY(UNION({ STRING, NUMBER })), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = TUPLE({ STRING }) }), - ["insert"] = a_type({ - typename = "poly", - types = { - a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, a }), rets = TUPLE({}) } end), - a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), a }), rets = TUPLE({}) } end), - }, + ["concat"] = a_function({ args = a_tuple({ an_array(a_union({ STRING, NUMBER })), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), + ["insert"] = a_poly({ + a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, a }), rets = a_tuple({}) } end), + a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), a }), rets = a_tuple({}) } end), }), - ["move"] = a_type({ - typename = "poly", - types = { - a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, NUMBER, NUMBER }), rets = TUPLE({ ARRAY(a) }) } end), - a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, NUMBER, NUMBER, ARRAY(a) }), rets = TUPLE({ ARRAY(a) }) } end), - }, + ["move"] = a_poly({ + a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, NUMBER, NUMBER }), rets = a_tuple({ an_array(a) }) } end), + a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }), rets = a_tuple({ an_array(a) }) } end), }), - ["pack"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ TABLE }) }), - ["remove"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(NUMBER) }), rets = TUPLE({ a }) } end), - ["sort"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(TABLE_SORT_FUNCTION) }), rets = TUPLE({}) } end), - ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = TUPLE({ ARRAY(a), OPT(NUMBER), OPT(NUMBER) }), rets = VARARG({ a }) } end), + ["pack"] = a_function({ args = a_vararg({ ANY }), rets = a_tuple({ TABLE }) }), + ["remove"] = a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), OPT(NUMBER) }), rets = a_tuple({ a }) } end), + ["sort"] = a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), OPT(TABLE_SORT_FUNCTION) }), rets = a_tuple({}) } end), + ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = a_tuple({ an_array(a), OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ a }) } end), }, }), ["utf8"] = a_record({ fields = { - ["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }), + ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_tuple({ STRING }) }), ["charpattern"] = STRING, - ["codepoint"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER), OPT(NUMBER) }), rets = VARARG({ INTEGER }) }), - ["codes"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ - a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER) }), rets = TUPLE({ NUMBER, NUMBER }) }), + ["codepoint"] = a_function({ args = a_tuple({ STRING, OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ INTEGER }) }), + ["codes"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ + a_function({ args = a_tuple({ STRING, OPT(NUMBER) }), rets = a_tuple({ NUMBER, NUMBER }) }), }), }), - ["len"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }), - ["offset"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }), + ["len"] = a_function({ args = a_tuple({ STRING, NUMBER, NUMBER }), rets = a_tuple({ INTEGER }) }), + ["offset"] = a_function({ args = a_tuple({ STRING, NUMBER, NUMBER }), rets = a_tuple({ INTEGER }) }), }, }), ["_VERSION"] = STRING, @@ -6230,7 +6154,7 @@ tl.type_check = function(ast, opts) end if opts.module_name then - env.modules[opts.module_name] = a_type({ typename = "typetype", def = CIRCULAR_REQUIRE }) + env.modules[opts.module_name] = a_typetype({ def = CIRCULAR_REQUIRE }) end local lax = opts.lax @@ -6308,15 +6232,13 @@ tl.type_check = function(ast, opts) local resolve_typevars local function fresh_typevar(t) - return a_type({ - typename = "typevar", + return a_type("typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, }) end local function fresh_typearg(t) - return a_type({ - typename = "typearg", + return a_type("typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, }) end @@ -7051,13 +6973,12 @@ tl.type_check = function(ast, opts) unresolved = find_var_type("@unresolved") end if not unresolved then - unresolved = { - typename = "unresolved", + unresolved = a_type("unresolved", { labels = {}, nominals = {}, global_types = {}, narrows = {}, - } + }) add_var(nil, "@unresolved", unresolved) end return unresolved @@ -7179,7 +7100,7 @@ tl.type_check = function(ast, opts) end if not resolved then - resolved = a_type({ typename = "bad_nominal", names = t.names }) + resolved = a_type("bad_nominal", { names = t.names }) end if not t.filename then @@ -7319,10 +7240,7 @@ tl.type_check = function(ast, opts) if #ts == 1 then return ts[1] else - return a_type({ - typename = "union", - types = ts, - }) + return a_union(ts) end end @@ -7347,19 +7265,13 @@ tl.type_check = function(ast, opts) local element_type = unite(tupletype.types, true) local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) if valid then - return a_type({ - elements = element_type, - typename = "array", - }) + return an_array(element_type) end - local arr_type = a_type({ - elements = tupletype.types[1], - typename = "array", - }) + local arr_type = an_array(tupletype.types[1]) for i = 2, #tupletype.types do - arr_type = expand_type(where, arr_type, a_type({ elements = tupletype.types[i], typename = "array" })) + arr_type = expand_type(where, arr_type, an_array(tupletype.types[i])) if not arr_type.elements then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end @@ -7881,7 +7793,7 @@ a.types[i], b.types[i]), } ["bad_nominal"] = compare_false, ["any"] = compare_true, ["tuple"] = function(a, b) - return is_a(TUPLE({ a }), b) + return is_a(a_tuple({ a }), b) end, ["typevar"] = function(a, b) return compare_or_infer_typevar(b.typevar, a, nil, is_a) @@ -7963,6 +7875,7 @@ a.types[i], b.types[i]), } else ok = t1.typename == t2.typename end + if (not ok) and not err then return false, { Err(t1, "got %s, expected %s", t1, t2) } end @@ -8009,9 +7922,9 @@ a.types[i], b.types[i]), } return true elseif t2.typename == "unresolved_emptytable_value" then if is_number_type(t2.emptytable_type.keys) then - infer_emptytable(t2.emptytable_type, infer_at(where, a_type({ typename = "array", elements = t1 }))) + infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }))) + infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) end return true elseif t2.typename == "emptytable" then @@ -8081,7 +7994,7 @@ a.types[i], b.types[i]), } t = resolve_tuple_and_nominal(t) local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt then - local args_tuple = a_type({ typename = "tuple" }) + local args_tuple = a_tuple({}) for i = 2, #call_mt.args do table.insert(args_tuple, call_mt.args[i]) end @@ -8093,7 +8006,7 @@ a.types[i], b.types[i]), } local function resolve_for_call(func, args, is_method) if lax and is_unknown(func) then - func = a_type({ typename = "function", args = VARARG({ UNKNOWN }), rets = VARARG({ UNKNOWN }) }) + func = a_function({ args = a_vararg({ UNKNOWN }), rets = a_vararg({ UNKNOWN }) }) end func = resolve_tuple_and_nominal(func) @@ -8218,7 +8131,7 @@ a.types[i], b.types[i]), } if f.typeargs then for _, a in ipairs(f.typeargs) do if not find_var_type(a.typearg) then - add_var(nil, a.typearg, lax and UNKNOWN or { typename = "unresolvable_typearg", typearg = a.typearg }) + add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) end end end @@ -8312,7 +8225,7 @@ a.types[i], b.types[i]), } local function push_typeargs(func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, { typename = "unresolved_typearg" }) + add_var(nil, fnarg.typearg, a_type("unresolved_typearg", {})) end end end @@ -8353,7 +8266,9 @@ a.types[i], b.types[i]), } end local f = func.typename == "poly" and func.types[1] or func + mark_invalid_typeargs(f) + return resolve_typevars_at(where, f.rets) end @@ -8368,7 +8283,7 @@ a.types[i], b.types[i]), } argdelta = is_method and -1 or argdelta or 0 if is_method and args[1] then - add_var(nil, "@self", a_type({ typename = "typetype", y = where.y, x = where.x, def = args[1] })) + add_var(nil, "@self", a_typetype({ y = where.y, x = where.x, def = args[1] })) end local is_func = func.typename == "function" @@ -8440,7 +8355,7 @@ a.types[i], b.types[i]), } type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) if node.expected and node.expected.typename ~= "tuple" then - node.expected = a_type({ typename = "tuple", node.expected }) + node.expected = a_tuple({ node.expected }) end begin_scope() @@ -8489,7 +8404,7 @@ a.types[i], b.types[i]), } if metamethod then local where_args = { node.e1 } - local args = { typename = "tuple", orig_a } + local args = a_tuple({ orig_a }) if b and method_name ~= "__is" then where_args[2] = node.e2 args[2] = orig_b @@ -8675,11 +8590,12 @@ a.types[i], b.types[i]), } local function get_rets(rets) if lax and (#rets == 0) then - return VARARG({ UNKNOWN }) + return a_vararg({ UNKNOWN }) end local t = rets if not t.typename then - t = TUPLE(t) + + t = a_tuple(t) end assert(t.typeid) return t @@ -8689,7 +8605,7 @@ a.types[i], b.types[i]), } assert(args.typename == "tuple") add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_type({ typename = "tuple" })) + add_var(nil, "@return", node.rets or a_tuple({})) if node.typeargs then for _, t in ipairs(node.typeargs) do @@ -8704,14 +8620,13 @@ a.types[i], b.types[i]), } local function add_function_definition_for_recursion(node, fnargs) assert(fnargs.typename == "tuple") - local args = TUPLE({}) + local args = a_tuple({}) args.is_va = fnargs.is_va for _, fnarg in ipairs(fnargs) do table.insert(args, fnarg) end - add_var(nil, node.name.tk, a_type({ - typename = "function", + add_var(nil, node.name.tk, a_function({ args = args, rets = get_rets(node.rets), })) @@ -8863,7 +8778,11 @@ a.types[i], b.types[i]), } end if is_a(orig_b, a.keys) then - return a_type({ y = anode.y, x = anode.x, typename = "unresolved_emptytable_value", emptytable_type = a }) + return a_type("unresolved_emptytable_value", { + y = anode.y, + x = anode.x, + emptytable_type = a, + }) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. @@ -8923,11 +8842,12 @@ a.types[i], b.types[i]), } for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) end + edit_type(old, "map") else error_at(where, "cannot determine table literal type") end elseif is_record_type(old) and is_record_type(new) then - old.typename = "map" + edit_type(old, "map") old.keys = STRING for _, ftype in fields_of(old) do if not old.values then @@ -8946,6 +8866,7 @@ a.types[i], b.types[i]), } old.fields = nil old.field_order = nil elseif old.typename == "union" then + edit_type(old, "union") new.tk = nil table.insert(old.types, new) else @@ -8991,13 +8912,12 @@ a.types[i], b.types[i]), } if t.def.typeargs then typevals = {} for _, a in ipairs(t.def.typeargs) do - table.insert(typevals, a_type({ typename = "typevar", typevar = a.typearg })) + table.insert(typevals, a_type("typevar", { typevar = a.typearg })) end end - return a_type({ + return a_type("nominal", { y = where.y, x = where.x, - typename = "nominal", typevals = typevals, names = { name }, found = t, @@ -9401,7 +9321,7 @@ a.types[i], b.types[i]), } local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return TUPLE({ BOOLEAN }) + return a_tuple({ BOOLEAN }) end @@ -9428,7 +9348,8 @@ a.types[i], b.types[i]), } } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets.typename ~= "tuple" then - rets = a_type({ typename = "tuple", rets }) + + rets = a_tuple({ rets }) end table.insert(rets, 1, BOOLEAN) return rets @@ -9561,7 +9482,7 @@ a.types[i], b.types[i]), } resolved = find_type(names) if (not resolved) or (not is_typetype(resolved)) then error_at(typetype, "%s is not a type", typetype) - resolved = a_type({ typename = "bad_nominal", names = names }) + resolved = a_type("bad_nominal", { names = names }) end end return resolved, aliasing @@ -9591,7 +9512,7 @@ a.types[i], b.types[i]), } typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = a_type({ y = node.y, x = node.x, filename = filename, typename = "tuple", types = {} }) + typ = a_type("tuple", { y = node.y, x = node.x, filename = filename }) for a = i, ndecl do table.insert(typ, decls[a]) end @@ -9639,11 +9560,10 @@ a.types[i], b.types[i]), } end local function infer_table_literal(node, children) - local typ = a_type({ + local typ = a_type("emptytable", { filename = filename, y = node.y, x = node.x, - typename = "emptytable", }) local is_record = false @@ -9738,11 +9658,10 @@ a.types[i], b.types[i]), } elseif is_record and is_array then typ.typename = "record" typ.interface_list = { - a_type({ + a_type("array", { filename = filename, y = node.y, x = node.x, - typename = "array", elements = typ.elements, }), } @@ -10013,7 +9932,6 @@ a.types[i], b.types[i]), } local name = node.var.tk local resolved, aliasing = get_type_declaration(node) local var = add_var(node.var, name, resolved, node.var.attribute) - if aliasing then var.aliasing = aliasing node.value.is_alias = true @@ -10227,7 +10145,7 @@ a.types[i], b.types[i]), } error_at(node, "label '" .. node.label .. "' already defined at " .. filename) end local unresolved = st[#st]["@unresolved"] - local var = add_var(node, label_id, a_type({ y = node.y, x = node.x, typename = "none" })) + local var = add_var(node, label_id, a_type("none", { y = node.y, x = node.x })) if unresolved then if unresolved.t.labels[node.label] then var.used = true @@ -10267,11 +10185,10 @@ a.types[i], b.types[i]), } widen_all_unions(node) local exp1 = node.exps[1] - local args = { - typename = "tuple", + local args = a_tuple({ node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], - } + }) local exp1type = resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then @@ -10384,7 +10301,7 @@ a.types[i], b.types[i]), } }, ["variable_list"] = { after = function(_node, children) - local tuple = TUPLE(children) + local tuple = a_tuple(children) local n = #tuple @@ -10538,20 +10455,15 @@ a.types[i], b.types[i]), } local t if force_array then - t = infer_at(node, a_type({ - typename = "array", - elements = force_array, - })) + t = infer_at(node, an_array(force_array)) else t = resolve_typevars_at(node, node.expected) if node.expected == t and t.typename == "nominal" then - t = { - typeid = t.typeid, - typename = "nominal", + t = a_type("nominal", { names = t.names, found = t.found, resolved = t.resolved, - } + }) end end @@ -10579,10 +10491,9 @@ a.types[i], b.types[i]), } vtype.typeid = new_typeid() vtype.is_method = false end - return a_type({ + return a_type("table_item", { y = node.y, x = node.x, - typename = "table_item", kname = kname, ktype = ktype, vtype = vtype, @@ -10606,10 +10517,9 @@ a.types[i], b.types[i]), } end_function_scope(node) local rets = get_rets(children[3]) - local t = ensure_fresh_typeargs(a_type({ + local t = ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[2], rets = rets, @@ -10648,10 +10558,9 @@ a.types[i], b.types[i]), } return NONE end - add_global(node, node.name.tk, ensure_fresh_typeargs(a_type({ + add_global(node, node.name.tk, ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[2], rets = get_rets(children[3]), @@ -10672,10 +10581,9 @@ a.types[i], b.types[i]), } if node.rtype.typeargs then for _, typ in ipairs(node.rtype.typeargs) do - add_var(nil, typ.typearg, a_type({ + add_var(nil, typ.typearg, a_type("typearg", { y = typ.y, x = typ.x, - typename = "typearg", typearg = typ.typearg, })) end @@ -10686,7 +10594,7 @@ a.types[i], b.types[i]), } local rtype = node.rtype if rtype.typename == "emptytable" then - rtype.typename = "record" + edit_type(rtype, "record") rtype.fields = {} rtype.field_order = {} end @@ -10710,10 +10618,9 @@ a.types[i], b.types[i]), } add_var(nil, "self", selftype) end - local fn_type = ensure_fresh_typeargs(a_type({ + local fn_type = ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, - typename = "function", is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -10781,10 +10688,9 @@ a.types[i], b.types[i]), } end_function_scope(node) - return ensure_fresh_typeargs(a_type({ + return ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[1], rets = children[2], @@ -10805,10 +10711,9 @@ a.types[i], b.types[i]), } end_function_scope(node) - return ensure_fresh_typeargs(a_type({ + return ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[1], rets = children[2], @@ -10930,11 +10835,10 @@ a.types[i], b.types[i]), } kind = "string", conststr = node.e2.tk, } - local btype = a_type({ + local btype = a_type("string", { y = node.e2.y, x = node.e2.x, tk = '"' .. node.e2.tk .. '"', - typename = "string", }) local t = type_check_index(node.e1, bnode, orig_a, btype) @@ -11258,7 +11162,7 @@ a.types[i], b.types[i]), } t = UNKNOWN end if node.tk == "..." then - t = a_type({ typename = "tuple", is_va = true, t }) + t = a_vararg({ t }) end if node.opt then t = OPT(t) @@ -11293,10 +11197,9 @@ a.types[i], b.types[i]), } local function after_literal(node) node.known = FACT_TRUTHY - return a_type({ + return a_type(node.kind, { y = node.y, x = node.x, - typename = node.kind, tk = node.tk, }) end @@ -11368,7 +11271,7 @@ a.types[i], b.types[i]), } ["record"] = { before = function(typ) begin_scope() - add_var(nil, "@self", a_type({ typename = "typetype", y = typ.y, x = typ.x, def = typ })) + add_var(nil, "@self", a_typetype({ y = typ.y, x = typ.x, def = typ })) for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then @@ -11429,10 +11332,9 @@ a.types[i], b.types[i]), } }, ["typearg"] = { after = function(typ, _children) - add_var(nil, typ.typearg, a_type({ + add_var(nil, typ.typearg, a_type("typearg", { y = typ.y, x = typ.x, - typename = "typearg", typearg = typ.typearg, })) return typ @@ -11456,8 +11358,8 @@ a.types[i], b.types[i]), } if t then if t.typename == "typearg" then + edit_type(typ, "typevar") typ.names = nil - typ.typename = "typevar" typ.typevar = t.typearg else if t.is_alias then diff --git a/tl.tl b/tl.tl index 700241d5d..2b2b948a9 100644 --- a/tl.tl +++ b/tl.tl @@ -1539,22 +1539,76 @@ local function new_node(tokens: {Token}, i: integer, kind?: NodeKind): Node return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end -local function a_type(t: Type): Type +local function a_type(typename: TypeName, t: Type): Type t.typeid = new_typeid() + t.typename = typename + return t +end + +local function edit_type(t: Type, typename: TypeName): Type + t.typeid = new_typeid() + t.typename = typename return t end local function new_type(ps: ParseState, i: integer, typename: TypeName): Type local token = ps.tokens[i] - return a_type { - typename = assert(typename), + return a_type(typename, { filename = ps.filename, y = token.y, x = token.x, tk = token.tk - } + }) +end + +local function a_tuple(t: {Type}): Type + return a_type("tuple", t) +end + +local function a_union(t: {Type}): Type + return a_type("union", { types = t }) +end + +local function a_poly(t: {Type}): Type + return a_type("poly", { types = t }) +end + +local function a_function(t: Type): Type + return a_type("function", t) +end + +local function a_typetype(t: Type): Type + return a_type("typetype", t) +end + +local function a_vararg(t: {Type}): Type + local tuple = t as Type + tuple.is_va = true + return a_tuple(t) +end + +local function an_array(t: Type): Type + return a_type("array", { + elements = t, + }) +end + +local function a_map(k: Type, v: Type): Type + return a_type("map", { + keys = k, + values = v, + }) end +local NIL = a_type("nil", {}) +local ANY = a_type("any", {}) +local TABLE = a_type("map", { keys = ANY, values = ANY }) +local NUMBER = a_type("number", {}) +local STRING = a_type("string", {}) +local THREAD = a_type("thread", {}) +local BOOLEAN = a_type("boolean", {}) +local INTEGER = a_type("integer", {}) + -- Makes a shallow copy of the given type local function shallow_copy_type(t: Type): Type local copy: {any:any} = {} @@ -1808,12 +1862,11 @@ end local function parse_typearg(ps: ParseState, i: integer): integer, Type, integer i = verify_kind(ps, i, "identifier") - return i, a_type { + return i, a_type("typearg", { y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, - typename = "typearg", typearg = ps.tokens[i-1].tk, - } + }) end local function parse_return_types(ps: ParseState, i: integer): integer, Type @@ -1830,8 +1883,8 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Type i, typ.args = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else - typ.args = a_type { typename = "tuple", is_va = true, a_type { typename = "any" } } - typ.rets = a_type { typename = "tuple", is_va = true, a_type { typename = "any" } } + typ.args = a_vararg { ANY } + typ.rets = a_vararg { ANY } end if typ.args[1] and typ.args[1].is_self then typ.is_method = true @@ -1839,15 +1892,6 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Type return i, typ end -local NIL = a_type { typename = "nil" } -local ANY = a_type { typename = "any" } -local TABLE = a_type { typename = "map", keys = ANY, values = ANY } -local NUMBER = a_type { typename = "number" } -local STRING = a_type { typename = "string" } -local THREAD = a_type { typename = "thread" } -local BOOLEAN = a_type { typename = "boolean" } -local INTEGER = a_type { typename = "integer" } - local simple_types: {string:Type} = { ["nil"] = NIL, ["any"] = ANY, @@ -1902,18 +1946,21 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ if ps.tokens[i].kind == "identifier" then return parse_simple_type_or_nominal(ps, i) elseif tk == "{" then + local istart = i i = i + 1 - local decl = new_type(ps, i, "array") local t: Type i, t = parse_type(ps, i) if not t then return i end if ps.tokens[i].tk == "}" then + local decl = new_type(ps, istart, "array") decl.elements = t end_at(decl as Node, ps.tokens[i]) i = verify_tk(ps, i, "}") + return i, decl elseif ps.tokens[i].tk == "," then + local decl = new_type(ps, istart, "tupletable") decl.typename = "tupletable" decl.types = { t } local n = 2 @@ -1927,8 +1974,9 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ until ps.tokens[i].tk ~= "," end_at(decl as Node, ps.tokens[i]) i = verify_tk(ps, i, "}") + return i, decl elseif ps.tokens[i].tk == ":" then - decl.typename = "map" + local decl = new_type(ps, istart, "map") i = i + 1 decl.keys = t i, decl.values = parse_type(ps, i) @@ -1937,18 +1985,17 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ end end_at(decl as Node, ps.tokens[i]) i = verify_tk(ps, i, "}") - else - return fail(ps, i, "syntax error; did you forget a '}'?") + return i, decl end - return i, decl + return fail(ps, i, "syntax error; did you forget a '}'?") elseif tk == "function" then return parse_function_type(ps, i) elseif tk == "nil" then return i + 1, simple_types["nil"] elseif tk == "table" then local typ = new_type(ps, i, "map") - typ.keys = a_type { typename = "any" } - typ.values = a_type { typename = "any" } + typ.keys = ANY + typ.values = ANY return i + 1, typ end return fail(ps, i, "expected a type") @@ -3002,8 +3049,8 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): local typ = new_type(ps, wstart, "function") typ.is_method = true - typ.args = a_type { typename = "tuple", a_type { typename = "nominal", y = typ.y, x = typ.x, names = { "@self" } } } - typ.rets = a_type { typename = "tuple", a_type { typename = "boolean" } } + typ.args = a_tuple { a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) } + typ.rets = a_tuple { BOOLEAN } typ.macroexp = where_macroexp store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) @@ -4897,32 +4944,15 @@ end -- Type check -------------------------------------------------------------------------------- -local function VARARG(t: {Type}): Type - local tuple = t as Type - tuple.typename = "tuple" - tuple.is_va = true - return a_type(t) -end - -local function TUPLE(t: {Type}): Type - local tuple = t as Type - tuple.typename = "tuple" - return a_type(t) -end +local NONE = a_type("none", {}) +local INVALID = a_type("invalid", {}) +local UNKNOWN = a_type("unknown", {}) +local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local function UNION(t: {Type}): Type - return a_type { typename = "union", types = t } -end +local FUNCTION = a_function { args = a_vararg { ANY }, rets = a_vararg { ANY } } -local NONE = a_type { typename = "none" } -local INVALID = a_type { typename = "invalid" } -local UNKNOWN = a_type { typename = "unknown" } -local CIRCULAR_REQUIRE = a_type { typename = "circular_require" } - -local FUNCTION = a_type { typename = "function", args = VARARG { ANY }, rets = VARARG { ANY } } - -local NOMINAL_FILE = a_type { typename = "nominal", names = {"FILE"} } -local XPCALL_MSGH_FUNCTION = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { } } +local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} }) +local XPCALL_MSGH_FUNCTION = a_function { args = a_tuple { ANY }, rets = a_tuple { } } local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type @@ -5579,40 +5609,34 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} end local function a_record(t: Type): Type - t = a_type(t) - t.typename = "record" + t = a_type("record", t) t.field_order = sorted_keys(t.fields) return t end - local function a_gfunction(n: integer, f: function(...: Type): Type): Type + local function a_gfunction(n: integer, f: function(...: Type): (Type), typename?: TypeName): Type local typevars = {} local typeargs = {} local c = string.byte("A") - 1 fresh_typevar_ctr = fresh_typevar_ctr + 1 for i = 1, n do local name = string.char(c + i) .. "@" .. fresh_typevar_ctr - typevars[i] = a_type { typename = "typevar", typevar = name } - typeargs[i] = a_type { typename = "typearg", typearg = name } + typevars[i] = a_type("typevar", { typevar = name }) + typeargs[i] = a_type("typearg", { typearg = name }) end local t = f(table.unpack(typevars)) - t.typename = "function" t.typeargs = typeargs - return a_type(t) + return a_type(typename or "function", t) end local function a_grecord(n: integer, f: function(...: Type): Type): Type - local t = a_gfunction(n, f) - t.typename = "record" + local t = a_gfunction(n, f, "record") t.field_order = sorted_keys(t.fields) return t end local function an_enum(keys: {string}): Type - local t = a_type { - typename = "enum", - enumset = {} - } + local t = a_type("enum", { enumset = {} }) for _, k in ipairs(keys) do t.enumset[k] = true end @@ -5628,18 +5652,15 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} end local file_reader_poly_types: {ArgsRets} = { - { ctor = VARARG, args = {UNION { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, - { ctor = TUPLE, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, - { ctor = VARARG, args = { UNION { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { UNION { STRING, NUMBER } } }, - { ctor = VARARG, args = { UNION { NUMBER, STRING } }, rets = { STRING } }, - { ctor = VARARG, args = { }, rets = { STRING } }, + { ctor = a_vararg, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, + { ctor = a_tuple, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, + { ctor = a_vararg, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { a_union { STRING, NUMBER } } }, + { ctor = a_vararg, args = { a_union { NUMBER, STRING } }, rets = { STRING } }, + { ctor = a_vararg, args = { }, rets = { STRING } }, } local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type)): Type - local t = a_type { - typename = "poly", - types = {} - } + local t = a_poly {} for _, entry in ipairs(file_reader_poly_types) do local args = shallow_copy_type(entry.args as Type) as {Type} local rets = shallow_copy_type(entry.rets as Type) as {Type} @@ -5648,7 +5669,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local LOAD_FUNCTION = a_type { typename = "function", args = {}, rets = TUPLE { STRING } } + local LOAD_FUNCTION = a_function { args = {}, rets = a_tuple { STRING } } local OS_DATE_TABLE = a_record { fields = { @@ -5679,467 +5700,370 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["nparams"] = INTEGER, ["isvararg"] = BOOLEAN, ["func"] = ANY, - ["activelines"] = a_type { typename = "map", keys = INTEGER, values = BOOLEAN }, + ["activelines"] = a_type("map", { keys = INTEGER, values = BOOLEAN }), } } local DEBUG_HOOK_EVENT = an_enum { "call", "tail call", "return", "line", "count" } - local DEBUG_HOOK_FUNCTION = a_type { - typename = "function", - args = TUPLE { DEBUG_HOOK_EVENT, INTEGER }, - rets = TUPLE {}, + local DEBUG_HOOK_FUNCTION = a_function { + args = a_tuple { DEBUG_HOOK_EVENT, INTEGER }, + rets = a_tuple {}, } - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type):Type return { args = TUPLE { a, a }, rets = TUPLE { BOOLEAN } } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type):Type return { args = a_tuple { a, a }, rets = a_tuple { BOOLEAN } } end) local metatable_nominals = {} local function METATABLE(a: Type): Type - local t = a_type { typename = "nominal", names = {"metatable"}, typevals = { a } } + local t = a_type("nominal", { names = {"metatable"}, typevals = { a } }) table.insert(metatable_nominals, t) return t end - local function ARRAY(t: Type): Type - return a_type { - typename = "array", - elements = t, - } - end - - local function MAP(k: Type, v: Type): Type - return a_type { - typename = "map", - keys = k, - values = v, - } - end - local standard_library: {string:Type} = { - ["..."] = VARARG { STRING }, - ["any"] = a_type { typename = "typetype", def = ANY }, - ["arg"] = ARRAY(STRING), - ["assert"] = a_gfunction(2, function(a: Type, b: Type): Type return { args = TUPLE { a, OPT(b) }, rets = TUPLE { a } } end), - ["collectgarbage"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { an_enum { "collect", "count", "stop", "restart" } }, rets = TUPLE { NUMBER } }, - a_type { typename = "function", args = TUPLE { an_enum { "step", "setpause", "setstepmul" }, NUMBER }, rets = TUPLE { NUMBER } }, - a_type { typename = "function", args = TUPLE { an_enum { "isrunning" } }, rets = TUPLE { BOOLEAN } }, - a_type { typename = "function", args = TUPLE { STRING, OPT(NUMBER) }, rets = TUPLE { a_type { typename = "union", types = { BOOLEAN, NUMBER } } } }, - } + ["..."] = a_vararg { STRING }, + ["any"] = a_type("typetype", { def = ANY }), + ["arg"] = an_array(STRING), + ["assert"] = a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a, OPT(b) }, rets = a_tuple { a } } end), + ["collectgarbage"] = a_poly { + a_function { args = a_tuple { an_enum { "collect", "count", "stop", "restart" } }, rets = a_tuple { NUMBER } }, + a_function { args = a_tuple { an_enum { "step", "setpause", "setstepmul" }, NUMBER }, rets = a_tuple { NUMBER } }, + a_function { args = a_tuple { an_enum { "isrunning" } }, rets = a_tuple { BOOLEAN } }, + a_function { args = a_tuple { STRING, OPT(NUMBER) }, rets = a_tuple { a_union { BOOLEAN, NUMBER } } }, }, - ["dofile"] = a_type { typename = "function", args = TUPLE { OPT(STRING) }, rets = VARARG { ANY } }, - ["error"] = a_type { typename = "function", args = TUPLE { ANY, OPT(NUMBER) }, rets = TUPLE {} }, - ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { a }, rets = TUPLE { METATABLE(a) } } end), - ["ipairs"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a) }, rets = TUPLE { - a_type { typename = "function", args = TUPLE {}, rets = TUPLE { INTEGER, a } }, + ["dofile"] = a_function { args = a_tuple { OPT(STRING) }, rets = a_vararg { ANY } }, + ["error"] = a_function { args = a_tuple { ANY, OPT(NUMBER) }, rets = a_tuple {} }, + ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), + ["ipairs"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a) }, rets = a_tuple { + a_function { args = a_tuple {}, rets = a_tuple { INTEGER, a } }, } } end), - ["load"] = a_type { typename = "function", args = TUPLE { UNION { STRING, LOAD_FUNCTION }, OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = TUPLE { FUNCTION, STRING } }, - ["loadfile"] = a_type { typename = "function", args = TUPLE { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = TUPLE { FUNCTION, STRING } }, - ["next"] = a_type { - typename = "poly", - types = { - a_gfunction(2, function(a: Type, b: Type): Type return { args = TUPLE { MAP(a, b), OPT(a) }, rets = TUPLE { a, b } } end), - a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), OPT(a) }, rets = TUPLE { INTEGER, a } } end), - }, + ["load"] = a_function { args = a_tuple { a_union { STRING, LOAD_FUNCTION }, OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = a_tuple { FUNCTION, STRING } }, + ["loadfile"] = a_function { args = a_tuple { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = a_tuple { FUNCTION, STRING } }, + ["next"] = a_poly { + a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a_map(a, b), OPT(a) }, rets = a_tuple { a, b } } end), + a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), OPT(a) }, rets = a_tuple { INTEGER, a } } end), }, - ["pairs"] = a_gfunction(2, function(a: Type, b: Type): Type return { args = TUPLE { a_type { typename = "map", keys = a, values = b } }, rets = TUPLE { - a_type { typename = "function", args = TUPLE {}, rets = TUPLE { a, b } }, + ["pairs"] = a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a_map(a, b) }, rets = a_tuple { + a_function { args = a_tuple {}, rets = a_tuple { a, b } }, } } end), - ["pcall"] = a_type { typename = "function", args = VARARG { FUNCTION, ANY }, rets = VARARG { BOOLEAN, ANY } }, - ["xpcall"] = a_type { typename = "function", args = VARARG { FUNCTION, XPCALL_MSGH_FUNCTION, ANY }, rets = VARARG { BOOLEAN, ANY } }, - ["print"] = a_type { typename = "function", args = VARARG { ANY }, rets = TUPLE {} }, - ["rawequal"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { BOOLEAN } }, - ["rawget"] = a_type { typename = "function", args = TUPLE { TABLE, ANY }, rets = TUPLE { ANY } }, - ["rawlen"] = a_type { typename = "function", args = TUPLE { UNION { TABLE, STRING } }, rets = TUPLE { INTEGER } }, - ["rawset"] = a_type { - typename = "poly", - types = { - a_gfunction(2, function(a: Type, b: Type): Type return { args = TUPLE { MAP(a, b), a, b }, rets = TUPLE {} } end), - a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), NUMBER, a }, rets = TUPLE {} } end), - a_type { typename = "function", args = TUPLE { TABLE, ANY, ANY }, rets = TUPLE {} }, - } + ["pcall"] = a_function { args = a_vararg { FUNCTION, ANY }, rets = a_vararg { BOOLEAN, ANY } }, + ["xpcall"] = a_function { args = a_vararg { FUNCTION, XPCALL_MSGH_FUNCTION, ANY }, rets = a_vararg { BOOLEAN, ANY } }, + ["print"] = a_function { args = a_vararg { ANY }, rets = a_tuple {} }, + ["rawequal"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, + ["rawget"] = a_function { args = a_tuple { TABLE, ANY }, rets = a_tuple { ANY } }, + ["rawlen"] = a_function { args = a_tuple { a_union { TABLE, STRING } }, rets = a_tuple { INTEGER } }, + ["rawset"] = a_poly { + a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a_map(a, b), a, b }, rets = a_tuple {} } end), + a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), + a_function { args = a_tuple { TABLE, ANY, ANY }, rets = a_tuple {} }, }, - ["require"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE {} }, - ["select"] = a_type { - typename = "poly", - types = { - a_gfunction(1, function(a: Type): Type return { args = VARARG { NUMBER, a }, rets = TUPLE { a } } end), - a_type { typename = "function", args = VARARG { NUMBER, ANY }, rets = TUPLE { ANY } }, - a_type { typename = "function", args = VARARG { STRING, ANY }, rets = TUPLE { INTEGER } }, - } + ["require"] = a_function { args = a_tuple { STRING }, rets = a_tuple {} }, + ["select"] = a_poly { + a_gfunction(1, function(a: Type): Type return { args = a_vararg { NUMBER, a }, rets = a_tuple { a } } end), + a_function { args = a_vararg { NUMBER, ANY }, rets = a_tuple { ANY } }, + a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { INTEGER } }, }, - ["setmetatable"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { a, METATABLE(a) }, rets = TUPLE { a } } end), - ["tonumber"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { NUMBER } }, - a_type { typename = "function", args = TUPLE { ANY, NUMBER }, rets = TUPLE { INTEGER } }, - } + ["setmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), + ["tonumber"] = a_poly { + a_function { args = a_tuple { ANY }, rets = a_tuple { NUMBER } }, + a_function { args = a_tuple { ANY, NUMBER }, rets = a_tuple { INTEGER } }, }, - ["tostring"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { STRING } }, - ["type"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { STRING } }, - ["FILE"] = a_type { - typename = "typetype", + ["tostring"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, + ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, + ["FILE"] = a_typetype { def = a_record { is_userdata = true, fields = { - ["close"] = a_type { typename = "function", args = TUPLE { NOMINAL_FILE }, rets = TUPLE { BOOLEAN, STRING, INTEGER } }, - ["flush"] = a_type { typename = "function", args = TUPLE { NOMINAL_FILE }, rets = TUPLE {} }, + ["close"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple { BOOLEAN, STRING, INTEGER } }, + ["flush"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple {} }, ["lines"] = a_file_reader(function(ctor: (function({Type}):Type), args: {Type}, rets: {Type}): Type table.insert(args, 1, NOMINAL_FILE) - return a_type { typename = "function", args = ctor(args), rets = TUPLE { - a_type { typename = "function", args = TUPLE {}, rets = ctor(rets) }, + return a_function { args = ctor(args), rets = a_tuple { + a_function { args = a_tuple {}, rets = ctor(rets) }, } } end), ["read"] = a_file_reader(function(ctor: (function({Type}):Type), args: {Type}, rets: {Type}): Type table.insert(args, 1, NOMINAL_FILE) - return a_type { typename = "function", args = ctor(args), rets = ctor(rets) } + return a_function { args = ctor(args), rets = ctor(rets) } end), - ["seek"] = a_type { typename = "function", args = TUPLE { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = TUPLE { INTEGER, STRING } }, - ["setvbuf"] = a_type { typename = "function", args = TUPLE { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = TUPLE {} }, - ["write"] = a_type { typename = "function", args = VARARG { NOMINAL_FILE, UNION { STRING, NUMBER } }, rets = TUPLE { NOMINAL_FILE, STRING } }, + ["seek"] = a_function { args = a_tuple { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = a_tuple { INTEGER, STRING } }, + ["setvbuf"] = a_function { args = a_tuple { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = a_tuple {} }, + ["write"] = a_function { args = a_vararg { NOMINAL_FILE, a_union { STRING, NUMBER } }, rets = a_tuple { NOMINAL_FILE, STRING } }, -- TODO complete... }, meta_fields = { ["__close"] = FUNCTION }, meta_field_order = { "__close" }, }, }, - ["metatable"] = a_type { - typename = "typetype", + ["metatable"] = a_typetype { def = a_grecord(1, function(a: Type): Type return { fields = { - ["__call"] = a_type { typename = "function", args = VARARG { a, ANY }, rets = VARARG { ANY } }, - ["__gc"] = a_type { typename = "function", args = TUPLE { a }, rets = TUPLE {} }, + ["__call"] = a_function { args = a_vararg { a, ANY }, rets = a_vararg { ANY } }, + ["__gc"] = a_function { args = a_tuple { a }, rets = a_tuple {} }, ["__index"] = ANY, -- FIXME: function | table | anything with an __index metamethod - ["__len"] = a_type { typename = "function", args = TUPLE { a }, rets = TUPLE { ANY } }, + ["__len"] = a_function { args = a_tuple { a }, rets = a_tuple { ANY } }, ["__mode"] = an_enum { "k", "v", "kv" }, ["__newindex"] = ANY, -- FIXME: function | table | anything with a __newindex metamethod ["__pairs"] = a_gfunction(2, function(k: Type, v: Type): Type return { - args = TUPLE { a }, - rets = TUPLE { a_type { typename = "function", args = TUPLE {}, rets = TUPLE { k, v } } } + args = a_tuple { a }, + rets = a_tuple { a_function { args = a_tuple {}, rets = a_tuple { k, v } } } } end), - ["__tostring"] = a_type { typename = "function", args = TUPLE { a }, rets = TUPLE { STRING } }, + ["__tostring"] = a_function { args = a_tuple { a }, rets = a_tuple { STRING } }, ["__name"] = STRING, - ["__add"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__sub"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__mul"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__div"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__idiv"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__mod"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__pow"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__unm"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { ANY } }, - ["__band"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__bor"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__bxor"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__bnot"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { ANY } }, - ["__shl"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__shr"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__concat"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { ANY } }, - ["__eq"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { BOOLEAN } }, - ["__lt"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { BOOLEAN } }, - ["__le"] = a_type { typename = "function", args = TUPLE { ANY, ANY }, rets = TUPLE { BOOLEAN } }, - ["__close"] = a_type { typename = "function", args = TUPLE { a }, rets = TUPLE { } }, + ["__add"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__sub"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__mul"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__div"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__idiv"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__mod"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__pow"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__unm"] = a_function { args = a_tuple { ANY }, rets = a_tuple { ANY } }, + ["__band"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__bor"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__bxor"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__bnot"] = a_function { args = a_tuple { ANY }, rets = a_tuple { ANY } }, + ["__shl"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__shr"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__concat"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, + ["__eq"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, + ["__lt"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, + ["__le"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, + ["__close"] = a_function { args = a_tuple { a }, rets = a_tuple { } }, }, } end), }, ["coroutine"] = a_record { fields = { - ["create"] = a_type { typename = "function", args = TUPLE { FUNCTION }, rets = TUPLE { THREAD } }, - ["close"] = a_type { typename = "function", args = TUPLE { THREAD }, rets = TUPLE { BOOLEAN, STRING } }, - ["isyieldable"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { BOOLEAN } }, - ["resume"] = a_type { typename = "function", args = VARARG { THREAD, ANY }, rets = VARARG { BOOLEAN, ANY } }, - ["running"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { THREAD, BOOLEAN } }, - ["status"] = a_type { typename = "function", args = TUPLE { THREAD }, rets = TUPLE { STRING } }, - ["wrap"] = a_type { typename = "function", args = TUPLE { FUNCTION }, rets = TUPLE { FUNCTION } }, - ["yield"] = a_type { typename = "function", args = VARARG { ANY }, rets = VARARG { ANY } }, + ["create"] = a_function { args = a_tuple { FUNCTION }, rets = a_tuple { THREAD } }, + ["close"] = a_function { args = a_tuple { THREAD }, rets = a_tuple { BOOLEAN, STRING } }, + ["isyieldable"] = a_function { args = a_tuple {}, rets = a_tuple { BOOLEAN } }, + ["resume"] = a_function { args = a_vararg { THREAD, ANY }, rets = a_vararg { BOOLEAN, ANY } }, + ["running"] = a_function { args = a_tuple {}, rets = a_tuple { THREAD, BOOLEAN } }, + ["status"] = a_function { args = a_tuple { THREAD }, rets = a_tuple { STRING } }, + ["wrap"] = a_function { args = a_tuple { FUNCTION }, rets = a_tuple { FUNCTION } }, + ["yield"] = a_function { args = a_vararg { ANY }, rets = a_vararg { ANY } }, } }, ["debug"] = a_record { fields = { - ["Info"] = a_type { - typename = "typetype", - def = DEBUG_GETINFO_TABLE, + ["Info"] = a_typetype { def = DEBUG_GETINFO_TABLE }, + ["Hook"] = a_typetype { def = DEBUG_HOOK_FUNCTION }, + ["HookEvent"] = a_typetype { def = DEBUG_HOOK_EVENT }, + + ["debug"] = a_function { args = a_tuple {}, rets = a_tuple {} }, + ["gethook"] = a_function { args = a_tuple { OPT(THREAD) }, rets = a_tuple { DEBUG_HOOK_FUNCTION, INTEGER } }, + ["getlocal"] = a_poly { + a_function { args = a_tuple { THREAD, FUNCTION, NUMBER }, rets = STRING }, + a_function { args = a_tuple { THREAD, NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, + a_function { args = a_tuple { FUNCTION, NUMBER }, rets = STRING }, + a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, }, - ["Hook"] = a_type { - typename = "typetype", - def = DEBUG_HOOK_FUNCTION, + ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), + ["getregistry"] = a_function { args = a_tuple {}, rets = a_tuple { TABLE } }, + ["getupvalue"] = a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { ANY } }, + ["getuservalue"] = a_function { args = a_tuple { USERDATA, NUMBER }, rets = a_tuple { ANY } }, + ["sethook"] = a_poly { + a_function { args = a_tuple { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = a_tuple {} }, + a_function { args = a_tuple { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = a_tuple {} }, }, - ["HookEvent"] = a_type { - typename = "typetype", - def = DEBUG_HOOK_EVENT, + ["setlocal"] = a_poly { + a_function { args = a_tuple { THREAD, NUMBER, NUMBER, ANY }, rets = a_tuple { STRING } }, + a_function { args = a_tuple { NUMBER, NUMBER, ANY }, rets = a_tuple { STRING } }, }, - - ["debug"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE {} }, - ["gethook"] = a_type { typename = "function", args = TUPLE { OPT(THREAD) }, rets = TUPLE { DEBUG_HOOK_FUNCTION, INTEGER } }, - ["getlocal"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { THREAD, FUNCTION, NUMBER }, rets = STRING }, - a_type { typename = "function", args = TUPLE { THREAD, NUMBER, NUMBER }, rets = TUPLE { STRING, ANY } }, - a_type { typename = "function", args = TUPLE { FUNCTION, NUMBER }, rets = STRING }, - a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { STRING, ANY } }, - }, + ["setmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), + ["setupvalue"] = a_function { args = a_tuple { FUNCTION, NUMBER, ANY }, rets = a_tuple { STRING } }, + ["setuservalue"] = a_function { args = a_tuple { USERDATA, ANY, NUMBER }, rets = a_tuple { USERDATA } }, + ["traceback"] = a_poly { + a_function { args = a_tuple { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING } }, + a_function { args = a_tuple { OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING } }, }, - ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { a }, rets = TUPLE { METATABLE(a) } } end), - ["getregistry"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { TABLE } }, - ["getupvalue"] = a_type { typename = "function", args = TUPLE { FUNCTION, NUMBER }, rets = TUPLE { ANY } }, - ["getuservalue"] = a_type { typename = "function", args = TUPLE { USERDATA, NUMBER }, rets = TUPLE { ANY } }, - ["sethook"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = TUPLE {} }, - a_type { typename = "function", args = TUPLE { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = TUPLE {} }, - } - }, - ["setlocal"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { THREAD, NUMBER, NUMBER, ANY }, rets = TUPLE { STRING } }, - a_type { typename = "function", args = TUPLE { NUMBER, NUMBER, ANY }, rets = TUPLE { STRING } }, - } - }, - ["setmetatable"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { a, METATABLE(a) }, rets = TUPLE { a } } end), - ["setupvalue"] = a_type { typename = "function", args = TUPLE { FUNCTION, NUMBER, ANY }, rets = TUPLE { STRING } }, - ["setuservalue"] = a_type { typename = "function", args = TUPLE { USERDATA, ANY, NUMBER }, rets = TUPLE { USERDATA } }, - ["traceback"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING } }, - a_type { typename = "function", args = TUPLE { OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING } }, - }, - }, - ["upvalueid"] = a_type { typename = "function", args = TUPLE { FUNCTION, NUMBER }, rets = TUPLE { USERDATA } }, - ["upvaluejoin"] = a_type { typename = "function", args = TUPLE { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = TUPLE {} }, - ["getinfo"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { DEBUG_GETINFO_TABLE } }, - a_type { typename = "function", args = TUPLE { ANY, STRING }, rets = TUPLE { DEBUG_GETINFO_TABLE } }, - a_type { typename = "function", args = TUPLE { ANY, ANY, STRING }, rets = TUPLE { DEBUG_GETINFO_TABLE } }, - }, + ["upvalueid"] = a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { USERDATA } }, + ["upvaluejoin"] = a_function { args = a_tuple { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = a_tuple {} }, + ["getinfo"] = a_poly { + a_function { args = a_tuple { ANY }, rets = a_tuple { DEBUG_GETINFO_TABLE } }, + a_function { args = a_tuple { ANY, STRING }, rets = a_tuple { DEBUG_GETINFO_TABLE } }, + a_function { args = a_tuple { ANY, ANY, STRING }, rets = a_tuple { DEBUG_GETINFO_TABLE } }, }, }, }, ["io"] = a_record { fields = { - ["close"] = a_type { typename = "function", args = TUPLE { OPT(NOMINAL_FILE) }, rets = TUPLE { BOOLEAN, STRING } }, - ["flush"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE {} }, - ["input"] = a_type { typename = "function", args = TUPLE { OPT(UNION { STRING, NOMINAL_FILE }) }, rets = TUPLE { NOMINAL_FILE } }, + ["close"] = a_function { args = a_tuple { OPT(NOMINAL_FILE) }, rets = a_tuple { BOOLEAN, STRING } }, + ["flush"] = a_function { args = a_tuple {}, rets = a_tuple {} }, + ["input"] = a_function { args = a_tuple { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = a_tuple { NOMINAL_FILE } }, ["lines"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - return a_type { typename = "function", args = ctor(args), rets = TUPLE { - a_type { typename = "function", args = TUPLE {}, rets = ctor(rets) }, + return a_function { args = ctor(args), rets = a_tuple { + a_function { args = a_tuple {}, rets = ctor(rets) }, } } end), - ["open"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING) }, rets = TUPLE { NOMINAL_FILE, STRING } }, - ["output"] = a_type { typename = "function", args = TUPLE { OPT(UNION { STRING, NOMINAL_FILE }) }, rets = TUPLE { NOMINAL_FILE } }, - ["popen"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING) }, rets = TUPLE { NOMINAL_FILE, STRING } }, + ["open"] = a_function { args = a_tuple { STRING, OPT(STRING) }, rets = a_tuple { NOMINAL_FILE, STRING } }, + ["output"] = a_function { args = a_tuple { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = a_tuple { NOMINAL_FILE } }, + ["popen"] = a_function { args = a_tuple { STRING, OPT(STRING) }, rets = a_tuple { NOMINAL_FILE, STRING } }, ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - return a_type { typename = "function", args = ctor(args), rets = ctor(rets) } + return a_function { args = ctor(args), rets = ctor(rets) } end), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { NOMINAL_FILE } }, - ["type"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { STRING } }, - ["write"] = a_type { typename = "function", args = VARARG { UNION { STRING, NUMBER } }, rets = TUPLE { NOMINAL_FILE, STRING } }, + ["tmpfile"] = a_function { args = a_tuple {}, rets = a_tuple { NOMINAL_FILE } }, + ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, + ["write"] = a_function { args = a_vararg { a_union { STRING, NUMBER } }, rets = a_tuple { NOMINAL_FILE, STRING } }, }, }, ["math"] = a_record { fields = { - ["abs"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { INTEGER }, rets = TUPLE { INTEGER } }, - a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - } + ["abs"] = a_poly { + a_function { args = a_tuple { INTEGER }, rets = a_tuple { INTEGER } }, + a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, }, - ["acos"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["asin"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["atan"] = a_type { typename = "function", args = TUPLE { NUMBER, OPT(NUMBER) }, rets = TUPLE { NUMBER } }, - ["atan2"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, - ["ceil"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { INTEGER } }, - ["cos"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["cosh"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["deg"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["exp"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["floor"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { INTEGER } }, - ["fmod"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { INTEGER, INTEGER }, rets = TUPLE { INTEGER } }, - a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, - } + ["acos"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["asin"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["atan"] = a_function { args = a_tuple { NUMBER, OPT(NUMBER) }, rets = a_tuple { NUMBER } }, + ["atan2"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, + ["ceil"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { INTEGER } }, + ["cos"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["cosh"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["deg"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["exp"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["floor"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { INTEGER } }, + ["fmod"] = a_poly { + a_function { args = a_tuple { INTEGER, INTEGER }, rets = a_tuple { INTEGER } }, + a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, }, - ["frexp"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER, NUMBER } }, + ["frexp"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER, NUMBER } }, ["huge"] = NUMBER, - ["ldexp"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, - ["log"] = a_type { typename = "function", args = TUPLE { NUMBER, OPT(NUMBER) }, rets = TUPLE { NUMBER } }, - ["log10"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["max"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = VARARG { INTEGER }, rets = TUPLE { INTEGER } }, - a_gfunction(1, function(a: Type): Type return { args = VARARG { a }, rets = TUPLE { a } } end), - a_type { typename = "function", args = VARARG { a_type { typename = "union", types = { NUMBER, INTEGER } } }, rets = TUPLE { NUMBER } }, - a_type { typename = "function", args = VARARG { ANY }, rets = TUPLE { ANY } }, - } + ["ldexp"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, + ["log"] = a_function { args = a_tuple { NUMBER, OPT(NUMBER) }, rets = a_tuple { NUMBER } }, + ["log10"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["max"] = a_poly { + a_function { args = a_vararg { INTEGER }, rets = a_tuple { INTEGER } }, + a_gfunction(1, function(a: Type): Type return { args = a_vararg { a }, rets = a_tuple { a } } end), + a_function { args = a_vararg { a_union { NUMBER, INTEGER } }, rets = a_tuple { NUMBER } }, + a_function { args = a_vararg { ANY }, rets = a_tuple { ANY } }, }, - ["maxinteger"] = a_type { typename = "integer", needs_compat = true }, - ["min"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = VARARG { INTEGER }, rets = TUPLE { INTEGER } }, - a_gfunction(1, function(a: Type): Type return { args = VARARG { a }, rets = TUPLE { a } } end), - a_type { typename = "function", args = VARARG { a_type { typename = "union", types = { NUMBER, INTEGER } } }, rets = TUPLE { NUMBER } }, - a_type { typename = "function", args = VARARG { ANY }, rets = TUPLE { ANY } }, - } + ["maxinteger"] = a_type("integer", { needs_compat = true }), + ["min"] = a_poly { + a_function { args = a_vararg { INTEGER }, rets = a_tuple { INTEGER } }, + a_gfunction(1, function(a: Type): Type return { args = a_vararg { a }, rets = a_tuple { a } } end), + a_function { args = a_vararg { a_union { NUMBER, INTEGER } }, rets = a_tuple { NUMBER } }, + a_function { args = a_vararg { ANY }, rets = a_tuple { ANY } }, }, - ["mininteger"] = a_type { typename = "integer", needs_compat = true }, - ["modf"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { INTEGER, NUMBER } }, + ["mininteger"] = a_type("integer", { needs_compat = true }), + ["modf"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { INTEGER, NUMBER } }, ["pi"] = NUMBER, - ["pow"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, - ["rad"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["random"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { NUMBER, OPT(NUMBER) }, rets = TUPLE { INTEGER } }, - a_type { typename = "function", args = TUPLE {}, rets = TUPLE { NUMBER } }, - } + ["pow"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, + ["rad"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["random"] = a_poly { + a_function { args = a_tuple { NUMBER, OPT(NUMBER) }, rets = a_tuple { INTEGER } }, + a_function { args = a_tuple {}, rets = a_tuple { NUMBER } }, }, - ["randomseed"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { INTEGER, INTEGER } }, - ["sin"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["sinh"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["sqrt"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["tan"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["tanh"] = a_type { typename = "function", args = TUPLE { NUMBER }, rets = TUPLE { NUMBER } }, - ["tointeger"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { INTEGER } }, - ["type"] = a_type { typename = "function", args = TUPLE { ANY }, rets = TUPLE { STRING } }, - ["ult"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { BOOLEAN } }, + ["randomseed"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { INTEGER, INTEGER } }, + ["sin"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["sinh"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["sqrt"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["tan"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["tanh"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["tointeger"] = a_function { args = a_tuple { ANY }, rets = a_tuple { INTEGER } }, + ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, + ["ult"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { BOOLEAN } }, }, }, ["os"] = a_record { fields = { - ["clock"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { NUMBER } }, - ["date"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { }, rets = TUPLE { STRING } }, - a_type { typename = "function", args = TUPLE { an_enum { "!*t", "*t" }, OPT(NUMBER) }, rets = TUPLE { OS_DATE_TABLE } }, - a_type { typename = "function", args = TUPLE { OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING } }, - } + ["clock"] = a_function { args = a_tuple {}, rets = a_tuple { NUMBER } }, + ["date"] = a_poly { + a_function { args = a_tuple { }, rets = a_tuple { STRING } }, + a_function { args = a_tuple { an_enum { "!*t", "*t" }, OPT(NUMBER) }, rets = a_tuple { OS_DATE_TABLE } }, + a_function { args = a_tuple { OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING } }, }, - ["difftime"] = a_type { typename = "function", args = TUPLE { NUMBER, NUMBER }, rets = TUPLE { NUMBER } }, - ["execute"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { BOOLEAN, STRING, INTEGER } }, - ["exit"] = a_type { typename = "function", args = TUPLE { OPT(UNION { NUMBER, BOOLEAN }), OPT(BOOLEAN) }, rets = TUPLE {} }, - ["getenv"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, - ["remove"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { BOOLEAN, STRING } }, - ["rename"] = a_type { typename = "function", args = TUPLE { STRING, STRING}, rets = TUPLE { BOOLEAN, STRING } }, - ["setlocale"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING) }, rets = TUPLE { STRING } }, - ["time"] = a_type { typename = "function", args = TUPLE { OPT(OS_DATE_TABLE) }, rets = TUPLE { INTEGER } }, - ["tmpname"] = a_type { typename = "function", args = TUPLE {}, rets = TUPLE { STRING } }, + ["difftime"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, + ["execute"] = a_function { args = a_tuple { STRING }, rets = a_tuple { BOOLEAN, STRING, INTEGER } }, + ["exit"] = a_function { args = a_tuple { OPT(a_union { NUMBER, BOOLEAN }), OPT(BOOLEAN) }, rets = a_tuple {} }, + ["getenv"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, + ["remove"] = a_function { args = a_tuple { STRING }, rets = a_tuple { BOOLEAN, STRING } }, + ["rename"] = a_function { args = a_tuple { STRING, STRING}, rets = a_tuple { BOOLEAN, STRING } }, + ["setlocale"] = a_function { args = a_tuple { STRING, OPT(STRING) }, rets = a_tuple { STRING } }, + ["time"] = a_function { args = a_tuple { OPT(OS_DATE_TABLE) }, rets = a_tuple { INTEGER } }, + ["tmpname"] = a_function { args = a_tuple {}, rets = a_tuple { STRING } }, }, }, ["package"] = a_record { fields = { ["config"] = STRING, ["cpath"] = STRING, - ["loaded"] = a_type { - typename = "map", - keys = STRING, - values = ANY, - }, - ["loaders"] = a_type { - typename = "array", - elements = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { ANY, ANY } } - }, - ["loadlib"] = a_type { typename = "function", args = TUPLE { STRING, STRING }, rets = TUPLE { FUNCTION } }, + ["loaded"] = a_map(STRING, ANY), + ["loaders"] = an_array(a_function { args = a_tuple { STRING }, rets = a_tuple { ANY, ANY } }), + ["loadlib"] = a_function { args = a_tuple { STRING, STRING }, rets = a_tuple { FUNCTION } }, ["path"] = STRING, ["preload"] = TABLE, - ["searchers"] = a_type { - typename = "array", - elements = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { ANY, ANY } } - }, - ["searchpath"] = a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = TUPLE { STRING, STRING } }, + ["searchers"] = an_array(a_function { args = a_tuple { STRING }, rets = a_tuple { ANY, ANY } }), + ["searchpath"] = a_function { args = a_tuple { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = a_tuple { STRING, STRING } }, }, }, ["string"] = a_record { fields = { - ["byte"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { STRING, OPT(NUMBER) }, rets = TUPLE { INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, NUMBER, NUMBER }, rets = VARARG { INTEGER } }, - }, + ["byte"] = a_poly { + a_function { args = a_tuple { STRING, OPT(NUMBER) }, rets = a_tuple { INTEGER } }, + a_function { args = a_tuple { STRING, NUMBER, NUMBER }, rets = a_vararg { INTEGER } }, }, - ["char"] = a_type { typename = "function", args = VARARG { NUMBER }, rets = TUPLE { STRING } }, - ["dump"] = a_type({ typename = "function", args = TUPLE { FUNCTION, OPT(BOOLEAN) }, rets = TUPLE { STRING } }), - ["find"] = a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = VARARG { INTEGER, INTEGER, STRING } }, - ["format"] = a_type { typename = "function", args = VARARG { STRING, ANY }, rets = TUPLE { STRING } }, - ["gmatch"] = a_type { typename = "function", args = TUPLE { STRING, STRING }, rets = TUPLE { - a_type { typename = "function", args = TUPLE {}, rets = VARARG { STRING } }, + ["char"] = a_function { args = a_vararg { NUMBER }, rets = a_tuple { STRING } }, + ["dump"] = a_function { args = a_tuple { FUNCTION, OPT(BOOLEAN) }, rets = a_tuple { STRING } }, + ["find"] = a_function { args = a_tuple { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = a_vararg { INTEGER, INTEGER, STRING } }, + ["format"] = a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { STRING } }, + ["gmatch"] = a_function { args = a_tuple { STRING, STRING }, rets = a_tuple { + a_function { args = a_tuple {}, rets = a_vararg { STRING } }, } }, - ["gsub"] = a_type { - typename = "poly", - types = { - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "map", keys = STRING, values = STRING }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { STRING } }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { NUMBER } }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE { BOOLEAN } }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, a_type { typename = "function", args = VARARG { STRING }, rets = TUPLE {} }, OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, - a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = TUPLE { STRING, INTEGER } }, - -- FIXME any other modes - } + ["gsub"] = a_poly { + a_function { args = a_tuple { STRING, STRING, a_map(STRING, STRING), OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple { STRING } }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple { NUMBER } }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple { BOOLEAN } }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple {} }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + a_function { args = a_tuple { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + -- FIXME any other modes }, - ["len"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { INTEGER } }, - ["lower"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, - ["match"] = a_type { typename = "function", args = TUPLE { STRING, OPT(STRING), OPT(NUMBER) }, rets = VARARG { STRING } }, - ["pack"] = a_type { typename = "function", args = VARARG { STRING, ANY }, rets = TUPLE { STRING } }, - ["packsize"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { INTEGER } }, - ["rep"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, OPT(STRING) }, rets = TUPLE { STRING } }, - ["reverse"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, - ["sub"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, OPT(NUMBER) }, rets = TUPLE { STRING } }, - ["unpack"] = a_type { typename = "function", args = TUPLE { STRING, STRING, OPT(NUMBER) }, rets = VARARG { ANY } }, - ["upper"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { STRING } }, + ["len"] = a_function { args = a_tuple { STRING }, rets = a_tuple { INTEGER } }, + ["lower"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, + ["match"] = a_function { args = a_tuple { STRING, OPT(STRING), OPT(NUMBER) }, rets = a_vararg { STRING } }, + ["pack"] = a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { STRING } }, + ["packsize"] = a_function { args = a_tuple { STRING }, rets = a_tuple { INTEGER } }, + ["rep"] = a_function { args = a_tuple { STRING, NUMBER, OPT(STRING) }, rets = a_tuple { STRING } }, + ["reverse"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, + ["sub"] = a_function { args = a_tuple { STRING, NUMBER, OPT(NUMBER) }, rets = a_tuple { STRING } }, + ["unpack"] = a_function { args = a_tuple { STRING, STRING, OPT(NUMBER) }, rets = a_vararg { ANY } }, + ["upper"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, }, }, ["table"] = a_record { fields = { - ["concat"] = a_type { typename = "function", args = TUPLE { ARRAY(UNION {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = TUPLE { STRING } }, - ["insert"] = a_type { - typename = "poly", - types = { - a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), NUMBER, a }, rets = TUPLE {} } end), - a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), a }, rets = TUPLE {} } end), - } + ["concat"] = a_function { args = a_tuple { an_array(a_union {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = a_tuple { STRING } }, + ["insert"] = a_poly { + a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), + a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), a }, rets = a_tuple {} } end), }, - ["move"] = a_type { - typename = "poly", - types = { - a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), NUMBER, NUMBER, NUMBER }, rets = TUPLE { ARRAY(a) } }end ), - a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), NUMBER, NUMBER, NUMBER, ARRAY(a) }, rets = TUPLE { ARRAY(a) } } end), - } + ["move"] = a_poly { + a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER }, rets = a_tuple { an_array(a) } }end ), + a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }, rets = a_tuple { an_array(a) } } end), }, - ["pack"] = a_type { typename = "function", args = VARARG { ANY }, rets = TUPLE { TABLE } }, - ["remove"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), OPT(NUMBER) }, rets = TUPLE { a } } end), - ["sort"] = a_gfunction(1, function(a: Type): Type return { args = TUPLE { ARRAY(a), OPT(TABLE_SORT_FUNCTION) }, rets = TUPLE {} } end), - ["unpack"] = a_gfunction(1, function(a: Type): Type return { needs_compat = true, args = TUPLE { ARRAY(a), OPT(NUMBER), OPT(NUMBER) }, rets = VARARG { a } } end), + ["pack"] = a_function { args = a_vararg { ANY }, rets = a_tuple { TABLE } }, + ["remove"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), OPT(NUMBER) }, rets = a_tuple { a } } end), + ["sort"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), OPT(TABLE_SORT_FUNCTION) }, rets = a_tuple {} } end), + ["unpack"] = a_gfunction(1, function(a: Type): Type return { needs_compat = true, args = a_tuple { an_array(a), OPT(NUMBER), OPT(NUMBER) }, rets = a_vararg { a } } end), }, }, ["utf8"] = a_record { fields = { - ["char"] = a_type { typename = "function", args = VARARG { NUMBER }, rets = TUPLE { STRING } }, + ["char"] = a_function { args = a_vararg { NUMBER }, rets = a_tuple { STRING } }, ["charpattern"] = STRING, - ["codepoint"] = a_type { typename = "function", args = TUPLE { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = VARARG { INTEGER } }, - ["codes"] = a_type { typename = "function", args = TUPLE { STRING }, rets = TUPLE { - a_type { typename = "function", args = TUPLE { STRING, OPT(NUMBER) }, rets = TUPLE { NUMBER, NUMBER } }, + ["codepoint"] = a_function { args = a_tuple { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = a_vararg { INTEGER } }, + ["codes"] = a_function { args = a_tuple { STRING }, rets = a_tuple { + a_function { args = a_tuple { STRING, OPT(NUMBER) }, rets = a_tuple { NUMBER, NUMBER } }, }, }, - ["len"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, NUMBER }, rets = TUPLE { INTEGER } }, - ["offset"] = a_type { typename = "function", args = TUPLE { STRING, NUMBER, NUMBER }, rets = TUPLE { INTEGER } }, + ["len"] = a_function { args = a_tuple { STRING, NUMBER, NUMBER }, rets = a_tuple { INTEGER } }, + ["offset"] = a_function { args = a_tuple { STRING, NUMBER, NUMBER }, rets = a_tuple { INTEGER } }, }, }, ["_VERSION"] = STRING, @@ -6230,7 +6154,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if opts.module_name then - env.modules[opts.module_name] = a_type { typename = "typetype", def = CIRCULAR_REQUIRE } + env.modules[opts.module_name] = a_typetype { def = CIRCULAR_REQUIRE } end local lax = opts.lax @@ -6308,17 +6232,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local resolve_typevars: function (typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local function fresh_typevar(t: Type): Type, Type, boolean - return a_type { - typename = "typevar", + return a_type("typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr - } + }) end local function fresh_typearg(t: Type): Type - return a_type { - typename = "typearg", + return a_type("typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr - } + }) end local function ensure_fresh_typeargs(t: Type): Type @@ -7051,13 +6973,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string unresolved = find_var_type("@unresolved") end if not unresolved then - unresolved = { - typename = "unresolved", + unresolved = a_type("unresolved", { labels = {}, nominals = {}, global_types = {}, narrows = {}, - } + }) add_var(nil, "@unresolved", unresolved) end return unresolved @@ -7179,7 +7100,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not resolved then - resolved = a_type { typename = "bad_nominal", names = t.names } + resolved = a_type("bad_nominal", { names = t.names }) end if not t.filename then @@ -7319,10 +7240,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if #ts == 1 then return ts[1] else - return a_type { - typename = "union", - types = ts, - } + return a_union(ts) end end @@ -7347,19 +7265,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local element_type = unite(tupletype.types, true) local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) if valid then - return a_type { - elements = element_type, - typename = "array", - } + return an_array(element_type) end -- failing a basic union, expand the types - local arr_type = a_type { - elements = tupletype.types[1], - typename = "array", - } + local arr_type = an_array(tupletype.types[1]) for i = 2, #tupletype.types do - arr_type = expand_type(where, arr_type, a_type { elements = tupletype.types[i], typename = "array" }) + arr_type = expand_type(where, arr_type, an_array(tupletype.types[i])) if not arr_type.elements then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end @@ -7881,7 +7793,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["bad_nominal"] = compare_false, ["any"] = compare_true, ["tuple"] = function(a: Type, b: Type): boolean, {Error} - return is_a(TUPLE({a}), b) + return is_a(a_tuple({a}), b) end, ["typevar"] = function(a: Type, b: Type): boolean, {Error} return compare_or_infer_typevar(b.typevar, a, nil, is_a) @@ -7963,6 +7875,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else ok = t1.typename == t2.typename end + if (not ok) and not err then return false, { Err(t1, "got %s, expected %s", t1, t2) } end @@ -8009,9 +7922,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true elseif t2.typename == "unresolved_emptytable_value" then if is_number_type(t2.emptytable_type.keys) then -- ideally integer only - infer_emptytable(t2.emptytable_type, infer_at(where, a_type { typename = "array", elements = t1 })) + infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_type { typename = "map", keys = t2.emptytable_type.keys, values = t1 })) + infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) end return true elseif t2.typename == "emptytable" then @@ -8081,7 +7994,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t = resolve_tuple_and_nominal(t) local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt then - local args_tuple = a_type { typename = "tuple" } + local args_tuple = a_tuple({}) for i = 2, #call_mt.args do table.insert(args_tuple, call_mt.args[i]) end @@ -8093,7 +8006,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function resolve_for_call(func: Type, args: {Type}, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function if lax and is_unknown(func) then - func = a_type { typename = "function", args = VARARG { UNKNOWN }, rets = VARARG { UNKNOWN } } + func = a_function { args = a_vararg { UNKNOWN }, rets = a_vararg { UNKNOWN } } end -- unwrap if tuple, resolve if nominal func = resolve_tuple_and_nominal(func) @@ -8218,7 +8131,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if f.typeargs then for _, a in ipairs(f.typeargs) do if not find_var_type(a.typearg) then - add_var(nil, a.typearg, lax and UNKNOWN or { typename = "unresolvable_typearg", typearg = a.typearg }) + add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) end end end @@ -8312,7 +8225,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function push_typeargs(func: Type) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, { typename = "unresolved_typearg" }) + add_var(nil, fnarg.typearg, a_type("unresolved_typearg", {})) end end end @@ -8353,7 +8266,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local f = func.typename == "poly" and func.types[1] or func + mark_invalid_typeargs(f) + return resolve_typevars_at(where, f.rets) end @@ -8368,7 +8283,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string argdelta = is_method and -1 or argdelta or 0 if is_method and args[1] then - add_var(nil, "@self", a_type({ typename = "typetype", y = where.y, x = where.x, def = args[1] })) + add_var(nil, "@self", a_typetype({ y = where.y, x = where.x, def = args[1] })) end local is_func = func.typename == "function" @@ -8440,7 +8355,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta?: integer): Type, Type if node.expected and node.expected.typename ~= "tuple" then - node.expected = a_type { typename = "tuple", node.expected } + node.expected = a_tuple { node.expected } end begin_scope() @@ -8489,7 +8404,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if metamethod then local where_args = { node.e1 } - local args = { typename = "tuple", orig_a } + local args = a_tuple { orig_a } if b and method_name ~= "__is" then where_args[2] = node.e2 args[2] = orig_b @@ -8675,11 +8590,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function get_rets(rets: {Type}): Type if lax and (#rets == 0) then - return VARARG { UNKNOWN } + return a_vararg { UNKNOWN } end local t: Type = rets as Type if not t.typename then - t = TUPLE(t) + -- what type is this? + t = a_tuple(t) end assert(t.typeid) return t @@ -8689,7 +8605,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(args.typename == "tuple") add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_type { typename = "tuple" }) + add_var(nil, "@return", node.rets or a_tuple({})) if node.typeargs then for _, t in ipairs(node.typeargs) do @@ -8704,14 +8620,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function add_function_definition_for_recursion(node: Node, fnargs: Type) assert(fnargs.typename == "tuple") - local args: Type = TUPLE({}) + local args = a_tuple({}) args.is_va = fnargs.is_va for _, fnarg in ipairs(fnargs) do table.insert(args, fnarg) end - add_var(nil, node.name.tk, a_type { - typename = "function", + add_var(nil, node.name.tk, a_function { args = args, rets = get_rets(node.rets), }) @@ -8863,7 +8778,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if is_a(orig_b, a.keys) then - return a_type { y = anode.y, x = anode.x, typename = "unresolved_emptytable_value", emptytable_type = a } + return a_type("unresolved_emptytable_value", { + y = anode.y, + x = anode.x, + emptytable_type = a + }) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " @@ -8923,11 +8842,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) end + edit_type(old, "map") -- map changed, refresh typeid else error_at(where, "cannot determine table literal type") end elseif is_record_type(old) and is_record_type(new) then - old.typename = "map" + edit_type(old, "map") old.keys = STRING for _, ftype in fields_of(old) do if not old.values then @@ -8946,6 +8866,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.fields = nil old.field_order = nil elseif old.typename == "union" then + edit_type(old, "union") new.tk = nil table.insert(old.types, new) else @@ -8991,18 +8912,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.def.typeargs then typevals = {} for _, a in ipairs(t.def.typeargs) do - table.insert(typevals, a_type { typename = "typevar", typevar = a.typearg }) + table.insert(typevals, a_type("typevar", { typevar = a.typearg })) end end - return a_type { + return a_type("nominal", { y = where.y, x = where.x, - typename = "nominal", typevals = typevals, names = { name }, found = t, resolved = resolved, - } + }) end local function get_self_type(exp: Node): Type @@ -9401,7 +9321,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return TUPLE { BOOLEAN } + return a_tuple { BOOLEAN } end -- The function called by pcall/xpcall is invoked as a regular function, so we wish to avoid incorrect error messages / unnecessary warning messages associated with calling methods as functions @@ -9428,7 +9348,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets.typename ~= "tuple" then - rets = a_type { typename = "tuple", rets } + -- TODO what type is this?... + rets = a_tuple({ rets }) end table.insert(rets, 1, BOOLEAN) return rets @@ -9561,7 +9482,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string resolved = find_type(names) if (not resolved) or (not is_typetype(resolved)) then error_at(typetype, "%s is not a type", typetype) - resolved = a_type { typename = "bad_nominal", names = names } + resolved = a_type("bad_nominal", { names = names }) end end return resolved, aliasing @@ -9591,7 +9512,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = a_type { y = node.y, x = node.x, filename = filename, typename = "tuple", types = {} } + typ = a_type("tuple", { y = node.y, x = node.x, filename = filename }) for a = i, ndecl do table.insert(typ, decls[a]) end @@ -9639,12 +9560,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function infer_table_literal(node: Node, children: {Type}): Type - local typ = a_type { + local typ = a_type("emptytable", { filename = filename, y = node.y, x = node.x, - typename = "emptytable", - } + }) local is_record = false local is_array = false @@ -9738,13 +9658,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif is_record and is_array then typ.typename = "record" typ.interface_list = { - a_type { + a_type("array", { filename = filename, y = node.y, x = node.x, - typename = "array", elements = typ.elements, - } + }) } -- TODO adopt logic from is_array below when we accept tupletable as an interface elseif is_record and is_map then @@ -10013,7 +9932,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local name = node.var.tk local resolved, aliasing = get_type_declaration(node) local var = add_var(node.var, name, resolved, node.var.attribute) ---@-- node.value.type = resolved if aliasing then var.aliasing = aliasing node.value.is_alias = true @@ -10227,7 +10145,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) end local unresolved = st[#st]["@unresolved"] - local var = add_var(node, label_id, a_type { y = node.y, x = node.x, typename = "none" }) + local var = add_var(node, label_id, a_type("none", { y = node.y, x = node.x })) if unresolved then if unresolved.t.labels[node.label] then var.used = true @@ -10267,8 +10185,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions(node) local exp1 = node.exps[1] - local args = { - typename = "tuple", + local args = a_tuple { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3] } @@ -10384,7 +10301,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["variable_list"] = { after = function(_node: Node, children: {Type}): Type - local tuple = TUPLE(children) + local tuple = a_tuple(children) -- explode last tuple: (1, 2, (3, 4)) becomes (1, 2, 3, 4) local n = #tuple @@ -10538,20 +10455,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t: Type if force_array then - t = infer_at(node, a_type { - typename = "array", - elements = force_array, - }) + t = infer_at(node, an_array(force_array)) else t = resolve_typevars_at(node, node.expected) if node.expected == t and t.typename == "nominal" then - t = { - typeid = t.typeid, - typename = "nominal", + t = a_type("nominal", { names = t.names, found = t.found, resolved = t.resolved, - } + }) end end @@ -10579,14 +10491,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype.typeid = new_typeid() vtype.is_method = false end - return a_type { + return a_type("table_item", { y = node.y, x = node.x, - typename = "table_item", kname = kname, ktype = ktype, vtype = vtype, - } + }) end, }, ["local_function"] = { @@ -10606,10 +10517,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_function_scope(node) local rets = get_rets(children[3]) - local t = ensure_fresh_typeargs(a_type { + local t = ensure_fresh_typeargs(a_function { y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[2], rets = rets, @@ -10648,10 +10558,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return NONE end - add_global(node, node.name.tk, ensure_fresh_typeargs(a_type { + add_global(node, node.name.tk, ensure_fresh_typeargs(a_function { y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[2], rets = get_rets(children[3]), @@ -10672,12 +10581,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- add type arguments from the record implicitly if node.rtype.typeargs then for _, typ in ipairs(node.rtype.typeargs) do - add_var(nil, typ.typearg, a_type { + add_var(nil, typ.typearg, a_type("typearg", { y = typ.y, x = typ.x, - typename = "typearg", typearg = typ.typearg, - }) + })) end end end, @@ -10686,7 +10594,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rtype = node.rtype if rtype.typename == "emptytable" then - rtype.typename = "record" + edit_type(rtype, "record") rtype.fields = {} rtype.field_order = {} end @@ -10710,10 +10618,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(nil, "self", selftype) end - local fn_type = ensure_fresh_typeargs(a_type { + local fn_type = ensure_fresh_typeargs(a_function { y = node.y, x = node.x, - typename = "function", is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -10781,10 +10688,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_function_scope(node) -- children[1] args -- children[2] body - return ensure_fresh_typeargs(a_type { + return ensure_fresh_typeargs(a_function { y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[1], rets = children[2], @@ -10805,10 +10711,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_function_scope(node) -- children[1] args -- children[2] body - return ensure_fresh_typeargs(a_type { + return ensure_fresh_typeargs(a_function { y = node.y, x = node.x, - typename = "function", typeargs = node.typeargs, args = children[1], rets = children[2], @@ -10930,12 +10835,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string kind = "string", conststr = node.e2.tk, } - local btype = a_type { + local btype = a_type("string", { y = node.e2.y, x = node.e2.x, tk = '"' ..node.e2.tk .. '"', - typename = "string", - } + }) local t = type_check_index(node.e1, bnode, orig_a, btype) if t.needs_compat and opts.gen_compat ~= "off" then @@ -11258,7 +11162,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t = UNKNOWN end if node.tk == "..." then - t = a_type { typename = "tuple", is_va = true, t } + t = a_vararg { t } end if node.opt then t = OPT(t) @@ -11293,12 +11197,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function after_literal(node: Node): Type node.known = FACT_TRUTHY - return a_type { + return a_type(node.kind as TypeName, { y = node.y, x = node.x, - typename = node.kind as TypeName, tk = node.tk, - } + }) end visit_node.cbs["string"] = { @@ -11368,7 +11271,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = { before = function(typ: Type) begin_scope() - add_var(nil, "@self", a_type({ typename = "typetype", y = typ.y, x = typ.x, def = typ })) + add_var(nil, "@self", a_typetype({ y = typ.y, x = typ.x, def = typ })) for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then @@ -11429,12 +11332,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["typearg"] = { after = function(typ: Type, _children: {Type}): Type - add_var(nil, typ.typearg, a_type { + add_var(nil, typ.typearg, a_type("typearg", { y = typ.y, x = typ.x, - typename = "typearg", typearg = typ.typearg, - }) + })) return typ end, }, @@ -11456,8 +11358,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t then if t.typename == "typearg" then -- convert nominal into a typevar + edit_type(typ, "typevar") typ.names = nil - typ.typename = "typevar" typ.typevar = t.typearg else if t.is_alias then From be6a19e0ebb18dcaa9ae7ae8c0a0ee2dbf72c389 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 11 Dec 2023 11:32:50 -0300 Subject: [PATCH 038/224] interfaces: fix error location --- spec/assignment/to_interface_spec.lua | 5 ++++- tl.lua | 4 ++-- tl.tl | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/spec/assignment/to_interface_spec.lua b/spec/assignment/to_interface_spec.lua index 9f384276d..b57d75ba1 100644 --- a/spec/assignment/to_interface_spec.lua +++ b/spec/assignment/to_interface_spec.lua @@ -32,7 +32,10 @@ describe("assignment", function() err = { { y = 6, msg = "cannot reassign a type" } } elseif scope:match("to inner def") then -- 3 if outer == "interface" and scope:match("with outer def") then - err = { { y = 6, msg = "interfaces are abstract; consider using a concrete record" } } + err = { + { y = 6, msg = "interfaces are abstract; consider using a concrete record" }, + { y = 6, msg = "cannot reassign a type" }, + } else err = { { y = 6, msg = "cannot reassign a type" } } end diff --git a/tl.lua b/tl.lua index 7a06b85b2..772aa6c95 100644 --- a/tl.lua +++ b/tl.lua @@ -10804,14 +10804,14 @@ a.types[i], b.types[i]), } if ra.def.typename == "record" then ra = ra.def elseif ra.def.typename == "interface" then - error_at(node, "interfaces are abstract; consider using a concrete record") + error_at(node.e1, "interfaces are abstract; consider using a concrete record") end end if rb and is_typetype(rb) and rb.def.typename == "record" then if rb.def.typename == "record" then rb = rb.def elseif rb.def.typename == "interface" then - error_at(node, "interfaces are abstract; consider using a concrete record") + error_at(node.e2, "interfaces are abstract; consider using a concrete record") end end diff --git a/tl.tl b/tl.tl index 2b2b948a9..af700c018 100644 --- a/tl.tl +++ b/tl.tl @@ -10804,14 +10804,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ra.def.typename == "record" then ra = ra.def elseif ra.def.typename == "interface" then - error_at(node, "interfaces are abstract; consider using a concrete record") + error_at(node.e1, "interfaces are abstract; consider using a concrete record") end end if rb and is_typetype(rb) and rb.def.typename == "record" then if rb.def.typename == "record" then rb = rb.def elseif rb.def.typename == "interface" then - error_at(node, "interfaces are abstract; consider using a concrete record") + error_at(node.e2, "interfaces are abstract; consider using a concrete record") end end From 8aaf3de4a3a60229d4ca8b558555865e2e9c93ba Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 11 Dec 2023 11:34:43 -0300 Subject: [PATCH 039/224] macroexp: require `return` in statement --- spec/declaration/macroexp_spec.lua | 8 ++++---- spec/macroexp/is_spec.lua | 8 ++++---- spec/operator/is_spec.lua | 4 ++-- tl.lua | 1 + tl.tl | 1 + 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/spec/declaration/macroexp_spec.lua b/spec/declaration/macroexp_spec.lua index ba5b517af..f1aa77bbc 100644 --- a/spec/declaration/macroexp_spec.lua +++ b/spec/declaration/macroexp_spec.lua @@ -4,7 +4,7 @@ describe("macroexp declaration", function() it("checks unused arguments", util.check_warnings([[ local record R1 metamethod __is: function(self: R1): boolean = macroexp(self: R1): boolean - true + return true end end ]], { @@ -14,7 +14,7 @@ describe("macroexp declaration", function() it("checks argument mismatch", util.check_type_error([[ local record R1 metamethod __call: function(self: R1, n: number): boolean = macroexp(self: R1, s: string): boolean - self.field == s + return self.field == s end field: string end @@ -29,10 +29,10 @@ describe("macroexp declaration", function() local record R1 metamethod __call: function(self: R1, s: string): boolean = macroexp(self: R1, s: string): boolean - print(s, s) + return print(s, s) end end ]], { - { y = 7, x = 22, msg = "cannot use argument 's' multiple times in macroexp" } + { y = 7, x = 29, msg = "cannot use argument 's' multiple times in macroexp" } })) end) diff --git a/spec/macroexp/is_spec.lua b/spec/macroexp/is_spec.lua index b359a332b..f5604225a 100644 --- a/spec/macroexp/is_spec.lua +++ b/spec/macroexp/is_spec.lua @@ -4,13 +4,13 @@ describe("__is with macroexp", function() it("can expand a constant expression", util.gen([[ local record R1 metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean - true + return true end end local record R2 metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean - false + return false end end @@ -58,7 +58,7 @@ describe("__is with macroexp", function() it("can expand self in an expression", util.gen([[ local record R1 metamethod __is: function(self: R1|R2): boolean = macroexp(self: R1|R2): boolean - self.kind == "r1" + return self.kind == "r1" end kind: string @@ -66,7 +66,7 @@ describe("__is with macroexp", function() local record R2 metamethod __is: function(self: R1|R2): boolean = macroexp(self: R1|R2): boolean - self.kind == "r2" + return self.kind == "r2" end kind: string diff --git a/spec/operator/is_spec.lua b/spec/operator/is_spec.lua index 433eff152..9e31199d2 100644 --- a/spec/operator/is_spec.lua +++ b/spec/operator/is_spec.lua @@ -405,13 +405,13 @@ describe("flow analysis with is", function() it("produces no errors or warnings for checks on unions of records", util.check_warnings([[ local record R1 metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean - true + return true end end local record R2 metamethod __is: function(self: R1|R2): boolean = macroexp(_self: R1|R2): boolean - false + return false end end diff --git a/tl.lua b/tl.lua index 772aa6c95..4e90d0dfb 100644 --- a/tl.lua +++ b/tl.lua @@ -2953,6 +2953,7 @@ local function parse_macroexp(ps, i) i = i + 1 i, node.args = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) + i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i]) i = verify_end(ps, i, istart, node) diff --git a/tl.tl b/tl.tl index af700c018..43891a42b 100644 --- a/tl.tl +++ b/tl.tl @@ -2953,6 +2953,7 @@ local function parse_macroexp(ps: ParseState, i: integer): integer, Node i = i + 1 -- skip 'macroexp' i, node.args = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) + i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i]) i = verify_end(ps, i, istart, node) From 94a8368cb091a23bd05c9546b5264019fd30744a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 4 Jan 2024 00:39:52 -0300 Subject: [PATCH 040/224] fix: define typeargs in function definition for recursion --- tl.lua | 1 + tl.tl | 1 + 2 files changed, 2 insertions(+) diff --git a/tl.lua b/tl.lua index 4e90d0dfb..c715b0c1a 100644 --- a/tl.lua +++ b/tl.lua @@ -8628,6 +8628,7 @@ a.types[i], b.types[i]), } end add_var(nil, node.name.tk, a_function({ + typeargs = node.typeargs, args = args, rets = get_rets(node.rets), })) diff --git a/tl.tl b/tl.tl index 43891a42b..4d50eb652 100644 --- a/tl.tl +++ b/tl.tl @@ -8628,6 +8628,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end add_var(nil, node.name.tk, a_function { + typeargs = node.typeargs, args = args, rets = get_rets(node.rets), }) From 5aebc641f59c4554974a6f92bcf68dc55b25b3cf Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 11 Dec 2023 11:37:01 -0300 Subject: [PATCH 041/224] local macroexp: new statement --- tl.lua | 717 +++++++++++++++++++++++++++++++-------------------------- tl.tl | 103 +++++++-- 2 files changed, 471 insertions(+), 349 deletions(-) diff --git a/tl.lua b/tl.lua index c715b0c1a..d12ed0ee2 100644 --- a/tl.lua +++ b/tl.lua @@ -1249,6 +1249,7 @@ local table_types = { + local TruthyFact = {} @@ -1444,6 +1445,7 @@ local Node = {ExpectedContext = {}, } + local function is_array_type(t) @@ -1561,44 +1563,43 @@ local function new_type(ps, i, typename) }) end -local function a_tuple(t) + + + + +local function c_tuple(t) return a_type("tuple", t) end -local function a_union(t) - return a_type("union", { types = t }) -end -local function a_poly(t) - return a_type("poly", { types = t }) -end -local function a_function(t) - return a_type("function", t) -end -local function a_typetype(t) - return a_type("typetype", t) -end + + + + + + + + + + + + local function a_vararg(t) local tuple = t tuple.is_va = true - return a_tuple(t) + return a_type("tuple", t) end -local function an_array(t) - return a_type("array", { - elements = t, - }) -end -local function a_map(k, v) - return a_type("map", { - keys = k, - values = v, - }) -end + + + + + + local NIL = a_type("nil", {}) local ANY = a_type("any", {}) @@ -2943,15 +2944,14 @@ local metamethod_names = { ["__is"] = true, } -local function parse_macroexp(ps, i) - local istart = i +local function parse_macroexp(ps, istart, iargs) local node = new_node(ps.tokens, istart, "macroexp") - i = i + 1 - i, node.args = parse_argument_list(ps, i) + local i + i, node.args = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) @@ -3050,8 +3050,8 @@ parse_record_body = function(ps, i, def, node) local typ = new_type(ps, wstart, "function") typ.is_method = true - typ.args = a_tuple({ a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) }) - typ.rets = a_tuple({ BOOLEAN }) + typ.args = a_type("tuple", { a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) }) + typ.rets = a_type("tuple", { BOOLEAN }) typ.macroexp = where_macroexp store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) @@ -3135,7 +3135,7 @@ parse_record_body = function(ps, i, def, node) if t.typename ~= "function" then fail(ps, i + 1, "macroexp must have a function type") end - i, t.macroexp = parse_macroexp(ps, i + 1) + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3342,6 +3342,15 @@ local function skip_type_declaration(ps, i) return parse_type_declaration(ps, i - 1, "local_type") end +local function parse_local_macroexp(ps, i) + local istart = i + i = i + 2 + local node = new_node(ps.tokens, i, "local_macroexp") + i, node.name = parse_identifier(ps, i) + i, node.macrodef = parse_macroexp(ps, istart, i) + return i, node +end + local function parse_local(ps, i) local ntk = ps.tokens[i + 1].tk local tn = ntk @@ -3349,6 +3358,8 @@ local function parse_local(ps, i) return parse_local_function(ps, i) elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then return parse_type_declaration(ps, i, "local_type") + elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then + return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) end @@ -3891,6 +3902,14 @@ local function recurse_node(root, extra_callback("before_statements", ast, xs, visit_node) xs[5] = recurse(ast.body) end, + ["local_macroexp"] = function(ast, xs) + + xs[1] = recurse(ast.name) + xs[2] = recurse(ast.macrodef.args) + xs[3] = recurse_type(ast.macrodef.rets, visit_type) + extra_callback("before_exp", ast, xs, visit_node) + xs[4] = recurse(ast.macrodef.exp) + end, ["forin"] = function(ast, xs) xs[1] = recurse(ast.vars) @@ -4450,6 +4469,12 @@ function tl.pretty_print_ast(ast, gen_target, mode) return out end, }, + ["local_macroexp"] = { + before = increment_indent, + after = function(node, _children) + return { y = node.y, h = 0 } + end, + }, ["local_function"] = { before = increment_indent, after = function(node, children) @@ -4950,10 +4975,10 @@ local INVALID = a_type("invalid", {}) local UNKNOWN = a_type("unknown", {}) local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }) +local FUNCTION = a_type("function", { args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }) local NOMINAL_FILE = a_type("nominal", { names = { "FILE" } }) -local XPCALL_MSGH_FUNCTION = a_function({ args = a_tuple({ ANY }), rets = a_tuple({}) }) +local XPCALL_MSGH_FUNCTION = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", {}) }) local USERDATA = ANY @@ -5653,15 +5678,15 @@ local function init_globals(lax) local file_reader_poly_types = { - { ctor = a_vararg, args = { a_union({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) }) }, rets = { STRING } }, - { ctor = a_tuple, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, - { ctor = a_vararg, args = { a_union({ NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) }) }, rets = { a_union({ STRING, NUMBER }) } }, - { ctor = a_vararg, args = { a_union({ NUMBER, STRING }) }, rets = { STRING } }, + { ctor = a_vararg, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) } }) }, rets = { STRING } }, + { ctor = c_tuple, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, + { ctor = a_vararg, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) } }) }, rets = { a_type("union", { types = { STRING, NUMBER } }) } }, + { ctor = a_vararg, args = { a_type("union", { types = { NUMBER, STRING } }) }, rets = { STRING } }, { ctor = a_vararg, args = {}, rets = { STRING } }, } local function a_file_reader(fn) - local t = a_poly({}) + local t = a_type("poly", { types = {} }) for _, entry in ipairs(file_reader_poly_types) do local args = shallow_copy_type(entry.args) local rets = shallow_copy_type(entry.rets) @@ -5670,7 +5695,7 @@ local function init_globals(lax) return t end - local LOAD_FUNCTION = a_function({ args = {}, rets = a_tuple({ STRING }) }) + local LOAD_FUNCTION = a_type("function", { args = {}, rets = a_type("tuple", { STRING }) }) local OS_DATE_TABLE = a_record({ fields = { @@ -5707,12 +5732,12 @@ local function init_globals(lax) local DEBUG_HOOK_EVENT = an_enum({ "call", "tail call", "return", "line", "count" }) - local DEBUG_HOOK_FUNCTION = a_function({ - args = a_tuple({ DEBUG_HOOK_EVENT, INTEGER }), - rets = a_tuple({}), + local DEBUG_HOOK_FUNCTION = a_type("function", { + args = a_type("tuple", { DEBUG_HOOK_EVENT, INTEGER }), + rets = a_type("tuple", {}), }) - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = a_tuple({ a, a }), rets = a_tuple({ BOOLEAN }) } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = a_type("tuple", { a, a }), rets = a_type("tuple", { BOOLEAN }) } end) local metatable_nominals = {} @@ -5725,346 +5750,346 @@ local function init_globals(lax) local standard_library = { ["..."] = a_vararg({ STRING }), ["any"] = a_type("typetype", { def = ANY }), - ["arg"] = an_array(STRING), - ["assert"] = a_gfunction(2, function(a, b) return { args = a_tuple({ a, OPT(b) }), rets = a_tuple({ a }) } end), - ["collectgarbage"] = a_poly({ - a_function({ args = a_tuple({ an_enum({ "collect", "count", "stop", "restart" }) }), rets = a_tuple({ NUMBER }) }), - a_function({ args = a_tuple({ an_enum({ "step", "setpause", "setstepmul" }), NUMBER }), rets = a_tuple({ NUMBER }) }), - a_function({ args = a_tuple({ an_enum({ "isrunning" }) }), rets = a_tuple({ BOOLEAN }) }), - a_function({ args = a_tuple({ STRING, OPT(NUMBER) }), rets = a_tuple({ a_union({ BOOLEAN, NUMBER }) }) }), - }), - ["dofile"] = a_function({ args = a_tuple({ OPT(STRING) }), rets = a_vararg({ ANY }) }), - ["error"] = a_function({ args = a_tuple({ ANY, OPT(NUMBER) }), rets = a_tuple({}) }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a }), rets = a_tuple({ METATABLE(a) }) } end), - ["ipairs"] = a_gfunction(1, function(a) return { args = a_tuple({ an_array(a) }), rets = a_tuple({ - a_function({ args = a_tuple({}), rets = a_tuple({ INTEGER, a }) }), + ["arg"] = a_type("array", { elements = STRING }), + ["assert"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { a, OPT(b) }), rets = a_type("tuple", { a }) } end), + ["collectgarbage"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { an_enum({ "collect", "count", "stop", "restart" }) }), rets = a_type("tuple", { NUMBER }) }), + a_type("function", { args = a_type("tuple", { an_enum({ "step", "setpause", "setstepmul" }), NUMBER }), rets = a_type("tuple", { NUMBER }) }), + a_type("function", { args = a_type("tuple", { an_enum({ "isrunning" }) }), rets = a_type("tuple", { BOOLEAN }) }), + a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER) }), rets = a_type("tuple", { a_type("union", { types = { BOOLEAN, NUMBER } }) }) }), + } }), + ["dofile"] = a_type("function", { args = a_type("tuple", { OPT(STRING) }), rets = a_vararg({ ANY }) }), + ["error"] = a_type("function", { args = a_type("tuple", { ANY, OPT(NUMBER) }), rets = a_type("tuple", {}) }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a }), rets = a_type("tuple", { METATABLE(a) }) } end), + ["ipairs"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }) }), rets = a_type("tuple", { + a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { INTEGER, a }) }), }), } end), - ["load"] = a_function({ args = a_tuple({ a_union({ STRING, LOAD_FUNCTION }), OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_tuple({ FUNCTION, STRING }) }), - ["loadfile"] = a_function({ args = a_tuple({ OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_tuple({ FUNCTION, STRING }) }), - ["next"] = a_poly({ - a_gfunction(2, function(a, b) return { args = a_tuple({ a_map(a, b), OPT(a) }), rets = a_tuple({ a, b }) } end), - a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), OPT(a) }), rets = a_tuple({ INTEGER, a }) } end), - }), - ["pairs"] = a_gfunction(2, function(a, b) return { args = a_tuple({ a_map(a, b) }), rets = a_tuple({ - a_function({ args = a_tuple({}), rets = a_tuple({ a, b }) }), + ["load"] = a_type("function", { args = a_type("tuple", { a_type("union", { types = { STRING, LOAD_FUNCTION } }), OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_type("tuple", { FUNCTION, STRING }) }), + ["loadfile"] = a_type("function", { args = a_type("tuple", { OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_type("tuple", { FUNCTION, STRING }) }), + ["next"] = a_type("poly", { types = { + a_gfunction(2, function(a, b) return { args = a_type("tuple", { a_type("map", { keys = a, values = b }), OPT(a) }), rets = a_type("tuple", { a, b }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), OPT(a) }), rets = a_type("tuple", { INTEGER, a }) } end), + } }), + ["pairs"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { a_type("map", { keys = a, values = b }) }), rets = a_type("tuple", { + a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { a, b }) }), }), } end), - ["pcall"] = a_function({ args = a_vararg({ FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["xpcall"] = a_function({ args = a_vararg({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["print"] = a_function({ args = a_vararg({ ANY }), rets = a_tuple({}) }), - ["rawequal"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), - ["rawget"] = a_function({ args = a_tuple({ TABLE, ANY }), rets = a_tuple({ ANY }) }), - ["rawlen"] = a_function({ args = a_tuple({ a_union({ TABLE, STRING }) }), rets = a_tuple({ INTEGER }) }), - ["rawset"] = a_poly({ - a_gfunction(2, function(a, b) return { args = a_tuple({ a_map(a, b), a, b }), rets = a_tuple({}) } end), - a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, a }), rets = a_tuple({}) } end), - a_function({ args = a_tuple({ TABLE, ANY, ANY }), rets = a_tuple({}) }), - }), - ["require"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({}) }), - ["select"] = a_poly({ - a_gfunction(1, function(a) return { args = a_vararg({ NUMBER, a }), rets = a_tuple({ a }) } end), - a_function({ args = a_vararg({ NUMBER, ANY }), rets = a_tuple({ ANY }) }), - a_function({ args = a_vararg({ STRING, ANY }), rets = a_tuple({ INTEGER }) }), - }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a, METATABLE(a) }), rets = a_tuple({ a }) } end), - ["tonumber"] = a_poly({ - a_function({ args = a_tuple({ ANY }), rets = a_tuple({ NUMBER }) }), - a_function({ args = a_tuple({ ANY, NUMBER }), rets = a_tuple({ INTEGER }) }), - }), - ["tostring"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), - ["type"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), - ["FILE"] = a_typetype({ + ["pcall"] = a_type("function", { args = a_vararg({ FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["xpcall"] = a_type("function", { args = a_vararg({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["print"] = a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", {}) }), + ["rawequal"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), + ["rawget"] = a_type("function", { args = a_type("tuple", { TABLE, ANY }), rets = a_type("tuple", { ANY }) }), + ["rawlen"] = a_type("function", { args = a_type("tuple", { a_type("union", { types = { TABLE, STRING } }) }), rets = a_type("tuple", { INTEGER }) }), + ["rawset"] = a_type("poly", { types = { + a_gfunction(2, function(a, b) return { args = a_type("tuple", { a_type("map", { keys = a, values = b }), a, b }), rets = a_type("tuple", {}) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, a }), rets = a_type("tuple", {}) } end), + a_type("function", { args = a_type("tuple", { TABLE, ANY, ANY }), rets = a_type("tuple", {}) }), + } }), + ["require"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", {}) }), + ["select"] = a_type("poly", { types = { + a_gfunction(1, function(a) return { args = a_vararg({ NUMBER, a }), rets = a_type("tuple", { a }) } end), + a_type("function", { args = a_vararg({ NUMBER, ANY }), rets = a_type("tuple", { ANY }) }), + a_type("function", { args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { INTEGER }) }), + } }), + ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a, METATABLE(a) }), rets = a_type("tuple", { a }) } end), + ["tonumber"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { NUMBER }) }), + a_type("function", { args = a_type("tuple", { ANY, NUMBER }), rets = a_type("tuple", { INTEGER }) }), + } }), + ["tostring"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), + ["type"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), + ["FILE"] = a_type("typetype", { def = a_record({ is_userdata = true, fields = { - ["close"] = a_function({ args = a_tuple({ NOMINAL_FILE }), rets = a_tuple({ BOOLEAN, STRING, INTEGER }) }), - ["flush"] = a_function({ args = a_tuple({ NOMINAL_FILE }), rets = a_tuple({}) }), + ["close"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE }), rets = a_type("tuple", { BOOLEAN, STRING, INTEGER }) }), + ["flush"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE }), rets = a_type("tuple", {}) }), ["lines"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_function({ args = ctor(args), rets = a_tuple({ - a_function({ args = a_tuple({}), rets = ctor(rets) }), + return a_type("function", { args = ctor(args), rets = a_type("tuple", { + a_type("function", { args = a_type("tuple", {}), rets = ctor(rets) }), }), }) end), ["read"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_function({ args = ctor(args), rets = ctor(rets) }) + return a_type("function", { args = ctor(args), rets = ctor(rets) }) end), - ["seek"] = a_function({ args = a_tuple({ NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ INTEGER, STRING }) }), - ["setvbuf"] = a_function({ args = a_tuple({ NOMINAL_FILE, STRING, OPT(NUMBER) }), rets = a_tuple({}) }), - ["write"] = a_function({ args = a_vararg({ NOMINAL_FILE, a_union({ STRING, NUMBER }) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), + ["seek"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { INTEGER, STRING }) }), + ["setvbuf"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE, STRING, OPT(NUMBER) }), rets = a_type("tuple", {}) }), + ["write"] = a_type("function", { args = a_vararg({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), }, meta_fields = { ["__close"] = FUNCTION }, meta_field_order = { "__close" }, }), }), - ["metatable"] = a_typetype({ + ["metatable"] = a_type("typetype", { def = a_grecord(1, function(a) return { fields = { - ["__call"] = a_function({ args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), - ["__gc"] = a_function({ args = a_tuple({ a }), rets = a_tuple({}) }), + ["__call"] = a_type("function", { args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), + ["__gc"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", {}) }), ["__index"] = ANY, - ["__len"] = a_function({ args = a_tuple({ a }), rets = a_tuple({ ANY }) }), + ["__len"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", { ANY }) }), ["__mode"] = an_enum({ "k", "v", "kv" }), ["__newindex"] = ANY, ["__pairs"] = a_gfunction(2, function(k, v) return { - args = a_tuple({ a }), - rets = a_tuple({ a_function({ args = a_tuple({}), rets = a_tuple({ k, v }) }) }), + args = a_type("tuple", { a }), + rets = a_type("tuple", { a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { k, v }) }) }), } end), - ["__tostring"] = a_function({ args = a_tuple({ a }), rets = a_tuple({ STRING }) }), + ["__tostring"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", { STRING }) }), ["__name"] = STRING, - ["__add"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__sub"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__mul"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__div"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__idiv"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__mod"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__pow"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__unm"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ ANY }) }), - ["__band"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__bor"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__bxor"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__bnot"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ ANY }) }), - ["__shl"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__shr"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__concat"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ ANY }) }), - ["__eq"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), - ["__lt"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), - ["__le"] = a_function({ args = a_tuple({ ANY, ANY }), rets = a_tuple({ BOOLEAN }) }), - ["__close"] = a_function({ args = a_tuple({ a }), rets = a_tuple({}) }), + ["__add"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__sub"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__mul"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__div"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__idiv"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__mod"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__pow"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__unm"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { ANY }) }), + ["__band"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__bor"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__bxor"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__bnot"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { ANY }) }), + ["__shl"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__shr"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__concat"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), + ["__eq"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), + ["__lt"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), + ["__le"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), + ["__close"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", {}) }), }, } end), }), ["coroutine"] = a_record({ fields = { - ["create"] = a_function({ args = a_tuple({ FUNCTION }), rets = a_tuple({ THREAD }) }), - ["close"] = a_function({ args = a_tuple({ THREAD }), rets = a_tuple({ BOOLEAN, STRING }) }), - ["isyieldable"] = a_function({ args = a_tuple({}), rets = a_tuple({ BOOLEAN }) }), - ["resume"] = a_function({ args = a_vararg({ THREAD, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["running"] = a_function({ args = a_tuple({}), rets = a_tuple({ THREAD, BOOLEAN }) }), - ["status"] = a_function({ args = a_tuple({ THREAD }), rets = a_tuple({ STRING }) }), - ["wrap"] = a_function({ args = a_tuple({ FUNCTION }), rets = a_tuple({ FUNCTION }) }), - ["yield"] = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }), + ["create"] = a_type("function", { args = a_type("tuple", { FUNCTION }), rets = a_type("tuple", { THREAD }) }), + ["close"] = a_type("function", { args = a_type("tuple", { THREAD }), rets = a_type("tuple", { BOOLEAN, STRING }) }), + ["isyieldable"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { BOOLEAN }) }), + ["resume"] = a_type("function", { args = a_vararg({ THREAD, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["running"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { THREAD, BOOLEAN }) }), + ["status"] = a_type("function", { args = a_type("tuple", { THREAD }), rets = a_type("tuple", { STRING }) }), + ["wrap"] = a_type("function", { args = a_type("tuple", { FUNCTION }), rets = a_type("tuple", { FUNCTION }) }), + ["yield"] = a_type("function", { args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }), }, }), ["debug"] = a_record({ fields = { - ["Info"] = a_typetype({ def = DEBUG_GETINFO_TABLE }), - ["Hook"] = a_typetype({ def = DEBUG_HOOK_FUNCTION }), - ["HookEvent"] = a_typetype({ def = DEBUG_HOOK_EVENT }), - - ["debug"] = a_function({ args = a_tuple({}), rets = a_tuple({}) }), - ["gethook"] = a_function({ args = a_tuple({ OPT(THREAD) }), rets = a_tuple({ DEBUG_HOOK_FUNCTION, INTEGER }) }), - ["getlocal"] = a_poly({ - a_function({ args = a_tuple({ THREAD, FUNCTION, NUMBER }), rets = STRING }), - a_function({ args = a_tuple({ THREAD, NUMBER, NUMBER }), rets = a_tuple({ STRING, ANY }) }), - a_function({ args = a_tuple({ FUNCTION, NUMBER }), rets = STRING }), - a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ STRING, ANY }) }), - }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a }), rets = a_tuple({ METATABLE(a) }) } end), - ["getregistry"] = a_function({ args = a_tuple({}), rets = a_tuple({ TABLE }) }), - ["getupvalue"] = a_function({ args = a_tuple({ FUNCTION, NUMBER }), rets = a_tuple({ ANY }) }), - ["getuservalue"] = a_function({ args = a_tuple({ USERDATA, NUMBER }), rets = a_tuple({ ANY }) }), - ["sethook"] = a_poly({ - a_function({ args = a_tuple({ THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_tuple({}) }), - a_function({ args = a_tuple({ DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_tuple({}) }), - }), - ["setlocal"] = a_poly({ - a_function({ args = a_tuple({ THREAD, NUMBER, NUMBER, ANY }), rets = a_tuple({ STRING }) }), - a_function({ args = a_tuple({ NUMBER, NUMBER, ANY }), rets = a_tuple({ STRING }) }), - }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = a_tuple({ a, METATABLE(a) }), rets = a_tuple({ a }) } end), - ["setupvalue"] = a_function({ args = a_tuple({ FUNCTION, NUMBER, ANY }), rets = a_tuple({ STRING }) }), - ["setuservalue"] = a_function({ args = a_tuple({ USERDATA, ANY, NUMBER }), rets = a_tuple({ USERDATA }) }), - ["traceback"] = a_poly({ - a_function({ args = a_tuple({ OPT(THREAD), OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), - a_function({ args = a_tuple({ OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), - }), - ["upvalueid"] = a_function({ args = a_tuple({ FUNCTION, NUMBER }), rets = a_tuple({ USERDATA }) }), - ["upvaluejoin"] = a_function({ args = a_tuple({ FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = a_tuple({}) }), - ["getinfo"] = a_poly({ - a_function({ args = a_tuple({ ANY }), rets = a_tuple({ DEBUG_GETINFO_TABLE }) }), - a_function({ args = a_tuple({ ANY, STRING }), rets = a_tuple({ DEBUG_GETINFO_TABLE }) }), - a_function({ args = a_tuple({ ANY, ANY, STRING }), rets = a_tuple({ DEBUG_GETINFO_TABLE }) }), - }), + ["Info"] = a_type("typetype", { def = DEBUG_GETINFO_TABLE }), + ["Hook"] = a_type("typetype", { def = DEBUG_HOOK_FUNCTION }), + ["HookEvent"] = a_type("typetype", { def = DEBUG_HOOK_EVENT }), + + ["debug"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", {}) }), + ["gethook"] = a_type("function", { args = a_type("tuple", { OPT(THREAD) }), rets = a_type("tuple", { DEBUG_HOOK_FUNCTION, INTEGER }) }), + ["getlocal"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { THREAD, FUNCTION, NUMBER }), rets = STRING }), + a_type("function", { args = a_type("tuple", { THREAD, NUMBER, NUMBER }), rets = a_type("tuple", { STRING, ANY }) }), + a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER }), rets = STRING }), + a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { STRING, ANY }) }), + } }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a }), rets = a_type("tuple", { METATABLE(a) }) } end), + ["getregistry"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { TABLE }) }), + ["getupvalue"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER }), rets = a_type("tuple", { ANY }) }), + ["getuservalue"] = a_type("function", { args = a_type("tuple", { USERDATA, NUMBER }), rets = a_type("tuple", { ANY }) }), + ["sethook"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_type("tuple", {}) }), + a_type("function", { args = a_type("tuple", { DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_type("tuple", {}) }), + } }), + ["setlocal"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { THREAD, NUMBER, NUMBER, ANY }), rets = a_type("tuple", { STRING }) }), + a_type("function", { args = a_type("tuple", { NUMBER, NUMBER, ANY }), rets = a_type("tuple", { STRING }) }), + } }), + ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a, METATABLE(a) }), rets = a_type("tuple", { a }) } end), + ["setupvalue"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER, ANY }), rets = a_type("tuple", { STRING }) }), + ["setuservalue"] = a_type("function", { args = a_type("tuple", { USERDATA, ANY, NUMBER }), rets = a_type("tuple", { USERDATA }) }), + ["traceback"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { OPT(THREAD), OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + a_type("function", { args = a_type("tuple", { OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + } }), + ["upvalueid"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER }), rets = a_type("tuple", { USERDATA }) }), + ["upvaluejoin"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = a_type("tuple", {}) }), + ["getinfo"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { DEBUG_GETINFO_TABLE }) }), + a_type("function", { args = a_type("tuple", { ANY, STRING }), rets = a_type("tuple", { DEBUG_GETINFO_TABLE }) }), + a_type("function", { args = a_type("tuple", { ANY, ANY, STRING }), rets = a_type("tuple", { DEBUG_GETINFO_TABLE }) }), + } }), }, }), ["io"] = a_record({ fields = { - ["close"] = a_function({ args = a_tuple({ OPT(NOMINAL_FILE) }), rets = a_tuple({ BOOLEAN, STRING }) }), - ["flush"] = a_function({ args = a_tuple({}), rets = a_tuple({}) }), - ["input"] = a_function({ args = a_tuple({ OPT(a_union({ STRING, NOMINAL_FILE })) }), rets = a_tuple({ NOMINAL_FILE }) }), + ["close"] = a_type("function", { args = a_type("tuple", { OPT(NOMINAL_FILE) }), rets = a_type("tuple", { BOOLEAN, STRING }) }), + ["flush"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", {}) }), + ["input"] = a_type("function", { args = a_type("tuple", { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }), rets = a_type("tuple", { NOMINAL_FILE }) }), ["lines"] = a_file_reader(function(ctor, args, rets) - return a_function({ args = ctor(args), rets = a_tuple({ - a_function({ args = a_tuple({}), rets = ctor(rets) }), + return a_type("function", { args = ctor(args), rets = a_type("tuple", { + a_type("function", { args = a_type("tuple", {}), rets = ctor(rets) }), }), }) end), - ["open"] = a_function({ args = a_tuple({ STRING, OPT(STRING) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), - ["output"] = a_function({ args = a_tuple({ OPT(a_union({ STRING, NOMINAL_FILE })) }), rets = a_tuple({ NOMINAL_FILE }) }), - ["popen"] = a_function({ args = a_tuple({ STRING, OPT(STRING) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), + ["open"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), + ["output"] = a_type("function", { args = a_type("tuple", { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }), rets = a_type("tuple", { NOMINAL_FILE }) }), + ["popen"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), ["read"] = a_file_reader(function(ctor, args, rets) - return a_function({ args = ctor(args), rets = ctor(rets) }) + return a_type("function", { args = ctor(args), rets = ctor(rets) }) end), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_function({ args = a_tuple({}), rets = a_tuple({ NOMINAL_FILE }) }), - ["type"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), - ["write"] = a_function({ args = a_vararg({ a_union({ STRING, NUMBER }) }), rets = a_tuple({ NOMINAL_FILE, STRING }) }), + ["tmpfile"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { NOMINAL_FILE }) }), + ["type"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), + ["write"] = a_type("function", { args = a_vararg({ a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), }, }), ["math"] = a_record({ fields = { - ["abs"] = a_poly({ - a_function({ args = a_tuple({ INTEGER }), rets = a_tuple({ INTEGER }) }), - a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - }), - ["acos"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["asin"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["atan"] = a_function({ args = a_tuple({ NUMBER, OPT(NUMBER) }), rets = a_tuple({ NUMBER }) }), - ["atan2"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), - ["ceil"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ INTEGER }) }), - ["cos"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["cosh"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["deg"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["exp"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["floor"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ INTEGER }) }), - ["fmod"] = a_poly({ - a_function({ args = a_tuple({ INTEGER, INTEGER }), rets = a_tuple({ INTEGER }) }), - a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), - }), - ["frexp"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER, NUMBER }) }), + ["abs"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { INTEGER }), rets = a_type("tuple", { INTEGER }) }), + a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + } }), + ["acos"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["asin"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["atan"] = a_type("function", { args = a_type("tuple", { NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { NUMBER }) }), + ["atan2"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["ceil"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { INTEGER }) }), + ["cos"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["cosh"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["deg"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["exp"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["floor"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { INTEGER }) }), + ["fmod"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { INTEGER, INTEGER }), rets = a_type("tuple", { INTEGER }) }), + a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), + } }), + ["frexp"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER, NUMBER }) }), ["huge"] = NUMBER, - ["ldexp"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), - ["log"] = a_function({ args = a_tuple({ NUMBER, OPT(NUMBER) }), rets = a_tuple({ NUMBER }) }), - ["log10"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["max"] = a_poly({ - a_function({ args = a_vararg({ INTEGER }), rets = a_tuple({ INTEGER }) }), - a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_tuple({ a }) } end), - a_function({ args = a_vararg({ a_union({ NUMBER, INTEGER }) }), rets = a_tuple({ NUMBER }) }), - a_function({ args = a_vararg({ ANY }), rets = a_tuple({ ANY }) }), - }), + ["ldexp"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["log"] = a_type("function", { args = a_type("tuple", { NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { NUMBER }) }), + ["log10"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["max"] = a_type("poly", { types = { + a_type("function", { args = a_vararg({ INTEGER }), rets = a_type("tuple", { INTEGER }) }), + a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { a }) } end), + a_type("function", { args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { NUMBER }) }), + a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", { ANY }) }), + } }), ["maxinteger"] = a_type("integer", { needs_compat = true }), - ["min"] = a_poly({ - a_function({ args = a_vararg({ INTEGER }), rets = a_tuple({ INTEGER }) }), - a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_tuple({ a }) } end), - a_function({ args = a_vararg({ a_union({ NUMBER, INTEGER }) }), rets = a_tuple({ NUMBER }) }), - a_function({ args = a_vararg({ ANY }), rets = a_tuple({ ANY }) }), - }), + ["min"] = a_type("poly", { types = { + a_type("function", { args = a_vararg({ INTEGER }), rets = a_type("tuple", { INTEGER }) }), + a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { a }) } end), + a_type("function", { args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { NUMBER }) }), + a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", { ANY }) }), + } }), ["mininteger"] = a_type("integer", { needs_compat = true }), - ["modf"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ INTEGER, NUMBER }) }), + ["modf"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { INTEGER, NUMBER }) }), ["pi"] = NUMBER, - ["pow"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), - ["rad"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["random"] = a_poly({ - a_function({ args = a_tuple({ NUMBER, OPT(NUMBER) }), rets = a_tuple({ INTEGER }) }), - a_function({ args = a_tuple({}), rets = a_tuple({ NUMBER }) }), - }), - ["randomseed"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ INTEGER, INTEGER }) }), - ["sin"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["sinh"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["sqrt"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["tan"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["tanh"] = a_function({ args = a_tuple({ NUMBER }), rets = a_tuple({ NUMBER }) }), - ["tointeger"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ INTEGER }) }), - ["type"] = a_function({ args = a_tuple({ ANY }), rets = a_tuple({ STRING }) }), - ["ult"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ BOOLEAN }) }), + ["pow"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["rad"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["random"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { INTEGER }) }), + a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { NUMBER }) }), + } }), + ["randomseed"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { INTEGER, INTEGER }) }), + ["sin"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["sinh"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["sqrt"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["tan"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["tanh"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["tointeger"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { INTEGER }) }), + ["type"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), + ["ult"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { BOOLEAN }) }), }, }), ["os"] = a_record({ fields = { - ["clock"] = a_function({ args = a_tuple({}), rets = a_tuple({ NUMBER }) }), - ["date"] = a_poly({ - a_function({ args = a_tuple({}), rets = a_tuple({ STRING }) }), - a_function({ args = a_tuple({ an_enum({ "!*t", "*t" }), OPT(NUMBER) }), rets = a_tuple({ OS_DATE_TABLE }) }), - a_function({ args = a_tuple({ OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), - }), - ["difftime"] = a_function({ args = a_tuple({ NUMBER, NUMBER }), rets = a_tuple({ NUMBER }) }), - ["execute"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ BOOLEAN, STRING, INTEGER }) }), - ["exit"] = a_function({ args = a_tuple({ OPT(a_union({ NUMBER, BOOLEAN })), OPT(BOOLEAN) }), rets = a_tuple({}) }), - ["getenv"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), - ["remove"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ BOOLEAN, STRING }) }), - ["rename"] = a_function({ args = a_tuple({ STRING, STRING }), rets = a_tuple({ BOOLEAN, STRING }) }), - ["setlocale"] = a_function({ args = a_tuple({ STRING, OPT(STRING) }), rets = a_tuple({ STRING }) }), - ["time"] = a_function({ args = a_tuple({ OPT(OS_DATE_TABLE) }), rets = a_tuple({ INTEGER }) }), - ["tmpname"] = a_function({ args = a_tuple({}), rets = a_tuple({ STRING }) }), + ["clock"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { NUMBER }) }), + ["date"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { STRING }) }), + a_type("function", { args = a_type("tuple", { an_enum({ "!*t", "*t" }), OPT(NUMBER) }), rets = a_type("tuple", { OS_DATE_TABLE }) }), + a_type("function", { args = a_type("tuple", { OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + } }), + ["difftime"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["execute"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { BOOLEAN, STRING, INTEGER }) }), + ["exit"] = a_type("function", { args = a_type("tuple", { OPT(a_type("union", { types = { NUMBER, BOOLEAN } })), OPT(BOOLEAN) }), rets = a_type("tuple", {}) }), + ["getenv"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), + ["remove"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { BOOLEAN, STRING }) }), + ["rename"] = a_type("function", { args = a_type("tuple", { STRING, STRING }), rets = a_type("tuple", { BOOLEAN, STRING }) }), + ["setlocale"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING) }), rets = a_type("tuple", { STRING }) }), + ["time"] = a_type("function", { args = a_type("tuple", { OPT(OS_DATE_TABLE) }), rets = a_type("tuple", { INTEGER }) }), + ["tmpname"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { STRING }) }), }, }), ["package"] = a_record({ fields = { ["config"] = STRING, ["cpath"] = STRING, - ["loaded"] = a_map(STRING, ANY), - ["loaders"] = an_array(a_function({ args = a_tuple({ STRING }), rets = a_tuple({ ANY, ANY }) })), - ["loadlib"] = a_function({ args = a_tuple({ STRING, STRING }), rets = a_tuple({ FUNCTION }) }), + ["loaded"] = a_type("map", { keys = STRING, values = ANY }), + ["loaders"] = a_type("array", { elements = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { ANY, ANY }) }) }), + ["loadlib"] = a_type("function", { args = a_type("tuple", { STRING, STRING }), rets = a_type("tuple", { FUNCTION }) }), ["path"] = STRING, ["preload"] = TABLE, - ["searchers"] = an_array(a_function({ args = a_tuple({ STRING }), rets = a_tuple({ ANY, ANY }) })), - ["searchpath"] = a_function({ args = a_tuple({ STRING, STRING, OPT(STRING), OPT(STRING) }), rets = a_tuple({ STRING, STRING }) }), + ["searchers"] = a_type("array", { elements = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { ANY, ANY }) }) }), + ["searchpath"] = a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(STRING), OPT(STRING) }), rets = a_type("tuple", { STRING, STRING }) }), }, }), ["string"] = a_record({ fields = { - ["byte"] = a_poly({ - a_function({ args = a_tuple({ STRING, OPT(NUMBER) }), rets = a_tuple({ INTEGER }) }), - a_function({ args = a_tuple({ STRING, NUMBER, NUMBER }), rets = a_vararg({ INTEGER }) }), - }), - ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_tuple({ STRING }) }), - ["dump"] = a_function({ args = a_tuple({ FUNCTION, OPT(BOOLEAN) }), rets = a_tuple({ STRING }) }), - ["find"] = a_function({ args = a_tuple({ STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }), rets = a_vararg({ INTEGER, INTEGER, STRING }) }), - ["format"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_tuple({ STRING }) }), - ["gmatch"] = a_function({ args = a_tuple({ STRING, STRING }), rets = a_tuple({ - a_function({ args = a_tuple({}), rets = a_vararg({ STRING }) }), + ["byte"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER) }), rets = a_type("tuple", { INTEGER }) }), + a_type("function", { args = a_type("tuple", { STRING, NUMBER, NUMBER }), rets = a_vararg({ INTEGER }) }), + } }), + ["char"] = a_type("function", { args = a_vararg({ NUMBER }), rets = a_type("tuple", { STRING }) }), + ["dump"] = a_type("function", { args = a_type("tuple", { FUNCTION, OPT(BOOLEAN) }), rets = a_type("tuple", { STRING }) }), + ["find"] = a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }), rets = a_vararg({ INTEGER, INTEGER, STRING }) }), + ["format"] = a_type("function", { args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { STRING }) }), + ["gmatch"] = a_type("function", { args = a_type("tuple", { STRING, STRING }), rets = a_type("tuple", { + a_type("function", { args = a_type("tuple", {}), rets = a_vararg({ STRING }) }), }), }), - ["gsub"] = a_poly({ - a_function({ args = a_tuple({ STRING, STRING, a_map(STRING, STRING), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), - a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({ STRING }) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), - a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({ NUMBER }) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), - a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({ BOOLEAN }) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), - a_function({ args = a_tuple({ STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_tuple({}) }), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), - a_function({ args = a_tuple({ STRING, STRING, OPT(STRING), OPT(NUMBER) }), rets = a_tuple({ STRING, INTEGER }) }), - - }), - ["len"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ INTEGER }) }), - ["lower"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), - ["match"] = a_function({ args = a_tuple({ STRING, OPT(STRING), OPT(NUMBER) }), rets = a_vararg({ STRING }) }), - ["pack"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_tuple({ STRING }) }), - ["packsize"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ INTEGER }) }), - ["rep"] = a_function({ args = a_tuple({ STRING, NUMBER, OPT(STRING) }), rets = a_tuple({ STRING }) }), - ["reverse"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), - ["sub"] = a_function({ args = a_tuple({ STRING, NUMBER, OPT(NUMBER) }), rets = a_tuple({ STRING }) }), - ["unpack"] = a_function({ args = a_tuple({ STRING, STRING, OPT(NUMBER) }), rets = a_vararg({ ANY }) }), - ["upper"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ STRING }) }), + ["gsub"] = a_type("poly", { types = { + a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("map", { keys = STRING, values = STRING }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", { STRING }) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", { NUMBER }) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", { BOOLEAN }) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", {}) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + + } }), + ["len"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { INTEGER }) }), + ["lower"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), + ["match"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING), OPT(NUMBER) }), rets = a_vararg({ STRING }) }), + ["pack"] = a_type("function", { args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { STRING }) }), + ["packsize"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { INTEGER }) }), + ["rep"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, OPT(STRING) }), rets = a_type("tuple", { STRING }) }), + ["reverse"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), + ["sub"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + ["unpack"] = a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(NUMBER) }), rets = a_vararg({ ANY }) }), + ["upper"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), }, }), ["table"] = a_record({ fields = { - ["concat"] = a_function({ args = a_tuple({ an_array(a_union({ STRING, NUMBER })), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = a_tuple({ STRING }) }), - ["insert"] = a_poly({ - a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, a }), rets = a_tuple({}) } end), - a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), a }), rets = a_tuple({}) } end), - }), - ["move"] = a_poly({ - a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, NUMBER, NUMBER }), rets = a_tuple({ an_array(a) }) } end), - a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }), rets = a_tuple({ an_array(a) }) } end), - }), - ["pack"] = a_function({ args = a_vararg({ ANY }), rets = a_tuple({ TABLE }) }), - ["remove"] = a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), OPT(NUMBER) }), rets = a_tuple({ a }) } end), - ["sort"] = a_gfunction(1, function(a) return { args = a_tuple({ an_array(a), OPT(TABLE_SORT_FUNCTION) }), rets = a_tuple({}) } end), - ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = a_tuple({ an_array(a), OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ a }) } end), + ["concat"] = a_type("function", { args = a_type("tuple", { a_type("array", { elements = a_type("union", { types = { STRING, NUMBER } }) }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + ["insert"] = a_type("poly", { types = { + a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, a }), rets = a_type("tuple", {}) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), a }), rets = a_type("tuple", {}) } end), + } }), + ["move"] = a_type("poly", { types = { + a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER }), rets = a_type("tuple", { a_type("array", { elements = a }) }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER, a_type("array", { elements = a }) }), rets = a_type("tuple", { a_type("array", { elements = a }) }) } end), + } }), + ["pack"] = a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", { TABLE }) }), + ["remove"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), OPT(NUMBER) }), rets = a_type("tuple", { a }) } end), + ["sort"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), OPT(TABLE_SORT_FUNCTION) }), rets = a_type("tuple", {}) } end), + ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = a_type("tuple", { a_type("array", { elements = a }), OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ a }) } end), }, }), ["utf8"] = a_record({ fields = { - ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_tuple({ STRING }) }), + ["char"] = a_type("function", { args = a_vararg({ NUMBER }), rets = a_type("tuple", { STRING }) }), ["charpattern"] = STRING, - ["codepoint"] = a_function({ args = a_tuple({ STRING, OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ INTEGER }) }), - ["codes"] = a_function({ args = a_tuple({ STRING }), rets = a_tuple({ - a_function({ args = a_tuple({ STRING, OPT(NUMBER) }), rets = a_tuple({ NUMBER, NUMBER }) }), + ["codepoint"] = a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ INTEGER }) }), + ["codes"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { + a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER) }), rets = a_type("tuple", { NUMBER, NUMBER }) }), }), }), - ["len"] = a_function({ args = a_tuple({ STRING, NUMBER, NUMBER }), rets = a_tuple({ INTEGER }) }), - ["offset"] = a_function({ args = a_tuple({ STRING, NUMBER, NUMBER }), rets = a_tuple({ INTEGER }) }), + ["len"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, NUMBER }), rets = a_type("tuple", { INTEGER }) }), + ["offset"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, NUMBER }), rets = a_type("tuple", { INTEGER }) }), }, }), ["_VERSION"] = STRING, @@ -6155,7 +6180,7 @@ tl.type_check = function(ast, opts) end if opts.module_name then - env.modules[opts.module_name] = a_typetype({ def = CIRCULAR_REQUIRE }) + env.modules[opts.module_name] = a_type("typetype", { def = CIRCULAR_REQUIRE }) end local lax = opts.lax @@ -7241,7 +7266,7 @@ tl.type_check = function(ast, opts) if #ts == 1 then return ts[1] else - return a_union(ts) + return a_type("union", { types = ts }) end end @@ -7266,13 +7291,13 @@ tl.type_check = function(ast, opts) local element_type = unite(tupletype.types, true) local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) if valid then - return an_array(element_type) + return a_type("array", { elements = element_type }) end - local arr_type = an_array(tupletype.types[1]) + local arr_type = a_type("array", { elements = tupletype.types[1] }) for i = 2, #tupletype.types do - arr_type = expand_type(where, arr_type, an_array(tupletype.types[i])) + arr_type = expand_type(where, arr_type, a_type("array", { elements = tupletype.types[i] })) if not arr_type.elements then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end @@ -7794,7 +7819,7 @@ a.types[i], b.types[i]), } ["bad_nominal"] = compare_false, ["any"] = compare_true, ["tuple"] = function(a, b) - return is_a(a_tuple({ a }), b) + return is_a(a_type("tuple", { a }), b) end, ["typevar"] = function(a, b) return compare_or_infer_typevar(b.typevar, a, nil, is_a) @@ -7923,9 +7948,9 @@ a.types[i], b.types[i]), } return true elseif t2.typename == "unresolved_emptytable_value" then if is_number_type(t2.emptytable_type.keys) then - infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) + infer_emptytable(t2.emptytable_type, infer_at(where, a_type("array", { elements = t1 }))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) + infer_emptytable(t2.emptytable_type, infer_at(where, a_type("map", { keys = t2.emptytable_type.keys, values = t1 }))) end return true elseif t2.typename == "emptytable" then @@ -7995,7 +8020,7 @@ a.types[i], b.types[i]), } t = resolve_tuple_and_nominal(t) local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt then - local args_tuple = a_tuple({}) + local args_tuple = a_type("tuple", {}) for i = 2, #call_mt.args do table.insert(args_tuple, call_mt.args[i]) end @@ -8007,7 +8032,7 @@ a.types[i], b.types[i]), } local function resolve_for_call(func, args, is_method) if lax and is_unknown(func) then - func = a_function({ args = a_vararg({ UNKNOWN }), rets = a_vararg({ UNKNOWN }) }) + func = a_type("function", { args = a_vararg({ UNKNOWN }), rets = a_vararg({ UNKNOWN }) }) end func = resolve_tuple_and_nominal(func) @@ -8284,7 +8309,7 @@ a.types[i], b.types[i]), } argdelta = is_method and -1 or argdelta or 0 if is_method and args[1] then - add_var(nil, "@self", a_typetype({ y = where.y, x = where.x, def = args[1] })) + add_var(nil, "@self", a_type("typetype", { y = where.y, x = where.x, def = args[1] })) end local is_func = func.typename == "function" @@ -8356,7 +8381,7 @@ a.types[i], b.types[i]), } type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) if node.expected and node.expected.typename ~= "tuple" then - node.expected = a_tuple({ node.expected }) + node.expected = a_type("tuple", { node.expected }) end begin_scope() @@ -8405,7 +8430,7 @@ a.types[i], b.types[i]), } if metamethod then local where_args = { node.e1 } - local args = a_tuple({ orig_a }) + local args = a_type("tuple", { orig_a }) if b and method_name ~= "__is" then where_args[2] = node.e2 args[2] = orig_b @@ -8596,7 +8621,7 @@ a.types[i], b.types[i]), } local t = rets if not t.typename then - t = a_tuple(t) + t = a_type("tuple", t) end assert(t.typeid) return t @@ -8606,7 +8631,7 @@ a.types[i], b.types[i]), } assert(args.typename == "tuple") add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_tuple({})) + add_var(nil, "@return", node.rets or a_type("tuple", {})) if node.typeargs then for _, t in ipairs(node.typeargs) do @@ -8621,13 +8646,13 @@ a.types[i], b.types[i]), } local function add_function_definition_for_recursion(node, fnargs) assert(fnargs.typename == "tuple") - local args = a_tuple({}) + local args = a_type("tuple", {}) args.is_va = fnargs.is_va for _, fnarg in ipairs(fnargs) do table.insert(args, fnarg) end - add_var(nil, node.name.tk, a_function({ + add_var(nil, node.name.tk, a_type("function", { typeargs = node.typeargs, args = args, rets = get_rets(node.rets), @@ -9323,7 +9348,7 @@ a.types[i], b.types[i]), } local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_tuple({ BOOLEAN }) + return a_type("tuple", { BOOLEAN }) end @@ -9351,7 +9376,7 @@ a.types[i], b.types[i]), } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets.typename ~= "tuple" then - rets = a_tuple({ rets }) + rets = a_type("tuple", { rets }) end table.insert(rets, 1, BOOLEAN) return rets @@ -10187,7 +10212,7 @@ a.types[i], b.types[i]), } widen_all_unions(node) local exp1 = node.exps[1] - local args = a_tuple({ + local args = a_type("tuple", { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], }) @@ -10303,7 +10328,7 @@ a.types[i], b.types[i]), } }, ["variable_list"] = { after = function(_node, children) - local tuple = a_tuple(children) + local tuple = a_type("tuple", children) local n = #tuple @@ -10457,7 +10482,7 @@ a.types[i], b.types[i]), } local t if force_array then - t = infer_at(node, an_array(force_array)) + t = infer_at(node, a_type("array", { elements = force_array })) else t = resolve_typevars_at(node, node.expected) if node.expected == t and t.typename == "nominal" then @@ -10519,7 +10544,7 @@ a.types[i], b.types[i]), } end_function_scope(node) local rets = get_rets(children[3]) - local t = ensure_fresh_typeargs(a_function({ + local t = ensure_fresh_typeargs(a_type("function", { y = node.y, x = node.x, typeargs = node.typeargs, @@ -10532,6 +10557,34 @@ a.types[i], b.types[i]), } return t end, }, + ["local_macroexp"] = { + before = function(node) + widen_all_unions() + if symbol_list then + reserve_symbol_list_slot(node) + end + begin_scope(node) + end, + after = function(node, children) + end_function_scope(node) + local rets = get_rets(children[3]) + + check_macroexp_arg_use(node.macrodef) + + local t = ensure_fresh_typeargs(a_type("function", { + y = node.y, + x = node.x, + typeargs = node.typeargs, + args = children[2], + rets = rets, + filename = filename, + macroexp = node.macrodef, + })) + + add_var(node, node.name.tk, t) + return t + end, + }, ["global_function"] = { before = function(node) widen_all_unions() @@ -10560,7 +10613,7 @@ a.types[i], b.types[i]), } return NONE end - add_global(node, node.name.tk, ensure_fresh_typeargs(a_function({ + add_global(node, node.name.tk, ensure_fresh_typeargs(a_type("function", { y = node.y, x = node.x, typeargs = node.typeargs, @@ -10620,7 +10673,7 @@ a.types[i], b.types[i]), } add_var(nil, "self", selftype) end - local fn_type = ensure_fresh_typeargs(a_function({ + local fn_type = ensure_fresh_typeargs(a_type("function", { y = node.y, x = node.x, is_method = node.is_method, @@ -10690,7 +10743,7 @@ a.types[i], b.types[i]), } end_function_scope(node) - return ensure_fresh_typeargs(a_function({ + return ensure_fresh_typeargs(a_type("function", { y = node.y, x = node.x, typeargs = node.typeargs, @@ -10713,7 +10766,7 @@ a.types[i], b.types[i]), } end_function_scope(node) - return ensure_fresh_typeargs(a_function({ + return ensure_fresh_typeargs(a_type("function", { y = node.y, x = node.x, typeargs = node.typeargs, @@ -10826,6 +10879,14 @@ a.types[i], b.types[i]), } return type_check_funcall(node, a, b) end + if ra.macroexp then + error_at(node.e1, "macroexps are abstract; consider using a concrete function") + end + + if rb and rb.macroexp then + error_at(node.e2, "macroexps are abstract; consider using a concrete function") + end + if node.op.op == "." then node.receiver = a @@ -11273,7 +11334,7 @@ a.types[i], b.types[i]), } ["record"] = { before = function(typ) begin_scope() - add_var(nil, "@self", a_typetype({ y = typ.y, x = typ.x, def = typ })) + add_var(nil, "@self", a_type("typetype", { y = typ.y, x = typ.x, def = typ })) for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then diff --git a/tl.tl b/tl.tl index 4d50eb652..b0d13a468 100644 --- a/tl.tl +++ b/tl.tl @@ -1215,6 +1215,7 @@ local enum NodeKind "..." "paren" "macroexp" + "local_macroexp" "interface" "error_node" end @@ -1434,6 +1435,7 @@ local record Node is_lvalue: boolean -- macroexp + macrodef: Node expanded: Node decltype: Type @@ -1561,23 +1563,27 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type }) end -local function a_tuple(t: {Type}): Type +local macroexp a_tuple(t: {Type}): Type return a_type("tuple", t) end -local function a_union(t: {Type}): Type +local function c_tuple(t: {Type}): Type + return a_type("tuple", t) +end + +local macroexp a_union(t: {Type}): Type return a_type("union", { types = t }) end -local function a_poly(t: {Type}): Type +local macroexp a_poly(t: {Type}): Type return a_type("poly", { types = t }) end -local function a_function(t: Type): Type +local macroexp a_function(t: Type): Type return a_type("function", t) end -local function a_typetype(t: Type): Type +local macroexp a_typetype(t: Type): Type return a_type("typetype", t) end @@ -1587,17 +1593,12 @@ local function a_vararg(t: {Type}): Type return a_tuple(t) end -local function an_array(t: Type): Type - return a_type("array", { - elements = t, - }) +local macroexp an_array(t: Type): Type + return a_type("array", { elements = t }) end -local function a_map(k: Type, v: Type): Type - return a_type("map", { - keys = k, - values = v, - }) +local macroexp a_map(k: Type, v: Type): Type + return a_type("map", { keys = k, values = v }) end local NIL = a_type("nil", {}) @@ -2943,15 +2944,14 @@ local metamethod_names: {string:boolean} = { ["__is"] = true, } -local function parse_macroexp(ps: ParseState, i: integer): integer, Node - local istart = i +local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): integer, Node -- TODO: generic macroexp -- if ps.tokens[i].tk == "<" then -- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) -- end local node = new_node(ps.tokens, istart, "macroexp") - i = i + 1 -- skip 'macroexp' - i, node.args = parse_argument_list(ps, i) + local i: integer + i, node.args = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) @@ -3135,7 +3135,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): if t.typename ~= "function" then fail(ps, i + 1, "macroexp must have a function type") end - i, t.macroexp = parse_macroexp(ps, i + 1) + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3342,6 +3342,15 @@ local function skip_type_declaration(ps: ParseState, i: integer): integer, Node return parse_type_declaration(ps, i - 1, "local_type") end +local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node + local istart = i + i = i + 2 -- skip `local` + local node = new_node(ps.tokens, i, "local_macroexp") + i, node.name = parse_identifier(ps, i) + i, node.macrodef = parse_macroexp(ps, istart, i) + return i, node +end + local function parse_local(ps: ParseState, i: integer): integer, Node local ntk = ps.tokens[i + 1].tk local tn = ntk as TypeName @@ -3349,6 +3358,8 @@ local function parse_local(ps: ParseState, i: integer): integer, Node return parse_local_function(ps, i) elseif ntk == "type" and ps.tokens[i+2].kind == "identifier" then return parse_type_declaration(ps, i, "local_type") + elseif ntk == "macroexp" and ps.tokens[i+2].kind == "identifier" then + return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) end @@ -3891,6 +3902,14 @@ local function recurse_node(root: Node, extra_callback("before_statements", ast, xs, visit_node) xs[5] = recurse(ast.body) end, + ["local_macroexp"] = function(ast: Node, xs: {T}) + -- TODO: generic macroexp + xs[1] = recurse(ast.name) + xs[2] = recurse(ast.macrodef.args) + xs[3] = recurse_type(ast.macrodef.rets, visit_type) + extra_callback("before_exp", ast, xs, visit_node) + xs[4] = recurse(ast.macrodef.exp) + end, ["forin"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) @@ -4450,6 +4469,12 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return out end, }, + ["local_macroexp"] = { + before = increment_indent, + after = function(node: Node, _children: {Output}): Output + return { y = node.y, h = 0 } + end, + }, ["local_function"] = { before = increment_indent, after = function(node: Node, children: {Output}): Output @@ -5654,7 +5679,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local file_reader_poly_types: {ArgsRets} = { { ctor = a_vararg, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, - { ctor = a_tuple, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, + { ctor = c_tuple, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, { ctor = a_vararg, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { a_union { STRING, NUMBER } } }, { ctor = a_vararg, args = { a_union { NUMBER, STRING } }, rets = { STRING } }, { ctor = a_vararg, args = { }, rets = { STRING } }, @@ -8621,7 +8646,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function add_function_definition_for_recursion(node: Node, fnargs: Type) assert(fnargs.typename == "tuple") - local args = a_tuple({}) + local args = a_type("tuple", {}) args.is_va = fnargs.is_va for _, fnarg in ipairs(fnargs) do table.insert(args, fnarg) @@ -10532,6 +10557,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end, }, + ["local_macroexp"] = { + before = function(node: Node) + widen_all_unions() + if symbol_list then + reserve_symbol_list_slot(node) + end + begin_scope(node) + end, + after = function(node: Node, children: {Type}): Type + end_function_scope(node) + local rets = get_rets(children[3]) + + check_macroexp_arg_use(node.macrodef) + + local t = ensure_fresh_typeargs(a_function { + y = node.y, + x = node.x, + typeargs = node.typeargs, + args = children[2], + rets = rets, + filename = filename, + macroexp = node.macrodef, + }) + + add_var(node, node.name.tk, t) + return t + end, + }, ["global_function"] = { before = function(node: Node) widen_all_unions() @@ -10826,6 +10879,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return type_check_funcall(node, a, b) end + if ra.macroexp then + error_at(node.e1, "macroexps are abstract; consider using a concrete function") + end + + if rb and rb.macroexp then + error_at(node.e2, "macroexps are abstract; consider using a concrete function") + end + if node.op.op == "." then node.receiver = a From b55294fbaf8d9d70f8637f5306422b245a0e929b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 11 Dec 2023 11:01:43 -0300 Subject: [PATCH 042/224] macroexp: adjust yend --- tl.lua | 4 ++-- tl.tl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index d12ed0ee2..d69a3c358 100644 --- a/tl.lua +++ b/tl.lua @@ -8112,10 +8112,10 @@ a.types[i], b.types[i]), } end end + out.yend = out.yend and (orignode.y + (out.yend - out.y)) or nil + out.xend = nil out.y = orignode.y out.x = orignode.x - out.yend = nil - out.xend = nil return { node, out } end diff --git a/tl.tl b/tl.tl index b0d13a468..28c3667bc 100644 --- a/tl.tl +++ b/tl.tl @@ -8112,10 +8112,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + out.yend = out.yend and (orignode.y + (out.yend - out.y)) or nil + out.xend = nil out.y = orignode.y out.x = orignode.x - out.yend = nil - out.xend = nil return { node, out } end From 180394b9780fdb680e943142f95820a2ac25dacd Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 13 Dec 2023 19:41:07 -0300 Subject: [PATCH 043/224] remove node.rtype --- tl.lua | 12 ++++++------ tl.tl | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tl.lua b/tl.lua index d69a3c358..a8bafd5e5 100644 --- a/tl.lua +++ b/tl.lua @@ -1445,7 +1445,6 @@ local Node = {ExpectedContext = {}, } - local function is_array_type(t) @@ -10630,12 +10629,12 @@ a.types[i], b.types[i]), } widen_all_unions() begin_scope(node) end, - before_arguments = function(node, children) - node.rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + before_arguments = function(_node, children) + local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) - if node.rtype.typeargs then - for _, typ in ipairs(node.rtype.typeargs) do + if rtype.typeargs then + for _, typ in ipairs(rtype.typeargs) do add_var(nil, typ.typearg, a_type("typearg", { y = typ.y, x = typ.x, @@ -10647,7 +10646,8 @@ a.types[i], b.types[i]), } before_statements = function(node, children) local args = children[3] - local rtype = node.rtype + local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + if rtype.typename == "emptytable" then edit_type(rtype, "record") rtype.fields = {} diff --git a/tl.tl b/tl.tl index 28c3667bc..91b43c16f 100644 --- a/tl.tl +++ b/tl.tl @@ -1379,7 +1379,6 @@ local record Node body: Node implicit_global_function: boolean is_predeclared_local_function: boolean - rtype: Type name: Node @@ -10630,12 +10629,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions() begin_scope(node) end, - before_arguments = function(node: Node, children: {Type}) - node.rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + before_arguments = function(_node: Node, children: {Type}) + local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) -- add type arguments from the record implicitly - if node.rtype.typeargs then - for _, typ in ipairs(node.rtype.typeargs) do + if rtype.typeargs then + for _, typ in ipairs(rtype.typeargs) do add_var(nil, typ.typearg, a_type("typearg", { y = typ.y, x = typ.x, @@ -10647,7 +10646,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string before_statements = function(node: Node, children: {Type}) local args = children[3] - local rtype = node.rtype + local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + if rtype.typename == "emptytable" then edit_type(rtype, "record") rtype.fields = {} From 80e41bd20c7f89bb7385a1454832634c57b3c202 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 13 Dec 2023 22:27:12 -0300 Subject: [PATCH 044/224] improve ipairs error message --- spec/stdlib/ipairs_spec.lua | 11 +++++++++++ tl.lua | 7 ++++--- tl.tl | 7 ++++--- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/spec/stdlib/ipairs_spec.lua b/spec/stdlib/ipairs_spec.lua index f92710666..b3712d7af 100644 --- a/spec/stdlib/ipairs_spec.lua +++ b/spec/stdlib/ipairs_spec.lua @@ -15,4 +15,15 @@ describe("ipairs", function() ]], { { msg = [[attempting ipairs on tuple that's not a valid array: {{integer}, {string "a"}}]] }, })) + + it("reports a nominal type in error message", util.check_type_error([[ + local record Rec + x: integer + end + local r: Rec + for i, v in ipairs(r) do + end + ]], { + { msg = [[attempting ipairs on something that's not an array: Rec]] }, + })) end) diff --git a/tl.lua b/tl.lua index a8bafd5e5..9b70b8361 100644 --- a/tl.lua +++ b/tl.lua @@ -9411,16 +9411,17 @@ a.types[i], b.types[i]), } if not b[1] then return invalid_at(node, "ipairs requires an argument") end - local t = resolve_tuple_and_nominal(b[1]) + local orig_t = b[1] + local t = resolve_tuple_and_nominal(orig_t) if t.typename == "tupletable" then local arr_type = arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", t) + return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not is_array_type(t) then if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", t) + return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end diff --git a/tl.tl b/tl.tl index 91b43c16f..407d0bea8 100644 --- a/tl.tl +++ b/tl.tl @@ -9411,16 +9411,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not b[1] then return invalid_at(node, "ipairs requires an argument") end - local t = resolve_tuple_and_nominal(b[1]) + local orig_t = b[1] + local t = resolve_tuple_and_nominal(orig_t) if t.typename == "tupletable" then local arr_type = arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", t) + return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not is_array_type(t) then if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", t) + return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end From b428197aaf4332c0f1a6f7ee762d16a343939f42 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 13 Dec 2023 20:42:52 -0300 Subject: [PATCH 045/224] be more careful with typeid when shallow-copying types --- tl.lua | 29 +++++++++++++---------------- tl.tl | 31 ++++++++++++++----------------- 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/tl.lua b/tl.lua index 9b70b8361..3a78a389d 100644 --- a/tl.lua +++ b/tl.lua @@ -1609,17 +1609,16 @@ local THREAD = a_type("thread", {}) local BOOLEAN = a_type("boolean", {}) local INTEGER = a_type("integer", {}) - -local function shallow_copy_type(t) +local function shallow_copy_new_type(t) local copy = {} for k, v in pairs(t) do copy[k] = v end + copy.typeid = new_typeid() return copy end - -local function shallow_copy_node(t) +local function shallow_copy_table(t) local copy = {} for k, v in pairs(t) do copy[k] = v @@ -1910,7 +1909,7 @@ local function OPT(t) return memoize_opt_types[t] end - local ot = shallow_copy_type(t) + local ot = shallow_copy_new_type(t) ot.opt = true memoize_opt_types[t] = ot return ot @@ -2588,7 +2587,7 @@ local function parse_argument_type(ps, i) end if argument_name == "self" then - typ = shallow_copy_type(typ) + typ = shallow_copy_new_type(typ) typ.is_self = true end end @@ -5687,8 +5686,8 @@ local function init_globals(lax) local function a_file_reader(fn) local t = a_type("poly", { types = {} }) for _, entry in ipairs(file_reader_poly_types) do - local args = shallow_copy_type(entry.args) - local rets = shallow_copy_type(entry.rets) + local args = shallow_copy_table(entry.args) + local rets = shallow_copy_table(entry.rets) table.insert(t.types, fn(entry.ctor, args, rets)) end return t @@ -6801,7 +6800,7 @@ tl.type_check = function(ast, opts) if ret.typename == "invalid" then ret = t end - ret = (ret ~= t) and ret or shallow_copy_type(t) + ret = (ret ~= t) and ret or shallow_copy_table(t) ret.inferred_at = where ret.inferred_at.filename = filename return ret @@ -6811,7 +6810,7 @@ tl.type_check = function(ast, opts) if not t.tk then return t end - local ret = shallow_copy_type(t) + local ret = shallow_copy_table(t) ret.tk = nil return ret end @@ -8096,7 +8095,7 @@ a.types[i], b.types[i]), } local on_node = function(node, children, ret) local orig = ret and ret[2] or node - local out = shallow_copy_node(orig) + local out = shallow_copy_table(orig) local map = {} for _, pair in pairs(children) do @@ -9352,7 +9351,7 @@ a.types[i], b.types[i]), } local ftype = table.remove(b, 1) - ftype = shallow_copy_type(ftype) + ftype = shallow_copy_new_type(ftype) ftype.is_method = false local fe2 = {} @@ -9777,8 +9776,7 @@ a.types[i], b.types[i]), } infertype = INVALID elseif infertype and infertype.is_method then - infertype = shallow_copy_type(infertype) - infertype.typeid = new_typeid() + infertype = shallow_copy_new_type(infertype) infertype.is_method = false end end @@ -10514,8 +10512,7 @@ a.types[i], b.types[i]), } end if vtype.is_method then - vtype = shallow_copy_type(vtype) - vtype.typeid = new_typeid() + vtype = shallow_copy_new_type(vtype) vtype.is_method = false end return a_type("table_item", { diff --git a/tl.tl b/tl.tl index 407d0bea8..d83ca8b6a 100644 --- a/tl.tl +++ b/tl.tl @@ -1609,22 +1609,21 @@ local THREAD = a_type("thread", {}) local BOOLEAN = a_type("boolean", {}) local INTEGER = a_type("integer", {}) --- Makes a shallow copy of the given type -local function shallow_copy_type(t: Type): Type +local function shallow_copy_new_type(t: Type): Type local copy: {any:any} = {} for k, v in pairs(t as {any:any}) do copy[k] = v end + copy.typeid = new_typeid() return copy as Type end --- Makes a shallow copy of the given node -local function shallow_copy_node(t: Node): Node +local function shallow_copy_table(t: T): T local copy: {any:any} = {} for k, v in pairs(t as {any:any}) do copy[k] = v end - return copy as Node + return copy as T end local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind?: NodeKind): integer, Node @@ -1910,7 +1909,7 @@ local function OPT(t: Type): Type return memoize_opt_types[t] end - local ot = shallow_copy_type(t) + local ot = shallow_copy_new_type(t) ot.opt = true memoize_opt_types[t] = ot return ot @@ -2588,7 +2587,7 @@ local function parse_argument_type(ps: ParseState, i: integer): integer, TypeAnd end if argument_name == "self" then - typ = shallow_copy_type(typ) + typ = shallow_copy_new_type(typ) typ.is_self = true end end @@ -5687,8 +5686,8 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type)): Type local t = a_poly {} for _, entry in ipairs(file_reader_poly_types) do - local args = shallow_copy_type(entry.args as Type) as {Type} - local rets = shallow_copy_type(entry.rets as Type) as {Type} + local args = shallow_copy_table(entry.args) + local rets = shallow_copy_table(entry.rets) table.insert(t.types, fn(entry.ctor, args, rets)) end return t @@ -6801,7 +6800,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ret.typename == "invalid" then ret = t -- errors are produced by resolve_typevars_at end - ret = (ret ~= t) and ret or shallow_copy_type(t) + ret = (ret ~= t) and ret or shallow_copy_table(t) ret.inferred_at = where ret.inferred_at.filename = filename return ret @@ -6811,7 +6810,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t.tk then return t end - local ret = shallow_copy_type(t) + local ret = shallow_copy_table(t) ret.tk = nil return ret end @@ -8096,7 +8095,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local on_node = function(node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} local orig = ret and ret[2] or node - local out = shallow_copy_node(orig) + local out = shallow_copy_table(orig) local map = {} for _, pair in pairs(children as {integer:{Node, Node}}) do @@ -9352,7 +9351,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- The function called by pcall/xpcall is invoked as a regular function, so we wish to avoid incorrect error messages / unnecessary warning messages associated with calling methods as functions local ftype = table.remove(b, 1) - ftype = shallow_copy_type(ftype) + ftype = shallow_copy_new_type(ftype) ftype.is_method = false local fe2: Node = {} @@ -9777,8 +9776,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string infertype = INVALID elseif infertype and infertype.is_method then -- If we assign a method to a variable, e.g local myfunc = myobj.dothing, the variable should not be treated as a method - infertype = shallow_copy_type(infertype) - infertype.typeid = new_typeid() + infertype = shallow_copy_new_type(infertype) infertype.is_method = false end end @@ -10514,8 +10512,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if vtype.is_method then -- If we assign a method to a table item, e.g local a = { myfunc = myobj.dothing }, the table item should not be treated as a method - vtype = shallow_copy_type(vtype) - vtype.typeid = new_typeid() + vtype = shallow_copy_new_type(vtype) vtype.is_method = false end return a_type("table_item", { From 660f741efb607b69df757ba42d39ad2ac9bf2d9d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 13 Dec 2023 21:28:58 -0300 Subject: [PATCH 046/224] refactor: flatten tuple --- .../assignment/to_multiple_variables_spec.lua | 8 ++ tl.lua | 85 ++++++++---------- tl.tl | 87 ++++++++----------- 3 files changed, 81 insertions(+), 99 deletions(-) diff --git a/spec/assignment/to_multiple_variables_spec.lua b/spec/assignment/to_multiple_variables_spec.lua index 42b73ca79..8a0459d34 100644 --- a/spec/assignment/to_multiple_variables_spec.lua +++ b/spec/assignment/to_multiple_variables_spec.lua @@ -9,6 +9,14 @@ describe("assignment to multiple variables", function() print(b .. " right!") ]])) + it("adjusts arity of tuple", util.check([[ + local function foo(): boolean, string + return true, "yeah!" + end + local a, b, c = 2, foo() + print(c .. " right!") + ]])) + it("reports unsufficient rvalues as an error, simple", util.check_type_error([[ local a, b = 1, 2 a, b = 3 diff --git a/tl.lua b/tl.lua index 3a78a389d..765b319d7 100644 --- a/tl.lua +++ b/tl.lua @@ -8692,54 +8692,53 @@ a.types[i], b.types[i]), } return t end - local function flatten_list(list) - local exps = {} - for i = 1, #list - 1 do - table.insert(exps, resolve_tuple_and_nominal(list[i])) - end - if #list > 0 then - local last = list[#list] - if last.typename == "tuple" then - for _, val in ipairs(last) do - table.insert(exps, val) - end - else - table.insert(exps, last) - end - end - return exps - end + local function flatten_tuple(vals) + local vt = vals + local n_vals = #vt + local ret = a_type("tuple", {}) + local rt = ret - local function get_assignment_values(vals, wanted) - local ret = {} - if vals == nil then + if n_vals == 0 then return ret end - local is_va = vals.is_va - for i = 1, #vals - 1 do - ret[i] = resolve_tuple(vals[i]) + for i = 1, n_vals - 1 do + rt[i] = resolve_tuple(vt[i]) end - local last = vals[#vals] - if last then - if last.typename == "tuple" then - - is_va = last.is_va - for _, v in ipairs(last) do - table.insert(ret, v) - end - else + local last = vt[n_vals] + if last.typename == "tuple" then - table.insert(ret, last) + local lt = last + for _, v in ipairs(lt) do + table.insert(rt, v) end + ret.is_va = last.is_va + else + rt[n_vals] = vt[n_vals] + ret.is_va = vals.is_va end + return ret + end + + local function get_assignment_values(vals, wanted) + if vals == nil then + return a_type("tuple", {}) + end + + local ret = flatten_tuple(vals) - if is_va and last and #ret < wanted then - while #ret < wanted do - table.insert(ret, last) + + if ret.is_va then + local n_ret = #ret + local rt = ret + if n_ret > 0 and n_ret < wanted then + local last = rt[n_ret] + for _ = n_ret + 1, wanted do + table.insert(rt, last) + end end end return ret @@ -10074,7 +10073,6 @@ a.types[i], b.types[i]), } before_exp = set_expected_types_to_decltypes, after = function(node, children) local valtypes = get_assignment_values(children[3], #children[1]) - valtypes = flatten_list(valtypes) for i, vartype in ipairs(children[1]) do local varnode = node.vars[i] local varname = varnode.tk @@ -10328,18 +10326,7 @@ a.types[i], b.types[i]), } after = function(_node, children) local tuple = a_type("tuple", children) - - local n = #tuple - if n > 0 and tuple[n].typename == "tuple" then - local final_tuple = tuple[n] - if final_tuple.is_va then - tuple.is_va = true - end - tuple[n] = nil - for i, c in ipairs(final_tuple) do - tuple[n + i - 1] = c - end - end + tuple = flatten_tuple(tuple) return tuple end, diff --git a/tl.tl b/tl.tl index d83ca8b6a..80090552e 100644 --- a/tl.tl +++ b/tl.tl @@ -8692,54 +8692,53 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function flatten_list(list: {Type}): {Type} - local exps = {} - for i = 1, #list - 1 do - table.insert(exps, resolve_tuple_and_nominal(list[i])) - end - if #list > 0 then - local last = list[#list] - if last.typename == "tuple" then - for _, val in ipairs(last) do - table.insert(exps, val) - end - else - table.insert(exps, last) - end - end - return exps - end + local function flatten_tuple(vals: Type): Type + local vt = vals + local n_vals = #vt + local ret = a_tuple {} + local rt = ret - local function get_assignment_values(vals: Type, wanted: integer): {Type} - local ret: {Type} = {} - if vals == nil then + if n_vals == 0 then return ret end -- get all arguments except the last... - local is_va = vals.is_va - for i = 1, #vals - 1 do - ret[i] = resolve_tuple(vals[i]) + for i = 1, n_vals - 1 do + rt[i] = resolve_tuple(vt[i]) end - local last = vals[#vals] - if last then - if last.typename == "tuple" then - -- ...if the last is a tuple, unpack it - is_va = last.is_va - for _, v in ipairs(last) do - table.insert(ret, v) - end - else - -- ...otherwise simply get it - table.insert(ret, last) + local last = vt[n_vals] + if last.typename == "tuple" then + -- ...then unpack the last tuple + local lt = last + for _, v in ipairs(lt) do + table.insert(rt, v) end + ret.is_va = last.is_va + else + rt[n_vals] = vt[n_vals] + ret.is_va = vals.is_va end + return ret + end + + local function get_assignment_values(vals: Type, wanted: integer): {Type} + if vals == nil then + return a_tuple {} + end + + local ret = flatten_tuple(vals) + -- ...if the last is vararg, repeat its type until it matches the number of wanted args - if is_va and last and #ret < wanted then - while #ret < wanted do - table.insert(ret, last) + if ret.is_va then + local n_ret = #ret + local rt = ret + if n_ret > 0 and n_ret < wanted then + local last = rt[n_ret] + for _ = n_ret + 1, wanted do + table.insert(rt, last) + end end end return ret @@ -10074,7 +10073,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string before_exp = set_expected_types_to_decltypes, after = function(node: Node, children: {Type}): Type local valtypes: {Type} = get_assignment_values(children[3], #children[1]) - valtypes = flatten_list(valtypes) for i, vartype in ipairs(children[1]) do local varnode = node.vars[i] local varname = varnode.tk @@ -10328,18 +10326,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string after = function(_node: Node, children: {Type}): Type local tuple = a_tuple(children) - -- explode last tuple: (1, 2, (3, 4)) becomes (1, 2, 3, 4) - local n = #tuple - if n > 0 and tuple[n].typename == "tuple" then - local final_tuple = tuple[n] - if final_tuple.is_va then - tuple.is_va = true - end - tuple[n] = nil - for i, c in ipairs(final_tuple) do - tuple[n + i - 1] = c - end - end + tuple = flatten_tuple(tuple) return tuple end, From f651782a21557851d62888f2f7422c254af61e18 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 11 Dec 2023 13:21:07 -0300 Subject: [PATCH 047/224] abstract checks for macroexp and interfaces Ensure is_abstract is always set for interface typetypes by using new_typetype constructor. --- spec/assignment/to_interface_spec.lua | 11 +-- spec/declaration/record_spec.lua | 14 ++++ tl.lua | 101 ++++++++++++++++--------- tl.tl | 103 ++++++++++++++++---------- 4 files changed, 149 insertions(+), 80 deletions(-) diff --git a/spec/assignment/to_interface_spec.lua b/spec/assignment/to_interface_spec.lua index b57d75ba1..2c3f4e49e 100644 --- a/spec/assignment/to_interface_spec.lua +++ b/spec/assignment/to_interface_spec.lua @@ -31,13 +31,14 @@ describe("assignment", function() elseif scope:match("with inner def") then -- 2 err = { { y = 6, msg = "cannot reassign a type" } } elseif scope:match("to inner def") then -- 3 + err = {} if outer == "interface" and scope:match("with outer def") then - err = { - { y = 6, msg = "interfaces are abstract; consider using a concrete record" }, - { y = 6, msg = "cannot reassign a type" }, - } + table.insert(err, { y = 6, msg = "interfaces are abstract; consider using a concrete record" }) + end + if inner == "record" then + table.insert(err, { y = 6, msg = "cannot reassign a type" }) else - err = { { y = 6, msg = "cannot reassign a type" } } + table.insert(err, { y = 6, msg = "interfaces are abstract; consider using a concrete record" }) end elseif outer == "interface" and scope == "to inner var with outer def" then -- 4 err = { { y = 6, msg = "interfaces are abstract; consider using a concrete record" } } diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 5e1b34954..60787d933 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -858,3 +858,17 @@ describe("arrayrecord", function() print(a) ]], {})) end) + +describe("abstract check", function() + it("for nested interface", util.check_type_error([[ + local record F1 + interface F2 + z: number + end + end + + F1.F2.z = 9 + ]], { + { y = 7, x = 9, msg = "interfaces are abstract", } + })) +end) diff --git a/tl.lua b/tl.lua index 765b319d7..194ad2fc8 100644 --- a/tl.lua +++ b/tl.lua @@ -1250,6 +1250,7 @@ local table_types = { + local TruthyFact = {} @@ -1562,13 +1563,19 @@ local function new_type(ps, i, typename) }) end +local function new_typetype(ps, i, def) + local t = new_type(ps, i, "typetype") + t.def = def + if def.typename == "interface" then + + t.is_abstract = true + end + return t +end -local function c_tuple(t) - return a_type("tuple", t) -end @@ -1579,6 +1586,15 @@ end +local function c_tuple(t) + return a_type("tuple", t) +end + + + + + + @@ -2885,12 +2901,11 @@ local function parse_nested_type(ps, i, def, typename, parse_body) end local nt = new_node(ps.tokens, i - 2, "newtype") - nt.newtype = new_type(ps, i, "typetype") - local rdef = new_type(ps, i, typename) - local iok = parse_body(ps, i, rdef, nt) + local ndef = new_type(ps, i, typename) + local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype.def = rdef + nt.newtype = new_typetype(ps, i, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3051,6 +3066,7 @@ parse_record_body = function(ps, i, def, node) typ.args = a_type("tuple", { a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) }) typ.rets = a_type("tuple", { BOOLEAN }) typ.macroexp = where_macroexp + typ.is_abstract = true store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) end @@ -3134,6 +3150,7 @@ parse_record_body = function(ps, i, def, node) fail(ps, i + 1, "macroexp must have a function type") end i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + t.is_abstract = true end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3163,19 +3180,22 @@ parse_type_body_fns = { parse_newtype = function(ps, i) local node = new_node(ps.tokens, i, "newtype") - node.newtype = new_type(ps, i, "typetype") + local def local tn = ps.tokens[i].tk + local itype = i if parse_type_body_fns[tn] then - local def = new_type(ps, i, tn) + def = new_type(ps, i, tn) i = i + 1 i = parse_type_body_fns[tn](ps, i, def, node) - node.newtype.def = def - return i, node else - i, node.newtype.def = parse_type(ps, i) - if not node.newtype.def then + i, def = parse_type(ps, i) + if not def then return i end + end + + if def then + node.newtype = new_typetype(ps, itype, def) return i, node end return fail(ps, i, "expected a type") @@ -3320,9 +3340,8 @@ local function parse_type_constructor(ps, i, node_name, type_name, parse_body) local asgn = new_node(ps.tokens, i, node_name) local nt = new_node(ps.tokens, i, "newtype") asgn.value = nt - nt.newtype = new_type(ps, i, "typetype") + local itype = i local def = new_type(ps, i, type_name) - nt.newtype.def = def i = i + 2 @@ -3330,9 +3349,12 @@ local function parse_type_constructor(ps, i, node_name, type_name, parse_body) if not asgn.var then return fail(ps, i, "expected a type name") end - nt.newtype.def.names = { asgn.var.tk } + def.names = { asgn.var.tk } i = parse_body(ps, i, def, nt) + + nt.newtype = new_typetype(ps, itype, def) + return i, asgn end @@ -6332,6 +6354,18 @@ tl.type_check = function(ast, opts) end end + local function ensure_not_abstract(where, t) + if not t.is_abstract then + return + end + + if t.macroexp then + error_at(where, "macroexps are abstract; consider using a concrete function") + else + error_at(where, "interfaces are abstract; consider using a concrete record") + end + end + local function find_type(names, accept_typearg) local typ = find_var_type(names[1], "use_type") if not typ then @@ -6559,6 +6593,7 @@ tl.type_check = function(ast, opts) copy.opt = t.opt copy.is_userdata = t.is_userdata + copy.is_abstract = t.is_abstract copy.typename = t.typename copy.filename = t.filename copy.x = t.x @@ -10323,11 +10358,15 @@ a.types[i], b.types[i]), } end, }, ["variable_list"] = { - after = function(_node, children) + after = function(node, children) local tuple = a_type("tuple", children) tuple = flatten_tuple(tuple) + for i, t in ipairs(tuple) do + ensure_not_abstract(node[i], t) + end + return tuple end, }, @@ -10840,21 +10879,6 @@ a.types[i], b.types[i]), } return invalid_at(node, "cannot dereference a type from a circular require") end - if is_typetype(ra) then - if ra.def.typename == "record" then - ra = ra.def - elseif ra.def.typename == "interface" then - error_at(node.e1, "interfaces are abstract; consider using a concrete record") - end - end - if rb and is_typetype(rb) and rb.def.typename == "record" then - if rb.def.typename == "record" then - rb = rb.def - elseif rb.def.typename == "interface" then - error_at(node.e2, "interfaces are abstract; consider using a concrete record") - end - end - if node.op.op == "@funcall" then if lax and is_unknown(a) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then @@ -10864,12 +10888,15 @@ a.types[i], b.types[i]), } return type_check_funcall(node, a, b) end - if ra.macroexp then - error_at(node.e1, "macroexps are abstract; consider using a concrete function") + ensure_not_abstract(node.e1, ra) + if ra.typename == "typetype" and ra.def.typename == "record" then + ra = ra.def end - - if rb and rb.macroexp then - error_at(node.e2, "macroexps are abstract; consider using a concrete function") + if rb then + ensure_not_abstract(node.e2, rb) + if rb.typename == "typetype" and rb.def.typename == "record" then + rb = rb.def + end end if node.op.op == "." then diff --git a/tl.tl b/tl.tl index 80090552e..1254191f6 100644 --- a/tl.tl +++ b/tl.tl @@ -1096,6 +1096,7 @@ local record Type def: Type is_alias: boolean closed: boolean + is_abstract: boolean -- map keys: Type @@ -1562,6 +1563,25 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type }) end +local function new_typetype(ps: ParseState, i: integer, def: Type): Type + local t = new_type(ps, i, "typetype") + t.def = def + if def.typename == "interface" then + -- ...or should this be set on traversal, to account for nominal type aliases? + t.is_abstract = true + end + return t +end + +local macroexp a_typetype(t: Type): Type +-- FIXME set is_abstract here once standard_library defines interfaces +-- if t.def.typename == "interface" then +-- t.is_abstract = true +-- end +-- return t + return a_type("typetype", t) +end + local macroexp a_tuple(t: {Type}): Type return a_type("tuple", t) end @@ -1582,10 +1602,6 @@ local macroexp a_function(t: Type): Type return a_type("function", t) end -local macroexp a_typetype(t: Type): Type - return a_type("typetype", t) -end - local function a_vararg(t: {Type}): Type local tuple = t as Type tuple.is_va = true @@ -2885,12 +2901,11 @@ local function parse_nested_type(ps: ParseState, i: integer, def: Type, typename end local nt: Node = new_node(ps.tokens, i - 2, "newtype") - nt.newtype = new_type(ps, i, "typetype") - local rdef = new_type(ps, i, typename) - local iok = parse_body(ps, i, rdef, nt) + local ndef = new_type(ps, i, typename) + local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype.def = rdef + nt.newtype = new_typetype(ps, i, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3051,6 +3066,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): typ.args = a_tuple { a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) } typ.rets = a_tuple { BOOLEAN } typ.macroexp = where_macroexp + typ.is_abstract = true store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) end @@ -3134,6 +3150,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): fail(ps, i + 1, "macroexp must have a function type") end i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + t.is_abstract = true end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3163,19 +3180,22 @@ parse_type_body_fns = { parse_newtype = function(ps: ParseState, i: integer): integer, Node local node: Node = new_node(ps.tokens, i, "newtype") - node.newtype = new_type(ps, i, "typetype") + local def: Type local tn = ps.tokens[i].tk as TypeName + local itype = i if parse_type_body_fns[tn] then - local def = new_type(ps, i, tn) + def = new_type(ps, i, tn) i = i + 1 i = parse_type_body_fns[tn](ps, i, def, node) - node.newtype.def = def - return i, node else - i, node.newtype.def = parse_type(ps, i) - if not node.newtype.def then + i, def = parse_type(ps, i) + if not def then return i end + end + + if def then + node.newtype = new_typetype(ps, itype, def) return i, node end return fail(ps, i, "expected a type") @@ -3320,9 +3340,8 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod local asgn: Node = new_node(ps.tokens, i, node_name) local nt: Node = new_node(ps.tokens, i, "newtype") asgn.value = nt - nt.newtype = new_type(ps, i, "typetype") + local itype = i local def = new_type(ps, i, type_name) - nt.newtype.def = def i = i + 2 -- skip `local` or `global`, and the constructor name @@ -3330,9 +3349,12 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod if not asgn.var then return fail(ps, i, "expected a type name") end - nt.newtype.def.names = { asgn.var.tk } + def.names = { asgn.var.tk } i = parse_body(ps, i, def, nt) + + nt.newtype = new_typetype(ps, itype, def) + return i, asgn end @@ -6332,6 +6354,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function ensure_not_abstract(where: Where, t: Type) + if not t.is_abstract then + return + end + + if t.macroexp then + error_at(where, "macroexps are abstract; consider using a concrete function") + else + error_at(where, "interfaces are abstract; consider using a concrete record") + end + end + local function find_type(names: {string}, accept_typearg?: boolean): Type local typ = find_var_type(names[1], "use_type") if not typ then @@ -6559,6 +6593,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.opt = t.opt copy.is_userdata = t.is_userdata + copy.is_abstract = t.is_abstract copy.typename = t.typename copy.filename = t.filename copy.x = t.x @@ -10323,11 +10358,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable_list"] = { - after = function(_node: Node, children: {Type}): Type + after = function(node: Node, children: {Type}): Type local tuple = a_tuple(children) tuple = flatten_tuple(tuple) + for i, t in ipairs(tuple) do + ensure_not_abstract(node[i], t) + end + return tuple end, }, @@ -10840,21 +10879,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(node, "cannot dereference a type from a circular require") end - if is_typetype(ra) then - if ra.def.typename == "record" then - ra = ra.def - elseif ra.def.typename == "interface" then - error_at(node.e1, "interfaces are abstract; consider using a concrete record") - end - end - if rb and is_typetype(rb) and rb.def.typename == "record" then - if rb.def.typename == "record" then - rb = rb.def - elseif rb.def.typename == "interface" then - error_at(node.e2, "interfaces are abstract; consider using a concrete record") - end - end - if node.op.op == "@funcall" then if lax and is_unknown(a) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then @@ -10864,12 +10888,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return type_check_funcall(node, a, b) end - if ra.macroexp then - error_at(node.e1, "macroexps are abstract; consider using a concrete function") + ensure_not_abstract(node.e1, ra) + if ra.typename == "typetype" and ra.def.typename == "record" then + ra = ra.def end - - if rb and rb.macroexp then - error_at(node.e2, "macroexps are abstract; consider using a concrete function") + if rb then + ensure_not_abstract(node.e2, rb) + if rb.typename == "typetype" and rb.def.typename == "record" then + rb = rb.def + end end if node.op.op == "." then From e3a22d80f3c1357dc57ab816d2401f5075d982c6 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 13 Dec 2023 23:10:20 -0300 Subject: [PATCH 048/224] fixup use macroexp --- tl.lua | 1 + tl.tl | 1 + 2 files changed, 2 insertions(+) diff --git a/tl.lua b/tl.lua index 194ad2fc8..625d8d49c 100644 --- a/tl.lua +++ b/tl.lua @@ -3368,6 +3368,7 @@ local function parse_local_macroexp(ps, i) local node = new_node(ps.tokens, i, "local_macroexp") i, node.name = parse_identifier(ps, i) i, node.macrodef = parse_macroexp(ps, istart, i) + end_at(node, ps.tokens[i - 1]) return i, node end diff --git a/tl.tl b/tl.tl index 1254191f6..ae0acd038 100644 --- a/tl.tl +++ b/tl.tl @@ -3368,6 +3368,7 @@ local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node local node = new_node(ps.tokens, i, "local_macroexp") i, node.name = parse_identifier(ps, i) i, node.macrodef = parse_macroexp(ps, istart, i) + end_at(node, ps.tokens[i - 1]) return i, node end From 2780503c9bd7e76e72f7000626266e66829cdd10 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 14 Dec 2023 03:09:07 -0300 Subject: [PATCH 049/224] refactor: reduce confusion in Type instances * `Type` is no longer a `{Type}` - this was being used only for "tuple" types; now those have an explicit `.tuple` attribute. * Split `Node`'s `.decltype` into three different attributes: `.argtype` for "argument", `.itemtype` for "table_item" and `.decltuple` for variable lists. * Add `TupleType` alias; this is not enforced as a subtype at this point. It is only being used for documentation. * Type object constructors such as `a_function` are now stricter about expecting "tuple" entries. This caught a few mistakes. * More consistent tuple usage: function calls always return a "tuple" or an "invalid". Arguments and expected returns are always passed in as a "tuple". Special function handlers are adjusted accordingly. * Explicit assignments of `.y` and `.x` are replaced by `type_at` wrapper, which uses a `Where`. --- spec/cli/types_spec.lua | 1 - tl.lua | 1019 ++++++++++++++++++++------------------- tl.tl | 599 ++++++++++++----------- 3 files changed, 836 insertions(+), 783 deletions(-) diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 5a8fe58f2..d196002d2 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -303,7 +303,6 @@ describe("tl types works like check", function() assert(by_pos["1"]["20"]) -- ( assert(by_pos["1"]["21"]) -- "os" assert(by_pos["1"]["26"]) -- . - assert(by_pos["1"]["20"] == by_pos["1"]["26"]) end) end) end) diff --git a/tl.lua b/tl.lua index 625d8d49c..5b86c0658 100644 --- a/tl.lua +++ b/tl.lua @@ -1250,6 +1250,8 @@ local table_types = { + + @@ -1444,6 +1446,9 @@ local Node = {ExpectedContext = {}, } + + + @@ -1587,7 +1592,7 @@ end local function c_tuple(t) - return a_type("tuple", t) + return a_type("tuple", { tuple = t }) end @@ -1598,14 +1603,16 @@ end - - - +local function a_function(t) + assert(t.args.typename == "tuple") + assert(t.rets.typename == "tuple") + return a_type("function", t) +end local function a_vararg(t) - local tuple = t - tuple.is_va = true - return a_type("tuple", t) + local typ = a_type("tuple", { tuple = t }) + typ.is_va = true + return typ end @@ -1732,8 +1739,8 @@ local function parse_table_item(ps, i, n) node.key.conststr = node.key.tk node.key.tk = '"' .. node.key.tk .. '"' i = verify_tk(try_ps, i, ":") - i, node.decltype = parse_type(try_ps, i) - if node.decltype and ps.tokens[i].tk == "=" then + i, node.itemtype = parse_type(try_ps, i) + if node.itemtype and ps.tokens[i].tk == "=" then i = verify_tk(try_ps, i, "=") i, node.value = parse_table_value(try_ps, i) if node.value then @@ -1744,7 +1751,7 @@ local function parse_table_item(ps, i, n) end end - node.decltype = nil + node.itemtype = nil i = orig_i end end @@ -1861,9 +1868,9 @@ local function parse_anglebracket_list(ps, i, parse_item) if ps.tokens[i + 1].tk == ">" then return fail(ps, i + 1, "type argument list cannot be empty") end - local typ = new_type(ps, i, "tuple") + local types = {} i = verify_tk(ps, i, "<") - i = parse_list(ps, i, typ, { [">"] = true, [">>"] = true }, "sep", parse_item) + i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) if ps.tokens[i].tk == ">" then i = i + 1 elseif ps.tokens[i].tk == ">>" then @@ -1872,7 +1879,7 @@ local function parse_anglebracket_list(ps, i, parse_item) else return fail(ps, i, "syntax error, expected '>'") end - return i, typ + return i, types end local function parse_typearg(ps, i) @@ -1901,7 +1908,7 @@ local function parse_function_type(ps, i) typ.args = a_vararg({ ANY }) typ.rets = a_vararg({ ANY }) end - if typ.args[1] and typ.args[1].is_self then + if typ.args.tuple[1] and typ.args.tuple[1].is_self then typ.is_method = true end return i, typ @@ -2047,15 +2054,21 @@ parse_type = function(ps, i) return i, bt end +local function new_tuple(ps, i) + local t = new_type(ps, i, "tuple") + t.tuple = {} + return t, t.tuple +end + parse_type_list = function(ps, i, mode) - local list = new_type(ps, i, "tuple") + local t, list = new_tuple(ps, i) local first_token = ps.tokens[i].tk - if mode == "rets" or mode == "decltype" then + if mode == "rets" or mode == "decltuple" then if first_token == ":" then i = i + 1 else - return i, list + return i, t end end @@ -2075,7 +2088,7 @@ parse_type_list = function(ps, i, mode) i = i + 1 local nrets = #list if nrets > 0 then - list.is_va = true + t.is_va = true else fail(ps, i, "unexpected '...'") end @@ -2085,7 +2098,7 @@ parse_type_list = function(ps, i, mode) i = verify_tk(ps, i, ")") end - return i, list + return i, t end local function parse_function_args_rets_body(ps, i, node) @@ -2528,12 +2541,12 @@ local function parse_argument(ps, i) end if ps.tokens[i].tk == ":" then i = i + 1 - local decltype + local argtype - i, decltype = parse_type(ps, i) + i, argtype = parse_type(ps, i) if node then - node.decltype = decltype + node.argtype = argtype end end return i, node, 0 @@ -2614,7 +2627,7 @@ end parse_argument_type_list = function(ps, i) local tvs = {} i = parse_bracket_list(ps, i, tvs, "(", ")", "sep", parse_argument_type) - local list = new_type(ps, i, "tuple") + local t, list = new_tuple(ps, i) local n = #tvs for l, tv in ipairs(tvs) do list[l] = tv.type @@ -2623,9 +2636,9 @@ parse_argument_type_list = function(ps, i) end end if tvs[n] and tvs[n].is_va then - list.is_va = true + t.is_va = true end - return i, list + return i, t end local function parse_identifier(ps, i) @@ -2979,10 +2992,10 @@ local function parse_where_clause(ps, i) node.args = new_node(ps.tokens, i, "argument_list") node.args[1] = new_node(ps.tokens, i, "argument") node.args[1].tk = "self" - node.args[1].decltype = new_type(ps, i, "nominal") - node.args[1].decltype.names = { "@self" } - node.rets = new_type(ps, i, "tuple") - node.rets[1] = BOOLEAN + node.args[1].argtype = new_type(ps, i, "nominal") + node.args[1].argtype.names = { "@self" } + node.rets = new_tuple(ps, i) + node.rets.tuple[1] = BOOLEAN i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i - 1]) return i, node @@ -3063,8 +3076,8 @@ parse_record_body = function(ps, i, def, node) local typ = new_type(ps, wstart, "function") typ.is_method = true - typ.args = a_type("tuple", { a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) }) - typ.rets = a_type("tuple", { BOOLEAN }) + typ.args = a_type("tuple", { tuple = { a_type("nominal", { y = typ.y, x = typ.x, filename = ps.filename, names = { "@self" } }) } }) + typ.rets = a_type("tuple", { tuple = { BOOLEAN } }) typ.macroexp = where_macroexp typ.is_abstract = true @@ -3282,7 +3295,7 @@ local function parse_variable_declarations(ps, i, node_name) return fail(ps, i, "expected a local variable definition") end - i, asgn.decltype = parse_type_list(ps, i, "decltype") + i, asgn.decltuple = parse_type_list(ps, i, "decltuple") if ps.tokens[i].tk == "=" then @@ -3670,10 +3683,11 @@ local function recurse_type(ast, visit) end end - for i, child in ipairs(ast) do - xs[i] = recurse_type(child, visit) + if ast.tuple then + for i, child in ipairs(ast.tuple) do + xs[i] = recurse_type(child, visit) + end end - if ast.types then for _, child in ipairs(ast.types) do table.insert(xs, recurse_type(child, visit)) @@ -3707,14 +3721,14 @@ local function recurse_type(ast, visit) end end if ast.args then - for i, child in ipairs(ast.args) do + for i, child in ipairs(ast.args.tuple) do if i > 1 or not ast.is_method or child.is_self then table.insert(xs, recurse_type(child, visit)) end end end if ast.rets then - for _, child in ipairs(ast.rets) do + for _, child in ipairs(ast.rets.tuple) do table.insert(xs, recurse_type(child, visit)) end end @@ -3803,8 +3817,8 @@ local function recurse_node(root, local function walk_vars_exps(ast, xs) xs[1] = recurse(ast.vars) - if ast.decltype then - xs[2] = recurse_type(ast.decltype, visit_type) + if ast.decltuple then + xs[2] = recurse_type(ast.decltuple, visit_type) end extra_callback("before_exp", ast, xs, visit_node) if ast.exps then @@ -3854,8 +3868,8 @@ local function recurse_node(root, ["table_item"] = function(ast, xs) xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) - if ast.decltype then - xs[3] = recurse_type(ast.decltype, visit_type) + if ast.itemtype then + xs[3] = recurse_type(ast.itemtype, visit_type) end end, @@ -3965,8 +3979,8 @@ local function recurse_node(root, end, ["argument"] = function(ast, xs) - if ast.decltype then - xs[1] = recurse_type(ast.decltype, visit_type) + if ast.argtype then + xs[1] = recurse_type(ast.argtype, visit_type) end end, } @@ -4839,7 +4853,6 @@ local typename_to_typecode = { local skip_types = { ["none"] = true, - ["tuple"] = true, ["table_item"] = true, ["unresolved"] = true, ["typetype"] = true, @@ -4879,12 +4892,12 @@ end local function store_function(trenv, ti, rt) local args = {} - for _, fnarg in ipairs(rt.args) do + for _, fnarg in ipairs(rt.args.tuple) do table.insert(args, mark_array({ get_typenum(trenv, fnarg), nil })) end ti.args = mark_array(args) local rets = {} - for _, fnarg in ipairs(rt.rets) do + for _, fnarg in ipairs(rt.rets.tuple) do table.insert(rets, mark_array({ get_typenum(trenv, fnarg), nil })) end ti.rets = mark_array(rets) @@ -4907,8 +4920,8 @@ get_typenum = function(trenv, t) local rt = t if is_typetype(rt) then rt = rt.def - elseif rt.typename == "tuple" and #rt == 1 then - rt = rt[1] + elseif rt.typename == "tuple" and #rt.tuple == 1 then + rt = rt.tuple[1] end local ti = { @@ -4996,10 +5009,10 @@ local INVALID = a_type("invalid", {}) local UNKNOWN = a_type("unknown", {}) local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_type("function", { args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }) +local FUNCTION = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }) local NOMINAL_FILE = a_type("nominal", { names = { "FILE" } }) -local XPCALL_MSGH_FUNCTION = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", {}) }) +local XPCALL_MSGH_FUNCTION = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = {} }) }) local USERDATA = ANY @@ -5303,10 +5316,14 @@ local function show_type_base(t, short, seen) end elseif t.typename == "tuple" then local out = {} - for _, v in ipairs(t) do + for _, v in ipairs(t.tuple) do table.insert(out, show(v)) end - return "(" .. table.concat(out, ", ") .. ")" + local list = table.concat(out, ", ") + if short then + return list + end + return "(" .. list .. ")" elseif t.typename == "tupletable" then local out = {} for _, v in ipairs(t.types) do @@ -5353,20 +5370,20 @@ local function show_type_base(t, short, seen) if t.is_method then table.insert(args, "self") end - for i, v in ipairs(t.args) do + for i, v in ipairs(t.args.tuple) do if not t.is_method or i > 1 then - table.insert(args, ((i == #t.args and t.args.is_va) and "...: " or + table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " or v.opt and "? " or "") .. show(v)) end end table.insert(out, table.concat(args, ", ")) table.insert(out, ")") - if #t.rets > 0 then + if t.rets.tuple and #t.rets.tuple > 0 then table.insert(out, ": ") local rets = {} - for i, v in ipairs(t.rets) do - table.insert(rets, show(v) .. (i == #t.rets and t.rets.is_va and "..." or "")) + for i, v in ipairs(t.rets.tuple) do + table.insert(rets, show(v) .. (i == #t.rets.tuple and t.rets.is_va and "..." or "")) end table.insert(out, table.concat(rets, ", ")) end @@ -5716,7 +5733,7 @@ local function init_globals(lax) return t end - local LOAD_FUNCTION = a_type("function", { args = {}, rets = a_type("tuple", { STRING }) }) + local LOAD_FUNCTION = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { STRING } }) }) local OS_DATE_TABLE = a_record({ fields = { @@ -5753,12 +5770,12 @@ local function init_globals(lax) local DEBUG_HOOK_EVENT = an_enum({ "call", "tail call", "return", "line", "count" }) - local DEBUG_HOOK_FUNCTION = a_type("function", { - args = a_type("tuple", { DEBUG_HOOK_EVENT, INTEGER }), - rets = a_type("tuple", {}), + local DEBUG_HOOK_FUNCTION = a_function({ + args = a_type("tuple", { tuple = { DEBUG_HOOK_EVENT, INTEGER } }), + rets = a_type("tuple", { tuple = {} }), }) - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = a_type("tuple", { a, a }), rets = a_type("tuple", { BOOLEAN }) } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a, a } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) } end) local metatable_nominals = {} @@ -5772,71 +5789,71 @@ local function init_globals(lax) ["..."] = a_vararg({ STRING }), ["any"] = a_type("typetype", { def = ANY }), ["arg"] = a_type("array", { elements = STRING }), - ["assert"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { a, OPT(b) }), rets = a_type("tuple", { a }) } end), + ["assert"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a, OPT(b) } }), rets = a_type("tuple", { tuple = { a } }) } end), ["collectgarbage"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { an_enum({ "collect", "count", "stop", "restart" }) }), rets = a_type("tuple", { NUMBER }) }), - a_type("function", { args = a_type("tuple", { an_enum({ "step", "setpause", "setstepmul" }), NUMBER }), rets = a_type("tuple", { NUMBER }) }), - a_type("function", { args = a_type("tuple", { an_enum({ "isrunning" }) }), rets = a_type("tuple", { BOOLEAN }) }), - a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER) }), rets = a_type("tuple", { a_type("union", { types = { BOOLEAN, NUMBER } }) }) }), + a_function({ args = a_type("tuple", { tuple = { an_enum({ "collect", "count", "stop", "restart" }) } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_function({ args = a_type("tuple", { tuple = { an_enum({ "step", "setpause", "setstepmul" }), NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_function({ args = a_type("tuple", { tuple = { an_enum({ "isrunning" }) } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { a_type("union", { types = { BOOLEAN, NUMBER } }) } }) }), } }), - ["dofile"] = a_type("function", { args = a_type("tuple", { OPT(STRING) }), rets = a_vararg({ ANY }) }), - ["error"] = a_type("function", { args = a_type("tuple", { ANY, OPT(NUMBER) }), rets = a_type("tuple", {}) }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a }), rets = a_type("tuple", { METATABLE(a) }) } end), - ["ipairs"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }) }), rets = a_type("tuple", { - a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { INTEGER, a }) }), -}), } end), - ["load"] = a_type("function", { args = a_type("tuple", { a_type("union", { types = { STRING, LOAD_FUNCTION } }), OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_type("tuple", { FUNCTION, STRING }) }), - ["loadfile"] = a_type("function", { args = a_type("tuple", { OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = a_type("tuple", { FUNCTION, STRING }) }), + ["dofile"] = a_function({ args = a_type("tuple", { tuple = { OPT(STRING) } }), rets = a_vararg({ ANY }) }), + ["error"] = a_function({ args = a_type("tuple", { tuple = { ANY, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = {} }) }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { METATABLE(a) } }) } end), + ["ipairs"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }) } }), rets = a_type("tuple", { tuple = { + a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { INTEGER, a } }) }), +} }), } end), + ["load"] = a_function({ args = a_type("tuple", { tuple = { a_type("union", { types = { STRING, LOAD_FUNCTION } }), OPT(STRING), OPT(STRING), OPT(TABLE) } }), rets = a_type("tuple", { tuple = { FUNCTION, STRING } }) }), + ["loadfile"] = a_function({ args = a_type("tuple", { tuple = { OPT(STRING), OPT(STRING), OPT(TABLE) } }), rets = a_type("tuple", { tuple = { FUNCTION, STRING } }) }), ["next"] = a_type("poly", { types = { - a_gfunction(2, function(a, b) return { args = a_type("tuple", { a_type("map", { keys = a, values = b }), OPT(a) }), rets = a_type("tuple", { a, b }) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), OPT(a) }), rets = a_type("tuple", { INTEGER, a }) } end), + a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a_type("map", { keys = a, values = b }), OPT(a) } }), rets = a_type("tuple", { tuple = { a, b } }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(a) } }), rets = a_type("tuple", { tuple = { INTEGER, a } }) } end), } }), - ["pairs"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { a_type("map", { keys = a, values = b }) }), rets = a_type("tuple", { - a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { a, b }) }), -}), } end), - ["pcall"] = a_type("function", { args = a_vararg({ FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["xpcall"] = a_type("function", { args = a_vararg({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["print"] = a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", {}) }), - ["rawequal"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), - ["rawget"] = a_type("function", { args = a_type("tuple", { TABLE, ANY }), rets = a_type("tuple", { ANY }) }), - ["rawlen"] = a_type("function", { args = a_type("tuple", { a_type("union", { types = { TABLE, STRING } }) }), rets = a_type("tuple", { INTEGER }) }), + ["pairs"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a_type("map", { keys = a, values = b }) } }), rets = a_type("tuple", { tuple = { + a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { a, b } }) }), +} }), } end), + ["pcall"] = a_function({ args = a_vararg({ FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["xpcall"] = a_function({ args = a_vararg({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["print"] = a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = {} }) }), + ["rawequal"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["rawget"] = a_function({ args = a_type("tuple", { tuple = { TABLE, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["rawlen"] = a_function({ args = a_type("tuple", { tuple = { a_type("union", { types = { TABLE, STRING } }) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), ["rawset"] = a_type("poly", { types = { - a_gfunction(2, function(a, b) return { args = a_type("tuple", { a_type("map", { keys = a, values = b }), a, b }), rets = a_type("tuple", {}) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, a }), rets = a_type("tuple", {}) } end), - a_type("function", { args = a_type("tuple", { TABLE, ANY, ANY }), rets = a_type("tuple", {}) }), + a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a_type("map", { keys = a, values = b }), a, b } }), rets = a_type("tuple", { tuple = {} }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, a } }), rets = a_type("tuple", { tuple = {} }) } end), + a_function({ args = a_type("tuple", { tuple = { TABLE, ANY, ANY } }), rets = a_type("tuple", { tuple = {} }) }), } }), - ["require"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", {}) }), + ["require"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = {} }) }), ["select"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = a_vararg({ NUMBER, a }), rets = a_type("tuple", { a }) } end), - a_type("function", { args = a_vararg({ NUMBER, ANY }), rets = a_type("tuple", { ANY }) }), - a_type("function", { args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { INTEGER }) }), + a_gfunction(1, function(a) return { args = a_vararg({ NUMBER, a }), rets = a_type("tuple", { tuple = { a } }) } end), + a_function({ args = a_vararg({ NUMBER, ANY }), rets = a_type("tuple", { tuple = { ANY } }) }), + a_function({ args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { tuple = { INTEGER } }) }), } }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a, METATABLE(a) }), rets = a_type("tuple", { a }) } end), + ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a, METATABLE(a) } }), rets = a_type("tuple", { tuple = { a } }) } end), ["tonumber"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { NUMBER }) }), - a_type("function", { args = a_type("tuple", { ANY, NUMBER }), rets = a_type("tuple", { INTEGER }) }), + a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_function({ args = a_type("tuple", { tuple = { ANY, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), } }), - ["tostring"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), - ["type"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), + ["tostring"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), ["FILE"] = a_type("typetype", { def = a_record({ is_userdata = true, fields = { - ["close"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE }), rets = a_type("tuple", { BOOLEAN, STRING, INTEGER }) }), - ["flush"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE }), rets = a_type("tuple", {}) }), + ["close"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING, INTEGER } }) }), + ["flush"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = {} }) }), ["lines"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_type("function", { args = ctor(args), rets = a_type("tuple", { - a_type("function", { args = a_type("tuple", {}), rets = ctor(rets) }), - }), }) + return a_function({ args = ctor(args), rets = a_type("tuple", { tuple = { + a_function({ args = a_type("tuple", { tuple = {} }), rets = ctor(rets) }), +} }), }) end), ["read"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_type("function", { args = ctor(args), rets = ctor(rets) }) + return a_function({ args = ctor(args), rets = ctor(rets) }) end), - ["seek"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { INTEGER, STRING }) }), - ["setvbuf"] = a_type("function", { args = a_type("tuple", { NOMINAL_FILE, STRING, OPT(NUMBER) }), rets = a_type("tuple", {}) }), - ["write"] = a_type("function", { args = a_vararg({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), + ["seek"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER, STRING } }) }), + ["setvbuf"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = {} }) }), + ["write"] = a_function({ args = a_vararg({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), }, meta_fields = { ["__close"] = FUNCTION }, @@ -5846,52 +5863,52 @@ local function init_globals(lax) ["metatable"] = a_type("typetype", { def = a_grecord(1, function(a) return { fields = { - ["__call"] = a_type("function", { args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), - ["__gc"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", {}) }), + ["__call"] = a_function({ args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), + ["__gc"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), ["__index"] = ANY, - ["__len"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", { ANY }) }), + ["__len"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { ANY } }) }), ["__mode"] = an_enum({ "k", "v", "kv" }), ["__newindex"] = ANY, ["__pairs"] = a_gfunction(2, function(k, v) return { - args = a_type("tuple", { a }), - rets = a_type("tuple", { a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { k, v }) }) }), + args = a_type("tuple", { tuple = { a } }), + rets = a_type("tuple", { tuple = { a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { k, v } }) }) } }), } end), - ["__tostring"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", { STRING }) }), + ["__tostring"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { STRING } }) }), ["__name"] = STRING, - ["__add"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__sub"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__mul"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__div"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__idiv"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__mod"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__pow"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__unm"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { ANY }) }), - ["__band"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__bor"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__bxor"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__bnot"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { ANY }) }), - ["__shl"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__shr"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__concat"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { ANY }) }), - ["__eq"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), - ["__lt"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), - ["__le"] = a_type("function", { args = a_type("tuple", { ANY, ANY }), rets = a_type("tuple", { BOOLEAN }) }), - ["__close"] = a_type("function", { args = a_type("tuple", { a }), rets = a_type("tuple", {}) }), + ["__add"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__sub"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__mul"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__div"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__idiv"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__mod"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__pow"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__unm"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__band"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__bor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__bxor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__bnot"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__shl"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__shr"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__concat"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__eq"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["__lt"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["__le"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["__close"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), }, } end), }), ["coroutine"] = a_record({ fields = { - ["create"] = a_type("function", { args = a_type("tuple", { FUNCTION }), rets = a_type("tuple", { THREAD }) }), - ["close"] = a_type("function", { args = a_type("tuple", { THREAD }), rets = a_type("tuple", { BOOLEAN, STRING }) }), - ["isyieldable"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { BOOLEAN }) }), - ["resume"] = a_type("function", { args = a_vararg({ THREAD, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["running"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { THREAD, BOOLEAN }) }), - ["status"] = a_type("function", { args = a_type("tuple", { THREAD }), rets = a_type("tuple", { STRING }) }), - ["wrap"] = a_type("function", { args = a_type("tuple", { FUNCTION }), rets = a_type("tuple", { FUNCTION }) }), - ["yield"] = a_type("function", { args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }), + ["create"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION } }), rets = a_type("tuple", { tuple = { THREAD } }) }), + ["close"] = a_function({ args = a_type("tuple", { tuple = { THREAD } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), + ["isyieldable"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["resume"] = a_function({ args = a_vararg({ THREAD, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), + ["running"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { THREAD, BOOLEAN } }) }), + ["status"] = a_function({ args = a_type("tuple", { tuple = { THREAD } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["wrap"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION } }), rets = a_type("tuple", { tuple = { FUNCTION } }) }), + ["yield"] = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }), }, }), ["debug"] = a_record({ @@ -5900,141 +5917,141 @@ local function init_globals(lax) ["Hook"] = a_type("typetype", { def = DEBUG_HOOK_FUNCTION }), ["HookEvent"] = a_type("typetype", { def = DEBUG_HOOK_EVENT }), - ["debug"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", {}) }), - ["gethook"] = a_type("function", { args = a_type("tuple", { OPT(THREAD) }), rets = a_type("tuple", { DEBUG_HOOK_FUNCTION, INTEGER }) }), + ["debug"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = {} }) }), + ["gethook"] = a_function({ args = a_type("tuple", { tuple = { OPT(THREAD) } }), rets = a_type("tuple", { tuple = { DEBUG_HOOK_FUNCTION, INTEGER } }) }), ["getlocal"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { THREAD, FUNCTION, NUMBER }), rets = STRING }), - a_type("function", { args = a_type("tuple", { THREAD, NUMBER, NUMBER }), rets = a_type("tuple", { STRING, ANY }) }), - a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER }), rets = STRING }), - a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { STRING, ANY }) }), + a_function({ args = a_type("tuple", { tuple = { THREAD, FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_function({ args = a_type("tuple", { tuple = { THREAD, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { STRING, ANY } }) }), + a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { STRING, ANY } }) }), } }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a }), rets = a_type("tuple", { METATABLE(a) }) } end), - ["getregistry"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { TABLE }) }), - ["getupvalue"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER }), rets = a_type("tuple", { ANY }) }), - ["getuservalue"] = a_type("function", { args = a_type("tuple", { USERDATA, NUMBER }), rets = a_type("tuple", { ANY }) }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { METATABLE(a) } }) } end), + ["getregistry"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { TABLE } }) }), + ["getupvalue"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["getuservalue"] = a_function({ args = a_type("tuple", { tuple = { USERDATA, NUMBER } }), rets = a_type("tuple", { tuple = { ANY } }) }), ["sethook"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_type("tuple", {}) }), - a_type("function", { args = a_type("tuple", { DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = a_type("tuple", {}) }), + a_function({ args = a_type("tuple", { tuple = { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER } }), rets = a_type("tuple", { tuple = {} }) }), + a_function({ args = a_type("tuple", { tuple = { DEBUG_HOOK_FUNCTION, STRING, NUMBER } }), rets = a_type("tuple", { tuple = {} }) }), } }), ["setlocal"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { THREAD, NUMBER, NUMBER, ANY }), rets = a_type("tuple", { STRING }) }), - a_type("function", { args = a_type("tuple", { NUMBER, NUMBER, ANY }), rets = a_type("tuple", { STRING }) }), + a_function({ args = a_type("tuple", { tuple = { THREAD, NUMBER, NUMBER, ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER, ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), } }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a, METATABLE(a) }), rets = a_type("tuple", { a }) } end), - ["setupvalue"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER, ANY }), rets = a_type("tuple", { STRING }) }), - ["setuservalue"] = a_type("function", { args = a_type("tuple", { USERDATA, ANY, NUMBER }), rets = a_type("tuple", { USERDATA }) }), + ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a, METATABLE(a) } }), rets = a_type("tuple", { tuple = { a } }) } end), + ["setupvalue"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER, ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["setuservalue"] = a_function({ args = a_type("tuple", { tuple = { USERDATA, ANY, NUMBER } }), rets = a_type("tuple", { tuple = { USERDATA } }) }), ["traceback"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { OPT(THREAD), OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), - a_type("function", { args = a_type("tuple", { OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + a_function({ args = a_type("tuple", { tuple = { OPT(THREAD), OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_function({ args = a_type("tuple", { tuple = { OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), } }), - ["upvalueid"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER }), rets = a_type("tuple", { USERDATA }) }), - ["upvaluejoin"] = a_type("function", { args = a_type("tuple", { FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = a_type("tuple", {}) }), + ["upvalueid"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { USERDATA } }) }), + ["upvaluejoin"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER, FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = {} }) }), ["getinfo"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { DEBUG_GETINFO_TABLE }) }), - a_type("function", { args = a_type("tuple", { ANY, STRING }), rets = a_type("tuple", { DEBUG_GETINFO_TABLE }) }), - a_type("function", { args = a_type("tuple", { ANY, ANY, STRING }), rets = a_type("tuple", { DEBUG_GETINFO_TABLE }) }), + a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { DEBUG_GETINFO_TABLE } }) }), + a_function({ args = a_type("tuple", { tuple = { ANY, STRING } }), rets = a_type("tuple", { tuple = { DEBUG_GETINFO_TABLE } }) }), + a_function({ args = a_type("tuple", { tuple = { ANY, ANY, STRING } }), rets = a_type("tuple", { tuple = { DEBUG_GETINFO_TABLE } }) }), } }), }, }), ["io"] = a_record({ fields = { - ["close"] = a_type("function", { args = a_type("tuple", { OPT(NOMINAL_FILE) }), rets = a_type("tuple", { BOOLEAN, STRING }) }), - ["flush"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", {}) }), - ["input"] = a_type("function", { args = a_type("tuple", { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }), rets = a_type("tuple", { NOMINAL_FILE }) }), + ["close"] = a_function({ args = a_type("tuple", { tuple = { OPT(NOMINAL_FILE) } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), + ["flush"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = {} }) }), + ["input"] = a_function({ args = a_type("tuple", { tuple = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE } }) }), ["lines"] = a_file_reader(function(ctor, args, rets) - return a_type("function", { args = ctor(args), rets = a_type("tuple", { - a_type("function", { args = a_type("tuple", {}), rets = ctor(rets) }), - }), }) + return a_function({ args = ctor(args), rets = a_type("tuple", { tuple = { + a_function({ args = a_type("tuple", { tuple = {} }), rets = ctor(rets) }), +} }), }) end), - ["open"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), - ["output"] = a_type("function", { args = a_type("tuple", { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }), rets = a_type("tuple", { NOMINAL_FILE }) }), - ["popen"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), + ["open"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), + ["output"] = a_function({ args = a_type("tuple", { tuple = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE } }) }), + ["popen"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), ["read"] = a_file_reader(function(ctor, args, rets) - return a_type("function", { args = ctor(args), rets = ctor(rets) }) + return a_function({ args = ctor(args), rets = ctor(rets) }) end), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { NOMINAL_FILE }) }), - ["type"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), - ["write"] = a_type("function", { args = a_vararg({ a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { NOMINAL_FILE, STRING }) }), + ["tmpfile"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { NOMINAL_FILE } }) }), + ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["write"] = a_function({ args = a_vararg({ a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), }, }), ["math"] = a_record({ fields = { ["abs"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { INTEGER }), rets = a_type("tuple", { INTEGER }) }), - a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + a_function({ args = a_type("tuple", { tuple = { INTEGER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), } }), - ["acos"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["asin"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["atan"] = a_type("function", { args = a_type("tuple", { NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { NUMBER }) }), - ["atan2"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["ceil"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { INTEGER }) }), - ["cos"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["cosh"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["deg"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["exp"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["floor"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { INTEGER }) }), + ["acos"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["asin"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["atan"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["atan2"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["ceil"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["cos"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["cosh"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["deg"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["exp"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["floor"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), ["fmod"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { INTEGER, INTEGER }), rets = a_type("tuple", { INTEGER }) }), - a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), + a_function({ args = a_type("tuple", { tuple = { INTEGER, INTEGER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), } }), - ["frexp"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER, NUMBER }) }), + ["frexp"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER, NUMBER } }) }), ["huge"] = NUMBER, - ["ldexp"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["log"] = a_type("function", { args = a_type("tuple", { NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { NUMBER }) }), - ["log10"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["ldexp"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["log"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["log10"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), ["max"] = a_type("poly", { types = { - a_type("function", { args = a_vararg({ INTEGER }), rets = a_type("tuple", { INTEGER }) }), - a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { a }) } end), - a_type("function", { args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { NUMBER }) }), - a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", { ANY }) }), + a_function({ args = a_vararg({ INTEGER }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { tuple = { a } }) } end), + a_function({ args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = { ANY } }) }), } }), ["maxinteger"] = a_type("integer", { needs_compat = true }), ["min"] = a_type("poly", { types = { - a_type("function", { args = a_vararg({ INTEGER }), rets = a_type("tuple", { INTEGER }) }), - a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { a }) } end), - a_type("function", { args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { NUMBER }) }), - a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", { ANY }) }), + a_function({ args = a_vararg({ INTEGER }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { tuple = { a } }) } end), + a_function({ args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = { ANY } }) }), } }), ["mininteger"] = a_type("integer", { needs_compat = true }), - ["modf"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { INTEGER, NUMBER }) }), + ["modf"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER, NUMBER } }) }), ["pi"] = NUMBER, - ["pow"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["rad"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), + ["pow"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["rad"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), ["random"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { INTEGER }) }), - a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { NUMBER }) }), + a_function({ args = a_type("tuple", { tuple = { NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { NUMBER } }) }), } }), - ["randomseed"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { INTEGER, INTEGER }) }), - ["sin"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["sinh"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["sqrt"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["tan"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["tanh"] = a_type("function", { args = a_type("tuple", { NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["tointeger"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { INTEGER }) }), - ["type"] = a_type("function", { args = a_type("tuple", { ANY }), rets = a_type("tuple", { STRING }) }), - ["ult"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { BOOLEAN }) }), + ["randomseed"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER, INTEGER } }) }), + ["sin"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["sinh"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["sqrt"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["tan"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["tanh"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["tointeger"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["ult"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), }, }), ["os"] = a_record({ fields = { - ["clock"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { NUMBER }) }), + ["clock"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { NUMBER } }) }), ["date"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { STRING }) }), - a_type("function", { args = a_type("tuple", { an_enum({ "!*t", "*t" }), OPT(NUMBER) }), rets = a_type("tuple", { OS_DATE_TABLE }) }), - a_type("function", { args = a_type("tuple", { OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_function({ args = a_type("tuple", { tuple = { an_enum({ "!*t", "*t" }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { OS_DATE_TABLE } }) }), + a_function({ args = a_type("tuple", { tuple = { OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), } }), - ["difftime"] = a_type("function", { args = a_type("tuple", { NUMBER, NUMBER }), rets = a_type("tuple", { NUMBER }) }), - ["execute"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { BOOLEAN, STRING, INTEGER }) }), - ["exit"] = a_type("function", { args = a_type("tuple", { OPT(a_type("union", { types = { NUMBER, BOOLEAN } })), OPT(BOOLEAN) }), rets = a_type("tuple", {}) }), - ["getenv"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), - ["remove"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { BOOLEAN, STRING }) }), - ["rename"] = a_type("function", { args = a_type("tuple", { STRING, STRING }), rets = a_type("tuple", { BOOLEAN, STRING }) }), - ["setlocale"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING) }), rets = a_type("tuple", { STRING }) }), - ["time"] = a_type("function", { args = a_type("tuple", { OPT(OS_DATE_TABLE) }), rets = a_type("tuple", { INTEGER }) }), - ["tmpname"] = a_type("function", { args = a_type("tuple", {}), rets = a_type("tuple", { STRING }) }), + ["difftime"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["execute"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING, INTEGER } }) }), + ["exit"] = a_function({ args = a_type("tuple", { tuple = { OPT(a_type("union", { types = { NUMBER, BOOLEAN } })), OPT(BOOLEAN) } }), rets = a_type("tuple", { tuple = {} }) }), + ["getenv"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["remove"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), + ["rename"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), + ["setlocale"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["time"] = a_function({ args = a_type("tuple", { tuple = { OPT(OS_DATE_TABLE) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["tmpname"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { STRING } }) }), }, }), ["package"] = a_record({ @@ -6042,75 +6059,75 @@ local function init_globals(lax) ["config"] = STRING, ["cpath"] = STRING, ["loaded"] = a_type("map", { keys = STRING, values = ANY }), - ["loaders"] = a_type("array", { elements = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { ANY, ANY }) }) }), - ["loadlib"] = a_type("function", { args = a_type("tuple", { STRING, STRING }), rets = a_type("tuple", { FUNCTION }) }), + ["loaders"] = a_type("array", { elements = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { ANY, ANY } }) }) }), + ["loadlib"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING } }), rets = a_type("tuple", { tuple = { FUNCTION } }) }), ["path"] = STRING, ["preload"] = TABLE, - ["searchers"] = a_type("array", { elements = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { ANY, ANY }) }) }), - ["searchpath"] = a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(STRING), OPT(STRING) }), rets = a_type("tuple", { STRING, STRING }) }), + ["searchers"] = a_type("array", { elements = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { ANY, ANY } }) }) }), + ["searchpath"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(STRING), OPT(STRING) } }), rets = a_type("tuple", { tuple = { STRING, STRING } }) }), }, }), ["string"] = a_record({ fields = { ["byte"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER) }), rets = a_type("tuple", { INTEGER }) }), - a_type("function", { args = a_type("tuple", { STRING, NUMBER, NUMBER }), rets = a_vararg({ INTEGER }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, NUMBER } }), rets = a_vararg({ INTEGER }) }), } }), - ["char"] = a_type("function", { args = a_vararg({ NUMBER }), rets = a_type("tuple", { STRING }) }), - ["dump"] = a_type("function", { args = a_type("tuple", { FUNCTION, OPT(BOOLEAN) }), rets = a_type("tuple", { STRING }) }), - ["find"] = a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }), rets = a_vararg({ INTEGER, INTEGER, STRING }) }), - ["format"] = a_type("function", { args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { STRING }) }), - ["gmatch"] = a_type("function", { args = a_type("tuple", { STRING, STRING }), rets = a_type("tuple", { - a_type("function", { args = a_type("tuple", {}), rets = a_vararg({ STRING }) }), - }), }), + ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["dump"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, OPT(BOOLEAN) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["find"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) } }), rets = a_vararg({ INTEGER, INTEGER, STRING }) }), + ["format"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["gmatch"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING } }), rets = a_type("tuple", { tuple = { + a_function({ args = a_type("tuple", { tuple = {} }), rets = a_vararg({ STRING }) }), +} }), }), ["gsub"] = a_type("poly", { types = { - a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("map", { keys = STRING, values = STRING }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), - a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", { STRING }) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), - a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", { NUMBER }) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), - a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", { BOOLEAN }) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), - a_type("function", { args = a_type("tuple", { STRING, STRING, a_type("function", { args = a_vararg({ STRING }), rets = a_type("tuple", {}) }), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), - a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(STRING), OPT(NUMBER) }), rets = a_type("tuple", { STRING, INTEGER }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_type("map", { keys = STRING, values = STRING }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = { STRING } }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = { NUMBER } }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = {} }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), + a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), } }), - ["len"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { INTEGER }) }), - ["lower"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), - ["match"] = a_type("function", { args = a_type("tuple", { STRING, OPT(STRING), OPT(NUMBER) }), rets = a_vararg({ STRING }) }), - ["pack"] = a_type("function", { args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { STRING }) }), - ["packsize"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { INTEGER }) }), - ["rep"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, OPT(STRING) }), rets = a_type("tuple", { STRING }) }), - ["reverse"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), - ["sub"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), - ["unpack"] = a_type("function", { args = a_type("tuple", { STRING, STRING, OPT(NUMBER) }), rets = a_vararg({ ANY }) }), - ["upper"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { STRING }) }), + ["len"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["lower"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["match"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING), OPT(NUMBER) } }), rets = a_vararg({ STRING }) }), + ["pack"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["packsize"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["rep"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, OPT(STRING) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["reverse"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["sub"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["unpack"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(NUMBER) } }), rets = a_vararg({ ANY }) }), + ["upper"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), }, }), ["table"] = a_record({ fields = { - ["concat"] = a_type("function", { args = a_type("tuple", { a_type("array", { elements = a_type("union", { types = { STRING, NUMBER } }) }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = a_type("tuple", { STRING }) }), + ["concat"] = a_function({ args = a_type("tuple", { tuple = { a_type("array", { elements = a_type("union", { types = { STRING, NUMBER } }) }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), ["insert"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, a }), rets = a_type("tuple", {}) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), a }), rets = a_type("tuple", {}) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, a } }), rets = a_type("tuple", { tuple = {} }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), a } }), rets = a_type("tuple", { tuple = {} }) } end), } }), ["move"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER }), rets = a_type("tuple", { a_type("array", { elements = a }) }) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER, a_type("array", { elements = a }) }), rets = a_type("tuple", { a_type("array", { elements = a }) }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { a_type("array", { elements = a }) } }) } end), + a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER, a_type("array", { elements = a }) } }), rets = a_type("tuple", { tuple = { a_type("array", { elements = a }) } }) } end), } }), - ["pack"] = a_type("function", { args = a_vararg({ ANY }), rets = a_type("tuple", { TABLE }) }), - ["remove"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), OPT(NUMBER) }), rets = a_type("tuple", { a }) } end), - ["sort"] = a_gfunction(1, function(a) return { args = a_type("tuple", { a_type("array", { elements = a }), OPT(TABLE_SORT_FUNCTION) }), rets = a_type("tuple", {}) } end), - ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = a_type("tuple", { a_type("array", { elements = a }), OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ a }) } end), + ["pack"] = a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = { TABLE } }) }), + ["remove"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { a } }) } end), + ["sort"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(TABLE_SORT_FUNCTION) } }), rets = a_type("tuple", { tuple = {} }) } end), + ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(NUMBER), OPT(NUMBER) } }), rets = a_vararg({ a }) } end), }, }), ["utf8"] = a_record({ fields = { - ["char"] = a_type("function", { args = a_vararg({ NUMBER }), rets = a_type("tuple", { STRING }) }), + ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_type("tuple", { tuple = { STRING } }) }), ["charpattern"] = STRING, - ["codepoint"] = a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER), OPT(NUMBER) }), rets = a_vararg({ INTEGER }) }), - ["codes"] = a_type("function", { args = a_type("tuple", { STRING }), rets = a_type("tuple", { - a_type("function", { args = a_type("tuple", { STRING, OPT(NUMBER) }), rets = a_type("tuple", { NUMBER, NUMBER }) }), - }), }), - ["len"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, NUMBER }), rets = a_type("tuple", { INTEGER }) }), - ["offset"] = a_type("function", { args = a_type("tuple", { STRING, NUMBER, NUMBER }), rets = a_type("tuple", { INTEGER }) }), + ["codepoint"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER), OPT(NUMBER) } }), rets = a_vararg({ INTEGER }) }), + ["codes"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { + a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { NUMBER, NUMBER } }) }), +} }), }), + ["len"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["offset"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), }, }), ["_VERSION"] = STRING, @@ -6399,7 +6416,7 @@ tl.type_check = function(ast, opts) if is_typetype(t) then return union_type(t.def), t.def elseif t.typename == "tuple" then - return union_type(t[1]), t[1] + return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then local typetype = t.found or find_type(t.names) if not typetype then @@ -6504,7 +6521,7 @@ tl.type_check = function(ast, opts) if f.min_arity then return end - local tuple = f.args + local tuple = f.args.tuple local n = #tuple if f.args.is_va then n = n - 1 @@ -6520,9 +6537,10 @@ tl.type_check = function(ast, opts) end local function show_arity(f) - return f.min_arity < #f.args and - "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. #f.args) or - tostring(#f.args or 0) + local nfargs = #f.args.tuple + return f.min_arity < nfargs and + "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) or + tostring(nfargs or 0) end local function resolve_typetype(t) @@ -6603,10 +6621,6 @@ tl.type_check = function(ast, opts) copy.xend = t.xend copy.names = t.names - for i, tf in ipairs(t) do - copy[i], same = resolve(tf, same) - end - if t.typename == "array" then copy.elements, same = resolve(t.elements, same) @@ -6686,6 +6700,10 @@ tl.type_check = function(ast, opts) end elseif t.typename == "tuple" then copy.is_va = t.is_va + copy.tuple = {} + for i, tf in ipairs(t.tuple) do + copy.tuple[i], same = resolve(tf, same) + end end copy.typeid = same and orig_t.typeid or new_typeid() @@ -6722,7 +6740,7 @@ tl.type_check = function(ast, opts) local function resolve_tuple(t) if t.typename == "tuple" then - t = t[1] + t = t.tuple[1] end if t == nil then return NIL @@ -6821,14 +6839,24 @@ tl.type_check = function(ast, opts) end end + local function type_at(w, t) + t.x = w.x + t.y = w.y + t.filename = filename + return t + end + local function resolve_typevars_at(where, t) assert(where) - local ok, typ, errs = resolve_typevars(t) + local ok, ret, errs = resolve_typevars(t) if not ok then assert(where.y) add_errs_prefixing(where, errs, errors, "") end - return typ + if ret == t or t.typename == "typevar" then + ret = shallow_copy_table(ret) + end + return type_at(where, ret) end local function infer_at(where, t) @@ -6836,7 +6864,9 @@ tl.type_check = function(ast, opts) if ret.typename == "invalid" then ret = t end - ret = (ret ~= t) and ret or shallow_copy_table(t) + if ret == t or t.typename == "typevar" then + ret = shallow_copy_table(ret) + end ret.inferred_at = where ret.inferred_at.filename = filename return ret @@ -7573,21 +7603,23 @@ tl.type_check = function(ast, opts) ["function"] = { ["function"] = function(a, b) local argdelta = a.is_method and 1 or 0 - if #a.args ~= #b.args then + local naargs, nbargs = #a.args.tuple, #b.args.tuple + if naargs ~= nbargs then if a.is_method ~= b.is_method then return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. #a.args - argdelta .. ", expected " .. #b.args - argdelta) } + return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end - if #a.rets ~= #b.rets then - return false, { Err(a, "different number of return values: got " .. #a.rets .. ", expected " .. #b.rets) } + local narets, nbrets = #a.rets.tuple, #b.rets.tuple + if narets ~= nbrets then + return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} - for i = 1, #a.args do - arg_check(a, same_type, a.args[i], b.args[i], i - argdelta, errs, "argument") + for i = 1, naargs do + arg_check(a, same_type, a.args.tuple[i], b.args.tuple[i], i - argdelta, errs, "argument") end - for i = 1, #a.rets do - arg_check(a, same_type, a.rets[i], b.rets[i], i, errs, "return") + for i = 1, narets do + arg_check(a, same_type, a.rets.tuple[i], b.rets.tuple[i], i, errs, "return") end return any_errors(errs) end, @@ -7607,11 +7639,12 @@ tl.type_check = function(ast, opts) }, ["tuple"] = { ["tuple"] = function(a, b) - if #a ~= #b then + local at, bt = a.tuple, b.tuple + if #at ~= #bt then return false end - for i = 1, #a do - if not is_a(a[i], b[i]) then + for i = 1, #at do + if not is_a(at[i], bt[i]) then return false end end @@ -7821,21 +7854,21 @@ a.types[i], b.types[i]), } ["function"] = function(a, b) local errs = {} - local aa, ba = a.args, b.args + local aa, ba = a.args.tuple, b.args.tuple set_min_arity(a) set_min_arity(b) - if (not ba.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", aa, ba)) + if (not b.args.is_va) and a.min_arity > b.min_arity then + table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do arg_check(nil, is_a, aa[i], ba[i] or ANY, i, errs, "argument") end end - local ar, br = a.rets, b.rets - local diff_by_va = #br - #ar == 1 and br.is_va + local ar, br = a.rets.tuple, b.rets.tuple + local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", ar, br)) + table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then @@ -7853,7 +7886,7 @@ a.types[i], b.types[i]), } ["bad_nominal"] = compare_false, ["any"] = compare_true, ["tuple"] = function(a, b) - return is_a(a_type("tuple", { a }), b) + return is_a(a_type("tuple", { tuple = { a } }), b) end, ["typevar"] = function(a, b) return compare_or_infer_typevar(b.typevar, a, nil, is_a) @@ -8054,9 +8087,9 @@ a.types[i], b.types[i]), } t = resolve_tuple_and_nominal(t) local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt then - local args_tuple = a_type("tuple", {}) - for i = 2, #call_mt.args do - table.insert(args_tuple, call_mt.args[i]) + local args_tuple = a_type("tuple", { tuple = {} }) + for i = 2, #call_mt.args.tuple do + table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end return args_tuple, call_mt end @@ -8066,7 +8099,7 @@ a.types[i], b.types[i]), } local function resolve_for_call(func, args, is_method) if lax and is_unknown(func) then - func = a_type("function", { args = a_vararg({ UNKNOWN }), rets = a_vararg({ UNKNOWN }) }) + func = a_function({ args = a_vararg({ UNKNOWN }), rets = a_vararg({ UNKNOWN }) }) end func = resolve_tuple_and_nominal(func) @@ -8075,7 +8108,7 @@ a.types[i], b.types[i]), } if func.typename == "union" then local r = same_call_mt_in_all_union_entries(func) if r then - table.insert(args, 1, func.types[1]) + table.insert(args.tuple, 1, func.types[1]) return resolve_tuple_and_nominal(r), true end end @@ -8085,7 +8118,7 @@ a.types[i], b.types[i]), } end if func.meta_fields and func.meta_fields["__call"] then - table.insert(args, 1, func) + table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] func = resolve_tuple_and_nominal(func) is_method = true @@ -8201,18 +8234,19 @@ a.types[i], b.types[i]), } assert(xs.typename == "tuple") assert(ys.typename == "tuple") - local n_xs = #xs - local n_ys = #ys + local xt, yt = xs.tuple, ys.tuple + local n_xs = #xt + local n_ys = #yt for i = 1, n_xs do - local x = xs[i] + local x = xt[i] if x.typename == "emptytable" or x.typename == "unresolved_emptytable_value" then - local y = ys[i] or (ys.is_va and ys[n_ys]) + local y = yt[i] or (ys.is_va and yt[n_ys]) if y then local w = wheres and wheres[i + delta] or where local inferred_y = infer_at(w, y) infer_emptytable(x, inferred_y) - xs[i] = inferred_y + xt[i] = inferred_y end end end @@ -8226,13 +8260,14 @@ a.types[i], b.types[i]), } assert(ys.typename == "tuple", ys.typename) local errs = {} - local n_xs = #xs - local n_ys = #ys + local xt, yt = xs.tuple, ys.tuple + local n_xs = #xt + local n_ys = #yt for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xs[i] or (xs.is_va and xs[n_xs]) or NIL - local y = ys[i] or (ys.is_va and ys[n_ys]) + local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local y = yt[i] or (ys.is_va and yt[n_ys]) if y then local w = wheres and wheres[pos] or where if not arg_check(w, is_a, x, y, pos, errs, mode) then @@ -8244,26 +8279,27 @@ a.types[i], b.types[i]), } return true end - check_args_rets = function(where, where_args, f, args, rets, argdelta) + check_args_rets = function(where, where_args, f, args, expected_rets, argdelta) local rets_ok = true local rets_errs local args_ok local args_errs + local fargs = f.args.tuple local from = 1 if argdelta == -1 then from = 2 local errs = {} - if (not is_self(f.args[1])) and not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then + if (not is_self(fargs[1])) and not arg_check(where, is_a, args.tuple[1], fargs[1], nil, errs, "self") then return nil, errs end end - if rets then - rets = infer_at(where, rets) - infer_emptytables(where, nil, rets, f.rets, 0) + if expected_rets then + expected_rets = infer_at(where, expected_rets) + infer_emptytables(where, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, rets, 1, 0, "return") + rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "return") end args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, from, argdelta, "argument") @@ -8332,7 +8368,7 @@ a.types[i], b.types[i]), } return resolve_typevars_at(where, f.rets) end - local function check_call(where, where_args, func, args, expected, typetype_funcall, is_method, argdelta) + local function check_call(where, where_args, func, args, expected_rets, typetype_funcall, is_method, argdelta) assert(type(func) == "table") assert(type(args) == "table") @@ -8342,8 +8378,8 @@ a.types[i], b.types[i]), } argdelta = is_method and -1 or argdelta or 0 - if is_method and args[1] then - add_var(nil, "@self", a_type("typetype", { y = where.y, x = where.x, def = args[1] })) + if is_method and args.tuple[1] then + add_var(nil, "@self", type_at(where, a_type("typetype", { def = args.tuple[1] }))) end local is_func = func.typename == "function" @@ -8357,15 +8393,16 @@ a.types[i], b.types[i]), } passes, n = 3, #func.types end - local given = #args + local given = #args.tuple local tried local first_errs for pass = 1, passes do for i = 1, n do if (not tried) or not tried[i] then local f = is_func and func or func.types[i] + local fargs = f.args.tuple if f.is_method and not is_method then - if args[1] and is_a(args[1], f.args[1]) then + if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then if not typetype_funcall then add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") @@ -8374,7 +8411,7 @@ a.types[i], b.types[i]), } return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") end end - local wanted = #f.args + local wanted = #fargs set_min_arity(f) @@ -8388,14 +8425,14 @@ a.types[i], b.types[i]), } push_typeargs(f) - local matched, errs = check_args_rets(where, where_args, f, args, expected, argdelta) + local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) if matched then return matched, f end first_errs = first_errs or errs - if expected then + if expected_rets then infer_emptytables(where, where_args, f.rets, f.rets, argdelta) end @@ -8415,7 +8452,7 @@ a.types[i], b.types[i]), } type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) if node.expected and node.expected.typename ~= "tuple" then - node.expected = a_type("tuple", { node.expected }) + node.expected = a_type("tuple", { tuple = { node.expected } }) end begin_scope() @@ -8464,10 +8501,10 @@ a.types[i], b.types[i]), } if metamethod then local where_args = { node.e1 } - local args = a_type("tuple", { orig_a }) + local args = a_type("tuple", { tuple = { orig_a } }) if b and method_name ~= "__is" then where_args[2] = node.e2 - args[2] = orig_b + args.tuple[2] = orig_b end return resolve_tuple_and_nominal((type_check_function_call(node, where_args, metamethod, args, nil, true))), meta_on_operator else @@ -8648,24 +8685,25 @@ a.types[i], b.types[i]), } return st[1][var] end - local function get_rets(rets) - if lax and (#rets == 0) then - return a_vararg({ UNKNOWN }) + local get_rets + if lax then + get_rets = function(rets) + if #rets.tuple == 0 then + return a_vararg({ UNKNOWN }) + end + return rets end - local t = rets - if not t.typename then - - t = a_type("tuple", t) + else + get_rets = function(rets) + return rets end - assert(t.typeid) - return t end local function add_internal_function_variables(node, args) assert(args.typename == "tuple") add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_type("tuple", {})) + add_var(nil, "@return", node.rets or a_type("tuple", { tuple = {} })) if node.typeargs then for _, t in ipairs(node.typeargs) do @@ -8680,13 +8718,14 @@ a.types[i], b.types[i]), } local function add_function_definition_for_recursion(node, fnargs) assert(fnargs.typename == "tuple") - local args = a_type("tuple", {}) + + local args = a_type("tuple", { tuple = {} }) args.is_va = fnargs.is_va - for _, fnarg in ipairs(fnargs) do - table.insert(args, fnarg) + for _, fnarg in ipairs(fnargs.tuple) do + table.insert(args.tuple, fnarg) end - add_var(nil, node.name.tk, a_type("function", { + add_var(nil, node.name.tk, a_function({ typeargs = node.typeargs, args = args, rets = get_rets(node.rets), @@ -8729,10 +8768,10 @@ a.types[i], b.types[i]), } end local function flatten_tuple(vals) - local vt = vals + local vt = vals.tuple local n_vals = #vt - local ret = a_type("tuple", {}) - local rt = ret + local ret = a_type("tuple", { tuple = {} }) + local rt = ret.tuple if n_vals == 0 then return ret @@ -8746,7 +8785,7 @@ a.types[i], b.types[i]), } local last = vt[n_vals] if last.typename == "tuple" then - local lt = last + local lt = last.tuple for _, v in ipairs(lt) do table.insert(rt, v) end @@ -8761,15 +8800,15 @@ a.types[i], b.types[i]), } local function get_assignment_values(vals, wanted) if vals == nil then - return a_type("tuple", {}) + return a_type("tuple", { tuple = {} }) end local ret = flatten_tuple(vals) if ret.is_va then - local n_ret = #ret - local rt = ret + local rt = ret.tuple + local n_ret = #rt if n_ret > 0 and n_ret < wanted then local last = rt[n_ret] for _ = n_ret + 1, wanted do @@ -8838,11 +8877,7 @@ a.types[i], b.types[i]), } end if is_a(orig_b, a.keys) then - return a_type("unresolved_emptytable_value", { - y = anode.y, - x = anode.x, - emptytable_type = a, - }) + return type_at(anode, a_type("unresolved_emptytable_value", { emptytable_type = a })) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. @@ -8975,14 +9010,12 @@ a.types[i], b.types[i]), } table.insert(typevals, a_type("typevar", { typevar = a.typearg })) end end - return a_type("nominal", { - y = where.y, - x = where.x, + return type_at(where, a_type("nominal", { typevals = typevals, names = { name }, found = t, resolved = resolved, - }) + })) end local function get_self_type(exp) @@ -9381,18 +9414,18 @@ a.types[i], b.types[i]), } local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_type("tuple", { BOOLEAN }) + return a_type("tuple", { tuple = { BOOLEAN } }) end - local ftype = table.remove(b, 1) + local ftype = table.remove(b.tuple, 1) ftype = shallow_copy_new_type(ftype) ftype.is_method = false local fe2 = {} if node.e1.tk == "xpcall" then base_nargs = 2 - local msgh = table.remove(b, 1) + local msgh = table.remove(b.tuple, 1) assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") end for i = base_nargs + 1, #node.e2 do @@ -9407,20 +9440,20 @@ a.types[i], b.types[i]), } e2 = fe2, } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) - if rets.typename ~= "tuple" then - - rets = a_type("tuple", { rets }) + if rets == INVALID then + return rets end - table.insert(rets, 1, BOOLEAN) + assert(rets and rets.typename == "tuple", show_type(rets)) + table.insert(rets.tuple, 1, BOOLEAN) return rets end local special_functions = { ["pairs"] = function(node, a, b, argdelta) - if not b[1] then + if not b.tuple[1] then return invalid_at(node, "pairs requires an argument") end - local t = resolve_tuple_and_nominal(b[1]) + local t = resolve_tuple_and_nominal(b.tuple[1]) if is_array_type(t) then add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end @@ -9442,10 +9475,10 @@ a.types[i], b.types[i]), } end, ["ipairs"] = function(node, a, b, argdelta) - if not b[1] then + if not b.tuple[1] then return invalid_at(node, "ipairs requires an argument") end - local orig_t = b[1] + local orig_t = b.tuple[1] local t = resolve_tuple_and_nominal(orig_t) if t.typename == "tupletable" then @@ -9464,15 +9497,15 @@ a.types[i], b.types[i]), } ["rawget"] = function(node, _a, b, _argdelta) - if #b == 2 then - return type_check_index(node.e2[1], node.e2[2], b[1], b[2]) + if #b.tuple == 2 then + return a_type("tuple", { tuple = { type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) else return invalid_at(node, "rawget expects two arguments") end end, ["require"] = function(node, _a, b, _argdelta) - if #b ~= 1 then + if #b.tuple ~= 1 then return invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then @@ -9493,7 +9526,7 @@ a.types[i], b.types[i]), } end dependencies[module_name] = t.filename - return t + return type_at(node, a_type("tuple", { tuple = { t } })) end, ["pcall"] = special_pcall_xpcall, @@ -9517,7 +9550,7 @@ a.types[i], b.types[i]), } return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then - table.insert(b, 1, node.e1.receiver) + table.insert(b.tuple, 1, node.e1.receiver) return (type_check_function_call(node, node.e2, a, b, node.e1, true)) else return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) @@ -9563,8 +9596,10 @@ a.types[i], b.types[i]), } end end - local function set_expected_types_to_decltypes(node, children) - local decls = node.kind == "assignment" and children[1] or node.decltype + local function set_expected_types_to_decltuple(node, children) + local decltuple = node.kind == "assignment" and children[1] or node.decltuple + assert(decltuple.typename == "tuple") + local decls = decltuple.tuple if decls and node.exps then local ndecl = #decls local nexps = #node.exps @@ -9573,9 +9608,9 @@ a.types[i], b.types[i]), } typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = a_type("tuple", { y = node.y, x = node.x, filename = filename }) + typ = type_at(node, a_type("tuple", { tuple = {} })) for a = i, ndecl do - table.insert(typ, decls[a]) + table.insert(typ.tuple, decls[a]) end end node.exps[i].expected = typ @@ -9621,11 +9656,7 @@ a.types[i], b.types[i]), } end local function infer_table_literal(node, children) - local typ = a_type("emptytable", { - filename = filename, - y = node.y, - x = node.x, - }) + local typ = type_at(node, a_type("emptytable", {})) local is_record = false local is_array = false @@ -9673,7 +9704,7 @@ a.types[i], b.types[i]), } if node[i].key_parsed == "implicit" then if i == #children and child.vtype.typename == "tuple" then - for _, c in ipairs(child.vtype) do + for _, c in ipairs(child.vtype.tuple) do typ.elements = expand_type(node, typ.elements, c) typ.types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 @@ -9719,12 +9750,7 @@ a.types[i], b.types[i]), } elseif is_record and is_array then typ.typename = "record" typ.interface_list = { - a_type("array", { - filename = filename, - y = node.y, - x = node.x, - elements = typ.elements, - }), + type_at(node, a_type("array", { elements = typ.elements })), } elseif is_record and is_map then @@ -9790,12 +9816,12 @@ a.types[i], b.types[i]), } local function determine_declaration_type(var, node, infertypes, i) local ok = true local name = var.tk - local infertype = infertypes and infertypes[i] + local infertype = infertypes and infertypes.tuple[i] if lax and infertype and infertype.typename == "nil" then infertype = nil end - local decltype = node.decltype and node.decltype[i] + local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then if resolve_tuple_and_nominal(decltype) == INVALID then decltype = INVALID @@ -9859,8 +9885,15 @@ a.types[i], b.types[i]), } end local function get_type_declaration(node) - if node.value.kind == "op" and node.value.op.op == "@funcall" then - return special_functions["require"](node.value, find_var_type("require"), { STRING }, 0) + if node.value.kind == "op" and + node.value.op.op == "@funcall" and + node.value.e1.kind == "variable" and + node.value.e1.tk == "require" then + + local t = special_functions["require"](node.value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) + if t ~= INVALID then + return t.tuple[1] + end else return resolve_nominal_typetype(node.value.newtype) end @@ -10037,7 +10070,7 @@ a.types[i], b.types[i]), } end end end, - before_exp = set_expected_types_to_decltypes, + before_exp = set_expected_types_to_decltuple, after = function(node, children) local encountered_close = false local infertypes = get_assignment_values(children[3], #node.vars) @@ -10067,9 +10100,9 @@ a.types[i], b.types[i]), } assert(var) add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") - if ok and infertypes and infertypes[i] then + local infertype = infertypes.tuple[i] + if ok and infertype then local where = node.exps[i] or node.exps - local infertype = infertypes[i] local rt = resolve_tuple_and_nominal(t) if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then @@ -10088,7 +10121,7 @@ a.types[i], b.types[i]), } end, }, ["global_declaration"] = { - before_exp = set_expected_types_to_decltypes, + before_exp = set_expected_types_to_decltuple, after = function(node, children) local infertypes = get_assignment_values(children[3], #node.vars) for i, var in ipairs(node.vars) do @@ -10106,20 +10139,22 @@ a.types[i], b.types[i]), } end, }, ["assignment"] = { - before_exp = set_expected_types_to_decltypes, + before_exp = set_expected_types_to_decltuple, after = function(node, children) - local valtypes = get_assignment_values(children[3], #children[1]) - for i, vartype in ipairs(children[1]) do + local vartypes = children[1].tuple + local valtypes = get_assignment_values(children[3], #vartypes) + for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk - local rvar, rval, err = check_assignment(varnode, vartype, valtypes[i], varname, varnode.attribute) + local valtype = valtypes.tuple[i] + local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local rets = children[3] if rets.typename == "tuple" then - local msg = #rets == 1 and + local msg = #rets.tuple == 1 and "only 1 value is returned by the function" or - ("only " .. #rets .. " values are returned by the function") + ("only " .. #rets.tuple .. " values are returned by the function") add_warning("hint", varnode, msg) end end @@ -10137,7 +10172,7 @@ a.types[i], b.types[i]), } end if store_type then - store_type(varnode.y, varnode.x, valtypes[i]) + store_type(varnode.y, varnode.x, valtype) end end end @@ -10204,7 +10239,7 @@ a.types[i], b.types[i]), } error_at(node, "label '" .. node.label .. "' already defined at " .. filename) end local unresolved = st[#st]["@unresolved"] - local var = add_var(node, label_id, a_type("none", { y = node.y, x = node.x })) + local var = add_var(node, label_id, type_at(node, a_type("none", {}))) if unresolved then if unresolved.t.labels[node.label] then var.used = true @@ -10240,14 +10275,14 @@ a.types[i], b.types[i]), } begin_scope(node) end, before_statements = function(node, children) - local exptypes = children[2] + local exptypes = children[2].tuple widen_all_unions(node) local exp1 = node.exps[1] - local args = a_type("tuple", { + local args = a_type("tuple", { tuple = { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], - }) + } }) local exp1type = resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then @@ -10260,7 +10295,7 @@ a.types[i], b.types[i]), } local last local rets = exp1type.rets for i, v in ipairs(node.vars) do - local r = rets[i] + local r = rets.tuple[i] if not r then if rets.is_va then r = last @@ -10271,8 +10306,8 @@ a.types[i], b.types[i]), } add_var(v, v.tk, r) last = r end - if (not lax) and (not rets.is_va and #node.vars > #rets) then - local nrets = #rets + local nrets = #rets.tuple + if (not lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" error_at(at, "too many variables for this iterator; it produces " .. n_values) @@ -10306,36 +10341,42 @@ a.types[i], b.types[i]), } local rets = find_var_type("@return") if rets then for i, exp in ipairs(node.exps) do - exp.expected = rets[i] + exp.expected = rets.tuple[i] end end end, after = function(node, children) + local got = children[1] + local got_t = got.tuple + local n_got = #got_t + node.block_returns = true - local rets = find_var_type("@return") - if not rets then + local expected = find_var_type("@return") + if not expected then - rets = infer_at(node, children[1]) - module_type = resolve_tuple_and_nominal(rets) + expected = infer_at(node, got) + module_type = resolve_tuple_and_nominal(expected) module_type.tk = nil - st[2]["@return"] = { t = rets } + st[2]["@return"] = { t = expected } end + local expected_t = expected.tuple + local what = "in return value" - if rets.inferred_at then - what = what .. inferred_msg(rets) + if expected.inferred_at then + what = what .. inferred_msg(expected) end - local nrets = #rets + local n_expected = #expected_t local vatype - if nrets > 0 then - vatype = rets.is_va and rets[nrets] + if n_expected > 0 then + vatype = expected.is_va and expected.tuple[n_expected] end - if #children[1] > nrets and (not lax) and not vatype then - error_at(node, what .. ": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) + if n_got > n_expected and (not lax) and not vatype then + error_at(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end - if nrets > 1 and + if n_expected > 1 and #node.exps == 1 and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and @@ -10343,15 +10384,15 @@ a.types[i], b.types[i]), } add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end - for i = 1, #children[1] do - local expected = rets[i] or vatype - if expected then - expected = resolve_tuple(expected) + for i = 1, n_got do + local e = expected_t[i] or vatype + if e then + e = resolve_tuple(e) local where = (node.exps[i] and node.exps[i].x) and node.exps[i] or node.exps assert(where and where.x) - assert_is_a(where, children[1][i], expected, what) + assert_is_a(where, got_t[i], e, what) end end @@ -10360,11 +10401,11 @@ a.types[i], b.types[i]), } }, ["variable_list"] = { after = function(node, children) - local tuple = a_type("tuple", children) + local tuple = a_type("tuple", { tuple = children }) tuple = flatten_tuple(tuple) - for i, t in ipairs(tuple) do + for i, t in ipairs(tuple.tuple) do ensure_not_abstract(node[i], t) end @@ -10484,7 +10525,7 @@ a.types[i], b.types[i]), } elseif is_array and is_number_type(child.ktype) then if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then - for ti, tt in ipairs(child.vtype) do + for ti, tt in ipairs(child.vtype.tuple) do assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) end else @@ -10533,22 +10574,20 @@ a.types[i], b.types[i]), } local kname = node.key.conststr local ktype = children[1] local vtype = children[2] - if node.decltype then - vtype = node.decltype - assert_is_a(node.value, children[2], node.decltype, "in table item") + if node.itemtype then + vtype = node.itemtype + assert_is_a(node.value, children[2], node.itemtype, "in table item") end if vtype.is_method then vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return a_type("table_item", { - y = node.y, - x = node.x, + return type_at(node, a_type("table_item", { kname = kname, ktype = ktype, vtype = vtype, - }) + })) end, }, ["local_function"] = { @@ -10568,7 +10607,7 @@ a.types[i], b.types[i]), } end_function_scope(node) local rets = get_rets(children[3]) - local t = ensure_fresh_typeargs(a_type("function", { + local t = ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, @@ -10595,7 +10634,7 @@ a.types[i], b.types[i]), } check_macroexp_arg_use(node.macrodef) - local t = ensure_fresh_typeargs(a_type("function", { + local t = ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, @@ -10637,7 +10676,7 @@ a.types[i], b.types[i]), } return NONE end - add_global(node, node.name.tk, ensure_fresh_typeargs(a_type("function", { + add_global(node, node.name.tk, ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, @@ -10660,16 +10699,13 @@ a.types[i], b.types[i]), } if rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, a_type("typearg", { - y = typ.y, - x = typ.x, - typearg = typ.typearg, - })) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) end end end, before_statements = function(node, children) local args = children[3] + assert(args.typename == "tuple") local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) @@ -10694,11 +10730,11 @@ a.types[i], b.types[i]), } error_at(node, "could not resolve type of self") return end - args[1] = selftype + args.tuple[1] = selftype add_var(nil, "self", selftype) end - local fn_type = ensure_fresh_typeargs(a_type("function", { + local fn_type = ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, is_method = node.is_method, @@ -10768,7 +10804,7 @@ a.types[i], b.types[i]), } end_function_scope(node) - return ensure_fresh_typeargs(a_type("function", { + return ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, @@ -10791,7 +10827,7 @@ a.types[i], b.types[i]), } end_function_scope(node) - return ensure_fresh_typeargs(a_type("function", { + return ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, @@ -10842,7 +10878,7 @@ a.types[i], b.types[i]), } if node.expected then is_a(e1type.rets, node.expected) end - local e1args = e1type.args + local e1args = e1type.args.tuple local at = argdelta for _, typ in ipairs(e1args) do at = at + 1 @@ -10850,7 +10886,7 @@ a.types[i], b.types[i]), } node.e2[at].expected = typ end end - if e1args.is_va then + if e1type.args.is_va then local typ = e1args[#e1args] for i = at + 1, #node.e2 do node.e2[i].expected = typ @@ -10886,7 +10922,8 @@ a.types[i], b.types[i]), } add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - return type_check_funcall(node, a, b) + local t = type_check_funcall(node, a, b) + return t end ensure_not_abstract(node.e1, ra) @@ -10911,11 +10948,7 @@ a.types[i], b.types[i]), } kind = "string", conststr = node.e2.tk, } - local btype = a_type("string", { - y = node.e2.y, - x = node.e2.x, - tk = '"' .. node.e2.tk .. '"', - }) + local btype = type_at(node.e2, a_type("string", { tk = '"' .. node.e2.tk .. '"' })) local t = type_check_index(node.e1, bnode, orig_a, btype) if t.needs_compat and opts.gen_compat ~= "off" then @@ -11273,11 +11306,7 @@ a.types[i], b.types[i]), } local function after_literal(node) node.known = FACT_TRUTHY - return a_type(node.kind, { - y = node.y, - x = node.x, - tk = node.tk, - }) + return type_at(node, a_type(node.kind, { tk = node.tk })) end visit_node.cbs["string"] = { @@ -11347,7 +11376,7 @@ a.types[i], b.types[i]), } ["record"] = { before = function(typ) begin_scope() - add_var(nil, "@self", a_type("typetype", { y = typ.y, x = typ.x, def = typ })) + add_var(nil, "@self", type_at(typ, a_type("typetype", { def = typ }))) for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then @@ -11375,20 +11404,22 @@ a.types[i], b.types[i]), } end for name, _ in fields_of(typ) do local ftype = children[i] - - if ftype.is_method and ftype.args and ftype.args[1] and ftype.args[1].is_self then - local record_name = typ.names and typ.names[1] - if record_name then - local selfarg = ftype.args[1] - if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - selfarg.is_self = false - elseif typ.typeargs then - for j = 1, #typ.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then - ftype.is_method = false - selfarg.is_self = false - break + if ftype.typename == "function" and ftype.is_method then + local fargs = ftype.args.tuple + if fargs[1] and fargs[1].is_self then + local record_name = typ.names and typ.names[1] + if record_name then + local selfarg = fargs[1] + if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + selfarg.is_self = false + elseif typ.typeargs then + for j = 1, #typ.typeargs do + if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + ftype.is_method = false + selfarg.is_self = false + break + end end end end @@ -11408,11 +11439,7 @@ a.types[i], b.types[i]), } }, ["typearg"] = { after = function(typ, _children) - add_var(nil, typ.typearg, a_type("typearg", { - y = typ.y, - x = typ.x, - typearg = typ.typearg, - })) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) return typ end, }, diff --git a/tl.tl b/tl.tl index ae0acd038..05f88f183 100644 --- a/tl.tl +++ b/tl.tl @@ -1068,7 +1068,6 @@ local table_types : {TypeName:boolean} = { } local record Type - is {Type} where self.typename ~= nil y: integer @@ -1088,6 +1087,7 @@ local record Type -- tuple is_va: boolean + tuple: {Type} -- poly, union, tupletable types: {Type} @@ -1121,8 +1121,8 @@ local record Type -- function is_method: boolean min_arity: number - args: Type - rets: Type + args: TupleType + rets: TupleType typeid: integer @@ -1165,6 +1165,8 @@ local record Type narrows: {string:boolean} end +local type TupleType = Type + local record Operator y: integer x: integer @@ -1374,7 +1376,7 @@ local record Node value: Node key_parsed: KeyParsed - typeargs: Type + typeargs: {Type} args: Node rets: Type body: Node @@ -1438,7 +1440,10 @@ local record Node macrodef: Node expanded: Node - decltype: Type + argtype: Type + itemtype: Type + decltuple: TupleType + opt: boolean debug_type: Type @@ -1476,7 +1481,7 @@ end local enum ParseTypeListMode "rets" - "decltype" + "decltuple" "casttype" end @@ -1583,11 +1588,11 @@ local macroexp a_typetype(t: Type): Type end local macroexp a_tuple(t: {Type}): Type - return a_type("tuple", t) + return a_type("tuple", { tuple = t }) end local function c_tuple(t: {Type}): Type - return a_type("tuple", t) + return a_type("tuple", { tuple = t }) end local macroexp a_union(t: {Type}): Type @@ -1598,14 +1603,16 @@ local macroexp a_poly(t: {Type}): Type return a_type("poly", { types = t }) end -local macroexp a_function(t: Type): Type +local function a_function(t: Type): Type + assert(t.args.typename == "tuple") + assert(t.rets.typename == "tuple") return a_type("function", t) end local function a_vararg(t: {Type}): Type - local tuple = t as Type - tuple.is_va = true - return a_tuple(t) + local typ = a_tuple(t) + typ.is_va = true + return typ end local macroexp an_array(t: Type): Type @@ -1732,8 +1739,8 @@ local function parse_table_item(ps: ParseState, i: integer, n?: integer): intege node.key.conststr = node.key.tk node.key.tk = '"' .. node.key.tk .. '"' i = verify_tk(try_ps, i, ":") - i, node.decltype = parse_type(try_ps, i) - if node.decltype and ps.tokens[i].tk == "=" then + i, node.itemtype = parse_type(try_ps, i) + if node.itemtype and ps.tokens[i].tk == "=" then i = verify_tk(try_ps, i, "=") i, node.value = parse_table_value(try_ps, i) if node.value then @@ -1744,7 +1751,7 @@ local function parse_table_item(ps: ParseState, i: integer, n?: integer): intege end end -- backtrack: - node.decltype = nil + node.itemtype = nil i = orig_i end end @@ -1861,9 +1868,9 @@ local function parse_anglebracket_list(ps: ParseState, i: integer, parse_item: P if ps.tokens[i+1].tk == ">" then return fail(ps, i+1, "type argument list cannot be empty") end - local typ = new_type(ps, i, "tuple") + local types: {Type} = {} i = verify_tk(ps, i, "<") - i = parse_list(ps, i, typ, { [">"] = true, [">>"] = true, }, "sep", parse_item) + i = parse_list(ps, i, types, { [">"] = true, [">>"] = true, }, "sep", parse_item) if ps.tokens[i].tk == ">" then i = i + 1 elseif ps.tokens[i].tk == ">>" then @@ -1872,7 +1879,7 @@ local function parse_anglebracket_list(ps: ParseState, i: integer, parse_item: P else return fail(ps, i, "syntax error, expected '>'") end - return i, typ + return i, types end local function parse_typearg(ps: ParseState, i: integer): integer, Type, integer @@ -1901,7 +1908,7 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Type typ.args = a_vararg { ANY } typ.rets = a_vararg { ANY } end - if typ.args[1] and typ.args[1].is_self then + if typ.args.tuple[1] and typ.args.tuple[1].is_self then typ.is_method = true end return i, typ @@ -2047,15 +2054,21 @@ parse_type = function(ps: ParseState, i: integer): integer, Type, integer return i, bt end +local function new_tuple(ps: ParseState, i: integer): Type, {Type} + local t = new_type(ps, i, "tuple") + t.tuple = {} + return t, t.tuple +end + parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): integer, Type - local list = new_type(ps, i, "tuple") + local t, list = new_tuple(ps, i) local first_token = ps.tokens[i].tk - if mode == "rets" or mode == "decltype" then + if mode == "rets" or mode == "decltuple" then if first_token == ":" then i = i + 1 else - return i, list + return i, t end end @@ -2075,7 +2088,7 @@ parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): i = i + 1 local nrets = #list if nrets > 0 then - list.is_va = true + t.is_va = true else fail(ps, i, "unexpected '...'") end @@ -2085,7 +2098,7 @@ parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): i = verify_tk(ps, i, ")") end - return i, list + return i, t end local function parse_function_args_rets_body(ps: ParseState, i: integer, node: Node): integer, Node @@ -2528,12 +2541,12 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege end if ps.tokens[i].tk == ":" then i = i + 1 - local decltype: Type + local argtype: Type - i, decltype = parse_type(ps, i) + i, argtype = parse_type(ps, i) if node then - node.decltype = decltype + node.argtype = argtype end end return i, node, 0 @@ -2614,7 +2627,7 @@ end parse_argument_type_list = function(ps: ParseState, i: integer): integer, Type local tvs: {TypeAndVararg} = {} i = parse_bracket_list(ps, i, tvs, "(", ")", "sep", parse_argument_type) - local list = new_type(ps, i, "tuple") + local t, list = new_tuple(ps, i) local n = #tvs for l, tv in ipairs(tvs) do list[l] = tv.type @@ -2623,9 +2636,9 @@ parse_argument_type_list = function(ps: ParseState, i: integer): integer, Type end end if tvs[n] and tvs[n].is_va then - list.is_va = true + t.is_va = true end - return i, list + return i, t end local function parse_identifier(ps: ParseState, i: integer): integer, Node, integer @@ -2979,10 +2992,10 @@ local function parse_where_clause(ps: ParseState, i: integer): integer, Node node.args = new_node(ps.tokens, i, "argument_list") node.args[1] = new_node(ps.tokens, i, "argument") node.args[1].tk = "self" - node.args[1].decltype = new_type(ps, i, "nominal") - node.args[1].decltype.names = { "@self" } - node.rets = new_type(ps, i, "tuple") - node.rets[1] = BOOLEAN + node.args[1].argtype = new_type(ps, i, "nominal") + node.args[1].argtype.names = { "@self" } + node.rets = new_tuple(ps, i) + node.rets.tuple[1] = BOOLEAN i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i - 1]) return i, node @@ -3063,7 +3076,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): local typ = new_type(ps, wstart, "function") typ.is_method = true - typ.args = a_tuple { a_type("nominal", { y = typ.y, x = typ.x, names = { "@self" } }) } + typ.args = a_tuple { a_type("nominal", { y = typ.y, x = typ.x, filename = ps.filename, names = { "@self" } }) } typ.rets = a_tuple { BOOLEAN } typ.macroexp = where_macroexp typ.is_abstract = true @@ -3282,7 +3295,7 @@ local function parse_variable_declarations(ps: ParseState, i: integer, node_name return fail(ps, i, "expected a local variable definition") end - i, asgn.decltype = parse_type_list(ps, i, "decltype") + i, asgn.decltuple = parse_type_list(ps, i, "decltuple") if ps.tokens[i].tk == "=" then -- produce nice error message when using <= 0.7.1 syntax @@ -3670,10 +3683,11 @@ local function recurse_type(ast: Type, visit: Visitor): T end end - for i, child in ipairs(ast) do - xs[i] = recurse_type(child, visit) + if ast.tuple then + for i, child in ipairs(ast.tuple) do + xs[i] = recurse_type(child, visit) + end end - if ast.types then for _, child in ipairs(ast.types) do table.insert(xs, recurse_type(child, visit)) @@ -3707,14 +3721,14 @@ local function recurse_type(ast: Type, visit: Visitor): T end end if ast.args then - for i, child in ipairs(ast.args) do + for i, child in ipairs(ast.args.tuple) do if i > 1 or not ast.is_method or child.is_self then table.insert(xs, recurse_type(child, visit)) end end end if ast.rets then - for _, child in ipairs(ast.rets) do + for _, child in ipairs(ast.rets.tuple) do table.insert(xs, recurse_type(child, visit)) end end @@ -3803,8 +3817,8 @@ local function recurse_node(root: Node, local function walk_vars_exps(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) - if ast.decltype then - xs[2] = recurse_type(ast.decltype, visit_type) + if ast.decltuple then + xs[2] = recurse_type(ast.decltuple, visit_type) end extra_callback("before_exp", ast, xs, visit_node) if ast.exps then @@ -3854,8 +3868,8 @@ local function recurse_node(root: Node, ["table_item"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) - if ast.decltype then - xs[3] = recurse_type(ast.decltype, visit_type) + if ast.itemtype then + xs[3] = recurse_type(ast.itemtype, visit_type) end end, @@ -3965,8 +3979,8 @@ local function recurse_node(root: Node, end, ["argument"] = function(ast: Node, xs:{T}) - if ast.decltype then - xs[1] = recurse_type(ast.decltype, visit_type) + if ast.argtype then + xs[1] = recurse_type(ast.argtype, visit_type) end end, } @@ -4839,7 +4853,6 @@ local typename_to_typecode : {TypeName:integer} = { local skip_types: {TypeName: boolean} = { ["none"] = true, - ["tuple"] = true, ["table_item"] = true, ["unresolved"] = true, ["typetype"] = true, @@ -4879,12 +4892,12 @@ end local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: Type) local args: {{integer, string}} = {} - for _, fnarg in ipairs(rt.args) do + for _, fnarg in ipairs(rt.args.tuple) do table.insert(args, mark_array { get_typenum(trenv, fnarg), nil }) end ti.args = mark_array(args) local rets: {{integer, string}} = {} - for _, fnarg in ipairs(rt.rets) do + for _, fnarg in ipairs(rt.rets.tuple) do table.insert(rets, mark_array { get_typenum(trenv, fnarg), nil }) end ti.rets = mark_array(rets) @@ -4907,8 +4920,8 @@ get_typenum = function(trenv:TypeReportEnv, t: Type): integer local rt = t if is_typetype(rt) then rt = rt.def - elseif rt.typename == "tuple" and #rt == 1 then - rt = rt[1] + elseif rt.typename == "tuple" and #rt.tuple == 1 then + rt = rt.tuple[1] end local ti: TypeInfo = { @@ -5303,10 +5316,14 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end elseif t.typename == "tuple" then local out: {string} = {} - for _, v in ipairs(t) do + for _, v in ipairs(t.tuple) do table.insert(out, show(v)) end - return "(" .. table.concat(out, ", ") .. ")" + local list = table.concat(out, ", ") + if short then + return list + end + return "(" .. list .. ")" elseif t.typename == "tupletable" then local out: {string} = {} for _, v in ipairs(t.types) do @@ -5353,20 +5370,20 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str if t.is_method then table.insert(args, "self") end - for i, v in ipairs(t.args) do + for i, v in ipairs(t.args.tuple) do if not t.is_method or i > 1 then - table.insert(args, ((i == #t.args and t.args.is_va) and "...: " + table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " or v.opt and "? " or "") .. show(v)) end end table.insert(out, table.concat(args, ", ")) table.insert(out, ")") - if #t.rets > 0 then + if t.rets.tuple and #t.rets.tuple > 0 then table.insert(out, ": ") local rets = {} - for i, v in ipairs(t.rets) do - table.insert(rets, show(v) .. (i == #t.rets and t.rets.is_va and "..." or "")) + for i, v in ipairs(t.rets.tuple) do + table.insert(rets, show(v) .. (i == #t.rets.tuple and t.rets.is_va and "..." or "")) end table.insert(out, table.concat(rets, ", ")) end @@ -5716,7 +5733,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local LOAD_FUNCTION = a_function { args = {}, rets = a_tuple { STRING } } + local LOAD_FUNCTION = a_function { args = a_tuple {}, rets = a_tuple { STRING } } local OS_DATE_TABLE = a_record { fields = { @@ -5747,7 +5764,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["nparams"] = INTEGER, ["isvararg"] = BOOLEAN, ["func"] = ANY, - ["activelines"] = a_type("map", { keys = INTEGER, values = BOOLEAN }), + ["activelines"] = a_map(INTEGER, BOOLEAN), } } @@ -5903,9 +5920,9 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["debug"] = a_function { args = a_tuple {}, rets = a_tuple {} }, ["gethook"] = a_function { args = a_tuple { OPT(THREAD) }, rets = a_tuple { DEBUG_HOOK_FUNCTION, INTEGER } }, ["getlocal"] = a_poly { - a_function { args = a_tuple { THREAD, FUNCTION, NUMBER }, rets = STRING }, + a_function { args = a_tuple { THREAD, FUNCTION, NUMBER }, rets = a_tuple { STRING } }, a_function { args = a_tuple { THREAD, NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, - a_function { args = a_tuple { FUNCTION, NUMBER }, rets = STRING }, + a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { STRING } }, a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, }, ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), @@ -6355,7 +6372,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function ensure_not_abstract(where: Where, t: Type) + local function ensure_not_abstract(where: Where, t: Type) if not t.is_abstract then return end @@ -6399,7 +6416,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if is_typetype(t) then return union_type(t.def), t.def elseif t.typename == "tuple" then - return union_type(t[1]), t[1] + return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then local typetype = t.found or find_type(t.names) if not typetype then @@ -6504,7 +6521,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if f.min_arity then return end - local tuple = f.args + local tuple = f.args.tuple local n = #tuple if f.args.is_va then n = n - 1 @@ -6520,9 +6537,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function show_arity(f: Type): string - return f.min_arity < #f.args - and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. #f.args) - or tostring(#f.args or 0) + local nfargs = #f.args.tuple + return f.min_arity < nfargs + and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + or tostring(nfargs or 0) end local function resolve_typetype(t: Type): Type @@ -6603,10 +6621,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.xend = t.xend copy.names = t.names -- which types have this, exactly? - for i, tf in ipairs(t) do - copy[i], same = resolve(tf, same) - end - if t.typename == "array" then copy.elements, same = resolve(t.elements, same) -- inferred_len is not propagated @@ -6686,6 +6700,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t.typename == "tuple" then copy.is_va = t.is_va + copy.tuple = {} + for i, tf in ipairs(t.tuple) do + copy.tuple[i], same = resolve(tf, same) + end end copy.typeid = same and orig_t.typeid or new_typeid() @@ -6722,7 +6740,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function resolve_tuple(t: Type): Type if t.typename == "tuple" then - t = t[1] + t = t.tuple[1] end if t == nil then return NIL @@ -6821,14 +6839,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function type_at(w: Where, t: Type): Type + t.x = w.x + t.y = w.y + t.filename = filename + return t + end + local function resolve_typevars_at(where: Where, t: Type): Type assert(where) - local ok, typ, errs = resolve_typevars(t) + local ok, ret, errs = resolve_typevars(t) if not ok then assert(where.y) add_errs_prefixing(where, errs, errors, "") end - return typ + if ret == t or t.typename == "typevar" then + ret = shallow_copy_table(ret) + end + return type_at(where, ret) end local function infer_at(where: Where, t: Type): Type @@ -6836,7 +6864,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ret.typename == "invalid" then ret = t -- errors are produced by resolve_typevars_at end - ret = (ret ~= t) and ret or shallow_copy_table(t) + if ret == t or t.typename == "typevar" then + ret = shallow_copy_table(ret) + end ret.inferred_at = where ret.inferred_at.filename = filename return ret @@ -7573,21 +7603,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["function"] = { ["function"] = function(a: Type, b: Type): boolean, {Error} local argdelta = a.is_method and 1 or 0 - if #a.args ~= #b.args then + local naargs, nbargs = #a.args.tuple, #b.args.tuple + if naargs ~= nbargs then if a.is_method ~= b.is_method then return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. #a.args - argdelta .. ", expected " .. #b.args - argdelta) } + return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end - if #a.rets ~= #b.rets then - return false, { Err(a, "different number of return values: got " .. #a.rets .. ", expected " .. #b.rets) } + local narets, nbrets = #a.rets.tuple, #b.rets.tuple + if narets ~= nbrets then + return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} - for i = 1, #a.args do - arg_check(a, same_type, a.args[i], b.args[i], i - argdelta, errs, "argument") + for i = 1, naargs do + arg_check(a, same_type, a.args.tuple[i], b.args.tuple[i], i - argdelta, errs, "argument") end - for i = 1, #a.rets do - arg_check(a, same_type, a.rets[i], b.rets[i], i, errs, "return") + for i = 1, narets do + arg_check(a, same_type, a.rets.tuple[i], b.rets.tuple[i], i, errs, "return") end return any_errors(errs) end, @@ -7607,11 +7639,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["tuple"] = { ["tuple"] = function(a: Type, b: Type): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] - if #a ~= #b then -- ────────────────────────────────── - return false -- a tuple <: b tuple + local at, bt = a.tuple, b.tuple -- ────────────────────────────────── + if #at ~= #bt then -- a tuple <: b tuple + return false end - for i = 1, #a do - if not is_a(a[i], b[i]) then + for i = 1, #at do + if not is_a(at[i], bt[i]) then return false end end @@ -7821,21 +7854,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["function"] = function(a: Type, b: Type): boolean, {Error} local errs = {} - local aa, ba = a.args, b.args + local aa, ba = a.args.tuple, b.args.tuple set_min_arity(a) set_min_arity(b) - if (not ba.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", aa, ba)) + if (not b.args.is_va) and a.min_arity > b.min_arity then + table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do arg_check(nil, is_a, aa[i], ba[i] or ANY, i, errs, "argument") end end - local ar, br = a.rets, b.rets - local diff_by_va = #br - #ar == 1 and br.is_va + local ar, br = a.rets.tuple, b.rets.tuple + local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", ar, br)) + table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then @@ -8055,15 +8088,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt then local args_tuple = a_tuple({}) - for i = 2, #call_mt.args do - table.insert(args_tuple, call_mt.args[i]) + for i = 2, #call_mt.args.tuple do + table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end return args_tuple, call_mt end end) end - local function resolve_for_call(func: Type, args: {Type}, is_method: boolean): Type, boolean + local function resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function if lax and is_unknown(func) then func = a_function { args = a_vararg { UNKNOWN }, rets = a_vararg { UNKNOWN } } @@ -8075,7 +8108,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func.typename == "union" then local r = same_call_mt_in_all_union_entries(func) if r then - table.insert(args, 1, func.types[1]) -- FIXME: is this right? + table.insert(args.tuple, 1, func.types[1]) -- FIXME: is this right? return resolve_tuple_and_nominal(r), true end end @@ -8085,7 +8118,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- resolve if metatable if func.meta_fields and func.meta_fields["__call"] then - table.insert(args, 1, func) + table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] func = resolve_tuple_and_nominal(func) is_method = true @@ -8185,7 +8218,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type_check_function_call: function(Node, {Node}, Type, {Type}, Node, boolean, ? integer): Type, Type + local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): Type, Type do local function mark_invalid_typeargs(f: Type) if f.typeargs then @@ -8201,38 +8234,40 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(xs.typename == "tuple") assert(ys.typename == "tuple") - local n_xs = #xs - local n_ys = #ys + local xt, yt = xs.tuple, ys.tuple + local n_xs = #xt + local n_ys = #yt -- resolve inference of emptytables used as arguments or returns for i = 1, n_xs do - local x = xs[i] + local x = xt[i] if x.typename == "emptytable" or x.typename == "unresolved_emptytable_value" then - local y = ys[i] or (ys.is_va and ys[n_ys]) + local y = yt[i] or (ys.is_va and yt[n_ys]) if y then -- y may not be present when inferring returns local w = wheres and wheres[i + delta] or where -- for self, a + argdelta is 0 local inferred_y = infer_at(w, y) infer_emptytable(x, inferred_y) - xs[i] = inferred_y + xt[i] = inferred_y end end end end - local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: {Type}, rets: {Type}, argdelta: integer): Type, {Error} + local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} do -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(where: Where, wheres: {Where}, xs: Type, ys: Type, from: integer, delta: integer, mode: string): boolean, {Error} + local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, mode: string): boolean, {Error} assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) local errs = {} - local n_xs = #xs - local n_ys = #ys + local xt, yt = xs.tuple, ys.tuple + local n_xs = #xt + local n_ys = #yt for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xs[i] or (xs.is_va and xs[n_xs]) or NIL - local y = ys[i] or (ys.is_va and ys[n_ys]) + local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local y = yt[i] or (ys.is_va and yt[n_ys]) if y then local w = wheres and wheres[pos] or where if not arg_check(w, is_a, x, y, pos, errs, mode) then @@ -8244,26 +8279,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Where, where_args: {Node}, f: Type, args: {Type}, rets: {Type}, argdelta: integer): Type, {Error} + check_args_rets = function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean local args_errs: {Error} + local fargs = f.args.tuple local from = 1 if argdelta == -1 then from = 2 local errs = {} - if (not is_self(f.args[1])) and not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then + if (not is_self(fargs[1])) and not arg_check(where, is_a, args.tuple[1], fargs[1], nil, errs, "self") then return nil, errs end end - if rets then - rets = infer_at(where, rets) - infer_emptytables(where, nil, rets, f.rets, 0) + if expected_rets then + expected_rets = infer_at(where, expected_rets) + infer_emptytables(where, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, rets, 1, 0, "return") + rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "return") end args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, from, argdelta, "argument") @@ -8332,7 +8368,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_typevars_at(where, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: {Type}, expected: Type, typetype_funcall: boolean, is_method: boolean, argdelta: integer): Type, Type + local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, typetype_funcall: boolean, is_method: boolean, argdelta: integer): Type, Type assert(type(func) == "table") assert(type(args) == "table") @@ -8342,8 +8378,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string argdelta = is_method and -1 or argdelta or 0 - if is_method and args[1] then - add_var(nil, "@self", a_typetype({ y = where.y, x = where.x, def = args[1] })) + if is_method and args.tuple[1] then + add_var(nil, "@self", type_at(where, a_typetype { def = args.tuple[1] })) end local is_func = func.typename == "function" @@ -8357,15 +8393,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string passes, n = 3, #func.types end - local given = #args + local given = #args.tuple local tried: {integer:boolean} local first_errs: {Error} for pass = 1, passes do for i = 1, n do if (not tried) or not tried[i] then local f = is_func and func or func.types[i] + local fargs = f.args.tuple if f.is_method and not is_method then - if args[1] and is_a(args[1], f.args[1]) then + if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then -- a non-"@funcall" means a synthesized call, e.g. from a metamethod if not typetype_funcall then add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") @@ -8374,7 +8411,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") end end - local wanted = #f.args + local wanted = #fargs set_min_arity(f) -- simple functions: @@ -8388,14 +8425,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string then push_typeargs(f) - local matched, errs = check_args_rets(where, where_args, f, args, expected, argdelta) + local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) if matched then -- success! return matched, f end first_errs = first_errs or errs - if expected then + if expected_rets then -- revert inferred returns infer_emptytables(where, where_args, f.rets, f.rets, argdelta) end @@ -8413,7 +8450,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return fail_call(where, func, given, first_errs) end - type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: {Type}, e1: Node, is_method: boolean, argdelta?: integer): Type, Type + type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: TupleType, e1: Node, is_method: boolean, argdelta?: integer): Type, Type if node.expected and node.expected.typename ~= "tuple" then node.expected = a_tuple { node.expected } end @@ -8467,7 +8504,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local args = a_tuple { orig_a } if b and method_name ~= "__is" then where_args[2] = node.e2 - args[2] = orig_b + args.tuple[2] = orig_b end return resolve_tuple_and_nominal((type_check_function_call(node, where_args, metamethod, args, nil, true))), meta_on_operator else @@ -8648,17 +8685,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return st[1][var] end - local function get_rets(rets: {Type}): Type - if lax and (#rets == 0) then - return a_vararg { UNKNOWN } + local get_rets: function(Type): Type + if lax then + get_rets = function(rets: Type): Type + if #rets.tuple == 0 then + return a_vararg { UNKNOWN } + end + return rets end - local t: Type = rets as Type - if not t.typename then - -- what type is this? - t = a_tuple(t) + else + get_rets = function(rets: Type): Type + return rets end - assert(t.typeid) - return t end local function add_internal_function_variables(node: Node, args: Type) @@ -8680,10 +8718,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function add_function_definition_for_recursion(node: Node, fnargs: Type) assert(fnargs.typename == "tuple") - local args = a_type("tuple", {}) + -- FIXME needs this copy? + local args: Type = a_tuple({}) args.is_va = fnargs.is_va - for _, fnarg in ipairs(fnargs) do - table.insert(args, fnarg) + for _, fnarg in ipairs(fnargs.tuple) do + table.insert(args.tuple, fnarg) end add_var(nil, node.name.tk, a_function { @@ -8729,10 +8768,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function flatten_tuple(vals: Type): Type - local vt = vals + local vt = vals.tuple local n_vals = #vt local ret = a_tuple {} - local rt = ret + local rt = ret.tuple if n_vals == 0 then return ret @@ -8746,7 +8785,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local last = vt[n_vals] if last.typename == "tuple" then -- ...then unpack the last tuple - local lt = last + local lt = last.tuple for _, v in ipairs(lt) do table.insert(rt, v) end @@ -8759,7 +8798,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function get_assignment_values(vals: Type, wanted: integer): {Type} + local function get_assignment_values(vals: Type, wanted: integer): Type if vals == nil then return a_tuple {} end @@ -8768,8 +8807,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- ...if the last is vararg, repeat its type until it matches the number of wanted args if ret.is_va then - local n_ret = #ret - local rt = ret + local rt = ret.tuple + local n_ret = #rt if n_ret > 0 and n_ret < wanted then local last = rt[n_ret] for _ = n_ret + 1, wanted do @@ -8838,11 +8877,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if is_a(orig_b, a.keys) then - return a_type("unresolved_emptytable_value", { - y = anode.y, - x = anode.x, - emptytable_type = a - }) + return type_at(anode, a_type("unresolved_emptytable_value", { emptytable_type = a })) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " @@ -8968,21 +9003,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function typetype_to_nominal(where: Where, name: string, t: Type, resolved?: Type): Type assert(t.typename == "typetype") - local typevals: Type + local typevals: {Type} if t.def.typeargs then typevals = {} for _, a in ipairs(t.def.typeargs) do table.insert(typevals, a_type("typevar", { typevar = a.typearg })) end end - return a_type("nominal", { - y = where.y, - x = where.x, + return type_at(where, a_type("nominal", { typevals = typevals, names = { name }, found = t, resolved = resolved, - }) + })) end local function get_self_type(exp: Node): Type @@ -9375,9 +9408,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local type_check_funcall: function(node: Node, a: Type, b: {Type}, argdelta?: integer): Type + local type_check_funcall: function(node: Node, a: Type, b: Type, argdelta?: integer): Type - local function special_pcall_xpcall(node: Node, _a: Type, b: {Type}, argdelta: integer): Type + local function special_pcall_xpcall(node: Node, _a: Type, b: Type, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") @@ -9385,14 +9418,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- The function called by pcall/xpcall is invoked as a regular function, so we wish to avoid incorrect error messages / unnecessary warning messages associated with calling methods as functions - local ftype = table.remove(b, 1) + local ftype = table.remove(b.tuple, 1) ftype = shallow_copy_new_type(ftype) ftype.is_method = false local fe2: Node = {} if node.e1.tk == "xpcall" then base_nargs = 2 - local msgh = table.remove(b, 1) + local msgh = table.remove(b.tuple, 1) assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") end for i = base_nargs + 1, #node.e2 do @@ -9407,20 +9440,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string e2 = fe2, } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) - if rets.typename ~= "tuple" then - -- TODO what type is this?... - rets = a_tuple({ rets }) + if rets == INVALID then + return rets end - table.insert(rets, 1, BOOLEAN) + assert(rets and rets.typename == "tuple", show_type(rets)) + table.insert(rets.tuple, 1, BOOLEAN) return rets end - local special_functions: {string : function(Node,Type,{Type},integer):Type } = { - ["pairs"] = function(node: Node, a: Type, b: {Type}, argdelta: integer): Type - if not b[1] then + local special_functions: {string : function(Node,Type,TupleType,integer):TupleType } = { + ["pairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): TupleType + if not b.tuple[1] then return invalid_at(node, "pairs requires an argument") end - local t = resolve_tuple_and_nominal(b[1]) + local t = resolve_tuple_and_nominal(b.tuple[1]) if is_array_type(t) then add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end @@ -9441,11 +9474,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) end, - ["ipairs"] = function(node: Node, a: Type, b: {Type}, argdelta: integer): Type - if not b[1] then + ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): TupleType + if not b.tuple[1] then return invalid_at(node, "ipairs requires an argument") end - local orig_t = b[1] + local orig_t = b.tuple[1] local t = resolve_tuple_and_nominal(orig_t) if t.typename == "tupletable" then @@ -9462,17 +9495,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) end, - ["rawget"] = function(node: Node, _a: Type, b: {Type}, _argdelta: integer): Type + ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): TupleType -- TODO should those offsets be fixed by _argdelta? - if #b == 2 then - return type_check_index(node.e2[1], node.e2[2], b[1], b[2]) + if #b.tuple == 2 then + return a_tuple({ type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) else return invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node: Node, _a: Type, b: {Type}, _argdelta: integer): Type - if #b ~= 1 then + ["require"] = function(node: Node, _a: Type, b: Type, _argdelta: integer): TupleType + if #b.tuple ~= 1 then return invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then @@ -9493,13 +9526,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end dependencies[module_name] = t.filename - return t + return type_at(node, a_tuple({ t })) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node: Node, a: Type, b: {Type}, argdelta: integer): Type + ["assert"] = function(node: Node, a: Type, b: Type, argdelta: integer): TupleType node.known = FACT_TRUTHY local r = type_check_function_call(node, node.e2, a, b, node, false, argdelta) apply_facts(node, node.e2[1].known) @@ -9507,7 +9540,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, } - type_check_funcall = function(node: Node, a: Type, b: {Type}, argdelta?: integer): Type + type_check_funcall = function(node: Node, a: Type, b: Type, argdelta?: integer): TupleType argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] @@ -9517,7 +9550,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then - table.insert(b, 1, node.e1.receiver) + table.insert(b.tuple, 1, node.e1.receiver) return (type_check_function_call(node, node.e2, a, b, node.e1, true)) else return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) @@ -9563,8 +9596,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function set_expected_types_to_decltypes(node: Node, children: {Type}) - local decls = node.kind == "assignment" and children[1] or node.decltype + local function set_expected_types_to_decltuple(node: Node, children: {Type}) + local decltuple = node.kind == "assignment" and children[1] or node.decltuple + assert(decltuple.typename == "tuple") + local decls = decltuple.tuple if decls and node.exps then local ndecl = #decls local nexps = #node.exps @@ -9573,9 +9608,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = a_type("tuple", { y = node.y, x = node.x, filename = filename }) + typ = type_at(node, a_tuple {}) for a = i, ndecl do - table.insert(typ, decls[a]) + table.insert(typ.tuple, decls[a]) end end node.exps[i].expected = typ @@ -9621,11 +9656,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function infer_table_literal(node: Node, children: {Type}): Type - local typ = a_type("emptytable", { - filename = filename, - y = node.y, - x = node.x, - }) + local typ = type_at(node, a_type("emptytable", {})) local is_record = false local is_array = false @@ -9673,7 +9704,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node[i].key_parsed == "implicit" then if i == #children and child.vtype.typename == "tuple" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) - for _, c in ipairs(child.vtype) do + for _, c in ipairs(child.vtype.tuple) do typ.elements = expand_type(node, typ.elements, c) typ.types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 @@ -9719,12 +9750,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif is_record and is_array then typ.typename = "record" typ.interface_list = { - a_type("array", { - filename = filename, - y = node.y, - x = node.x, - elements = typ.elements, - }) + type_at(node, an_array(typ.elements)) } -- TODO adopt logic from is_array below when we accept tupletable as an interface elseif is_record and is_map then @@ -9787,15 +9813,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string apply_facts(where, f) end - local function determine_declaration_type(var: Node, node: Node, infertypes: {Type}, i: integer): boolean, Type, boolean + local function determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean local ok = true local name = var.tk - local infertype = infertypes and infertypes[i] + local infertype = infertypes and infertypes.tuple[i] if lax and infertype and infertype.typename == "nil" then infertype = nil end - local decltype = node.decltype and node.decltype[i] + local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then if resolve_tuple_and_nominal(decltype) == INVALID then decltype = INVALID @@ -9859,8 +9885,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function get_type_declaration(node: Node): Type, Variable - if node.value.kind == "op" and node.value.op.op == "@funcall" then - return special_functions["require"](node.value, find_var_type("require"), { STRING }, 0) + if node.value.kind == "op" + and node.value.op.op == "@funcall" + and node.value.e1.kind == "variable" + and node.value.e1.tk == "require" + then + local t = special_functions["require"](node.value, find_var_type("require"), a_tuple { STRING }, 0) + if t ~= INVALID then + return t.tuple[1] + end else return resolve_nominal_typetype(node.value.newtype) end @@ -10037,10 +10070,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - before_exp = set_expected_types_to_decltypes, + before_exp = set_expected_types_to_decltuple, after = function(node: Node, children: {Type}): Type local encountered_close = false - local infertypes: {Type} = get_assignment_values(children[3], #node.vars) + local infertypes = get_assignment_values(children[3], #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then if opts.gen_target == "5.4" then @@ -10067,9 +10100,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(var) add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") - if ok and infertypes and infertypes[i] then + local infertype = infertypes.tuple[i] + if ok and infertype then local where = node.exps[i] or node.exps - local infertype = infertypes[i] local rt = resolve_tuple_and_nominal(t) if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then @@ -10088,9 +10121,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["global_declaration"] = { - before_exp = set_expected_types_to_decltypes, + before_exp = set_expected_types_to_decltuple, after = function(node: Node, children: {Type}): Type - local infertypes: {Type} = get_assignment_values(children[3], #node.vars) + local infertypes = get_assignment_values(children[3], #node.vars) for i, var in ipairs(node.vars) do local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) @@ -10106,20 +10139,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["assignment"] = { - before_exp = set_expected_types_to_decltypes, + before_exp = set_expected_types_to_decltuple, after = function(node: Node, children: {Type}): Type - local valtypes: {Type} = get_assignment_values(children[3], #children[1]) - for i, vartype in ipairs(children[1]) do + local vartypes = children[1].tuple + local valtypes = get_assignment_values(children[3], #vartypes) + for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk - local rvar, rval, err = check_assignment(varnode, vartype, valtypes[i], varname, varnode.attribute) + local valtype = valtypes.tuple[i] + local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local rets = children[3] if rets.typename == "tuple" then - local msg = #rets == 1 + local msg = #rets.tuple == 1 and "only 1 value is returned by the function" - or ("only " .. #rets .. " values are returned by the function") + or ("only " .. #rets.tuple .. " values are returned by the function") add_warning("hint", varnode, msg) end end @@ -10137,7 +10172,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if store_type then - store_type(varnode.y, varnode.x, valtypes[i]) + store_type(varnode.y, varnode.x, valtype) end end end @@ -10204,7 +10239,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) end local unresolved = st[#st]["@unresolved"] - local var = add_var(node, label_id, a_type("none", { y = node.y, x = node.x })) + local var = add_var(node, label_id, type_at(node, a_type("none", {}))) if unresolved then if unresolved.t.labels[node.label] then var.used = true @@ -10240,7 +10275,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope(node) end, before_statements = function(node: Node, children: {Type}) - local exptypes = children[2] + local exptypes = children[2].tuple widen_all_unions(node) local exp1 = node.exps[1] @@ -10260,7 +10295,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local last: Type local rets = exp1type.rets for i, v in ipairs(node.vars) do - local r = rets[i] + local r = rets.tuple[i] if not r then if rets.is_va then r = last @@ -10271,8 +10306,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(v, v.tk, r) last = r end - if (not lax) and (not rets.is_va and #node.vars > #rets) then - local nrets = #rets + local nrets = #rets.tuple + if (not lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" error_at(at, "too many variables for this iterator; it produces " .. n_values) @@ -10306,36 +10341,42 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rets = find_var_type("@return") if rets then for i, exp in ipairs(node.exps) do - exp.expected = rets[i] + exp.expected = rets.tuple[i] end end end, after = function(node: Node, children: {Type}): Type + local got = children[1] + local got_t = got.tuple + local n_got = #got_t + node.block_returns = true - local rets = find_var_type("@return") - if not rets then + local expected = find_var_type("@return") + if not expected then -- if at the toplevel - rets = infer_at(node, children[1]) - module_type = resolve_tuple_and_nominal(rets) + expected = infer_at(node, got) + module_type = resolve_tuple_and_nominal(expected) module_type.tk = nil - st[2]["@return"] = { t = rets } + st[2]["@return"] = { t = expected } end + local expected_t = expected.tuple + local what = "in return value" - if rets.inferred_at then - what = what .. inferred_msg(rets) + if expected.inferred_at then + what = what .. inferred_msg(expected) end - local nrets = #rets + local n_expected = #expected_t local vatype: Type - if nrets > 0 then - vatype = rets.is_va and rets[nrets] + if n_expected > 0 then + vatype = expected.is_va and expected.tuple[n_expected] end - if #children[1] > nrets and (not lax) and not vatype then - error_at(node, what ..": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) + if n_got > n_expected and (not lax) and not vatype then + error_at(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end - if nrets > 1 + if n_expected > 1 and #node.exps == 1 and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") @@ -10343,15 +10384,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end - for i = 1, #children[1] do - local expected = rets[i] or vatype - if expected then - expected = resolve_tuple(expected) + for i = 1, n_got do + local e = expected_t[i] or vatype + if e then + e = resolve_tuple(e) local where = (node.exps[i] and node.exps[i].x) and node.exps[i] or node.exps assert(where and where.x) - assert_is_a(where, children[1][i], expected, what) + assert_is_a(where, got_t[i], e, what) end end @@ -10364,7 +10405,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tuple = flatten_tuple(tuple) - for i, t in ipairs(tuple) do + for i, t in ipairs(tuple.tuple) do ensure_not_abstract(node[i], t) end @@ -10484,7 +10525,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif is_array and is_number_type(child.ktype) then if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) - for ti, tt in ipairs(child.vtype) do + for ti, tt in ipairs(child.vtype.tuple) do assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) end else @@ -10533,22 +10574,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local kname = node.key.conststr local ktype = children[1] local vtype = children[2] - if node.decltype then - vtype = node.decltype - assert_is_a(node.value, children[2], node.decltype, "in table item") + if node.itemtype then + vtype = node.itemtype + assert_is_a(node.value, children[2], node.itemtype, "in table item") end if vtype.is_method then -- If we assign a method to a table item, e.g local a = { myfunc = myobj.dothing }, the table item should not be treated as a method vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return a_type("table_item", { - y = node.y, - x = node.x, + return type_at(node, a_type("table_item", { kname = kname, ktype = ktype, vtype = vtype, - }) + })) end, }, ["local_function"] = { @@ -10660,16 +10699,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- add type arguments from the record implicitly if rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, a_type("typearg", { - y = typ.y, - x = typ.x, - typearg = typ.typearg, - })) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) end end end, before_statements = function(node: Node, children: {Type}) local args = children[3] + assert(args.typename == "tuple") local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) @@ -10694,7 +10730,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string error_at(node, "could not resolve type of self") return end - args[1] = selftype + args.tuple[1] = selftype add_var(nil, "self", selftype) end @@ -10842,7 +10878,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.expected then is_a(e1type.rets, node.expected) end - local e1args = e1type.args + local e1args = e1type.args.tuple local at = argdelta for _, typ in ipairs(e1args) do at = at + 1 @@ -10850,7 +10886,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.e2[at].expected = typ end end - if e1args.is_va then + if e1type.args.is_va then local typ = e1args[#e1args] for i = at + 1, #node.e2 do node.e2[i].expected = typ @@ -10886,7 +10922,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - return type_check_funcall(node, a, b) + local t = type_check_funcall(node, a, b) + return t end ensure_not_abstract(node.e1, ra) @@ -10911,11 +10948,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string kind = "string", conststr = node.e2.tk, } - local btype = a_type("string", { - y = node.e2.y, - x = node.e2.x, - tk = '"' ..node.e2.tk .. '"', - }) + local btype = type_at(node.e2, a_type("string", { tk = '"' ..node.e2.tk .. '"' })) local t = type_check_index(node.e1, bnode, orig_a, btype) if t.needs_compat and opts.gen_compat ~= "off" then @@ -11273,11 +11306,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function after_literal(node: Node): Type node.known = FACT_TRUTHY - return a_type(node.kind as TypeName, { - y = node.y, - x = node.x, - tk = node.tk, - }) + return type_at(node, a_type(node.kind as TypeName, { tk = node.tk })) end visit_node.cbs["string"] = { @@ -11347,7 +11376,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = { before = function(typ: Type) begin_scope() - add_var(nil, "@self", a_typetype({ y = typ.y, x = typ.x, def = typ })) + add_var(nil, "@self", type_at(typ, a_typetype({ def = typ }))) for name, typ2 in fields_of(typ) do if typ2.typename == "typetype" then @@ -11375,20 +11404,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for name, _ in fields_of(typ) do local ftype = children[i] - - if ftype.is_method and ftype.args and ftype.args[1] and ftype.args[1].is_self then - local record_name = typ.names and typ.names[1] - if record_name then - local selfarg = ftype.args[1] - if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - selfarg.is_self = false - elseif typ.typeargs then - for j=1,#typ.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then - ftype.is_method = false - selfarg.is_self = false - break + if ftype.typename == "function" and ftype.is_method then + local fargs = ftype.args.tuple + if fargs[1] and fargs[1].is_self then + local record_name = typ.names and typ.names[1] + if record_name then + local selfarg = fargs[1] + if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + selfarg.is_self = false + elseif typ.typeargs then + for j=1,#typ.typeargs do + if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + ftype.is_method = false + selfarg.is_self = false + break + end end end end @@ -11408,11 +11439,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["typearg"] = { after = function(typ: Type, _children: {Type}): Type - add_var(nil, typ.typearg, a_type("typearg", { - y = typ.y, - x = typ.x, - typearg = typ.typearg, - })) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) return typ end, }, From 305ef38a9da0801f8aa877ecfc960f65c44aefdb Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 14 Dec 2023 03:07:13 -0300 Subject: [PATCH 050/224] tl types: produce types for for-in variables --- spec/cli/types_spec.lua | 28 ++++++++++++++++++++++++++++ tl.lua | 5 +++++ tl.tl | 5 +++++ 3 files changed, 38 insertions(+) diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index d196002d2..d36aafd9a 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -304,5 +304,33 @@ describe("tl types works like check", function() assert(by_pos["1"]["21"]) -- "os" assert(by_pos["1"]["26"]) -- . end) + + it("produce values for forin variables", function() + local name = util.write_tmp_file(finally, [[ + local x: {string:boolean} = {} + for k, v in pairs(x) do + end + ]]) + local pd = io.popen(util.tl_cmd("types", name) .. " 2>" .. util.os_null, "r") + local output = pd:read("*a") + util.assert_popen_close(0, pd:close()) + local types = json.decode(output) + assert(types.by_pos) + local by_pos = types.by_pos[next(types.by_pos)] + assert.same({ + ["19"] = 2, + ["20"] = 5, + ["22"] = 2, + ["39"] = 6, + ["41"] = 2, + }, by_pos["1"]) + assert.same({ + ["17"] = 3, + ["20"] = 4, + ["25"] = 15, + ["30"] = 14, + ["31"] = 2, + }, by_pos["2"]) + end) end) end) diff --git a/tl.lua b/tl.lua index 5b86c0658..67849f5b4 100644 --- a/tl.lua +++ b/tl.lua @@ -10304,6 +10304,11 @@ a.types[i], b.types[i]), } end end add_var(v, v.tk, r) + + if store_type then + store_type(v.y, v.x, r) + end + last = r end local nrets = #rets.tuple diff --git a/tl.tl b/tl.tl index 05f88f183..5fa28ac7e 100644 --- a/tl.tl +++ b/tl.tl @@ -10304,6 +10304,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end add_var(v, v.tk, r) + + if store_type then + store_type(v.y, v.x, r) + end + last = r end local nrets = #rets.tuple From 0f8311cd70be08200c87cd6fe21026f8b3e9f5e9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 14 Dec 2023 17:07:22 -0300 Subject: [PATCH 051/224] visit_type: minor refactor --- tl.lua | 134 +++++++++++++++++++++++++++++---------------------------- tl.tl | 134 +++++++++++++++++++++++++++++---------------------------- 2 files changed, 138 insertions(+), 130 deletions(-) diff --git a/tl.lua b/tl.lua index 67849f5b4..321ede2ad 100644 --- a/tl.lua +++ b/tl.lua @@ -4751,46 +4751,47 @@ function tl.pretty_print_ast(ast, gen_target, mode) } local visit_type = {} - visit_type.cbs = { - ["string"] = { - after = function(typ, _children) - local out = { y = typ.y or -1, h = 0 } - local r = typ.resolved or typ - local lua_type = primitive[r.typename] or - (r.is_userdata and "userdata") or - "table" - table.insert(out, lua_type) - return out - end, - }, + visit_type.cbs = {} + local default_type_visitor = { + after = function(typ, _children) + local out = { y = typ.y or -1, h = 0 } + local r = typ.resolved or typ + local lua_type = primitive[r.typename] or + (r.is_userdata and "userdata") or + "table" + table.insert(out, lua_type) + return out + end, } - visit_type.cbs["typetype"] = visit_type.cbs["string"] - visit_type.cbs["typevar"] = visit_type.cbs["string"] - visit_type.cbs["typearg"] = visit_type.cbs["string"] - visit_type.cbs["function"] = visit_type.cbs["string"] - visit_type.cbs["thread"] = visit_type.cbs["string"] - visit_type.cbs["array"] = visit_type.cbs["string"] - visit_type.cbs["map"] = visit_type.cbs["string"] - visit_type.cbs["tupletable"] = visit_type.cbs["string"] - visit_type.cbs["record"] = visit_type.cbs["string"] - visit_type.cbs["enum"] = visit_type.cbs["string"] - visit_type.cbs["boolean"] = visit_type.cbs["string"] - visit_type.cbs["nil"] = visit_type.cbs["string"] - visit_type.cbs["number"] = visit_type.cbs["string"] - visit_type.cbs["integer"] = visit_type.cbs["string"] - visit_type.cbs["union"] = visit_type.cbs["string"] - visit_type.cbs["nominal"] = visit_type.cbs["string"] - visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] - visit_type.cbs["emptytable"] = visit_type.cbs["string"] - visit_type.cbs["table_item"] = visit_type.cbs["string"] - visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"] - visit_type.cbs["tuple"] = visit_type.cbs["string"] - visit_type.cbs["poly"] = visit_type.cbs["string"] - visit_type.cbs["any"] = visit_type.cbs["string"] - visit_type.cbs["unknown"] = visit_type.cbs["string"] - visit_type.cbs["invalid"] = visit_type.cbs["string"] - visit_type.cbs["unresolved"] = visit_type.cbs["string"] - visit_type.cbs["none"] = visit_type.cbs["string"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typevar"] = default_type_visitor + visit_type.cbs["typearg"] = default_type_visitor + visit_type.cbs["function"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["record"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["union"] = default_type_visitor + visit_type.cbs["nominal"] = default_type_visitor + visit_type.cbs["bad_nominal"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["unresolved"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] @@ -11352,11 +11353,6 @@ a.types[i], b.types[i]), } local visit_type visit_type = { cbs = { - ["string"] = { - after = function(typ, _children) - return typ - end, - }, ["function"] = { before = function(_typ) begin_scope() @@ -11545,28 +11541,36 @@ a.types[i], b.types[i]), } visit_node.after = debug_type_after(visit_node.after) end - visit_type.cbs["tupletable"] = visit_type.cbs["string"] - visit_type.cbs["typetype"] = visit_type.cbs["string"] - visit_type.cbs["array"] = visit_type.cbs["string"] - visit_type.cbs["map"] = visit_type.cbs["string"] + local default_type_visitor = { + after = function(typ, _children) + return typ + end, + } + visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["enum"] = visit_type.cbs["string"] - visit_type.cbs["boolean"] = visit_type.cbs["string"] - visit_type.cbs["nil"] = visit_type.cbs["string"] - visit_type.cbs["number"] = visit_type.cbs["string"] - visit_type.cbs["integer"] = visit_type.cbs["string"] - visit_type.cbs["thread"] = visit_type.cbs["string"] - visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] - visit_type.cbs["emptytable"] = visit_type.cbs["string"] - visit_type.cbs["table_item"] = visit_type.cbs["string"] - visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"] - visit_type.cbs["tuple"] = visit_type.cbs["string"] - visit_type.cbs["poly"] = visit_type.cbs["string"] - visit_type.cbs["any"] = visit_type.cbs["string"] - visit_type.cbs["unknown"] = visit_type.cbs["string"] - visit_type.cbs["invalid"] = visit_type.cbs["string"] - visit_type.cbs["unresolved"] = visit_type.cbs["string"] - visit_type.cbs["none"] = visit_type.cbs["string"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["bad_nominal"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["unresolved"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor assert(ast.kind == "statements") recurse_node(ast, visit_node, visit_type) diff --git a/tl.tl b/tl.tl index 5fa28ac7e..38ee09716 100644 --- a/tl.tl +++ b/tl.tl @@ -4751,46 +4751,47 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | } local visit_type: Visitor = {} - visit_type.cbs = { - ["string"] = { - after = function(typ: Type, _children: {Output}): Output - local out: Output = { y = typ.y or -1, h = 0 } - local r = typ.resolved or typ - local lua_type = primitive[r.typename] - or (r.is_userdata and "userdata") - or "table" - table.insert(out, lua_type) - return out - end, - }, + visit_type.cbs = {} + local default_type_visitor = { + after = function(typ: Type, _children: {Output}): Output + local out: Output = { y = typ.y or -1, h = 0 } + local r = typ.resolved or typ + local lua_type = primitive[r.typename] + or (r.is_userdata and "userdata") + or "table" + table.insert(out, lua_type) + return out + end, } - visit_type.cbs["typetype"] = visit_type.cbs["string"] - visit_type.cbs["typevar"] = visit_type.cbs["string"] - visit_type.cbs["typearg"] = visit_type.cbs["string"] - visit_type.cbs["function"] = visit_type.cbs["string"] - visit_type.cbs["thread"] = visit_type.cbs["string"] - visit_type.cbs["array"] = visit_type.cbs["string"] - visit_type.cbs["map"] = visit_type.cbs["string"] - visit_type.cbs["tupletable"] = visit_type.cbs["string"] - visit_type.cbs["record"] = visit_type.cbs["string"] - visit_type.cbs["enum"] = visit_type.cbs["string"] - visit_type.cbs["boolean"] = visit_type.cbs["string"] - visit_type.cbs["nil"] = visit_type.cbs["string"] - visit_type.cbs["number"] = visit_type.cbs["string"] - visit_type.cbs["integer"] = visit_type.cbs["string"] - visit_type.cbs["union"] = visit_type.cbs["string"] - visit_type.cbs["nominal"] = visit_type.cbs["string"] - visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] - visit_type.cbs["emptytable"] = visit_type.cbs["string"] - visit_type.cbs["table_item"] = visit_type.cbs["string"] - visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"] - visit_type.cbs["tuple"] = visit_type.cbs["string"] - visit_type.cbs["poly"] = visit_type.cbs["string"] - visit_type.cbs["any"] = visit_type.cbs["string"] - visit_type.cbs["unknown"] = visit_type.cbs["string"] - visit_type.cbs["invalid"] = visit_type.cbs["string"] - visit_type.cbs["unresolved"] = visit_type.cbs["string"] - visit_type.cbs["none"] = visit_type.cbs["string"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typevar"] = default_type_visitor + visit_type.cbs["typearg"] = default_type_visitor + visit_type.cbs["function"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["record"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["union"] = default_type_visitor + visit_type.cbs["nominal"] = default_type_visitor + visit_type.cbs["bad_nominal"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["unresolved"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] @@ -11352,11 +11353,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local visit_type: Visitor visit_type = { cbs = { - ["string"] = { - after = function(typ: Type, _children: {Type}): Type - return typ - end, - }, ["function"] = { before = function(_typ: Type) begin_scope() @@ -11545,28 +11541,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.after = debug_type_after(visit_node.after) end - visit_type.cbs["tupletable"] = visit_type.cbs["string"] - visit_type.cbs["typetype"] = visit_type.cbs["string"] - visit_type.cbs["array"] = visit_type.cbs["string"] - visit_type.cbs["map"] = visit_type.cbs["string"] + local default_type_visitor = { + after = function(typ: Type, _children: {Type}): Type + return typ + end, + } + visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["enum"] = visit_type.cbs["string"] - visit_type.cbs["boolean"] = visit_type.cbs["string"] - visit_type.cbs["nil"] = visit_type.cbs["string"] - visit_type.cbs["number"] = visit_type.cbs["string"] - visit_type.cbs["integer"] = visit_type.cbs["string"] - visit_type.cbs["thread"] = visit_type.cbs["string"] - visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] - visit_type.cbs["emptytable"] = visit_type.cbs["string"] - visit_type.cbs["table_item"] = visit_type.cbs["string"] - visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"] - visit_type.cbs["tuple"] = visit_type.cbs["string"] - visit_type.cbs["poly"] = visit_type.cbs["string"] - visit_type.cbs["any"] = visit_type.cbs["string"] - visit_type.cbs["unknown"] = visit_type.cbs["string"] - visit_type.cbs["invalid"] = visit_type.cbs["string"] - visit_type.cbs["unresolved"] = visit_type.cbs["string"] - visit_type.cbs["none"] = visit_type.cbs["string"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["bad_nominal"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["unresolved"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor assert(ast.kind == "statements") recurse_node(ast, visit_node, visit_type) From d9117f087d5608b08f89ef076a42d28a8642ac30 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 28 Dec 2023 14:30:28 -0500 Subject: [PATCH 052/224] interfaces: expand field list --- tl.lua | 107 +++++++++++++++++++++++++++++++++++++++++++++++++-------- tl.tl | 107 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 184 insertions(+), 30 deletions(-) diff --git a/tl.lua b/tl.lua index 321ede2ad..dc19a7013 100644 --- a/tl.lua +++ b/tl.lua @@ -1253,6 +1253,7 @@ local table_types = { + local TruthyFact = {} @@ -3695,7 +3696,7 @@ local function recurse_type(ast, visit) end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - recurse_type(child, visit) + table.insert(xs, recurse_type(child, visit)) end end if ast.def then @@ -11350,6 +11351,61 @@ a.types[i], b.types[i]), } return t end + local expand_interfaces + do + local function add_interface_fields(what, fields, field_order, iface, orig_iface, list) + for fname, ftype in fields_of(iface, list) do + if fields[fname] then + if not is_a(fields[fname], ftype) then + error_at(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", orig_iface) + end + else + table.insert(field_order, fname) + fields[fname] = ftype + end + end + end + + local function expand(t, seen) + if t.interfaces_expanded then + return t + end + t.interfaces_expanded = true + if seen[t] then + return + end + seen[t] = true + + t.fields = t.fields or {} + t.meta_fields = t.meta_fields or {} + t.field_order = t.field_order or {} + t.meta_field_order = t.meta_field_order or {} + + + for _, iface in ipairs(t.interface_list) do + local orig_iface = iface + + if iface.typename == "nominal" then + iface = resolve_nominal(iface) + end + + if iface.typename == "interface" then + if iface.interface_list then + iface = expand(iface, seen) + end + + add_interface_fields("field", t.fields, t.field_order, iface, orig_iface) + add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, iface, orig_iface, "meta") + end + end + return t + end + + expand_interfaces = function(t) + return expand(t, {}) + end + end + local visit_type visit_type = { cbs = { @@ -11359,19 +11415,7 @@ a.types[i], b.types[i]), } end, after = function(typ, _children) end_scope() - typ = ensure_fresh_typeargs(typ) - - if typ.macroexp then - local macroexp_type = recurse_node(typ.macroexp, visit_node, visit_type) - - check_macroexp_arg_use(typ.macroexp) - - if not is_a(macroexp_type, typ) then - error_at(macroexp_type, "macroexp type does not match declaration") - end - end - - return typ + return ensure_fresh_typeargs(typ) end, }, ["record"] = { @@ -11391,7 +11435,6 @@ a.types[i], b.types[i]), } end end, after = function(typ, children) - end_scope() local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11399,12 +11442,23 @@ a.types[i], b.types[i]), } i = i + 1 end end + if typ.interface_list then + for j, _ in ipairs(typ.interface_list) do + typ.interface_list[j] = children[i] + i = i + 1 + end + end if typ.elements then typ.elements = children[i] i = i + 1 end + local fmacros for name, _ in fields_of(typ) do local ftype = children[i] + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end if ftype.typename == "function" and ftype.is_method then local fargs = ftype.args.tuple if fargs[1] and fargs[1].is_self then @@ -11432,9 +11486,32 @@ a.types[i], b.types[i]), } end for name, _ in fields_of(typ, "meta") do local ftype = children[i] + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end typ.meta_fields[name] = ftype i = i + 1 end + + if typ.interface_list then + expand_interfaces(typ) + end + + if fmacros then + for _, t in ipairs(fmacros) do + local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + + check_macroexp_arg_use(t.macroexp) + + if not is_a(macroexp_type, t) then + error_at(macroexp_type, "macroexp type does not match declaration") + end + end + end + + end_scope() + return typ end, }, diff --git a/tl.tl b/tl.tl index 38ee09716..aea6ad02e 100644 --- a/tl.tl +++ b/tl.tl @@ -1106,6 +1106,7 @@ local record Type -- records interface_list: {Type} + interfaces_expanded: boolean typeargs: {Type} fields: {string: Type} field_order: {string} @@ -3695,7 +3696,7 @@ local function recurse_type(ast: Type, visit: Visitor): T end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - recurse_type(child, visit) + table.insert(xs, recurse_type(child, visit)) end end if ast.def then @@ -11350,6 +11351,61 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end + local expand_interfaces: function(Type): Type + do + local function add_interface_fields(what: string, fields: {string:Type}, field_order: {string}, iface: Type, orig_iface: Type, list?: MetaMode) + for fname, ftype in fields_of(iface, list) do + if fields[fname] then + if not is_a(fields[fname], ftype) then + error_at(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", orig_iface) + end + else + table.insert(field_order, fname) + fields[fname] = ftype + end + end + end + + local function expand(t: Type, seen: {Type:boolean}): Type + if t.interfaces_expanded then + return t + end + t.interfaces_expanded = true + if seen[t] then + return + end + seen[t] = true + + t.fields = t.fields or {} + t.meta_fields = t.meta_fields or {} + t.field_order = t.field_order or {} + t.meta_field_order = t.meta_field_order or {} + + -- FIXME expand and collect interface_list recursively, THEN add fields + for _, iface in ipairs(t.interface_list) do + local orig_iface = iface + + if iface.typename == "nominal" then + iface = resolve_nominal(iface) + end + + if iface.typename == "interface" then + if iface.interface_list then + iface = expand(iface, seen) + end + + add_interface_fields("field", t.fields, t.field_order, iface, orig_iface) + add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, iface, orig_iface, "meta") + end + end + return t + end + + expand_interfaces = function(t: Type): Type + return expand(t, {}) + end + end + local visit_type: Visitor visit_type = { cbs = { @@ -11359,19 +11415,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, after = function(typ: Type, _children: {Type}): Type end_scope() - typ = ensure_fresh_typeargs(typ) - - if typ.macroexp then - local macroexp_type = recurse_node(typ.macroexp, visit_node, visit_type) - - check_macroexp_arg_use(typ.macroexp) - - if not is_a(macroexp_type, typ) then - error_at(macroexp_type, "macroexp type does not match declaration") - end - end - - return typ + return ensure_fresh_typeargs(typ) end, }, ["record"] = { @@ -11391,7 +11435,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end, after = function(typ: Type, children: {Type}): Type - end_scope() local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11399,12 +11442,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string i = i + 1 end end + if typ.interface_list then + for j, _ in ipairs(typ.interface_list) do + typ.interface_list[j] = children[i] + i = i + 1 + end + end if typ.elements then typ.elements = children[i] i = i + 1 end + local fmacros: {Type} for name, _ in fields_of(typ) do local ftype = children[i] + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end if ftype.typename == "function" and ftype.is_method then local fargs = ftype.args.tuple if fargs[1] and fargs[1].is_self then @@ -11432,9 +11486,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for name, _ in fields_of(typ, "meta") do local ftype = children[i] + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end typ.meta_fields[name] = ftype i = i + 1 end + + if typ.interface_list then + expand_interfaces(typ) + end + + if fmacros then + for _, t in ipairs(fmacros) do + local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + + check_macroexp_arg_use(t.macroexp) + + if not is_a(macroexp_type, t) then + error_at(macroexp_type, "macroexp type does not match declaration") + end + end + end + + end_scope() + return typ end, }, From da234d5bc2df556e95269747291fa4954e13b176 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 28 Dec 2023 15:04:58 -0500 Subject: [PATCH 053/224] interfaces: subtype checking --- tl.lua | 47 +++++++++++++++++++++++++++++++---------------- tl.tl | 49 ++++++++++++++++++++++++++++++++----------------- 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/tl.lua b/tl.lua index dc19a7013..acfb320ca 100644 --- a/tl.lua +++ b/tl.lua @@ -7414,6 +7414,29 @@ tl.type_check = function(ast, opts) return true end + local function find_in_interface_list(a, f) + if not a.interface_list then + return nil + end + + for _, t in ipairs(a.interface_list) do + local ret = f(t) + if ret then + return ret + end + end + + return nil + end + + local function subtype_interface(a, b) + if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + return true + end + + return same_type(a, b) + end + local function subtype_record(a, b) if a.elements and b.elements then @@ -7712,10 +7735,15 @@ tl.type_check = function(ast, opts) local ra = resolve_nominal(a) local rb = resolve_nominal(b) - if ra.typename == "union" or rb.typename == "union" then + if rb.typename == "interface" then + + return is_a(a, rb) + elseif ra.typename == "union" or rb.typename == "union" then + return is_a(ra, rb) end + return are_same_nominals(a, b) end, ["*"] = subtype_nominal, @@ -7740,6 +7768,7 @@ tl.type_check = function(ast, opts) ["number"] = compare_true, }, ["interface"] = { + ["interface"] = subtype_interface, ["array"] = subtype_array, ["record"] = subtype_record, ["tupletable"] = function(a, b) @@ -7790,6 +7819,7 @@ a.types[i], b.types[i]), } }, ["record"] = { ["record"] = subtype_record, + ["interface"] = subtype_interface, ["array"] = subtype_array, ["map"] = function(a, b) if not is_a(b.keys, STRING) then @@ -9943,21 +9973,6 @@ a.types[i], b.types[i]), } return is_total, missing end - local function find_in_interface_list(a, f) - if not a.interface_list then - return nil - end - - for _, t in ipairs(a.interface_list) do - local ret = f(t) - if ret then - return ret - end - end - - return nil - end - diff --git a/tl.tl b/tl.tl index aea6ad02e..41e32212d 100644 --- a/tl.tl +++ b/tl.tl @@ -7414,6 +7414,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end + local function find_in_interface_list(a: Type, f: function(Type): T): T + if not a.interface_list then + return nil + end + + for _, t in ipairs(a.interface_list) do + local ret = f(t) + if ret then + return ret + end + end + + return nil + end + + local function subtype_interface(a: Type, b: Type): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + return true + end + + return same_type(a, b) + end + local function subtype_record(a: Type, b: Type): boolean, {Error} -- assert(b.typename == "record") if a.elements and b.elements then @@ -7711,10 +7734,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["nominal"] = function(a: Type, b: Type): boolean, {Error} local ra = resolve_nominal(a) local rb = resolve_nominal(b) - -- match unions structurally - if ra.typename == "union" or rb.typename == "union" then + + if rb.typename == "interface" then + -- match interface subtyping + return is_a(a, rb) + elseif ra.typename == "union" or rb.typename == "union" then + -- match unions structurally return is_a(ra, rb) end + -- all other types nominally return are_same_nominals(a, b) end, @@ -7740,6 +7768,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["number"] = compare_true, }, ["interface"] = { + ["interface"] = subtype_interface, ["array"] = subtype_array, ["record"] = subtype_record, ["tupletable"] = function(a: Type, b: Type): boolean, {Error} @@ -7790,6 +7819,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["record"] = { ["record"] = subtype_record, + ["interface"] = subtype_interface, ["array"] = subtype_array, ["map"] = function(a: Type, b: Type): boolean, {Error} if not is_a(b.keys, STRING) then @@ -9943,21 +9973,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function find_in_interface_list(a: Type, f: function(Type): T): T - if not a.interface_list then - return nil - end - - for _, t in ipairs(a.interface_list) do - local ret = f(t) - if ret then - return ret - end - end - - return nil - end - local enum MissingError "missing" end From f7ac7de3df241defa750b0c4ad5c30dd7cbd7b98 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 28 Dec 2023 15:59:03 -0500 Subject: [PATCH 054/224] interface constraints in type arguments --- tl.lua | 56 +++++++++++++++++++++++++++++++++++++++++++++------- tl.tl | 62 ++++++++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 101 insertions(+), 17 deletions(-) diff --git a/tl.lua b/tl.lua index acfb320ca..fafccde90 100644 --- a/tl.lua +++ b/tl.lua @@ -1253,6 +1253,8 @@ local table_types = { + + @@ -1884,11 +1886,19 @@ local function parse_anglebracket_list(ps, i, parse_item) end local function parse_typearg(ps, i) + local name = ps.tokens[i].tk + local interface_constraint i = verify_kind(ps, i, "identifier") + if ps.tokens[i].tk == "is" then + i = i + 1 + interface_constraint = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + end return i, a_type("typearg", { y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, - typearg = ps.tokens[i - 1].tk, + typearg = name, + interface_name = interface_constraint, }) end @@ -6300,12 +6310,16 @@ tl.type_check = function(ast, opts) local function fresh_typevar(t) return a_type("typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + interface_name = t.interface_name, + interface_type = t.interface_type, }) end local function fresh_typearg(t) return a_type("typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + interface_name = t.interface_name, + interface_type = t.interface_type, }) end @@ -6326,7 +6340,7 @@ tl.type_check = function(ast, opts) if var then local t = var.t if t.typename == "unresolved_typearg" then - return nil + return nil, nil, t.interface_type end t = ensure_fresh_typeargs(t) return t, var.attribute @@ -6631,11 +6645,19 @@ tl.type_check = function(ast, opts) copy = fn_arg(t) else copy.typearg = t.typearg + copy.interface_name = t.interface_name + if t.interface_type then + copy.interface_type, same = resolve(t.interface_type, same) + end end elseif t.typename == "unresolvable_typearg" then copy.typearg = t.typearg elseif t.typename == "typevar" then copy.typevar = t.typevar + copy.interface_name = t.interface_name + if t.interface_type then + copy.interface_type, same = resolve(t.interface_type, same) + end elseif is_typetype(t) then copy.def, same = resolve(t.def, same) elseif t.typename == "nominal" then @@ -7530,13 +7552,22 @@ tl.type_check = function(ast, opts) - local vt = find_var_type(typevar) + local vt, _, interface_type = find_var_type(typevar) if vt then return cmp(a or vt, b or vt) else - local ok, r, errs = resolve_typevars(a or b) + local other = a or b + + + if interface_type then + if not is_a(other, interface_type) then + return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. typevar, other, interface_type) } + end + end + + local ok, r, errs = resolve_typevars(other) if not ok then return false, errs end @@ -8353,7 +8384,9 @@ a.types[i], b.types[i]), } local function push_typeargs(func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", {})) + add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + interface_type = func.interface_name and find_var_type(func.interface_name), + })) end end end @@ -9039,7 +9072,10 @@ a.types[i], b.types[i]), } if t.def.typeargs then typevals = {} for _, a in ipairs(t.def.typeargs) do - table.insert(typevals, a_type("typevar", { typevar = a.typearg })) + table.insert(typevals, a_type("typevar", { + typevar = a.typearg, + interface_name = a.interface_name, + })) end end return type_at(where, a_type("nominal", { @@ -11532,7 +11568,11 @@ a.types[i], b.types[i]), } }, ["typearg"] = { after = function(typ, _children) - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + typearg = typ.typearg, + interface_name = typ.interface_name, + interface_type = typ.interface_name and find_var_type(typ.interface_name), + }))) return typ end, }, @@ -11557,6 +11597,8 @@ a.types[i], b.types[i]), } edit_type(typ, "typevar") typ.names = nil typ.typevar = t.typearg + typ.interface_name = t.interface_name + typ.interface_type = t.interface_type else if t.is_alias then t = t.def.resolved diff --git a/tl.tl b/tl.tl index 41e32212d..e481fa90e 100644 --- a/tl.tl +++ b/tl.tl @@ -1141,6 +1141,8 @@ local record Type -- typearg typearg: string + interface_name: string + interface_type: Type -- table items kname: string @@ -1884,11 +1886,19 @@ local function parse_anglebracket_list(ps: ParseState, i: integer, parse_item: P end local function parse_typearg(ps: ParseState, i: integer): integer, Type, integer + local name = ps.tokens[i].tk + local interface_constraint: string i = verify_kind(ps, i, "identifier") + if ps.tokens[i].tk == "is" then + i = i + 1 + interface_constraint = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") -- FIXME generic interfaces... + end return i, a_type("typearg", { y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, - typearg = ps.tokens[i-1].tk, + typearg = name, + interface_name = interface_constraint, }) end @@ -6299,13 +6309,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function fresh_typevar(t: Type): Type, Type, boolean return a_type("typevar", { - typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr + typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + interface_name = t.interface_name, + interface_type = t.interface_type, }) end local function fresh_typearg(t: Type): Type return a_type("typearg", { - typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr + typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + interface_name = t.interface_name, + interface_type = t.interface_type, }) end @@ -6321,12 +6335,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function find_var_type(name: string, use?: VarUse): Type, Attribute + local function find_var_type(name: string, use?: VarUse): Type, Attribute, Type local var = find_var(name, use) if var then local t = var.t if t.typename == "unresolved_typearg" then - return nil + return nil, nil, t.interface_type end t = ensure_fresh_typeargs(t) return t, var.attribute @@ -6631,11 +6645,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy = fn_arg(t) else copy.typearg = t.typearg + copy.interface_name = t.interface_name + if t.interface_type then + copy.interface_type, same = resolve(t.interface_type, same) + end end elseif t.typename == "unresolvable_typearg" then copy.typearg = t.typearg elseif t.typename == "typevar" then copy.typevar = t.typevar + copy.interface_name = t.interface_name + if t.interface_type then + copy.interface_type, same = resolve(t.interface_type, same) + end elseif is_typetype(t) then copy.def, same = resolve(t.def, same) elseif t.typename == "nominal" then @@ -7530,13 +7552,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- assert((a == nil and b ~= nil) or (a ~= nil and b == nil)) -- does the typevar currently match to a type? - local vt = find_var_type(typevar) + local vt, _, interface_type = find_var_type(typevar) if vt then -- If so, compare it to the other type return cmp(a or vt, b or vt) else -- otherwise, infer it to the other type - local ok, r, errs = resolve_typevars(a or b) + local other = a or b + + -- but check interface constraint first if present + if interface_type then + if not is_a(other, interface_type) then + return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. typevar, other, interface_type) } + end + end + + local ok, r, errs = resolve_typevars(other) if not ok then return false, errs end @@ -8353,7 +8384,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function push_typeargs(func: Type) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", {})) + add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + interface_type = func.interface_name and find_var_type(func.interface_name), + })) end end end @@ -9039,7 +9072,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.def.typeargs then typevals = {} for _, a in ipairs(t.def.typeargs) do - table.insert(typevals, a_type("typevar", { typevar = a.typearg })) + table.insert(typevals, a_type("typevar", { + typevar = a.typearg, + interface_name = a.interface_name, + })) end end return type_at(where, a_type("nominal", { @@ -11532,7 +11568,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["typearg"] = { after = function(typ: Type, _children: {Type}): Type - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + typearg = typ.typearg, + interface_name = typ.interface_name, + interface_type = typ.interface_name and find_var_type(typ.interface_name), + }))) return typ end, }, @@ -11557,6 +11597,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string edit_type(typ, "typevar") typ.names = nil typ.typevar = t.typearg + typ.interface_name = t.interface_name + typ.interface_type = t.interface_type else if t.is_alias then t = t.def.resolved From f5ea5e5ffb04fb20acb8edddc12eea06afc4fd78 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 28 Dec 2023 16:04:58 -0500 Subject: [PATCH 055/224] match record keys in constrained type variables --- tl.lua | 15 ++++++++++++--- tl.tl | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tl.lua b/tl.lua index fafccde90..f38c2e86e 100644 --- a/tl.lua +++ b/tl.lua @@ -8577,9 +8577,7 @@ a.types[i], b.types[i]), } end end - local match_record_key - - match_record_key = function(tbl, rec, key) + local function match_record_key(tbl, rec, key) assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") @@ -8606,6 +8604,14 @@ a.types[i], b.types[i]), } end end + if tbl.typename == "typevar" and tbl.interface_type then + local t = match_record_key(tbl.interface_type, rec, key) + + if t then + return t + end + end + if is_record_type(tbl) then assert(tbl.fields, "record has no fields!?") @@ -11581,6 +11587,9 @@ a.types[i], b.types[i]), } if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end + if typ.interface_name then + typ.interface_type = find_var_type(typ.interface_name) + end return typ end, }, diff --git a/tl.tl b/tl.tl index e481fa90e..f582d5296 100644 --- a/tl.tl +++ b/tl.tl @@ -8577,9 +8577,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local match_record_key: function(tbl: Type, rec: Node, key: string): Type, string - - match_record_key = function(tbl: Type, rec: Node, key: string): Type, string + local function match_record_key(tbl: Type, rec: Node, key: string): Type, string assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") @@ -8606,6 +8604,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + if tbl.typename == "typevar" and tbl.interface_type then + local t = match_record_key(tbl.interface_type, rec, key) + + if t then + return t + end + end + if is_record_type(tbl) then assert(tbl.fields, "record has no fields!?") @@ -11581,6 +11587,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end + if typ.interface_name then + typ.interface_type = find_var_type(typ.interface_name) + end return typ end, }, From 179e669d996aac47f0005da16c3f1e7239bbeeec Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 28 Dec 2023 18:43:54 -0500 Subject: [PATCH 056/224] arg_check: explicit variance rules --- tl.lua | 57 ++++++++++++++++++++++++++++++++++++++++++++------------- tl.tl | 57 ++++++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/tl.lua b/tl.lua index f38c2e86e..e3eae1517 100644 --- a/tl.lua +++ b/tl.lua @@ -6986,10 +6986,41 @@ tl.type_check = function(ast, opts) local same_type local is_a - local function arg_check(where, cmp, a, b, n, errs, ctx) - local matches, match_errs = cmp(a, b) - if not matches then - add_errs_prefixing(where, match_errs, errs, ctx .. (n and " " .. n or "") .. ": ") + + + + + + + + + + + + + + local function arg_check(where, all_errs, a, b, v, mode, n) + local ok, errs + + if v == "covariant" then + ok, errs = is_a(a, b) + elseif v == "contravariant" then + ok, errs = is_a(b, a) + elseif v == "bivariant" then + ok, errs = is_a(a, b) + if ok then + return true + end + ok = is_a(b, a) + if ok then + return true + end + elseif v == "invariant" then + ok, errs = same_type(a, b) + end + + if not ok then + add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") return false end return true @@ -7672,10 +7703,10 @@ tl.type_check = function(ast, opts) end local errs = {} for i = 1, naargs do - arg_check(a, same_type, a.args.tuple[i], b.args.tuple[i], i - argdelta, errs, "argument") + arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, same_type, a.rets.tuple[i], b.rets.tuple[i], i, errs, "return") + arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, @@ -7924,7 +7955,7 @@ a.types[i], b.types[i]), } table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, is_a, aa[i], ba[i] or ANY, i, errs, "argument") + arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end @@ -7938,7 +7969,7 @@ a.types[i], b.types[i]), } nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, is_a, ar[i], br[i], i, errs, "return") + arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8318,7 +8349,7 @@ a.types[i], b.types[i]), } local check_args_rets do - local function check_func_type_list(where, wheres, xs, ys, from, delta, mode) + local function check_func_type_list(where, wheres, xs, ys, from, delta, v, mode) assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8333,7 +8364,7 @@ a.types[i], b.types[i]), } local y = yt[i] or (ys.is_va and yt[n_ys]) if y then local w = wheres and wheres[pos] or where - if not arg_check(w, is_a, x, y, pos, errs, mode) then + if not arg_check(w, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8353,7 +8384,7 @@ a.types[i], b.types[i]), } if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, is_a, args.tuple[1], fargs[1], nil, errs, "self") then + if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end @@ -8362,10 +8393,10 @@ a.types[i], b.types[i]), } expected_rets = infer_at(where, expected_rets) infer_emptytables(where, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "return") + rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, from, argdelta, "argument") + args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end diff --git a/tl.tl b/tl.tl index f582d5296..56d904e3f 100644 --- a/tl.tl +++ b/tl.tl @@ -6986,10 +6986,41 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local same_type: function(t1: Type, t2: Type): boolean, {Error} local is_a: function(Type, Type): boolean, {Error} - local function arg_check(where: Where, cmp: CompareTypes, a: Type, b: Type, n: integer, errs: {Error}, ctx: string): boolean - local matches, match_errs = cmp(a, b) - if not matches then - add_errs_prefixing(where, match_errs, errs, ctx .. (n and " " .. n or "") .. ": ") + local enum ArgCheckMode + "argument" + "return" + "self" + end + + local enum VarianceMode + "covariant" + "contravariant" + "bivariant" + "invariant" + end + + local function arg_check(where: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean + local ok, errs: boolean, {Error} + + if v == "covariant" then + ok, errs = is_a(a, b) + elseif v == "contravariant" then + ok, errs = is_a(b, a) + elseif v == "bivariant" then + ok, errs = is_a(a, b) + if ok then + return true + end + ok = is_a(b, a) + if ok then + return true + end + elseif v == "invariant" then + ok, errs = same_type(a, b) + end + + if not ok then + add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") return false end return true @@ -7672,10 +7703,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local errs = {} for i = 1, naargs do - arg_check(a, same_type, a.args.tuple[i], b.args.tuple[i], i - argdelta, errs, "argument") + arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, same_type, a.rets.tuple[i], b.rets.tuple[i], i, errs, "return") + arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, @@ -7924,7 +7955,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, is_a, aa[i], ba[i] or ANY, i, errs, "argument") + arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end @@ -7938,7 +7969,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, is_a, ar[i], br[i], i, errs, "return") + arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8318,7 +8349,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} do -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, mode: string): boolean, {Error} + local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8333,7 +8364,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local y = yt[i] or (ys.is_va and yt[n_ys]) if y then local w = wheres and wheres[pos] or where - if not arg_check(w, is_a, x, y, pos, errs, mode) then + if not arg_check(w, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8353,7 +8384,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, is_a, args.tuple[1], fargs[1], nil, errs, "self") then + if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end @@ -8362,10 +8393,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string expected_rets = infer_at(where, expected_rets) infer_emptytables(where, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "return") + rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, from, argdelta, "argument") + args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end From 9e996701ea1bb2ef4d42601de4c5b19644b8203b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 28 Dec 2023 18:50:30 -0500 Subject: [PATCH 057/224] wip subtype interface rework --- spec/declaration/record_spec.lua | 14 ++------------ tl.lua | 25 +++++++++++++++---------- tl.tl | 25 +++++++++++++++---------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 60787d933..7dc5b018d 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -804,19 +804,9 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces f(bar) ]], { { y = 5, msg = "in local declaration: foo: got {}, expected Foo" }, - select(i, - { y = 6, msg = "in assignment: record is not a userdata" }, - { y = 6, msg = "in assignment: record is not a userdata" }, - { y = 6, msg = "in assignment: got record (a: integer), expected Foo" }, - { y = 6, msg = "in assignment: got record (a: integer), expected Foo" } - ), + { y = 6, msg = "in assignment: record is not a userdata" }, { y = 8, msg = "argument 1: got {}, expected Foo" }, - select(i, - { y = 9, msg = "argument 1: record is not a userdata" }, - { y = 9, msg = "argument 1: record is not a userdata" }, - { y = 9, msg = "argument 1: got record (a: integer), expected Foo" }, - { y = 9, msg = "argument 1: got record (a: integer), expected Foo" } - ), + { y = 9, msg = "argument 1: record is not a userdata" }, nil })) diff --git a/tl.lua b/tl.lua index e3eae1517..7071d1a0f 100644 --- a/tl.lua +++ b/tl.lua @@ -7482,14 +7482,6 @@ tl.type_check = function(ast, opts) return nil end - local function subtype_interface(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then - return true - end - - return same_type(a, b) - end - local function subtype_record(a, b) if a.elements and b.elements then @@ -7830,7 +7822,12 @@ tl.type_check = function(ast, opts) ["number"] = compare_true, }, ["interface"] = { - ["interface"] = subtype_interface, + ["interface"] = function(a, b) + if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + return true + end + return same_type(a, b) + end, ["array"] = subtype_array, ["record"] = subtype_record, ["tupletable"] = function(a, b) @@ -7881,7 +7878,15 @@ a.types[i], b.types[i]), } }, ["record"] = { ["record"] = subtype_record, - ["interface"] = subtype_interface, + ["interface"] = function(a, b) + if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + return true + end + if not a.names then + + return subtype_record(a, b) + end + end, ["array"] = subtype_array, ["map"] = function(a, b) if not is_a(b.keys, STRING) then diff --git a/tl.tl b/tl.tl index 56d904e3f..19adcc670 100644 --- a/tl.tl +++ b/tl.tl @@ -7482,14 +7482,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - local function subtype_interface(a: Type, b: Type): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then - return true - end - - return same_type(a, b) - end - local function subtype_record(a: Type, b: Type): boolean, {Error} -- assert(b.typename == "record") if a.elements and b.elements then @@ -7830,7 +7822,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["number"] = compare_true, }, ["interface"] = { - ["interface"] = subtype_interface, + ["interface"] = function(a: Type, b: Type): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + return true + end + return same_type(a, b) + end, ["array"] = subtype_array, ["record"] = subtype_record, ["tupletable"] = function(a: Type, b: Type): boolean, {Error} @@ -7881,7 +7878,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["record"] = { ["record"] = subtype_record, - ["interface"] = subtype_interface, + ["interface"] = function(a: Type, b: Type): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + return true + end + if not a.names then + -- match inferred table (anonymous record) structurally to interface + return subtype_record(a, b) + end + end, ["array"] = subtype_array, ["map"] = function(a: Type, b: Type): boolean, {Error} if not is_a(b.keys, STRING) then From a35728216d604489afdbb29a94b440d51a83b100 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 11:05:12 -0500 Subject: [PATCH 058/224] refactor: display_typevar --- tl.lua | 12 ++++++++---- tl.tl | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tl.lua b/tl.lua index 7071d1a0f..32c5f7769 100644 --- a/tl.lua +++ b/tl.lua @@ -5266,6 +5266,10 @@ local function is_unknown(t) t.typename == "unresolved_emptytable_value" end +local function display_typevar(typevar) + return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) +end + local function show_type_base(t, short, seen) if seen[t] then @@ -5413,11 +5417,11 @@ local function show_type_base(t, short, seen) (t.tk and " " .. t.tk or "") end elseif t.typename == "typevar" then - return TL_DEBUG and t.typevar or (t.typevar:gsub("@.*", "")) + return display_typevar(t.typevar) elseif t.typename == "typearg" then - return TL_DEBUG and t.typearg or (t.typearg:gsub("@.*", "")) + return display_typevar(t.typearg) elseif t.typename == "unresolvable_typearg" then - return (TL_DEBUG and t.typearg or (t.typearg:gsub("@.*", ""))) .. " (unresolved generic)" + return display_typevar(t.typearg) .. " (unresolved generic)" elseif is_unknown(t) then return "" elseif t.typename == "invalid" then @@ -7586,7 +7590,7 @@ tl.type_check = function(ast, opts) if interface_type then if not is_a(other, interface_type) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. typevar, other, interface_type) } + return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, interface_type) } end end diff --git a/tl.tl b/tl.tl index 19adcc670..61a9eb4e0 100644 --- a/tl.tl +++ b/tl.tl @@ -5266,6 +5266,10 @@ local function is_unknown(t: Type): boolean or t.typename == "unresolved_emptytable_value" end +local function display_typevar(typevar: string): string + return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) +end + local function show_type_base(t: Type, short: boolean, seen: {Type:string}): string -- FIXME this is a control for recursively built types, which should in principle not exist if seen[t] then @@ -5413,11 +5417,11 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str (t.tk and " " .. t.tk or "") end elseif t.typename == "typevar" then - return TL_DEBUG and t.typevar or (t.typevar:gsub("@.*", "")) + return display_typevar(t.typevar) elseif t.typename == "typearg" then - return TL_DEBUG and t.typearg or (t.typearg:gsub("@.*", "")) + return display_typevar(t.typearg) elseif t.typename == "unresolvable_typearg" then - return (TL_DEBUG and t.typearg or (t.typearg:gsub("@.*", ""))) .. " (unresolved generic)" + return display_typevar(t.typearg) .. " (unresolved generic)" elseif is_unknown(t) then return "" elseif t.typename == "invalid" then @@ -7586,7 +7590,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- but check interface constraint first if present if interface_type then if not is_a(other, interface_type) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. typevar, other, interface_type) } + return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, interface_type) } end end From 2dbdefab29da950be51e9d5933830521bdf69695 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 12:35:53 -0500 Subject: [PATCH 059/224] resolve_interface_type --- tl.lua | 30 +++++++++++++++++++++++++----- tl.tl | 30 +++++++++++++++++++++++++----- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/tl.lua b/tl.lua index 32c5f7769..0ba8c80dc 100644 --- a/tl.lua +++ b/tl.lua @@ -8321,6 +8321,13 @@ a.types[i], b.types[i]), } orignode.known = saveknown end + local function resolve_interface_type(interface_name) + if not interface_name then + return nil + end + return resolve_typetype((find_var_type(interface_name, "use_type"))) + end + local type_check_function_call do local function mark_invalid_typeargs(f) @@ -8425,7 +8432,8 @@ a.types[i], b.types[i]), } if func.typeargs then for _, fnarg in ipairs(func.typeargs) do add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { - interface_type = func.interface_name and find_var_type(func.interface_name), + interface_name = fnarg.interface_name, + interface_type = resolve_interface_type(fnarg.interface_name), })) end end @@ -10520,6 +10528,11 @@ a.types[i], b.types[i]), } before = function(node) if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) + + if decltype.typename == "typevar" and decltype.interface_type then + decltype = decltype.interface_type + end + if decltype.typename == "tupletable" then for _, child in ipairs(node) do local n = child.key.constnum @@ -10557,6 +10570,11 @@ a.types[i], b.types[i]), } end local decltype = resolve_tuple_and_nominal(node.expected) + local interface_type = decltype.typename == "typevar" and decltype.interface_type + + if interface_type then + decltype = interface_type + end if decltype.typename == "union" then local single_table_type @@ -10670,6 +10688,10 @@ a.types[i], b.types[i]), } t.is_total, t.missing = total_map_check(decltype, seen_keys) end + if interface_type then + return interface_type + end + return t end, }, @@ -11617,7 +11639,7 @@ a.types[i], b.types[i]), } add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, interface_name = typ.interface_name, - interface_type = typ.interface_name and find_var_type(typ.interface_name), + interface_type = resolve_interface_type(typ.interface_name), }))) return typ end, @@ -11627,9 +11649,7 @@ a.types[i], b.types[i]), } if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end - if typ.interface_name then - typ.interface_type = find_var_type(typ.interface_name) - end + typ.interface_type = resolve_interface_type(typ.interface_name) return typ end, }, diff --git a/tl.tl b/tl.tl index 61a9eb4e0..bfaf3f9f4 100644 --- a/tl.tl +++ b/tl.tl @@ -8321,6 +8321,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end + local function resolve_interface_type(interface_name: string): Type + if not interface_name then + return nil + end + return resolve_typetype((find_var_type(interface_name, "use_type"))) + end + local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): Type, Type do local function mark_invalid_typeargs(f: Type) @@ -8425,7 +8432,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func.typeargs then for _, fnarg in ipairs(func.typeargs) do add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { - interface_type = func.interface_name and find_var_type(func.interface_name), + interface_name = fnarg.interface_name, + interface_type = resolve_interface_type(fnarg.interface_name), })) end end @@ -10520,6 +10528,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string before = function(node: Node) if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) + + if decltype.typename == "typevar" and decltype.interface_type then + decltype = decltype.interface_type + end + if decltype.typename == "tupletable" then for _, child in ipairs(node) do local n = child.key.constnum @@ -10557,6 +10570,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local decltype = resolve_tuple_and_nominal(node.expected) + local interface_type = decltype.typename == "typevar" and decltype.interface_type + + if interface_type then + decltype = interface_type + end if decltype.typename == "union" then local single_table_type: Type @@ -10670,6 +10688,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t.is_total, t.missing = total_map_check(decltype, seen_keys) end + if interface_type then + return interface_type + end + return t end, }, @@ -11617,7 +11639,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, interface_name = typ.interface_name, - interface_type = typ.interface_name and find_var_type(typ.interface_name), + interface_type = resolve_interface_type(typ.interface_name), }))) return typ end, @@ -11627,9 +11649,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end - if typ.interface_name then - typ.interface_type = find_var_type(typ.interface_name) - end + typ.interface_type = resolve_interface_type(typ.interface_name) return typ end, }, From 5c0f475791c5befd107d4b66bba70a8a009eb2f3 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 12:54:10 -0500 Subject: [PATCH 060/224] EnumType: begin use interface subtyping! --- tl.lua | 13 ++++++++----- tl.tl | 51 +++++++++++++++++++++++++++------------------------ 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tl.lua b/tl.lua index 0ba8c80dc..3d3f00026 100644 --- a/tl.lua +++ b/tl.lua @@ -1256,6 +1256,7 @@ local table_types = { + local TruthyFact = {} @@ -7898,7 +7899,8 @@ a.types[i], b.types[i]), } end for _, k in ipairs(a.field_order) do - if b.keys.typename == "enum" and not b.keys.enumset[k] then + local bk = b.keys + if bk.typename == "enum" and not bk.enumset[k] then return false, { Err(a, "key is not an enum value: " .. k) } end if not is_a(a.fields[k], b.values) then @@ -11438,11 +11440,12 @@ a.types[i], b.types[i]), } visit_node.cbs["string"] = { after = function(node, _children) local t = after_literal(node) - if node.expected then - if node.expected.typename == "enum" and is_a(t, node.expected) then - t = node.expected - end + + local expected = node.expected + if expected and expected.typename == "enum" and is_a(t, expected) then + return node.expected end + return t end, } diff --git a/tl.tl b/tl.tl index bfaf3f9f4..023bdf3f4 100644 --- a/tl.tl +++ b/tl.tl @@ -1067,7 +1067,7 @@ local table_types : {TypeName:boolean} = { ["*"] = false, } -local record Type +local interface Type where self.typename ~= nil y: integer @@ -1155,9 +1155,6 @@ local record Type inferred_at: Where emptytable_type: Type - -- enum - enumset: {string:boolean} - -- macroexp macroexp: Node @@ -1168,6 +1165,10 @@ local record Type narrows: {string:boolean} end +local record EnumType is Type where self.typename == "enum" + enumset: {string:boolean} +end + local type TupleType = Type local record Operator @@ -1498,7 +1499,7 @@ local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer): integer, Node local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node -local parse_enum_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node +local parse_enum_body: function(ps: ParseState, i: integer, def: EnumType, node: Node): integer, Node local parse_record_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node local parse_type_body_fns: {TypeName:ParseBody} @@ -1549,7 +1550,7 @@ local function new_node(tokens: {Token}, i: integer, kind?: NodeKind): Node return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end -local function a_type(typename: TypeName, t: Type): Type +local function a_type(typename: TypeName, t: T): T t.typeid = new_typeid() t.typename = typename return t @@ -2936,7 +2937,7 @@ local function parse_nested_type(ps: ParseState, i: integer, def: Type, typename return i end -parse_enum_body = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node +parse_enum_body = function(ps: ParseState, i: integer, def: EnumType, node: Node): integer, Node local istart = i - 1 def.enumset = {} while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do @@ -4972,7 +4973,7 @@ get_typenum = function(trenv:TypeReportEnv, t: Type): integer if rt.typename == "map" then ti.keys = get_typenum(trenv, rt.keys) ti.values = get_typenum(trenv, rt.values) - elseif rt.typename == "enum" then + elseif rt is EnumType then ti.enums = mark_array(sorted_keys(rt.enumset)) elseif rt.typename == "function" then store_function(trenv, ti, rt) @@ -5710,13 +5711,13 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} end local function a_grecord(n: integer, f: function(...: Type): Type): Type - local t = a_gfunction(n, f, "record") + local t = a_gfunction(n, f, "record") as Type -- FIXME t.field_order = sorted_keys(t.fields) return t end - local function an_enum(keys: {string}): Type - local t = a_type("enum", { enumset = {} }) + local function an_enum(keys: {string}): EnumType + local t = a_type("enum", { enumset = {} } as EnumType) for _, k in ipairs(keys) do t.enumset[k] = true end @@ -7810,7 +7811,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a: Type, b: Type): boolean, {Error} + ["enum"] = function(a: Type, b: EnumType): boolean, {Error} if not a.tk then return false, { Err(a, "string is not a %s", b) } end @@ -7898,7 +7899,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for _, k in ipairs(a.field_order) do - if b.keys.typename == "enum" and not b.keys.enumset[k] then + local bk = b.keys + if bk is EnumType and not bk.enumset[k] then return false, { Err(a, "key is not an enum value: " .. k) } end if not is_a(a.fields[k], b.values) then @@ -8632,7 +8634,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tbl = resolve_tuple_and_nominal(tbl) - if tbl.typename == "string" or tbl.typename == "enum" then + if tbl.typename == "string" or tbl is EnumType then tbl = find_var_type("string") -- simulate string metatable end @@ -9017,7 +9019,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra = e, orig_a elseif is_record_type(a) then - if b.typename == "enum" then + if b is EnumType then local field_names: {string} = sorted_keys(b.enumset) for _, k in ipairs(field_names) do if not a.fields[k] then @@ -10049,7 +10051,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local k = resolve_tuple_and_nominal(t.keys) local is_total = true local missing: {string} - if k.typename == "enum" then + if k is EnumType then for _, key in ipairs(sorted_keys(k.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -11153,10 +11155,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = a - elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) - or (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then + elseif ((ra is EnumType and rb.typename == "string" and is_a(rb, ra)) + or (ra.typename == "string" and rb is EnumType and is_a(ra, rb))) then node.known = nil - t = (ra.typename == "enum" and ra or rb) + t = (ra is EnumType and ra or rb) elseif expected and expected.typename == "union" then -- must be checked after string/enum above @@ -11196,7 +11198,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end - if ra.typename == "enum" and rb.typename == "string" then + if ra is EnumType and rb.typename == "string" then if not (rb.tk and ra.enumset[unquote(rb.tk)]) then return invalid_at(node, "%s is not a member of %s", b, a) end @@ -11438,11 +11440,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["string"] = { after = function(node: Node, _children: {Type}): Type local t = after_literal(node) - if node.expected then - if node.expected.typename == "enum" and is_a(t, node.expected) then - t = node.expected - end + + local expected = node.expected + if expected and expected is EnumType and is_a(t, expected) then + return node.expected end + return t end, } From 7d9434aae4f9b8df96f9ac15865bacc71332dd21 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 14:23:43 -0500 Subject: [PATCH 061/224] narrow interfaces on assignment --- tl.lua | 2 +- tl.tl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tl.lua b/tl.lua index 3d3f00026..2f10b990b 100644 --- a/tl.lua +++ b/tl.lua @@ -10275,7 +10275,7 @@ a.types[i], b.types[i]), } widen_all_unions() end - if varname and rvar.typename == "union" then + if varname and (rvar.typename == "union" or rvar.typename == "interface") then add_var(varnode, varname, rval, nil, "narrow") end diff --git a/tl.tl b/tl.tl index 023bdf3f4..cf1143df2 100644 --- a/tl.tl +++ b/tl.tl @@ -10275,8 +10275,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions() end - if varname and rvar.typename == "union" then - -- narrow union + if varname and (rvar.typename == "union" or rvar.typename == "interface") then + -- narrow unions and interfaces add_var(varnode, varname, rval, nil, "narrow") end From 69a5393a75e5ece7c8f363a7c30cb6f01189f880 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 16:25:25 -0500 Subject: [PATCH 062/224] refactor: simplify interface_type to constraint --- tl.lua | 78 ++++++++++++++++++++++++-------------------------------- tl.tl | 80 +++++++++++++++++++++++++--------------------------------- 2 files changed, 67 insertions(+), 91 deletions(-) diff --git a/tl.lua b/tl.lua index 2f10b990b..4c3033f0f 100644 --- a/tl.lua +++ b/tl.lua @@ -1256,7 +1256,6 @@ local table_types = { - local TruthyFact = {} @@ -1497,6 +1496,7 @@ local parse_argument_list local parse_argument_type_list local parse_type local parse_newtype +local parse_interface_name local parse_enum_body @@ -1888,18 +1888,17 @@ end local function parse_typearg(ps, i) local name = ps.tokens[i].tk - local interface_constraint + local constraint i = verify_kind(ps, i, "identifier") if ps.tokens[i].tk == "is" then i = i + 1 - interface_constraint = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") + i, constraint = parse_interface_name(ps, i) end return i, a_type("typearg", { y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, typearg = name, - interface_name = interface_constraint, + constraint = constraint, }) end @@ -3013,7 +3012,7 @@ local function parse_where_clause(ps, i) return i, node end -local function parse_interface_name(ps, i) +parse_interface_name = function(ps, i) local istart = i local typ i, typ = parse_simple_type_or_nominal(ps, i) @@ -3755,6 +3754,9 @@ local function recurse_type(ast, visit) if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end + if ast.constraint then + table.insert(xs, recurse_type(ast.constraint, visit)) + end local ret local cbkind_after = cbkind and cbkind.after @@ -6315,16 +6317,14 @@ tl.type_check = function(ast, opts) local function fresh_typevar(t) return a_type("typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - interface_name = t.interface_name, - interface_type = t.interface_type, + constraint = t.constraint, }) end local function fresh_typearg(t) return a_type("typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - interface_name = t.interface_name, - interface_type = t.interface_type, + constraint = t.constraint, }) end @@ -6345,7 +6345,7 @@ tl.type_check = function(ast, opts) if var then local t = var.t if t.typename == "unresolved_typearg" then - return nil, nil, t.interface_type + return nil, nil, t.constraint end t = ensure_fresh_typeargs(t) return t, var.attribute @@ -6650,18 +6650,16 @@ tl.type_check = function(ast, opts) copy = fn_arg(t) else copy.typearg = t.typearg - copy.interface_name = t.interface_name - if t.interface_type then - copy.interface_type, same = resolve(t.interface_type, same) + if t.constraint then + copy.constraint, same = resolve(t.constraint, same) end end elseif t.typename == "unresolvable_typearg" then copy.typearg = t.typearg elseif t.typename == "typevar" then copy.typevar = t.typevar - copy.interface_name = t.interface_name - if t.interface_type then - copy.interface_type, same = resolve(t.interface_type, same) + if t.constraint then + copy.constraint, same = resolve(t.constraint, same) end elseif is_typetype(t) then copy.def, same = resolve(t.def, same) @@ -7580,7 +7578,7 @@ tl.type_check = function(ast, opts) - local vt, _, interface_type = find_var_type(typevar) + local vt, _, constraint = find_var_type(typevar) if vt then return cmp(a or vt, b or vt) @@ -7589,9 +7587,9 @@ tl.type_check = function(ast, opts) local other = a or b - if interface_type then - if not is_a(other, interface_type) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, interface_type) } + if constraint then + if not is_a(other, constraint) then + return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end end @@ -8323,13 +8321,6 @@ a.types[i], b.types[i]), } orignode.known = saveknown end - local function resolve_interface_type(interface_name) - if not interface_name then - return nil - end - return resolve_typetype((find_var_type(interface_name, "use_type"))) - end - local type_check_function_call do local function mark_invalid_typeargs(f) @@ -8434,8 +8425,7 @@ a.types[i], b.types[i]), } if func.typeargs then for _, fnarg in ipairs(func.typeargs) do add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { - interface_name = fnarg.interface_name, - interface_type = resolve_interface_type(fnarg.interface_name), + constraint = fnarg.constraint, })) end end @@ -8654,8 +8644,8 @@ a.types[i], b.types[i]), } end end - if tbl.typename == "typevar" and tbl.interface_type then - local t = match_record_key(tbl.interface_type, rec, key) + if (tbl.typename == "typevar" or tbl.typename == "typearg") and tbl.constraint then + local t = match_record_key(tbl.constraint, rec, key) if t then return t @@ -9130,7 +9120,7 @@ a.types[i], b.types[i]), } for _, a in ipairs(t.def.typeargs) do table.insert(typevals, a_type("typevar", { typevar = a.typearg, - interface_name = a.interface_name, + constraint = a.constraint, })) end end @@ -10531,8 +10521,8 @@ a.types[i], b.types[i]), } if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) - if decltype.typename == "typevar" and decltype.interface_type then - decltype = decltype.interface_type + if decltype.typename == "typevar" and decltype.constraint then + decltype = resolve_typetype(resolve_tuple_and_nominal(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -10572,10 +10562,11 @@ a.types[i], b.types[i]), } end local decltype = resolve_tuple_and_nominal(node.expected) - local interface_type = decltype.typename == "typevar" and decltype.interface_type - if interface_type then - decltype = interface_type + local constraint + if decltype.typename == "typevar" and decltype.constraint then + constraint = resolve_typetype(decltype.constraint) + decltype = resolve_tuple_and_nominal(constraint) end if decltype.typename == "union" then @@ -10690,8 +10681,8 @@ a.types[i], b.types[i]), } t.is_total, t.missing = total_map_check(decltype, seen_keys) end - if interface_type then - return interface_type + if constraint then + return constraint end return t @@ -11641,8 +11632,7 @@ a.types[i], b.types[i]), } after = function(typ, _children) add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, - interface_name = typ.interface_name, - interface_type = resolve_interface_type(typ.interface_name), + constraint = typ.constraint, }))) return typ end, @@ -11652,7 +11642,6 @@ a.types[i], b.types[i]), } if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end - typ.interface_type = resolve_interface_type(typ.interface_name) return typ end, }, @@ -11669,8 +11658,7 @@ a.types[i], b.types[i]), } edit_type(typ, "typevar") typ.names = nil typ.typevar = t.typearg - typ.interface_name = t.interface_name - typ.interface_type = t.interface_type + typ.constraint = t.constraint else if t.is_alias then t = t.def.resolved diff --git a/tl.tl b/tl.tl index cf1143df2..9b3ef96c7 100644 --- a/tl.tl +++ b/tl.tl @@ -1141,8 +1141,7 @@ local interface Type -- typearg typearg: string - interface_name: string - interface_type: Type + constraint: Type -- table items kname: string @@ -1497,6 +1496,7 @@ local parse_argument_list: function(ParseState, integer): integer, Node local parse_argument_type_list: function(ParseState, integer): integer, Type local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer): integer, Node +local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node local parse_enum_body: function(ps: ParseState, i: integer, def: EnumType, node: Node): integer, Node @@ -1888,18 +1888,17 @@ end local function parse_typearg(ps: ParseState, i: integer): integer, Type, integer local name = ps.tokens[i].tk - local interface_constraint: string + local constraint: Type i = verify_kind(ps, i, "identifier") if ps.tokens[i].tk == "is" then i = i + 1 - interface_constraint = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") -- FIXME generic interfaces... + i, constraint = parse_interface_name(ps, i) -- FIXME what about generic interfaces end return i, a_type("typearg", { y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, typearg = name, - interface_name = interface_constraint, + constraint = constraint, }) end @@ -3013,7 +3012,7 @@ local function parse_where_clause(ps: ParseState, i: integer): integer, Node return i, node end -local function parse_interface_name(ps: ParseState, i: integer): integer, Type, integer +parse_interface_name = function(ps: ParseState, i: integer): integer, Type, integer local istart = i local typ: Type i, typ = parse_simple_type_or_nominal(ps, i) @@ -3755,6 +3754,9 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end + if ast.constraint then + table.insert(xs, recurse_type(ast.constraint, visit)) + end local ret: T local cbkind_after = cbkind and cbkind.after @@ -6315,16 +6317,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function fresh_typevar(t: Type): Type, Type, boolean return a_type("typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - interface_name = t.interface_name, - interface_type = t.interface_type, + constraint = t.constraint, }) end local function fresh_typearg(t: Type): Type return a_type("typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - interface_name = t.interface_name, - interface_type = t.interface_type, + constraint = t.constraint, }) end @@ -6345,7 +6345,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if var then local t = var.t if t.typename == "unresolved_typearg" then - return nil, nil, t.interface_type + return nil, nil, t.constraint end t = ensure_fresh_typeargs(t) return t, var.attribute @@ -6650,18 +6650,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy = fn_arg(t) else copy.typearg = t.typearg - copy.interface_name = t.interface_name - if t.interface_type then - copy.interface_type, same = resolve(t.interface_type, same) + if t.constraint then + copy.constraint, same = resolve(t.constraint, same) end end elseif t.typename == "unresolvable_typearg" then copy.typearg = t.typearg elseif t.typename == "typevar" then copy.typevar = t.typevar - copy.interface_name = t.interface_name - if t.interface_type then - copy.interface_type, same = resolve(t.interface_type, same) + if t.constraint then + copy.constraint, same = resolve(t.constraint, same) end elseif is_typetype(t) then copy.def, same = resolve(t.def, same) @@ -7580,7 +7578,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- assert((a == nil and b ~= nil) or (a ~= nil and b == nil)) -- does the typevar currently match to a type? - local vt, _, interface_type = find_var_type(typevar) + local vt, _, constraint = find_var_type(typevar) if vt then -- If so, compare it to the other type return cmp(a or vt, b or vt) @@ -7589,9 +7587,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local other = a or b -- but check interface constraint first if present - if interface_type then - if not is_a(other, interface_type) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, interface_type) } + if constraint then + if not is_a(other, constraint) then + return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end end @@ -8323,13 +8321,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local function resolve_interface_type(interface_name: string): Type - if not interface_name then - return nil - end - return resolve_typetype((find_var_type(interface_name, "use_type"))) - end - local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): Type, Type do local function mark_invalid_typeargs(f: Type) @@ -8434,8 +8425,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func.typeargs then for _, fnarg in ipairs(func.typeargs) do add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { - interface_name = fnarg.interface_name, - interface_type = resolve_interface_type(fnarg.interface_name), + constraint = fnarg.constraint, })) end end @@ -8654,8 +8644,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if tbl.typename == "typevar" and tbl.interface_type then - local t = match_record_key(tbl.interface_type, rec, key) + if (tbl.typename == "typevar" or tbl.typename == "typearg") and tbl.constraint then + local t = match_record_key(tbl.constraint, rec, key) if t then return t @@ -9130,7 +9120,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for _, a in ipairs(t.def.typeargs) do table.insert(typevals, a_type("typevar", { typevar = a.typearg, - interface_name = a.interface_name, + constraint = a.constraint, })) end end @@ -10531,8 +10521,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) - if decltype.typename == "typevar" and decltype.interface_type then - decltype = decltype.interface_type + if decltype.typename == "typevar" and decltype.constraint then + decltype = resolve_typetype(resolve_tuple_and_nominal(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -10572,10 +10562,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local decltype = resolve_tuple_and_nominal(node.expected) - local interface_type = decltype.typename == "typevar" and decltype.interface_type - if interface_type then - decltype = interface_type + local constraint: Type + if decltype.typename == "typevar" and decltype.constraint then + constraint = resolve_typetype(decltype.constraint) + decltype = resolve_tuple_and_nominal(constraint) end if decltype.typename == "union" then @@ -10690,8 +10681,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t.is_total, t.missing = total_map_check(decltype, seen_keys) end - if interface_type then - return interface_type + if constraint then + return constraint end return t @@ -11641,8 +11632,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string after = function(typ: Type, _children: {Type}): Type add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, - interface_name = typ.interface_name, - interface_type = resolve_interface_type(typ.interface_name), + constraint = typ.constraint, }))) return typ end, @@ -11652,7 +11642,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end - typ.interface_type = resolve_interface_type(typ.interface_name) return typ end, }, @@ -11669,8 +11658,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string edit_type(typ, "typevar") typ.names = nil typ.typevar = t.typearg - typ.interface_name = t.interface_name - typ.interface_type = t.interface_type + typ.constraint = t.constraint else if t.is_alias then t = t.def.resolved From f0ad3835fd9472b0e1c6d9fa53b6e82e87024813 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 17:22:46 -0500 Subject: [PATCH 063/224] do not resolve constraint too early --- tl.lua | 13 ++++++++++++- tl.tl | 13 ++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 4c3033f0f..30dd506ca 100644 --- a/tl.lua +++ b/tl.lua @@ -7591,6 +7591,13 @@ tl.type_check = function(ast, opts) if not is_a(other, constraint) then return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end + + if same_type(other, constraint) then + + + + return true + end end local ok, r, errs = resolve_typevars(other) @@ -8327,7 +8334,11 @@ a.types[i], b.types[i]), } if f.typeargs then for _, a in ipairs(f.typeargs) do if not find_var_type(a.typearg) then - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) + if a.constraint then + add_var(nil, a.typearg, a.constraint) + else + add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) + end end end end diff --git a/tl.tl b/tl.tl index 9b3ef96c7..6a9fc0a45 100644 --- a/tl.tl +++ b/tl.tl @@ -7591,6 +7591,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not is_a(other, constraint) then return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end + + if same_type(other, constraint) then + -- do not infer to some type as constraint right away, + -- to give a chance to more specific inferences + -- in other arguments/returns + return true + end end local ok, r, errs = resolve_typevars(other) @@ -8327,7 +8334,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if f.typeargs then for _, a in ipairs(f.typeargs) do if not find_var_type(a.typearg) then - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) + if a.constraint then + add_var(nil, a.typearg, a.constraint) + else + add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) + end end end end From 45d20e9a7660d7d169b94e9b8a091af6b6ed7e69 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 17:29:42 -0500 Subject: [PATCH 064/224] array: rename constant types to consttypes --- tl.lua | 27 +++++++++++++++------------ tl.tl | 27 +++++++++++++++------------ 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/tl.lua b/tl.lua index 30dd506ca..741ec8cfb 100644 --- a/tl.lua +++ b/tl.lua @@ -1256,6 +1256,7 @@ local table_types = { + local TruthyFact = {} @@ -7458,10 +7459,9 @@ tl.type_check = function(ast, opts) if (not a.elements) or (not is_a(a.elements, b.elements)) then return false end - if a.types and #a.types > 1 then + if a.consttypes and #a.consttypes > 1 then - for i = 1, #a.types do - local e = a.types[i] + for _, e in ipairs(a.consttypes) do if not is_a(e, b.elements) then return false, { Err(a, "%s is not a member of %s", e, b.elements) } end @@ -9795,6 +9795,8 @@ a.types[i], b.types[i]), } local seen_keys = {} + local types + for i, child in ipairs(children) do assert(child.typename == "table_item") @@ -9822,8 +9824,8 @@ a.types[i], b.types[i]), } if not is_not_tuple then is_tuple = true end - if not typ.types then - typ.types = {} + if not types then + types = {} end if node[i].key_parsed == "implicit" then @@ -9831,11 +9833,11 @@ a.types[i], b.types[i]), } for _, c in ipairs(child.vtype.tuple) do typ.elements = expand_type(node, typ.elements, c) - typ.types[last_array_idx] = resolve_tuple(c) + types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else - typ.types[last_array_idx] = uvtype + types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 typ.elements = expand_type(node, typ.elements, uvtype) end @@ -9844,7 +9846,7 @@ a.types[i], b.types[i]), } typ.elements = expand_type(node, typ.elements, uvtype) is_not_tuple = true elseif n then - typ.types[n] = uvtype + types[n] = uvtype if n > largest_array_idx then largest_array_idx = n end @@ -9893,7 +9895,7 @@ a.types[i], b.types[i]), } local pure_array = true if not is_not_tuple then local last_t - for _, current_t in pairs(typ.types) do + for _, current_t in pairs(types) do if last_t then if not same_type(last_t, current_t) then pure_array = false @@ -9905,13 +9907,13 @@ a.types[i], b.types[i]), } end if pure_array then typ.typename = "array" - + typ.consttypes = types assert(typ.elements) typ.inferred_len = largest_array_idx - 1 else typ.typename = "tupletable" typ.elements = nil - assert(typ.types) + typ.types = types end elseif is_record then typ.typename = "record" @@ -9919,7 +9921,8 @@ a.types[i], b.types[i]), } typ.typename = "map" elseif is_tuple then typ.typename = "tupletable" - if not typ.types or #typ.types == 0 then + typ.types = types + if not types or #types == 0 then error_at(node, "cannot determine type of tuple elements") end end diff --git a/tl.tl b/tl.tl index 6a9fc0a45..8aa8d07c0 100644 --- a/tl.tl +++ b/tl.tl @@ -1116,6 +1116,7 @@ local interface Type -- array elements: Type + consttypes: {Type} -- tupletable/array inferred_len: integer @@ -7458,10 +7459,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if (not a.elements) or (not is_a(a.elements, b.elements)) then return false end - if a.types and #a.types > 1 then + if a.consttypes and #a.consttypes > 1 then -- constant array, check elements (useful for array of enums) - for i = 1, #a.types do - local e = a.types[i] + for _, e in ipairs(a.consttypes) do if not is_a(e, b.elements) then return false, { Err(a, "%s is not a member of %s", e, b.elements) } end @@ -9795,6 +9795,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local seen_keys: {CheckableKey:Where} = {} + local types: {Type} + for i, child in ipairs(children) do assert(child.typename == "table_item") @@ -9822,8 +9824,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not is_not_tuple then is_tuple = true end - if not typ.types then - typ.types = {} + if not types then + types = {} end if node[i].key_parsed == "implicit" then @@ -9831,11 +9833,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for _, c in ipairs(child.vtype.tuple) do typ.elements = expand_type(node, typ.elements, c) - typ.types[last_array_idx] = resolve_tuple(c) + types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else - typ.types[last_array_idx] = uvtype + types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 typ.elements = expand_type(node, typ.elements, uvtype) end @@ -9844,7 +9846,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ.elements = expand_type(node, typ.elements, uvtype) is_not_tuple = true elseif n then - typ.types[n as integer] = uvtype + types[n as integer] = uvtype if n > largest_array_idx then largest_array_idx = n as integer end @@ -9893,7 +9895,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local pure_array = true if not is_not_tuple then local last_t: Type - for _, current_t in pairs(typ.types as {integer:Type}) do + for _, current_t in pairs(types as {integer:Type}) do if last_t then if not same_type(last_t, current_t) then pure_array = false @@ -9905,13 +9907,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if pure_array then typ.typename = "array" - -- typ.types = nil + typ.consttypes = types assert(typ.elements) typ.inferred_len = largest_array_idx - 1 else typ.typename = "tupletable" typ.elements = nil - assert(typ.types) + typ.types = types end elseif is_record then typ.typename = "record" @@ -9919,7 +9921,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ.typename = "map" elseif is_tuple then typ.typename = "tupletable" - if not typ.types or #typ.types == 0 then + typ.types = types + if not types or #types == 0 then error_at(node, "cannot determine type of tuple elements") end end From 0fbb2d771e225219e8ac207bde5a477899168a84 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 17:37:23 -0500 Subject: [PATCH 065/224] minor tweaks --- tl.lua | 8 ++++---- tl.tl | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tl.lua b/tl.lua index 741ec8cfb..83339698d 100644 --- a/tl.lua +++ b/tl.lua @@ -2674,7 +2674,7 @@ end -local function parse_function(ps, i, ft) +local function parse_function(ps, i, fk) local orig_i = i i = verify_tk(ps, i, "function") local fn = new_node(ps.tokens, i - 1, "global_function") @@ -2713,9 +2713,9 @@ local function parse_function(ps, i, ft) return orig_i + 1 end - if fn.kind == "record_function" and ft == "global" then + if fn.kind == "record_function" and fk == "global" then fail(ps, orig_i, "record functions cannot be annotated as 'global'") - elseif fn.kind == "global_function" and ft == "record" then + elseif fn.kind == "global_function" and fk == "record" then fn.implicit_global_function = true end @@ -6534,7 +6534,7 @@ tl.type_check = function(ast, opts) table.insert(errs, Err(where, err, u)) end if not valid then - u = INVALID + return INVALID, store_errs and errs end return u, store_errs and errs end diff --git a/tl.tl b/tl.tl index 8aa8d07c0..9aff1df98 100644 --- a/tl.tl +++ b/tl.tl @@ -1470,7 +1470,7 @@ local function is_number_type(t:Type): boolean return t.typename == "number" or t.typename == "integer" end -local function is_typetype(t:Type): boolean +local function is_typetype(t: Type): boolean return t.typename == "typetype" end @@ -1494,7 +1494,7 @@ local parse_expression: function(ParseState, integer): integer, Node, integer local parse_expression_and_tk: function(ps: ParseState, i: integer, tk: string): integer, Node local parse_statements: function(ParseState, integer, ? boolean): integer, Node local parse_argument_list: function(ParseState, integer): integer, Node -local parse_argument_type_list: function(ParseState, integer): integer, Type +local parse_argument_type_list: function(ParseState, integer): integer, TupleType local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer @@ -1592,11 +1592,11 @@ local macroexp a_typetype(t: Type): Type return a_type("typetype", t) end -local macroexp a_tuple(t: {Type}): Type +local macroexp a_tuple(t: {Type}): TupleType return a_type("tuple", { tuple = t }) end -local function c_tuple(t: {Type}): Type +local function c_tuple(t: {Type}): TupleType return a_type("tuple", { tuple = t }) end @@ -1637,13 +1637,13 @@ local THREAD = a_type("thread", {}) local BOOLEAN = a_type("boolean", {}) local INTEGER = a_type("integer", {}) -local function shallow_copy_new_type(t: Type): Type +local function shallow_copy_new_type(t: T): T local copy: {any:any} = {} for k, v in pairs(t as {any:any}) do copy[k] = v end copy.typeid = new_typeid() - return copy as Type + return copy as T end local function shallow_copy_table(t: T): T @@ -2669,12 +2669,12 @@ local function parse_local_function(ps: ParseState, i: integer): integer, Node return parse_function_args_rets_body(ps, i, node) end -local enum FunctionType +local enum FunctionKind "global" "record" end -local function parse_function(ps: ParseState, i: integer, ft: FunctionType): integer, Node +local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): integer, Node local orig_i = i i = verify_tk(ps, i, "function") local fn = new_node(ps.tokens, i - 1, "global_function") @@ -2713,9 +2713,9 @@ local function parse_function(ps: ParseState, i: integer, ft: FunctionType): int return orig_i + 1 end - if fn.kind == "record_function" and ft == "global" then + if fn.kind == "record_function" and fk == "global" then fail(ps, orig_i, "record functions cannot be annotated as 'global'") - elseif fn.kind == "global_function" and ft == "record" then + elseif fn.kind == "global_function" and fk == "record" then fn.implicit_global_function = true end @@ -4920,7 +4920,7 @@ local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: Type) ti.vararg = not not rt.is_va end -get_typenum = function(trenv:TypeReportEnv, t: Type): integer +get_typenum = function(trenv: TypeReportEnv, t: Type): integer assert(t.typeid) -- try by typeid local n = trenv.typeid_to_num[t.typeid] @@ -6534,7 +6534,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string table.insert(errs, Err(where, err, u)) end if not valid then - u = INVALID + return INVALID, store_errs and errs end return u, store_errs and errs end From 042856059422237806d2c46899ec016b6b90c117 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 17:39:29 -0500 Subject: [PATCH 066/224] FunctionType, PolyType, UnionType, TupleTableType --- tl.lua | 306 +++++++++++++++++++------------- tl.tl | 546 +++++++++++++++++++++++++++++++-------------------------- 2 files changed, 482 insertions(+), 370 deletions(-) diff --git a/tl.lua b/tl.lua index 83339698d..5575ee3e5 100644 --- a/tl.lua +++ b/tl.lua @@ -1238,6 +1238,29 @@ local table_types = { + + + + + + + + + + + + + + + + + + + + + + + @@ -1609,8 +1632,6 @@ end local function a_function(t) - assert(t.args.typename == "tuple") - assert(t.rets.typename == "tuple") return a_type("function", t) end @@ -1995,7 +2016,6 @@ local function parse_base_type(ps, i) return i, decl elseif ps.tokens[i].tk == "," then local decl = new_type(ps, istart, "tupletable") - decl.typename = "tupletable" decl.types = { t } local n = 2 repeat @@ -2123,7 +2143,6 @@ local function parse_function_args_rets_body(ps, i, node) i, node.body = parse_statements(ps, i) end_at(node, ps.tokens[i]) i = verify_end(ps, i, istart, node) - assert(node.rets.typename == "tuple") return i, node end @@ -2903,8 +2922,9 @@ local function store_field_in_record(ps, i, field_name, t, fields, field_order) else local prev_t = fields[field_name] if t.typename == "function" and prev_t.typename == "function" then - fields[field_name] = new_type(ps, i, "poly") - fields[field_name].types = { prev_t, t } + local p = new_type(ps, i, "poly") + p.types = { prev_t, t } + fields[field_name] = p elseif t.typename == "function" and prev_t.typename == "poly" then table.insert(prev_t.types, t) else @@ -2995,7 +3015,6 @@ local function parse_macroexp(ps, istart, iargs) i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i]) i = verify_end(ps, i, istart, node) - assert(node.rets.typename == "tuple") return i, node end @@ -3171,11 +3190,12 @@ parse_record_body = function(ps, i, def, node) end if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then - if t.typename ~= "function" then + if not (t.typename == "function") then fail(ps, i + 1, "macroexp must have a function type") + else + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + t.is_abstract = true end - i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) - t.is_abstract = true end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3732,16 +3752,18 @@ local function recurse_type(ast, visit) table.insert(xs, recurse_type(child, visit)) end end - if ast.args then - for i, child in ipairs(ast.args.tuple) do - if i > 1 or not ast.is_method or child.is_self then - table.insert(xs, recurse_type(child, visit)) + if ast.typename == "function" then + if ast.args then + for i, child in ipairs(ast.args.tuple) do + if i > 1 or not ast.is_method or child.is_self then + table.insert(xs, recurse_type(child, visit)) + end end end - end - if ast.rets then - for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + if ast.rets then + for _, child in ipairs(ast.rets.tuple) do + table.insert(xs, recurse_type(child, visit)) + end end end if ast.typevals then @@ -4980,13 +5002,11 @@ get_typenum = function(trenv, t) ti.enums = mark_array(sorted_keys(rt.enumset)) elseif rt.typename == "function" then store_function(trenv, ti, rt) - elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then + elseif rt.types then local tis = {} - for _, pt in ipairs(rt.types) do table.insert(tis, get_typenum(trenv, pt)) end - ti.types = mark_array(tis) end @@ -6399,7 +6419,7 @@ tl.type_check = function(ast, opts) return end - if t.macroexp then + if t.typename == "function" and t.macroexp then error_at(where, "macroexps are abstract; consider using a concrete function") else error_at(where, "interfaces are abstract; consider using a concrete record") @@ -6458,10 +6478,6 @@ tl.type_check = function(ast, opts) end local function is_valid_union(typ) - if typ.typename ~= "union" then - return false, nil - end - local n_table_types = 0 @@ -6679,6 +6695,7 @@ tl.type_check = function(ast, opts) end set_min_arity(t) + assert(copy.typename == "function") copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) @@ -6715,13 +6732,21 @@ tl.type_check = function(ast, opts) copy.keys, same = resolve(t.keys, same) copy.values, same = resolve(t.values, same) elseif t.typename == "union" then + assert(copy.typename == "union") copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) end copy, errs = validate_union(t, copy, true, errs) - elseif t.typename == "poly" or t.typename == "tupletable" then + elseif t.typename == "poly" then + assert(copy.typename == "poly") + copy.types = {} + for i, tf in ipairs(t.types) do + copy.types[i], same = resolve(tf, same) + end + elseif t.typename == "tupletable" then + assert(copy.typename == "tupletable") copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -6830,12 +6855,13 @@ tl.type_check = function(ast, opts) if name:sub(1, 2) == "::" then add_warning("unused", var.declared_at, "unused label %s", name) else + local t = var.t add_warning( "unused", var.declared_at, "unused %s %s: %s", var.is_func_arg and "argument" or - var.t.typename == "function" and "function" or + t.typename == "function" and "function" or is_typetype(var.t) and "type" or "variable", name, @@ -7412,7 +7438,7 @@ tl.type_check = function(ast, opts) local function arraytype_from_tuple(where, tupletype) local element_type = unite(tupletype.types, true) - local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) + local valid = (not (element_type.typename == "union")) and true or is_valid_union(element_type) if valid then return a_type("array", { elements = element_type }) end @@ -7796,13 +7822,14 @@ tl.type_check = function(ast, opts) }, ["nominal"] = { ["nominal"] = function(a, b) - local ra = resolve_nominal(a) local rb = resolve_nominal(b) - if rb.typename == "interface" then return is_a(a, rb) - elseif ra.typename == "union" or rb.typename == "union" then + end + + local ra = resolve_nominal(a) + if ra.typename == "union" or rb.typename == "union" then return is_a(ra, rb) end @@ -8192,11 +8219,11 @@ a.types[i], b.types[i]), } return f or t1 end - local function same_call_mt_in_all_union_entries(tbl) - return same_in_all_union_entries(tbl, function(t) + local function same_call_mt_in_all_union_entries(u) + return same_in_all_union_entries(u, function(t) t = resolve_tuple_and_nominal(t) local call_mt = t.meta_fields and t.meta_fields["__call"] - if call_mt then + if call_mt.typename == "function" then local args_tuple = a_type("tuple", { tuple = {} }) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) @@ -8490,6 +8517,9 @@ a.types[i], b.types[i]), } if not (func.typename == "function" or func.typename == "poly") then func, is_method = resolve_for_call(func, args, is_method) + if not (func.typename == "function" or func.typename == "poly") then + return invalid_at(where, "not a function: %s", func) + end end argdelta = is_method and -1 or argdelta or 0 @@ -8498,14 +8528,8 @@ a.types[i], b.types[i]), } add_var(nil, "@self", type_at(where, a_type("typetype", { def = args.tuple[1] }))) end - local is_func = func.typename == "function" - local is_poly = func.typename == "poly" - if not (is_func or is_poly) then - return invalid_at(where, "not a function: %s", func) - end - local passes, n = 1, 1 - if is_poly then + if func.typename == "poly" then passes, n = 3, #func.types end @@ -8515,7 +8539,7 @@ a.types[i], b.types[i]), } for pass = 1, passes do for i = 1, n do if (not tried) or not tried[i] then - local f = is_func and func or func.types[i] + local f = func.typename == "poly" and func.types[i] or func local fargs = f.args.tuple if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then @@ -8531,9 +8555,9 @@ a.types[i], b.types[i]), } set_min_arity(f) - if (is_func and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or + if (passes == 1 and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or - (is_poly and ((pass == 1 and given == wanted) or + (passes == 3 and ((pass == 1 and given == wanted) or (pass == 2 and given < wanted and (lax or given >= f.min_arity)) or @@ -8553,7 +8577,7 @@ a.types[i], b.types[i]), } infer_emptytables(where, where_args, f.rets, f.rets, argdelta) end - if is_poly then + if passes == 3 then tried = tried or {} tried[i] = true pop_typeargs(f) @@ -8590,8 +8614,8 @@ a.types[i], b.types[i]), } store_type(e1.y, e1.x, f) end - if func.macroexp then - expand_macroexp(node, where_args, func.macroexp) + if f and f.macroexp then + expand_macroexp(node, where_args, f.macroexp) end return ret, f @@ -9326,12 +9350,12 @@ a.types[i], b.types[i]), } t1 = resolve_if_union(t1) - if t1.typename ~= "union" then + if not (t1.typename == "union") then return t1 end t2 = resolve_if_union(t2) - local t2types = t2.types or { t2 } + local t2types = t2.typename == "union" and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true @@ -9543,9 +9567,13 @@ a.types[i], b.types[i]), } end + + local ftype = table.remove(b.tuple, 1) ftype = shallow_copy_new_type(ftype) - ftype.is_method = false + if ftype.typename == "function" then + ftype.is_method = false + end local fe2 = {} if node.e1.tk == "xpcall" then @@ -9781,8 +9809,6 @@ a.types[i], b.types[i]), } end local function infer_table_literal(node, children) - local typ = type_at(node, a_type("emptytable", {})) - local is_record = false local is_array = false local is_map = false @@ -9795,8 +9821,16 @@ a.types[i], b.types[i]), } local seen_keys = {} + local types + local fields + local field_order + + local elements + + local keys, values + for i, child in ipairs(children) do assert(child.typename == "table_item") @@ -9813,12 +9847,12 @@ a.types[i], b.types[i]), } local uvtype = resolve_tuple(child.vtype) if ck then is_record = true - if not typ.fields then - typ.fields = {} - typ.field_order = {} + if not fields then + fields = {} + field_order = {} end - typ.fields[ck] = uvtype - table.insert(typ.field_order, ck) + fields[ck] = uvtype + table.insert(field_order, ck) elseif is_number_type(child.ktype) then is_array = true if not is_not_tuple then @@ -9832,62 +9866,66 @@ a.types[i], b.types[i]), } if i == #children and child.vtype.typename == "tuple" then for _, c in ipairs(child.vtype.tuple) do - typ.elements = expand_type(node, typ.elements, c) + elements = expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - typ.elements = expand_type(node, typ.elements, uvtype) + elements = expand_type(node, elements, uvtype) end else if not is_positive_int(n) then - typ.elements = expand_type(node, typ.elements, uvtype) + elements = expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n] = uvtype if n > largest_array_idx then largest_array_idx = n end - typ.elements = expand_type(node, typ.elements, uvtype) + elements = expand_type(node, elements, uvtype) end end if last_array_idx > largest_array_idx then largest_array_idx = last_array_idx end - if not typ.elements then + if not elements then is_array = false end else is_map = true child.ktype.tk = nil - typ.keys = expand_type(node, typ.keys, child.ktype) - typ.values = expand_type(node, typ.values, uvtype) + keys = expand_type(node, keys, child.ktype) + values = expand_type(node, values, uvtype) end end + local t + if is_array and is_map then - typ.typename = "map" - typ.keys = expand_type(node, typ.keys, INTEGER) - typ.values = expand_type(node, typ.values, typ.elements) - typ.elements = nil error_at(node, "cannot determine type of table literal") + t = a_type("map", { keys = +expand_type(node, keys, INTEGER), values = + +expand_type(node, values, elements) }) elseif is_record and is_array then - typ.typename = "record" - typ.interface_list = { - type_at(node, a_type("array", { elements = typ.elements })), - } + t = a_type("record", { + fields = fields, + field_order = field_order, + elements = elements, + interface_list = { + type_at(node, a_type("array", { elements = elements })), + }, + }) elseif is_record and is_map then - if typ.keys.typename == "string" then - typ.typename = "map" - for _, ftype in fields_of(typ) do - typ.values = expand_type(node, typ.values, ftype) + if keys.typename == "string" then + for _, fname in ipairs(field_order) do + values = expand_type(node, values, fields[fname]) end - typ.fields = nil - typ.field_order = nil + t = a_type("map", { keys = keys, values = values }) else error_at(node, "cannot determine type of table literal") end @@ -9906,28 +9944,33 @@ a.types[i], b.types[i]), } end end if pure_array then - typ.typename = "array" - typ.consttypes = types - assert(typ.elements) - typ.inferred_len = largest_array_idx - 1 + t = a_type("array", { elements = elements }) + t.consttypes = types + t.inferred_len = largest_array_idx - 1 else - typ.typename = "tupletable" - typ.elements = nil - typ.types = types + t = a_type("tupletable", {}) + t.types = types end elseif is_record then - typ.typename = "record" + t = a_type("record", { + fields = fields, + field_order = field_order, + }) elseif is_map then - typ.typename = "map" + t = a_type("map", { keys = keys, values = values }) elseif is_tuple then - typ.typename = "tupletable" - typ.types = types + t = a_type("tupletable", {}) + t.types = types if not types or #types == 0 then error_at(node, "cannot determine type of tuple elements") end end - return typ + if not t then + t = a_type("emptytable", {}) + end + + return type_at(node, t) end local function infer_negation_of_if_blocks(where, ifnode, n) @@ -9959,14 +10002,18 @@ a.types[i], b.types[i]), } ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) end else - if infertype and infertype.typename == "unresolvable_typearg" then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") - ok = false - infertype = INVALID - elseif infertype and infertype.is_method then + if infertype then + if infertype.typename == "unresolvable_typearg" then + error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") + ok = false + infertype = INVALID + elseif infertype.typename == "function" and infertype.is_method then + - infertype = shallow_copy_new_type(infertype) - infertype.is_method = false + + infertype = shallow_copy_new_type(infertype) + infertype.is_method = false + end end end @@ -10614,7 +10661,6 @@ a.types[i], b.types[i]), } local is_record = is_record_type(decltype) local is_array = is_array_type(decltype) - local is_tupletable = decltype.typename == "tupletable" local is_map = decltype.typename == "map" local force_array = nil @@ -10642,7 +10688,7 @@ a.types[i], b.types[i]), } assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif is_tupletable and is_number_type(child.ktype) then + elseif decltype.typename == "tupletable" and is_number_type(child.ktype) then local dt = decltype.types[n] if not n then error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) @@ -10711,7 +10757,9 @@ a.types[i], b.types[i]), } vtype = node.itemtype assert_is_a(node.value, children[2], node.itemtype, "in table item") end - if vtype.is_method then + if vtype.typename == "function" and vtype.is_method then + + vtype = shallow_copy_new_type(vtype) vtype.is_method = false @@ -11307,8 +11355,11 @@ a.types[i], b.types[i]), } if not t then error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) t = INVALID - if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + if node.op.op == "or" then + local u = unite({ orig_a, orig_b }) + if u.typename == "union" and is_valid_union(u) then + add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + end end end end @@ -11582,25 +11633,28 @@ a.types[i], b.types[i]), } local fmacros for name, _ in fields_of(typ) do local ftype = children[i] - if ftype.macroexp then - fmacros = fmacros or {} - table.insert(fmacros, ftype) - end - if ftype.typename == "function" and ftype.is_method then - local fargs = ftype.args.tuple - if fargs[1] and fargs[1].is_self then - local record_name = typ.names and typ.names[1] - if record_name then - local selfarg = fargs[1] - if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - selfarg.is_self = false - elseif typ.typeargs then - for j = 1, #typ.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then - ftype.is_method = false - selfarg.is_self = false - break + if ftype.typename == "function" then + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end + + if ftype.is_method then + local fargs = ftype.args.tuple + if fargs[1] and fargs[1].is_self then + local record_name = typ.names and typ.names[1] + if record_name then + local selfarg = fargs[1] + if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + selfarg.is_self = false + elseif typ.typeargs then + for j = 1, #typ.typeargs do + if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + ftype.is_method = false + selfarg.is_self = false + break + end end end end @@ -11613,9 +11667,11 @@ a.types[i], b.types[i]), } end for name, _ in fields_of(typ, "meta") do local ftype = children[i] - if ftype.macroexp then - fmacros = fmacros or {} - table.insert(fmacros, ftype) + if ftype.typename == "function" then + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end end typ.meta_fields[name] = ftype i = i + 1 diff --git a/tl.tl b/tl.tl index 9aff1df98..cda7a300e 100644 --- a/tl.tl +++ b/tl.tl @@ -1089,9 +1089,6 @@ local interface Type is_va: boolean tuple: {Type} - -- poly, union, tupletable - types: {Type} - -- typetype def: Type is_alias: boolean @@ -1120,12 +1117,6 @@ local interface Type -- tupletable/array inferred_len: integer - -- function - is_method: boolean - min_arity: number - args: TupleType - rets: TupleType - typeid: integer -- function argument @@ -1155,9 +1146,6 @@ local interface Type inferred_at: Where emptytable_type: Type - -- macroexp - macroexp: Node - -- unresolved items labels: {string:{Node}} nominals: {string:{Type}} @@ -1165,6 +1153,41 @@ local interface Type narrows: {string:boolean} end +local record FunctionType + is Type + where self.typename == "function" + + is_method: boolean + min_arity: number + args: TupleType + rets: TupleType + macroexp: Node +end + +local interface AggregateType + is Type + where self.types + + types: {Type} +end + +local record UnionType + is AggregateType + where self.typename == "union" +end + +local record TupleTableType + is AggregateType + where self.typename == "tupletable" +end + +local record PolyType + is AggregateType + where self.typename == "poly" + + types: {FunctionType} +end + local record EnumType is Type where self.typename == "enum" enumset: {string:boolean} end @@ -1600,17 +1623,15 @@ local function c_tuple(t: {Type}): TupleType return a_type("tuple", { tuple = t }) end -local macroexp a_union(t: {Type}): Type - return a_type("union", { types = t }) +local macroexp a_union(t: {Type}): UnionType + return a_type("union", { types = t } as UnionType) end -local macroexp a_poly(t: {Type}): Type - return a_type("poly", { types = t }) +local macroexp a_poly(t: {FunctionType}): PolyType + return a_type("poly", { types = t } as PolyType) end -local function a_function(t: Type): Type - assert(t.args.typename == "tuple") - assert(t.rets.typename == "tuple") +local function a_function(t: FunctionType): FunctionType return a_type("function", t) end @@ -1907,8 +1928,8 @@ local function parse_return_types(ps: ParseState, i: integer): integer, Type return parse_type_list(ps, i, "rets") end -local function parse_function_type(ps: ParseState, i: integer): integer, Type - local typ = new_type(ps, i, "function") +local function parse_function_type(ps: ParseState, i: integer): integer, FunctionType + local typ = new_type(ps, i, "function") as FunctionType i = i + 1 if ps.tokens[i].tk == "<" then i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) @@ -1994,8 +2015,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ i = verify_tk(ps, i, "}") return i, decl elseif ps.tokens[i].tk == "," then - local decl = new_type(ps, istart, "tupletable") - decl.typename = "tupletable" + local decl = new_type(ps, istart, "tupletable") as TupleTableType decl.types = { t } local n = 2 repeat @@ -2051,7 +2071,7 @@ parse_type = function(ps: ParseState, i: integer): integer, Type, integer return i end if ps.tokens[i].tk == "|" then - local u = new_type(ps, istart, "union") + local u = new_type(ps, istart, "union") as UnionType u.types = { bt } while ps.tokens[i].tk == "|" do i = i + 1 @@ -2123,7 +2143,6 @@ local function parse_function_args_rets_body(ps: ParseState, i: integer, node: N i, node.body = parse_statements(ps, i) end_at(node, ps.tokens[i]) i = verify_end(ps, i, istart, node) - assert(node.rets.typename == "tuple") return i, node end @@ -2902,10 +2921,11 @@ local function store_field_in_record(ps: ParseState, i: integer, field_name: str table.insert(field_order, field_name) else local prev_t = fields[field_name] - if t.typename == "function" and prev_t.typename == "function" then - fields[field_name] = new_type(ps, i, "poly") - fields[field_name].types = { prev_t, t } - elseif t.typename == "function" and prev_t.typename == "poly" then + if t is FunctionType and prev_t is FunctionType then + local p = new_type(ps, i, "poly") as PolyType + p.types = { prev_t, t } + fields[field_name] = p + elseif t is FunctionType and prev_t is PolyType then table.insert(prev_t.types, t) else fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") @@ -2995,7 +3015,6 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i]) i = verify_end(ps, i, istart, node) - assert(node.rets.typename == "tuple") return i, node end @@ -3086,7 +3105,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): def.meta_fields = {} def.meta_field_order = {} - local typ = new_type(ps, wstart, "function") + local typ = new_type(ps, wstart, "function") as FunctionType typ.is_method = true typ.args = a_tuple { a_type("nominal", { y = typ.y, x = typ.x, filename = ps.filename, names = { "@self" } }) } typ.rets = a_tuple { BOOLEAN } @@ -3171,11 +3190,12 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): end if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then - if t.typename ~= "function" then + if not t is FunctionType then fail(ps, i + 1, "macroexp must have a function type") + else + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + t.is_abstract = true end - i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) - t.is_abstract = true end store_field_in_record(ps, iv, field_name, t, fields, field_order) @@ -3700,7 +3720,7 @@ local function recurse_type(ast: Type, visit: Visitor): T xs[i] = recurse_type(child, visit) end end - if ast.types then + if ast is AggregateType then for _, child in ipairs(ast.types) do table.insert(xs, recurse_type(child, visit)) end @@ -3732,16 +3752,18 @@ local function recurse_type(ast: Type, visit: Visitor): T table.insert(xs, recurse_type(child, visit)) end end - if ast.args then - for i, child in ipairs(ast.args.tuple) do - if i > 1 or not ast.is_method or child.is_self then - table.insert(xs, recurse_type(child, visit)) + if ast is FunctionType then + if ast.args then + for i, child in ipairs(ast.args.tuple) do + if i > 1 or not ast.is_method or child.is_self then + table.insert(xs, recurse_type(child, visit)) + end end end - end - if ast.rets then - for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + if ast.rets then + for _, child in ipairs(ast.rets.tuple) do + table.insert(xs, recurse_type(child, visit)) + end end end if ast.typevals then @@ -4906,7 +4928,7 @@ function tl.init_type_report(): TypeReportEnv } end -local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: Type) +local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: FunctionType) local args: {{integer, string}} = {} for _, fnarg in ipairs(rt.args.tuple) do table.insert(args, mark_array { get_typenum(trenv, fnarg), nil }) @@ -4978,15 +5000,13 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer ti.values = get_typenum(trenv, rt.values) elseif rt is EnumType then ti.enums = mark_array(sorted_keys(rt.enumset)) - elseif rt.typename == "function" then + elseif rt is FunctionType then store_function(trenv, ti, rt) - elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then + elseif rt is AggregateType then local tis = {} - for _, pt in ipairs(rt.types) do table.insert(tis, get_typenum(trenv, pt)) end - ti.types = mark_array(tis) end @@ -5344,19 +5364,19 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return list end return "(" .. list .. ")" - elseif t.typename == "tupletable" then + elseif t is TupleTableType then local out: {string} = {} for _, v in ipairs(t.types) do table.insert(out, show(v)) end return "{" .. table.concat(out, ", ") .. "}" - elseif t.typename == "poly" then + elseif t is PolyType then local out: {string} = {} for _, v in ipairs(t.types) do table.insert(out, show(v)) end return "polymorphic function (with types " .. table.concat(out, " and ") .. ")" - elseif t.typename == "union" then + elseif t is UnionType then local out: {string} = {} for _, v in ipairs(t.types) do table.insert(out, show(v)) @@ -5374,7 +5394,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return show_record_type("interface") elseif is_record_type(t) then return show_record_type("record") - elseif t.typename == "function" then + elseif t is FunctionType then local out: {string} = {"function"} if t.typeargs then table.insert(out, "<") @@ -5698,7 +5718,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local function a_gfunction(n: integer, f: function(...: Type): (Type), typename?: TypeName): Type + local function a_gfunction(n: integer, f: function(...: Type): (FunctionType), typename?: TypeName): FunctionType local typevars = {} local typeargs = {} local c = string.byte("A") - 1 @@ -5743,7 +5763,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} { ctor = a_vararg, args = { }, rets = { STRING } }, } - local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type)): Type + local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): FunctionType)): Type local t = a_poly {} for _, entry in ipairs(file_reader_poly_types) do local args = shallow_copy_table(entry.args) @@ -5795,7 +5815,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} rets = a_tuple {}, } - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type):Type return { args = a_tuple { a, a }, rets = a_tuple { BOOLEAN } } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, a }, rets = a_tuple { BOOLEAN } } end) local metatable_nominals = {} @@ -5809,7 +5829,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["..."] = a_vararg { STRING }, ["any"] = a_type("typetype", { def = ANY }), ["arg"] = an_array(STRING), - ["assert"] = a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a, OPT(b) }, rets = a_tuple { a } } end), + ["assert"] = a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a, OPT(b) }, rets = a_tuple { a } } end), ["collectgarbage"] = a_poly { a_function { args = a_tuple { an_enum { "collect", "count", "stop", "restart" } }, rets = a_tuple { NUMBER } }, a_function { args = a_tuple { an_enum { "step", "setpause", "setstepmul" }, NUMBER }, rets = a_tuple { NUMBER } }, @@ -5818,17 +5838,17 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, ["dofile"] = a_function { args = a_tuple { OPT(STRING) }, rets = a_vararg { ANY } }, ["error"] = a_function { args = a_tuple { ANY, OPT(NUMBER) }, rets = a_tuple {} }, - ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), - ["ipairs"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a) }, rets = a_tuple { + ["getmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), + ["ipairs"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a) }, rets = a_tuple { a_function { args = a_tuple {}, rets = a_tuple { INTEGER, a } }, } } end), ["load"] = a_function { args = a_tuple { a_union { STRING, LOAD_FUNCTION }, OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = a_tuple { FUNCTION, STRING } }, ["loadfile"] = a_function { args = a_tuple { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = a_tuple { FUNCTION, STRING } }, ["next"] = a_poly { - a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a_map(a, b), OPT(a) }, rets = a_tuple { a, b } } end), - a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), OPT(a) }, rets = a_tuple { INTEGER, a } } end), + a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a_map(a, b), OPT(a) }, rets = a_tuple { a, b } } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), OPT(a) }, rets = a_tuple { INTEGER, a } } end), }, - ["pairs"] = a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a_map(a, b) }, rets = a_tuple { + ["pairs"] = a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a_map(a, b) }, rets = a_tuple { a_function { args = a_tuple {}, rets = a_tuple { a, b } }, } } end), ["pcall"] = a_function { args = a_vararg { FUNCTION, ANY }, rets = a_vararg { BOOLEAN, ANY } }, @@ -5838,17 +5858,17 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["rawget"] = a_function { args = a_tuple { TABLE, ANY }, rets = a_tuple { ANY } }, ["rawlen"] = a_function { args = a_tuple { a_union { TABLE, STRING } }, rets = a_tuple { INTEGER } }, ["rawset"] = a_poly { - a_gfunction(2, function(a: Type, b: Type): Type return { args = a_tuple { a_map(a, b), a, b }, rets = a_tuple {} } end), - a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), + a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a_map(a, b), a, b }, rets = a_tuple {} } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), a_function { args = a_tuple { TABLE, ANY, ANY }, rets = a_tuple {} }, }, ["require"] = a_function { args = a_tuple { STRING }, rets = a_tuple {} }, ["select"] = a_poly { - a_gfunction(1, function(a: Type): Type return { args = a_vararg { NUMBER, a }, rets = a_tuple { a } } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_vararg { NUMBER, a }, rets = a_tuple { a } } end), a_function { args = a_vararg { NUMBER, ANY }, rets = a_tuple { ANY } }, a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { INTEGER } }, }, - ["setmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), + ["setmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), ["tonumber"] = a_poly { a_function { args = a_tuple { ANY }, rets = a_tuple { NUMBER } }, a_function { args = a_tuple { ANY, NUMBER }, rets = a_tuple { INTEGER } }, @@ -5889,7 +5909,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["__len"] = a_function { args = a_tuple { a }, rets = a_tuple { ANY } }, ["__mode"] = an_enum { "k", "v", "kv" }, ["__newindex"] = ANY, -- FIXME: function | table | anything with a __newindex metamethod - ["__pairs"] = a_gfunction(2, function(k: Type, v: Type): Type + ["__pairs"] = a_gfunction(2, function(k: Type, v: Type): FunctionType return { args = a_tuple { a }, rets = a_tuple { a_function { args = a_tuple {}, rets = a_tuple { k, v } } } @@ -5945,7 +5965,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { STRING } }, a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, }, - ["getmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), + ["getmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), ["getregistry"] = a_function { args = a_tuple {}, rets = a_tuple { TABLE } }, ["getupvalue"] = a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { ANY } }, ["getuservalue"] = a_function { args = a_tuple { USERDATA, NUMBER }, rets = a_tuple { ANY } }, @@ -5957,7 +5977,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} a_function { args = a_tuple { THREAD, NUMBER, NUMBER, ANY }, rets = a_tuple { STRING } }, a_function { args = a_tuple { NUMBER, NUMBER, ANY }, rets = a_tuple { STRING } }, }, - ["setmetatable"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), + ["setmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), ["setupvalue"] = a_function { args = a_tuple { FUNCTION, NUMBER, ANY }, rets = a_tuple { STRING } }, ["setuservalue"] = a_function { args = a_tuple { USERDATA, ANY, NUMBER }, rets = a_tuple { USERDATA } }, ["traceback"] = a_poly { @@ -6024,14 +6044,14 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["log10"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, ["max"] = a_poly { a_function { args = a_vararg { INTEGER }, rets = a_tuple { INTEGER } }, - a_gfunction(1, function(a: Type): Type return { args = a_vararg { a }, rets = a_tuple { a } } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_vararg { a }, rets = a_tuple { a } } end), a_function { args = a_vararg { a_union { NUMBER, INTEGER } }, rets = a_tuple { NUMBER } }, a_function { args = a_vararg { ANY }, rets = a_tuple { ANY } }, }, ["maxinteger"] = a_type("integer", { needs_compat = true }), ["min"] = a_poly { a_function { args = a_vararg { INTEGER }, rets = a_tuple { INTEGER } }, - a_gfunction(1, function(a: Type): Type return { args = a_vararg { a }, rets = a_tuple { a } } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_vararg { a }, rets = a_tuple { a } } end), a_function { args = a_vararg { a_union { NUMBER, INTEGER } }, rets = a_tuple { NUMBER } }, a_function { args = a_vararg { ANY }, rets = a_tuple { ANY } }, }, @@ -6125,17 +6145,17 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} fields = { ["concat"] = a_function { args = a_tuple { an_array(a_union {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = a_tuple { STRING } }, ["insert"] = a_poly { - a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), - a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), a }, rets = a_tuple {} } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), a }, rets = a_tuple {} } end), }, ["move"] = a_poly { - a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER }, rets = a_tuple { an_array(a) } }end ), - a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }, rets = a_tuple { an_array(a) } } end), + a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER }, rets = a_tuple { an_array(a) } }end ), + a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }, rets = a_tuple { an_array(a) } } end), }, ["pack"] = a_function { args = a_vararg { ANY }, rets = a_tuple { TABLE } }, - ["remove"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), OPT(NUMBER) }, rets = a_tuple { a } } end), - ["sort"] = a_gfunction(1, function(a: Type): Type return { args = a_tuple { an_array(a), OPT(TABLE_SORT_FUNCTION) }, rets = a_tuple {} } end), - ["unpack"] = a_gfunction(1, function(a: Type): Type return { needs_compat = true, args = a_tuple { an_array(a), OPT(NUMBER), OPT(NUMBER) }, rets = a_vararg { a } } end), + ["remove"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), OPT(NUMBER) }, rets = a_tuple { a } } end), + ["sort"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), OPT(TABLE_SORT_FUNCTION) }, rets = a_tuple {} } end), + ["unpack"] = a_gfunction(1, function(a: Type): FunctionType return { needs_compat = true, args = a_tuple { an_array(a), OPT(NUMBER), OPT(NUMBER) }, rets = a_vararg { a } } end), }, }, ["utf8"] = a_record { @@ -6399,7 +6419,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return end - if t.macroexp then + if t is FunctionType and t.macroexp then error_at(where, "macroexps are abstract; consider using a concrete function") else error_at(where, "interfaces are abstract; consider using a concrete record") @@ -6457,11 +6477,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function is_valid_union(typ: Type): boolean, string - if typ.typename ~= "union" then - return false, nil - end - + local function is_valid_union(typ: UnionType): boolean, string -- check for limitations in our union support -- due to codegen limitations (we only check with type() so far) local n_table_types = 0 @@ -6523,7 +6539,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function validate_union(where: Where, u: Type, store_errs?: boolean, errs?: {Error}): Type, {Error} + local function validate_union(where: Where, u: UnionType, store_errs?: boolean, errs?: {Error}): Type, {Error} local valid, err = is_valid_union(u) if err then if store_errs then @@ -6539,7 +6555,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return u, store_errs and errs end - local function set_min_arity(f: Type) + local function set_min_arity(f: FunctionType) if f.min_arity then return end @@ -6558,7 +6574,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string f.min_arity = n end - local function show_arity(f: Type): string + local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple return f.min_arity < nfargs and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) @@ -6670,7 +6686,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.typevals[i], same = resolve(tf, same) end copy.found = t.found - elseif t.typename == "function" then + elseif t is FunctionType then if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6679,6 +6695,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end set_min_arity(t) + assert(copy is FunctionType) copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) @@ -6714,14 +6731,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t.typename == "map" then copy.keys, same = resolve(t.keys, same) copy.values, same = resolve(t.values, same) - elseif t.typename == "union" then + elseif t is UnionType then + assert(copy is UnionType) copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) end copy, errs = validate_union(t, copy, true, errs) - elseif t.typename == "poly" or t.typename == "tupletable" then + elseif t is PolyType then + assert(copy is PolyType) + copy.types = {} + for i, tf in ipairs(t.types) do + copy.types[i], same = resolve(tf, same) as (FunctionType, boolean) + end + elseif t is TupleTableType then + assert(copy is TupleTableType) copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -6830,12 +6855,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if name:sub(1, 2) == "::" then add_warning("unused", var.declared_at, "unused label %s", name) else + local t = var.t add_warning( "unused", var.declared_at, "unused %s %s: %s", var.is_func_arg and "argument" - or var.t.typename == "function" and "function" + or t is FunctionType and "function" or is_typetype(var.t) and "type" or "variable", name, @@ -7359,7 +7385,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string i = i + 1 end t = resolve_tuple(t) - if t.typename == "union" then + if t is UnionType then for _, s in ipairs(t.types) do table.insert(stack, s) end @@ -7409,10 +7435,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local expand_type: function(where: Where, old: Type, new: Type): Type - local function arraytype_from_tuple(where: Where, tupletype: Type): Type, {Error} + local function arraytype_from_tuple(where: Where, tupletype: TupleTableType): Type, {Error} -- first just try a basic union local element_type = unite(tupletype.types, true) - local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) + local valid = (not element_type is UnionType) and true or is_valid_union(element_type) if valid then return an_array(element_type) end @@ -7613,7 +7639,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- ∃ x ∈ xs. t <: x - local function exists_supertype_in(t: Type, xs: Type): Type + local function exists_supertype_in(t: Type, xs: AggregateType): Type for _, x in ipairs(xs.types) do if is_a(t, x) then return x @@ -7655,7 +7681,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do if not same_type(a.types[i], b.types[i]) then return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } @@ -7678,7 +7704,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["union"] = { - ["union"] = function(a: Type, b: Type): boolean, {Error} + ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} return (has_all_types_of(a.types, b.types) and has_all_types_of(b.types, a.types)) end, @@ -7690,7 +7716,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = eqtype_record, }, ["function"] = { - ["function"] = function(a: Type, b: Type): boolean, {Error} + ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then @@ -7759,9 +7785,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a: Type, b: Type): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u - local used = {} -- ──────────────────────── - for _, t in ipairs(a.types) do -- a union <: b union + ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u + local used = {} -- ──────────────────────── + for _, t in ipairs(a.types) do -- a union <: b union begin_scope() local u = exists_supertype_in(t, b) end_scope() -- don't preserve failed inferences @@ -7777,7 +7803,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end return true end, - ["*"] = function(a: Type, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b + ["*"] = function(a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b for _, t in ipairs(a.types) do -- ──────────────── if not is_a(t, b) then -- a union <: b return false @@ -7796,13 +7822,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["nominal"] = { ["nominal"] = function(a: Type, b: Type): boolean, {Error} - local ra = resolve_nominal(a) local rb = resolve_nominal(b) - if rb.typename == "interface" then -- match interface subtyping return is_a(a, rb) - elseif ra.typename == "union" or rb.typename == "union" then + end + + local ra = resolve_nominal(a) + if ra is UnionType or rb is UnionType then -- match unions structurally return is_a(ra, rb) end @@ -7846,7 +7873,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do if not is_a(a.types[i], b.types[i]) then return false, { Err(a, "in tuple entry " @@ -7864,7 +7891,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return subtype_relations["tupletable"]["array"](a, b) end end, - ["array"] = function(a: Type, b: Type): boolean, {Error} + ["array"] = function(a: TupleTableType, b: Type): boolean, {Error} if b.inferred_len and b.inferred_len > #a.types then return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end @@ -7877,7 +7904,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end return true end, - ["map"] = function(a: Type, b: Type): boolean, {Error} + ["map"] = function(a: TupleTableType, b: Type): boolean, {Error} local aa = arraytype_from_tuple(a.inferred_at, a) if not aa then return false, { Err(a, "Unable to convert tuple %s to map", a) } @@ -7931,7 +7958,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["map"] = function(a: Type, b: Type): boolean, {Error} return compare_map(INTEGER, b.keys, a.elements, b.values) end, - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + ["tupletable"] = function(a: Type, b: TupleTableType): boolean, {Error} local alen = a.inferred_len or 0 if alen > #b.types then return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } @@ -7961,7 +7988,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["function"] = { - ["function"] = function(a: Type, b: Type): boolean, {Error} + ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} local errs = {} local aa, ba = a.args.tuple, b.args.tuple @@ -8005,7 +8032,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- ─────────────── -- a <: b union ["nominal"] = subtype_nominal, - ["poly"] = function(a: Type, b: Type): boolean, {Error} -- ∀ t ∈ b, a <: t + ["poly"] = function(a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t for _, t in ipairs(b.types) do -- ─────────────── if not is_a(a, t) then -- a <: b poly return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } @@ -8178,7 +8205,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function same_in_all_union_entries(u: Type, check: function(Type): (Type, Type)): Type + local function same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type local t1, f = check(u.types[1]) if not t1 then return nil @@ -8192,11 +8219,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return f or t1 end - local function same_call_mt_in_all_union_entries(tbl: Type): Type - return same_in_all_union_entries(tbl, function(t: Type): (Type, Type) + local function same_call_mt_in_all_union_entries(u: UnionType): Type + return same_in_all_union_entries(u, function(t: Type): (Type, Type) t = resolve_tuple_and_nominal(t) local call_mt = t.meta_fields and t.meta_fields["__call"] - if call_mt then + if call_mt is FunctionType then local args_tuple = a_tuple({}) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) @@ -8215,7 +8242,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string func = resolve_tuple_and_nominal(func) if func.typename ~= "function" and func.typename ~= "poly" then -- resolve if union - if func.typename == "union" then + if func is UnionType then local r = same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) -- FIXME: is this right? @@ -8328,9 +8355,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): Type, Type + local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): TupleType, Type do - local function mark_invalid_typeargs(f: Type) + local function mark_invalid_typeargs(f: FunctionType) if f.typeargs then for _, a in ipairs(f.typeargs) do if not find_var_type(a.typearg) then @@ -8393,7 +8420,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} + check_args_rets = function(where: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean @@ -8432,7 +8459,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function push_typeargs(func: Type) + local function push_typeargs(func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { @@ -8442,7 +8469,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function pop_typeargs(func: Type) + local function pop_typeargs(func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do if st[#st][fnarg.typearg] then @@ -8452,7 +8479,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function fail_call(where: Where, func: Type, nargs: integer, errs: {Error}): Type + local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): Type if errs then -- report the errors from the first match for _, err in ipairs(errs) do @@ -8461,7 +8488,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else -- found no arity match to try local expects: {string} = {} - if func.typename == "poly" then + if func is PolyType then for _, f in ipairs(func.types) do table.insert(expects, show_arity(f)) end @@ -8477,19 +8504,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end - local f = func.typename == "poly" and func.types[1] or func + local f = func is PolyType and func.types[1] or func mark_invalid_typeargs(f) return resolve_typevars_at(where, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, typetype_funcall: boolean, is_method: boolean, argdelta: integer): Type, Type + local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, typetype_funcall: boolean, is_method: boolean, argdelta: integer): Type, FunctionType assert(type(func) == "table") assert(type(args) == "table") - if not (func.typename == "function" or func.typename == "poly") then + if not (func is FunctionType or func is PolyType) then func, is_method = resolve_for_call(func, args, is_method) + if not (func is FunctionType or func is PolyType) then + return invalid_at(where, "not a function: %s", func) + end end argdelta = is_method and -1 or argdelta or 0 @@ -8498,14 +8528,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(nil, "@self", type_at(where, a_typetype { def = args.tuple[1] })) end - local is_func = func.typename == "function" - local is_poly = func.typename == "poly" - if not (is_func or is_poly) then - return invalid_at(where, "not a function: %s", func) - end - local passes, n = 1, 1 - if is_poly then + if func is PolyType then passes, n = 3, #func.types end @@ -8515,7 +8539,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for pass = 1, passes do for i = 1, n do if (not tried) or not tried[i] then - local f = is_func and func or func.types[i] + local f = func is PolyType and func.types[i] or func local fargs = f.args.tuple if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then @@ -8531,13 +8555,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string set_min_arity(f) -- simple functions: - if (is_func and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) + if (passes == 1 and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) -- poly, pass 1: try exact arity matches first - or (is_poly and ((pass == 1 and given == wanted) + or (passes == 3 and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and (lax or given >= f.min_arity)) + or (pass == 2 and given < wanted and (lax or given >= f.min_arity)) -- poly, pass 3: then finally try vararg functions - or (pass == 3 and f.args.is_va and given > wanted))) + or (pass == 3 and f.args.is_va and given > wanted))) then push_typeargs(f) @@ -8553,7 +8577,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string infer_emptytables(where, where_args, f.rets, f.rets, argdelta) end - if is_poly then + if passes == 3 then tried = tried or {} tried[i] = true pop_typeargs(f) @@ -8566,7 +8590,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return fail_call(where, func, given, first_errs) end - type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: TupleType, e1: Node, is_method: boolean, argdelta?: integer): Type, Type + type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: TupleType, e1: Node, is_method: boolean, argdelta?: integer): TupleType, Type if node.expected and node.expected.typename ~= "tuple" then node.expected = a_tuple { node.expected } end @@ -8590,8 +8614,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string store_type(e1.y, e1.x, f) end - if func.macroexp then - expand_macroexp(node, where_args, func.macroexp) + if f and f.macroexp then + expand_macroexp(node, where_args, f.macroexp) end return ret, f @@ -8976,7 +9000,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local erra: Type local errb: Type - if a.typename == "tupletable" and is_a(b, INTEGER) then + if a is TupleTableType and is_a(b, INTEGER) then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #a.types and bnode.constnum == math.floor(bnode.constnum) then return a.types[bnode.constnum as integer] @@ -9082,7 +9106,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end old.fields = nil old.field_order = nil - elseif old.typename == "union" then + elseif old is UnionType then edit_type(old, "union") new.tk = nil table.insert(old.types, new) @@ -9289,10 +9313,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- t1 ∩ t2 local function intersect_types(t1: Type, t2: Type): Type, string - if t2.typename == "union" then + if t2 is UnionType then t1, t2 = t2, t1 end - if t1.typename == "union" then + if t1 is UnionType then local out = {} for _, t in ipairs(t1.types) do if is_a(t, t2) then @@ -9313,7 +9337,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function resolve_if_union(t: Type): Type local rt = resolve_tuple_and_nominal(t) - if rt.typename == "union" then + if rt is UnionType then return rt end return t @@ -9326,12 +9350,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t1 = resolve_if_union(t1) -- poly are not first-class, so we don't handle them here - if t1.typename ~= "union" then + if not t1 is UnionType then return t1 end t2 = resolve_if_union(t2) - local t2types = t2.types or { t2 } + local t2types = t2 is UnionType and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true @@ -9542,10 +9566,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return a_tuple { BOOLEAN } end - -- The function called by pcall/xpcall is invoked as a regular function, so we wish to avoid incorrect error messages / unnecessary warning messages associated with calling methods as functions + -- The function called by pcall/xpcall is invoked as a regular function, + -- so we wish to avoid incorrect error messages / unnecessary warning messages + -- associated with calling methods as functions local ftype = table.remove(b.tuple, 1) ftype = shallow_copy_new_type(ftype) - ftype.is_method = false + if ftype is FunctionType then + ftype.is_method = false + end local fe2: Node = {} if node.e1.tk == "xpcall" then @@ -9606,7 +9634,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local orig_t = b.tuple[1] local t = resolve_tuple_and_nominal(orig_t) - if t.typename == "tupletable" then + if t is TupleTableType then local arr_type = arraytype_from_tuple(node.e2, t) if not arr_type then return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) @@ -9781,8 +9809,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function infer_table_literal(node: Node, children: {Type}): Type - local typ = type_at(node, a_type("emptytable", {})) - local is_record = false local is_array = false local is_map = false @@ -9795,7 +9821,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local seen_keys: {CheckableKey:Where} = {} + -- array, tupletable local types: {Type} + -- record + local fields: {string:Type} + local field_order: {string} + -- array, record + local elements: Type + -- map + local keys, values: Type, Type for i, child in ipairs(children) do assert(child.typename == "table_item") @@ -9813,12 +9847,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local uvtype = resolve_tuple(child.vtype) if ck then is_record = true - if not typ.fields then - typ.fields = {} - typ.field_order = {} + if not fields then + fields = {} + field_order = {} end - typ.fields[ck] = uvtype - table.insert(typ.field_order, ck) + fields[ck] = uvtype + table.insert(field_order, ck) elseif is_number_type(child.ktype) then is_array = true if not is_not_tuple then @@ -9832,62 +9866,66 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if i == #children and child.vtype.typename == "tuple" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for _, c in ipairs(child.vtype.tuple) do - typ.elements = expand_type(node, typ.elements, c) + elements = expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - typ.elements = expand_type(node, typ.elements, uvtype) + elements = expand_type(node, elements, uvtype) end else -- explicit if not is_positive_int(n) then - typ.elements = expand_type(node, typ.elements, uvtype) + elements = expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n as integer] = uvtype if n > largest_array_idx then largest_array_idx = n as integer end - typ.elements = expand_type(node, typ.elements, uvtype) + elements = expand_type(node, elements, uvtype) end end if last_array_idx > largest_array_idx then largest_array_idx = last_array_idx end - if not typ.elements then + if not elements then is_array = false end else is_map = true child.ktype.tk = nil - typ.keys = expand_type(node, typ.keys, child.ktype) - typ.values = expand_type(node, typ.values, uvtype) + keys = expand_type(node, keys, child.ktype) + values = expand_type(node, values, uvtype) end end + local t: Type + if is_array and is_map then - typ.typename = "map" - typ.keys = expand_type(node, typ.keys, INTEGER) - typ.values = expand_type(node, typ.values, typ.elements) - typ.elements = nil error_at(node, "cannot determine type of table literal") + t = a_map( + expand_type(node, keys, INTEGER), + expand_type(node, values, elements) + ) elseif is_record and is_array then - typ.typename = "record" - typ.interface_list = { - type_at(node, an_array(typ.elements)) - } + t = a_type("record", { + fields = fields, + field_order = field_order, + elements = elements, + interface_list = { + type_at(node, an_array(elements)) + } + }) -- TODO adopt logic from is_array below when we accept tupletable as an interface elseif is_record and is_map then - if typ.keys.typename == "string" then - typ.typename = "map" - for _, ftype in fields_of(typ) do - typ.values = expand_type(node, typ.values, ftype) + if keys.typename == "string" then + for _, fname in ipairs(field_order) do + values = expand_type(node, values, fields[fname]) end - typ.fields = nil - typ.field_order = nil + t = a_map(keys, values) else error_at(node, "cannot determine type of table literal") end @@ -9906,28 +9944,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if pure_array then - typ.typename = "array" - typ.consttypes = types - assert(typ.elements) - typ.inferred_len = largest_array_idx - 1 + t = an_array(elements) + t.consttypes = types + t.inferred_len = largest_array_idx - 1 else - typ.typename = "tupletable" - typ.elements = nil - typ.types = types + t = a_type("tupletable", {}) as TupleTableType + t.types = types end elseif is_record then - typ.typename = "record" + t = a_type("record", { + fields = fields, + field_order = field_order, + }) elseif is_map then - typ.typename = "map" + t = a_map(keys, values) elseif is_tuple then - typ.typename = "tupletable" - typ.types = types + t = a_type("tupletable", {}) as TupleTableType + t.types = types if not types or #types == 0 then error_at(node, "cannot determine type of tuple elements") end end - return typ + if not t then + t = a_type("emptytable", {}) + end + + return type_at(node, t) end local function infer_negation_of_if_blocks(where: Where, ifnode: Node, n: integer) @@ -9959,14 +10002,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) end else - if infertype and infertype.typename == "unresolvable_typearg" then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") - ok = false - infertype = INVALID - elseif infertype and infertype.is_method then - -- If we assign a method to a variable, e.g local myfunc = myobj.dothing, the variable should not be treated as a method - infertype = shallow_copy_new_type(infertype) - infertype.is_method = false + if infertype then + if infertype.typename == "unresolvable_typearg" then + error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") + ok = false + infertype = INVALID + elseif infertype is FunctionType and infertype.is_method then + -- If we assign a method to a variable, e.g: + -- `local myfunc = myobj.dothing`, + -- the variable should not be treated as a method + infertype = shallow_copy_new_type(infertype) + infertype.is_method = false + end end end @@ -10218,7 +10265,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local where = node.exps[i] or node.exps local rt = resolve_tuple_and_nominal(t) - if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then + if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt is UnionType) and not same_type(t, infertype) then t = infer_at(where, infertype) add_var(where, var.tk, t, "const", "narrowed_declaration") end @@ -10275,11 +10322,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if rval and rvar then -- assigning a function - if rval.typename == "function" then + if rval is FunctionType then widen_all_unions() end - if varname and (rvar.typename == "union" or rvar.typename == "interface") then + if varname and (rvar is UnionType or rvar.typename == "interface") then -- narrow unions and interfaces add_var(varnode, varname, rval, nil, "narrow") end @@ -10398,12 +10445,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } local exp1type = resolve_for_call(exptypes[1], args, false) - if exp1type.typename == "poly" then + if exp1type is PolyType then local _: Type _, exp1type = type_check_function_call(exp1, {node.exps[2], node.exps[3]}, exp1type, args, exp1, false, 0) end - if exp1type.typename == "function" then + if exp1type is FunctionType then -- TODO: check that exp1's arguments match with (optional self, explicit iterator, state) local last: Type local rets = exp1type.rets @@ -10539,7 +10586,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string decltype = resolve_typetype(resolve_tuple_and_nominal(decltype.constraint)) end - if decltype.typename == "tupletable" then + if decltype is TupleTableType then for _, child in ipairs(node) do local n = child.key.constnum if n and is_positive_int(n) then @@ -10583,7 +10630,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string decltype = resolve_tuple_and_nominal(constraint) end - if decltype.typename == "union" then + if decltype is UnionType then local single_table_type: Type local single_table_rt: Type @@ -10614,7 +10661,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local is_record = is_record_type(decltype) local is_array = is_array_type(decltype) - local is_tupletable = decltype.typename == "tupletable" local is_map = decltype.typename == "map" local force_array: Type = nil @@ -10642,7 +10688,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif is_tupletable and is_number_type(child.ktype) then + elseif decltype is TupleTableType and is_number_type(child.ktype) then local dt = decltype.types[n as integer] if not n then error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) @@ -10711,8 +10757,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype = node.itemtype assert_is_a(node.value, children[2], node.itemtype, "in table item") end - if vtype.is_method then - -- If we assign a method to a table item, e.g local a = { myfunc = myobj.dothing }, the table item should not be treated as a method + if vtype is FunctionType and vtype.is_method then + -- If we assign a method to a table item, e.g. + -- `local a = { myfunc = myobj.dothing }` + -- the table item should not be treated as a method vtype = shallow_copy_new_type(vtype) vtype.is_method = false end @@ -10788,7 +10836,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.implicit_global_function then local typ = find_var_type(node.name.tk) if typ then - if typ.typename == "function" then + if typ is FunctionType then node.is_predeclared_local_function = true elseif not lax then error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) @@ -10889,7 +10937,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ok, err = same_type(fn_type, rfieldtype) if not ok then - if rfieldtype.typename == "poly" then + if rfieldtype is PolyType then add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") return end @@ -11006,7 +11054,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif node.op.op == "or" then apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then - if e1type.typename == "function" then + if e1type is FunctionType then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then is_a(e1type.rets, node.expected) @@ -11165,11 +11213,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = (ra is EnumType and ra or rb) - elseif expected and expected.typename == "union" then + elseif expected and expected is UnionType then -- must be checked after string/enum above node.known = facts_or(node, node.e1.known, node.e2.known) local u = unite({ra, rb}, true) - if u.typename == "union" then + if u is UnionType then u = validate_union(node, u) end t = u @@ -11207,7 +11255,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not (rb.tk and ra.enumset[unquote(rb.tk)]) then return invalid_at(node, "%s is not a member of %s", b, a) end - elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then + elseif ra is TupleTableType and rb is TupleTableType and #ra.types ~= #rb.types then return invalid_at(node, "tuples are not the same size") elseif is_a(b, a) or a.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then @@ -11228,7 +11276,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.arity == 1 and unop_types[node.op.op] then a = ra - if a.typename == "union" then + if a is UnionType then a = unite(a.types, true) -- squash unions of string constants end @@ -11287,10 +11335,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string a = ra b = rb - if a.typename == "union" then + if a is UnionType then a = unite(a.types, true) -- squash unions of string constants end - if b.typename == "union" then + if b is UnionType then b = unite(b.types, true) -- squash unions of string constants end @@ -11307,8 +11355,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t then error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) t = INVALID - if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + if node.op.op == "or" then + local u = unite({orig_a, orig_b}) + if u is UnionType and is_valid_union(u) then + add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + end end end end @@ -11579,28 +11630,31 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ.elements = children[i] i = i + 1 end - local fmacros: {Type} + local fmacros: {FunctionType} for name, _ in fields_of(typ) do local ftype = children[i] - if ftype.macroexp then - fmacros = fmacros or {} - table.insert(fmacros, ftype) - end - if ftype.typename == "function" and ftype.is_method then - local fargs = ftype.args.tuple - if fargs[1] and fargs[1].is_self then - local record_name = typ.names and typ.names[1] - if record_name then - local selfarg = fargs[1] - if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - selfarg.is_self = false - elseif typ.typeargs then - for j=1,#typ.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then - ftype.is_method = false - selfarg.is_self = false - break + if ftype is FunctionType then + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end + + if ftype.is_method then + local fargs = ftype.args.tuple + if fargs[1] and fargs[1].is_self then + local record_name = typ.names and typ.names[1] + if record_name then + local selfarg = fargs[1] + if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + selfarg.is_self = false + elseif typ.typeargs then + for j=1,#typ.typeargs do + if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + ftype.is_method = false + selfarg.is_self = false + break + end end end end @@ -11613,9 +11667,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for name, _ in fields_of(typ, "meta") do local ftype = children[i] - if ftype.macroexp then - fmacros = fmacros or {} - table.insert(fmacros, ftype) + if ftype is FunctionType then + if ftype.macroexp then + fmacros = fmacros or {} + table.insert(fmacros, ftype) + end end typ.meta_fields[name] = ftype i = i + 1 @@ -11691,7 +11747,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["union"] = { - after = function(typ: Type, _children: {Type}): Type + after = function(typ: UnionType, _children: {Type}): Type return (validate_union(typ, typ)) end }, From 318c7b7c85f123b5856525726971482ea321548c Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 17:50:52 -0500 Subject: [PATCH 067/224] EmptyTableType, UnresolvedEmptyTableValueType --- tl.lua | 20 +++++++++++++++++--- tl.tl | 52 +++++++++++++++++++++++++++++++++------------------- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/tl.lua b/tl.lua index 5575ee3e5..65e522550 100644 --- a/tl.lua +++ b/tl.lua @@ -1269,6 +1269,18 @@ local table_types = { + + + + + + + + + + + + @@ -8160,7 +8172,7 @@ a.types[i], b.types[i]), } elseif t2.typename == "emptytable" then if is_lua_table_type(t1) then infer_emptytable(t2, infer_at(where, t1)) - elseif t1.typename ~= "emptytable" then + elseif not (t1.typename == "emptytable") then error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) return false end @@ -8381,7 +8393,7 @@ a.types[i], b.types[i]), } for i = 1, n_xs do local x = xt[i] - if x.typename == "emptytable" or x.typename == "unresolved_emptytable_value" then + if x.typename == "emptytable" then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then local w = wheres and wheres[i + delta] or where @@ -9023,7 +9035,9 @@ a.types[i], b.types[i]), } end if is_a(orig_b, a.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { emptytable_type = a })) + return type_at(anode, a_type("unresolved_emptytable_value", { + emptytable_type = a, + })) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. diff --git a/tl.tl b/tl.tl index cda7a300e..9347c91d3 100644 --- a/tl.tl +++ b/tl.tl @@ -1079,6 +1079,8 @@ local interface Type yend: integer xend: integer + inferred_at: Where + -- Lua compatibilty needs_compat: boolean @@ -1140,12 +1142,6 @@ local interface Type ktype: Type vtype: Type - -- emptytable - declared_at: Node - assigned_to: string - inferred_at: Where - emptytable_type: Type - -- unresolved items labels: {string:{Node}} nominals: {string:{Type}} @@ -1153,6 +1149,22 @@ local interface Type narrows: {string:boolean} end +local record EmptyTableType + is Type + where self.typename == "emptytable" + + declared_at: Node + assigned_to: string + keys: Type +end + +local record UnresolvedEmptyTableValueType + is Type + where self.typename == "unresolved_emptytable_value" + + emptytable_type: EmptyTableType +end + local record FunctionType is Type where self.typename == "function" @@ -5382,7 +5394,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str table.insert(out, show(v)) end return table.concat(out, " | ") - elseif t.typename == "emptytable" then + elseif t is EmptyTableType then return "{}" elseif t.typename == "map" then return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}" @@ -6780,7 +6792,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true, copy end - local function infer_emptytable(emptytable: Type, fresh_t: Type) + local function infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") local nst = is_global and 1 or #st for i = nst, 1, -1 do @@ -8150,17 +8162,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- some flow-based inference if t1.typename == "nil" then return true - elseif t2.typename == "unresolved_emptytable_value" then + elseif t2 is UnresolvedEmptyTableValueType then if is_number_type(t2.emptytable_type.keys) then -- ideally integer only infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) else infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) end return true - elseif t2.typename == "emptytable" then + elseif t2 is EmptyTableType then if is_lua_table_type(t1) then infer_emptytable(t2, infer_at(where, t1)) - elseif t1.typename ~= "emptytable" then + elseif not t1 is EmptyTableType then error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) return false end @@ -8381,7 +8393,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- resolve inference of emptytables used as arguments or returns for i = 1, n_xs do local x = xt[i] - if x.typename == "emptytable" or x.typename == "unresolved_emptytable_value" then + if x is EmptyTableType then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then -- y may not be present when inferring returns local w = wheres and wheres[i + delta] or where -- for self, a + argdelta is 0 @@ -8708,7 +8720,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else return nil, "invalid key '" .. key .. "' in type %s" end - elseif tbl.typename == "emptytable" or is_unknown(tbl) then + elseif tbl is EmptyTableType or is_unknown(tbl) then if lax then return INVALID end @@ -9017,13 +9029,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif is_array_type(a) and is_a(b, INTEGER) then return a.elements - elseif a.typename == "emptytable" then + elseif a is EmptyTableType then if a.keys == nil then a.keys = infer_at(anode, resolve_tuple(orig_b)) end if is_a(orig_b, a.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { emptytable_type = a })) + return type_at(anode, a_type("unresolved_emptytable_value", { + emptytable_type = a + } as UnresolvedEmptyTableValueType)) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " @@ -9640,7 +9654,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not is_array_type(t) then - if not (lax and (is_unknown(t) or t.typename == "emptytable")) then + if not (lax and (is_unknown(t) or t is EmptyTableType)) then return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end @@ -10050,7 +10064,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = decltype or infertype if t == nil then t = missing_initializer(node, i, name) - elseif t.typename == "emptytable" then + elseif t is EmptyTableType then t.declared_at = node t.assigned_to = name end @@ -10890,7 +10904,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) - if rtype.typename == "emptytable" then + if rtype is EmptyTableType then edit_type(rtype, "record") rtype.fields = {} rtype.field_order = {} @@ -11204,7 +11218,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = a - elseif is_lua_table_type(ra) and b.typename == "emptytable" then + elseif is_lua_table_type(ra) and b is EmptyTableType then node.known = nil t = a From 30650c3a125e5f4d4fc128607078ac6c0dd8b7e6 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 18:15:13 -0500 Subject: [PATCH 068/224] MapType (includes fix in expand_type) --- tl.lua | 20 +++++++++------- tl.tl | 74 +++++++++++++++++++++++++++++++--------------------------- 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/tl.lua b/tl.lua index 65e522550..1b9592303 100644 --- a/tl.lua +++ b/tl.lua @@ -1288,6 +1288,11 @@ local table_types = { + + + + + @@ -3745,10 +3750,8 @@ local function recurse_type(ast, visit) if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end - if ast.keys then + if ast.typename == "map" then table.insert(xs, recurse_type(ast.keys, visit)) - end - if ast.values then table.insert(xs, recurse_type(ast.values, visit)) end if ast.elements then @@ -6741,6 +6744,7 @@ tl.type_check = function(ast, opts) end end elseif t.typename == "map" then + assert(copy.typename == "map") copy.keys, same = resolve(t.keys, same) copy.values, same = resolve(t.values, same) elseif t.typename == "union" then @@ -9103,6 +9107,7 @@ a.types[i], b.types[i]), } end elseif is_record_type(old) and is_record_type(new) then edit_type(old, "map") + assert(old.typename == "map") old.keys = STRING for _, ftype in fields_of(old) do if not old.values then @@ -9113,9 +9118,9 @@ a.types[i], b.types[i]), } end for _, ftype in fields_of(new) do if not old.values then - new.values = ftype + old.values = ftype else - new.values = expand_type(where, old.values, ftype) + old.values = expand_type(where, old.values, ftype) end end old.fields = nil @@ -10675,7 +10680,6 @@ expand_type(node, values, elements) }) local is_record = is_record_type(decltype) local is_array = is_array_type(decltype) - local is_map = decltype.typename == "map" local force_array = nil @@ -10721,12 +10725,12 @@ expand_type(node, values, elements) }) assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then - if is_map then + if decltype.typename == "map" then assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) end force_array = expand_type(node[i], force_array, child.vtype) - elseif is_map then + elseif decltype.typename == "map" then force_array = nil assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) diff --git a/tl.tl b/tl.tl index 9347c91d3..298b28ac2 100644 --- a/tl.tl +++ b/tl.tl @@ -1081,6 +1081,9 @@ local interface Type inferred_at: Where + is_total: boolean + missing: {string} + -- Lua compatibilty needs_compat: boolean @@ -1097,12 +1100,6 @@ local interface Type closed: boolean is_abstract: boolean - -- map - keys: Type - values: Type - is_total: boolean - missing: {string} - -- records interface_list: {Type} interfaces_expanded: boolean @@ -1149,6 +1146,14 @@ local interface Type narrows: {string:boolean} end +local record MapType + is Type -- TODO TotalType + where self.typename == "map" + + keys: Type + values: Type +end + local record EmptyTableType is Type where self.typename == "emptytable" @@ -1657,13 +1662,13 @@ local macroexp an_array(t: Type): Type return a_type("array", { elements = t }) end -local macroexp a_map(k: Type, v: Type): Type - return a_type("map", { keys = k, values = v }) +local macroexp a_map(k: Type, v: Type): MapType + return a_type("map", { keys = k, values = v } as MapType) end local NIL = a_type("nil", {}) local ANY = a_type("any", {}) -local TABLE = a_type("map", { keys = ANY, values = ANY }) +local TABLE = a_map(ANY, ANY) local NUMBER = a_type("number", {}) local STRING = a_type("string", {}) local THREAD = a_type("thread", {}) @@ -2042,7 +2047,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ i = verify_tk(ps, i, "}") return i, decl elseif ps.tokens[i].tk == ":" then - local decl = new_type(ps, istart, "map") + local decl = new_type(ps, istart, "map") as MapType i = i + 1 decl.keys = t i, decl.values = parse_type(ps, i) @@ -2059,7 +2064,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ elseif tk == "nil" then return i + 1, simple_types["nil"] elseif tk == "table" then - local typ = new_type(ps, i, "map") + local typ = new_type(ps, i, "map") as MapType typ.keys = ANY typ.values = ANY return i + 1, typ @@ -3745,10 +3750,8 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end - if ast.keys then + if ast is MapType then table.insert(xs, recurse_type(ast.keys, visit)) - end - if ast.values then table.insert(xs, recurse_type(ast.values, visit)) end if ast.elements then @@ -5007,7 +5010,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer ti.elements = get_typenum(trenv, rt.elements) end - if rt.typename == "map" then + if rt is MapType then ti.keys = get_typenum(trenv, rt.keys) ti.values = get_typenum(trenv, rt.values) elseif rt is EnumType then @@ -5396,7 +5399,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return table.concat(out, " | ") elseif t is EmptyTableType then return "{}" - elseif t.typename == "map" then + elseif t is MapType then return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}" elseif t.typename == "array" then return "{" .. show(t.elements) .. "}" @@ -6740,7 +6743,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.meta_fields[k], same = resolve(t.meta_fields[k], same) end end - elseif t.typename == "map" then + elseif t is MapType then + assert(copy is MapType) copy.keys, same = resolve(t.keys, same) copy.values, same = resolve(t.values, same) elseif t is UnionType then @@ -7711,7 +7715,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["map"] = { - ["map"] = function(a: Type, b: Type): boolean, {Error} + ["map"] = function(a: MapType, b: MapType): boolean, {Error} return compare_map(a.keys, b.keys, a.values, b.values, true) end, }, @@ -7916,7 +7920,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end return true end, - ["map"] = function(a: TupleTableType, b: Type): boolean, {Error} + ["map"] = function(a: TupleTableType, b: MapType): boolean, {Error} local aa = arraytype_from_tuple(a.inferred_at, a) if not aa then return false, { Err(a, "Unable to convert tuple %s to map", a) } @@ -7937,7 +7941,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end, ["array"] = subtype_array, - ["map"] = function(a: Type, b: Type): boolean, {Error} + ["map"] = function(a: Type, b: MapType): boolean, {Error} if not is_a(b.keys, STRING) then return false, { Err(a, "can't match a record to a map with non-string keys") } end @@ -7967,7 +7971,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return subtype_array(a, b) end end, - ["map"] = function(a: Type, b: Type): boolean, {Error} + ["map"] = function(a: Type, b: MapType): boolean, {Error} return compare_map(INTEGER, b.keys, a.elements, b.values) end, ["tupletable"] = function(a: Type, b: TupleTableType): boolean, {Error} @@ -7987,10 +7991,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["map"] = { - ["map"] = function(a: Type, b: Type): boolean, {Error} + ["map"] = function(a: MapType, b: MapType): boolean, {Error} return compare_map(a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a: Type, b: Type): boolean, {Error} + ["array"] = function(a: MapType, b: Type): boolean, {Error} return compare_map(a.keys, INTEGER, a.values, b.elements) end, }, @@ -9044,7 +9048,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string .. a.keys.inferred_at.filename .. ":" .. a.keys.inferred_at.y .. ":" .. a.keys.inferred_at.x .. ": )", orig_b, a.keys - elseif a.typename == "map" then + elseif a is MapType then if is_a(orig_b, a.keys) then return a.values end @@ -9092,7 +9096,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return new else if not is_a(new, old) then - if old.typename == "map" and is_record_type(new) then + if old is MapType and is_record_type(new) then if old.keys.typename == "string" then for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) @@ -9103,6 +9107,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif is_record_type(old) and is_record_type(new) then edit_type(old, "map") + assert(old is MapType) old.keys = STRING for _, ftype in fields_of(old) do if not old.values then @@ -9113,9 +9118,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end for _, ftype in fields_of(new) do if not old.values then - new.values = ftype + old.values = ftype else - new.values = expand_type(where, old.values, ftype) + old.values = expand_type(where, old.values, ftype) end end old.fields = nil @@ -10112,7 +10117,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_map_check(t: Type, seen_keys: {CheckableKey:Where}): boolean, {string} + local function total_map_check(t: MapType, seen_keys: {CheckableKey:Where}): boolean, {string} local k = resolve_tuple_and_nominal(t.keys) local is_total = true local missing: {string} @@ -10613,7 +10618,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string child.value.expected = decltype.elements end end - elseif decltype.typename == "map" then + elseif decltype is MapType then for _, child in ipairs(node) do child.key.expected = decltype.keys child.value.expected = decltype.values @@ -10675,7 +10680,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local is_record = is_record_type(decltype) local is_array = is_array_type(decltype) - local is_map = decltype.typename == "map" local force_array: Type = nil @@ -10721,12 +10725,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then - if is_map then + if decltype is MapType then assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) end force_array = expand_type(node[i], force_array, child.vtype) - elseif is_map then + elseif decltype is MapType then force_array = nil assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) @@ -10751,7 +10755,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if decltype.typename == "record" then t.is_total, t.missing = total_record_check(decltype, seen_keys) - elseif decltype.typename == "map" then + elseif decltype is MapType then t.is_total, t.missing = total_map_check(decltype, seen_keys) end @@ -11089,7 +11093,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif node.op.op == "@index" then - if e1type.typename == "map" then + if e1type is MapType then node.e2.expected = e1type.keys end end @@ -11316,7 +11320,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if a.typename == "map" then + if a is MapType then if a.keys.typename == "number" or a.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else From b6534e5cf6c64b9b5bd743e3ed26f6682a11bf9b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 29 Dec 2023 18:39:19 -0500 Subject: [PATCH 069/224] TypeVarType, TypeArgType, UnresolvedTypeArgType, UnresolvableTypeArgType --- tl.lua | 48 ++++++++++++++++--- tl.tl | 144 +++++++++++++++++++++++++++++++++++---------------------- 2 files changed, 130 insertions(+), 62 deletions(-) diff --git a/tl.lua b/tl.lua index 1b9592303..034f7deae 100644 --- a/tl.lua +++ b/tl.lua @@ -1278,6 +1278,30 @@ local table_types = { + + + + + + + + + + + + + + + + + + + + + + + + @@ -3792,8 +3816,10 @@ local function recurse_type(ast, visit) if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end - if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + if ast.typename == "typearg" then + if ast.constraint then + table.insert(xs, recurse_type(ast.constraint, visit)) + end end local ret @@ -6650,7 +6676,7 @@ tl.type_check = function(ast, opts) if t.typename == "typevar" then local rt = fn_var(t) if rt then - resolved[orig_t.typevar] = true + resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then seen[orig_t] = rt return rt, false @@ -6681,14 +6707,17 @@ tl.type_check = function(ast, opts) if fn_arg then copy = fn_arg(t) else + assert(copy.typename == "typearg") copy.typearg = t.typearg if t.constraint then copy.constraint, same = resolve(t.constraint, same) end end elseif t.typename == "unresolvable_typearg" then + assert(copy.typename == "unresolvable_typearg") copy.typearg = t.typearg elseif t.typename == "typevar" then + assert(copy.typename == "typevar") copy.typevar = t.typevar if t.constraint then copy.constraint, same = resolve(t.constraint, same) @@ -7646,7 +7675,7 @@ tl.type_check = function(ast, opts) if not ok then return false, errs end - if r.typevar == typevar then + if r.typename == "typevar" and r.typevar == typevar then return true end add_var(nil, typevar, r) @@ -8380,7 +8409,9 @@ a.types[i], b.types[i]), } if a.constraint then add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) + add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + typearg = a.typearg, + })) end end end @@ -10898,7 +10929,9 @@ expand_type(node, values, elements) }) if rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + typearg = typ.typearg, + }))) end end end, @@ -11743,8 +11776,9 @@ expand_type(node, values, elements) }) if t then if t.typename == "typearg" then - edit_type(typ, "typevar") typ.names = nil + edit_type(typ, "typevar") + assert(typ.typename == "typevar") typ.typevar = t.typearg typ.constraint = t.constraint else diff --git a/tl.tl b/tl.tl index 298b28ac2..ee2d94c8e 100644 --- a/tl.tl +++ b/tl.tl @@ -1103,7 +1103,7 @@ local interface Type -- records interface_list: {Type} interfaces_expanded: boolean - typeargs: {Type} + typeargs: {TypeArgType} fields: {string: Type} field_order: {string} meta_fields: {string: Type} @@ -1127,13 +1127,6 @@ local interface Type found: Type -- type is found but typeargs are not resolved resolved: Type -- type is found and typeargs are resolved - -- typevar - typevar: string - - -- typearg - typearg: string - constraint: Type - -- table items kname: string ktype: Type @@ -1146,6 +1139,37 @@ local interface Type narrows: {string:boolean} end +local record TypeArgType + is Type + where self.typename == "typearg" + + typearg: string + constraint: Type +end + +local record UnresolvedTypeArgType + is Type + where self.typename == "unresolved_typearg" + + typearg: string + constraint: Type +end + +local record UnresolvableTypeArgType + is Type + where self.typename == "unresolvable_typearg" + + typearg: string +end + +local record TypeVarType + is Type + where self.typename == "typevar" + + typevar: string + constraint: Type +end + local record MapType is Type -- TODO TotalType where self.typename == "map" @@ -1420,7 +1444,7 @@ local record Node value: Node key_parsed: KeyParsed - typeargs: {Type} + typeargs: {TypeArgType} args: Node rets: Type body: Node @@ -1907,11 +1931,11 @@ local function parse_trying_list(ps: ParseState, i: integer, list: {T}, parse return i, list end -local function parse_anglebracket_list(ps: ParseState, i: integer, parse_item: ParseItem): integer, {Type} +local function parse_anglebracket_list(ps: ParseState, i: integer, parse_item: ParseItem): integer, {T} if ps.tokens[i+1].tk == ">" then return fail(ps, i+1, "type argument list cannot be empty") end - local types: {Type} = {} + local types: {T} = {} i = verify_tk(ps, i, "<") i = parse_list(ps, i, types, { [">"] = true, [">>"] = true, }, "sep", parse_item) if ps.tokens[i].tk == ">" then @@ -1925,7 +1949,7 @@ local function parse_anglebracket_list(ps: ParseState, i: integer, parse_item: P return i, types end -local function parse_typearg(ps: ParseState, i: integer): integer, Type, integer +local function parse_typearg(ps: ParseState, i: integer): integer, TypeArgType, integer local name = ps.tokens[i].tk local constraint: Type i = verify_kind(ps, i, "identifier") @@ -1938,7 +1962,7 @@ local function parse_typearg(ps: ParseState, i: integer): integer, Type, integer x = ps.tokens[i - 2].x, typearg = name, constraint = constraint, - }) + } as TypeArgType) end local function parse_return_types(ps: ParseState, i: integer): integer, Type @@ -3792,8 +3816,10 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end - if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + if ast is TypeArgType then + if ast.constraint then + table.insert(xs, recurse_type(ast.constraint, visit)) + end end local ret: T @@ -5455,11 +5481,11 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return t.typename .. (t.tk and " " .. t.tk or "") end - elseif t.typename == "typevar" then + elseif t is TypeVarType then return display_typevar(t.typevar) - elseif t.typename == "typearg" then + elseif t is TypeArgType then return display_typevar(t.typearg) - elseif t.typename == "unresolvable_typearg" then + elseif t is UnresolvableTypeArgType then return display_typevar(t.typearg) .. " (unresolved generic)" elseif is_unknown(t) then return "" @@ -5733,15 +5759,15 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local function a_gfunction(n: integer, f: function(...: Type): (FunctionType), typename?: TypeName): FunctionType + local function a_gfunction(n: integer, f: function(...: TypeVarType): (FunctionType), typename?: TypeName): FunctionType local typevars = {} local typeargs = {} local c = string.byte("A") - 1 fresh_typevar_ctr = fresh_typevar_ctr + 1 for i = 1, n do local name = string.char(c + i) .. "@" .. fresh_typevar_ctr - typevars[i] = a_type("typevar", { typevar = name }) - typeargs[i] = a_type("typearg", { typearg = name }) + typevars[i] = a_type("typevar", { typevar = name } as TypeVarType) + typeargs[i] = a_type("typearg", { typearg = name } as TypeArgType) end local t = f(table.unpack(typevars)) t.typeargs = typeargs @@ -6350,18 +6376,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local type ResolveType = function(Type): Type local resolve_typevars: function (typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} - local function fresh_typevar(t: Type): Type, Type, boolean + local function fresh_typevar(t: TypeVarType): Type, Type, boolean return a_type("typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, - }) + } as TypeVarType) end - local function fresh_typearg(t: Type): Type + local function fresh_typearg(t: TypeArgType): Type return a_type("typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, - }) + } as TypeArgType) end local function ensure_fresh_typeargs(t: Type): Type @@ -6380,7 +6406,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local var = find_var(name, use) if var then local t = var.t - if t.typename == "unresolved_typearg" then + if t is UnresolvedTypeArgType then return nil, nil, t.constraint end t = ensure_fresh_typeargs(t) @@ -6464,7 +6490,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end end - if is_typetype(typ) or (accept_typearg and typ.typename == "typearg") then + if is_typetype(typ) or (accept_typearg and typ is TypeArgType) then return typ end end @@ -6616,7 +6642,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["unknown"] = true, } - local function default_resolve_typevars_callback(t: Type): Type + local function default_resolve_typevars_callback(t: TypeVarType): Type local rt = find_var_type(t.typevar) if not rt then return nil @@ -6647,10 +6673,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local orig_t = t - if t.typename == "typevar" then + if t is TypeVarType then local rt = fn_var(t) if rt then - resolved[orig_t.typevar] = true + resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then seen[orig_t] = rt return rt, false @@ -6677,18 +6703,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.typename == "array" then copy.elements, same = resolve(t.elements, same) -- inferred_len is not propagated - elseif t.typename == "typearg" then + elseif t is TypeArgType then if fn_arg then copy = fn_arg(t) else + assert(copy is TypeArgType) copy.typearg = t.typearg if t.constraint then copy.constraint, same = resolve(t.constraint, same) end end - elseif t.typename == "unresolvable_typearg" then + elseif t is UnresolvableTypeArgType then + assert(copy is UnresolvableTypeArgType) copy.typearg = t.typearg - elseif t.typename == "typevar" then + elseif t is TypeVarType then + assert(copy is TypeVarType) copy.typevar = t.typevar if t.constraint then copy.constraint, same = resolve(t.constraint, same) @@ -6705,7 +6734,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do - copy.typeargs[i], same = resolve(tf, same) + copy.typeargs[i], same = resolve(tf, same) as (TypeArgType, boolean) end end @@ -6719,7 +6748,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do - copy.typeargs[i], same = resolve(tf, same) + copy.typeargs[i], same = resolve(tf, same) as (TypeArgType, boolean) end end @@ -7646,7 +7675,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not ok then return false, errs end - if r.typevar == typevar then + if r is TypeVarType and r.typevar == typevar then return true end add_var(nil, typevar, r) @@ -7684,14 +7713,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["*"] = compare_false, }, ["typevar"] = { - ["typevar"] = function(a: Type, b: Type): boolean, {Error} + ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end return compare_or_infer_typevar(b.typevar, a, nil, same_type) end, - ["*"] = function(a: Type, b: Type): boolean, {Error} + ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} return compare_or_infer_typevar(a.typevar, nil, b, same_type) end, }, @@ -7757,7 +7786,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["*"] = { ["bad_nominal"] = compare_false, - ["typevar"] = function(a: Type, b: Type): boolean, {Error} + ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} return compare_or_infer_typevar(b.typevar, a, nil, same_type) end, }, @@ -7786,14 +7815,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["typevar"] = { - ["typevar"] = function(a: Type, b: Type): boolean, {Error} + ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end return compare_or_infer_typevar(b.typevar, a, nil, is_a) end, - ["*"] = function(a: Type, b: Type): boolean, {Error} + ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} return compare_or_infer_typevar(a.typevar, nil, b, is_a) end, }, @@ -8041,7 +8070,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["tuple"] = function(a: Type, b: Type): boolean, {Error} return is_a(a_tuple({a}), b) end, - ["typevar"] = function(a: Type, b: Type): boolean, {Error} + ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} return compare_or_infer_typevar(b.typevar, a, nil, is_a) end, ["union"] = exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t @@ -8380,7 +8409,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if a.constraint then add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { typearg = a.typearg })) + add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + typearg = a.typearg + } as UnresolvableTypeArgType)) end end end @@ -8480,7 +8511,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for _, fnarg in ipairs(func.typeargs) do add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { constraint = fnarg.constraint, - })) + } as UnresolvedTypeArgType)) end end end @@ -8695,7 +8726,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if (tbl.typename == "typevar" or tbl.typename == "typearg") and tbl.constraint then + if (tbl is TypeVarType or tbl is TypeArgType) and tbl.constraint then local t = match_record_key(tbl.constraint, rec, key) if t then @@ -9175,7 +9206,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string table.insert(typevals, a_type("typevar", { typevar = a.typearg, constraint = a.constraint, - })) + } as TypeVarType)) end end return type_at(where, a_type("nominal", { @@ -10601,7 +10632,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) - if decltype.typename == "typevar" and decltype.constraint then + if decltype is TypeVarType and decltype.constraint then decltype = resolve_typetype(resolve_tuple_and_nominal(decltype.constraint)) end @@ -10644,7 +10675,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local decltype = resolve_tuple_and_nominal(node.expected) local constraint: Type - if decltype.typename == "typevar" and decltype.constraint then + if decltype is TypeVarType and decltype.constraint then constraint = resolve_typetype(decltype.constraint) decltype = resolve_tuple_and_nominal(constraint) end @@ -10898,7 +10929,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- add type arguments from the record implicitly if rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg }))) + add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + typearg = typ.typearg + } as TypeArgType))) end end end, @@ -11634,7 +11667,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do - typ.typeargs[i] = children[i] + typ.typeargs[i] = children[i] as TypeArgType i = i + 1 end end @@ -11717,16 +11750,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["typearg"] = { - after = function(typ: Type, _children: {Type}): Type + after = function(typ: TypeArgType, _children: {Type}): Type add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + } as TypeArgType))) return typ end, }, ["typevar"] = { - after = function(typ: Type, _children: {Type}): Type + after = function(typ: TypeVarType, _children: {Type}): Type if not find_var_type(typ.typevar) then error_at(typ, "undefined type variable " .. typ.typevar) end @@ -11741,10 +11774,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = find_type(typ.names, true) if t then - if t.typename == "typearg" then + if t is TypeArgType then -- convert nominal into a typevar - edit_type(typ, "typevar") typ.names = nil + edit_type(typ, "typevar") + assert(typ is TypeVarType) typ.typevar = t.typearg typ.constraint = t.constraint else From 36e8ea3e12042bbaf8ef7895709806ccf49fff58 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 30 Dec 2023 13:05:50 -0500 Subject: [PATCH 070/224] subtype-match a typearg using its constraint --- spec/call/generic_function_spec.lua | 22 +++++++++++- tl.lua | 51 ++++++++++++++++++++------- tl.tl | 53 +++++++++++++++++++++-------- 3 files changed, 98 insertions(+), 28 deletions(-) diff --git a/spec/call/generic_function_spec.lua b/spec/call/generic_function_spec.lua index 89d45d09d..2fb8cf4d6 100644 --- a/spec/call/generic_function_spec.lua +++ b/spec/call/generic_function_spec.lua @@ -73,6 +73,24 @@ describe("generic function", function() print(use_conv("123", convert_str_num) + 123.0) ]])) + it("accepts correct typevars, does not mix up multiple uses", util.check([[ + local type Convert = function(a): b + + local function convert_num_str(n: number): string + return tostring(n) + end + + local function convert_str_num(s: string): number + return tonumber(s) + end + + local function use_conv(x: X, cvt: Convert, tvc: Convert): Y -- tvc is flipped! + return cvt(tvc(cvt(x))) + end + + print(use_conv(122.0, convert_num_str, convert_str_num) .. "!") + ]])) + it("catches incorrect typevars, does not mix up multiple uses", util.check_type_error([[ local type Convert = function(a): b @@ -90,7 +108,9 @@ describe("generic function", function() print(use_conv(122.0, convert_num_str, convert_str_num) .. "!") ]], { - { y = 15, x = 46, msg = "argument 3: argument 1: got string, expected number" } + { y = 12, x = 24, msg = "argument 1: got Y, expected X" }, + { y = 12, x = 28, msg = "argument 1: got Y, expected X" }, + { y = 15, x = 46, msg = "argument 3: argument 1: got string, expected number" }, })) it("will catch if resolved typevar does not match", util.check_type_error([[ diff --git a/tl.lua b/tl.lua index 034f7deae..f49985026 100644 --- a/tl.lua +++ b/tl.lua @@ -7879,8 +7879,14 @@ tl.type_check = function(ast, opts) return is_a(ra, rb) end + local ok, errs = are_same_nominals(a, b) + if ok then + return true + end + + - return are_same_nominals(a, b) + return ok, errs end, ["*"] = subtype_nominal, }, @@ -8064,6 +8070,16 @@ a.types[i], b.types[i]), } return any_errors(errs) end, }, + ["typearg"] = { + ["typearg"] = function(a, b) + return a.typearg == b.typearg + end, + ["*"] = function(a, b) + if a.constraint then + return is_a(a.constraint, b) + end + end, + }, ["*"] = { ["bad_nominal"] = compare_false, ["any"] = compare_true, @@ -8073,6 +8089,11 @@ a.types[i], b.types[i]), } ["typevar"] = function(a, b) return compare_or_infer_typevar(b.typevar, a, nil, is_a) end, + ["typearg"] = function(a, b) + if b.constraint then + return is_a(a, b.constraint) + end + end, ["union"] = exists_supertype_in, @@ -8098,22 +8119,25 @@ a.types[i], b.types[i]), } ["any"] = 5, ["union"] = 6, ["poly"] = 7, - ["nominal"] = 8, - ["enum"] = 9, - ["string"] = 9, - ["integer"] = 9, - ["boolean"] = 9, + ["typearg"] = 8, + + ["nominal"] = 9, + + ["enum"] = 10, + ["string"] = 10, + ["integer"] = 10, + ["boolean"] = 10, - ["interface"] = 10, + ["interface"] = 11, - ["emptytable"] = 11, - ["tupletable"] = 12, + ["emptytable"] = 12, + ["tupletable"] = 13, - ["record"] = 13, - ["array"] = 13, - ["map"] = 13, - ["function"] = 13, + ["record"] = 14, + ["array"] = 14, + ["map"] = 14, + ["function"] = 14, } if lax then @@ -10931,6 +10955,7 @@ expand_type(node, values, elements) }) for _, typ in ipairs(rtype.typeargs) do add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, + constraint = typ.constraint, }))) end end diff --git a/tl.tl b/tl.tl index ee2d94c8e..d239f28c3 100644 --- a/tl.tl +++ b/tl.tl @@ -7879,8 +7879,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_a(ra, rb) end + local ok, errs = are_same_nominals(a, b) + if ok then + return true + end + + -- all other types nominally - return are_same_nominals(a, b) + return ok, errs end, ["*"] = subtype_nominal, }, @@ -8064,6 +8070,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return any_errors(errs) end, }, + ["typearg"] = { + ["typearg"] = function(a: TypeArgType, b: TypeArgType): boolean, {Error} + return a.typearg == b.typearg + end, + ["*"] = function(a: TypeArgType, b: Type): boolean, {Error} + if a.constraint then + return is_a(a.constraint, b) + end + end, + }, ["*"] = { ["bad_nominal"] = compare_false, ["any"] = compare_true, @@ -8073,6 +8089,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} return compare_or_infer_typevar(b.typevar, a, nil, is_a) end, + ["typearg"] = function(a: Type, b: TypeArgType): boolean, {Error} + if b.constraint then + return is_a(a, b.constraint) + end + end, ["union"] = exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t -- ─────────────── -- a <: b union @@ -8098,22 +8119,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["any"] = 5, ["union"] = 6, ["poly"] = 7, - ["nominal"] = 8, + -- then typeargs + ["typearg"] = 8, + -- then nominals + ["nominal"] = 9, -- then base types - ["enum"] = 9, - ["string"] = 9, - ["integer"] = 9, - ["boolean"] = 9, + ["enum"] = 10, + ["string"] = 10, + ["integer"] = 10, + ["boolean"] = 10, -- then interfaces - ["interface"] = 10, + ["interface"] = 11, -- then special cases of tables - ["emptytable"] = 11, - ["tupletable"] = 12, + ["emptytable"] = 12, + ["tupletable"] = 13, -- then other recursive types - ["record"] = 13, - ["array"] = 13, - ["map"] = 13, - ["function"] = 13, + ["record"] = 14, + ["array"] = 14, + ["map"] = 14, + ["function"] = 14, } if lax then @@ -10930,7 +10954,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { - typearg = typ.typearg + typearg = typ.typearg, + constraint = typ.constraint, } as TypeArgType))) end end From c54a112e49191a9f08da79516d0c11e57b0b0c4c Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 30 Dec 2023 13:24:04 -0500 Subject: [PATCH 071/224] more efficient nominal check --- tl.lua | 11 +++++------ tl.tl | 11 +++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tl.lua b/tl.lua index f49985026..ee83dcda1 100644 --- a/tl.lua +++ b/tl.lua @@ -7867,6 +7867,11 @@ tl.type_check = function(ast, opts) }, ["nominal"] = { ["nominal"] = function(a, b) + local ok, errs = are_same_nominals(a, b) + if ok then + return true + end + local rb = resolve_nominal(b) if rb.typename == "interface" then @@ -7879,12 +7884,6 @@ tl.type_check = function(ast, opts) return is_a(ra, rb) end - local ok, errs = are_same_nominals(a, b) - if ok then - return true - end - - return ok, errs end, diff --git a/tl.tl b/tl.tl index d239f28c3..7f334ad04 100644 --- a/tl.tl +++ b/tl.tl @@ -7867,6 +7867,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["nominal"] = { ["nominal"] = function(a: Type, b: Type): boolean, {Error} + local ok, errs = are_same_nominals(a, b) + if ok then + return true + end + local rb = resolve_nominal(b) if rb.typename == "interface" then -- match interface subtyping @@ -7879,12 +7884,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_a(ra, rb) end - local ok, errs = are_same_nominals(a, b) - if ok then - return true - end - - -- all other types nominally return ok, errs end, From 0954829854c961bc759580b8d151897c06adafb9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 31 Dec 2023 03:42:47 -0500 Subject: [PATCH 072/224] TupleType, InvalidType --- tl.lua | 154 +++++++++++++++++++++----------- tl.tl | 274 ++++++++++++++++++++++++++++++++++----------------------- 2 files changed, 264 insertions(+), 164 deletions(-) diff --git a/tl.lua b/tl.lua index ee83dcda1..739e40c88 100644 --- a/tl.lua +++ b/tl.lua @@ -143,6 +143,7 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = { + tl.version = function() @@ -1314,6 +1315,14 @@ local table_types = { + + + + + + + + @@ -3756,7 +3765,7 @@ local function recurse_type(ast, visit) end end - if ast.tuple then + if ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do xs[i] = recurse_type(child, visit) end @@ -4980,7 +4989,8 @@ local function store_function(trenv, ti, rt) table.insert(rets, mark_array({ get_typenum(trenv, fnarg), nil })) end ti.rets = mark_array(rets) - ti.vararg = not not rt.is_va + ti.vararg = not not rt.args.is_va + ti.varret = not not rt.rets.is_va end get_typenum = function(trenv, t) @@ -6797,6 +6807,7 @@ tl.type_check = function(ast, opts) copy.types[i], same = resolve(tf, same) end elseif t.typename == "tuple" then + assert(copy.typename == "tuple") copy.is_va = t.is_va copy.tuple = {} for i, tf in ipairs(t.tuple) do @@ -8423,6 +8434,8 @@ a.types[i], b.types[i]), } orignode.known = saveknown end + + local type_check_function_call do local function mark_invalid_typeargs(f) @@ -8442,9 +8455,6 @@ a.types[i], b.types[i]), } end local function infer_emptytables(where, wheres, xs, ys, delta) - assert(xs.typename == "tuple") - assert(ys.typename == "tuple") - local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8661,8 +8671,12 @@ a.types[i], b.types[i]), } end type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) - if node.expected and node.expected.typename ~= "tuple" then - node.expected = a_type("tuple", { tuple = { node.expected } }) + local expected = node.expected + local expected_rets + if expected and expected.typename == "tuple" then + expected_rets = expected + else + expected_rets = a_type("tuple", { tuple = { node.expected } }) end begin_scope() @@ -8676,7 +8690,7 @@ a.types[i], b.types[i]), } node.e1.receiver.resolved.typename == "typetype") - local ret, f = check_call(node, where_args, func, args, node.expected, typetype_funcall, is_method, argdelta) + local ret, f = check_call(node, where_args, func, args, expected_rets, typetype_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() @@ -8916,8 +8930,6 @@ a.types[i], b.types[i]), } end local function add_internal_function_variables(node, args) - assert(args.typename == "tuple") - add_var(nil, "@is_va", args.is_va and ANY or NIL) add_var(nil, "@return", node.rets or a_type("tuple", { tuple = {} })) @@ -9666,10 +9678,9 @@ a.types[i], b.types[i]), } e2 = fe2, } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) - if rets == INVALID then + if rets.typename == "invalid" then return rets end - assert(rets and rets.typename == "tuple", show_type(rets)) table.insert(rets.tuple, 1, BOOLEAN) return rets end @@ -9735,7 +9746,7 @@ a.types[i], b.types[i]), } return invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return a_type({ typename = "any" }) + return a_type("tuple", { tuple = { a_type("any", {}) } }) end local module_name = assert(node.e2[1].conststr) @@ -9746,7 +9757,7 @@ a.types[i], b.types[i]), } if t.typename == "invalid" then if lax then - return UNKNOWN + return a_type("tuple", { tuple = { UNKNOWN } }) end return invalid_at(node, "no type information for required module: '" .. module_name .. "'") end @@ -9936,9 +9947,10 @@ a.types[i], b.types[i]), } end if node[i].key_parsed == "implicit" then - if i == #children and child.vtype.typename == "tuple" then + local cv = child.vtype + if i == #children and cv.typename == "tuple" then - for _, c in ipairs(child.vtype.tuple) do + for _, c in ipairs(cv.tuple) do elements = expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 @@ -9979,8 +9991,8 @@ a.types[i], b.types[i]), } if is_array and is_map then error_at(node, "cannot determine type of table literal") - t = a_type("map", { keys = -expand_type(node, keys, INTEGER), values = + t = a_type("map", { keys = +expand_type(node, keys, INTEGER), values = expand_type(node, values, elements) }) elseif is_record and is_array then @@ -10139,7 +10151,7 @@ expand_type(node, values, elements) }) node.value.e1.tk == "require" then local t = special_functions["require"](node.value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) - if t ~= INVALID then + if not (t.typename == "invalid") then return t.tuple[1] end else @@ -10305,8 +10317,10 @@ expand_type(node, values, elements) }) end, before_exp = set_expected_types_to_decltuple, after = function(node, children) + local valtuple = children[3] + local encountered_close = false - local infertypes = get_assignment_values(children[3], #node.vars) + local infertypes = get_assignment_values(valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then if opts.gen_target == "5.4" then @@ -10356,7 +10370,9 @@ expand_type(node, values, elements) }) ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, after = function(node, children) - local infertypes = get_assignment_values(children[3], #node.vars) + local valtuple = children[3] + + local infertypes = get_assignment_values(valtuple, #node.vars) for i, var in ipairs(node.vars) do local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) @@ -10374,8 +10390,12 @@ expand_type(node, values, elements) }) ["assignment"] = { before_exp = set_expected_types_to_decltuple, after = function(node, children) - local vartypes = children[1].tuple - local valtypes = get_assignment_values(children[3], #vartypes) + local vartuple = children[1] + assert(vartuple.typename == "tuple") + local vartypes = vartuple.tuple + local valtuple = children[3] + assert(valtuple.typename == "tuple") + local valtypes = get_assignment_values(valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk @@ -10383,13 +10403,10 @@ expand_type(node, values, elements) }) local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then - local rets = children[3] - if rets.typename == "tuple" then - local msg = #rets.tuple == 1 and - "only 1 value is returned by the function" or - ("only " .. #rets.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) - end + local msg = #valtuple.tuple == 1 and + "only 1 value is returned by the function" or + ("only " .. #valtuple.tuple .. " values are returned by the function") + add_warning("hint", varnode, msg) end end @@ -10508,7 +10525,9 @@ expand_type(node, values, elements) }) begin_scope(node) end, before_statements = function(node, children) - local exptypes = children[2].tuple + local exptuple = children[2] + assert(exptuple.typename == "tuple") + local exptypes = exptuple.tuple widen_all_unions(node) local exp1 = node.exps[1] @@ -10577,7 +10596,7 @@ expand_type(node, values, elements) }) ["return"] = { before = function(node) local rets = find_var_type("@return") - if rets then + if rets and rets.typename == "tuple" then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end @@ -10585,6 +10604,7 @@ expand_type(node, values, elements) }) end, after = function(node, children) local got = children[1] + assert(got.typename == "tuple") local got_t = got.tuple local n_got = #got_t @@ -10770,9 +10790,10 @@ expand_type(node, values, elements) }) assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) end elseif is_array and is_number_type(child.ktype) then - if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then + local cv = child.vtype + if cv.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then - for ti, tt in ipairs(child.vtype.tuple) do + for ti, tt in ipairs(cv.tuple) do assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) end else @@ -10853,19 +10874,25 @@ expand_type(node, values, elements) }) end, before_statements = function(node, children) local args = children[2] + assert(args.typename == "tuple") + add_internal_function_variables(node, args) add_function_definition_for_recursion(node, args) end, after = function(node, children) + local args = children[2] + assert(args.typename == "tuple") + local rets = children[3] + assert(rets.typename == "tuple") + end_function_scope(node) - local rets = get_rets(children[3]) local t = ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, - args = children[2], - rets = rets, + args = args, + rets = get_rets(rets), filename = filename, })) @@ -10882,8 +10909,12 @@ expand_type(node, values, elements) }) begin_scope(node) end, after = function(node, children) + local args = children[2] + assert(args.typename == "tuple") + local rets = children[3] + assert(rets.typename == "tuple") + end_function_scope(node) - local rets = get_rets(children[3]) check_macroexp_arg_use(node.macrodef) @@ -10891,8 +10922,8 @@ expand_type(node, values, elements) }) y = node.y, x = node.x, typeargs = node.typeargs, - args = children[2], - rets = rets, + args = args, + rets = get_rets(rets), filename = filename, macroexp = node.macrodef, })) @@ -10920,10 +10951,17 @@ expand_type(node, values, elements) }) end, before_statements = function(node, children) local args = children[2] + assert(args.typename == "tuple") + add_internal_function_variables(node, args) add_function_definition_for_recursion(node, args) end, after = function(node, children) + local args = children[2] + assert(args.typename == "tuple") + local rets = children[3] + assert(rets.typename == "tuple") + end_function_scope(node) if node.is_predeclared_local_function then return NONE @@ -10933,8 +10971,8 @@ expand_type(node, values, elements) }) y = node.y, x = node.x, typeargs = node.typeargs, - args = children[2], - rets = get_rets(children[3]), + args = args, + rets = get_rets(rets), filename = filename, }))) @@ -10962,6 +11000,8 @@ expand_type(node, values, elements) }) before_statements = function(node, children) local args = children[3] assert(args.typename == "tuple") + local rets = children[4] + assert(rets.typename == "tuple") local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) @@ -10996,7 +11036,7 @@ expand_type(node, values, elements) }) is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(children[4]), + rets = get_rets(rets), filename = filename, })) @@ -11054,18 +11094,23 @@ expand_type(node, values, elements) }) end, before_statements = function(node, children) local args = children[1] + assert(args.typename == "tuple") + add_internal_function_variables(node, args) end, after = function(node, children) - end_function_scope(node) - + local args = children[1] + assert(args.typename == "tuple") + local rets = children[2] + assert(rets.typename == "tuple") + end_function_scope(node) return ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, - args = children[1], - rets = children[2], + args = args, + rets = rets, filename = filename, })) end, @@ -11077,18 +11122,23 @@ expand_type(node, values, elements) }) end, before_exp = function(node, children) local args = children[1] + assert(args.typename == "tuple") + add_internal_function_variables(node, args) end, after = function(node, children) - end_function_scope(node) - + local args = children[1] + assert(args.typename == "tuple") + local rets = children[2] + assert(rets.typename == "tuple") + end_function_scope(node) return ensure_fresh_typeargs(a_function({ y = node.y, x = node.x, typeargs = node.typeargs, - args = children[1], - rets = children[2], + args = args, + rets = rets, filename = filename, })) end, @@ -11257,7 +11307,7 @@ expand_type(node, values, elements) }) local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, a == INVALID and a or resolve_tuple(orig_a)) + return invalid_at(node.e2, e, resolve_tuple(orig_a)) end return t diff --git a/tl.tl b/tl.tl index 7f334ad04..c18a931e1 100644 --- a/tl.tl +++ b/tl.tl @@ -111,6 +111,7 @@ local record tl args: {{integer, string}} -- FUNCTION rets: {{integer, string}} -- FUNCTION vararg: boolean -- FUNCTION + varret: boolean -- FUNCTION types: {integer} -- UNION, POLY, TUPLE keys: integer -- MAP values: integer -- MAP @@ -1090,10 +1091,6 @@ local interface Type -- arguments: optional arity opt: boolean - -- tuple - is_va: boolean - tuple: {Type} - -- typetype def: Type is_alias: boolean @@ -1139,6 +1136,20 @@ local interface Type narrows: {string:boolean} end +local record InvalidType + is Type + where self.typename == "invalid" +end + +local record TupleType + is Type + where self.typename == "tuple" + + -- tuple + is_va: boolean + tuple: {Type} +end + local record TypeArgType is Type where self.typename == "typearg" @@ -1233,8 +1244,6 @@ local record EnumType is Type where self.typename == "enum" enumset: {string:boolean} end -local type TupleType = Type - local record Operator y: integer x: integer @@ -1446,7 +1455,7 @@ local record Node typeargs: {TypeArgType} args: Node - rets: Type + rets: TupleType body: Node implicit_global_function: boolean is_predeclared_local_function: boolean @@ -1553,7 +1562,7 @@ local enum ParseTypeListMode "casttype" end -local parse_type_list: function(ParseState, integer, ParseTypeListMode): integer, Type +local parse_type_list: function(ParseState, integer, ParseTypeListMode): integer, TupleType local parse_expression: function(ParseState, integer): integer, Node, integer local parse_expression_and_tk: function(ps: ParseState, i: integer, tk: string): integer, Node local parse_statements: function(ParseState, integer, ? boolean): integer, Node @@ -1657,11 +1666,11 @@ local macroexp a_typetype(t: Type): Type end local macroexp a_tuple(t: {Type}): TupleType - return a_type("tuple", { tuple = t }) + return a_type("tuple", { tuple = t } as TupleType) end local function c_tuple(t: {Type}): TupleType - return a_type("tuple", { tuple = t }) + return a_type("tuple", { tuple = t } as TupleType) end local macroexp a_union(t: {Type}): UnionType @@ -1676,7 +1685,7 @@ local function a_function(t: FunctionType): FunctionType return a_type("function", t) end -local function a_vararg(t: {Type}): Type +local function a_vararg(t: {Type}): TupleType local typ = a_tuple(t) typ.is_va = true return typ @@ -1965,7 +1974,7 @@ local function parse_typearg(ps: ParseState, i: integer): integer, TypeArgType, } as TypeArgType) end -local function parse_return_types(ps: ParseState, i: integer): integer, Type +local function parse_return_types(ps: ParseState, i: integer): integer, TupleType return parse_type_list(ps, i, "rets") end @@ -2127,13 +2136,13 @@ parse_type = function(ps: ParseState, i: integer): integer, Type, integer return i, bt end -local function new_tuple(ps: ParseState, i: integer): Type, {Type} - local t = new_type(ps, i, "tuple") +local function new_tuple(ps: ParseState, i: integer): TupleType, {Type} + local t = new_type(ps, i, "tuple") as TupleType t.tuple = {} return t, t.tuple end -parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): integer, Type +parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): integer, TupleType local t, list = new_tuple(ps, i) local first_token = ps.tokens[i].tk @@ -3756,7 +3765,7 @@ local function recurse_type(ast: Type, visit: Visitor): T end end - if ast.tuple then + if ast is TupleType then for i, child in ipairs(ast.tuple) do xs[i] = recurse_type(child, visit) end @@ -4980,7 +4989,8 @@ local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: FunctionTy table.insert(rets, mark_array { get_typenum(trenv, fnarg), nil }) end ti.rets = mark_array(rets) - ti.vararg = not not rt.is_va + ti.vararg = not not rt.args.is_va + ti.varret = not not rt.rets.is_va end get_typenum = function(trenv: TypeReportEnv, t: Type): integer @@ -4999,7 +5009,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer local rt = t if is_typetype(rt) then rt = rt.def - elseif rt.typename == "tuple" and #rt.tuple == 1 then + elseif rt is TupleType and #rt.tuple == 1 then rt = rt.tuple[1] end @@ -5082,7 +5092,7 @@ end -------------------------------------------------------------------------------- local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {}) +local INVALID = a_type("invalid", {} as InvalidType) local UNKNOWN = a_type("unknown", {}) local CIRCULAR_REQUIRE = a_type("circular_require", {}) @@ -5395,7 +5405,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str else return table.concat(t.names, ".") end - elseif t.typename == "tuple" then + elseif t is TupleType then local out: {string} = {} for _, v in ipairs(t.tuple) do table.insert(out, show(v)) @@ -5788,7 +5798,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local type TypeConstructor = function({Type}):Type + local type TypeConstructor = function({Type}):TupleType local record ArgsRets ctor: TypeConstructor @@ -5922,13 +5932,13 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} fields = { ["close"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple { BOOLEAN, STRING, INTEGER } }, ["flush"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple {} }, - ["lines"] = a_file_reader(function(ctor: (function({Type}):Type), args: {Type}, rets: {Type}): Type + ["lines"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type table.insert(args, 1, NOMINAL_FILE) return a_function { args = ctor(args), rets = a_tuple { a_function { args = a_tuple {}, rets = ctor(rets) }, } } end), - ["read"] = a_file_reader(function(ctor: (function({Type}):Type), args: {Type}, rets: {Type}): Type + ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type table.insert(args, 1, NOMINAL_FILE) return a_function { args = ctor(args), rets = ctor(rets) } end), @@ -6498,7 +6508,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function union_type(t: Type): string, Type if is_typetype(t) then return union_type(t.def), t.def - elseif t.typename == "tuple" then + elseif t is TupleType then return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then local typetype = t.found or find_type(t.names) @@ -6742,8 +6752,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(copy is FunctionType) copy.min_arity = t.min_arity copy.is_method = t.is_method - copy.args, same = resolve(t.args, same) - copy.rets, same = resolve(t.rets, same) + copy.args, same = resolve(t.args, same) as (TupleType, boolean) + copy.rets, same = resolve(t.rets, same) as (TupleType, boolean) elseif is_record_type(t) then if t.typeargs then copy.typeargs = {} @@ -6796,7 +6806,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) end - elseif t.typename == "tuple" then + elseif t is TupleType then + assert(copy is TupleType) copy.is_va = t.is_va copy.tuple = {} for i, tf in ipairs(t.tuple) do @@ -6837,7 +6848,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function resolve_tuple(t: Type): Type - if t.typename == "tuple" then + if t is TupleType then t = t.tuple[1] end if t == nil then @@ -6856,7 +6867,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }) end - local function invalid_at(where: Where, msg: string, ...:Type): Type + local function invalid_at(where: Where, msg: string, ...:Type): InvalidType error_at(where, msg, ...) return INVALID end @@ -6938,14 +6949,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function type_at(w: Where, t: Type): Type + local function type_at(w: Where, t: T): T t.x = w.x t.y = w.y t.filename = filename return t end - local function resolve_typevars_at(where: Where, t: Type): Type + local function resolve_typevars_at(where: Where, t: T): T assert(where) local ok, ret, errs = resolve_typevars(t) if not ok then @@ -6958,7 +6969,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return type_at(where, ret) end - local function infer_at(where: Where, t: Type): Type + local function infer_at(where: Where, t: T): T local ret = resolve_typevars_at(where, t) if ret.typename == "invalid" then ret = t -- errors are produced by resolve_typevars_at @@ -7798,9 +7809,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["*"] = compare_false, }, ["tuple"] = { - ["tuple"] = function(a: Type, b: Type): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] - local at, bt = a.tuple, b.tuple -- ────────────────────────────────── - if #at ~= #bt then -- a tuple <: b tuple + ["tuple"] = function(a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] + local at, bt = a.tuple, b.tuple -- ────────────────────────────────── + if #at ~= #bt then -- a tuple <: b tuple return false end for i = 1, #at do @@ -8423,7 +8434,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): TupleType, Type + local type InvalidOrTupleType = InvalidType | TupleType + + local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): InvalidOrTupleType, FunctionType do local function mark_invalid_typeargs(f: FunctionType) if f.typeargs then @@ -8441,10 +8454,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function infer_emptytables(where: Where, wheres: {Where}, xs: Type, ys: Type, delta: integer) - assert(xs.typename == "tuple") - assert(ys.typename == "tuple") - + local function infer_emptytables(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8463,7 +8473,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} + local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} do -- check if a tuple `xs` matches tuple `ys` local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} @@ -8490,7 +8500,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): Type, {Error} + check_args_rets = function(where: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean @@ -8549,7 +8559,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): Type + local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType if errs then -- report the errors from the first match for _, err in ipairs(errs) do @@ -8581,7 +8591,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_typevars_at(where, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, typetype_funcall: boolean, is_method: boolean, argdelta: integer): Type, FunctionType + local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, typetype_funcall: boolean, is_method: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") @@ -8660,9 +8670,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return fail_call(where, func, given, first_errs) end - type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: TupleType, e1: Node, is_method: boolean, argdelta?: integer): TupleType, Type - if node.expected and node.expected.typename ~= "tuple" then - node.expected = a_tuple { node.expected } + type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: TupleType, e1: Node, is_method: boolean, argdelta?: integer): InvalidOrTupleType, FunctionType + local expected = node.expected + local expected_rets: TupleType + if expected and expected is TupleType then + expected_rets = expected + else + expected_rets = a_tuple { node.expected } end begin_scope() @@ -8676,7 +8690,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.e1.receiver.resolved.typename == "typetype" ) - local ret, f = check_call(node, where_args, func, args, node.expected, typetype_funcall, is_method, argdelta) + local ret, f = check_call(node, where_args, func, args, expected_rets, typetype_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() @@ -8739,7 +8753,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tbl = resolve_typetype(tbl) - if tbl.typename == "union" then + if tbl is UnionType then local t = same_in_all_union_entries(tbl, function(t: Type): (Type, Type) return (match_record_key(t, rec, key)) end) @@ -8901,23 +8915,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return st[1][var] end - local get_rets: function(Type): Type + local get_rets: function(TupleType): TupleType if lax then - get_rets = function(rets: Type): Type + get_rets = function(rets: TupleType): TupleType if #rets.tuple == 0 then return a_vararg { UNKNOWN } end return rets end else - get_rets = function(rets: Type): Type + get_rets = function(rets: TupleType): TupleType return rets end end - local function add_internal_function_variables(node: Node, args: Type) - assert(args.typename == "tuple") - + local function add_internal_function_variables(node: Node, args: TupleType) add_var(nil, "@is_va", args.is_va and ANY or NIL) add_var(nil, "@return", node.rets or a_tuple({})) @@ -8931,11 +8943,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function add_function_definition_for_recursion(node: Node, fnargs: Type) + local function add_function_definition_for_recursion(node: Node, fnargs: TupleType) assert(fnargs.typename == "tuple") -- FIXME needs this copy? - local args: Type = a_tuple({}) + local args = a_tuple({}) args.is_va = fnargs.is_va for _, fnarg in ipairs(fnargs.tuple) do table.insert(args.tuple, fnarg) @@ -8983,7 +8995,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function flatten_tuple(vals: Type): Type + local function flatten_tuple(vals: TupleType): TupleType local vt = vals.tuple local n_vals = #vt local ret = a_tuple {} @@ -8999,7 +9011,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local last = vt[n_vals] - if last.typename == "tuple" then + if last is TupleType then -- ...then unpack the last tuple local lt = last.tuple for _, v in ipairs(lt) do @@ -9014,7 +9026,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function get_assignment_values(vals: Type, wanted: integer): Type + local function get_assignment_values(vals: TupleType, wanted: integer): TupleType if vals == nil then return a_tuple {} end @@ -9630,9 +9642,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local type_check_funcall: function(node: Node, a: Type, b: Type, argdelta?: integer): Type + local type_check_funcall: function(node: Node, a: Type, b: Type, argdelta?: integer): InvalidOrTupleType - local function special_pcall_xpcall(node: Node, _a: Type, b: Type, argdelta: integer): Type + local function special_pcall_xpcall(node: Node, _a: Type, b: TupleType, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 if #node.e2 < base_nargs then error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") @@ -9666,16 +9678,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string e2 = fe2, } local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) - if rets == INVALID then + if rets is InvalidType then return rets end - assert(rets and rets.typename == "tuple", show_type(rets)) table.insert(rets.tuple, 1, BOOLEAN) return rets end - local special_functions: {string : function(Node,Type,TupleType,integer):TupleType } = { - ["pairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): TupleType + local special_functions: {string : function(Node,Type,TupleType,integer):InvalidOrTupleType } = { + ["pairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then return invalid_at(node, "pairs requires an argument") end @@ -9700,7 +9711,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) end, - ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): TupleType + ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then return invalid_at(node, "ipairs requires an argument") end @@ -9721,7 +9732,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) end, - ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): TupleType + ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType -- TODO should those offsets be fixed by _argdelta? if #b.tuple == 2 then return a_tuple({ type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) @@ -9730,12 +9741,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end, - ["require"] = function(node: Node, _a: Type, b: Type, _argdelta: integer): TupleType + ["require"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType if #b.tuple ~= 1 then return invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return a_type { typename = "any" } + return a_tuple({ a_type("any", {}) }) end local module_name = assert(node.e2[1].conststr) @@ -9746,7 +9757,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.typename == "invalid" then if lax then - return UNKNOWN + return a_tuple({ UNKNOWN }) end return invalid_at(node, "no type information for required module: '" .. module_name .. "'") end @@ -9758,7 +9769,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node: Node, a: Type, b: Type, argdelta: integer): TupleType + ["assert"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType node.known = FACT_TRUTHY local r = type_check_function_call(node, node.e2, a, b, node, false, argdelta) apply_facts(node, node.e2[1].known) @@ -9766,7 +9777,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, } - type_check_funcall = function(node: Node, a: Type, b: Type, argdelta?: integer): TupleType + type_check_funcall = function(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] @@ -9824,7 +9835,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function set_expected_types_to_decltuple(node: Node, children: {Type}) local decltuple = node.kind == "assignment" and children[1] or node.decltuple - assert(decltuple.typename == "tuple") + assert(decltuple is TupleType) local decls = decltuple.tuple if decls and node.exps then local ndecl = #decls @@ -9936,9 +9947,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node[i].key_parsed == "implicit" then - if i == #children and child.vtype.typename == "tuple" then + local cv = child.vtype + if i == #children and cv is TupleType then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) - for _, c in ipairs(child.vtype.tuple) do + for _, c in ipairs(cv.tuple) do elements = expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 @@ -10139,7 +10151,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.value.e1.tk == "require" then local t = special_functions["require"](node.value, find_var_type("require"), a_tuple { STRING }, 0) - if t ~= INVALID then + if not t is InvalidType then return t.tuple[1] end else @@ -10305,8 +10317,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, before_exp = set_expected_types_to_decltuple, after = function(node: Node, children: {Type}): Type + local valtuple = children[3] as TupleType -- may be nil + local encountered_close = false - local infertypes = get_assignment_values(children[3], #node.vars) + local infertypes = get_assignment_values(valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then if opts.gen_target == "5.4" then @@ -10356,7 +10370,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, after = function(node: Node, children: {Type}): Type - local infertypes = get_assignment_values(children[3], #node.vars) + local valtuple = children[3] as TupleType -- may be nil + + local infertypes = get_assignment_values(valtuple, #node.vars) for i, var in ipairs(node.vars) do local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) @@ -10374,8 +10390,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["assignment"] = { before_exp = set_expected_types_to_decltuple, after = function(node: Node, children: {Type}): Type - local vartypes = children[1].tuple - local valtypes = get_assignment_values(children[3], #vartypes) + local vartuple = children[1] + assert(vartuple is TupleType) + local vartypes = vartuple.tuple + local valtuple = children[3] + assert(valtuple is TupleType) + local valtypes = get_assignment_values(valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk @@ -10383,13 +10403,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then - local rets = children[3] - if rets.typename == "tuple" then - local msg = #rets.tuple == 1 - and "only 1 value is returned by the function" - or ("only " .. #rets.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) - end + local msg = #valtuple.tuple == 1 + and "only 1 value is returned by the function" + or ("only " .. #valtuple.tuple .. " values are returned by the function") + add_warning("hint", varnode, msg) end end @@ -10508,7 +10525,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope(node) end, before_statements = function(node: Node, children: {Type}) - local exptypes = children[2].tuple + local exptuple = children[2] + assert(exptuple is TupleType) + local exptypes = exptuple.tuple widen_all_unions(node) local exp1 = node.exps[1] @@ -10577,7 +10596,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["return"] = { before = function(node: Node) local rets = find_var_type("@return") - if rets then + if rets and rets is TupleType then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end @@ -10585,11 +10604,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, after = function(node: Node, children: {Type}): Type local got = children[1] + assert(got is TupleType) local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") + local expected = find_var_type("@return") as TupleType if not expected then -- if at the toplevel expected = infer_at(node, got) @@ -10770,9 +10790,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) end elseif is_array and is_number_type(child.ktype) then - if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then + local cv = child.vtype + if cv is TupleType and i == #children and node[i].key_parsed == "implicit" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) - for ti, tt in ipairs(child.vtype.tuple) do + for ti, tt in ipairs(cv.tuple) do assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) end else @@ -10853,19 +10874,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, before_statements = function(node: Node, children: {Type}) local args = children[2] + assert(args is TupleType) + add_internal_function_variables(node, args) add_function_definition_for_recursion(node, args) end, after = function(node: Node, children: {Type}): Type + local args = children[2] + assert(args is TupleType) + local rets = children[3] + assert(rets is TupleType) + end_function_scope(node) - local rets = get_rets(children[3]) local t = ensure_fresh_typeargs(a_function { y = node.y, x = node.x, typeargs = node.typeargs, - args = children[2], - rets = rets, + args = args, + rets = get_rets(rets), filename = filename, }) @@ -10882,8 +10909,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope(node) end, after = function(node: Node, children: {Type}): Type + local args = children[2] + assert(args is TupleType) + local rets = children[3] + assert(rets is TupleType) + end_function_scope(node) - local rets = get_rets(children[3]) check_macroexp_arg_use(node.macrodef) @@ -10891,8 +10922,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string y = node.y, x = node.x, typeargs = node.typeargs, - args = children[2], - rets = rets, + args = args, + rets = get_rets(rets), filename = filename, macroexp = node.macrodef, }) @@ -10920,10 +10951,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, before_statements = function(node: Node, children: {Type}) local args = children[2] + assert(args is TupleType) + add_internal_function_variables(node, args) add_function_definition_for_recursion(node, args) end, after = function(node: Node, children: {Type}): Type + local args = children[2] + assert(args is TupleType) + local rets = children[3] + assert(rets is TupleType) + end_function_scope(node) if node.is_predeclared_local_function then return NONE @@ -10933,8 +10971,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string y = node.y, x = node.x, typeargs = node.typeargs, - args = children[2], - rets = get_rets(children[3]), + args = args, + rets = get_rets(rets), filename = filename, })) @@ -10961,7 +10999,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, before_statements = function(node: Node, children: {Type}) local args = children[3] - assert(args.typename == "tuple") + assert(args is TupleType) + local rets = children[4] + assert(rets is TupleType) local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) @@ -10996,7 +11036,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(children[4]), + rets = get_rets(rets), filename = filename, }) @@ -11054,18 +11094,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, before_statements = function(node: Node, children: {Type}) local args = children[1] + assert(args is TupleType) + add_internal_function_variables(node, args) end, after = function(node: Node, children: {Type}): Type + local args = children[1] + assert(args is TupleType) + local rets = children[2] + assert(rets is TupleType) + end_function_scope(node) - -- children[1] args - -- children[2] body return ensure_fresh_typeargs(a_function { y = node.y, x = node.x, typeargs = node.typeargs, - args = children[1], - rets = children[2], + args = args, + rets = rets, filename = filename, }) end, @@ -11077,18 +11122,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, before_exp = function(node: Node, children: {Type}) local args = children[1] + assert(args is TupleType) + add_internal_function_variables(node, args) end, after = function(node: Node, children: {Type}): Type + local args = children[1] + assert(args is TupleType) + local rets = children[2] + assert(rets is TupleType) + end_function_scope(node) - -- children[1] args - -- children[2] body return ensure_fresh_typeargs(a_function { y = node.y, x = node.x, typeargs = node.typeargs, - args = children[1], - rets = children[2], + args = args, + rets = rets, filename = filename, }) end, @@ -11257,7 +11307,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, a == INVALID and a or resolve_tuple(orig_a)) + return invalid_at(node.e2, e, resolve_tuple(orig_a)) end return t From af519c6f163b5e97400a36fde14be63ea71f48ff Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 1 Jan 2024 15:10:05 -0500 Subject: [PATCH 073/224] simplify is_total check a bit The implementation of the attribute remains messy, but at least this restrict the is_total field to actual records and maps (and not nominals). The feature for records should go away once we have nilable/non-nilable fields. --- spec/declaration/local_spec.lua | 1 + tl.lua | 18 ++++++++++++------ tl.tl | 18 ++++++++++++------ 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index 0dd540f6b..8eb9906a2 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -330,6 +330,7 @@ describe("local", function() it("rejects direct declaration from non-total to total", util.check_type_error([[ local record Point x: number + y: number end local p: Point = { diff --git a/tl.lua b/tl.lua index 739e40c88..47d32271c 100644 --- a/tl.lua +++ b/tl.lua @@ -10115,10 +10115,10 @@ expand_type(node, values, elements) }) if ri.typename ~= "map" and ri.typename ~= "record" then error_at(var, "attribute only applies to maps and records") ok = false - elseif not infertype.is_total then + elseif not ri.is_total then local missing = "" - if infertype.missing then - missing = " (missing: " .. table.concat(infertype.missing, ", ") .. ")" + if ri.missing then + missing = " (missing: " .. table.concat(ri.missing, ", ") .. ")" end if ri.typename == "map" then error_at(var, "map variable declared does not declare values for all possible keys" .. missing) @@ -10128,7 +10128,7 @@ expand_type(node, values, elements) }) ok = false end end - infertype.is_total = nil + ri.is_total = nil end end @@ -10829,9 +10829,15 @@ expand_type(node, values, elements) }) end if decltype.typename == "record" then - t.is_total, t.missing = total_record_check(decltype, seen_keys) + local rt = resolve_tuple_and_nominal(t) + if rt.typename == "record" then + rt.is_total, rt.missing = total_record_check(decltype, seen_keys) + end elseif decltype.typename == "map" then - t.is_total, t.missing = total_map_check(decltype, seen_keys) + local rt = resolve_tuple_and_nominal(t) + if rt.typename == "map" then + rt.is_total, rt.missing = total_map_check(decltype, seen_keys) + end end if constraint then diff --git a/tl.tl b/tl.tl index c18a931e1..999b3a730 100644 --- a/tl.tl +++ b/tl.tl @@ -10115,10 +10115,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ri.typename ~= "map" and ri.typename ~= "record" then error_at(var, "attribute only applies to maps and records") ok = false - elseif not infertype.is_total then + elseif not ri.is_total then local missing = "" - if infertype.missing then - missing = " (missing: " .. table.concat(infertype.missing, ", ") .. ")" + if ri.missing then + missing = " (missing: " .. table.concat(ri.missing, ", ") .. ")" end if ri.typename == "map" then error_at(var, "map variable declared does not declare values for all possible keys" .. missing) @@ -10128,7 +10128,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ok = false end end - infertype.is_total = nil + ri.is_total = nil end end @@ -10829,9 +10829,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if decltype.typename == "record" then - t.is_total, t.missing = total_record_check(decltype, seen_keys) + local rt = resolve_tuple_and_nominal(t) + if rt.typename == "record" then + rt.is_total, rt.missing = total_record_check(decltype, seen_keys) + end elseif decltype is MapType then - t.is_total, t.missing = total_map_check(decltype, seen_keys) + local rt = resolve_tuple_and_nominal(t) + if rt is MapType then + rt.is_total, rt.missing = total_map_check(decltype, seen_keys) + end end if constraint then From 5768abfbfeaf9a64d8464407b8cfba8e35e5ac58 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 1 Jan 2024 15:19:49 -0500 Subject: [PATCH 074/224] remove bad_nominal type, just use invalid --- tl.lua | 59 ++++++++++++++++++++------------------------------------ tl.tl | 61 +++++++++++++++++++++------------------------------------- 2 files changed, 43 insertions(+), 77 deletions(-) diff --git a/tl.lua b/tl.lua index 47d32271c..0479576ee 100644 --- a/tl.lua +++ b/tl.lua @@ -1028,7 +1028,6 @@ end - local table_types = { @@ -1052,7 +1051,6 @@ local table_types = { ["integer"] = false, ["union"] = false, ["nominal"] = false, - ["bad_nominal"] = false, ["table_item"] = false, ["unresolved_emptytable_value"] = false, ["unresolved_typearg"] = false, @@ -4868,7 +4866,6 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["integer"] = default_type_visitor visit_type.cbs["union"] = default_type_visitor visit_type.cbs["nominal"] = default_type_visitor - visit_type.cbs["bad_nominal"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor visit_type.cbs["table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor @@ -4922,7 +4919,6 @@ local typename_to_typecode = { ["integer"] = tl.typecodes.INTEGER, ["union"] = tl.typecodes.IS_UNION, ["nominal"] = tl.typecodes.NOMINAL, - ["bad_nominal"] = tl.typecodes.NOMINAL, ["circular_require"] = tl.typecodes.NOMINAL, ["emptytable"] = tl.typecodes.EMPTY_TABLE, ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, @@ -5509,8 +5505,6 @@ local function show_type_base(t, short, seen) return "" elseif is_typetype(t) then return "type " .. show(t.def) .. (t.is_alias and " (alias)" or "") - elseif t.typename == "bad_nominal" then - return table.concat(t.names, ".") .. " (an unknown type)" else return "<" .. t.typename .. " " .. tostring(t) .. ">" end @@ -7309,29 +7303,32 @@ tl.type_check = function(ast, opts) if not typetype then error_at(t, "unknown type %s", t) return INVALID - elseif is_typetype(typetype) then - if typetype.is_alias then - typetype = typetype.def.found - assert(is_typetype(typetype)) - end + end - if typetype.def.typename == "circular_require" then + if not is_typetype(typetype) then + error_at(t, table.concat(t.names, ".") .. " is not a type") + return INVALID + end - return typetype.def - end + if typetype.is_alias then + typetype = typetype.def.found + assert(is_typetype(typetype)) + end - if typetype.def.typename == "nominal" then - typetype = typetype.def.found - assert(is_typetype(typetype)) - end - assert(typetype.def.typename ~= "nominal") - resolved = match_typevals(t, typetype.def) - else - error_at(t, table.concat(t.names, ".") .. " is not a type") + if typetype.def.typename == "circular_require" then + + return typetype.def end + if typetype.def.typename == "nominal" then + typetype = typetype.def.found + assert(is_typetype(typetype)) + end + assert(typetype.def.typename ~= "nominal") + resolved = match_typevals(t, typetype.def) if not resolved then - resolved = a_type("bad_nominal", { names = t.names }) + error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") + return INVALID end if not t.filename then @@ -7514,10 +7511,6 @@ tl.type_check = function(ast, opts) return t.typename == "nominal" and t.names[1] == "@self" end - local function compare_false(_, _) - return false - end - local function compare_true(_, _) return true end @@ -7720,9 +7713,6 @@ tl.type_check = function(ast, opts) local eqtype_relations eqtype_relations = { - ["bad_nominal"] = { - ["*"] = compare_false, - }, ["typevar"] = { ["typevar"] = function(a, b) if a.typevar == b.typevar then @@ -7796,7 +7786,6 @@ tl.type_check = function(ast, opts) end, }, ["*"] = { - ["bad_nominal"] = compare_false, ["typevar"] = function(a, b) return compare_or_infer_typevar(b.typevar, a, nil, same_type) end, @@ -7805,9 +7794,6 @@ tl.type_check = function(ast, opts) local subtype_relations subtype_relations = { - ["bad_nominal"] = { - ["*"] = compare_false, - }, ["tuple"] = { ["tuple"] = function(a, b) local at, bt = a.tuple, b.tuple @@ -8091,7 +8077,6 @@ a.types[i], b.types[i]), } end, }, ["*"] = { - ["bad_nominal"] = compare_false, ["any"] = compare_true, ["tuple"] = function(a, b) return is_a(a_type("tuple", { tuple = { a } }), b) @@ -8122,7 +8107,6 @@ a.types[i], b.types[i]), } local type_priorities = { - ["bad_nominal"] = 1, ["tuple"] = 2, ["typevar"] = 3, ["nil"] = 4, @@ -9813,7 +9797,7 @@ a.types[i], b.types[i]), } resolved = find_type(names) if (not resolved) or (not is_typetype(resolved)) then error_at(typetype, "%s is not a type", typetype) - resolved = a_type("bad_nominal", { names = names }) + resolved = INVALID end end return resolved, aliasing @@ -11956,7 +11940,6 @@ expand_type(node, values, elements) }) visit_type.cbs["number"] = default_type_visitor visit_type.cbs["integer"] = default_type_visitor visit_type.cbs["thread"] = default_type_visitor - visit_type.cbs["bad_nominal"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor visit_type.cbs["table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 999b3a730..9f5b6e80e 100644 --- a/tl.tl +++ b/tl.tl @@ -1014,7 +1014,6 @@ local enum TypeName "integer" "union" "nominal" - "bad_nominal" "emptytable" "table_item" "unresolved_emptytable_value" @@ -1052,7 +1051,6 @@ local table_types : {TypeName:boolean} = { ["integer"] = false, ["union"] = false, ["nominal"] = false, - ["bad_nominal"] = false, ["table_item"] = false, ["unresolved_emptytable_value"] = false, ["unresolved_typearg"] = false, @@ -4868,7 +4866,6 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["integer"] = default_type_visitor visit_type.cbs["union"] = default_type_visitor visit_type.cbs["nominal"] = default_type_visitor - visit_type.cbs["bad_nominal"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor visit_type.cbs["table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor @@ -4922,7 +4919,6 @@ local typename_to_typecode : {TypeName:integer} = { ["integer"] = tl.typecodes.INTEGER, ["union"] = tl.typecodes.IS_UNION, ["nominal"] = tl.typecodes.NOMINAL, - ["bad_nominal"] = tl.typecodes.NOMINAL, ["circular_require"] = tl.typecodes.NOMINAL, ["emptytable"] = tl.typecodes.EMPTY_TABLE, ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, @@ -5509,8 +5505,6 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "" elseif is_typetype(t) then return "type " .. show(t.def) .. (t.is_alias and " (alias)" or "") - elseif t.typename == "bad_nominal" then - return table.concat(t.names, ".") .. " (an unknown type)" else return "<" .. t.typename .. " " .. tostring(t) .. ">" end @@ -7309,29 +7303,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not typetype then error_at(t, "unknown type %s", t) return INVALID - elseif is_typetype(typetype) then - if typetype.is_alias then - typetype = typetype.def.found - assert(is_typetype(typetype)) - end - - if typetype.def.typename == "circular_require" then - -- return, but do not store resolution - return typetype.def - end + end - if typetype.def.typename == "nominal" then - typetype = typetype.def.found - assert(is_typetype(typetype)) - end - assert(typetype.def.typename ~= "nominal") - resolved = match_typevals(t, typetype.def) - else + if not is_typetype(typetype) then error_at(t, table.concat(t.names, ".") .. " is not a type") + return INVALID + end + + if typetype.is_alias then + typetype = typetype.def.found + assert(is_typetype(typetype)) end + if typetype.def.typename == "circular_require" then + -- return, but do not store resolution + return typetype.def + end + + if typetype.def.typename == "nominal" then + typetype = typetype.def.found + assert(is_typetype(typetype)) + end + assert(typetype.def.typename ~= "nominal") + resolved = match_typevals(t, typetype.def) if not resolved then - resolved = a_type("bad_nominal", { names = t.names }) + error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") + return INVALID end if not t.filename then @@ -7514,10 +7511,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t.typename == "nominal" and t.names[1] == "@self" end - local function compare_false(_: Type, _: Type): boolean, {Error} - return false - end - local function compare_true(_: Type, _: Type): boolean, {Error} return true end @@ -7720,9 +7713,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local eqtype_relations: TypeRelations eqtype_relations = { - ["bad_nominal"] = { - ["*"] = compare_false, - }, ["typevar"] = { ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then @@ -7796,7 +7786,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["*"] = { - ["bad_nominal"] = compare_false, ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} return compare_or_infer_typevar(b.typevar, a, nil, same_type) end, @@ -7805,9 +7794,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local subtype_relations: TypeRelations subtype_relations = { - ["bad_nominal"] = { - ["*"] = compare_false, - }, ["tuple"] = { ["tuple"] = function(a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] local at, bt = a.tuple, b.tuple -- ────────────────────────────────── @@ -8091,7 +8077,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["*"] = { - ["bad_nominal"] = compare_false, ["any"] = compare_true, ["tuple"] = function(a: Type, b: Type): boolean, {Error} return is_a(a_tuple({a}), b) @@ -8122,7 +8107,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- evaluation strategy local type_priorities: {TypeName:integer} = { -- types that have catch-all rules evaluate first - ["bad_nominal"] = 1, ["tuple"] = 2, ["typevar"] = 3, ["nil"] = 4, @@ -9813,7 +9797,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string resolved = find_type(names) if (not resolved) or (not is_typetype(resolved)) then error_at(typetype, "%s is not a type", typetype) - resolved = a_type("bad_nominal", { names = names }) + resolved = INVALID end end return resolved, aliasing @@ -11956,7 +11940,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["number"] = default_type_visitor visit_type.cbs["integer"] = default_type_visitor visit_type.cbs["thread"] = default_type_visitor - visit_type.cbs["bad_nominal"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor visit_type.cbs["table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor From fd021953cbd6c129c7f8a221874fe35c581b1770 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 1 Jan 2024 15:28:02 -0500 Subject: [PATCH 075/224] refactor: move show_fields out --- tl.lua | 65 ++++++++++++++++++++++++++++------------------------------ tl.tl | 65 ++++++++++++++++++++++++++++------------------------------ 2 files changed, 62 insertions(+), 68 deletions(-) diff --git a/tl.lua b/tl.lua index 0479576ee..5405788c5 100644 --- a/tl.lua +++ b/tl.lua @@ -5341,6 +5341,35 @@ local function display_typevar(typevar) return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) end +local function show_fields(t, show) + if t.names then + return t.names[1] + end + + local out = {} + if t.typeargs then + table.insert(out, "<") + local typeargs = {} + for _, v in ipairs(t.typeargs) do + table.insert(typeargs, show(v)) + end + table.insert(out, table.concat(typeargs, ", ")) + table.insert(out, ">") + end + table.insert(out, " (") + if t.elements then + table.insert(out, "{" .. show(t.elements) .. "}") + end + local fs = {} + for _, k in ipairs(t.field_order) do + local v = t.fields[k] + table.insert(fs, k .. ": " .. show(v)) + end + table.insert(out, table.concat(fs, "; ")) + table.insert(out, ")") + return table.concat(out) +end + local function show_type_base(t, short, seen) if seen[t] then @@ -5352,38 +5381,6 @@ local function show_type_base(t, short, seen) return show_type(typ, short, seen) end - local function show_record_type(name) - if t.names then - return t.names[1] - end - if short then - return name - else - local out = { name } - if t.typeargs then - table.insert(out, "<") - local typeargs = {} - for _, v in ipairs(t.typeargs) do - table.insert(typeargs, show(v)) - end - table.insert(out, table.concat(typeargs, ", ")) - table.insert(out, ">") - end - table.insert(out, " (") - if t.elements then - table.insert(out, "{" .. show(t.elements) .. "}") - end - local fs = {} - for _, k in ipairs(t.field_order) do - local v = t.fields[k] - table.insert(fs, k .. ": " .. show(v)) - end - table.insert(out, table.concat(fs, "; ")) - table.insert(out, ")") - return table.concat(out) - end - end - if t.typename == "nominal" then if #t.names == 1 and t.names[1] == "@self" then return "self" @@ -5438,9 +5435,9 @@ local function show_type_base(t, short, seen) elseif t.typename == "enum" then return t.names and table.concat(t.names, ".") or "enum" elseif t.typename == "interface" then - return show_record_type("interface") + return short and "interface" or "interface" .. show_fields(t, show) elseif is_record_type(t) then - return show_record_type("record") + return short and "record" or "record" .. show_fields(t, show) elseif t.typename == "function" then local out = { "function" } if t.typeargs then diff --git a/tl.tl b/tl.tl index 9f5b6e80e..d5bbbaab0 100644 --- a/tl.tl +++ b/tl.tl @@ -5341,6 +5341,35 @@ local function display_typevar(typevar: string): string return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) end +local function show_fields(t: Type, show: function(Type):(string)): string + if t.names then + return t.names[1] + end + + local out: {string} = {} + if t.typeargs then + table.insert(out, "<") + local typeargs = {} + for _, v in ipairs(t.typeargs) do + table.insert(typeargs, show(v)) + end + table.insert(out, table.concat(typeargs, ", ")) + table.insert(out, ">") + end + table.insert(out, " (") + if t.elements then + table.insert(out, "{" .. show(t.elements) .. "}") + end + local fs = {} + for _, k in ipairs(t.field_order) do + local v = t.fields[k] + table.insert(fs, k .. ": " .. show(v)) + end + table.insert(out, table.concat(fs, "; ")) + table.insert(out, ")") + return table.concat(out) +end + local function show_type_base(t: Type, short: boolean, seen: {Type:string}): string -- FIXME this is a control for recursively built types, which should in principle not exist if seen[t] then @@ -5352,38 +5381,6 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return show_type(typ, short, seen) end - local function show_record_type(name: string): string - if t.names then - return t.names[1] - end - if short then - return name - else - local out: {string} = {name} - if t.typeargs then - table.insert(out, "<") - local typeargs = {} - for _, v in ipairs(t.typeargs) do - table.insert(typeargs, show(v)) - end - table.insert(out, table.concat(typeargs, ", ")) - table.insert(out, ">") - end - table.insert(out, " (") - if t.elements then - table.insert(out, "{" .. show(t.elements) .. "}") - end - local fs = {} - for _, k in ipairs(t.field_order) do - local v = t.fields[k] - table.insert(fs, k .. ": " .. show(v)) - end - table.insert(out, table.concat(fs, "; ")) - table.insert(out, ")") - return table.concat(out) - end - end - if t.typename == "nominal" then if #t.names == 1 and t.names[1] == "@self" then return "self" @@ -5438,9 +5435,9 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str elseif t.typename == "enum" then return t.names and table.concat(t.names, ".") or "enum" elseif t.typename == "interface" then - return show_record_type("interface") + return short and "interface" or "interface" .. show_fields(t, show) elseif is_record_type(t) then - return show_record_type("record") + return short and "record" or "record" .. show_fields(t, show) elseif t is FunctionType then local out: {string} = {"function"} if t.typeargs then From cc9c0295fe1c6f3d25acab23cafd094f0d8eb1a7 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 3 Jan 2024 15:18:48 -0300 Subject: [PATCH 076/224] parse_newtype: preparatory refactor --- tl.lua | 19 ++++++++++++++----- tl.tl | 19 ++++++++++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/tl.lua b/tl.lua index 5405788c5..568595703 100644 --- a/tl.lua +++ b/tl.lua @@ -3280,18 +3280,27 @@ parse_newtype = function(ps, i) def = new_type(ps, i, tn) i = i + 1 i = parse_type_body_fns[tn](ps, i, def, node) + if not def then + return fail(ps, i, "expected a type") + end + + node.newtype = new_typetype(ps, itype, def) + return i, node else i, def = parse_type(ps, i) if not def then - return i + return fail(ps, i, "expected a type") + end + + if def.typename == "nominal" then + node.newtype = new_type(ps, itype, "typetype") + node.newtype.def = def + else + node.newtype = new_typetype(ps, itype, def) end - end - if def then - node.newtype = new_typetype(ps, itype, def) return i, node end - return fail(ps, i, "expected a type") end local function parse_assignment_expression_list(ps, i, asgn) diff --git a/tl.tl b/tl.tl index d5bbbaab0..aa882b6ae 100644 --- a/tl.tl +++ b/tl.tl @@ -3280,18 +3280,27 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node def = new_type(ps, i, tn) i = i + 1 i = parse_type_body_fns[tn](ps, i, def, node) + if not def then + return fail(ps, i, "expected a type") + end + + node.newtype = new_typetype(ps, itype, def) + return i, node else i, def = parse_type(ps, i) if not def then - return i + return fail(ps, i, "expected a type") + end + + if def.typename == "nominal" then + node.newtype = new_type(ps, itype, "typetype") -- TODO "typealias" + node.newtype.def = def -- todo alias_to + else + node.newtype = new_typetype(ps, itype, def) end - end - if def then - node.newtype = new_typetype(ps, itype, def) return i, node end - return fail(ps, i, "expected a type") end local function parse_assignment_expression_list(ps: ParseState, i: integer, asgn: Node): integer, Node From f55650113d28f8ce6428c833b4c845437a540989 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 3 Jan 2024 15:23:49 -0300 Subject: [PATCH 077/224] refactor: declname and is_alias_node Rename non-nominal-type use of names to declname, and rename Node.is_alias to Node.is_alias_node --- tl.lua | 42 +++++++++++++++++++++++++++--------------- tl.tl | 44 ++++++++++++++++++++++++++++---------------- 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/tl.lua b/tl.lua index 568595703..f9966e8b5 100644 --- a/tl.lua +++ b/tl.lua @@ -1327,6 +1327,8 @@ local table_types = { + + @@ -3431,8 +3433,12 @@ local function parse_type_declaration(ps, i, node_name) if not asgn.value then return i end - if not asgn.value.newtype.def.names then - asgn.value.newtype.def.names = { asgn.var.tk } + + local nt = asgn.value.newtype + if nt.typename == "typetype" then + if not nt.def.declname then + nt.def.declname = asgn.var.tk + end end return i, asgn @@ -3451,7 +3457,7 @@ local function parse_type_constructor(ps, i, node_name, type_name, parse_body) if not asgn.var then return fail(ps, i, "expected a type name") end - def.names = { asgn.var.tk } + def.declname = asgn.var.tk i = parse_body(ps, i, def, nt) @@ -4747,8 +4753,14 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["newtype"] = { after = function(node, _children) local out = { y = node.y, h = 0 } - if node.is_alias then - table.insert(out, table.concat(node.newtype.def.names, ".")) + if node.is_alias_node then + local def = node.newtype.def + if def.names then + table.insert(out, table.concat(def.names, ".")) + else + assert(def.declname) + table.insert(out, def.declname) + end elseif is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) else @@ -5351,8 +5363,8 @@ local function display_typevar(typevar) end local function show_fields(t, show) - if t.names then - return t.names[1] + if t.declname then + return " " .. t.declname end local out = {} @@ -5442,7 +5454,7 @@ local function show_type_base(t, short, seen) elseif t.typename == "array" then return "{" .. show(t.elements) .. "}" elseif t.typename == "enum" then - return t.names and table.concat(t.names, ".") or "enum" + return t.declname or "enum" elseif t.typename == "interface" then return short and "interface" or "interface" .. show_fields(t, show) elseif is_record_type(t) then @@ -6708,7 +6720,7 @@ tl.type_check = function(ast, opts) copy.y = t.y copy.yend = t.yend copy.xend = t.xend - copy.names = t.names + copy.declname = t.declname if t.typename == "array" then copy.elements, same = resolve(t.elements, same) @@ -6735,6 +6747,7 @@ tl.type_check = function(ast, opts) elseif is_typetype(t) then copy.def, same = resolve(t.def, same) elseif t.typename == "nominal" then + copy.names = t.names copy.typevals = {} for i, tf in ipairs(t.typevals) do copy.typevals[i], same = resolve(tf, same) @@ -7351,8 +7364,7 @@ tl.type_check = function(ast, opts) end local function are_same_unresolved_global_type(t1, t2) - if #t1.names == 1 and #t2.names == 1 and - t1.names[1] == t2.names[1] then + if t1.names[1] == t2.names[1] then local unresolved = get_unresolved() if unresolved.global_types[t1.names[1]] then @@ -7972,7 +7984,7 @@ a.types[i], b.types[i]), } if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then return true end - if not a.names then + if not a.declname then return subtype_record(a, b) end @@ -10262,7 +10274,7 @@ expand_type(node, values, elements) }) local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing - node.value.is_alias = true + node.value.is_alias_node = true end end, after = function(node, _children) @@ -10280,7 +10292,7 @@ expand_type(node, values, elements) }) node.value.newtype = resolved if aliasing then added.aliasing = aliasing - node.value.is_alias = true + node.value.is_alias_node = true end if added and unresolved.global_types[name] then @@ -11763,7 +11775,7 @@ expand_type(node, values, elements) }) if ftype.is_method then local fargs = ftype.args.tuple if fargs[1] and fargs[1].is_self then - local record_name = typ.names and typ.names[1] + local record_name = typ.declname if record_name then local selfarg = fargs[1] if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then diff --git a/tl.tl b/tl.tl index aa882b6ae..8e4e7be84 100644 --- a/tl.tl +++ b/tl.tl @@ -1089,6 +1089,8 @@ local interface Type -- arguments: optional arity opt: boolean + declname: string + -- typetype def: Type is_alias: boolean @@ -1487,7 +1489,7 @@ local record Node -- newtype newtype: Type - is_alias: boolean + is_alias_node: boolean elide_type: boolean -- expressions @@ -3431,8 +3433,12 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod if not asgn.value then return i end - if not asgn.value.newtype.def.names then - asgn.value.newtype.def.names = { asgn.var.tk } + + local nt = asgn.value.newtype + if nt.typename == "typetype" then + if not nt.def.declname then + nt.def.declname = asgn.var.tk + end end return i, asgn @@ -3451,7 +3457,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod if not asgn.var then return fail(ps, i, "expected a type name") end - def.names = { asgn.var.tk } + def.declname = asgn.var.tk i = parse_body(ps, i, def, nt) @@ -4747,8 +4753,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["newtype"] = { after = function(node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } - if node.is_alias then - table.insert(out, table.concat(node.newtype.def.names, ".")) + if node.is_alias_node then + local def = node.newtype.def + if def.names then + table.insert(out, table.concat(def.names, ".")) + else + assert(def.declname) + table.insert(out, def.declname) + end elseif is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) else @@ -5351,8 +5363,8 @@ local function display_typevar(typevar: string): string end local function show_fields(t: Type, show: function(Type):(string)): string - if t.names then - return t.names[1] + if t.declname then + return " " .. t.declname end local out: {string} = {} @@ -5442,7 +5454,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str elseif t.typename == "array" then return "{" .. show(t.elements) .. "}" elseif t.typename == "enum" then - return t.names and table.concat(t.names, ".") or "enum" + return t.declname or "enum" elseif t.typename == "interface" then return short and "interface" or "interface" .. show_fields(t, show) elseif is_record_type(t) then @@ -6708,7 +6720,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.y = t.y copy.yend = t.yend copy.xend = t.xend - copy.names = t.names -- which types have this, exactly? + copy.declname = t.declname -- which types have this, exactly? if t.typename == "array" then copy.elements, same = resolve(t.elements, same) @@ -6735,6 +6747,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif is_typetype(t) then copy.def, same = resolve(t.def, same) elseif t.typename == "nominal" then + copy.names = t.names copy.typevals = {} for i, tf in ipairs(t.typevals) do copy.typevals[i], same = resolve(tf, same) @@ -7351,8 +7364,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function are_same_unresolved_global_type(t1: Type, t2: Type): boolean - if #t1.names == 1 and #t2.names == 1 - and t1.names[1] == t2.names[1] + if t1.names[1] == t2.names[1] then local unresolved = get_unresolved() if unresolved.global_types[t1.names[1]] then @@ -7972,7 +7984,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then return true end - if not a.names then + if not a.declname then -- match inferred table (anonymous record) structurally to interface return subtype_record(a, b) end @@ -10262,7 +10274,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing - node.value.is_alias = true + node.value.is_alias_node = true end end, after = function(node: Node, _children: {Type}): Type @@ -10280,7 +10292,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.value.newtype = resolved if aliasing then added.aliasing = aliasing - node.value.is_alias = true + node.value.is_alias_node = true end if added and unresolved.global_types[name] then @@ -11763,7 +11775,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ftype.is_method then local fargs = ftype.args.tuple if fargs[1] and fargs[1].is_self then - local record_name = typ.names and typ.names[1] + local record_name = typ.declname if record_name then local selfarg = fargs[1] if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then From 9845b79d0f8509efa5bc5ee0e3b4d2f7e3116207 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 3 Jan 2024 17:28:41 -0300 Subject: [PATCH 078/224] add typealias type, split from typetype --- spec/declaration/local_spec.lua | 2 +- tl.lua | 235 ++++++++++++++++++------------- tl.tl | 239 +++++++++++++++++++------------- 3 files changed, 283 insertions(+), 193 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index 8eb9906a2..d2936d67d 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -127,7 +127,7 @@ describe("local", function() it("catches unknown types", util.check_type_error([[ local type MyType = UnknownType ]], { - { msg = "UnknownType is not a type" } + { msg = "unknown type UnknownType" } })) it("nominal types can take type arguments", util.check([[ diff --git a/tl.lua b/tl.lua index f9966e8b5..66d5c456c 100644 --- a/tl.lua +++ b/tl.lua @@ -1028,6 +1028,7 @@ end + local table_types = { @@ -1039,6 +1040,7 @@ local table_types = { ["tupletable"] = true, ["typetype"] = false, + ["typealias"] = false, ["typevar"] = false, ["typearg"] = false, ["function"] = false, @@ -1328,6 +1330,9 @@ local table_types = { + + + @@ -1543,10 +1548,6 @@ local function is_number_type(t) return t.typename == "number" or t.typename == "integer" end -local function is_typetype(t) - return t.typename == "typetype" -end - @@ -3191,6 +3192,10 @@ parse_record_body = function(ps, i, def, node) return fail(ps, i, "expected a type definition") end + if nt.newtype.typename == "typealias" then + nt.newtype.is_nested_alias = true + end + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) @@ -3295,8 +3300,8 @@ parse_newtype = function(ps, i) end if def.typename == "nominal" then - node.newtype = new_type(ps, itype, "typetype") - node.newtype.def = def + node.newtype = new_type(ps, itype, "typealias") + node.newtype.alias_to = def else node.newtype = new_typetype(ps, itype, def) end @@ -3796,6 +3801,9 @@ local function recurse_type(ast, visit) if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end + if ast.alias_to then + table.insert(xs, recurse_type(ast.alias_to, visit)) + end if ast.typename == "map" then table.insert(xs, recurse_type(ast.keys, visit)) table.insert(xs, recurse_type(ast.values, visit)) @@ -4309,7 +4317,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local function print_record_def(typ) local out = { "{" } for _, name in ipairs(typ.field_order) do - if is_typetype(typ.fields[name]) and is_record_type(typ.fields[name].def) then + if typ.fields[name].typename == "typetype" and is_record_type(typ.fields[name].def) then table.insert(out, name) table.insert(out, " = ") table.insert(out, print_record_def(typ.fields[name].def)) @@ -4754,12 +4762,12 @@ function tl.pretty_print_ast(ast, gen_target, mode) after = function(node, _children) local out = { y = node.y, h = 0 } if node.is_alias_node then - local def = node.newtype.def - if def.names then - table.insert(out, table.concat(def.names, ".")) + local nt = node.newtype + if nt.typename == "typealias" then + table.insert(out, table.concat(nt.alias_to.names, ".")) else - assert(def.declname) - table.insert(out, def.declname) + assert(nt.typename == "typetype") + table.insert(out, nt.def.declname) end elseif is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) @@ -4872,6 +4880,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["string"] = default_type_visitor visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["typevar"] = default_type_visitor visit_type.cbs["typearg"] = default_type_visitor visit_type.cbs["function"] = default_type_visitor @@ -4953,6 +4962,7 @@ local typename_to_typecode = { ["table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, + ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5024,8 +5034,10 @@ get_typenum = function(trenv, t) n = trenv.next_num local rt = t - if is_typetype(rt) then + if rt.typename == "typetype" then rt = rt.def + elseif rt.typename == "typealias" then + rt = rt.alias_to elseif rt.typename == "tuple" and #rt.tuple == 1 then rt = rt.tuple[1] end @@ -5047,7 +5059,7 @@ get_typenum = function(trenv, t) if t.resolved then rt = t end - assert(not is_typetype(rt)) + assert(not (rt.typename == "typetype" or rt.typename == "typealias")) if is_record_type(rt) then @@ -5521,8 +5533,10 @@ local function show_type_base(t, short, seen) return "nil" elseif t.typename == "none" then return "" - elseif is_typetype(t) then - return "type " .. show(t.def) .. (t.is_alias and " (alias)" or "") + elseif t.typename == "typealias" then + return "type " .. show(t.alias_to) + elseif t.typename == "typetype" then + return "type " .. show(t.def) else return "<" .. t.typename .. " " .. tostring(t) .. ">" end @@ -6512,14 +6526,18 @@ tl.type_check = function(ast, opts) return nil end end - if is_typetype(typ) or (accept_typearg and typ.typename == "typearg") then + if typ.typename == "typetype" or typ.typename == "typealias" then + return typ + elseif accept_typearg and typ.typename == "typearg" then return typ end end local function union_type(t) - if is_typetype(t) then + if t.typename == "typetype" then return union_type(t.def), t.def + elseif t.typename == "typealias" then + return union_type(t.alias_to), t.alias_to elseif t.typename == "tuple" then return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then @@ -6645,8 +6663,10 @@ tl.type_check = function(ast, opts) end local function resolve_typetype(t) - if is_typetype(t) then + if t.typename == "typetype" then return t.def + elseif t.typename == "typealias" then + return t.alias_to else return t end @@ -6744,8 +6764,11 @@ tl.type_check = function(ast, opts) if t.constraint then copy.constraint, same = resolve(t.constraint, same) end - elseif is_typetype(t) then + elseif t.typename == "typetype" then copy.def, same = resolve(t.def, same) + elseif t.typename == "typealias" then + copy.alias_to, same = resolve(t.alias_to, same) + copy.is_nested_alias = t.is_nested_alias elseif t.typename == "nominal" then copy.names = t.names copy.typevals = {} @@ -6931,7 +6954,8 @@ tl.type_check = function(ast, opts) "unused %s %s: %s", var.is_func_arg and "argument" or t.typename == "function" and "function" or - is_typetype(var.t) and "type" or + t.typename == "typetype" and "type" or + t.typename == "typealias" and "type" or "variable", name, show_type(var.t)) @@ -7151,7 +7175,7 @@ tl.type_check = function(ast, opts) local function close_nested_records(t) for _, ft in pairs(t.fields) do - if is_typetype(ft) then + if ft.typename == "typetype" then ft.closed = true if is_record_type(ft.def) then close_nested_records(ft.def) @@ -7162,10 +7186,11 @@ tl.type_check = function(ast, opts) local function close_types(vars) for _, var in pairs(vars) do - if is_typetype(var.t) then - var.t.closed = true - if is_record_type(var.t.def) then - close_nested_records(var.t.def) + local t = var.t + if t.typename == "typetype" then + t.closed = true + if is_record_type(t.def) then + close_nested_records(t.def) end end end @@ -7184,16 +7209,17 @@ tl.type_check = function(ast, opts) end local list = {} for name, var in pairs(vars) do + local t = var.t if var.declared_at and not var.used then if var.used_as_type then var.declared_at.elide_type = true else - if is_typetype(var.t) and not is_global then + if (t.typename == "typetype" or t.typename == "typealias") and not is_global then var.declared_at.elide_type = true end table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and is_typetype(var.t) and var.aliasing then + elseif var.used and (t.typename == "typetype" or t.typename == "typealias") and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -7285,6 +7311,7 @@ tl.type_check = function(ast, opts) end local resolve_nominal + local resolve_typealias do local function match_typevals(t, def) if t.typevals and def.typeargs then @@ -7316,35 +7343,36 @@ tl.type_check = function(ast, opts) return t.resolved end - local resolved - local typetype = t.found or find_type(t.names) if not typetype then error_at(t, "unknown type %s", t) return INVALID end - if not is_typetype(typetype) then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID - end + local resolved - if typetype.is_alias then - typetype = typetype.def.found - assert(is_typetype(typetype)) + if typetype.typename == "typealias" then + typetype = typetype.alias_to.found end - if typetype.def.typename == "circular_require" then + if typetype.typename == "typetype" then + if typetype.def.typename == "circular_require" then - return typetype.def - end + return typetype.def + end + + if typetype.def.typename == "nominal" then + typetype = typetype.def.found + assert(typetype.typename == "typetype") + end + assert(typetype.def.typename ~= "nominal") - if typetype.def.typename == "nominal" then - typetype = typetype.def.found - assert(is_typetype(typetype)) + resolved = match_typevals(t, typetype.def) + else + error_at(t, table.concat(t.names, ".") .. " is not a type") + return INVALID end - assert(typetype.def.typename ~= "nominal") - resolved = match_typevals(t, typetype.def) + if not resolved then error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") return INVALID @@ -7361,6 +7389,36 @@ tl.type_check = function(ast, opts) t.resolved = resolved return resolved end + + resolve_typealias = function(typealias) + local names = typealias.alias_to.names + local aliasing = find_var(names[1], "use_type") + if not aliasing then + return INVALID + end + + local t = typealias.alias_to + if t.resolved then + return t.resolved, aliasing + end + + local typetype = t.found or find_type(t.names) + if not typetype then + error_at(t, "unknown type %s", t) + return INVALID + end + + if t.typevals then + local resolved = match_typevals(t, typetype.def) + t.resolved = resolved + t.found = typetype + typetype = a_type("typetype", { def = resolved }) + else + t.resolved = t + end + + return typetype, aliasing + end end local function are_same_unresolved_global_type(t1, t2) @@ -8331,7 +8389,7 @@ a.types[i], b.types[i]), } end end - if is_typetype(func) and func.def.typename == "record" then + if func.typename == "typetype" and func.def.typename == "record" then func = func.def end @@ -8749,12 +8807,16 @@ a.types[i], b.types[i]), } tbl = find_var_type("string") end - if tbl.is_alias then - return nil, "cannot use a nested type alias as a concrete value" + if tbl.typename == "typetype" then + tbl = tbl.def + elseif tbl.typename == "typealias" then + if tbl.is_nested_alias then + return nil, "cannot use a nested type alias as a concrete value" + else + tbl = resolve_nominal(tbl.alias_to) + end end - tbl = resolve_typetype(tbl) - if tbl.typename == "union" then local t = same_in_all_union_entries(tbl, function(t) return (match_record_key(t, rec, key)) @@ -9804,25 +9866,6 @@ a.types[i], b.types[i]), } node.exps[i].tk == node.vars[i].tk end - local function resolve_nominal_typetype(typetype) - if typetype.def.typename == "nominal" then - local names = typetype.def.names - local aliasing = find_var(names[1], "use_type") - local resolved = typetype - if typetype.def.typevals then - typetype.def = resolve_nominal(typetype.def) - else - resolved = find_type(names) - if (not resolved) or (not is_typetype(resolved)) then - error_at(typetype, "%s is not a type", typetype) - resolved = INVALID - end - end - return resolved, aliasing - end - return typetype, nil - end - local function missing_initializer(node, i, name) if lax then return UNKNOWN @@ -10146,18 +10189,23 @@ expand_type(node, values, elements) }) return ok, t, infertype ~= nil end - local function get_type_declaration(node) - if node.value.kind == "op" and - node.value.op.op == "@funcall" and - node.value.e1.kind == "variable" and - node.value.e1.tk == "require" then + local function get_type_declaration(value) + if value.kind == "op" and + value.op.op == "@funcall" and + value.e1.kind == "variable" and + value.e1.tk == "require" then - local t = special_functions["require"](node.value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) + local t = special_functions["require"](value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) if not (t.typename == "invalid") then return t.tuple[1] end else - return resolve_nominal_typetype(node.value.newtype) + local newtype = value.newtype + if newtype.typename == "typealias" then + return resolve_typealias(value.newtype) + else + return value.newtype, nil + end end end @@ -10178,7 +10226,8 @@ expand_type(node, values, elements) }) local is_total = true local missing for _, key in ipairs(t.field_order) do - if not is_typetype(t.fields[key]) then + local ftype = t.fields[key] + if not (ftype.typename == "typetype" or ftype.typename == "typealias") then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10223,7 +10272,7 @@ expand_type(node, values, elements) }) end local var = resolve_tuple_and_nominal(vartype) - if is_typetype(var) then + if var.typename == "typetype" or var.typename == "typealias" then error_at(where, "cannot reassign a type") return nil end @@ -10270,7 +10319,7 @@ expand_type(node, values, elements) }) ["local_type"] = { before = function(node) local name = node.var.tk - local resolved, aliasing = get_type_declaration(node) + local resolved, aliasing = get_type_declaration(node.value) local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing @@ -10287,7 +10336,7 @@ expand_type(node, values, elements) }) local name = node.var.tk local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_type_declaration(node) + local resolved, aliasing = get_type_declaration(node.value) local added = add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then @@ -10776,7 +10825,7 @@ expand_type(node, values, elements) }) if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - if is_typetype(df) then + if df.typename == "typetype" or df.typename == "typealias" then error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -11558,7 +11607,7 @@ expand_type(node, values, elements) }) return invalid_at(node, "unknown variable: " .. node.tk) end - if is_typetype(t) then + if t.typename == "typetype" then t = typetype_to_nominal(node, node.tk, t, t) end @@ -11734,14 +11783,12 @@ expand_type(node, values, elements) }) begin_scope() add_var(nil, "@self", type_at(typ, a_type("typetype", { def = typ }))) - for name, typ2 in fields_of(typ) do - if typ2.typename == "typetype" then - local resolved, is_alias = resolve_nominal_typetype(typ2) - if is_alias then - typ2.is_alias = true - typ2.def.resolved = resolved - end - add_var(nil, name, resolved) + for fname, ftype in fields_of(typ) do + if ftype.typename == "typealias" then + resolve_nominal(ftype.alias_to) + add_var(nil, fname, ftype) + elseif ftype.typename == "typetype" then + add_var(nil, fname, ftype) end end end, @@ -11863,11 +11910,8 @@ expand_type(node, values, elements) }) assert(typ.typename == "typevar") typ.typevar = t.typearg typ.constraint = t.constraint - else - if t.is_alias then - t = t.def.resolved - end - if not (t.def and t.def.typename == "circular_require") then + elseif t.typename == "typetype" then + if t.def.typename ~= "circular_require" then typ.found = t end end @@ -11950,6 +11994,7 @@ expand_type(node, values, elements) }) visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor visit_type.cbs["enum"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 8e4e7be84..8ba217687 100644 --- a/tl.tl +++ b/tl.tl @@ -997,6 +997,7 @@ end local enum TypeName "typetype" + "typealias" "typevar" "typearg" "function" @@ -1039,6 +1040,7 @@ local table_types : {TypeName:boolean} = { ["tupletable"] = true, ["typetype"] = false, + ["typealias"] = false, ["typevar"] = false, ["typearg"] = false, ["function"] = false, @@ -1093,10 +1095,13 @@ local interface Type -- typetype def: Type - is_alias: boolean closed: boolean is_abstract: boolean + -- typealias + alias_to: Type + is_nested_alias: boolean + -- records interface_list: {Type} interfaces_expanded: boolean @@ -1543,10 +1548,6 @@ local function is_number_type(t:Type): boolean return t.typename == "number" or t.typename == "integer" end -local function is_typetype(t: Type): boolean - return t.typename == "typetype" -end - local record ParseState tokens: {Token} errs: {Error} @@ -3191,6 +3192,10 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): return fail(ps, i, "expected a type definition") end + if nt.newtype.typename == "typealias" then + nt.newtype.is_nested_alias = true + end + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) elseif parse_type_body_fns[tn] and ps.tokens[i+1].tk ~= ":" then i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) @@ -3295,8 +3300,8 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end if def.typename == "nominal" then - node.newtype = new_type(ps, itype, "typetype") -- TODO "typealias" - node.newtype.def = def -- todo alias_to + node.newtype = new_type(ps, itype, "typealias") + node.newtype.alias_to = def else node.newtype = new_typetype(ps, itype, def) end @@ -3796,6 +3801,9 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end + if ast.alias_to then + table.insert(xs, recurse_type(ast.alias_to, visit)) + end if ast is MapType then table.insert(xs, recurse_type(ast.keys, visit)) table.insert(xs, recurse_type(ast.values, visit)) @@ -4309,7 +4317,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local function print_record_def(typ: Type): string local out: {string} = { "{" } for _, name in ipairs(typ.field_order) do - if is_typetype(typ.fields[name]) and is_record_type(typ.fields[name].def) then + if typ.fields[name].typename == "typetype" and is_record_type(typ.fields[name].def) then table.insert(out, name) table.insert(out, " = ") table.insert(out, print_record_def(typ.fields[name].def)) @@ -4754,12 +4762,12 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | after = function(node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.is_alias_node then - local def = node.newtype.def - if def.names then - table.insert(out, table.concat(def.names, ".")) + local nt = node.newtype + if nt.typename == "typealias" then + table.insert(out, table.concat(nt.alias_to.names, ".")) else - assert(def.declname) - table.insert(out, def.declname) + assert(nt.typename == "typetype") + table.insert(out, nt.def.declname) end elseif is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) @@ -4872,6 +4880,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["string"] = default_type_visitor visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["typevar"] = default_type_visitor visit_type.cbs["typearg"] = default_type_visitor visit_type.cbs["function"] = default_type_visitor @@ -4953,6 +4962,7 @@ local typename_to_typecode : {TypeName:integer} = { ["table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, + ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5024,8 +5034,10 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer n = trenv.next_num local rt = t - if is_typetype(rt) then + if rt.typename == "typetype" then rt = rt.def + elseif rt.typename == "typealias" then + rt = rt.alias_to elseif rt is TupleType and #rt.tuple == 1 then rt = rt.tuple[1] end @@ -5047,7 +5059,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer if t.resolved then rt = t end - assert(not is_typetype(rt)) + assert(not (rt.typename == "typetype" or rt.typename == "typealias")) if is_record_type(rt) then -- store record field info @@ -5521,8 +5533,10 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "nil" elseif t.typename == "none" then return "" - elseif is_typetype(t) then - return "type " .. show(t.def) .. (t.is_alias and " (alias)" or "") + elseif t.typename == "typealias" then + return "type " .. show(t.alias_to) + elseif t.typename == "typetype" then + return "type " .. show(t.def) else return "<" .. t.typename .. " " .. tostring(t) .. ">" end @@ -6512,14 +6526,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end end - if is_typetype(typ) or (accept_typearg and typ is TypeArgType) then + if typ.typename == "typetype" or typ.typename == "typealias" then + return typ + elseif accept_typearg and typ is TypeArgType then return typ end end local function union_type(t: Type): string, Type - if is_typetype(t) then + if t.typename == "typetype" then return union_type(t.def), t.def + elseif t.typename == "typealias" then + return union_type(t.alias_to), t.alias_to elseif t is TupleType then return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then @@ -6645,8 +6663,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function resolve_typetype(t: Type): Type - if is_typetype(t) then + if t.typename == "typetype" then return t.def + elseif t.typename == "typealias" then + return t.alias_to else return t end @@ -6744,8 +6764,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.constraint then copy.constraint, same = resolve(t.constraint, same) end - elseif is_typetype(t) then + elseif t.typename == "typetype" then copy.def, same = resolve(t.def, same) + elseif t.typename == "typealias" then + copy.alias_to, same = resolve(t.alias_to, same) + copy.is_nested_alias = t.is_nested_alias elseif t.typename == "nominal" then copy.names = t.names copy.typevals = {} @@ -6931,7 +6954,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "unused %s %s: %s", var.is_func_arg and "argument" or t is FunctionType and "function" - or is_typetype(var.t) and "type" + or t.typename == "typetype" and "type" + or t.typename == "typealias" and "type" or "variable", name, show_type(var.t) @@ -7151,7 +7175,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function close_nested_records(t: Type) for _, ft in pairs(t.fields) do - if is_typetype(ft) then + if ft.typename == "typetype" then ft.closed = true if is_record_type(ft.def) then close_nested_records(ft.def) @@ -7162,10 +7186,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function close_types(vars: {string:Variable}) for _, var in pairs(vars) do - if is_typetype(var.t) then - var.t.closed = true - if is_record_type(var.t.def) then - close_nested_records(var.t.def) + local t = var.t + if t.typename == "typetype" then + t.closed = true + if is_record_type(t.def) then + close_nested_records(t.def) end end end @@ -7184,16 +7209,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local list: {Unused} = {} for name, var in pairs(vars) do + local t = var.t if var.declared_at and not var.used then if var.used_as_type then var.declared_at.elide_type = true else - if is_typetype(var.t) and not is_global then + if (t.typename == "typetype" or t.typename == "typealias") and not is_global then var.declared_at.elide_type = true end table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and is_typetype(var.t) and var.aliasing then + elseif var.used and (t.typename == "typetype" or t.typename == "typealias") and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -7285,6 +7311,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local resolve_nominal: function(t: Type): Type + local resolve_typealias: function(t: Type): Type, Variable do local function match_typevals(t: Type, def: Type): Type if t.typevals and def.typeargs then @@ -7316,35 +7343,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t.resolved end - local resolved: Type - local typetype = t.found or find_type(t.names) if not typetype then error_at(t, "unknown type %s", t) return INVALID end - if not is_typetype(typetype) then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID - end + local resolved: Type - if typetype.is_alias then - typetype = typetype.def.found - assert(is_typetype(typetype)) + if typetype.typename == "typealias" then + typetype = typetype.alias_to.found end - if typetype.def.typename == "circular_require" then - -- return, but do not store resolution - return typetype.def - end + if typetype.typename == "typetype" then + if typetype.def.typename == "circular_require" then + -- return, but do not store resolution + return typetype.def + end - if typetype.def.typename == "nominal" then - typetype = typetype.def.found - assert(is_typetype(typetype)) + if typetype.def.typename == "nominal" then + typetype = typetype.def.found + assert(typetype.typename == "typetype") + end + assert(typetype.def.typename ~= "nominal") + + resolved = match_typevals(t, typetype.def) + else + error_at(t, table.concat(t.names, ".") .. " is not a type") + return INVALID end - assert(typetype.def.typename ~= "nominal") - resolved = match_typevals(t, typetype.def) + if not resolved then error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") return INVALID @@ -7361,6 +7389,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t.resolved = resolved return resolved end + + resolve_typealias = function(typealias: Type): Type, Variable + local names = typealias.alias_to.names + local aliasing = find_var(names[1], "use_type") + if not aliasing then + return INVALID + end + + local t = typealias.alias_to + if t.resolved then + return t.resolved, aliasing + end + + local typetype = t.found or find_type(t.names) + if not typetype then + error_at(t, "unknown type %s", t) + return INVALID + end + + if t.typevals then + local resolved = match_typevals(t, typetype.def) + t.resolved = resolved + t.found = typetype + typetype = a_typetype { def = resolved } + else + t.resolved = t + end + + return typetype, aliasing + end end local function are_same_unresolved_global_type(t1: Type, t2: Type): boolean @@ -8331,7 +8389,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end -- resolve if prototype - if is_typetype(func) and func.def.typename == "record" then + if func.typename == "typetype" and func.def.typename == "record" then func = func.def end -- resolve if metatable @@ -8749,12 +8807,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tbl = find_var_type("string") -- simulate string metatable end - if tbl.is_alias then - return nil, "cannot use a nested type alias as a concrete value" + if tbl.typename == "typetype" then + tbl = tbl.def + elseif tbl.typename == "typealias" then + if tbl.is_nested_alias then + return nil, "cannot use a nested type alias as a concrete value" + else + tbl = resolve_nominal(tbl.alias_to) + end end - tbl = resolve_typetype(tbl) - if tbl is UnionType then local t = same_in_all_union_entries(tbl, function(t: Type): (Type, Type) return (match_record_key(t, rec, key)) @@ -9804,25 +9866,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[i].tk == node.vars[i].tk end - local function resolve_nominal_typetype(typetype: Type): Type, Variable - if typetype.def.typename == "nominal" then - local names = typetype.def.names - local aliasing = find_var(names[1], "use_type") - local resolved = typetype - if typetype.def.typevals then - typetype.def = resolve_nominal(typetype.def) - else - resolved = find_type(names) - if (not resolved) or (not is_typetype(resolved)) then - error_at(typetype, "%s is not a type", typetype) - resolved = INVALID - end - end - return resolved, aliasing - end - return typetype, nil - end - local function missing_initializer(node: Node, i: integer, name: string): Type if lax then return UNKNOWN @@ -10146,18 +10189,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ok, t, infertype ~= nil end - local function get_type_declaration(node: Node): Type, Variable - if node.value.kind == "op" - and node.value.op.op == "@funcall" - and node.value.e1.kind == "variable" - and node.value.e1.tk == "require" + local function get_type_declaration(value: Node): Type, Variable + if value.kind == "op" + and value.op.op == "@funcall" + and value.e1.kind == "variable" + and value.e1.tk == "require" then - local t = special_functions["require"](node.value, find_var_type("require"), a_tuple { STRING }, 0) + local t = special_functions["require"](value, find_var_type("require"), a_tuple { STRING }, 0) if not t is InvalidType then return t.tuple[1] end else - return resolve_nominal_typetype(node.value.newtype) + local newtype = value.newtype + if newtype.typename == "typealias" then + return resolve_typealias(value.newtype) + else + return value.newtype, nil + end end end @@ -10178,7 +10226,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local is_total = true local missing: {string} for _, key in ipairs(t.field_order) do - if not is_typetype(t.fields[key]) then + local ftype = t.fields[key] + if not (ftype.typename == "typetype" or ftype.typename == "typealias") then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10223,7 +10272,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local var = resolve_tuple_and_nominal(vartype) - if is_typetype(var) then + if var.typename == "typetype" or var.typename == "typealias" then error_at(where, "cannot reassign a type") return nil end @@ -10270,7 +10319,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["local_type"] = { before = function(node: Node) local name = node.var.tk - local resolved, aliasing = get_type_declaration(node) + local resolved, aliasing = get_type_declaration(node.value) local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing @@ -10287,7 +10336,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local name = node.var.tk local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_type_declaration(node) + local resolved, aliasing = get_type_declaration(node.value) local added = add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then @@ -10776,7 +10825,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - if is_typetype(df) then + if df.typename == "typetype" or df.typename == "typealias" then error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -11558,7 +11607,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(node, "unknown variable: " .. node.tk) end - if is_typetype(t) then + if t.typename == "typetype" then t = typetype_to_nominal(node, node.tk, t, t) end @@ -11734,14 +11783,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope() add_var(nil, "@self", type_at(typ, a_typetype({ def = typ }))) - for name, typ2 in fields_of(typ) do - if typ2.typename == "typetype" then - local resolved, is_alias = resolve_nominal_typetype(typ2) - if is_alias then - typ2.is_alias = true - typ2.def.resolved = resolved - end - add_var(nil, name, resolved) + for fname, ftype in fields_of(typ) do + if ftype.typename == "typealias" then + resolve_nominal(ftype.alias_to) + add_var(nil, fname, ftype) + elseif ftype.typename == "typetype" then + add_var(nil, fname, ftype) end end end, @@ -11863,11 +11910,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(typ is TypeVarType) typ.typevar = t.typearg typ.constraint = t.constraint - else - if t.is_alias then - t = t.def.resolved - end - if not (t.def and t.def.typename == "circular_require") then + elseif t.typename == "typetype" then + if t.def.typename ~= "circular_require" then typ.found = t end end @@ -11950,6 +11994,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor visit_type.cbs["enum"] = default_type_visitor From 19a3a0445a7e3abf8a4be3fa14406cfbdde524f3 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 4 Jan 2024 00:19:27 -0300 Subject: [PATCH 079/224] remove Node.is_alias_node --- tl.lua | 14 +++----------- tl.tl | 14 +++----------- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/tl.lua b/tl.lua index 66d5c456c..a71b658ea 100644 --- a/tl.lua +++ b/tl.lua @@ -1532,7 +1532,6 @@ local Node = {ExpectedContext = {}, } - local function is_array_type(t) @@ -4761,14 +4760,9 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["newtype"] = { after = function(node, _children) local out = { y = node.y, h = 0 } - if node.is_alias_node then - local nt = node.newtype - if nt.typename == "typealias" then - table.insert(out, table.concat(nt.alias_to.names, ".")) - else - assert(nt.typename == "typetype") - table.insert(out, nt.def.declname) - end + local nt = node.newtype + if nt.typename == "typealias" then + table.insert(out, table.concat(nt.alias_to.names, ".")) elseif is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) else @@ -10323,7 +10317,6 @@ expand_type(node, values, elements) }) local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing - node.value.is_alias_node = true end end, after = function(node, _children) @@ -10341,7 +10334,6 @@ expand_type(node, values, elements) }) node.value.newtype = resolved if aliasing then added.aliasing = aliasing - node.value.is_alias_node = true end if added and unresolved.global_types[name] then diff --git a/tl.tl b/tl.tl index 8ba217687..b099c2501 100644 --- a/tl.tl +++ b/tl.tl @@ -1494,7 +1494,6 @@ local record Node -- newtype newtype: Type - is_alias_node: boolean elide_type: boolean -- expressions @@ -4761,14 +4760,9 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["newtype"] = { after = function(node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } - if node.is_alias_node then - local nt = node.newtype - if nt.typename == "typealias" then - table.insert(out, table.concat(nt.alias_to.names, ".")) - else - assert(nt.typename == "typetype") - table.insert(out, nt.def.declname) - end + local nt = node.newtype + if nt.typename == "typealias" then + table.insert(out, table.concat(nt.alias_to.names, ".")) elseif is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) else @@ -10323,7 +10317,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing - node.value.is_alias_node = true end end, after = function(node: Node, _children: {Type}): Type @@ -10341,7 +10334,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.value.newtype = resolved if aliasing then added.aliasing = aliasing - node.value.is_alias_node = true end if added and unresolved.global_types[name] then From 30739c35e5321bc8d4379ef470bd07abf4eead0f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 5 Jan 2024 21:03:48 -0300 Subject: [PATCH 080/224] NominalType, InterfaceType, RecordType, ArrayType --- tl.lua | 583 ++++++++++++++++++++++++---------------- tl.tl | 817 +++++++++++++++++++++++++++++++++------------------------ 2 files changed, 835 insertions(+), 565 deletions(-) diff --git a/tl.lua b/tl.lua index a71b658ea..63149fa76 100644 --- a/tl.lua +++ b/tl.lua @@ -1316,6 +1316,44 @@ local table_types = { + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -1533,15 +1571,6 @@ local Node = {ExpectedContext = {}, } - -local function is_array_type(t) - - return t.typename == "array" or t.elements ~= nil -end - -local function is_record_type(t) - return t.typename == "record" or t.typename == "interface" -end local function is_number_type(t) return t.typename == "number" or t.typename == "integer" @@ -1669,6 +1698,7 @@ end + local function c_tuple(t) return a_type("tuple", { tuple = t }) end @@ -3070,11 +3100,14 @@ end local function parse_where_clause(ps, i) local node = new_node(ps.tokens, i, "macroexp") + + local selftype = new_type(ps, i, "nominal") + selftype.names = { "@self" } + node.args = new_node(ps.tokens, i, "argument_list") node.args[1] = new_node(ps.tokens, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = new_type(ps, i, "nominal") - node.args[1].argtype.names = { "@self" } + node.args[1].argtype = selftype node.rets = new_tuple(ps, i) node.rets.tuple[1] = BOOLEAN i, node.exp = parse_expression(ps, i) @@ -3086,22 +3119,25 @@ parse_interface_name = function(ps, i) local istart = i local typ i, typ = parse_simple_type_or_nominal(ps, i) - if typ.typename ~= "nominal" then + if not (typ.typename == "nominal") then return fail(ps, istart, "expected an interface") end return i, typ end local function parse_array_interface_type(ps, i, def) - if def.interface_list and def.interface_list[1].typename == "array" then - return failskip(ps, i, "duplicated declaration of array element type", parse_type) + if def.interface_list then + local first = def.interface_list[1] + if first.typename == "array" then + return failskip(ps, i, "duplicated declaration of array element type", parse_type) + end end local t i, t = parse_base_type(ps, i) if not t then return i end - if t.typename ~= "array" then + if not (t.typename == "array") then fail(ps, i, "expected an array declaration") return i end @@ -3152,16 +3188,22 @@ parse_record_body = function(ps, i, def, node) local where_macroexp i, where_macroexp = parse_where_clause(ps, i) - def.meta_fields = {} - def.meta_field_order = {} - local typ = new_type(ps, wstart, "function") typ.is_method = true - typ.args = a_type("tuple", { tuple = { a_type("nominal", { y = typ.y, x = typ.x, filename = ps.filename, names = { "@self" } }) } }) + typ.args = a_type("tuple", { tuple = { + a_type("nominal", { + y = typ.y, + x = typ.x, + filename = ps.filename, + names = { "@self" }, + }), + } }) typ.rets = a_type("tuple", { tuple = { BOOLEAN } }) typ.macroexp = where_macroexp typ.is_abstract = true + def.meta_fields = {} + def.meta_field_order = {} store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) end @@ -3440,8 +3482,11 @@ local function parse_type_declaration(ps, i, node_name) local nt = asgn.value.newtype if nt.typename == "typetype" then - if not nt.def.declname then - nt.def.declname = asgn.var.tk + local def = nt.def + if def.fields or def.typename == "enum" then + if not def.declname then + def.declname = asgn.var.tk + end end end @@ -3461,6 +3506,8 @@ local function parse_type_constructor(ps, i, node_name, type_name, parse_body) if not asgn.var then return fail(ps, i, "expected a type name") end + + assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") def.declname = asgn.var.tk i = parse_body(ps, i, def, nt) @@ -3776,51 +3823,47 @@ local function recurse_type(ast, visit) local xs = {} - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) - end - end - if ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do xs[i] = recurse_type(child, visit) end - end - if ast.types then + elseif ast.types then for _, child in ipairs(ast.types) do table.insert(xs, recurse_type(child, visit)) end - end - if ast.interface_list then - for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) - end - end - if ast.def then - table.insert(xs, recurse_type(ast.def, visit)) - end - if ast.alias_to then - table.insert(xs, recurse_type(ast.alias_to, visit)) - end - if ast.typename == "map" then + elseif ast.typename == "map" then table.insert(xs, recurse_type(ast.keys, visit)) table.insert(xs, recurse_type(ast.values, visit)) - end - if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) - end - if ast.fields then - for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + elseif ast.fields then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(child, visit)) + end end - end - if ast.meta_fields then - for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + if ast.interface_list then + for _, child in ipairs(ast.interface_list) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.elements then + table.insert(xs, recurse_type(ast.elements, visit)) + end + if ast.fields then + for _, child in fields_of(ast) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.meta_fields then + for _, child in fields_of(ast, "meta") do + table.insert(xs, recurse_type(child, visit)) + end + end + elseif ast.typename == "function" then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(child, visit)) + end end - end - if ast.typename == "function" then if ast.args then for i, child in ipairs(ast.args.tuple) do if i > 1 or not ast.is_method or child.is_self then @@ -3833,22 +3876,33 @@ local function recurse_type(ast, visit) table.insert(xs, recurse_type(child, visit)) end end - end - if ast.typevals then - for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + elseif ast.typename == "nominal" then + if ast.typevals then + for _, child in ipairs(ast.typevals) do + table.insert(xs, recurse_type(child, visit)) + end end - end - if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) - end - if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) - end - if ast.typename == "typearg" then + elseif ast.typename == "typearg" then if ast.constraint then table.insert(xs, recurse_type(ast.constraint, visit)) end + elseif ast.typename == "array" then + if ast.elements then + table.insert(xs, recurse_type(ast.elements, visit)) + end + else + if ast.def then + table.insert(xs, recurse_type(ast.def, visit)) + end + if ast.alias_to then + table.insert(xs, recurse_type(ast.alias_to, visit)) + end + if ast.ktype then + table.insert(xs, recurse_type(ast.ktype, visit)) + end + if ast.vtype then + table.insert(xs, recurse_type(ast.vtype, visit)) + end end local ret @@ -4316,7 +4370,8 @@ function tl.pretty_print_ast(ast, gen_target, mode) local function print_record_def(typ) local out = { "{" } for _, name in ipairs(typ.field_order) do - if typ.fields[name].typename == "typetype" and is_record_type(typ.fields[name].def) then + local def = typ.fields[name].def + if typ.fields[name].typename == "typetype" and def.fields then table.insert(out, name) table.insert(out, " = ") table.insert(out, print_record_def(typ.fields[name].def)) @@ -4763,10 +4818,13 @@ function tl.pretty_print_ast(ast, gen_target, mode) local nt = node.newtype if nt.typename == "typealias" then table.insert(out, table.concat(nt.alias_to.names, ".")) - elseif is_record_type(node.newtype.def) then - table.insert(out, print_record_def(node.newtype.def)) - else - table.insert(out, "{}") + elseif nt.typename == "typetype" then + local def = nt.def + if def.fields then + table.insert(out, print_record_def(def)) + else + table.insert(out, "{}") + end end return out end, @@ -4863,10 +4921,11 @@ function tl.pretty_print_ast(ast, gen_target, mode) local default_type_visitor = { after = function(typ, _children) local out = { y = typ.y or -1, h = 0 } - local r = typ.resolved or typ - local lua_type = primitive[r.typename] or - (r.is_userdata and "userdata") or - "table" + local r = typ.typename == "nominal" and typ.resolved or typ + local lua_type = primitive[r.typename] or "table" + if r.fields and r.is_userdata then + lua_type = "userdata" + end table.insert(out, lua_type) return out end, @@ -5047,15 +5106,17 @@ get_typenum = function(trenv, t) trenv.typeid_to_num[t.typeid] = n trenv.next_num = trenv.next_num + 1 - if t.found then - ti.ref = get_typenum(trenv, t.found) - end - if t.resolved then - rt = t + if t.typename == "nominal" then + if t.found then + ti.ref = get_typenum(trenv, t.found) + end + if t.resolved then + rt = t + end end assert(not (rt.typename == "typetype" or rt.typename == "typealias")) - if is_record_type(rt) then + if rt.fields then local r = {} for _, k in ipairs(rt.field_order) do @@ -5065,7 +5126,7 @@ get_typenum = function(trenv, t) ti.fields = r end - if is_array_type(rt) then + if rt.elements then ti.elements = get_typenum(trenv, rt.elements) end @@ -5461,10 +5522,8 @@ local function show_type_base(t, short, seen) return "{" .. show(t.elements) .. "}" elseif t.typename == "enum" then return t.declname or "enum" - elseif t.typename == "interface" then - return short and "interface" or "interface" .. show_fields(t, show) - elseif is_record_type(t) then - return short and "record" or "record" .. show_fields(t, show) + elseif t.fields then + return short and t.typename or t.typename .. show_fields(t, show) elseif t.typename == "function" then local out = { "function" } if t.typeargs then @@ -5785,6 +5844,7 @@ local function init_globals(lax) local function a_record(t) t = a_type("record", t) + assert(t.fields) t.field_order = sorted_keys(t.fields) return t end @@ -6502,23 +6562,28 @@ tl.type_check = function(ast, opts) if not typ then return nil end - if typ.found then + if typ.typename == "nominal" and typ.found then typ = typ.found end for i = 2, #names do - local fields = typ.fields or (typ.def and typ.def.fields) - if fields then - typ = fields[names[i]] - if typ == nil then - return nil - end - typ = ensure_fresh_typeargs(typ) - if typ.found then - typ = typ.found - end - else + if typ.typename == "typetype" then + typ = typ.def + end + + local fields = typ.fields and typ.fields + if not fields then + return nil + end + + typ = fields[names[i]] + if typ == nil then return nil end + + typ = ensure_fresh_typeargs(typ) + if typ.typename == "nominal" and typ.found then + typ = typ.found + end end if typ.typename == "typetype" or typ.typename == "typealias" then return typ @@ -6540,7 +6605,7 @@ tl.type_check = function(ast, opts) return "invalid" end return union_type(typetype) - elseif t.typename == "record" then + elseif t.fields then if t.is_userdata then return "userdata", t end @@ -6565,6 +6630,7 @@ tl.type_check = function(ast, opts) for _, t in ipairs(typ.types) do local ut, rt = union_type(t) if ut == "userdata" then + assert(rt.fields) if rt.meta_fields and rt.meta_fields["__is"] then n_userdata_is_types = n_userdata_is_types + 1 if n_userdata_types > 0 then @@ -6580,7 +6646,7 @@ tl.type_check = function(ast, opts) end end elseif ut == "table" then - if rt.meta_fields and rt.meta_fields["__is"] then + if rt.fields and rt.meta_fields and rt.meta_fields["__is"] then n_table_is_types = n_table_is_types + 1 if n_table_types > 0 then return false, "cannot mix table types with and without __is metamethod: %s" @@ -6726,7 +6792,6 @@ tl.type_check = function(ast, opts) seen[orig_t] = copy copy.opt = t.opt - copy.is_userdata = t.is_userdata copy.is_abstract = t.is_abstract copy.typename = t.typename copy.filename = t.filename @@ -6734,9 +6799,10 @@ tl.type_check = function(ast, opts) copy.y = t.y copy.yend = t.yend copy.xend = t.xend - copy.declname = t.declname if t.typename == "array" then + assert(copy.typename == "array") + copy.elements, same = resolve(t.elements, same) elseif t.typename == "typearg" then @@ -6764,6 +6830,7 @@ tl.type_check = function(ast, opts) copy.alias_to, same = resolve(t.alias_to, same) copy.is_nested_alias = t.is_nested_alias elseif t.typename == "nominal" then + assert(copy.typename == "nominal") copy.names = t.names copy.typevals = {} for i, tf in ipairs(t.typevals) do @@ -6771,6 +6838,8 @@ tl.type_check = function(ast, opts) end copy.found = t.found elseif t.typename == "function" then + assert(copy.typename == "function") + if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6779,12 +6848,14 @@ tl.type_check = function(ast, opts) end set_min_arity(t) - assert(copy.typename == "function") copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) copy.rets, same = resolve(t.rets, same) - elseif is_record_type(t) then + elseif t.fields then + assert(copy.typename == "record" or copy.typename == "interface") + copy.declname = t.declname + if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6797,6 +6868,8 @@ tl.type_check = function(ast, opts) copy.elements, same = resolve(t.elements, same) end + copy.is_userdata = t.is_userdata + copy.fields = {} copy.field_order = {} for i, k in ipairs(t.field_order) do @@ -6853,7 +6926,11 @@ tl.type_check = function(ast, opts) if errs then return false, INVALID, errs end - if copy.typeargs and not same then + + if (not same) and + (copy.typename == "function" or copy.fields) and + copy.typeargs then + for i = #copy.typeargs, 1, -1 do if resolved[copy.typeargs[i].typearg] then table.remove(copy.typeargs, i) @@ -6994,6 +7071,7 @@ tl.type_check = function(ast, opts) assert(where.y) add_errs_prefixing(where, errs, errors, "") end + if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end @@ -7005,6 +7083,7 @@ tl.type_check = function(ast, opts) if ret.typename == "invalid" then ret = t end + if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end @@ -7171,8 +7250,9 @@ tl.type_check = function(ast, opts) for _, ft in pairs(t.fields) do if ft.typename == "typetype" then ft.closed = true - if is_record_type(ft.def) then - close_nested_records(ft.def) + local def = ft.def + if def.fields then + close_nested_records(def) end end end @@ -7183,8 +7263,9 @@ tl.type_check = function(ast, opts) local t = var.t if t.typename == "typetype" then t.closed = true - if is_record_type(t.def) then - close_nested_records(t.def) + local def = t.def + if def.fields then + close_nested_records(def) end end end @@ -7350,18 +7431,21 @@ tl.type_check = function(ast, opts) end if typetype.typename == "typetype" then - if typetype.def.typename == "circular_require" then + local def = typetype.def + if def.typename == "circular_require" then return typetype.def end - if typetype.def.typename == "nominal" then - typetype = typetype.def.found + + if def.typename == "nominal" then + typetype = def.found assert(typetype.typename == "typetype") + def = typetype.def end - assert(typetype.def.typename ~= "nominal") + assert(not (def.typename == "nominal")) - resolved = match_typevals(t, typetype.def) + resolved = match_typevals(t, def) else error_at(t, table.concat(t.names, ".") .. " is not a type") return INVALID @@ -7417,7 +7501,6 @@ tl.type_check = function(ast, opts) local function are_same_unresolved_global_type(t1, t2) if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() if unresolved.global_types[t1.names[1]] then return true @@ -7553,7 +7636,8 @@ tl.type_check = function(ast, opts) is_lua_table_type = function(t) - return known_table_types[t.typename] and not t.is_userdata + return known_table_types[t.typename] and + not (t.fields and t.is_userdata) end end @@ -7569,10 +7653,11 @@ tl.type_check = function(ast, opts) local arr_type = a_type("array", { elements = tupletype.types[1] }) for i = 2, #tupletype.types do - arr_type = expand_type(where, arr_type, a_type("array", { elements = tupletype.types[i] })) - if not arr_type.elements then + local expanded = expand_type(where, arr_type, a_type("array", { elements = tupletype.types[i] })) + if not (expanded.typename == "array") then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end + arr_type = expanded end return arr_type end @@ -7600,7 +7685,6 @@ tl.type_check = function(ast, opts) end local function subtype_array(a, b) - if (not a.elements) or (not is_a(a.elements, b.elements)) then return false end @@ -8101,7 +8185,10 @@ a.types[i], b.types[i]), } }, ["typetype"] = { ["record"] = function(a, b) - return subtype_record(a.def, b) + local def = a.def + if def.fields then + return subtype_record(a.def, b) + end end, }, ["function"] = { @@ -8312,10 +8399,12 @@ a.types[i], b.types[i]), } if same_type(t, NIL) then return true end - if t.typename ~= "function" then + if t.typename == "nominal" then t = resolve_nominal(t) end - return t.meta_fields and t.meta_fields["__close"] ~= nil + if t.fields then + return t.meta_fields and t.meta_fields["__close"] ~= nil + end end local definitely_not_closable_exprs = { @@ -8355,13 +8444,15 @@ a.types[i], b.types[i]), } local function same_call_mt_in_all_union_entries(u) return same_in_all_union_entries(u, function(t) t = resolve_tuple_and_nominal(t) - local call_mt = t.meta_fields and t.meta_fields["__call"] - if call_mt.typename == "function" then - local args_tuple = a_type("tuple", { tuple = {} }) - for i = 2, #call_mt.args.tuple do - table.insert(args_tuple.tuple, call_mt.args.tuple[i]) + if t.fields then + local call_mt = t.meta_fields and t.meta_fields["__call"] + if call_mt.typename == "function" then + local args_tuple = a_type("tuple", { tuple = {} }) + for i = 2, #call_mt.args.tuple do + table.insert(args_tuple.tuple, call_mt.args.tuple[i]) + end + return args_tuple, call_mt end - return args_tuple, call_mt end end) end @@ -8383,11 +8474,12 @@ a.types[i], b.types[i]), } end end - if func.typename == "typetype" and func.def.typename == "record" then + local funcdef = func.def + if func.typename == "typetype" and funcdef.typename == "record" then func = func.def end - if func.meta_fields and func.meta_fields["__call"] then + if func.fields and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] func = resolve_tuple_and_nominal(func) @@ -8645,7 +8737,7 @@ a.types[i], b.types[i]), } return resolve_typevars_at(where, f.rets) end - local function check_call(where, where_args, func, args, expected_rets, typetype_funcall, is_method, argdelta) + local function check_call(where, where_args, func, args, expected_rets, is_typetype_funcall, is_method, argdelta) assert(type(func) == "table") assert(type(args) == "table") @@ -8678,7 +8770,7 @@ a.types[i], b.types[i]), } if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then - if not typetype_funcall then + if not is_typetype_funcall then add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") end else @@ -8735,16 +8827,15 @@ a.types[i], b.types[i]), } begin_scope() - local typetype_funcall = not not ( - node.kind == "op" and - node.op.op == "@funcall" and - node.e1 and - node.e1.receiver and - node.e1.receiver.resolved and - node.e1.receiver.resolved.typename == "typetype") - + local is_typetype_funcall + if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then + local receiver = node.e1.receiver + if receiver.typename == "nominal" and receiver.resolved and receiver.resolved.typename == "typetype" then + is_typetype_funcall = true + end + end - local ret, f = check_call(node, where_args, func, args, expected_rets, typetype_funcall, is_method, argdelta) + local ret, f = check_call(node, where_args, func, args, expected_rets, is_typetype_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() @@ -8763,17 +8854,21 @@ a.types[i], b.types[i]), } local function check_metamethod(node, method_name, a, b, orig_a, orig_b) if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return UNKNOWN, nil - elseif not a.meta_fields and not (b and b.meta_fields) then + end + local ameta = a.fields and a.meta_fields + local bmeta = b and b.fields and b.meta_fields + + if not ameta and not bmeta then return nil, nil end local meta_on_operator = 1 local metamethod if method_name ~= "__is" then - metamethod = a.meta_fields and a.meta_fields[method_name or ""] + metamethod = ameta and ameta[method_name or ""] end if (not metamethod) and b and method_name ~= "__index" then - metamethod = b.meta_fields and b.meta_fields[method_name or ""] + metamethod = bmeta and bmeta[method_name or ""] meta_on_operator = 2 end @@ -8829,7 +8924,7 @@ a.types[i], b.types[i]), } end end - if is_record_type(tbl) then + if tbl.fields then assert(tbl.fields, "record has no fields!?") if tbl.fields[key] then @@ -9049,7 +9144,7 @@ a.types[i], b.types[i]), } if t.typename == "nominal" then t = resolve_nominal(t) end - assert(t.typename ~= "nominal") + assert(not (t.typename == "nominal")) return t end @@ -9155,7 +9250,7 @@ a.types[i], b.types[i]), } errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif is_array_type(a) and is_a(b, INTEGER) then + elseif a.elements and is_a(b, INTEGER) then return a.elements elseif a.typename == "emptytable" then if a.keys == nil then @@ -9185,7 +9280,7 @@ a.types[i], b.types[i]), } end errm, erra = e, orig_a - elseif is_record_type(a) then + elseif a.fields then if b.typename == "enum" then local field_names = sorted_keys(b.enumset) for _, k in ipairs(field_names) do @@ -9220,7 +9315,7 @@ a.types[i], b.types[i]), } return new else if not is_a(new, old) then - if old.typename == "map" and is_record_type(new) then + if old.typename == "map" and new.fields then if old.keys.typename == "string" then for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) @@ -9229,26 +9324,32 @@ a.types[i], b.types[i]), } else error_at(where, "cannot determine table literal type") end - elseif is_record_type(old) and is_record_type(new) then - edit_type(old, "map") - assert(old.typename == "map") - old.keys = STRING + elseif old.fields and new.fields then + local values for _, ftype in fields_of(old) do - if not old.values then - old.values = ftype + if not values then + values = ftype else - old.values = expand_type(where, old.values, ftype) + values = expand_type(where, values, ftype) end end for _, ftype in fields_of(new) do - if not old.values then - old.values = ftype + if not values then + values = ftype else - old.values = expand_type(where, old.values, ftype) + values = expand_type(where, values, ftype) end end old.fields = nil old.field_order = nil + old.meta_fields = nil + old.meta_fields = nil + + + edit_type(old, "map") + assert(old.typename == "map") + old.keys = STRING + old.values = values elseif old.typename == "union" then edit_type(old, "union") new.tk = nil @@ -9283,9 +9384,18 @@ a.types[i], b.types[i]), } if not t then return nil, nil, dname end - t = t and t.fields and t.fields[fname] + if not t.fields then + return nil, nil, dname + end + t = t.fields[fname] - return t.def or t, v, dname + if t.typename == "typetype" then + t = t.def + elseif t.typename == "typealias" then + t = t.alias_to.resolved + end + + return t, v, dname end end @@ -9293,9 +9403,10 @@ a.types[i], b.types[i]), } assert(t.typename == "typetype") local typevals - if t.def.typeargs then + local def = t.def + if def.typeargs then typevals = {} - for _, a in ipairs(t.def.typeargs) do + for _, a in ipairs(def.typeargs) do table.insert(typevals, a_type("typevar", { typevar = a.typearg, constraint = a.constraint, @@ -9331,11 +9442,12 @@ a.types[i], b.types[i]), } end if t.typename == "nominal" then - if t.found and t.found.def and t.found.def.fields and t.found.def.fields[exp.e2.tk] then + local def = t.found and t.found.def + if def.fields and def.fields[exp.e2.tk] then table.insert(t.names, exp.e2.tk) - t.found = t.found.def.fields[exp.e2.tk] + t.found = def.fields[exp.e2.tk] end - else + elseif t.fields then return t.fields and t.fields[exp.e2.tk] end return t @@ -9749,13 +9861,13 @@ a.types[i], b.types[i]), } return invalid_at(node, "pairs requires an argument") end local t = resolve_tuple_and_nominal(b.tuple[1]) - if is_array_type(t) then + if t.elements then add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then if not (lax and is_unknown(t)) then - if is_record_type(t) then + if t.fields then match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" @@ -9781,7 +9893,7 @@ a.types[i], b.types[i]), } if not arr_type then return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end - elseif not is_array_type(t) then + elseif not t.elements then if not (lax and (is_unknown(t) or t.typename == "emptytable")) then return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end @@ -10151,7 +10263,7 @@ expand_type(node, values, elements) }) ok = false elseif not (node.exps[i] and node.exps[i].attribute == "total") then local ri = resolve_tuple_and_nominal(infertype) - if ri.typename ~= "map" and ri.typename ~= "record" then + if not (ri.typename == "map" or ri.typename == "record") then error_at(var, "attribute only applies to maps and records") ok = false elseif not ri.is_total then @@ -10166,8 +10278,8 @@ expand_type(node, values, elements) }) error_at(var, "record variable declared does not declare values for all fields" .. missing) ok = false end + ri.is_total = nil end - ri.is_total = nil end end @@ -10177,8 +10289,9 @@ expand_type(node, values, elements) }) elseif t.typename == "emptytable" then t.declared_at = node t.assigned_to = name + elseif t.elements then + t.inferred_len = nil end - t.inferred_len = nil return ok, t, infertype ~= nil end @@ -10395,7 +10508,10 @@ expand_type(node, values, elements) }) local where = node.exps[i] or node.exps local rt = resolve_tuple_and_nominal(t) - if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt.typename == "union") and not same_type(t, infertype) then + if (not (rt.typename == "enum")) and + ((not (t.typename == "nominal")) or (rt.typename == "union")) and + not same_type(t, infertype) then + t = infer_at(where, infertype) add_var(where, var.tk, t, "const", "narrowed_declaration") end @@ -10729,7 +10845,7 @@ expand_type(node, values, elements) }) child.value.expected = decltype.types[n] end end - elseif is_array_type(decltype) then + elseif decltype.elements then for _, child in ipairs(node) do if child.key.constnum then child.value.expected = decltype.elements @@ -10742,7 +10858,7 @@ expand_type(node, values, elements) }) end end - if is_record_type(decltype) then + if decltype.fields then for _, child in ipairs(node) do if child.key.conststr then child.value.expected = decltype.fields[child.key.conststr] @@ -10795,9 +10911,6 @@ expand_type(node, values, elements) }) return infer_table_literal(node, children) end - local is_record = is_record_type(decltype) - local is_array = is_array_type(decltype) - local force_array = nil local seen_keys = {} @@ -10812,7 +10925,7 @@ expand_type(node, values, elements) }) b = (node[i].key.tk == "true") end check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) - if is_record and ck then + if decltype.fields and ck then local df = decltype.fields[ck] if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) @@ -10832,7 +10945,7 @@ expand_type(node, values, elements) }) else assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) end - elseif is_array and is_number_type(child.ktype) then + elseif decltype.elements and is_number_type(child.ktype) then local cv = child.vtype if cv.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then @@ -11037,7 +11150,7 @@ expand_type(node, values, elements) }) local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) - if rtype.typeargs then + if rtype.fields and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, @@ -11054,18 +11167,19 @@ expand_type(node, values, elements) }) local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) - if rtype.typename == "emptytable" then - edit_type(rtype, "record") - rtype.fields = {} - rtype.field_order = {} - end - if lax and rtype.typename == "unknown" then return end - if not is_record_type(rtype) then - error_at(node, "not a module: %s", rtype) + if rtype.typename == "emptytable" then + edit_type(rtype, "record") + local r = rtype + r.fields = {} + r.field_order = {} + end + + if not rtype.fields then + error_at(node, "not a record: %s", rtype) return end @@ -11705,11 +11819,11 @@ expand_type(node, values, elements) }) local expand_interfaces do - local function add_interface_fields(what, fields, field_order, iface, orig_iface, list) - for fname, ftype in fields_of(iface, list) do + local function add_interface_fields(what, fields, field_order, resolved, named, list) + for fname, ftype in fields_of(resolved, list) do if fields[fname] then if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", orig_iface) + error_at(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11718,43 +11832,54 @@ expand_type(node, values, elements) }) end end - local function expand(t, seen) - if t.interfaces_expanded then - return t + local function collect_interfaces(list, t, seen) + if t.interface_list then + for _, iface in ipairs(t.interface_list) do + if iface.typename == "nominal" then + local ri = resolve_nominal(iface) + if not (ri.typename == "invalid") then + assert(ri.typename == "interface", "nominal resolved to " .. ri.typename) + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(list, ri, seen) + end + table.insert(list, iface) + end + else + if not seen[iface] then + seen[iface] = true + table.insert(list, iface) + end + end + end end - t.interfaces_expanded = true - if seen[t] then + return list + end + + expand_interfaces = function(t) + if t.interfaces_expanded then return end - seen[t] = true - - t.fields = t.fields or {} - t.meta_fields = t.meta_fields or {} - t.field_order = t.field_order or {} - t.meta_field_order = t.meta_field_order or {} + t.interfaces_expanded = true + t.interface_list = collect_interfaces({}, t, {}) for _, iface in ipairs(t.interface_list) do - local orig_iface = iface - if iface.typename == "nominal" then - iface = resolve_nominal(iface) - end - - if iface.typename == "interface" then - if iface.interface_list then - iface = expand(iface, seen) + local ri = resolve_nominal(iface) + assert(ri.typename == "interface") + add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + else + if not t.elements then + t.elements = iface + else + if not same_type(iface.elements, t.elements) then + error_at(t, "incompatible array interfaces") + end end - - add_interface_fields("field", t.fields, t.field_order, iface, orig_iface) - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, iface, orig_iface, "meta") end end - return t - end - - expand_interfaces = function(t) - return expand(t, {}) end end @@ -11794,7 +11919,17 @@ expand_type(node, values, elements) }) end if typ.interface_list then for j, _ in ipairs(typ.interface_list) do - typ.interface_list[j] = children[i] + local iface = children[i] + if iface.typename == "array" then + typ.interface_list[j] = iface + elseif iface.typename == "nominal" then + local ri = resolve_nominal(iface) + if ri.typename == "interface" then + typ.interface_list[j] = iface + else + error_at(children[i], "%s is not an interface", children[i]) + end + end i = i + 1 end end @@ -11899,9 +12034,9 @@ expand_type(node, values, elements) }) typ.names = nil edit_type(typ, "typevar") - assert(typ.typename == "typevar") - typ.typevar = t.typearg - typ.constraint = t.constraint + local tv = typ + tv.typevar = t.typearg + tv.constraint = t.constraint elseif t.typename == "typetype" then if t.def.typename ~= "circular_require" then typ.found = t diff --git a/tl.tl b/tl.tl index b099c2501..6fb1fce0a 100644 --- a/tl.tl +++ b/tl.tl @@ -1069,7 +1069,7 @@ local table_types : {TypeName:boolean} = { } local interface Type - where self.typename ~= nil + where self.typename y: integer x: integer @@ -1082,53 +1082,26 @@ local interface Type inferred_at: Where - is_total: boolean - missing: {string} - -- Lua compatibilty needs_compat: boolean -- arguments: optional arity opt: boolean - declname: string - -- typetype def: Type closed: boolean is_abstract: boolean -- typealias - alias_to: Type + alias_to: NominalType is_nested_alias: boolean - -- records - interface_list: {Type} - interfaces_expanded: boolean - typeargs: {TypeArgType} - fields: {string: Type} - field_order: {string} - meta_fields: {string: Type} - meta_field_order: {string} - is_userdata: boolean - - -- array - elements: Type - consttypes: {Type} - -- tupletable/array - inferred_len: integer - typeid: integer -- function argument is_self: boolean - -- nominal - names: {string} - typevals: {Type} - found: Type -- type is found but typeargs are not resolved - resolved: Type -- type is found and typeargs are resolved - -- table items kname: string ktype: Type @@ -1141,6 +1114,68 @@ local interface Type narrows: {string:boolean} end +local interface HasTypeArgs + where self.typeargs + + typeargs: {TypeArgType} +end + +local interface HasDeclName + declname: string +end + +local interface HasIsTotal + is_total: boolean + missing: {string} +end + +local record NominalType + is Type + where self.typename == "nominal" + + names: {string} + typevals: {Type} + found: Type -- type is found but typeargs are not resolved + resolved: Type -- type is found and typeargs are resolved +end + +local interface ArrayLikeType + is Type + where self.elements + + elements: Type + consttypes: {Type} + inferred_len: integer +end + +local interface RecordLikeType + is Type, HasTypeArgs, HasDeclName, ArrayLikeType + where self.fields + + interface_list: {ArrayType | NominalType} + interfaces_expanded: boolean + fields: {string: Type} + field_order: {string} + meta_fields: {string: Type} + meta_field_order: {string} + is_userdata: boolean +end + +local record ArrayType + is ArrayLikeType + where self.typename == "array" +end + +local record RecordType + is RecordLikeType, HasIsTotal + where self.typename == "record" +end + +local record InterfaceType + is RecordLikeType + where self.typename == "interface" +end + local record InvalidType is Type where self.typename == "invalid" @@ -1187,7 +1222,7 @@ local record TypeVarType end local record MapType - is Type -- TODO TotalType + is Type, HasIsTotal where self.typename == "map" keys: Type @@ -1211,7 +1246,7 @@ local record UnresolvedEmptyTableValueType end local record FunctionType - is Type + is Type, HasTypeArgs where self.typename == "function" is_method: boolean @@ -1245,7 +1280,10 @@ local record PolyType types: {FunctionType} end -local record EnumType is Type where self.typename == "enum" +local record EnumType + is Type, HasDeclName + where self.typename == "enum" + enumset: {string:boolean} end @@ -1534,15 +1572,6 @@ local type Where = Node | Type -local function is_array_type(t:Type): boolean - -- checking array interface - return t.typename == "array" or t.elements ~= nil -end - -local function is_record_type(t:Type): boolean - return t.typename == "record" or t.typename == "interface" -end - local function is_number_type(t:Type): boolean return t.typename == "number" or t.typename == "integer" end @@ -1649,7 +1678,7 @@ end local function new_typetype(ps: ParseState, i: integer, def: Type): Type local t = new_type(ps, i, "typetype") t.def = def - if def.typename == "interface" then + if def is InterfaceType then -- ...or should this be set on traversal, to account for nominal type aliases? t.is_abstract = true end @@ -1658,7 +1687,8 @@ end local macroexp a_typetype(t: Type): Type -- FIXME set is_abstract here once standard_library defines interfaces --- if t.def.typename == "interface" then +-- local def = t.def +-- if def is InterfaceType then -- t.is_abstract = true -- end -- return t @@ -1691,8 +1721,8 @@ local function a_vararg(t: {Type}): TupleType return typ end -local macroexp an_array(t: Type): Type - return a_type("array", { elements = t }) +local macroexp an_array(t: Type): ArrayType + return a_type("array", { elements = t } as ArrayType) end local macroexp a_map(k: Type, v: Type): MapType @@ -2027,7 +2057,7 @@ local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer if st then return i + 1, st end - local typ = new_type(ps, i, "nominal") + local typ = new_type(ps, i, "nominal") as NominalType typ.names = { tk } i = i + 1 while ps.tokens[i].tk == "." do @@ -2059,7 +2089,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ return i end if ps.tokens[i].tk == "}" then - local decl = new_type(ps, istart, "array") + local decl = new_type(ps, istart, "array") as ArrayType decl.elements = t end_at(decl as Node, ps.tokens[i]) i = verify_tk(ps, i, "}") @@ -2985,7 +3015,7 @@ local function store_field_in_record(ps: ParseState, i: integer, field_name: str return true end -local function parse_nested_type(ps: ParseState, i: integer, def: Type, typename: TypeName, parse_body: ParseBody): integer, boolean +local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType, typename: TypeName, parse_body: ParseBody): integer, boolean i = i + 1 -- skip 'record' or 'enum' local iv = i @@ -3070,11 +3100,14 @@ end local function parse_where_clause(ps: ParseState, i: integer): integer, Node local node = new_node(ps.tokens, i, "macroexp") + + local selftype = new_type(ps, i, "nominal") as NominalType + selftype.names = { "@self" } + node.args = new_node(ps.tokens, i, "argument_list") node.args[1] = new_node(ps.tokens, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = new_type(ps, i, "nominal") - node.args[1].argtype.names = { "@self" } + node.args[1].argtype = selftype node.rets = new_tuple(ps, i) node.rets.tuple[1] = BOOLEAN i, node.exp = parse_expression(ps, i) @@ -3086,22 +3119,25 @@ parse_interface_name = function(ps: ParseState, i: integer): integer, Type, inte local istart = i local typ: Type i, typ = parse_simple_type_or_nominal(ps, i) - if typ.typename ~= "nominal" then + if not typ is NominalType then return fail(ps, istart, "expected an interface") end return i, typ end -local function parse_array_interface_type(ps: ParseState, i: integer, def: Type): integer, Type - if def.interface_list and def.interface_list[1].typename == "array" then - return failskip(ps, i, "duplicated declaration of array element type", parse_type as SkipFunction) +local function parse_array_interface_type(ps: ParseState, i: integer, def: RecordLikeType): integer, Type + if def.interface_list then + local first = def.interface_list[1] + if first is ArrayType then + return failskip(ps, i, "duplicated declaration of array element type", parse_type as SkipFunction) + end end local t: Type i, t = parse_base_type(ps, i) if not t then return i end - if t.typename ~= "array" then + if not t is ArrayType then fail(ps, i, "expected an array declaration") return i end @@ -3109,7 +3145,7 @@ local function parse_array_interface_type(ps: ParseState, i: integer, def: Type) return i, t end -parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node +parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, node: Node): integer, Node local istart = i - 1 def.fields = {} def.field_order = {} @@ -3152,16 +3188,22 @@ parse_record_body = function(ps: ParseState, i: integer, def: Type, node: Node): local where_macroexp: Node i, where_macroexp = parse_where_clause(ps, i) - def.meta_fields = {} - def.meta_field_order = {} - local typ = new_type(ps, wstart, "function") as FunctionType typ.is_method = true - typ.args = a_tuple { a_type("nominal", { y = typ.y, x = typ.x, filename = ps.filename, names = { "@self" } }) } + typ.args = a_tuple { + a_type("nominal", { + y = typ.y, + x = typ.x, + filename = ps.filename, + names = { "@self" } + } as NominalType) + } typ.rets = a_tuple { BOOLEAN } typ.macroexp = where_macroexp typ.is_abstract = true + def.meta_fields = {} + def.meta_field_order = {} store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) end @@ -3298,7 +3340,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node return fail(ps, i, "expected a type") end - if def.typename == "nominal" then + if def is NominalType then node.newtype = new_type(ps, itype, "typealias") node.newtype.alias_to = def else @@ -3440,8 +3482,11 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod local nt = asgn.value.newtype if nt.typename == "typetype" then - if not nt.def.declname then - nt.def.declname = asgn.var.tk + local def = nt.def + if def is RecordLikeType or def is EnumType then + if not def.declname then + def.declname = asgn.var.tk + end end end @@ -3461,6 +3506,8 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod if not asgn.var then return fail(ps, i, "expected a type name") end + + assert(def is RecordType or def is InterfaceType or def is EnumType) def.declname = asgn.var.tk i = parse_body(ps, i, def, nt) @@ -3677,7 +3724,7 @@ local enum MetaMode "meta" end -local function fields_of(t: Type, meta?: MetaMode): (function(): string, Type) +local function fields_of(t: RecordLikeType, meta?: MetaMode): (function(): string, Type) local i = 1 local field_order, fields: {string}, {string:Type} if meta then @@ -3776,51 +3823,47 @@ local function recurse_type(ast: Type, visit: Visitor): T local xs: {T} = {} - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) - end - end - if ast is TupleType then for i, child in ipairs(ast.tuple) do xs[i] = recurse_type(child, visit) end - end - if ast is AggregateType then + elseif ast is AggregateType then for _, child in ipairs(ast.types) do table.insert(xs, recurse_type(child, visit)) end - end - if ast.interface_list then - for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) - end - end - if ast.def then - table.insert(xs, recurse_type(ast.def, visit)) - end - if ast.alias_to then - table.insert(xs, recurse_type(ast.alias_to, visit)) - end - if ast is MapType then + elseif ast is MapType then table.insert(xs, recurse_type(ast.keys, visit)) table.insert(xs, recurse_type(ast.values, visit)) - end - if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) - end - if ast.fields then - for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + elseif ast is RecordLikeType then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(child, visit)) + end end - end - if ast.meta_fields then - for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + if ast.interface_list then + for _, child in ipairs(ast.interface_list) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.elements then + table.insert(xs, recurse_type(ast.elements, visit)) + end + if ast.fields then + for _, child in fields_of(ast) do + table.insert(xs, recurse_type(child, visit)) + end + end + if ast.meta_fields then + for _, child in fields_of(ast, "meta") do + table.insert(xs, recurse_type(child, visit)) + end + end + elseif ast is FunctionType then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(child, visit)) + end end - end - if ast is FunctionType then if ast.args then for i, child in ipairs(ast.args.tuple) do if i > 1 or not ast.is_method or child.is_self then @@ -3833,22 +3876,33 @@ local function recurse_type(ast: Type, visit: Visitor): T table.insert(xs, recurse_type(child, visit)) end end - end - if ast.typevals then - for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + elseif ast is NominalType then + if ast.typevals then + for _, child in ipairs(ast.typevals) do + table.insert(xs, recurse_type(child, visit)) + end end - end - if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) - end - if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) - end - if ast is TypeArgType then + elseif ast is TypeArgType then if ast.constraint then table.insert(xs, recurse_type(ast.constraint, visit)) end + elseif ast is ArrayType then + if ast.elements then + table.insert(xs, recurse_type(ast.elements, visit)) + end + else + if ast.def then + table.insert(xs, recurse_type(ast.def, visit)) + end + if ast.alias_to then + table.insert(xs, recurse_type(ast.alias_to, visit)) + end + if ast.ktype then + table.insert(xs, recurse_type(ast.ktype, visit)) + end + if ast.vtype then + table.insert(xs, recurse_type(ast.vtype, visit)) + end end local ret: T @@ -4313,10 +4367,11 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return table.concat(out) end - local function print_record_def(typ: Type): string + local function print_record_def(typ: RecordLikeType): string local out: {string} = { "{" } for _, name in ipairs(typ.field_order) do - if typ.fields[name].typename == "typetype" and is_record_type(typ.fields[name].def) then + local def = typ.fields[name].def + if typ.fields[name].typename == "typetype" and def is RecordLikeType then table.insert(out, name) table.insert(out, " = ") table.insert(out, print_record_def(typ.fields[name].def)) @@ -4763,10 +4818,13 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local nt = node.newtype if nt.typename == "typealias" then table.insert(out, table.concat(nt.alias_to.names, ".")) - elseif is_record_type(node.newtype.def) then - table.insert(out, print_record_def(node.newtype.def)) - else - table.insert(out, "{}") + elseif nt.typename == "typetype" then + local def = nt.def + if def is RecordLikeType then + table.insert(out, print_record_def(def)) + else + table.insert(out, "{}") + end end return out end, @@ -4863,10 +4921,11 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local default_type_visitor = { after = function(typ: Type, _children: {Output}): Output local out: Output = { y = typ.y or -1, h = 0 } - local r = typ.resolved or typ - local lua_type = primitive[r.typename] - or (r.is_userdata and "userdata") - or "table" + local r = typ is NominalType and typ.resolved or typ + local lua_type = primitive[r.typename] or "table" + if r is RecordLikeType and r.is_userdata then + lua_type = "userdata" + end table.insert(out, lua_type) return out end, @@ -5047,15 +5106,17 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer trenv.typeid_to_num[t.typeid] = n trenv.next_num = trenv.next_num + 1 - if t.found then - ti.ref = get_typenum(trenv, t.found) - end - if t.resolved then - rt = t + if t is NominalType then + if t.found then + ti.ref = get_typenum(trenv, t.found) + end + if t.resolved then + rt = t + end end assert(not (rt.typename == "typetype" or rt.typename == "typealias")) - if is_record_type(rt) then + if rt is RecordLikeType then -- store record field info local r = {} for _, k in ipairs(rt.field_order) do @@ -5065,7 +5126,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer ti.fields = r end - if is_array_type(rt) then + if rt is ArrayLikeType then ti.elements = get_typenum(trenv, rt.elements) end @@ -5121,7 +5182,7 @@ local CIRCULAR_REQUIRE = a_type("circular_require", {}) local FUNCTION = a_function { args = a_vararg { ANY }, rets = a_vararg { ANY } } -local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} }) +local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) local XPCALL_MSGH_FUNCTION = a_function { args = a_tuple { ANY }, rets = a_tuple { } } local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type @@ -5368,7 +5429,7 @@ local function display_typevar(typevar: string): string return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) end -local function show_fields(t: Type, show: function(Type):(string)): string +local function show_fields(t: RecordLikeType, show: function(Type):(string)): string if t.declname then return " " .. t.declname end @@ -5408,7 +5469,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return show_type(typ, short, seen) end - if t.typename == "nominal" then + if t is NominalType then if #t.names == 1 and t.names[1] == "@self" then return "self" end @@ -5457,14 +5518,12 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "{}" elseif t is MapType then return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}" - elseif t.typename == "array" then + elseif t is ArrayType then return "{" .. show(t.elements) .. "}" - elseif t.typename == "enum" then + elseif t is EnumType then return t.declname or "enum" - elseif t.typename == "interface" then - return short and "interface" or "interface" .. show_fields(t, show) - elseif is_record_type(t) then - return short and "record" or "record" .. show_fields(t, show) + elseif t is RecordLikeType then + return short and t.typename or t.typename .. show_fields(t, show) elseif t is FunctionType then local out: {string} = {"function"} if t.typeargs then @@ -5783,8 +5842,9 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} last_typeid = globals_typeid end - local function a_record(t: Type): Type + local function a_record(t: RecordType): Type t = a_type("record", t) + assert(t.fields) t.field_order = sorted_keys(t.fields) return t end @@ -5805,7 +5865,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} end local function a_grecord(n: integer, f: function(...: Type): Type): Type - local t = a_gfunction(n, f, "record") as Type -- FIXME + local t = a_gfunction(n, f, "record") as RecordType -- FIXME t.field_order = sorted_keys(t.fields) return t end @@ -5888,10 +5948,10 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, a }, rets = a_tuple { BOOLEAN } } end) - local metatable_nominals = {} + local metatable_nominals: {NominalType} = {} local function METATABLE(a: Type): Type - local t = a_type("nominal", { names = {"metatable"}, typevals = { a } }) + local t = a_type("nominal", { names = {"metatable"}, typevals = { a } } as NominalType) table.insert(metatable_nominals, t) return t end @@ -5972,7 +6032,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, }, ["metatable"] = a_typetype { - def = a_grecord(1, function(a: Type): Type return { + def = a_grecord(1, function(a: Type): RecordType return { fields = { ["__call"] = a_function { args = a_vararg { a, ANY }, rets = a_vararg { ANY } }, ["__gc"] = a_function { args = a_tuple { a }, rets = a_tuple {} }, @@ -6299,7 +6359,7 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar -- make standard library tables available as modules for require() for name, var in pairs(standard_library) do - if var.typename == "record" then + if var is RecordType then env.modules[name] = var end end @@ -6387,7 +6447,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function simulate_g(): Type, Attribute + local function simulate_g(): RecordType, Attribute -- this is a static approximation of _G local globals: {string:Type} = {} for k, v in pairs(st[1]) do @@ -6420,8 +6480,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } as TypeArgType) end - local function ensure_fresh_typeargs(t: Type): Type - if not t.typeargs then + local function ensure_fresh_typeargs(t: T): T + if not t is HasTypeArgs then return t end @@ -6502,23 +6562,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not typ then return nil end - if typ.found then + if typ is NominalType and typ.found then typ = typ.found end for i = 2, #names do - local fields = typ.fields or (typ.def and typ.def.fields) - if fields then - typ = fields[names[i]] - if typ == nil then - return nil - end - typ = ensure_fresh_typeargs(typ) - if typ.found then - typ = typ.found - end - else + if typ.typename == "typetype" then + typ = typ.def + end + + local fields = typ is RecordLikeType and typ.fields + if not fields then return nil end + + typ = fields[names[i]] + if typ == nil then + return nil + end + + typ = ensure_fresh_typeargs(typ) + if typ is NominalType and typ.found then + typ = typ.found + end end if typ.typename == "typetype" or typ.typename == "typealias" then return typ @@ -6534,13 +6599,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return union_type(t.alias_to), t.alias_to elseif t is TupleType then return union_type(t.tuple[1]), t.tuple[1] - elseif t.typename == "nominal" then + elseif t is NominalType then local typetype = t.found or find_type(t.names) if not typetype then return "invalid" end return union_type(typetype) - elseif t.typename == "record" then + elseif t is RecordLikeType then if t.is_userdata then return "userdata", t end @@ -6565,6 +6630,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for _, t in ipairs(typ.types) do local ut, rt = union_type(t) if ut == "userdata" then -- must be tested before table_types + assert(rt is RecordLikeType) if rt.meta_fields and rt.meta_fields["__is"] then n_userdata_is_types = n_userdata_is_types + 1 if n_userdata_types > 0 then @@ -6580,7 +6646,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif ut == "table" then - if rt.meta_fields and rt.meta_fields["__is"] then + if rt is RecordLikeType and rt.meta_fields and rt.meta_fields["__is"] then n_table_is_types = n_table_is_types + 1 if n_table_types > 0 then return false, "cannot mix table types with and without __is metamethod: %s" @@ -6696,11 +6762,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t: Type, all_same: boolean): Type, boolean + local function resolve(t: T, all_same: boolean): T, boolean local same = true -- avoid copies of types that do not contain type variables - if no_nested_types[t.typename] or (t.typename == "nominal" and not t.typevals) then + if no_nested_types[t.typename] or (t is NominalType and not t.typevals) then return t, all_same end @@ -6713,7 +6779,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rt = fn_var(t) if rt then resolved[t.typevar] = true - if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then + if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then seen[orig_t] = rt return rt, false end @@ -6726,7 +6792,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string seen[orig_t] = copy copy.opt = t.opt - copy.is_userdata = t.is_userdata copy.is_abstract = t.is_abstract copy.typename = t.typename copy.filename = t.filename @@ -6734,9 +6799,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.y = t.y copy.yend = t.yend copy.xend = t.xend - copy.declname = t.declname -- which types have this, exactly? - if t.typename == "array" then + if t is ArrayType then + assert(copy is ArrayType) + copy.elements, same = resolve(t.elements, same) -- inferred_len is not propagated elseif t is TypeArgType then @@ -6763,7 +6829,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t.typename == "typealias" then copy.alias_to, same = resolve(t.alias_to, same) copy.is_nested_alias = t.is_nested_alias - elseif t.typename == "nominal" then + elseif t is NominalType then + assert(copy is NominalType) copy.names = t.names copy.typevals = {} for i, tf in ipairs(t.typevals) do @@ -6771,6 +6838,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end copy.found = t.found elseif t is FunctionType then + assert(copy is FunctionType) + if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6779,12 +6848,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end set_min_arity(t) - assert(copy is FunctionType) copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) as (TupleType, boolean) copy.rets, same = resolve(t.rets, same) as (TupleType, boolean) - elseif is_record_type(t) then + elseif t is RecordLikeType then + assert(copy is RecordType or copy is InterfaceType) + copy.declname = t.declname + if t.typeargs then copy.typeargs = {} for i, tf in ipairs(t.typeargs) do @@ -6797,6 +6868,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.elements, same = resolve(t.elements, same) end + copy.is_userdata = t.is_userdata + copy.fields = {} copy.field_order = {} for i, k in ipairs(t.field_order) do @@ -6853,7 +6926,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if errs then return false, INVALID, errs end - if copy.typeargs and not same then + + if (not same) and + (copy is FunctionType or copy is RecordLikeType) and + copy.typeargs + then for i = #copy.typeargs, 1, -1 do if resolved[copy.typeargs[i].typearg] then table.remove(copy.typeargs, i) @@ -6994,6 +7071,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(where.y) add_errs_prefixing(where, errs, errors, "") end + if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end @@ -7005,6 +7083,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ret.typename == "invalid" then ret = t -- errors are produced by resolve_typevars_at end + if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end @@ -7167,12 +7246,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function close_nested_records(t: Type) + local function close_nested_records(t: RecordLikeType) for _, ft in pairs(t.fields) do if ft.typename == "typetype" then ft.closed = true - if is_record_type(ft.def) then - close_nested_records(ft.def) + local def = ft.def + if def is RecordLikeType then + close_nested_records(def) end end end @@ -7183,8 +7263,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = var.t if t.typename == "typetype" then t.closed = true - if is_record_type(t.def) then - close_nested_records(t.def) + local def = t.def + if def is RecordLikeType then + close_nested_records(def) end end end @@ -7307,7 +7388,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local resolve_nominal: function(t: Type): Type local resolve_typealias: function(t: Type): Type, Variable do - local function match_typevals(t: Type, def: Type): Type + local function match_typevals(t: NominalType, def: RecordLikeType | FunctionType): Type if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then error_at(t, "mismatch in number of type arguments") @@ -7332,7 +7413,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - resolve_nominal = function(t: Type): Type + resolve_nominal = function(t: NominalType): Type if t.resolved then return t.resolved end @@ -7350,18 +7431,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if typetype.typename == "typetype" then - if typetype.def.typename == "circular_require" then + local def = typetype.def + if def.typename == "circular_require" then -- return, but do not store resolution return typetype.def end - if typetype.def.typename == "nominal" then - typetype = typetype.def.found + -- FIXME is this block still needed? + if def is NominalType then + typetype = def.found assert(typetype.typename == "typetype") + def = typetype.def end - assert(typetype.def.typename ~= "nominal") + assert(not def is NominalType) - resolved = match_typevals(t, typetype.def) + resolved = match_typevals(t, def) else error_at(t, table.concat(t.names, ".") .. " is not a type") return INVALID @@ -7415,9 +7499,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function are_same_unresolved_global_type(t1: Type, t2: Type): boolean - if t1.names[1] == t2.names[1] - then + local function are_same_unresolved_global_type(t1: NominalType, t2: NominalType): boolean + if t1.names[1] == t2.names[1] then local unresolved = get_unresolved() if unresolved.global_types[t1.names[1]] then return true @@ -7442,7 +7525,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, { Err(t1, t1name .. " is not a " .. t2name) } end - local function are_same_nominals(t1: Type, t2: Type): boolean, {Error} + local function are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} local same_names: boolean if t1.found and t2.found then same_names = t1.found.typeid == t2.found.typeid @@ -7520,7 +7603,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else local typeid = t.typeid - if t.typename == "nominal" then + if t is NominalType then typeid = resolve_nominal(t).typeid end if not types_seen[typeid] then @@ -7553,12 +7636,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- Is the type represented concretely as a Lua table? is_lua_table_type = function(t: Type): boolean - return known_table_types[t.typename] and not t.is_userdata + return known_table_types[t.typename] + and not (t is RecordLikeType and t.is_userdata) end end local expand_type: function(where: Where, old: Type, new: Type): Type - local function arraytype_from_tuple(where: Where, tupletype: TupleTableType): Type, {Error} + local function arraytype_from_tuple(where: Where, tupletype: TupleTableType): ArrayType, {Error} -- first just try a basic union local element_type = unite(tupletype.types, true) local valid = (not element_type is UnionType) and true or is_valid_union(element_type) @@ -7569,16 +7653,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- failing a basic union, expand the types local arr_type = an_array(tupletype.types[1]) for i = 2, #tupletype.types do - arr_type = expand_type(where, arr_type, an_array(tupletype.types[i])) - if not arr_type.elements then + local expanded = expand_type(where, arr_type, an_array(tupletype.types[i])) + if not expanded is ArrayType then return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } end + arr_type = expanded end return arr_type end local function is_self(t: Type): boolean - return t.typename == "nominal" and t.names[1] == "@self" + return t is NominalType and t.names[1] == "@self" end local function compare_true(_: Type, _: Type): boolean, {Error} @@ -7590,8 +7675,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local ra = a.typename == "nominal" and resolve_nominal(a) or a - local rb = b.typename == "nominal" and resolve_nominal(b) or b + local ra = a is NominalType and resolve_nominal(a) or a + local rb = b is NominalType and resolve_nominal(b) or b local ok, errs = is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false -- translate to got-expected error with unresolved types @@ -7599,8 +7684,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ok, errs end - local function subtype_array(a: Type, b: Type): boolean, {Error} - -- assert(b.typename == "array") + local function subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} if (not a.elements) or (not is_a(a.elements, b.elements)) then return false end @@ -7615,7 +7699,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function find_in_interface_list(a: Type, f: function(Type): T): T + local function find_in_interface_list(a: RecordLikeType, f: function(Type): T): T if not a.interface_list then return nil end @@ -7630,7 +7714,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - local function subtype_record(a: Type, b: Type): boolean, {Error} + local function subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} -- assert(b.typename == "record") if a.elements and b.elements then if not is_a(a.elements, b.elements) then @@ -7664,7 +7748,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local eqtype_record = function(a: Type, b: Type): boolean, {Error} + local eqtype_record = function(a: RecordType, b: RecordType): boolean, {Error} -- checking array interface if (a.elements ~= nil) ~= (b.elements ~= nil) then return false, { Err(a, "types do not have the same array interface") } @@ -7771,10 +7855,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a: Type, b: Type): boolean, {Error} + ["interface"] = function(_a: Type, b: InterfaceType): boolean, {Error} return not b.is_userdata end, - ["record"] = function(_a: Type, b: Type): boolean, {Error} + ["record"] = function(_a: Type, b: RecordType): boolean, {Error} return not b.is_userdata end, } @@ -7810,7 +7894,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["array"] = { - ["array"] = function(a: Type, b: Type): boolean, {Error} + ["array"] = function(a: ArrayType, b: ArrayType): boolean, {Error} return same_type(a.elements, b.elements) end, }, @@ -7933,14 +8017,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["nominal"] = { - ["nominal"] = function(a: Type, b: Type): boolean, {Error} + ["nominal"] = function(a: NominalType, b: NominalType): boolean, {Error} local ok, errs = are_same_nominals(a, b) if ok then return true end local rb = resolve_nominal(b) - if rb.typename == "interface" then + if rb is InterfaceType then -- match interface subtyping return is_a(a, rb) end @@ -8003,12 +8087,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end return true end, - ["record"] = function(a: Type, b: Type): boolean, {Error} + ["record"] = function(a: Type, b: RecordType): boolean, {Error} if b.elements then return subtype_relations["tupletable"]["array"](a, b) end end, - ["array"] = function(a: TupleTableType, b: Type): boolean, {Error} + ["array"] = function(a: TupleTableType, b: ArrayType): boolean, {Error} if b.inferred_len and b.inferred_len > #a.types then return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end @@ -8032,7 +8116,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["record"] = { ["record"] = subtype_record, - ["interface"] = function(a: Type, b: Type): boolean, {Error} + ["interface"] = function(a: RecordType, b: InterfaceType): boolean, {Error} if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then return true end @@ -8042,7 +8126,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end, ["array"] = subtype_array, - ["map"] = function(a: Type, b: MapType): boolean, {Error} + ["map"] = function(a: RecordType, b: MapType): boolean, {Error} if not is_a(b.keys, STRING) then return false, { Err(a, "can't match a record to a map with non-string keys") } end @@ -8059,7 +8143,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end, - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} + ["tupletable"] = function(a: RecordType, b: Type): boolean, {Error} if a.elements then return subtype_relations["array"]["tupletable"](a, b) end @@ -8067,15 +8151,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["array"] = { ["array"] = subtype_array, - ["record"] = function(a: Type, b: Type): boolean, {Error} + ["record"] = function(a: ArrayType, b: RecordType): boolean, {Error} if b.elements then return subtype_array(a, b) end end, - ["map"] = function(a: Type, b: MapType): boolean, {Error} + ["map"] = function(a: ArrayType, b: MapType): boolean, {Error} return compare_map(INTEGER, b.keys, a.elements, b.values) end, - ["tupletable"] = function(a: Type, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(a: ArrayType, b: TupleTableType): boolean, {Error} local alen = a.inferred_len or 0 if alen > #b.types then return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } @@ -8095,13 +8179,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["map"] = function(a: MapType, b: MapType): boolean, {Error} return compare_map(a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a: MapType, b: Type): boolean, {Error} + ["array"] = function(a: MapType, b: ArrayType): boolean, {Error} return compare_map(a.keys, INTEGER, a.values, b.elements) end, }, ["typetype"] = { - ["record"] = function(a: Type, b: Type): boolean, {Error} - return subtype_record(a.def, b) -- record as prototype + ["record"] = function(a: Type, b: RecordType): boolean, {Error} + local def = a.def + if def is RecordLikeType then + return subtype_record(a.def, b) -- record as prototype + end end, }, ["function"] = { @@ -8306,16 +8393,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function type_is_closable(t: Type): boolean - if t.typename == "invalid" then + if t is InvalidType then return false end if same_type(t, NIL) then return true end - if t.typename ~= "function" then + if t is NominalType then t = resolve_nominal(t) end - return t.meta_fields and t.meta_fields["__close"] ~= nil + if t is RecordLikeType then + return t.meta_fields and t.meta_fields["__close"] ~= nil + end end local definitely_not_closable_exprs : {NodeKind:boolean} = { @@ -8355,13 +8444,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function same_call_mt_in_all_union_entries(u: UnionType): Type return same_in_all_union_entries(u, function(t: Type): (Type, Type) t = resolve_tuple_and_nominal(t) - local call_mt = t.meta_fields and t.meta_fields["__call"] - if call_mt is FunctionType then - local args_tuple = a_tuple({}) - for i = 2, #call_mt.args.tuple do - table.insert(args_tuple.tuple, call_mt.args.tuple[i]) + if t is RecordLikeType then + local call_mt = t.meta_fields and t.meta_fields["__call"] + if call_mt is FunctionType then + local args_tuple = a_tuple({}) + for i = 2, #call_mt.args.tuple do + table.insert(args_tuple.tuple, call_mt.args.tuple[i]) + end + return args_tuple, call_mt end - return args_tuple, call_mt end end) end @@ -8383,11 +8474,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end -- resolve if prototype - if func.typename == "typetype" and func.def.typename == "record" then + local funcdef = func.def + if func.typename == "typetype" and funcdef is RecordType then func = func.def end -- resolve if metatable - if func.meta_fields and func.meta_fields["__call"] then + if func is RecordLikeType and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] func = resolve_tuple_and_nominal(func) @@ -8645,7 +8737,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_typevars_at(where, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, typetype_funcall: boolean, is_method: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typetype_funcall: boolean, is_method: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") @@ -8678,7 +8770,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then -- a non-"@funcall" means a synthesized call, e.g. from a metamethod - if not typetype_funcall then + if not is_typetype_funcall then add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") end else @@ -8735,16 +8827,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope() - local typetype_funcall = not not ( - node.kind == "op" - and node.op.op == "@funcall" - and node.e1 - and node.e1.receiver - and node.e1.receiver.resolved - and node.e1.receiver.resolved.typename == "typetype" - ) + local is_typetype_funcall: boolean + if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then + local receiver = node.e1.receiver + if receiver is NominalType and receiver.resolved and receiver.resolved.typename == "typetype" then + is_typetype_funcall = true + end + end - local ret, f = check_call(node, where_args, func, args, expected_rets, typetype_funcall, is_method, argdelta) + local ret, f = check_call(node, where_args, func, args, expected_rets, is_typetype_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() @@ -8763,17 +8854,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return UNKNOWN, nil - elseif not a.meta_fields and not (b and b.meta_fields) then + end + local ameta = a is RecordLikeType and a.meta_fields + local bmeta = b and b is RecordLikeType and b.meta_fields + + if not ameta and not bmeta then return nil, nil end local meta_on_operator = 1 local metamethod: Type if method_name ~= "__is" then - metamethod = a.meta_fields and a.meta_fields[method_name or ""] + metamethod = ameta and ameta[method_name or ""] end if (not metamethod) and b and method_name ~= "__index" then - metamethod = b.meta_fields and b.meta_fields[method_name or ""] + metamethod = bmeta and bmeta[method_name or ""] meta_on_operator = 2 end @@ -8829,7 +8924,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if is_record_type(tbl) then + if tbl is RecordLikeType then assert(tbl.fields, "record has no fields!?") if tbl.fields[key] then @@ -8842,7 +8937,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if rec.kind == "variable" then - if tbl.typename == "interface" then + if tbl is InterfaceType then return nil, "invalid key '" .. key .. "' in '" .. rec.tk .. "' of interface type %s" else return nil, "invalid key '" .. key .. "' in record '" .. rec.tk .. "' of type %s" @@ -9046,10 +9141,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string resolve_tuple_and_nominal = function(t: Type): Type t = resolve_tuple(t) - if t.typename == "nominal" then + if t is NominalType then t = resolve_nominal(t) end - assert(t.typename ~= "nominal") + assert(not t is NominalType) return t end @@ -9105,7 +9200,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function match_all_record_field_names(node: Node, a: Type, field_names: {string}, errmsg: string): Type + local function match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type local t: Type for _, k in ipairs(field_names) do local f = a.fields[k] @@ -9155,7 +9250,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif is_array_type(a) and is_a(b, INTEGER) then + elseif a is ArrayLikeType and is_a(b, INTEGER) then return a.elements elseif a is EmptyTableType then if a.keys == nil then @@ -9185,7 +9280,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end errm, erra = e, orig_a - elseif is_record_type(a) then + elseif a is RecordLikeType then if b is EnumType then local field_names: {string} = sorted_keys(b.enumset) for _, k in ipairs(field_names) do @@ -9220,7 +9315,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return new else if not is_a(new, old) then - if old is MapType and is_record_type(new) then + if old is MapType and new is RecordLikeType then if old.keys.typename == "string" then for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) @@ -9229,26 +9324,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else error_at(where, "cannot determine table literal type") end - elseif is_record_type(old) and is_record_type(new) then - edit_type(old, "map") - assert(old is MapType) - old.keys = STRING + elseif old is RecordLikeType and new is RecordLikeType then + local values: Type for _, ftype in fields_of(old) do - if not old.values then - old.values = ftype + if not values then + values = ftype else - old.values = expand_type(where, old.values, ftype) + values = expand_type(where, values, ftype) end end for _, ftype in fields_of(new) do - if not old.values then - old.values = ftype + if not values then + values = ftype else - old.values = expand_type(where, old.values, ftype) + values = expand_type(where, values, ftype) end end old.fields = nil old.field_order = nil + old.meta_fields = nil + old.meta_fields = nil + -- FIXME what about meta_fields + + edit_type(old, "map") + assert(old is MapType) + old.keys = STRING + old.values = values elseif old is UnionType then edit_type(old, "union") new.tk = nil @@ -9283,9 +9384,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t then return nil, nil, dname end - t = t and t.fields and t.fields[fname] + if not t is RecordLikeType then + return nil, nil, dname + end + t = t.fields[fname] + + if t.typename == "typetype" then + t = t.def + elseif t.typename == "typealias" then + t = t.alias_to.resolved + end - return t.def or t, v, dname + return t, v, dname end end @@ -9293,9 +9403,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(t.typename == "typetype") local typevals: {Type} - if t.def.typeargs then + local def = t.def + if def is HasTypeArgs then typevals = {} - for _, a in ipairs(t.def.typeargs) do + for _, a in ipairs(def.typeargs) do table.insert(typevals, a_type("typevar", { typevar = a.typearg, constraint = a.constraint, @@ -9307,7 +9418,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string names = { name }, found = t, resolved = resolved, - })) + } as NominalType)) end local function get_self_type(exp: Node): Type @@ -9330,12 +9441,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - if t.typename == "nominal" then - if t.found and t.found.def and t.found.def.fields and t.found.def.fields[exp.e2.tk] then + if t is NominalType then + local def = t.found and t.found.def + if def is RecordLikeType and def.fields[exp.e2.tk] then table.insert(t.names, exp.e2.tk) - t.found = t.found.def.fields[exp.e2.tk] + t.found = def.fields[exp.e2.tk] end - else + elseif t is RecordLikeType then return t.fields and t.fields[exp.e2.tk] end return t @@ -9749,13 +9861,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(node, "pairs requires an argument") end local t = resolve_tuple_and_nominal(b.tuple[1]) - if is_array_type(t) then + if t is ArrayLikeType then add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then if not (lax and is_unknown(t)) then - if is_record_type(t) then + if t is RecordLikeType then match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" @@ -9781,7 +9893,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not arr_type then return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end - elseif not is_array_type(t) then + elseif not t is ArrayLikeType then if not (lax and (is_unknown(t) or t is EmptyTableType)) then return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end @@ -10042,7 +10154,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string interface_list = { type_at(node, an_array(elements)) } - }) + } as RecordType) -- TODO adopt logic from is_array below when we accept tupletable as an interface elseif is_record and is_map then if keys.typename == "string" then @@ -10079,7 +10191,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t = a_type("record", { fields = fields, field_order = field_order, - }) + } as RecordType) elseif is_map then t = a_map(keys, values) elseif is_tuple then @@ -10127,7 +10239,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else if infertype then - if infertype.typename == "unresolvable_typearg" then + if infertype is UnresolvableTypeArgType then error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false infertype = INVALID @@ -10151,7 +10263,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ok = false elseif not (node.exps[i] and node.exps[i].attribute == "total") then local ri = resolve_tuple_and_nominal(infertype) - if ri.typename ~= "map" and ri.typename ~= "record" then + if not (ri is MapType or ri is RecordType) then error_at(var, "attribute only applies to maps and records") ok = false elseif not ri.is_total then @@ -10159,15 +10271,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ri.missing then missing = " (missing: " .. table.concat(ri.missing, ", ") .. ")" end - if ri.typename == "map" then + if ri is MapType then error_at(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false - elseif ri.typename == "record" then + elseif ri is RecordType then error_at(var, "record variable declared does not declare values for all fields" .. missing) ok = false end + ri.is_total = nil end - ri.is_total = nil end end @@ -10177,8 +10289,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t is EmptyTableType then t.declared_at = node t.assigned_to = name + elseif t is ArrayLikeType then + t.inferred_len = nil end - t.inferred_len = nil return ok, t, infertype ~= nil end @@ -10212,7 +10325,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_record_check(t: Type, seen_keys: {CheckableKey:Where}): boolean, {string} + local function total_record_check(t: RecordLikeType, seen_keys: {CheckableKey:Where}): boolean, {string} if t.meta_field_order then return false end @@ -10395,7 +10508,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local where = node.exps[i] or node.exps local rt = resolve_tuple_and_nominal(t) - if rt.typename ~= "enum" and (t.typename ~= "nominal" or rt is UnionType) and not same_type(t, infertype) then + if (not rt is EnumType) + and ((not t is NominalType) or (rt is UnionType)) + and not same_type(t, infertype) + then t = infer_at(where, infertype) add_var(where, var.tk, t, "const", "narrowed_declaration") end @@ -10459,7 +10575,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string widen_all_unions() end - if varname and (rvar is UnionType or rvar.typename == "interface") then + if varname and (rvar is UnionType or rvar is InterfaceType) then -- narrow unions and interfaces add_var(varnode, varname, rval, nil, "narrow") end @@ -10729,7 +10845,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string child.value.expected = decltype.types[n as integer] end end - elseif is_array_type(decltype) then + elseif decltype is ArrayLikeType then for _, child in ipairs(node) do if child.key.constnum then child.value.expected = decltype.elements @@ -10742,7 +10858,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if is_record_type(decltype) then + if decltype is RecordLikeType then for _, child in ipairs(node) do if child.key.conststr then child.value.expected = decltype.fields[child.key.conststr] @@ -10795,9 +10911,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return infer_table_literal(node, children) end - local is_record = is_record_type(decltype) - local is_array = is_array_type(decltype) - local force_array: Type = nil local seen_keys: {CheckableKey:Where} = {} @@ -10812,7 +10925,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string b = (node[i].key.tk == "true") end check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) - if is_record and ck then + if decltype is RecordLikeType and ck then local df = decltype.fields[ck] if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) @@ -10832,7 +10945,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) end - elseif is_array and is_number_type(child.ktype) then + elseif decltype is ArrayLikeType and is_number_type(child.ktype) then local cv = child.vtype if cv is TupleType and i == #children and node[i].key_parsed == "implicit" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) @@ -10862,18 +10975,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t = infer_at(node, an_array(force_array)) else t = resolve_typevars_at(node, node.expected) - if node.expected == t and t.typename == "nominal" then + if node.expected == t and t is NominalType then t = a_type("nominal", { names = t.names, found = t.found, resolved = t.resolved, - }) + } as NominalType) end end - if decltype.typename == "record" then + if decltype is RecordType then local rt = resolve_tuple_and_nominal(t) - if rt.typename == "record" then + if rt is RecordType then rt.is_total, rt.missing = total_record_check(decltype, seen_keys) end elseif decltype is MapType then @@ -11037,7 +11150,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) -- add type arguments from the record implicitly - if rtype.typeargs then + if rtype is RecordLikeType and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { typearg = typ.typearg, @@ -11054,18 +11167,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) - if rtype is EmptyTableType then - edit_type(rtype, "record") - rtype.fields = {} - rtype.field_order = {} - end - if lax and rtype.typename == "unknown" then return end - if not is_record_type(rtype) then - error_at(node, "not a module: %s", rtype) + if rtype is EmptyTableType then + edit_type(rtype, "record") + local r = rtype as RecordType + r.fields = {} + r.field_order = {} + end + + if not rtype is RecordLikeType then + error_at(node, "not a record: %s", rtype) return end @@ -11538,7 +11652,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if orig_a.typename == "nominal" and orig_b.typename == "nominal" and not meta_on_operator then + if orig_a is NominalType and orig_b is NominalType and not meta_on_operator then if is_a(orig_a, orig_b) then t = resolve_tuple(orig_a) else @@ -11703,13 +11817,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local expand_interfaces: function(Type): Type + local expand_interfaces: function(Type) do - local function add_interface_fields(what: string, fields: {string:Type}, field_order: {string}, iface: Type, orig_iface: Type, list?: MetaMode) - for fname, ftype in fields_of(iface, list) do + local function add_interface_fields(what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) + for fname, ftype in fields_of(resolved, list) do if fields[fname] then if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", orig_iface) + error_at(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11718,43 +11832,54 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function expand(t: Type, seen: {Type:boolean}): Type - if t.interfaces_expanded then - return t + local function collect_interfaces(list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} + if t.interface_list then + for _, iface in ipairs(t.interface_list) do + if iface is NominalType then + local ri = resolve_nominal(iface) + if not (ri.typename == "invalid") then + assert(ri is InterfaceType, "nominal resolved to " .. ri.typename) + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(list, ri, seen) + end + table.insert(list, iface) + end + else + if not seen[iface] then + seen[iface] = true + table.insert(list, iface) + end + end + end end - t.interfaces_expanded = true - if seen[t] then + return list + end + + expand_interfaces = function(t: RecordLikeType) + if t.interfaces_expanded then return end - seen[t] = true + t.interfaces_expanded = true - t.fields = t.fields or {} - t.meta_fields = t.meta_fields or {} - t.field_order = t.field_order or {} - t.meta_field_order = t.meta_field_order or {} + t.interface_list = collect_interfaces({}, t, {}) - -- FIXME expand and collect interface_list recursively, THEN add fields for _, iface in ipairs(t.interface_list) do - local orig_iface = iface - - if iface.typename == "nominal" then - iface = resolve_nominal(iface) - end - - if iface.typename == "interface" then - if iface.interface_list then - iface = expand(iface, seen) + if iface is NominalType then + local ri = resolve_nominal(iface) + assert(ri is InterfaceType) + add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + else + if not t.elements then + t.elements = iface + else + if not same_type(iface.elements, t.elements) then + error_at(t, "incompatible array interfaces") + end end - - add_interface_fields("field", t.fields, t.field_order, iface, orig_iface) - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, iface, orig_iface, "meta") end end - return t - end - - expand_interfaces = function(t: Type): Type - return expand(t, {}) end end @@ -11784,7 +11909,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(typ: Type, children: {Type}): Type + after = function(typ: RecordType, children: {Type}): Type local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11794,7 +11919,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if typ.interface_list then for j, _ in ipairs(typ.interface_list) do - typ.interface_list[j] = children[i] + local iface = children[i] + if iface is ArrayType then + typ.interface_list[j] = iface + elseif iface is NominalType then + local ri = resolve_nominal(iface) + if ri is InterfaceType then + typ.interface_list[j] = iface + else + error_at(children[i], "%s is not an interface", children[i]) + end + end i = i + 1 end end @@ -11816,7 +11951,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fargs[1] and fargs[1].is_self then local record_name = typ.declname if record_name then - local selfarg = fargs[1] + local selfarg = fargs[1] as NominalType if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then ftype.is_method = false selfarg.is_self = false @@ -11888,7 +12023,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["nominal"] = { - after = function(typ: Type, _children: {Type}): Type + after = function(typ: NominalType, _children: {Type}): Type if typ.found then return typ end @@ -11899,9 +12034,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- convert nominal into a typevar typ.names = nil edit_type(typ, "typevar") - assert(typ is TypeVarType) - typ.typevar = t.typearg - typ.constraint = t.constraint + local tv = typ as TypeVarType + tv.typevar = t.typearg + tv.constraint = t.constraint elseif t.typename == "typetype" then if t.def.typename ~= "circular_require" then typ.found = t From 94059b9efa07b16e26d223a4c9f1380a6facf554 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 5 Jan 2024 21:33:40 -0300 Subject: [PATCH 081/224] UnresolvedType --- tl.lua | 67 +++++++++++++++++++++++++++++---------------- tl.tl | 85 ++++++++++++++++++++++++++++++++++++---------------------- 2 files changed, 97 insertions(+), 55 deletions(-) diff --git a/tl.lua b/tl.lua index 63149fa76..d2408bf15 100644 --- a/tl.lua +++ b/tl.lua @@ -1369,6 +1369,11 @@ local table_types = { + + + + + @@ -7107,6 +7112,7 @@ tl.type_check = function(ast, opts) end local get_unresolved + local find_unresolved local function add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) local scope = st[#st] @@ -7163,7 +7169,11 @@ tl.type_check = function(ast, opts) local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if symbol_list and node and t.typename ~= "unresolved" and t.typename ~= "none" then + if t.typename == "unresolved" or t.typename == "none" then + return var + end + + if symbol_list and node then local slot if node.symbol_list_slot then slot = node.symbol_list_slot @@ -7329,6 +7339,13 @@ tl.type_check = function(ast, opts) return unresolved end + find_unresolved = function(level) + local u = st[level or #st]["@unresolved"] + if u then + return u.t + end + end + local function begin_scope(node) table.insert(st, {}) @@ -7342,27 +7359,29 @@ tl.type_check = function(ast, opts) local scope = st[#st] local unresolved = scope["@unresolved"] if unresolved then + local unrt = unresolved.t local next_scope = st[#st - 1] local upper = next_scope["@unresolved"] if upper then - for name, nodes in pairs(unresolved.t.labels) do + local uppert = upper.t + for name, nodes in pairs(unrt.labels) do for _, n in ipairs(nodes) do - upper.t.labels[name] = upper.t.labels[name] or {} - table.insert(upper.t.labels[name], n) + uppert.labels[name] = uppert.labels[name] or {} + table.insert(uppert.labels[name], n) end end - for name, types in pairs(unresolved.t.nominals) do + for name, types in pairs(unrt.nominals) do for _, typ in ipairs(types) do - upper.t.nominals[name] = upper.t.nominals[name] or {} - table.insert(upper.t.nominals[name], typ) + uppert.nominals[name] = uppert.nominals[name] or {} + table.insert(uppert.nominals[name], typ) end end - for name, _ in pairs(unresolved.t.global_types) do - upper.t.global_types[name] = true + for name, _ in pairs(unrt.global_types) do + uppert.global_types[name] = true end else next_scope["@unresolved"] = unresolved - unresolved.t.narrows = {} + unrt.narrows = {} end end close_types(scope) @@ -9030,9 +9049,9 @@ a.types[i], b.types[i]), } local function widen_all_unions(node) for i = #st, 1, -1 do local scope = st[i] - local unr = scope["@unresolved"] - if unr and unr.t.narrows then - for name, _ in pairs(unr.t.narrows) do + local unresolved = find_unresolved(i) + if unresolved and unresolved.narrows then + for name, _ in pairs(unresolved.narrows) do if not node or assigned_anywhere(name, node) then widen_in_scope(scope, name) end @@ -9117,13 +9136,14 @@ a.types[i], b.types[i]), } local unresolved = st[#st]["@unresolved"] if unresolved then st[#st]["@unresolved"] = nil - for name, nodes in pairs(unresolved.t.labels) do + local unrt = unresolved.t + for name, nodes in pairs(unrt.labels) do for _, node in ipairs(nodes) do error_at(node, "no visible label '" .. name .. "' for goto") end end - for name, types in pairs(unresolved.t.nominals) do - if not unresolved.t.global_types[name] then + for name, types in pairs(unrt.nominals) do + if not unrt.global_types[name] then for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) @@ -9799,13 +9819,14 @@ a.types[i], b.types[i]), } local function dismiss_unresolved(name) for i = #st, 1, -1 do - local unresolved = st[i]["@unresolved"] + local unresolved = find_unresolved(i) if unresolved then - if unresolved.t.nominals[name] then - for _, t in ipairs(unresolved.t.nominals[name]) do + local uses = unresolved.nominals[name] + if uses then + for _, t in ipairs(uses) do resolve_nominal(t) end - unresolved.t.nominals[name] = nil + unresolved.nominals[name] = nil return end end @@ -10647,13 +10668,13 @@ expand_type(node, values, elements) }) if st[#st][label_id] then error_at(node, "label '" .. node.label .. "' already defined at " .. filename) end - local unresolved = st[#st]["@unresolved"] + local unresolved = find_unresolved() local var = add_var(node, label_id, type_at(node, a_type("none", {}))) if unresolved then - if unresolved.t.labels[node.label] then + if unresolved.labels[node.label] then var.used = true end - unresolved.t.labels[node.label] = nil + unresolved.labels[node.label] = nil end end, after = function() diff --git a/tl.tl b/tl.tl index 6fb1fce0a..48dc7c58d 100644 --- a/tl.tl +++ b/tl.tl @@ -1077,6 +1077,8 @@ local interface Type typename: TypeName tk: string + typeid: integer + yend: integer xend: integer @@ -1097,8 +1099,6 @@ local interface Type alias_to: NominalType is_nested_alias: boolean - typeid: integer - -- function argument is_self: boolean @@ -1107,7 +1107,12 @@ local interface Type ktype: Type vtype: Type - -- unresolved items +end + +local record UnresolvedType + is Type + where self.typename == "unresolved" + labels: {string:{Node}} nominals: {string:{Type}} global_types: {string:boolean} @@ -7106,7 +7111,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.symbol_list_slot = symbol_list_n end - local get_unresolved: function(scope?: Scope): Type + local get_unresolved: function(scope?: Scope): UnresolvedType + local find_unresolved: function(level?: integer): UnresolvedType local function add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable local scope = st[#st] @@ -7163,7 +7169,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if symbol_list and node and t.typename ~= "unresolved" and t.typename ~= "none" then + if t is UnresolvedType or t.typename == "none" then + return var + end + + if symbol_list and node then local slot: integer if node.symbol_list_slot then slot = node.symbol_list_slot @@ -7309,13 +7319,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - get_unresolved = function(scope?: Scope): Type - local unresolved: Type + get_unresolved = function(scope?: Scope): UnresolvedType + local unresolved: UnresolvedType if scope then local unr = scope["@unresolved"] - unresolved = unr and unr.t + unresolved = unr and unr.t as UnresolvedType else - unresolved = find_var_type("@unresolved") + unresolved = find_var_type("@unresolved") as UnresolvedType end if not unresolved then unresolved = a_type("unresolved", { @@ -7323,12 +7333,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string nominals = {}, global_types = {}, narrows = {}, - }) + } as UnresolvedType) add_var(nil, "@unresolved", unresolved) end return unresolved end + find_unresolved = function(level?: integer): UnresolvedType + local u = st[level or #st]["@unresolved"] + if u then + return u.t as UnresolvedType + end + end + local function begin_scope(node?: Node) table.insert(st, {}) @@ -7342,27 +7359,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local scope = st[#st] local unresolved = scope["@unresolved"] if unresolved then + local unrt = unresolved.t as UnresolvedType local next_scope = st[#st - 1] local upper = next_scope["@unresolved"] if upper then - for name, nodes in pairs(unresolved.t.labels) do + local uppert = upper.t as UnresolvedType + for name, nodes in pairs(unrt.labels) do for _, n in ipairs(nodes) do - upper.t.labels[name] = upper.t.labels[name] or {} - table.insert(upper.t.labels[name], n) + uppert.labels[name] = uppert.labels[name] or {} + table.insert(uppert.labels[name], n) end end - for name, types in pairs(unresolved.t.nominals) do + for name, types in pairs(unrt.nominals) do for _, typ in ipairs(types) do - upper.t.nominals[name] = upper.t.nominals[name] or {} - table.insert(upper.t.nominals[name], typ) + uppert.nominals[name] = uppert.nominals[name] or {} + table.insert(uppert.nominals[name], typ) end end - for name, _ in pairs(unresolved.t.global_types) do - upper.t.global_types[name] = true + for name, _ in pairs(unrt.global_types) do + uppert.global_types[name] = true end else next_scope["@unresolved"] = unresolved - unresolved.t.narrows = {} + unrt.narrows = {} end end close_types(scope) @@ -9030,9 +9049,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function widen_all_unions(node?: Node) for i = #st, 1, -1 do local scope = st[i] - local unr = scope["@unresolved"] - if unr and unr.t.narrows then - for name, _ in pairs(unr.t.narrows) do + local unresolved = find_unresolved(i) + if unresolved and unresolved.narrows then + for name, _ in pairs(unresolved.narrows) do if not node or assigned_anywhere(name, node) then widen_in_scope(scope, name) end @@ -9117,13 +9136,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local unresolved = st[#st]["@unresolved"] if unresolved then st[#st]["@unresolved"] = nil - for name, nodes in pairs(unresolved.t.labels) do + local unrt = unresolved.t as UnresolvedType + for name, nodes in pairs(unrt.labels) do for _, node in ipairs(nodes) do error_at(node, "no visible label '" .. name .. "' for goto") end end - for name, types in pairs(unresolved.t.nominals) do - if not unresolved.t.global_types[name] then + for name, types in pairs(unrt.nominals) do + if not unrt.global_types[name] then for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) @@ -9799,13 +9819,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function dismiss_unresolved(name: string) for i = #st, 1, -1 do - local unresolved = st[i]["@unresolved"] + local unresolved = find_unresolved(i) if unresolved then - if unresolved.t.nominals[name] then - for _, t in ipairs(unresolved.t.nominals[name]) do + local uses = unresolved.nominals[name] + if uses then + for _, t in ipairs(uses) do resolve_nominal(t) end - unresolved.t.nominals[name] = nil + unresolved.nominals[name] = nil return end end @@ -10647,13 +10668,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if st[#st][label_id] then error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) end - local unresolved = st[#st]["@unresolved"] + local unresolved = find_unresolved() local var = add_var(node, label_id, type_at(node, a_type("none", {}))) if unresolved then - if unresolved.t.labels[node.label] then + if unresolved.labels[node.label] then var.used = true end - unresolved.t.labels[node.label] = nil + unresolved.labels[node.label] = nil end end, after = function(): Type From 476573bc47ed7bfb18fb7b4ef2e6b792c92afadb Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 5 Jan 2024 21:38:38 -0300 Subject: [PATCH 082/224] LiteralTableItemType --- spec/parser/parser_spec.lua | 4 +-- tl.lua | 54 +++++++++++++++-------------- tl.tl | 68 ++++++++++++++++++++----------------- 3 files changed, 67 insertions(+), 59 deletions(-) diff --git a/spec/parser/parser_spec.lua b/spec/parser/parser_spec.lua index 8cbf40495..d1e66fb38 100644 --- a/spec/parser/parser_spec.lua +++ b/spec/parser/parser_spec.lua @@ -60,8 +60,8 @@ describe("parser", function() assert.same({}, result.syntax_errors) assert.same("statements", result.ast.kind) assert.same("return", result.ast[1].kind) - assert.same("table_literal", result.ast[1].exps[1].kind) - assert.same("table_item", result.ast[1].exps[1][1].kind) + assert.same("literal_table", result.ast[1].exps[1].kind) + assert.same("literal_table_item", result.ast[1].exps[1][1].kind) assert.same("implicit", result.ast[1].exps[1][1].key_parsed) end) diff --git a/tl.lua b/tl.lua index d2408bf15..19cc49ca0 100644 --- a/tl.lua +++ b/tl.lua @@ -1053,7 +1053,7 @@ local table_types = { ["integer"] = false, ["union"] = false, ["nominal"] = false, - ["table_item"] = false, + ["literal_table_item"] = false, ["unresolved_emptytable_value"] = false, ["unresolved_typearg"] = false, ["unresolvable_typearg"] = false, @@ -1373,6 +1373,12 @@ local table_types = { + + + + + + @@ -1816,7 +1822,7 @@ local function parse_table_value(ps, i) end local function parse_table_item(ps, i, n) - local node = new_node(ps.tokens, i, "table_item") + local node = new_node(ps.tokens, i, "literal_table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i, "unexpected eof") end @@ -1943,7 +1949,7 @@ local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) end local function parse_table_literal(ps, i) - local node = new_node(ps.tokens, i, "table_literal") + local node = new_node(ps.tokens, i, "literal_table") return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end @@ -3895,6 +3901,13 @@ local function recurse_type(ast, visit) if ast.elements then table.insert(xs, recurse_type(ast.elements, visit)) end + elseif ast.typename == "literal_table_item" then + if ast.ktype then + table.insert(xs, recurse_type(ast.ktype, visit)) + end + if ast.vtype then + table.insert(xs, recurse_type(ast.vtype, visit)) + end else if ast.def then table.insert(xs, recurse_type(ast.def, visit)) @@ -3902,12 +3915,6 @@ local function recurse_type(ast, visit) if ast.alias_to then table.insert(xs, recurse_type(ast.alias_to, visit)) end - if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) - end - if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) - end end local ret @@ -4027,11 +4034,11 @@ local function recurse_node(root, ["statements"] = walk_children, ["argument_list"] = walk_children, - ["table_literal"] = walk_children, + ["literal_table"] = walk_children, ["variable_list"] = walk_children, ["expression_list"] = walk_children, - ["table_item"] = function(ast, xs) + ["literal_table_item"] = function(ast, xs) xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then @@ -4628,7 +4635,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) return out end, }, - ["table_literal"] = { + ["literal_table"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } @@ -4649,7 +4656,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) return out end, }, - ["table_item"] = { + ["literal_table_item"] = { after = function(node, children) local out = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then @@ -4955,7 +4962,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["union"] = default_type_visitor visit_type.cbs["nominal"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor visit_type.cbs["tuple"] = default_type_visitor visit_type.cbs["poly"] = default_type_visitor @@ -5017,7 +5024,7 @@ local typename_to_typecode = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, - ["table_item"] = tl.typecodes.UNKNOWN, + ["literal_table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, @@ -5026,7 +5033,7 @@ local typename_to_typecode = { local skip_types = { ["none"] = true, - ["table_item"] = true, + ["literal_table_item"] = true, ["unresolved"] = true, ["typetype"] = true, } @@ -8431,7 +8438,7 @@ a.types[i], b.types[i]), } ["number"] = true, ["integer"] = true, ["boolean"] = true, - ["table_literal"] = true, + ["literal_table"] = true, } local function expr_is_definitely_not_closable(e) return definitely_not_closable_exprs[e.kind] @@ -10088,8 +10095,6 @@ a.types[i], b.types[i]), } local keys, values for i, child in ipairs(children) do - assert(child.typename == "table_item") - local ck = child.kname local n = node[i].key.constnum local b = nil @@ -10850,7 +10855,7 @@ expand_type(node, values, elements) }) return tuple end, }, - ["table_literal"] = { + ["literal_table"] = { before = function(node) if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) @@ -10937,7 +10942,6 @@ expand_type(node, values, elements) }) local seen_keys = {} for i, child in ipairs(children) do - assert(child.typename == "table_item") local cvtype = resolve_tuple(child.vtype) local ck = child.kname local n = node[i].key.constnum @@ -11024,7 +11028,7 @@ expand_type(node, values, elements) }) return t end, }, - ["table_item"] = { + ["literal_table_item"] = { after = function(node, children) local kname = node.key.conststr local ktype = children[1] @@ -11040,7 +11044,7 @@ expand_type(node, values, elements) }) vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("table_item", { + return type_at(node, a_type("literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, @@ -11349,7 +11353,7 @@ expand_type(node, values, elements) }) node.e2.expected = node.expected elseif node.op.op == "or" then node.e1.expected = node.expected - if not (node.e2.kind == "table_literal" and #node.e2 == 0) then + if not (node.e2.kind == "literal_table" and #node.e2 == 0) then node.e2.expected = node.expected end end @@ -12152,7 +12156,7 @@ expand_type(node, values, elements) }) visit_type.cbs["integer"] = default_type_visitor visit_type.cbs["thread"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor visit_type.cbs["tuple"] = default_type_visitor visit_type.cbs["poly"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 48dc7c58d..338a96fe9 100644 --- a/tl.tl +++ b/tl.tl @@ -1016,7 +1016,7 @@ local enum TypeName "union" "nominal" "emptytable" - "table_item" + "literal_table_item" "unresolved_emptytable_value" "unresolved_typearg" "unresolvable_typearg" @@ -1053,7 +1053,7 @@ local table_types : {TypeName:boolean} = { ["integer"] = false, ["union"] = false, ["nominal"] = false, - ["table_item"] = false, + ["literal_table_item"] = false, ["unresolved_emptytable_value"] = false, ["unresolved_typearg"] = false, ["unresolvable_typearg"] = false, @@ -1102,11 +1102,17 @@ local interface Type -- function argument is_self: boolean + +end + +local record LiteralTableItemType + is Type + where self.typename == "literal_table_item" + -- table items kname: string ktype: Type vtype: Type - end local record UnresolvedType @@ -1307,8 +1313,8 @@ local enum NodeKind "number" "integer" "boolean" - "table_literal" - "table_item" + "literal_table" + "literal_table_item" "function" "expression_list" "enum_item" @@ -1816,7 +1822,7 @@ local function parse_table_value(ps: ParseState, i: integer): integer, Node, int end local function parse_table_item(ps: ParseState, i: integer, n?: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "table_item") + local node = new_node(ps.tokens, i, "literal_table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i, "unexpected eof") end @@ -1943,7 +1949,7 @@ local function parse_bracket_list(ps: ParseState, i: integer, list: {T}, open end local function parse_table_literal(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "table_literal") + local node = new_node(ps.tokens, i, "literal_table") return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end @@ -3895,6 +3901,13 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.elements then table.insert(xs, recurse_type(ast.elements, visit)) end + elseif ast is LiteralTableItemType then + if ast.ktype then + table.insert(xs, recurse_type(ast.ktype, visit)) + end + if ast.vtype then + table.insert(xs, recurse_type(ast.vtype, visit)) + end else if ast.def then table.insert(xs, recurse_type(ast.def, visit)) @@ -3902,12 +3915,6 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.alias_to then table.insert(xs, recurse_type(ast.alias_to, visit)) end - if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) - end - if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) - end end local ret: T @@ -4027,11 +4034,11 @@ local function recurse_node(root: Node, ["statements"] = walk_children, ["argument_list"] = walk_children, - ["table_literal"] = walk_children, + ["literal_table"] = walk_children, ["variable_list"] = walk_children, ["expression_list"] = walk_children, - ["table_item"] = function(ast: Node, xs: {T}) + ["literal_table_item"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then @@ -4628,7 +4635,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return out end, }, - ["table_literal"] = { + ["literal_table"] = { before = increment_indent, after = function(node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } @@ -4649,7 +4656,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return out end, }, - ["table_item"] = { + ["literal_table_item"] = { after = function(node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then @@ -4955,7 +4962,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["union"] = default_type_visitor visit_type.cbs["nominal"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor visit_type.cbs["tuple"] = default_type_visitor visit_type.cbs["poly"] = default_type_visitor @@ -5017,7 +5024,7 @@ local typename_to_typecode : {TypeName:integer} = { -- types that should be skipped or not present: ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, - ["table_item"] = tl.typecodes.UNKNOWN, + ["literal_table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, ["typetype"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, @@ -5026,7 +5033,7 @@ local typename_to_typecode : {TypeName:integer} = { local skip_types: {TypeName: boolean} = { ["none"] = true, - ["table_item"] = true, + ["literal_table_item"] = true, ["unresolved"] = true, ["typetype"] = true, } @@ -8431,7 +8438,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["number"] = true, ["integer"] = true, ["boolean"] = true, - ["table_literal"] = true, + ["literal_table"] = true, } local function expr_is_definitely_not_closable(e: Node): boolean return definitely_not_closable_exprs[e.kind] @@ -10064,7 +10071,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function infer_table_literal(node: Node, children: {Type}): Type + local function infer_table_literal(node: Node, children: {LiteralTableItemType}): Type local is_record = false local is_array = false local is_map = false @@ -10088,8 +10095,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local keys, values: Type, Type for i, child in ipairs(children) do - assert(child.typename == "table_item") - local ck = child.kname local n = node[i].key.constnum local b: boolean = nil @@ -10850,7 +10855,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return tuple end, }, - ["table_literal"] = { + ["literal_table"] = { before = function(node: Node) if node.expected then local decltype = resolve_tuple_and_nominal(node.expected) @@ -10888,7 +10893,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {Type}): Type + after = function(node: Node, children: {LiteralTableItemType}): Type node.known = FACT_TRUTHY if not node.expected then @@ -10937,7 +10942,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local seen_keys: {CheckableKey:Where} = {} for i, child in ipairs(children) do - assert(child.typename == "table_item") local cvtype = resolve_tuple(child.vtype) local ck = child.kname local n = node[i].key.constnum @@ -11024,7 +11028,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end, }, - ["table_item"] = { + ["literal_table_item"] = { after = function(node: Node, children: {Type}): Type local kname = node.key.conststr local ktype = children[1] @@ -11040,11 +11044,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("table_item", { + return type_at(node, a_type("literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - })) + } as LiteralTableItemType)) end, }, ["local_function"] = { @@ -11349,7 +11353,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.e2.expected = node.expected elseif node.op.op == "or" then node.e1.expected = node.expected - if not (node.e2.kind == "table_literal" and #node.e2 == 0) then + if not (node.e2.kind == "literal_table" and #node.e2 == 0) then node.e2.expected = node.expected end end @@ -12152,7 +12156,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["integer"] = default_type_visitor visit_type.cbs["thread"] = default_type_visitor visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["table_item"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor visit_type.cbs["tuple"] = default_type_visitor visit_type.cbs["poly"] = default_type_visitor From 25e667ea7f779687dbc8117fdc5ffefe1acc294d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 6 Jan 2024 22:40:15 -0300 Subject: [PATCH 083/224] simplify function argument traversal --- tl.lua | 6 ++---- tl.tl | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tl.lua b/tl.lua index 19cc49ca0..132ceb845 100644 --- a/tl.lua +++ b/tl.lua @@ -3876,10 +3876,8 @@ local function recurse_type(ast, visit) end end if ast.args then - for i, child in ipairs(ast.args.tuple) do - if i > 1 or not ast.is_method or child.is_self then - table.insert(xs, recurse_type(child, visit)) - end + for _, child in ipairs(ast.args.tuple) do + table.insert(xs, recurse_type(child, visit)) end end if ast.rets then diff --git a/tl.tl b/tl.tl index 338a96fe9..2536687e3 100644 --- a/tl.tl +++ b/tl.tl @@ -3876,10 +3876,8 @@ local function recurse_type(ast: Type, visit: Visitor): T end end if ast.args then - for i, child in ipairs(ast.args.tuple) do - if i > 1 or not ast.is_method or child.is_self then - table.insert(xs, recurse_type(child, visit)) - end + for _, child in ipairs(ast.args.tuple) do + table.insert(xs, recurse_type(child, visit)) end end if ast.rets then From 0938c212cf3707900b49c2d174808b06e28d2c60 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 6 Jan 2024 22:44:33 -0300 Subject: [PATCH 084/224] TypeAliasType --- tl.lua | 21 ++++++++++++------ tl.tl | 67 ++++++++++++++++++++++++++++++++-------------------------- 2 files changed, 51 insertions(+), 37 deletions(-) diff --git a/tl.lua b/tl.lua index 132ceb845..fee449990 100644 --- a/tl.lua +++ b/tl.lua @@ -1380,6 +1380,11 @@ local table_types = { + + + + + @@ -3244,8 +3249,9 @@ parse_record_body = function(ps, i, def, node) return fail(ps, i, "expected a type definition") end - if nt.newtype.typename == "typealias" then - nt.newtype.is_nested_alias = true + local ntt = nt.newtype + if ntt.typename == "typealias" then + ntt.is_nested_alias = true end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3352,8 +3358,9 @@ parse_newtype = function(ps, i) end if def.typename == "nominal" then - node.newtype = new_type(ps, itype, "typealias") - node.newtype.alias_to = def + local typealias = new_type(ps, itype, "typealias") + typealias.alias_to = def + node.newtype = typealias else node.newtype = new_typetype(ps, itype, def) end @@ -3906,13 +3913,12 @@ local function recurse_type(ast, visit) if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end + elseif ast.typename == "typealias" then + table.insert(xs, recurse_type(ast.alias_to, visit)) else if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end - if ast.alias_to then - table.insert(xs, recurse_type(ast.alias_to, visit)) - end end local ret @@ -6837,6 +6843,7 @@ tl.type_check = function(ast, opts) elseif t.typename == "typetype" then copy.def, same = resolve(t.def, same) elseif t.typename == "typealias" then + assert(copy.typename == "typealias") copy.alias_to, same = resolve(t.alias_to, same) copy.is_nested_alias = t.is_nested_alias elseif t.typename == "nominal" then diff --git a/tl.tl b/tl.tl index 2536687e3..721217be8 100644 --- a/tl.tl +++ b/tl.tl @@ -1095,9 +1095,6 @@ local interface Type closed: boolean is_abstract: boolean - -- typealias - alias_to: NominalType - is_nested_alias: boolean -- function argument is_self: boolean @@ -1105,6 +1102,14 @@ local interface Type end +local record TypeAliasType + is Type + where self.typename == "typealias" + + alias_to: NominalType + is_nested_alias: boolean +end + local record LiteralTableItemType is Type where self.typename == "literal_table_item" @@ -3244,8 +3249,9 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no return fail(ps, i, "expected a type definition") end - if nt.newtype.typename == "typealias" then - nt.newtype.is_nested_alias = true + local ntt = nt.newtype + if ntt is TypeAliasType then + ntt.is_nested_alias = true end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3352,8 +3358,9 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end if def is NominalType then - node.newtype = new_type(ps, itype, "typealias") - node.newtype.alias_to = def + local typealias = new_type(ps, itype, "typealias") as TypeAliasType + typealias.alias_to = def + node.newtype = typealias else node.newtype = new_typetype(ps, itype, def) end @@ -3906,13 +3913,12 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end + elseif ast is TypeAliasType then + table.insert(xs, recurse_type(ast.alias_to, visit)) else if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end - if ast.alias_to then - table.insert(xs, recurse_type(ast.alias_to, visit)) - end end local ret: T @@ -4826,7 +4832,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | after = function(node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } local nt = node.newtype - if nt.typename == "typealias" then + if nt is TypeAliasType then table.insert(out, table.concat(nt.alias_to.names, ".")) elseif nt.typename == "typetype" then local def = nt.def @@ -5099,7 +5105,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer local rt = t if rt.typename == "typetype" then rt = rt.def - elseif rt.typename == "typealias" then + elseif rt is TypeAliasType then rt = rt.alias_to elseif rt is TupleType and #rt.tuple == 1 then rt = rt.tuple[1] @@ -5124,7 +5130,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer rt = t end end - assert(not (rt.typename == "typetype" or rt.typename == "typealias")) + assert(not (rt.typename == "typetype" or rt is TypeAliasType)) if rt is RecordLikeType then -- store record field info @@ -5596,7 +5602,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "nil" elseif t.typename == "none" then return "" - elseif t.typename == "typealias" then + elseif t is TypeAliasType then return "type " .. show(t.alias_to) elseif t.typename == "typetype" then return "type " .. show(t.def) @@ -6595,7 +6601,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = typ.found end end - if typ.typename == "typetype" or typ.typename == "typealias" then + if typ.typename == "typetype" or typ is TypeAliasType then return typ elseif accept_typearg and typ is TypeArgType then return typ @@ -6605,7 +6611,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function union_type(t: Type): string, Type if t.typename == "typetype" then return union_type(t.def), t.def - elseif t.typename == "typealias" then + elseif t is TypeAliasType then return union_type(t.alias_to), t.alias_to elseif t is TupleType then return union_type(t.tuple[1]), t.tuple[1] @@ -6735,7 +6741,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function resolve_typetype(t: Type): Type if t.typename == "typetype" then return t.def - elseif t.typename == "typealias" then + elseif t is TypeAliasType then return t.alias_to else return t @@ -6836,7 +6842,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t.typename == "typetype" then copy.def, same = resolve(t.def, same) - elseif t.typename == "typealias" then + elseif t is TypeAliasType then + assert(copy is TypeAliasType) copy.alias_to, same = resolve(t.alias_to, same) copy.is_nested_alias = t.is_nested_alias elseif t is NominalType then @@ -7036,7 +7043,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string var.is_func_arg and "argument" or t is FunctionType and "function" or t.typename == "typetype" and "type" - or t.typename == "typealias" and "type" + or t is TypeAliasType and "type" or "variable", name, show_type(var.t) @@ -7304,12 +7311,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if var.used_as_type then var.declared_at.elide_type = true else - if (t.typename == "typetype" or t.typename == "typealias") and not is_global then + if (t.typename == "typetype" or t is TypeAliasType) and not is_global then var.declared_at.elide_type = true end table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and (t.typename == "typetype" or t.typename == "typealias") and var.aliasing then + elseif var.used and (t.typename == "typetype" or t is TypeAliasType) and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -7450,7 +7457,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local resolved: Type - if typetype.typename == "typealias" then + if typetype is TypeAliasType then typetype = typetype.alias_to.found end @@ -7492,7 +7499,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolved end - resolve_typealias = function(typealias: Type): Type, Variable + resolve_typealias = function(typealias: TypeAliasType): Type, Variable local names = typealias.alias_to.names local aliasing = find_var(names[1], "use_type") if not aliasing then @@ -8922,7 +8929,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if tbl.typename == "typetype" then tbl = tbl.def - elseif tbl.typename == "typealias" then + elseif tbl is TypeAliasType then if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else @@ -9416,7 +9423,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.typename == "typetype" then t = t.def - elseif t.typename == "typealias" then + elseif t is TypeAliasType then t = t.alias_to.resolved end @@ -10332,7 +10339,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else local newtype = value.newtype - if newtype.typename == "typealias" then + if newtype is TypeAliasType then return resolve_typealias(value.newtype) else return value.newtype, nil @@ -10358,7 +10365,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local missing: {string} for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype.typename == "typetype" or ftype.typename == "typealias") then + if not (ftype.typename == "typetype" or ftype is TypeAliasType) then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10403,7 +10410,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local var = resolve_tuple_and_nominal(vartype) - if var.typename == "typetype" or var.typename == "typealias" then + if var.typename == "typetype" or var is TypeAliasType then error_at(where, "cannot reassign a type") return nil end @@ -10953,7 +10960,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - if df.typename == "typetype" or df.typename == "typealias" then + if df.typename == "typetype" or df is TypeAliasType then error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -11924,7 +11931,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(nil, "@self", type_at(typ, a_typetype({ def = typ }))) for fname, ftype in fields_of(typ) do - if ftype.typename == "typealias" then + if ftype is TypeAliasType then resolve_nominal(ftype.alias_to) add_var(nil, fname, ftype) elseif ftype.typename == "typetype" then From 4293ee1d3ab8fb6d7c7e5effc713b6f02d1d411f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 6 Jan 2024 23:41:08 -0300 Subject: [PATCH 085/224] TypeDeclType --- tl.lua | 373 +++++++++++++++++++++++++++++---------------------------- tl.tl | 287 ++++++++++++++++++++++---------------------- 2 files changed, 331 insertions(+), 329 deletions(-) diff --git a/tl.lua b/tl.lua index fee449990..f249b1774 100644 --- a/tl.lua +++ b/tl.lua @@ -1039,7 +1039,7 @@ local table_types = { ["emptytable"] = true, ["tupletable"] = true, - ["typetype"] = false, + ["typedecl"] = false, ["typealias"] = false, ["typevar"] = false, ["typearg"] = false, @@ -1691,13 +1691,9 @@ local function new_type(ps, i, typename) }) end -local function new_typetype(ps, i, def) - local t = new_type(ps, i, "typetype") +local function new_typedecl(ps, i, def) + local t = new_type(ps, i, "typedecl") t.def = def - if def.typename == "interface" then - - t.is_abstract = true - end return t end @@ -1709,12 +1705,6 @@ end - - - - - - local function c_tuple(t) return a_type("tuple", { tuple = t }) end @@ -3046,7 +3036,7 @@ local function parse_nested_type(ps, i, def, typename, parse_body) local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype = new_typetype(ps, i, ndef) + nt.newtype = new_typedecl(ps, i, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3216,7 +3206,6 @@ parse_record_body = function(ps, i, def, node) } }) typ.rets = a_type("tuple", { tuple = { BOOLEAN } }) typ.macroexp = where_macroexp - typ.is_abstract = true def.meta_fields = {} def.meta_field_order = {} @@ -3307,7 +3296,6 @@ parse_record_body = function(ps, i, def, node) fail(ps, i + 1, "macroexp must have a function type") else i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) - t.is_abstract = true end end @@ -3349,7 +3337,7 @@ parse_newtype = function(ps, i) return fail(ps, i, "expected a type") end - node.newtype = new_typetype(ps, itype, def) + node.newtype = new_typedecl(ps, itype, def) return i, node else i, def = parse_type(ps, i) @@ -3362,7 +3350,7 @@ parse_newtype = function(ps, i) typealias.alias_to = def node.newtype = typealias else - node.newtype = new_typetype(ps, itype, def) + node.newtype = new_typedecl(ps, itype, def) end return i, node @@ -3499,7 +3487,7 @@ local function parse_type_declaration(ps, i, node_name) end local nt = asgn.value.newtype - if nt.typename == "typetype" then + if nt.typename == "typedecl" then local def = nt.def if def.fields or def.typename == "enum" then if not def.declname then @@ -3530,7 +3518,7 @@ local function parse_type_constructor(ps, i, node_name, type_name, parse_body) i = parse_body(ps, i, def, nt) - nt.newtype = new_typetype(ps, itype, def) + nt.newtype = new_typedecl(ps, itype, def) return i, asgn end @@ -3915,10 +3903,8 @@ local function recurse_type(ast, visit) end elseif ast.typename == "typealias" then table.insert(xs, recurse_type(ast.alias_to, visit)) - else - if ast.def then - table.insert(xs, recurse_type(ast.def, visit)) - end + elseif ast.typename == "typedecl" then + table.insert(xs, recurse_type(ast.def, visit)) end local ret @@ -4385,13 +4371,15 @@ function tl.pretty_print_ast(ast, gen_target, mode) local function print_record_def(typ) local out = { "{" } - for _, name in ipairs(typ.field_order) do - local def = typ.fields[name].def - if typ.fields[name].typename == "typetype" and def.fields then - table.insert(out, name) - table.insert(out, " = ") - table.insert(out, print_record_def(typ.fields[name].def)) - table.insert(out, ", ") + for fname, ftype in fields_of(typ) do + if ftype.typename == "typedecl" then + local def = ftype.def + if def.fields then + table.insert(out, fname) + table.insert(out, " = ") + table.insert(out, print_record_def(def)) + table.insert(out, ", ") + end end end table.insert(out, "}") @@ -4834,7 +4822,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local nt = node.newtype if nt.typename == "typealias" then table.insert(out, table.concat(nt.alias_to.names, ".")) - elseif nt.typename == "typetype" then + elseif nt.typename == "typedecl" then local def = nt.def if def.fields then table.insert(out, print_record_def(def)) @@ -4948,7 +4936,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) } visit_type.cbs["string"] = default_type_visitor - visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["typevar"] = default_type_visitor visit_type.cbs["typearg"] = default_type_visitor @@ -5030,7 +5018,7 @@ local typename_to_typecode = { ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, - ["typetype"] = tl.typecodes.UNKNOWN, + ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5039,7 +5027,7 @@ local skip_types = { ["none"] = true, ["literal_table_item"] = true, ["unresolved"] = true, - ["typetype"] = true, + ["typedecl"] = true, } local get_typenum @@ -5103,7 +5091,7 @@ get_typenum = function(trenv, t) n = trenv.next_num local rt = t - if rt.typename == "typetype" then + if rt.typename == "typedecl" then rt = rt.def elseif rt.typename == "typealias" then rt = rt.alias_to @@ -5130,7 +5118,7 @@ get_typenum = function(trenv, t) rt = t end end - assert(not (rt.typename == "typetype" or rt.typename == "typealias")) + assert(not (rt.typename == "typedecl" or rt.typename == "typealias")) if rt.fields then @@ -5604,7 +5592,7 @@ local function show_type_base(t, short, seen) return "" elseif t.typename == "typealias" then return "type " .. show(t.alias_to) - elseif t.typename == "typetype" then + elseif t.typename == "typedecl" then return "type " .. show(t.def) else return "<" .. t.typename .. " " .. tostring(t) .. ">" @@ -5974,7 +5962,7 @@ local function init_globals(lax) local standard_library = { ["..."] = a_vararg({ STRING }), - ["any"] = a_type("typetype", { def = ANY }), + ["any"] = a_type("typedecl", { def = ANY }), ["arg"] = a_type("array", { elements = STRING }), ["assert"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a, OPT(b) } }), rets = a_type("tuple", { tuple = { a } }) } end), ["collectgarbage"] = a_type("poly", { types = { @@ -6022,70 +6010,70 @@ local function init_globals(lax) } }), ["tostring"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["FILE"] = a_type("typetype", { - def = a_record({ - is_userdata = true, - fields = { - ["close"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING, INTEGER } }) }), - ["flush"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = {} }) }), - ["lines"] = a_file_reader(function(ctor, args, rets) - table.insert(args, 1, NOMINAL_FILE) - return a_function({ args = ctor(args), rets = a_type("tuple", { tuple = { + ["FILE"] = a_type("typedecl", { def = + a_record({ + is_userdata = true, + fields = { + ["close"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING, INTEGER } }) }), + ["flush"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = {} }) }), + ["lines"] = a_file_reader(function(ctor, args, rets) + table.insert(args, 1, NOMINAL_FILE) + return a_function({ args = ctor(args), rets = a_type("tuple", { tuple = { a_function({ args = a_type("tuple", { tuple = {} }), rets = ctor(rets) }), } }), }) - end), - ["read"] = a_file_reader(function(ctor, args, rets) - table.insert(args, 1, NOMINAL_FILE) - return a_function({ args = ctor(args), rets = ctor(rets) }) - end), - ["seek"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER, STRING } }) }), - ["setvbuf"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = {} }) }), - ["write"] = a_function({ args = a_vararg({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), + end), + ["read"] = a_file_reader(function(ctor, args, rets) + table.insert(args, 1, NOMINAL_FILE) + return a_function({ args = ctor(args), rets = ctor(rets) }) + end), + ["seek"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER, STRING } }) }), + ["setvbuf"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = {} }) }), + ["write"] = a_function({ args = a_vararg({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), + + }, + meta_fields = { ["__close"] = FUNCTION }, + meta_field_order = { "__close" }, + }) }), + + ["metatable"] = a_type("typedecl", { def = +a_grecord(1, function(a) return { + fields = { + ["__call"] = a_function({ args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), + ["__gc"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), + ["__index"] = ANY, + ["__len"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__mode"] = an_enum({ "k", "v", "kv" }), + ["__newindex"] = ANY, + ["__pairs"] = a_gfunction(2, function(k, v) + return { + args = a_type("tuple", { tuple = { a } }), + rets = a_type("tuple", { tuple = { a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { k, v } }) }) } }), + } + end), + ["__tostring"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["__name"] = STRING, + ["__add"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__sub"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__mul"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__div"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__idiv"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__mod"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__pow"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__unm"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__band"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__bor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__bxor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__bnot"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__shl"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__shr"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__concat"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__eq"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["__lt"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["__le"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["__close"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), + }, +} end) }), - }, - meta_fields = { ["__close"] = FUNCTION }, - meta_field_order = { "__close" }, - }), - }), - ["metatable"] = a_type("typetype", { - def = a_grecord(1, function(a) return { - fields = { - ["__call"] = a_function({ args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), - ["__gc"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), - ["__index"] = ANY, - ["__len"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__mode"] = an_enum({ "k", "v", "kv" }), - ["__newindex"] = ANY, - ["__pairs"] = a_gfunction(2, function(k, v) - return { - args = a_type("tuple", { tuple = { a } }), - rets = a_type("tuple", { tuple = { a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { k, v } }) }) } }), - } - end), - ["__tostring"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["__name"] = STRING, - ["__add"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__sub"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__mul"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__div"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__idiv"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__mod"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__pow"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__unm"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__band"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__bor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__bxor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__bnot"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__shl"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__shr"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__concat"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__eq"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["__lt"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["__le"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["__close"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), - }, - } end), - }), ["coroutine"] = a_record({ fields = { ["create"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION } }), rets = a_type("tuple", { tuple = { THREAD } }) }), @@ -6100,9 +6088,9 @@ local function init_globals(lax) }), ["debug"] = a_record({ fields = { - ["Info"] = a_type("typetype", { def = DEBUG_GETINFO_TABLE }), - ["Hook"] = a_type("typetype", { def = DEBUG_HOOK_FUNCTION }), - ["HookEvent"] = a_type("typetype", { def = DEBUG_HOOK_EVENT }), + ["Info"] = a_type("typedecl", { def = DEBUG_GETINFO_TABLE }), + ["Hook"] = a_type("typedecl", { def = DEBUG_HOOK_FUNCTION }), + ["HookEvent"] = a_type("typedecl", { def = DEBUG_HOOK_EVENT }), ["debug"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = {} }) }), ["gethook"] = a_function({ args = a_type("tuple", { tuple = { OPT(THREAD) } }), rets = a_type("tuple", { tuple = { DEBUG_HOOK_FUNCTION, INTEGER } }) }), @@ -6405,7 +6393,7 @@ tl.type_check = function(ast, opts) end if opts.module_name then - env.modules[opts.module_name] = a_type("typetype", { def = CIRCULAR_REQUIRE }) + env.modules[opts.module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) end local lax = opts.lax @@ -6562,14 +6550,13 @@ tl.type_check = function(ast, opts) end local function ensure_not_abstract(where, t) - if not t.is_abstract then - return - end - if t.typename == "function" and t.macroexp then error_at(where, "macroexps are abstract; consider using a concrete function") - else - error_at(where, "interfaces are abstract; consider using a concrete record") + elseif t.typename == "typedecl" then + local def = t.def + if def.typename == "interface" then + error_at(where, "interfaces are abstract; consider using a concrete record") + end end end @@ -6582,7 +6569,7 @@ tl.type_check = function(ast, opts) typ = typ.found end for i = 2, #names do - if typ.typename == "typetype" then + if typ.typename == "typedecl" then typ = typ.def end @@ -6601,7 +6588,7 @@ tl.type_check = function(ast, opts) typ = typ.found end end - if typ.typename == "typetype" or typ.typename == "typealias" then + if typ.typename == "typedecl" or typ.typename == "typealias" then return typ elseif accept_typearg and typ.typename == "typearg" then return typ @@ -6609,18 +6596,18 @@ tl.type_check = function(ast, opts) end local function union_type(t) - if t.typename == "typetype" then + if t.typename == "typedecl" then return union_type(t.def), t.def elseif t.typename == "typealias" then return union_type(t.alias_to), t.alias_to elseif t.typename == "tuple" then return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then - local typetype = t.found or find_type(t.names) - if not typetype then + local typedecl = t.found or find_type(t.names) + if not typedecl then return "invalid" end - return union_type(typetype) + return union_type(typedecl) elseif t.fields then if t.is_userdata then return "userdata", t @@ -6738,8 +6725,8 @@ tl.type_check = function(ast, opts) tostring(nfargs or 0) end - local function resolve_typetype(t) - if t.typename == "typetype" then + local function resolve_typedecl(t) + if t.typename == "typedecl" then return t.def elseif t.typename == "typealias" then return t.alias_to @@ -6808,7 +6795,6 @@ tl.type_check = function(ast, opts) seen[orig_t] = copy copy.opt = t.opt - copy.is_abstract = t.is_abstract copy.typename = t.typename copy.filename = t.filename copy.x = t.x @@ -6840,7 +6826,8 @@ tl.type_check = function(ast, opts) if t.constraint then copy.constraint, same = resolve(t.constraint, same) end - elseif t.typename == "typetype" then + elseif t.typename == "typedecl" then + assert(copy.typename == "typedecl") copy.def, same = resolve(t.def, same) elseif t.typename == "typealias" then assert(copy.typename == "typealias") @@ -7042,7 +7029,7 @@ tl.type_check = function(ast, opts) "unused %s %s: %s", var.is_func_arg and "argument" or t.typename == "function" and "function" or - t.typename == "typetype" and "type" or + t.typename == "typedecl" and "type" or t.typename == "typealias" and "type" or "variable", name, @@ -7270,7 +7257,7 @@ tl.type_check = function(ast, opts) local function close_nested_records(t) for _, ft in pairs(t.fields) do - if ft.typename == "typetype" then + if ft.typename == "typedecl" then ft.closed = true local def = ft.def if def.fields then @@ -7283,7 +7270,7 @@ tl.type_check = function(ast, opts) local function close_types(vars) for _, var in pairs(vars) do local t = var.t - if t.typename == "typetype" then + if t.typename == "typedecl" then t.closed = true local def = t.def if def.fields then @@ -7311,12 +7298,12 @@ tl.type_check = function(ast, opts) if var.used_as_type then var.declared_at.elide_type = true else - if (t.typename == "typetype" or t.typename == "typealias") and not is_global then + if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then var.declared_at.elide_type = true end table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and (t.typename == "typetype" or t.typename == "typealias") and var.aliasing then + elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -7449,30 +7436,30 @@ tl.type_check = function(ast, opts) return t.resolved end - local typetype = t.found or find_type(t.names) - if not typetype then + local found = t.found or find_type(t.names) + if not found then error_at(t, "unknown type %s", t) return INVALID end local resolved - if typetype.typename == "typealias" then - typetype = typetype.alias_to.found + if found.typename == "typealias" then + found = found.alias_to.found end - if typetype.typename == "typetype" then - local def = typetype.def + if found.typename == "typedecl" then + local def = found.def if def.typename == "circular_require" then - return typetype.def + return def end if def.typename == "nominal" then - typetype = def.found - assert(typetype.typename == "typetype") - def = typetype.def + found = def.found + assert(found.typename == "typedecl") + def = found.def end assert(not (def.typename == "nominal")) @@ -7494,7 +7481,7 @@ tl.type_check = function(ast, opts) t.y = resolved.y end end - t.found = typetype + t.found = found t.resolved = resolved return resolved end @@ -7511,22 +7498,24 @@ tl.type_check = function(ast, opts) return t.resolved, aliasing end - local typetype = t.found or find_type(t.names) - if not typetype then + local found = t.found or find_type(t.names) + if not found then error_at(t, "unknown type %s", t) return INVALID end + assert(found.typename == "typedecl") + if t.typevals then - local resolved = match_typevals(t, typetype.def) + local resolved = match_typevals(t, found.def) t.resolved = resolved - t.found = typetype - typetype = a_type("typetype", { def = resolved }) + t.found = found + found = a_type("typedecl", { def = resolved }) else t.resolved = t end - return typetype, aliasing + return found, aliasing end end @@ -8214,7 +8203,7 @@ a.types[i], b.types[i]), } return compare_map(a.keys, INTEGER, a.values, b.elements) end, }, - ["typetype"] = { + ["typedecl"] = { ["record"] = function(a, b) local def = a.def if def.fields then @@ -8505,9 +8494,11 @@ a.types[i], b.types[i]), } end end - local funcdef = func.def - if func.typename == "typetype" and funcdef.typename == "record" then - func = func.def + if func.typename == "typedecl" then + local funcdef = func.def + if funcdef.typename == "record" then + func = func.def + end end if func.fields and func.meta_fields and func.meta_fields["__call"] then @@ -8768,7 +8759,7 @@ a.types[i], b.types[i]), } return resolve_typevars_at(where, f.rets) end - local function check_call(where, where_args, func, args, expected_rets, is_typetype_funcall, is_method, argdelta) + local function check_call(where, where_args, func, args, expected_rets, is_typedecl_funcall, is_method, argdelta) assert(type(func) == "table") assert(type(args) == "table") @@ -8782,7 +8773,7 @@ a.types[i], b.types[i]), } argdelta = is_method and -1 or argdelta or 0 if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_type("typetype", { def = args.tuple[1] }))) + add_var(nil, "@self", type_at(where, a_type("typedecl", { def = args.tuple[1] }))) end local passes, n = 1, 1 @@ -8801,7 +8792,7 @@ a.types[i], b.types[i]), } if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then - if not is_typetype_funcall then + if not is_typedecl_funcall then add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") end else @@ -8858,15 +8849,18 @@ a.types[i], b.types[i]), } begin_scope() - local is_typetype_funcall + local is_typedecl_funcall if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then local receiver = node.e1.receiver - if receiver.typename == "nominal" and receiver.resolved and receiver.resolved.typename == "typetype" then - is_typetype_funcall = true + if receiver.typename == "nominal" then + local resolved = receiver.resolved + if resolved and resolved.typename == "typedecl" then + is_typedecl_funcall = true + end end end - local ret, f = check_call(node, where_args, func, args, expected_rets, is_typetype_funcall, is_method, argdelta) + local ret, f = check_call(node, where_args, func, args, expected_rets, is_typedecl_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() @@ -8927,7 +8921,7 @@ a.types[i], b.types[i]), } tbl = find_var_type("string") end - if tbl.typename == "typetype" then + if tbl.typename == "typedecl" then tbl = tbl.def elseif tbl.typename == "typealias" then if tbl.is_nested_alias then @@ -9256,7 +9250,7 @@ a.types[i], b.types[i]), } local function type_check_index(anode, bnode, a, b) local orig_a = a local orig_b = b - a = resolve_typetype(resolve_tuple_and_nominal(a)) + a = resolve_typedecl(resolve_tuple_and_nominal(a)) b = resolve_tuple_and_nominal(b) if lax and is_unknown(a) then @@ -9403,11 +9397,15 @@ a.types[i], b.types[i]), } end local t = v.t - if t.closed then - return nil, nil, exp.tk + if t.typename == "typedecl" then + if t.closed then + return nil, nil, exp.tk + end + + return t.def, v, exp.tk end - return t.def or t, v, exp.tk + return t, v, exp.tk elseif exp.kind == "op" then local t, v, rname = find_record_to_extend(exp.e1) @@ -9421,7 +9419,7 @@ a.types[i], b.types[i]), } end t = t.fields[fname] - if t.typename == "typetype" then + if t.typename == "typedecl" then t = t.def elseif t.typename == "typealias" then t = t.alias_to.resolved @@ -9431,9 +9429,7 @@ a.types[i], b.types[i]), } end end - local function typetype_to_nominal(where, name, t, resolved) - assert(t.typename == "typetype") - + local function typedecl_to_nominal(where, name, t, resolved) local typevals local def = t.def if def.typeargs then @@ -9461,8 +9457,8 @@ a.types[i], b.types[i]), } return nil end - if t.typename == "typetype" then - return typetype_to_nominal(exp, exp.tk, t) + if t.typename == "typedecl" then + return typedecl_to_nominal(exp, exp.tk, t) else return t end @@ -9474,10 +9470,15 @@ a.types[i], b.types[i]), } end if t.typename == "nominal" then - local def = t.found and t.found.def - if def.fields and def.fields[exp.e2.tk] then - table.insert(t.names, exp.e2.tk) - t.found = def.fields[exp.e2.tk] + local found = t.found + if found then + if found.typename == "typedecl" then + local def = found.def + if def.fields and def.fields[exp.e2.tk] then + table.insert(t.names, exp.e2.tk) + t.found = def.fields[exp.e2.tk] + end + end end elseif t.fields then return t.fields and t.fields[exp.e2.tk] @@ -10365,7 +10366,7 @@ expand_type(node, values, elements) }) local missing for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype.typename == "typetype" or ftype.typename == "typealias") then + if not (ftype.typename == "typedecl" or ftype.typename == "typealias") then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10410,7 +10411,7 @@ expand_type(node, values, elements) }) end local var = resolve_tuple_and_nominal(vartype) - if var.typename == "typetype" or var.typename == "typealias" then + if var.typename == "typedecl" or var.typename == "typealias" then error_at(where, "cannot reassign a type") return nil end @@ -10866,7 +10867,7 @@ expand_type(node, values, elements) }) local decltype = resolve_tuple_and_nominal(node.expected) if decltype.typename == "typevar" and decltype.constraint then - decltype = resolve_typetype(resolve_tuple_and_nominal(decltype.constraint)) + decltype = resolve_typedecl(resolve_tuple_and_nominal(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -10909,7 +10910,7 @@ expand_type(node, values, elements) }) local constraint if decltype.typename == "typevar" and decltype.constraint then - constraint = resolve_typetype(decltype.constraint) + constraint = resolve_typedecl(decltype.constraint) decltype = resolve_tuple_and_nominal(constraint) end @@ -10960,7 +10961,7 @@ expand_type(node, values, elements) }) if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - if df.typename == "typetype" or df.typename == "typealias" then + if df.typename == "typedecl" or df.typename == "typealias" then error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -11177,7 +11178,7 @@ expand_type(node, values, elements) }) begin_scope(node) end, before_arguments = function(_node, children) - local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) if rtype.fields and rtype.typeargs then @@ -11195,7 +11196,7 @@ expand_type(node, values, elements) }) local rets = children[4] assert(rets.typename == "tuple") - local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) if lax and rtype.typename == "unknown" then return @@ -11411,7 +11412,7 @@ expand_type(node, values, elements) }) local expected = node.expected and resolve_tuple_and_nominal(node.expected) - if ra.typename == "circular_require" or (ra.def and ra.def.typename == "circular_require") then + if ra.typename == "circular_require" or (ra.typename == "typedecl" and ra.def and ra.def.typename == "circular_require") then return invalid_at(node, "cannot dereference a type from a circular require") end @@ -11426,12 +11427,12 @@ expand_type(node, values, elements) }) end ensure_not_abstract(node.e1, ra) - if ra.typename == "typetype" and ra.def.typename == "record" then + if ra.typename == "typedecl" and ra.def.typename == "record" then ra = ra.def end if rb then ensure_not_abstract(node.e2, rb) - if rb.typename == "typetype" and rb.def.typename == "record" then + if rb.typename == "typedecl" and rb.def.typename == "record" then rb = rb.def end end @@ -11475,10 +11476,10 @@ expand_type(node, values, elements) }) if rb.typename == "integer" then all_needs_compat["math"] = true end - if ra.typename == "typetype" then + if ra.typename == "typedecl" then error_at(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) + check_metamethod(node, "__is", ra, resolve_typedecl(rb), orig_a, orig_b) node.known = IsFact({ var = node.e1.tk, typ = b, where = node }) else error_at(node, "can only use 'is' on variables") @@ -11743,8 +11744,8 @@ expand_type(node, values, elements) }) return invalid_at(node, "unknown variable: " .. node.tk) end - if t.typename == "typetype" then - t = typetype_to_nominal(node, node.tk, t, t) + if t.typename == "typedecl" then + t = typedecl_to_nominal(node, node.tk, t, t) end return t @@ -11928,13 +11929,13 @@ expand_type(node, values, elements) }) ["record"] = { before = function(typ) begin_scope() - add_var(nil, "@self", type_at(typ, a_type("typetype", { def = typ }))) + add_var(nil, "@self", type_at(typ, a_type("typedecl", { def = typ }))) for fname, ftype in fields_of(typ) do if ftype.typename == "typealias" then resolve_nominal(ftype.alias_to) add_var(nil, fname, ftype) - elseif ftype.typename == "typetype" then + elseif ftype.typename == "typedecl" then add_var(nil, fname, ftype) end end @@ -12067,7 +12068,7 @@ expand_type(node, values, elements) }) local tv = typ tv.typevar = t.typearg tv.constraint = t.constraint - elseif t.typename == "typetype" then + elseif t.typename == "typedecl" then if t.def.typename ~= "circular_require" then typ.found = t end @@ -12150,7 +12151,7 @@ expand_type(node, values, elements) }) visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 721217be8..ed58a1773 100644 --- a/tl.tl +++ b/tl.tl @@ -996,7 +996,7 @@ local function new_typeid(): integer end local enum TypeName - "typetype" + "typedecl" "typealias" "typevar" "typearg" @@ -1039,7 +1039,7 @@ local table_types : {TypeName:boolean} = { ["emptytable"] = true, ["tupletable"] = true, - ["typetype"] = false, + ["typedecl"] = false, ["typealias"] = false, ["typevar"] = false, ["typearg"] = false, @@ -1090,16 +1090,16 @@ local interface Type -- arguments: optional arity opt: boolean - -- typetype - def: Type - closed: boolean - is_abstract: boolean - - -- function argument is_self: boolean +end +local record TypeDeclType + is Type + where self.typename == "typedecl" + def: Type + closed: boolean end local record TypeAliasType @@ -1691,24 +1691,14 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type }) end -local function new_typetype(ps: ParseState, i: integer, def: Type): Type - local t = new_type(ps, i, "typetype") +local function new_typedecl(ps: ParseState, i: integer, def: Type): Type + local t = new_type(ps, i, "typedecl") as TypeDeclType t.def = def - if def is InterfaceType then - -- ...or should this be set on traversal, to account for nominal type aliases? - t.is_abstract = true - end return t end -local macroexp a_typetype(t: Type): Type --- FIXME set is_abstract here once standard_library defines interfaces --- local def = t.def --- if def is InterfaceType then --- t.is_abstract = true --- end --- return t - return a_type("typetype", t) +local macroexp a_typedecl(def: Type): TypeDeclType + return a_type("typedecl", { def = def } as TypeDeclType) end local macroexp a_tuple(t: {Type}): TupleType @@ -3046,7 +3036,7 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype = new_typetype(ps, i, ndef) + nt.newtype = new_typedecl(ps, i, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3216,7 +3206,6 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no } typ.rets = a_tuple { BOOLEAN } typ.macroexp = where_macroexp - typ.is_abstract = true def.meta_fields = {} def.meta_field_order = {} @@ -3307,7 +3296,6 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no fail(ps, i + 1, "macroexp must have a function type") else i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) - t.is_abstract = true end end @@ -3349,7 +3337,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node return fail(ps, i, "expected a type") end - node.newtype = new_typetype(ps, itype, def) + node.newtype = new_typedecl(ps, itype, def) return i, node else i, def = parse_type(ps, i) @@ -3362,7 +3350,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node typealias.alias_to = def node.newtype = typealias else - node.newtype = new_typetype(ps, itype, def) + node.newtype = new_typedecl(ps, itype, def) end return i, node @@ -3499,7 +3487,7 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod end local nt = asgn.value.newtype - if nt.typename == "typetype" then + if nt is TypeDeclType then local def = nt.def if def is RecordLikeType or def is EnumType then if not def.declname then @@ -3530,7 +3518,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod i = parse_body(ps, i, def, nt) - nt.newtype = new_typetype(ps, itype, def) + nt.newtype = new_typedecl(ps, itype, def) return i, asgn end @@ -3915,10 +3903,8 @@ local function recurse_type(ast: Type, visit: Visitor): T end elseif ast is TypeAliasType then table.insert(xs, recurse_type(ast.alias_to, visit)) - else - if ast.def then - table.insert(xs, recurse_type(ast.def, visit)) - end + elseif ast is TypeDeclType then + table.insert(xs, recurse_type(ast.def, visit)) end local ret: T @@ -4385,13 +4371,15 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local function print_record_def(typ: RecordLikeType): string local out: {string} = { "{" } - for _, name in ipairs(typ.field_order) do - local def = typ.fields[name].def - if typ.fields[name].typename == "typetype" and def is RecordLikeType then - table.insert(out, name) - table.insert(out, " = ") - table.insert(out, print_record_def(typ.fields[name].def)) - table.insert(out, ", ") + for fname, ftype in fields_of(typ) do + if ftype is TypeDeclType then + local def = ftype.def + if def is RecordLikeType then + table.insert(out, fname) + table.insert(out, " = ") + table.insert(out, print_record_def(def)) + table.insert(out, ", ") + end end end table.insert(out, "}") @@ -4834,7 +4822,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local nt = node.newtype if nt is TypeAliasType then table.insert(out, table.concat(nt.alias_to.names, ".")) - elseif nt.typename == "typetype" then + elseif nt is TypeDeclType then local def = nt.def if def is RecordLikeType then table.insert(out, print_record_def(def)) @@ -4948,7 +4936,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | } visit_type.cbs["string"] = default_type_visitor - visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["typevar"] = default_type_visitor visit_type.cbs["typearg"] = default_type_visitor @@ -5030,7 +5018,7 @@ local typename_to_typecode : {TypeName:integer} = { ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, ["unresolved"] = tl.typecodes.UNKNOWN, - ["typetype"] = tl.typecodes.UNKNOWN, + ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5039,7 +5027,7 @@ local skip_types: {TypeName: boolean} = { ["none"] = true, ["literal_table_item"] = true, ["unresolved"] = true, - ["typetype"] = true, + ["typedecl"] = true, } local get_typenum: function(trenv: TypeReportEnv, t: Type): integer @@ -5103,7 +5091,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer n = trenv.next_num local rt = t - if rt.typename == "typetype" then + if rt is TypeDeclType then rt = rt.def elseif rt is TypeAliasType then rt = rt.alias_to @@ -5130,7 +5118,7 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer rt = t end end - assert(not (rt.typename == "typetype" or rt is TypeAliasType)) + assert(not (rt is TypeDeclType or rt is TypeAliasType)) if rt is RecordLikeType then -- store record field info @@ -5604,7 +5592,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "" elseif t is TypeAliasType then return "type " .. show(t.alias_to) - elseif t.typename == "typetype" then + elseif t is TypeDeclType then return "type " .. show(t.def) else return "<" .. t.typename .. " " .. tostring(t) .. ">" @@ -5974,7 +5962,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local standard_library: {string:Type} = { ["..."] = a_vararg { STRING }, - ["any"] = a_type("typetype", { def = ANY }), + ["any"] = a_typedecl(ANY), ["arg"] = an_array(STRING), ["assert"] = a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a, OPT(b) }, rets = a_tuple { a } } end), ["collectgarbage"] = a_poly { @@ -6022,8 +6010,8 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, ["tostring"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, - ["FILE"] = a_typetype { - def = a_record { + ["FILE"] = a_typedecl( + a_record { is_userdata = true, fields = { ["close"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple { BOOLEAN, STRING, INTEGER } }, @@ -6045,10 +6033,10 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, meta_fields = { ["__close"] = FUNCTION }, meta_field_order = { "__close" }, - }, - }, - ["metatable"] = a_typetype { - def = a_grecord(1, function(a: Type): RecordType return { + } + ), + ["metatable"] = a_typedecl( + a_grecord(1, function(a: Type): RecordType return { fields = { ["__call"] = a_function { args = a_vararg { a, ANY }, rets = a_vararg { ANY } }, ["__gc"] = a_function { args = a_tuple { a }, rets = a_tuple {} }, @@ -6084,8 +6072,8 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["__le"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, ["__close"] = a_function { args = a_tuple { a }, rets = a_tuple { } }, }, - } end), - }, + } end) + ), ["coroutine"] = a_record { fields = { ["create"] = a_function { args = a_tuple { FUNCTION }, rets = a_tuple { THREAD } }, @@ -6100,9 +6088,9 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} }, ["debug"] = a_record { fields = { - ["Info"] = a_typetype { def = DEBUG_GETINFO_TABLE }, - ["Hook"] = a_typetype { def = DEBUG_HOOK_FUNCTION }, - ["HookEvent"] = a_typetype { def = DEBUG_HOOK_EVENT }, + ["Info"] = a_typedecl(DEBUG_GETINFO_TABLE), + ["Hook"] = a_typedecl(DEBUG_HOOK_FUNCTION), + ["HookEvent"] = a_typedecl(DEBUG_HOOK_EVENT), ["debug"] = a_function { args = a_tuple {}, rets = a_tuple {} }, ["gethook"] = a_function { args = a_tuple { OPT(THREAD) }, rets = a_tuple { DEBUG_HOOK_FUNCTION, INTEGER } }, @@ -6405,7 +6393,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if opts.module_name then - env.modules[opts.module_name] = a_typetype { def = CIRCULAR_REQUIRE } + env.modules[opts.module_name] = a_typedecl(CIRCULAR_REQUIRE) end local lax = opts.lax @@ -6562,14 +6550,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function ensure_not_abstract(where: Where, t: Type) - if not t.is_abstract then - return - end - if t is FunctionType and t.macroexp then error_at(where, "macroexps are abstract; consider using a concrete function") - else - error_at(where, "interfaces are abstract; consider using a concrete record") + elseif t is TypeDeclType then + local def = t.def + if def is InterfaceType then + error_at(where, "interfaces are abstract; consider using a concrete record") + end end end @@ -6582,7 +6569,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = typ.found end for i = 2, #names do - if typ.typename == "typetype" then + if typ is TypeDeclType then typ = typ.def end @@ -6601,7 +6588,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = typ.found end end - if typ.typename == "typetype" or typ is TypeAliasType then + if typ is TypeDeclType or typ is TypeAliasType then return typ elseif accept_typearg and typ is TypeArgType then return typ @@ -6609,18 +6596,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function union_type(t: Type): string, Type - if t.typename == "typetype" then + if t is TypeDeclType then return union_type(t.def), t.def elseif t is TypeAliasType then return union_type(t.alias_to), t.alias_to elseif t is TupleType then return union_type(t.tuple[1]), t.tuple[1] elseif t is NominalType then - local typetype = t.found or find_type(t.names) - if not typetype then + local typedecl = t.found or find_type(t.names) + if not typedecl then return "invalid" end - return union_type(typetype) + return union_type(typedecl) elseif t is RecordLikeType then if t.is_userdata then return "userdata", t @@ -6738,8 +6725,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string or tostring(nfargs or 0) end - local function resolve_typetype(t: Type): Type - if t.typename == "typetype" then + local function resolve_typedecl(t: Type): Type + if t is TypeDeclType then return t.def elseif t is TypeAliasType then return t.alias_to @@ -6808,7 +6795,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string seen[orig_t] = copy copy.opt = t.opt - copy.is_abstract = t.is_abstract copy.typename = t.typename copy.filename = t.filename copy.x = t.x @@ -6840,7 +6826,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t.constraint then copy.constraint, same = resolve(t.constraint, same) end - elseif t.typename == "typetype" then + elseif t is TypeDeclType then + assert(copy is TypeDeclType) copy.def, same = resolve(t.def, same) elseif t is TypeAliasType then assert(copy is TypeAliasType) @@ -7042,7 +7029,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "unused %s %s: %s", var.is_func_arg and "argument" or t is FunctionType and "function" - or t.typename == "typetype" and "type" + or t is TypeDeclType and "type" or t is TypeAliasType and "type" or "variable", name, @@ -7270,7 +7257,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function close_nested_records(t: RecordLikeType) for _, ft in pairs(t.fields) do - if ft.typename == "typetype" then + if ft is TypeDeclType then ft.closed = true local def = ft.def if def is RecordLikeType then @@ -7283,7 +7270,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function close_types(vars: {string:Variable}) for _, var in pairs(vars) do local t = var.t - if t.typename == "typetype" then + if t is TypeDeclType then t.closed = true local def = t.def if def is RecordLikeType then @@ -7311,12 +7298,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if var.used_as_type then var.declared_at.elide_type = true else - if (t.typename == "typetype" or t is TypeAliasType) and not is_global then + if (t is TypeDeclType or t is TypeAliasType) and not is_global then var.declared_at.elide_type = true end table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and (t.typename == "typetype" or t is TypeAliasType) and var.aliasing then + elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -7449,30 +7436,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t.resolved end - local typetype = t.found or find_type(t.names) - if not typetype then + local found = t.found or find_type(t.names) + if not found then error_at(t, "unknown type %s", t) return INVALID end local resolved: Type - if typetype is TypeAliasType then - typetype = typetype.alias_to.found + if found is TypeAliasType then + found = found.alias_to.found end - if typetype.typename == "typetype" then - local def = typetype.def + if found is TypeDeclType then + local def = found.def if def.typename == "circular_require" then -- return, but do not store resolution - return typetype.def + return def end -- FIXME is this block still needed? if def is NominalType then - typetype = def.found - assert(typetype.typename == "typetype") - def = typetype.def + found = def.found + assert(found is TypeDeclType) + def = found.def end assert(not def is NominalType) @@ -7494,7 +7481,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t.y = resolved.y end end - t.found = typetype + t.found = found t.resolved = resolved return resolved end @@ -7511,22 +7498,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t.resolved, aliasing end - local typetype = t.found or find_type(t.names) - if not typetype then + local found = t.found or find_type(t.names) + if not found then error_at(t, "unknown type %s", t) return INVALID end + assert(found is TypeDeclType) + if t.typevals then - local resolved = match_typevals(t, typetype.def) + local resolved = match_typevals(t, found.def) t.resolved = resolved - t.found = typetype - typetype = a_typetype { def = resolved } + t.found = found + found = a_typedecl(resolved) else t.resolved = t end - return typetype, aliasing + return found, aliasing end end @@ -8214,8 +8203,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return compare_map(a.keys, INTEGER, a.values, b.elements) end, }, - ["typetype"] = { - ["record"] = function(a: Type, b: RecordType): boolean, {Error} + ["typedecl"] = { + ["record"] = function(a: TypeDeclType, b: RecordType): boolean, {Error} local def = a.def if def is RecordLikeType then return subtype_record(a.def, b) -- record as prototype @@ -8505,9 +8494,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end -- resolve if prototype - local funcdef = func.def - if func.typename == "typetype" and funcdef is RecordType then - func = func.def + if func is TypeDeclType then + local funcdef = func.def + if funcdef is RecordType then + func = func.def + end end -- resolve if metatable if func is RecordLikeType and func.meta_fields and func.meta_fields["__call"] then @@ -8768,7 +8759,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_typevars_at(where, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typetype_funcall: boolean, is_method: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, is_method: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") @@ -8782,7 +8773,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string argdelta = is_method and -1 or argdelta or 0 if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_typetype { def = args.tuple[1] })) + add_var(nil, "@self", type_at(where, a_typedecl(args.tuple[1]))) end local passes, n = 1, 1 @@ -8801,7 +8792,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then -- a non-"@funcall" means a synthesized call, e.g. from a metamethod - if not is_typetype_funcall then + if not is_typedecl_funcall then add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") end else @@ -8858,15 +8849,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope() - local is_typetype_funcall: boolean + local is_typedecl_funcall: boolean if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then local receiver = node.e1.receiver - if receiver is NominalType and receiver.resolved and receiver.resolved.typename == "typetype" then - is_typetype_funcall = true + if receiver is NominalType then + local resolved = receiver.resolved + if resolved and resolved is TypeDeclType then + is_typedecl_funcall = true + end end end - local ret, f = check_call(node, where_args, func, args, expected_rets, is_typetype_funcall, is_method, argdelta) + local ret, f = check_call(node, where_args, func, args, expected_rets, is_typedecl_funcall, is_method, argdelta) ret = resolve_typevars_at(node, ret) end_scope() @@ -8927,7 +8921,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tbl = find_var_type("string") -- simulate string metatable end - if tbl.typename == "typetype" then + if tbl is TypeDeclType then tbl = tbl.def elseif tbl is TypeAliasType then if tbl.is_nested_alias then @@ -9256,7 +9250,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type local orig_a = a local orig_b = b - a = resolve_typetype(resolve_tuple_and_nominal(a)) + a = resolve_typedecl(resolve_tuple_and_nominal(a)) b = resolve_tuple_and_nominal(b) if lax and is_unknown(a) then @@ -9403,11 +9397,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local t = v.t - if t.closed then - return nil, nil, exp.tk + if t is TypeDeclType then + if t.closed then + return nil, nil, exp.tk + end + + return t.def, v, exp.tk end - return t.def or t, v, exp.tk + return t, v, exp.tk -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") local t, v, rname = find_record_to_extend(exp.e1) @@ -9421,7 +9419,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end t = t.fields[fname] - if t.typename == "typetype" then + if t is TypeDeclType then t = t.def elseif t is TypeAliasType then t = t.alias_to.resolved @@ -9431,9 +9429,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function typetype_to_nominal(where: Where, name: string, t: Type, resolved?: Type): Type - assert(t.typename == "typetype") - + local function typedecl_to_nominal(where: Where, name: string, t: TypeDeclType, resolved?: Type): Type local typevals: {Type} local def = t.def if def is HasTypeArgs then @@ -9461,8 +9457,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - if t.typename == "typetype" then - return typetype_to_nominal(exp, exp.tk, t) + if t is TypeDeclType then + return typedecl_to_nominal(exp, exp.tk, t) else return t end @@ -9474,10 +9470,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if t is NominalType then - local def = t.found and t.found.def - if def is RecordLikeType and def.fields[exp.e2.tk] then - table.insert(t.names, exp.e2.tk) - t.found = def.fields[exp.e2.tk] + local found = t.found + if found then + if found is TypeDeclType then + local def = found.def + if def is RecordLikeType and def.fields[exp.e2.tk] then + table.insert(t.names, exp.e2.tk) + t.found = def.fields[exp.e2.tk] + end + end end elseif t is RecordLikeType then return t.fields and t.fields[exp.e2.tk] @@ -10365,7 +10366,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local missing: {string} for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype.typename == "typetype" or ftype is TypeAliasType) then + if not (ftype is TypeDeclType or ftype is TypeAliasType) then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10410,7 +10411,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local var = resolve_tuple_and_nominal(vartype) - if var.typename == "typetype" or var is TypeAliasType then + if var is TypeDeclType or var is TypeAliasType then error_at(where, "cannot reassign a type") return nil end @@ -10866,7 +10867,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local decltype = resolve_tuple_and_nominal(node.expected) if decltype is TypeVarType and decltype.constraint then - decltype = resolve_typetype(resolve_tuple_and_nominal(decltype.constraint)) + decltype = resolve_typedecl(resolve_tuple_and_nominal(decltype.constraint)) end if decltype is TupleTableType then @@ -10909,7 +10910,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local constraint: Type if decltype is TypeVarType and decltype.constraint then - constraint = resolve_typetype(decltype.constraint) + constraint = resolve_typedecl(decltype.constraint) decltype = resolve_tuple_and_nominal(constraint) end @@ -10960,7 +10961,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not df then error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) else - if df.typename == "typetype" or df is TypeAliasType then + if df is TypeDeclType or df is TypeAliasType then error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) else assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -11177,7 +11178,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope(node) end, before_arguments = function(_node: Node, children: {Type}) - local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) -- add type arguments from the record implicitly if rtype is RecordLikeType and rtype.typeargs then @@ -11195,7 +11196,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rets = children[4] assert(rets is TupleType) - local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) + local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) if lax and rtype.typename == "unknown" then return @@ -11411,7 +11412,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local expected = node.expected and resolve_tuple_and_nominal(node.expected) - if ra.typename == "circular_require" or (ra.def and ra.def.typename == "circular_require") then + if ra.typename == "circular_require" or (ra is TypeDeclType and ra.def and ra.def.typename == "circular_require") then return invalid_at(node, "cannot dereference a type from a circular require") end @@ -11426,12 +11427,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end ensure_not_abstract(node.e1, ra) - if ra.typename == "typetype" and ra.def.typename == "record" then + if ra is TypeDeclType and ra.def.typename == "record" then ra = ra.def end if rb then ensure_not_abstract(node.e2, rb) - if rb.typename == "typetype" and rb.def.typename == "record" then + if rb is TypeDeclType and rb.def.typename == "record" then rb = rb.def end end @@ -11475,10 +11476,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if rb.typename == "integer" then all_needs_compat["math"] = true end - if ra.typename == "typetype" then + if ra is TypeDeclType then error_at(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typetype(rb), orig_a, orig_b) + check_metamethod(node, "__is", ra, resolve_typedecl(rb), orig_a, orig_b) node.known = IsFact { var = node.e1.tk, typ = b, where = node } else error_at(node, "can only use 'is' on variables") @@ -11743,8 +11744,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(node, "unknown variable: " .. node.tk) end - if t.typename == "typetype" then - t = typetype_to_nominal(node, node.tk, t, t) + if t is TypeDeclType then + t = typedecl_to_nominal(node, node.tk, t, t) end return t @@ -11928,13 +11929,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = { before = function(typ: Type) begin_scope() - add_var(nil, "@self", type_at(typ, a_typetype({ def = typ }))) + add_var(nil, "@self", type_at(typ, a_typedecl(typ))) for fname, ftype in fields_of(typ) do if ftype is TypeAliasType then resolve_nominal(ftype.alias_to) add_var(nil, fname, ftype) - elseif ftype.typename == "typetype" then + elseif ftype is TypeDeclType then add_var(nil, fname, ftype) end end @@ -12067,7 +12068,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local tv = typ as TypeVarType tv.typevar = t.typearg tv.constraint = t.constraint - elseif t.typename == "typetype" then + elseif t is TypeDeclType then if t.def.typename ~= "circular_require" then typ.found = t end @@ -12150,7 +12151,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typetype"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor From 7031282a2428657aaab7a48a2479736700d95db5 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 7 Jan 2024 00:29:44 -0300 Subject: [PATCH 086/224] StringType --- spec/inference/table_literal_spec.lua | 2 +- tl.lua | 55 ++++++++-------- tl.tl | 93 ++++++++++++++------------- 3 files changed, 76 insertions(+), 74 deletions(-) diff --git a/spec/inference/table_literal_spec.lua b/spec/inference/table_literal_spec.lua index 48c7720ab..71d5d91c7 100644 --- a/spec/inference/table_literal_spec.lua +++ b/spec/inference/table_literal_spec.lua @@ -15,7 +15,7 @@ describe("bidirectional inference for table literals", function() } print(x) ]], { - { msg = "in record field: type: string 'who' is not a member of enum" }, + { msg = "in record field: type: string \"who\" is not a member of enum" }, })) it("directed inference produces correct results for incomplete records (regression test for #348)", util.check([[ diff --git a/tl.lua b/tl.lua index f249b1774..ae98ce651 100644 --- a/tl.lua +++ b/tl.lua @@ -1387,6 +1387,9 @@ local table_types = { + + + @@ -1687,7 +1690,7 @@ local function new_type(ps, i, typename) filename = ps.filename, y = token.y, x = token.x, - tk = token.tk, + }) end @@ -5572,7 +5575,7 @@ local function show_type_base(t, short, seen) return "string" else return t.typename .. - (t.tk and " " .. t.tk or "") + (t.literal and string.format(" %q", t.literal) or "") end elseif t.typename == "typevar" then return display_typevar(t.typevar) @@ -7097,12 +7100,12 @@ tl.type_check = function(ast, opts) end local function drop_constant_value(t) - if not t.tk then - return t + if t.typename == "string" and t.literal then + local ret = shallow_copy_table(t) + ret.literal = nil + return ret end - local ret = shallow_copy_table(t) - ret.tk = nil - return ret + return t end local function reserve_symbol_list_slot(node) @@ -7616,7 +7619,7 @@ tl.type_check = function(ast, opts) table.insert(stack, s) end else - if primitive[t.typename] and (flatten_constants or not t.tk) then + if primitive[t.typename] and (flatten_constants or (t.typename == "string" and not t.literal)) then if not types_seen[t.typename] then types_seen[t.typename] = true table.insert(ts, t) @@ -8065,11 +8068,11 @@ tl.type_check = function(ast, opts) }, ["string"] = { ["enum"] = function(a, b) - if not a.tk then + if not a.literal then return false, { Err(a, "string is not a %s", b) } end - if b.enumset[unquote(a.tk)] then + if b.enumset[a.literal] then return true end @@ -9299,8 +9302,8 @@ a.types[i], b.types[i]), } end errm, erra, errb = "wrong index type: got %s, expected %s", orig_b, a.keys - elseif bnode.kind == "string" then - local t, e = match_record_key(orig_a, anode, bnode.conststr) + elseif b.typename == "string" and b.literal then + local t, e = match_record_key(orig_a, anode, b.literal) if t then return t end @@ -9342,7 +9345,8 @@ a.types[i], b.types[i]), } else if not is_a(new, old) then if old.typename == "map" and new.fields then - if old.keys.typename == "string" then + local old_keys = old.keys + if old_keys.typename == "string" then for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) end @@ -9371,15 +9375,13 @@ a.types[i], b.types[i]), } old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") assert(old.typename == "map") old.keys = STRING old.values = values elseif old.typename == "union" then edit_type(old, "union") - new.tk = nil - table.insert(old.types, new) + table.insert(old.types, drop_constant_value(new)) else return unite({ old, new }, true) end @@ -10164,8 +10166,7 @@ a.types[i], b.types[i]), } end else is_map = true - child.ktype.tk = nil - keys = expand_type(node, keys, child.ktype) + keys = expand_type(node, keys, drop_constant_value(child.ktype)) values = expand_type(node, values, uvtype) end end @@ -10804,8 +10805,7 @@ expand_type(node, values, elements) }) if not expected then expected = infer_at(node, got) - module_type = resolve_tuple_and_nominal(expected) - module_type.tk = nil + module_type = drop_constant_value(resolve_tuple_and_nominal(expected)) st[2]["@return"] = { t = expected } end local expected_t = expected.tuple @@ -11446,9 +11446,8 @@ expand_type(node, values, elements) }) x = node.e2.x, tk = node.e2.tk, kind = "string", - conststr = node.e2.tk, } - local btype = type_at(node.e2, a_type("string", { tk = '"' .. node.e2.tk .. '"' })) + local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk })) local t = type_check_index(node.e1, bnode, orig_a, btype) if t.needs_compat and opts.gen_compat ~= "off" then @@ -11556,7 +11555,7 @@ expand_type(node, values, elements) }) else t = resolve_tuple(a) end - t.tk = nil + t = drop_constant_value(t) end if t then @@ -11571,7 +11570,7 @@ expand_type(node, values, elements) }) if ra.typename == "enum" and rb.typename == "string" then - if not (rb.tk and ra.enumset[unquote(rb.tk)]) then + if not (rb.literal and ra.enumset[rb.literal]) then return invalid_at(node, "%s is not a member of %s", b, a) end elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then @@ -11809,12 +11808,13 @@ expand_type(node, values, elements) }) local function after_literal(node) node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind, { tk = node.tk })) + return type_at(node, a_type(node.kind, {})) end visit_node.cbs["string"] = { after = function(node, _children) local t = after_literal(node) + t.literal = node.conststr local expected = node.expected if expected and expected.typename == "enum" and is_a(t, expected) then @@ -11983,12 +11983,13 @@ expand_type(node, values, elements) }) local record_name = typ.declname if record_name then local selfarg = fargs[1] - if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then ftype.is_method = false selfarg.is_self = false elseif typ.typeargs then for j = 1, #typ.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + local tv = selfarg.typevals[j] + if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then ftype.is_method = false selfarg.is_self = false break diff --git a/tl.tl b/tl.tl index ed58a1773..6bea87745 100644 --- a/tl.tl +++ b/tl.tl @@ -1071,27 +1071,30 @@ local table_types : {TypeName:boolean} = { local interface Type where self.typename + typename: TypeName -- discriminator + typeid: integer -- unique identifier + y: integer x: integer - filename: string - typename: TypeName - tk: string - - typeid: integer - yend: integer xend: integer + filename: string inferred_at: Where -- Lua compatibilty needs_compat: boolean - -- arguments: optional arity - opt: boolean + -- markers for arguments: + opt: boolean -- optional arity + is_self: boolean -- used as self +end - -- function argument - is_self: boolean +local record StringType + is Type + where self.typename == "string" + + literal: string end local record TypeDeclType @@ -1687,7 +1690,7 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type filename = ps.filename, y = token.y, x = token.x, - tk = token.tk + --tk = token.tk }) end @@ -5567,12 +5570,12 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str or t.typename == "boolean" or t.typename == "thread" then return t.typename - elseif t.typename == "string" then + elseif t is StringType then if short then return "string" else return t.typename .. - (t.tk and " " .. t.tk or "") + (t.literal and string.format(" %q", t.literal) or "") end elseif t is TypeVarType then return display_typevar(t.typevar) @@ -6751,7 +6754,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rt = find_var_type(t.typevar) if not rt then return nil - elseif rt.typename == "string" then + elseif rt is StringType then -- tk is not propagated return STRING end @@ -7097,12 +7100,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function drop_constant_value(t: Type): Type - if not t.tk then - return t + if t is StringType and t.literal then + local ret = shallow_copy_table(t) + ret.literal = nil + return ret end - local ret = shallow_copy_table(t) - ret.tk = nil - return ret + return t end local function reserve_symbol_list_slot(node: Node) @@ -7616,7 +7619,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string table.insert(stack, s) end else - if primitive[t.typename] and (flatten_constants or not t.tk) then + if primitive[t.typename] and (flatten_constants or (t is StringType and not t.literal)) then if not types_seen[t.typename] then types_seen[t.typename] = true table.insert(ts, t) @@ -8064,12 +8067,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a: Type, b: EnumType): boolean, {Error} - if not a.tk then + ["enum"] = function(a: StringType, b: EnumType): boolean, {Error} + if not a.literal then return false, { Err(a, "string is not a %s", b) } end - if b.enumset[unquote(a.tk)] then + if b.enumset[a.literal] then return true end @@ -8917,7 +8920,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string tbl = resolve_tuple_and_nominal(tbl) - if tbl.typename == "string" or tbl is EnumType then + if tbl is StringType or tbl is EnumType then tbl = find_var_type("string") -- simulate string metatable end @@ -9299,8 +9302,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end errm, erra, errb = "wrong index type: got %s, expected %s", orig_b, a.keys - elseif bnode.kind == "string" then - local t, e = match_record_key(orig_a, anode, bnode.conststr) + elseif b is StringType and b.literal then + local t, e = match_record_key(orig_a, anode, b.literal) if t then return t end @@ -9342,7 +9345,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else if not is_a(new, old) then if old is MapType and new is RecordLikeType then - if old.keys.typename == "string" then + local old_keys = old.keys + if old_keys is StringType then for _, ftype in fields_of(new) do old.values = expand_type(where, old.values, ftype) end @@ -9370,7 +9374,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.field_order = nil old.meta_fields = nil old.meta_fields = nil - -- FIXME what about meta_fields edit_type(old, "map") assert(old is MapType) @@ -9378,8 +9381,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.values = values elseif old is UnionType then edit_type(old, "union") - new.tk = nil - table.insert(old.types, new) + table.insert(old.types, drop_constant_value(new)) else return unite({ old, new }, true) end @@ -10164,8 +10166,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else is_map = true - child.ktype.tk = nil - keys = expand_type(node, keys, child.ktype) + keys = expand_type(node, keys, drop_constant_value(child.ktype)) values = expand_type(node, values, uvtype) end end @@ -10189,7 +10190,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } as RecordType) -- TODO adopt logic from is_array below when we accept tupletable as an interface elseif is_record and is_map then - if keys.typename == "string" then + if keys is StringType then for _, fname in ipairs(field_order) do values = expand_type(node, values, fields[fname]) end @@ -10804,8 +10805,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not expected then -- if at the toplevel expected = infer_at(node, got) - module_type = resolve_tuple_and_nominal(expected) - module_type.tk = nil + module_type = drop_constant_value(resolve_tuple_and_nominal(expected)) st[2]["@return"] = { t = expected } end local expected_t = expected.tuple @@ -11446,9 +11446,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string x = node.e2.x, tk = node.e2.tk, kind = "string", - conststr = node.e2.tk, } - local btype = type_at(node.e2, a_type("string", { tk = '"' ..node.e2.tk .. '"' })) + local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk } as StringType)) local t = type_check_index(node.e1, bnode, orig_a, btype) if t.needs_compat and opts.gen_compat ~= "off" then @@ -11527,8 +11526,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = a - elseif ((ra is EnumType and rb.typename == "string" and is_a(rb, ra)) - or (ra.typename == "string" and rb is EnumType and is_a(ra, rb))) then + elseif ((ra is EnumType and rb is StringType and is_a(rb, ra)) + or (ra is StringType and rb is EnumType and is_a(ra, rb))) then node.known = nil t = (ra is EnumType and ra or rb) @@ -11556,7 +11555,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else t = resolve_tuple(a) end - t.tk = nil + t = drop_constant_value(t) end if t then @@ -11570,8 +11569,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end - if ra is EnumType and rb.typename == "string" then - if not (rb.tk and ra.enumset[unquote(rb.tk)]) then + if ra is EnumType and rb is StringType then + if not (rb.literal and ra.enumset[rb.literal]) then return invalid_at(node, "%s is not a member of %s", b, a) end elseif ra is TupleTableType and rb is TupleTableType and #ra.types ~= #rb.types then @@ -11809,12 +11808,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function after_literal(node: Node): Type node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind as TypeName, { tk = node.tk })) + return type_at(node, a_type(node.kind as TypeName, {})) end visit_node.cbs["string"] = { after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) + local t = after_literal(node) as StringType + t.literal = node.conststr local expected = node.expected if expected and expected is EnumType and is_a(t, expected) then @@ -11983,12 +11983,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local record_name = typ.declname if record_name then local selfarg = fargs[1] as NominalType - if selfarg.tk ~= record_name or (typ.typeargs and not selfarg.typevals) then + if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then ftype.is_method = false selfarg.is_self = false elseif typ.typeargs then for j=1,#typ.typeargs do - if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= typ.typeargs[j].typearg then + local tv = selfarg.typevals[j] + if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then ftype.is_method = false selfarg.is_self = false break From 69f7a4b61c560a7b2d4334102d3a2b9fe71d1109 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 7 Jan 2024 01:12:11 -0300 Subject: [PATCH 087/224] use interface subtyping for Fact --- tl.lua | 30 +----------------------------- tl.tl | 54 +++++++++++++----------------------------------------- 2 files changed, 14 insertions(+), 70 deletions(-) diff --git a/tl.lua b/tl.lua index ae98ce651..57b4828b1 100644 --- a/tl.lua +++ b/tl.lua @@ -1361,22 +1361,6 @@ local table_types = { - - - - - - - - - - - - - - - - @@ -1401,8 +1385,6 @@ local TruthyFact = {} - - local NotFact = {} @@ -1412,8 +1394,6 @@ local NotFact = {} - - local AndFact = {} @@ -1424,8 +1404,6 @@ local AndFact = {} - - local OrFact = {} @@ -1436,8 +1414,6 @@ local OrFact = {} - - local EqFact = {} @@ -1448,8 +1424,6 @@ local EqFact = {} - - local IsFact = {} @@ -1469,8 +1443,6 @@ local IsFact = {} - - @@ -9807,7 +9779,7 @@ a.types[i], b.types[i]), } return eval_not(f.f1) elseif f.fact == "and" then return and_facts(eval_fact(f.f1), eval_fact(f.f2)) - else + elseif f.fact == "or" then return or_facts(eval_fact(f.f1), eval_fact(f.f2)) end end diff --git a/tl.tl b/tl.tl index 6bea87745..330e78666 100644 --- a/tl.tl +++ b/tl.tl @@ -1371,55 +1371,33 @@ local enum FactType "truthy" -- expression that is either truthy or a runtime error end ---local record Fact --- fact: FactType --- where: Node --- --- -- is --- var: string --- typ: Type --- --- -- not, and, or --- f1: Fact --- f2: Fact --- --- metamethod __call: function(Fact, Fact): Fact ---end - -local type Fact - = TruthyFact - | NotFact - | AndFact - | OrFact - | IsFact - | EqFact - -local record TruthyFact - where self.fact == "truthy" +local interface Fact + where self.fact fact: FactType where: Where +end + +local record TruthyFact + is Fact + where self.fact == "truthy" metamethod __call: function(Fact, Fact): TruthyFact end local record NotFact + is Fact where self.fact == "not" - fact: FactType - where: Where - f1: Fact metamethod __call: function(Fact, Fact): NotFact end local record AndFact + is Fact where self.fact == "and" - fact: FactType - where: Where - f1: Fact f2: Fact @@ -1427,11 +1405,9 @@ local record AndFact end local record OrFact + is Fact where self.fact == "or" - fact: FactType - where: Where - f1: Fact f2: Fact @@ -1439,11 +1415,9 @@ local record OrFact end local record EqFact + is Fact where self.fact == "==" - fact: FactType - where: Where - var: string typ: Type @@ -1451,11 +1425,9 @@ local record EqFact end local record IsFact + is Fact where self.fact == "is" - fact: FactType - where: Where - var: string typ: Type @@ -9807,7 +9779,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return eval_not(f.f1) elseif f is AndFact then return and_facts(eval_fact(f.f1), eval_fact(f.f2)) - else -- f is OrFact + elseif f is OrFact then return or_facts(eval_fact(f.f1), eval_fact(f.f2)) end end From 55be2434a4af8f244ed2bfdd662ec2aa69b9240e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 7 Jan 2024 01:43:37 -0300 Subject: [PATCH 088/224] remove is_self from Type --- tl.lua | 40 +++++++++++++++++----------------------- tl.tl | 48 +++++++++++++++++++++--------------------------- 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/tl.lua b/tl.lua index 57b4828b1..0e414c913 100644 --- a/tl.lua +++ b/tl.lua @@ -1373,8 +1373,6 @@ local table_types = { - - @@ -1996,15 +1994,12 @@ local function parse_function_type(ps, i) i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end if ps.tokens[i].tk == "(" then - i, typ.args = parse_argument_type_list(ps, i) + i, typ.args, typ.is_method = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else typ.args = a_vararg({ ANY }) typ.rets = a_vararg({ ANY }) end - if typ.args.tuple[1] and typ.args.tuple[1].is_self then - typ.is_method = true - end return i, typ end @@ -2669,11 +2664,13 @@ end + local function parse_argument_type(ps, i) + local opt = false local is_va = false + local is_self = false local argument_name = nil - local opt = false if ps.tokens[i].kind == "identifier" then argument_name = ps.tokens[i].tk if ps.tokens[i + 1].tk == "?" then @@ -2708,29 +2705,28 @@ local function parse_argument_type(ps, i) end if argument_name == "self" then - typ = shallow_copy_new_type(typ) - typ.is_self = true + is_self = true end end - return i, { i = i, type = typ, is_va = is_va }, 0 + return i, { i = i, type = typ, is_va = is_va, is_self = is_self }, 0 end parse_argument_type_list = function(ps, i) - local tvs = {} - i = parse_bracket_list(ps, i, tvs, "(", ")", "sep", parse_argument_type) + local ars = {} + i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) local t, list = new_tuple(ps, i) - local n = #tvs - for l, tv in ipairs(tvs) do - list[l] = tv.type - if tv.is_va and l < n then - fail(ps, tv.i, "'...' can only be last argument") + local n = #ars + for l, ar in ipairs(ars) do + list[l] = ar.type + if ar.is_va and l < n then + fail(ps, ar.i, "'...' can only be last argument") end end - if tvs[n] and tvs[n].is_va then + if n > 0 and ars[n].is_va then t.is_va = true end - return i, t + return i, t, (n > 0 and ars[1].is_self) end local function parse_identifier(ps, i) @@ -7915,7 +7911,7 @@ tl.type_check = function(ast, opts) local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then - if a.is_method ~= b.is_method then + if (not not a.is_method) ~= (not not b.is_method) then return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } end return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } @@ -11951,19 +11947,17 @@ expand_type(node, values, elements) }) if ftype.is_method then local fargs = ftype.args.tuple - if fargs[1] and fargs[1].is_self then + if fargs[1] then local record_name = typ.declname if record_name then local selfarg = fargs[1] if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then ftype.is_method = false - selfarg.is_self = false elseif typ.typeargs then for j = 1, #typ.typeargs do local tv = selfarg.typevals[j] if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then ftype.is_method = false - selfarg.is_self = false break end end diff --git a/tl.tl b/tl.tl index 330e78666..01f775cd8 100644 --- a/tl.tl +++ b/tl.tl @@ -1085,9 +1085,7 @@ local interface Type -- Lua compatibilty needs_compat: boolean - -- markers for arguments: opt: boolean -- optional arity - is_self: boolean -- used as self end local record StringType @@ -1587,7 +1585,7 @@ local parse_expression: function(ParseState, integer): integer, Node, integer local parse_expression_and_tk: function(ps: ParseState, i: integer, tk: string): integer, Node local parse_statements: function(ParseState, integer, ? boolean): integer, Node local parse_argument_list: function(ParseState, integer): integer, Node -local parse_argument_type_list: function(ParseState, integer): integer, TupleType +local parse_argument_type_list: function(ParseState, integer): integer, TupleType, boolean local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer @@ -1996,15 +1994,12 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end if ps.tokens[i].tk == "(" then - i, typ.args = parse_argument_type_list(ps, i) + i, typ.args, typ.is_method = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else typ.args = a_vararg { ANY } typ.rets = a_vararg { ANY } end - if typ.args.tuple[1] and typ.args.tuple[1].is_self then - typ.is_method = true - end return i, typ end @@ -2663,17 +2658,19 @@ parse_argument_list = function(ps: ParseState, i: integer): integer, Node return i, node end -local record TypeAndVararg +local record ArgumentInfo i: integer type: Type is_va: boolean + is_self: boolean end -local function parse_argument_type(ps: ParseState, i: integer): integer, TypeAndVararg, integer +local function parse_argument_type(ps: ParseState, i: integer): integer, ArgumentInfo, integer + local opt = false local is_va = false + local is_self = false local argument_name: string = nil - local opt = false if ps.tokens[i].kind == "identifier" then argument_name = ps.tokens[i].tk if ps.tokens[i + 1].tk == "?" then @@ -2708,29 +2705,28 @@ local function parse_argument_type(ps: ParseState, i: integer): integer, TypeAnd end if argument_name == "self" then - typ = shallow_copy_new_type(typ) - typ.is_self = true + is_self = true end end - return i, { i = i, type = typ, is_va = is_va }, 0 + return i, { i = i, type = typ, is_va = is_va, is_self = is_self }, 0 end -parse_argument_type_list = function(ps: ParseState, i: integer): integer, Type - local tvs: {TypeAndVararg} = {} - i = parse_bracket_list(ps, i, tvs, "(", ")", "sep", parse_argument_type) +parse_argument_type_list = function(ps: ParseState, i: integer): integer, Type, boolean + local ars: {ArgumentInfo} = {} + i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) local t, list = new_tuple(ps, i) - local n = #tvs - for l, tv in ipairs(tvs) do - list[l] = tv.type - if tv.is_va and l < n then - fail(ps, tv.i, "'...' can only be last argument") + local n = #ars + for l, ar in ipairs(ars) do + list[l] = ar.type + if ar.is_va and l < n then + fail(ps, ar.i, "'...' can only be last argument") end end - if tvs[n] and tvs[n].is_va then + if n > 0 and ars[n].is_va then t.is_va = true end - return i, t + return i, t, (n > 0 and ars[1].is_self) end local function parse_identifier(ps: ParseState, i: integer): integer, Node, integer @@ -7915,7 +7911,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then - if a.is_method ~= b.is_method then + if (not not a.is_method) ~= (not not b.is_method) then return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } end return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } @@ -11951,19 +11947,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ftype.is_method then local fargs = ftype.args.tuple - if fargs[1] and fargs[1].is_self then + if fargs[1] then local record_name = typ.declname if record_name then local selfarg = fargs[1] as NominalType if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then ftype.is_method = false - selfarg.is_self = false elseif typ.typeargs then for j=1,#typ.typeargs do local tv = selfarg.typevals[j] if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then ftype.is_method = false - selfarg.is_self = false break end end From b2fe2c957ad041c4ffda351943255436707d7c30 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 7 Jan 2024 03:24:44 -0300 Subject: [PATCH 089/224] fix: may add meta_fields when expanding interface --- tl.lua | 6 +++++- tl.tl | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 0e414c913..fee9c6a44 100644 --- a/tl.lua +++ b/tl.lua @@ -11868,7 +11868,11 @@ expand_type(node, values, elements) }) local ri = resolve_nominal(iface) assert(ri.typename == "interface") add_interface_fields("field", t.fields, t.field_order, ri, iface) - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + if ri.meta_fields then + t.meta_fields = t.meta_fields or {} + t.meta_field_order = t.meta_field_order or {} + add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + end else if not t.elements then t.elements = iface diff --git a/tl.tl b/tl.tl index 01f775cd8..5073e4e03 100644 --- a/tl.tl +++ b/tl.tl @@ -11868,7 +11868,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ri = resolve_nominal(iface) assert(ri is InterfaceType) add_interface_fields("field", t.fields, t.field_order, ri, iface) - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + if ri.meta_fields then + t.meta_fields = t.meta_fields or {} + t.meta_field_order = t.meta_field_order or {} + add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + end else if not t.elements then t.elements = iface From fd7538f6615670899a100ac8a12bab968a29052a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 7 Jan 2024 03:39:26 -0300 Subject: [PATCH 090/224] compute min_arity during parsing, not type checking --- tl.lua | 676 ++++++++++++++++++++++++++++---------------------------- tl.tl | 680 +++++++++++++++++++++++++++++---------------------------- 2 files changed, 690 insertions(+), 666 deletions(-) diff --git a/tl.lua b/tl.lua index fee9c6a44..67dee5ec4 100644 --- a/tl.lua +++ b/tl.lua @@ -1369,9 +1369,6 @@ local table_types = { - - - @@ -1559,6 +1556,7 @@ local Node = {ExpectedContext = {}, } + local function is_number_type(t) @@ -1678,10 +1676,6 @@ end -local function c_tuple(t) - return a_type("tuple", { tuple = t }) -end - @@ -1691,9 +1685,70 @@ end local function a_function(t) + assert(t.min_arity) return a_type("function", t) end + + + + + + +local function OPT(t) + return { opttype = t } +end + + + + + + + +local function va_args(args) + args.is_va = true + return args +end + + + + + + + + + +local function a_fn(f) + local args_t = a_type("tuple", { tuple = {} }) + local tup = args_t.tuple + args_t.is_va = f.args.is_va + local min_arity = f.args.is_va and -1 or 0 + for _, a in ipairs(f.args) do + if a.opttype then + table.insert(tup, a.opttype) + else + table.insert(tup, a) + min_arity = min_arity + 1 + end + end + + local rets_t = a_type("tuple", { tuple = {} }) + tup = rets_t.tuple + rets_t.is_va = f.rets.is_va + for _, a in ipairs(f.rets) do + assert(a.typename) + table.insert(tup, a) + end + + return a_type("function", { + args = args_t, + rets = rets_t, + min_arity = min_arity, + needs_compat = f.needs_compat, + typeargs = f.typeargs, + }) +end + local function a_vararg(t) local typ = a_type("tuple", { tuple = t }) typ.is_va = true @@ -1994,7 +2049,7 @@ local function parse_function_type(ps, i) i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method = parse_argument_type_list(ps, i) + i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else typ.args = a_vararg({ ANY }) @@ -2014,19 +2069,6 @@ local simple_types = { ["integer"] = INTEGER, } -local memoize_opt_types = {} - -local function OPT(t) - if memoize_opt_types[t] then - return memoize_opt_types[t] - end - - local ot = shallow_copy_new_type(t) - ot.opt = true - memoize_opt_types[t] = ot - return ot -end - local function parse_simple_type_or_nominal(ps, i) local tk = ps.tokens[i].tk local st = simple_types[tk] @@ -2194,7 +2236,7 @@ local function parse_function_args_rets_body(ps, i, node) if ps.tokens[i].tk == "<" then i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end - i, node.args = parse_argument_list(ps, i) + i, node.args, node.min_arity = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.body = parse_statements(ps, i) end_at(node, ps.tokens[i]) @@ -2643,6 +2685,7 @@ parse_argument_list = function(ps, i) local node = new_node(ps.tokens, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false + local min_arity = 0 for a, fnarg in ipairs(node) do if fnarg.tk == "..." then if a ~= #node then @@ -2653,9 +2696,11 @@ parse_argument_list = function(ps, i) opts = true elseif opts then return fail(ps, i, "non-optional arguments cannot follow optional arguments") + else + min_arity = min_arity + 1 end end - return i, node + return i, node, min_arity end @@ -2665,6 +2710,7 @@ end + local function parse_argument_type(ps, i) local opt = false local is_va = false @@ -2700,16 +2746,12 @@ local function parse_argument_type(ps, i) is_va = true end - if opt then - typ = OPT(typ) - end - if argument_name == "self" then is_self = true end end - return i, { i = i, type = typ, is_va = is_va, is_self = is_self }, 0 + return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 end parse_argument_type_list = function(ps, i) @@ -2717,16 +2759,20 @@ parse_argument_type_list = function(ps, i) i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) local t, list = new_tuple(ps, i) local n = #ars + local min_arity = 0 for l, ar in ipairs(ars) do list[l] = ar.type if ar.is_va and l < n then fail(ps, ar.i, "'...' can only be last argument") end + if not ar.opt then + min_arity = min_arity + 1 + end end if n > 0 and ars[n].is_va then t.is_va = true end - return i, t, (n > 0 and ars[1].is_self) + return i, t, (n > 0 and ars[1].is_self), min_arity end local function parse_identifier(ps, i) @@ -2783,6 +2829,7 @@ local function parse_function(ps, i, fk) i = parse_function_args_rets_body(ps, i, fn) if fn.is_method then table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + fn.min_arity = fn.min_arity + 1 end if not fn.name then @@ -3066,7 +3113,7 @@ local function parse_macroexp(ps, istart, iargs) local node = new_node(ps.tokens, istart, "macroexp") local i - i, node.args = parse_argument_list(ps, iargs) + i, node.args, node.min_arity = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) @@ -3085,6 +3132,7 @@ local function parse_where_clause(ps, i) node.args[1] = new_node(ps.tokens, i, "argument") node.args[1].tk = "self" node.args[1].argtype = selftype + node.min_arity = 1 node.rets = new_tuple(ps, i) node.rets.tuple[1] = BOOLEAN i, node.exp = parse_expression(ps, i) @@ -3167,6 +3215,7 @@ parse_record_body = function(ps, i, def, node) local typ = new_type(ps, wstart, "function") typ.is_method = true + typ.min_arity = 1 typ.args = a_type("tuple", { tuple = { a_type("nominal", { y = typ.y, @@ -5155,10 +5204,10 @@ local INVALID = a_type("invalid", {}) local UNKNOWN = a_type("unknown", {}) local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }) +local FUNCTION = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }) local NOMINAL_FILE = a_type("nominal", { names = { "FILE" } }) -local XPCALL_MSGH_FUNCTION = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = {} }) }) +local XPCALL_MSGH_FUNCTION = a_fn({ args = { ANY }, rets = {} }) local USERDATA = ANY @@ -5518,7 +5567,7 @@ local function show_type_base(t, short, seen) for i, v in ipairs(t.args.tuple) do if not t.is_method or i > 1 then table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " or - v.opt and "? " or + (i > t.min_arity) and "? " or "") .. show(v)) end end @@ -5824,7 +5873,7 @@ local function init_globals(lax) return t end - local function a_gfunction(n, f, typename) + local function a_generic(n, f) local typevars = {} local typeargs = {} local c = string.byte("A") - 1 @@ -5835,12 +5884,18 @@ local function init_globals(lax) typeargs[i] = a_type("typearg", { typearg = name }) end local t = f(_tl_table_unpack(typevars)) - t.typeargs = typeargs - return a_type(typename or "function", t) + if t.typename == "function" or t.typename == "record" then + t.typeargs = typeargs + end + return t + end + + local function a_gfunction(n, f) + return a_generic(n, function(...) return a_fn(f(...)) end) end local function a_grecord(n, f) - local t = a_gfunction(n, f, "record") + local t = a_generic(n, f) t.field_order = sorted_keys(t.fields) return t end @@ -5861,12 +5916,16 @@ local function init_globals(lax) + local function id(x) + return x + end + local file_reader_poly_types = { - { ctor = a_vararg, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) } }) }, rets = { STRING } }, - { ctor = c_tuple, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, - { ctor = a_vararg, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) } }) }, rets = { a_type("union", { types = { STRING, NUMBER } }) } }, - { ctor = a_vararg, args = { a_type("union", { types = { NUMBER, STRING } }) }, rets = { STRING } }, - { ctor = a_vararg, args = {}, rets = { STRING } }, + { ctor = va_args, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) } }) }, rets = { STRING } }, + { ctor = id, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, + { ctor = va_args, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) } }) }, rets = { a_type("union", { types = { STRING, NUMBER } }) } }, + { ctor = va_args, args = { a_type("union", { types = { NUMBER, STRING } }) }, rets = { STRING } }, + { ctor = va_args, args = {}, rets = { STRING } }, } local function a_file_reader(fn) @@ -5879,7 +5938,7 @@ local function init_globals(lax) return t end - local LOAD_FUNCTION = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { STRING } }) }) + local LOAD_FUNCTION = a_fn({ args = {}, rets = { STRING } }) local OS_DATE_TABLE = a_record({ fields = { @@ -5916,12 +5975,12 @@ local function init_globals(lax) local DEBUG_HOOK_EVENT = an_enum({ "call", "tail call", "return", "line", "count" }) - local DEBUG_HOOK_FUNCTION = a_function({ - args = a_type("tuple", { tuple = { DEBUG_HOOK_EVENT, INTEGER } }), - rets = a_type("tuple", { tuple = {} }), + local DEBUG_HOOK_FUNCTION = a_fn({ + args = { DEBUG_HOOK_EVENT, INTEGER }, + rets = {}, }) - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a, a } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = { a, a }, rets = { BOOLEAN } } end) local metatable_nominals = {} @@ -5935,71 +5994,71 @@ local function init_globals(lax) ["..."] = a_vararg({ STRING }), ["any"] = a_type("typedecl", { def = ANY }), ["arg"] = a_type("array", { elements = STRING }), - ["assert"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a, OPT(b) } }), rets = a_type("tuple", { tuple = { a } }) } end), + ["assert"] = a_gfunction(2, function(a, b) return { args = { a, OPT(b) }, rets = { a } } end), ["collectgarbage"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { an_enum({ "collect", "count", "stop", "restart" }) } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - a_function({ args = a_type("tuple", { tuple = { an_enum({ "step", "setpause", "setstepmul" }), NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - a_function({ args = a_type("tuple", { tuple = { an_enum({ "isrunning" }) } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { a_type("union", { types = { BOOLEAN, NUMBER } }) } }) }), + a_fn({ args = { an_enum({ "collect", "count", "stop", "restart" }) }, rets = { NUMBER } }), + a_fn({ args = { an_enum({ "step", "setpause", "setstepmul" }), NUMBER }, rets = { NUMBER } }), + a_fn({ args = { an_enum({ "isrunning" }) }, rets = { BOOLEAN } }), + a_fn({ args = { STRING, OPT(NUMBER) }, rets = { a_type("union", { types = { BOOLEAN, NUMBER } }) } }), } }), - ["dofile"] = a_function({ args = a_type("tuple", { tuple = { OPT(STRING) } }), rets = a_vararg({ ANY }) }), - ["error"] = a_function({ args = a_type("tuple", { tuple = { ANY, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = {} }) }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { METATABLE(a) } }) } end), - ["ipairs"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }) } }), rets = a_type("tuple", { tuple = { - a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { INTEGER, a } }) }), -} }), } end), - ["load"] = a_function({ args = a_type("tuple", { tuple = { a_type("union", { types = { STRING, LOAD_FUNCTION } }), OPT(STRING), OPT(STRING), OPT(TABLE) } }), rets = a_type("tuple", { tuple = { FUNCTION, STRING } }) }), - ["loadfile"] = a_function({ args = a_type("tuple", { tuple = { OPT(STRING), OPT(STRING), OPT(TABLE) } }), rets = a_type("tuple", { tuple = { FUNCTION, STRING } }) }), + ["dofile"] = a_fn({ args = { OPT(STRING) }, rets = va_args({ ANY }) }), + ["error"] = a_fn({ args = { ANY, OPT(NUMBER) }, rets = {} }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = { a }, rets = { METATABLE(a) } } end), + ["ipairs"] = a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }) }, rets = { + a_fn({ args = {}, rets = { INTEGER, a } }), +}, } end), + ["load"] = a_fn({ args = { a_type("union", { types = { STRING, LOAD_FUNCTION } }), OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }), + ["loadfile"] = a_fn({ args = { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }), ["next"] = a_type("poly", { types = { - a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a_type("map", { keys = a, values = b }), OPT(a) } }), rets = a_type("tuple", { tuple = { a, b } }) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(a) } }), rets = a_type("tuple", { tuple = { INTEGER, a } }) } end), + a_gfunction(2, function(a, b) return { args = { a_type("map", { keys = a, values = b }), OPT(a) }, rets = { a, b } } end), + a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), OPT(a) }, rets = { INTEGER, a } } end), } }), - ["pairs"] = a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a_type("map", { keys = a, values = b }) } }), rets = a_type("tuple", { tuple = { - a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { a, b } }) }), -} }), } end), - ["pcall"] = a_function({ args = a_vararg({ FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["xpcall"] = a_function({ args = a_vararg({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["print"] = a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = {} }) }), - ["rawequal"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["rawget"] = a_function({ args = a_type("tuple", { tuple = { TABLE, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["rawlen"] = a_function({ args = a_type("tuple", { tuple = { a_type("union", { types = { TABLE, STRING } }) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["pairs"] = a_gfunction(2, function(a, b) return { args = { a_type("map", { keys = a, values = b }) }, rets = { + a_fn({ args = {}, rets = { a, b } }), +}, } end), + ["pcall"] = a_fn({ args = va_args({ FUNCTION, ANY }), rets = va_args({ BOOLEAN, ANY }) }), + ["xpcall"] = a_fn({ args = va_args({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = va_args({ BOOLEAN, ANY }) }), + ["print"] = a_fn({ args = va_args({ ANY }), rets = {} }), + ["rawequal"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), + ["rawget"] = a_fn({ args = { TABLE, ANY }, rets = { ANY } }), + ["rawlen"] = a_fn({ args = { a_type("union", { types = { TABLE, STRING } }) }, rets = { INTEGER } }), ["rawset"] = a_type("poly", { types = { - a_gfunction(2, function(a, b) return { args = a_type("tuple", { tuple = { a_type("map", { keys = a, values = b }), a, b } }), rets = a_type("tuple", { tuple = {} }) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, a } }), rets = a_type("tuple", { tuple = {} }) } end), - a_function({ args = a_type("tuple", { tuple = { TABLE, ANY, ANY } }), rets = a_type("tuple", { tuple = {} }) }), + a_gfunction(2, function(a, b) return { args = { a_type("map", { keys = a, values = b }), a, b }, rets = {} } end), + a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, a }, rets = {} } end), + a_fn({ args = { TABLE, ANY, ANY }, rets = {} }), } }), - ["require"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = {} }) }), + ["require"] = a_fn({ args = { STRING }, rets = {} }), ["select"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = a_vararg({ NUMBER, a }), rets = a_type("tuple", { tuple = { a } }) } end), - a_function({ args = a_vararg({ NUMBER, ANY }), rets = a_type("tuple", { tuple = { ANY } }) }), - a_function({ args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_gfunction(1, function(a) return { args = va_args({ NUMBER, a }), rets = { a } } end), + a_fn({ args = va_args({ NUMBER, ANY }), rets = { ANY } }), + a_fn({ args = va_args({ STRING, ANY }), rets = { INTEGER } }), } }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a, METATABLE(a) } }), rets = a_type("tuple", { tuple = { a } }) } end), + ["setmetatable"] = a_gfunction(1, function(a) return { args = { a, METATABLE(a) }, rets = { a } } end), ["tonumber"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - a_function({ args = a_type("tuple", { tuple = { ANY, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + a_fn({ args = { ANY }, rets = { NUMBER } }), + a_fn({ args = { ANY, NUMBER }, rets = { INTEGER } }), } }), - ["tostring"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["tostring"] = a_fn({ args = { ANY }, rets = { STRING } }), + ["type"] = a_fn({ args = { ANY }, rets = { STRING } }), ["FILE"] = a_type("typedecl", { def = a_record({ is_userdata = true, fields = { - ["close"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING, INTEGER } }) }), - ["flush"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE } }), rets = a_type("tuple", { tuple = {} }) }), + ["close"] = a_fn({ args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING, INTEGER } }), + ["flush"] = a_fn({ args = { NOMINAL_FILE }, rets = {} }), ["lines"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_function({ args = ctor(args), rets = a_type("tuple", { tuple = { - a_function({ args = a_type("tuple", { tuple = {} }), rets = ctor(rets) }), -} }), }) + return a_fn({ args = ctor(args), rets = { + a_fn({ args = {}, rets = ctor(rets) }), + }, }) end), ["read"] = a_file_reader(function(ctor, args, rets) table.insert(args, 1, NOMINAL_FILE) - return a_function({ args = ctor(args), rets = ctor(rets) }) + return a_fn({ args = ctor(args), rets = ctor(rets) }) end), - ["seek"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER, STRING } }) }), - ["setvbuf"] = a_function({ args = a_type("tuple", { tuple = { NOMINAL_FILE, STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = {} }) }), - ["write"] = a_function({ args = a_vararg({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), + ["seek"] = a_fn({ args = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = { INTEGER, STRING } }), + ["setvbuf"] = a_fn({ args = { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = {} }), + ["write"] = a_fn({ args = va_args({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = { NOMINAL_FILE, STRING } }), }, meta_fields = { ["__close"] = FUNCTION }, @@ -6007,54 +6066,54 @@ local function init_globals(lax) }) }), ["metatable"] = a_type("typedecl", { def = -a_grecord(1, function(a) return { +a_grecord(1, function(a) return a_record({ fields = { - ["__call"] = a_function({ args = a_vararg({ a, ANY }), rets = a_vararg({ ANY }) }), - ["__gc"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), + ["__call"] = a_fn({ args = va_args({ a, ANY }), rets = va_args({ ANY }) }), + ["__gc"] = a_fn({ args = { a }, rets = {} }), ["__index"] = ANY, - ["__len"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["__len"] = a_fn({ args = { a }, rets = { ANY } }), ["__mode"] = an_enum({ "k", "v", "kv" }), ["__newindex"] = ANY, ["__pairs"] = a_gfunction(2, function(k, v) return { - args = a_type("tuple", { tuple = { a } }), - rets = a_type("tuple", { tuple = { a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { k, v } }) }) } }), + args = { a }, + rets = { a_fn({ args = {}, rets = { k, v } }) }, } end), - ["__tostring"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["__tostring"] = a_fn({ args = { a }, rets = { STRING } }), ["__name"] = STRING, - ["__add"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__sub"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__mul"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__div"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__idiv"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__mod"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__pow"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__unm"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__band"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__bor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__bxor"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__bnot"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__shl"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__shr"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__concat"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["__eq"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["__lt"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["__le"] = a_function({ args = a_type("tuple", { tuple = { ANY, ANY } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["__close"] = a_function({ args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = {} }) }), + ["__add"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__sub"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__mul"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__div"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__idiv"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__mod"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__pow"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__unm"] = a_fn({ args = { ANY }, rets = { ANY } }), + ["__band"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__bor"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__bxor"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__bnot"] = a_fn({ args = { ANY }, rets = { ANY } }), + ["__shl"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__shr"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__concat"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), + ["__eq"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), + ["__lt"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), + ["__le"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), + ["__close"] = a_fn({ args = { a }, rets = {} }), }, -} end) }), +}) end) }), ["coroutine"] = a_record({ fields = { - ["create"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION } }), rets = a_type("tuple", { tuple = { THREAD } }) }), - ["close"] = a_function({ args = a_type("tuple", { tuple = { THREAD } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), - ["isyieldable"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), - ["resume"] = a_function({ args = a_vararg({ THREAD, ANY }), rets = a_vararg({ BOOLEAN, ANY }) }), - ["running"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { THREAD, BOOLEAN } }) }), - ["status"] = a_function({ args = a_type("tuple", { tuple = { THREAD } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["wrap"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION } }), rets = a_type("tuple", { tuple = { FUNCTION } }) }), - ["yield"] = a_function({ args = a_vararg({ ANY }), rets = a_vararg({ ANY }) }), + ["create"] = a_fn({ args = { FUNCTION }, rets = { THREAD } }), + ["close"] = a_fn({ args = { THREAD }, rets = { BOOLEAN, STRING } }), + ["isyieldable"] = a_fn({ args = {}, rets = { BOOLEAN } }), + ["resume"] = a_fn({ args = va_args({ THREAD, ANY }), rets = va_args({ BOOLEAN, ANY }) }), + ["running"] = a_fn({ args = {}, rets = { THREAD, BOOLEAN } }), + ["status"] = a_fn({ args = { THREAD }, rets = { STRING } }), + ["wrap"] = a_fn({ args = { FUNCTION }, rets = { FUNCTION } }), + ["yield"] = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }), }, }), ["debug"] = a_record({ @@ -6063,141 +6122,141 @@ a_grecord(1, function(a) return { ["Hook"] = a_type("typedecl", { def = DEBUG_HOOK_FUNCTION }), ["HookEvent"] = a_type("typedecl", { def = DEBUG_HOOK_EVENT }), - ["debug"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = {} }) }), - ["gethook"] = a_function({ args = a_type("tuple", { tuple = { OPT(THREAD) } }), rets = a_type("tuple", { tuple = { DEBUG_HOOK_FUNCTION, INTEGER } }) }), + ["debug"] = a_fn({ args = {}, rets = {} }), + ["gethook"] = a_fn({ args = { OPT(THREAD) }, rets = { DEBUG_HOOK_FUNCTION, INTEGER } }), ["getlocal"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { THREAD, FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { STRING } }) }), - a_function({ args = a_type("tuple", { tuple = { THREAD, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { STRING, ANY } }) }), - a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { STRING } }) }), - a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { STRING, ANY } }) }), + a_fn({ args = { THREAD, FUNCTION, NUMBER }, rets = { STRING } }), + a_fn({ args = { THREAD, NUMBER, NUMBER }, rets = { STRING, ANY } }), + a_fn({ args = { FUNCTION, NUMBER }, rets = { STRING } }), + a_fn({ args = { NUMBER, NUMBER }, rets = { STRING, ANY } }), } }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a } }), rets = a_type("tuple", { tuple = { METATABLE(a) } }) } end), - ["getregistry"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { TABLE } }) }), - ["getupvalue"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { ANY } }) }), - ["getuservalue"] = a_function({ args = a_type("tuple", { tuple = { USERDATA, NUMBER } }), rets = a_type("tuple", { tuple = { ANY } }) }), + ["getmetatable"] = a_gfunction(1, function(a) return { args = { a }, rets = { METATABLE(a) } } end), + ["getregistry"] = a_fn({ args = {}, rets = { TABLE } }), + ["getupvalue"] = a_fn({ args = { FUNCTION, NUMBER }, rets = { ANY } }), + ["getuservalue"] = a_fn({ args = { USERDATA, NUMBER }, rets = { ANY } }), ["sethook"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER } }), rets = a_type("tuple", { tuple = {} }) }), - a_function({ args = a_type("tuple", { tuple = { DEBUG_HOOK_FUNCTION, STRING, NUMBER } }), rets = a_type("tuple", { tuple = {} }) }), + a_fn({ args = { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }), + a_fn({ args = { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }), } }), ["setlocal"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { THREAD, NUMBER, NUMBER, ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), - a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER, ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_fn({ args = { THREAD, NUMBER, NUMBER, ANY }, rets = { STRING } }), + a_fn({ args = { NUMBER, NUMBER, ANY }, rets = { STRING } }), } }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a, METATABLE(a) } }), rets = a_type("tuple", { tuple = { a } }) } end), - ["setupvalue"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER, ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["setuservalue"] = a_function({ args = a_type("tuple", { tuple = { USERDATA, ANY, NUMBER } }), rets = a_type("tuple", { tuple = { USERDATA } }) }), + ["setmetatable"] = a_gfunction(1, function(a) return { args = { a, METATABLE(a) }, rets = { a } } end), + ["setupvalue"] = a_fn({ args = { FUNCTION, NUMBER, ANY }, rets = { STRING } }), + ["setuservalue"] = a_fn({ args = { USERDATA, ANY, NUMBER }, rets = { USERDATA } }), ["traceback"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { OPT(THREAD), OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), - a_function({ args = a_type("tuple", { tuple = { OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_fn({ args = { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = { STRING } }), + a_fn({ args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }), } }), - ["upvalueid"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = { USERDATA } }) }), - ["upvaluejoin"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, NUMBER, FUNCTION, NUMBER } }), rets = a_type("tuple", { tuple = {} }) }), + ["upvalueid"] = a_fn({ args = { FUNCTION, NUMBER }, rets = { USERDATA } }), + ["upvaluejoin"] = a_fn({ args = { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = {} }), ["getinfo"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { DEBUG_GETINFO_TABLE } }) }), - a_function({ args = a_type("tuple", { tuple = { ANY, STRING } }), rets = a_type("tuple", { tuple = { DEBUG_GETINFO_TABLE } }) }), - a_function({ args = a_type("tuple", { tuple = { ANY, ANY, STRING } }), rets = a_type("tuple", { tuple = { DEBUG_GETINFO_TABLE } }) }), + a_fn({ args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }), + a_fn({ args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), + a_fn({ args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), } }), }, }), ["io"] = a_record({ fields = { - ["close"] = a_function({ args = a_type("tuple", { tuple = { OPT(NOMINAL_FILE) } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), - ["flush"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = {} }) }), - ["input"] = a_function({ args = a_type("tuple", { tuple = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE } }) }), + ["close"] = a_fn({ args = { OPT(NOMINAL_FILE) }, rets = { BOOLEAN, STRING } }), + ["flush"] = a_fn({ args = {}, rets = {} }), + ["input"] = a_fn({ args = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }, rets = { NOMINAL_FILE } }), ["lines"] = a_file_reader(function(ctor, args, rets) - return a_function({ args = ctor(args), rets = a_type("tuple", { tuple = { - a_function({ args = a_type("tuple", { tuple = {} }), rets = ctor(rets) }), -} }), }) + return a_fn({ args = ctor(args), rets = { + a_fn({ args = {}, rets = ctor(rets) }), + }, }) end), - ["open"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), - ["output"] = a_function({ args = a_type("tuple", { tuple = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE } }) }), - ["popen"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING) } }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), + ["open"] = a_fn({ args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }), + ["output"] = a_fn({ args = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }, rets = { NOMINAL_FILE } }), + ["popen"] = a_fn({ args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }), ["read"] = a_file_reader(function(ctor, args, rets) - return a_function({ args = ctor(args), rets = ctor(rets) }) + return a_fn({ args = ctor(args), rets = ctor(rets) }) end), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { NOMINAL_FILE } }) }), - ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["write"] = a_function({ args = a_vararg({ a_type("union", { types = { STRING, NUMBER } }) }), rets = a_type("tuple", { tuple = { NOMINAL_FILE, STRING } }) }), + ["tmpfile"] = a_fn({ args = {}, rets = { NOMINAL_FILE } }), + ["type"] = a_fn({ args = { ANY }, rets = { STRING } }), + ["write"] = a_fn({ args = va_args({ a_type("union", { types = { STRING, NUMBER } }) }), rets = { NOMINAL_FILE, STRING } }), }, }), ["math"] = a_record({ fields = { ["abs"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { INTEGER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_fn({ args = { INTEGER }, rets = { INTEGER } }), + a_fn({ args = { NUMBER }, rets = { NUMBER } }), } }), - ["acos"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["asin"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["atan"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["atan2"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["ceil"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - ["cos"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["cosh"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["deg"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["exp"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["floor"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["acos"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["asin"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["atan"] = a_fn({ args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }), + ["atan2"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["ceil"] = a_fn({ args = { NUMBER }, rets = { INTEGER } }), + ["cos"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["cosh"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["deg"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["exp"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["floor"] = a_fn({ args = { NUMBER }, rets = { INTEGER } }), ["fmod"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { INTEGER, INTEGER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_fn({ args = { INTEGER, INTEGER }, rets = { INTEGER } }), + a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), } }), - ["frexp"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER, NUMBER } }) }), + ["frexp"] = a_fn({ args = { NUMBER }, rets = { NUMBER, NUMBER } }), ["huge"] = NUMBER, - ["ldexp"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["log"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["log10"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["ldexp"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["log"] = a_fn({ args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }), + ["log10"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), ["max"] = a_type("poly", { types = { - a_function({ args = a_vararg({ INTEGER }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { tuple = { a } }) } end), - a_function({ args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = { ANY } }) }), + a_fn({ args = va_args({ INTEGER }), rets = { INTEGER } }), + a_gfunction(1, function(a) return { args = va_args({ a }), rets = { a } } end), + a_fn({ args = va_args({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = { NUMBER } }), + a_fn({ args = va_args({ ANY }), rets = { ANY } }), } }), ["maxinteger"] = a_type("integer", { needs_compat = true }), ["min"] = a_type("poly", { types = { - a_function({ args = a_vararg({ INTEGER }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - a_gfunction(1, function(a) return { args = a_vararg({ a }), rets = a_type("tuple", { tuple = { a } }) } end), - a_function({ args = a_vararg({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = { ANY } }) }), + a_fn({ args = va_args({ INTEGER }), rets = { INTEGER } }), + a_gfunction(1, function(a) return { args = va_args({ a }), rets = { a } } end), + a_fn({ args = va_args({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = { NUMBER } }), + a_fn({ args = va_args({ ANY }), rets = { ANY } }), } }), ["mininteger"] = a_type("integer", { needs_compat = true }), - ["modf"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER, NUMBER } }) }), + ["modf"] = a_fn({ args = { NUMBER }, rets = { INTEGER, NUMBER } }), ["pi"] = NUMBER, - ["pow"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["rad"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["pow"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["rad"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), ["random"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + a_fn({ args = { NUMBER, OPT(NUMBER) }, rets = { INTEGER } }), + a_fn({ args = {}, rets = { NUMBER } }), } }), - ["randomseed"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER, INTEGER } }) }), - ["sin"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["sinh"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["sqrt"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["tan"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["tanh"] = a_function({ args = a_type("tuple", { tuple = { NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["tointeger"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - ["type"] = a_function({ args = a_type("tuple", { tuple = { ANY } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["ult"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), + ["randomseed"] = a_fn({ args = { NUMBER, NUMBER }, rets = { INTEGER, INTEGER } }), + ["sin"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["sinh"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["sqrt"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["tan"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["tanh"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), + ["tointeger"] = a_fn({ args = { ANY }, rets = { INTEGER } }), + ["type"] = a_fn({ args = { ANY }, rets = { STRING } }), + ["ult"] = a_fn({ args = { NUMBER, NUMBER }, rets = { BOOLEAN } }), }, }), ["os"] = a_record({ fields = { - ["clock"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { NUMBER } }) }), + ["clock"] = a_fn({ args = {}, rets = { NUMBER } }), ["date"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { STRING } }) }), - a_function({ args = a_type("tuple", { tuple = { an_enum({ "!*t", "*t" }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { OS_DATE_TABLE } }) }), - a_function({ args = a_type("tuple", { tuple = { OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + a_fn({ args = {}, rets = { STRING } }), + a_fn({ args = { an_enum({ "!*t", "*t" }), OPT(NUMBER) }, rets = { OS_DATE_TABLE } }), + a_fn({ args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }), } }), - ["difftime"] = a_function({ args = a_type("tuple", { tuple = { NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { NUMBER } }) }), - ["execute"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING, INTEGER } }) }), - ["exit"] = a_function({ args = a_type("tuple", { tuple = { OPT(a_type("union", { types = { NUMBER, BOOLEAN } })), OPT(BOOLEAN) } }), rets = a_type("tuple", { tuple = {} }) }), - ["getenv"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["remove"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), - ["rename"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING } }), rets = a_type("tuple", { tuple = { BOOLEAN, STRING } }) }), - ["setlocale"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING) } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["time"] = a_function({ args = a_type("tuple", { tuple = { OPT(OS_DATE_TABLE) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - ["tmpname"] = a_function({ args = a_type("tuple", { tuple = {} }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["difftime"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), + ["execute"] = a_fn({ args = { STRING }, rets = { BOOLEAN, STRING, INTEGER } }), + ["exit"] = a_fn({ args = { OPT(a_type("union", { types = { NUMBER, BOOLEAN } })), OPT(BOOLEAN) }, rets = {} }), + ["getenv"] = a_fn({ args = { STRING }, rets = { STRING } }), + ["remove"] = a_fn({ args = { STRING }, rets = { BOOLEAN, STRING } }), + ["rename"] = a_fn({ args = { STRING, STRING }, rets = { BOOLEAN, STRING } }), + ["setlocale"] = a_fn({ args = { STRING, OPT(STRING) }, rets = { STRING } }), + ["time"] = a_fn({ args = { OPT(OS_DATE_TABLE) }, rets = { INTEGER } }), + ["tmpname"] = a_fn({ args = {}, rets = { STRING } }), }, }), ["package"] = a_record({ @@ -6205,75 +6264,75 @@ a_grecord(1, function(a) return { ["config"] = STRING, ["cpath"] = STRING, ["loaded"] = a_type("map", { keys = STRING, values = ANY }), - ["loaders"] = a_type("array", { elements = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { ANY, ANY } }) }) }), - ["loadlib"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING } }), rets = a_type("tuple", { tuple = { FUNCTION } }) }), + ["loaders"] = a_type("array", { elements = a_fn({ args = { STRING }, rets = { ANY, ANY } }) }), + ["loadlib"] = a_fn({ args = { STRING, STRING }, rets = { FUNCTION } }), ["path"] = STRING, ["preload"] = TABLE, - ["searchers"] = a_type("array", { elements = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { ANY, ANY } }) }) }), - ["searchpath"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(STRING), OPT(STRING) } }), rets = a_type("tuple", { tuple = { STRING, STRING } }) }), + ["searchers"] = a_type("array", { elements = a_fn({ args = { STRING }, rets = { ANY, ANY } }) }), + ["searchpath"] = a_fn({ args = { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = { STRING, STRING } }), }, }), ["string"] = a_record({ fields = { ["byte"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, NUMBER } }), rets = a_vararg({ INTEGER }) }), + a_fn({ args = { STRING, OPT(NUMBER) }, rets = { INTEGER } }), + a_fn({ args = { STRING, NUMBER, NUMBER }, rets = va_args({ INTEGER }) }), } }), - ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["dump"] = a_function({ args = a_type("tuple", { tuple = { FUNCTION, OPT(BOOLEAN) } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["find"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) } }), rets = a_vararg({ INTEGER, INTEGER, STRING }) }), - ["format"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["gmatch"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING } }), rets = a_type("tuple", { tuple = { - a_function({ args = a_type("tuple", { tuple = {} }), rets = a_vararg({ STRING }) }), -} }), }), + ["char"] = a_fn({ args = va_args({ NUMBER }), rets = { STRING } }), + ["dump"] = a_fn({ args = { FUNCTION, OPT(BOOLEAN) }, rets = { STRING } }), + ["find"] = a_fn({ args = { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = va_args({ INTEGER, INTEGER, STRING }) }), + ["format"] = a_fn({ args = va_args({ STRING, ANY }), rets = { STRING } }), + ["gmatch"] = a_fn({ args = { STRING, STRING }, rets = { + a_fn({ args = {}, rets = va_args({ STRING }) }), + }, }), ["gsub"] = a_type("poly", { types = { - a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_type("map", { keys = STRING, values = STRING }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = { STRING } }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = { NUMBER } }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = { BOOLEAN } }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, STRING, a_function({ args = a_vararg({ STRING }), rets = a_type("tuple", { tuple = {} }) }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), - a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(STRING), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING, INTEGER } }) }), + a_fn({ args = { STRING, STRING, a_type("map", { keys = STRING, values = STRING }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), + a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = { STRING } }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), + a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = { NUMBER } }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), + a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = { BOOLEAN } }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), + a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = {} }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), + a_fn({ args = { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = { STRING, INTEGER } }), } }), - ["len"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - ["lower"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["match"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(STRING), OPT(NUMBER) } }), rets = a_vararg({ STRING }) }), - ["pack"] = a_function({ args = a_vararg({ STRING, ANY }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["packsize"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - ["rep"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, OPT(STRING) } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["reverse"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["sub"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), - ["unpack"] = a_function({ args = a_type("tuple", { tuple = { STRING, STRING, OPT(NUMBER) } }), rets = a_vararg({ ANY }) }), - ["upper"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["len"] = a_fn({ args = { STRING }, rets = { INTEGER } }), + ["lower"] = a_fn({ args = { STRING }, rets = { STRING } }), + ["match"] = a_fn({ args = { STRING, OPT(STRING), OPT(NUMBER) }, rets = va_args({ STRING }) }), + ["pack"] = a_fn({ args = va_args({ STRING, ANY }), rets = { STRING } }), + ["packsize"] = a_fn({ args = { STRING }, rets = { INTEGER } }), + ["rep"] = a_fn({ args = { STRING, NUMBER, OPT(STRING) }, rets = { STRING } }), + ["reverse"] = a_fn({ args = { STRING }, rets = { STRING } }), + ["sub"] = a_fn({ args = { STRING, NUMBER, OPT(NUMBER) }, rets = { STRING } }), + ["unpack"] = a_fn({ args = { STRING, STRING, OPT(NUMBER) }, rets = va_args({ ANY }) }), + ["upper"] = a_fn({ args = { STRING }, rets = { STRING } }), }, }), ["table"] = a_record({ fields = { - ["concat"] = a_function({ args = a_type("tuple", { tuple = { a_type("array", { elements = a_type("union", { types = { STRING, NUMBER } }) }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["concat"] = a_fn({ args = { a_type("array", { elements = a_type("union", { types = { STRING, NUMBER } }) }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = { STRING } }), ["insert"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, a } }), rets = a_type("tuple", { tuple = {} }) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), a } }), rets = a_type("tuple", { tuple = {} }) } end), + a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, a }, rets = {} } end), + a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), a }, rets = {} } end), } }), ["move"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { a_type("array", { elements = a }) } }) } end), - a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER, a_type("array", { elements = a }) } }), rets = a_type("tuple", { tuple = { a_type("array", { elements = a }) } }) } end), + a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER }, rets = { a_type("array", { elements = a }) } } end), + a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER, a_type("array", { elements = a }) }, rets = { a_type("array", { elements = a }) } } end), } }), - ["pack"] = a_function({ args = a_vararg({ ANY }), rets = a_type("tuple", { tuple = { TABLE } }) }), - ["remove"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { a } }) } end), - ["sort"] = a_gfunction(1, function(a) return { args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(TABLE_SORT_FUNCTION) } }), rets = a_type("tuple", { tuple = {} }) } end), - ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = a_type("tuple", { tuple = { a_type("array", { elements = a }), OPT(NUMBER), OPT(NUMBER) } }), rets = a_vararg({ a }) } end), + ["pack"] = a_fn({ args = va_args({ ANY }), rets = { TABLE } }), + ["remove"] = a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), OPT(NUMBER) }, rets = { a } } end), + ["sort"] = a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), OPT(TABLE_SORT_FUNCTION) }, rets = {} } end), + ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = { a_type("array", { elements = a }), OPT(NUMBER), OPT(NUMBER) }, rets = va_args({ a }) } end), }, }), ["utf8"] = a_record({ fields = { - ["char"] = a_function({ args = a_vararg({ NUMBER }), rets = a_type("tuple", { tuple = { STRING } }) }), + ["char"] = a_fn({ args = va_args({ NUMBER }), rets = { STRING } }), ["charpattern"] = STRING, - ["codepoint"] = a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER), OPT(NUMBER) } }), rets = a_vararg({ INTEGER }) }), - ["codes"] = a_function({ args = a_type("tuple", { tuple = { STRING } }), rets = a_type("tuple", { tuple = { - a_function({ args = a_type("tuple", { tuple = { STRING, OPT(NUMBER) } }), rets = a_type("tuple", { tuple = { NUMBER, NUMBER } }) }), -} }), }), - ["len"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), - ["offset"] = a_function({ args = a_type("tuple", { tuple = { STRING, NUMBER, NUMBER } }), rets = a_type("tuple", { tuple = { INTEGER } }) }), + ["codepoint"] = a_fn({ args = { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = va_args({ INTEGER }) }), + ["codes"] = a_fn({ args = { STRING }, rets = { + a_fn({ args = { STRING, OPT(NUMBER) }, rets = { NUMBER, NUMBER } }), + }, }), + ["len"] = a_fn({ args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }), + ["offset"] = a_fn({ args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }), }, }), ["_VERSION"] = STRING, @@ -6670,25 +6729,6 @@ tl.type_check = function(ast, opts) return u, store_errs and errs end - local function set_min_arity(f) - if f.min_arity then - return - end - local tuple = f.args.tuple - local n = #tuple - if f.args.is_va then - n = n - 1 - end - for i = n, 1, -1 do - if tuple[i].opt then - n = n - 1 - else - break - end - end - f.min_arity = n - end - local function show_arity(f) local nfargs = #f.args.tuple return f.min_arity < nfargs and @@ -6765,7 +6805,6 @@ tl.type_check = function(ast, opts) local copy = {} seen[orig_t] = copy - copy.opt = t.opt copy.typename = t.typename copy.filename = t.filename copy.x = t.x @@ -6822,7 +6861,6 @@ tl.type_check = function(ast, opts) end end - set_min_arity(t) copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) @@ -8187,8 +8225,6 @@ a.types[i], b.types[i]), } local errs = {} local aa, ba = a.args.tuple, b.args.tuple - set_min_arity(a) - set_min_arity(b) if (not b.args.is_va) and a.min_arity > b.min_arity then table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else @@ -8451,7 +8487,7 @@ a.types[i], b.types[i]), } local function resolve_for_call(func, args, is_method) if lax and is_unknown(func) then - func = a_function({ args = a_vararg({ UNKNOWN }), rets = a_vararg({ UNKNOWN }) }) + func = a_fn({ args = va_args({ UNKNOWN }), rets = va_args({ UNKNOWN }) }) end func = resolve_tuple_and_nominal(func) @@ -8771,7 +8807,6 @@ a.types[i], b.types[i]), } end end local wanted = #fargs - set_min_arity(f) if (passes == 1 and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or @@ -9093,20 +9128,12 @@ a.types[i], b.types[i]), } end local function add_function_definition_for_recursion(node, fnargs) - assert(fnargs.typename == "tuple") - - - local args = a_type("tuple", { tuple = {} }) - args.is_va = fnargs.is_va - for _, fnarg in ipairs(fnargs.tuple) do - table.insert(args.tuple, fnarg) - end - - add_var(nil, node.name.tk, a_function({ + add_var(nil, node.name.tk, type_at(node, a_function({ + min_arity = node.min_arity, typeargs = node.typeargs, - args = args, + args = fnargs, rets = get_rets(node.rets), - })) + }))) end local function fail_unresolved() @@ -11048,14 +11075,12 @@ expand_type(node, values, elements) }) end_function_scope(node) - local t = ensure_fresh_typeargs(a_function({ - y = node.y, - x = node.x, + local t = type_at(node, ensure_fresh_typeargs(a_function({ + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, - })) + }))) add_var(node, node.name.tk, t) return t @@ -11079,15 +11104,13 @@ expand_type(node, values, elements) }) check_macroexp_arg_use(node.macrodef) - local t = ensure_fresh_typeargs(a_function({ - y = node.y, - x = node.x, + local t = type_at(node, ensure_fresh_typeargs(a_function({ + min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, macroexp = node.macrodef, - })) + }))) add_var(node, node.name.tk, t) return t @@ -11128,14 +11151,12 @@ expand_type(node, values, elements) }) return NONE end - add_global(node, node.name.tk, ensure_fresh_typeargs(a_function({ - y = node.y, - x = node.x, + add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function({ + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, - }))) + })))) return NONE end, @@ -11192,15 +11213,13 @@ expand_type(node, values, elements) }) add_var(nil, "self", selftype) end - local fn_type = ensure_fresh_typeargs(a_function({ - y = node.y, - x = node.x, + local fn_type = type_at(node, ensure_fresh_typeargs(a_function({ + min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, - })) + }))) local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk @@ -11267,14 +11286,12 @@ expand_type(node, values, elements) }) assert(rets.typename == "tuple") end_function_scope(node) - return ensure_fresh_typeargs(a_function({ - y = node.y, - x = node.x, + return type_at(node, ensure_fresh_typeargs(a_function({ + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - filename = filename, - })) + }))) end, }, ["macroexp"] = { @@ -11295,14 +11312,12 @@ expand_type(node, values, elements) }) assert(rets.typename == "tuple") end_function_scope(node) - return ensure_fresh_typeargs(a_function({ - y = node.y, - x = node.x, + return type_at(node, ensure_fresh_typeargs(a_function({ + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - filename = filename, - })) + }))) end, }, ["cast"] = { @@ -11743,9 +11758,6 @@ expand_type(node, values, elements) }) if node.tk == "..." then t = a_vararg({ t }) end - if node.opt then - t = OPT(t) - end add_var(node, node.tk, t).is_func_arg = true return t end, diff --git a/tl.tl b/tl.tl index 5073e4e03..7d7288457 100644 --- a/tl.tl +++ b/tl.tl @@ -1084,8 +1084,6 @@ local interface Type -- Lua compatibilty needs_compat: boolean - - opt: boolean -- optional arity end local record StringType @@ -1202,7 +1200,6 @@ local record TupleType is Type where self.typename == "tuple" - -- tuple is_va: boolean tuple: {Type} end @@ -1267,7 +1264,7 @@ local record FunctionType where self.typename == "function" is_method: boolean - min_arity: number + min_arity: integer args: TupleType rets: TupleType macroexp: Node @@ -1486,6 +1483,7 @@ local record Node key_parsed: KeyParsed typeargs: {TypeArgType} + min_arity: integer args: Node rets: TupleType body: Node @@ -1584,8 +1582,8 @@ local parse_type_list: function(ParseState, integer, ParseTypeListMode): integer local parse_expression: function(ParseState, integer): integer, Node, integer local parse_expression_and_tk: function(ps: ParseState, i: integer, tk: string): integer, Node local parse_statements: function(ParseState, integer, ? boolean): integer, Node -local parse_argument_list: function(ParseState, integer): integer, Node -local parse_argument_type_list: function(ParseState, integer): integer, TupleType, boolean +local parse_argument_list: function(ParseState, integer): integer, Node, integer +local parse_argument_type_list: function(ParseState, integer): integer, TupleType, boolean, integer local parse_type: function(ParseState, integer): integer, Type, integer local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer @@ -1678,10 +1676,6 @@ local macroexp a_tuple(t: {Type}): TupleType return a_type("tuple", { tuple = t } as TupleType) end -local function c_tuple(t: {Type}): TupleType - return a_type("tuple", { tuple = t } as TupleType) -end - local macroexp a_union(t: {Type}): UnionType return a_type("union", { types = t } as UnionType) end @@ -1691,9 +1685,70 @@ local macroexp a_poly(t: {FunctionType}): PolyType end local function a_function(t: FunctionType): FunctionType + assert(t.min_arity) return a_type("function", t) end +local record Opt + where self.opttype + + opttype: Type +end + +local function OPT(t: Type): Opt + return { opttype = t } +end + +local record Args + is {Type|Opt} + + is_va: boolean +end + +local function va_args(args: Args): Args + args.is_va = true + return args +end + +local record FuncArgs + is HasTypeArgs + + args: Args + rets: Args + needs_compat: boolean +end + +local function a_fn(f: FuncArgs): FunctionType + local args_t = a_tuple {} + local tup = args_t.tuple + args_t.is_va = f.args.is_va + local min_arity = f.args.is_va and -1 or 0 + for _, a in ipairs(f.args) do + if a is Opt then + table.insert(tup, a.opttype) + else + table.insert(tup, a) + min_arity = min_arity + 1 + end + end + + local rets_t = a_tuple {} + tup = rets_t.tuple + rets_t.is_va = f.rets.is_va + for _, a in ipairs(f.rets) do + assert(a is Type) + table.insert(tup, a) + end + + return a_type("function", { + args = args_t, + rets = rets_t, + min_arity = min_arity, + needs_compat = f.needs_compat, + typeargs = f.typeargs, + } as FunctionType) +end + local function a_vararg(t: {Type}): TupleType local typ = a_tuple(t) typ.is_va = true @@ -1994,7 +2049,7 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method = parse_argument_type_list(ps, i) + i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else typ.args = a_vararg { ANY } @@ -2014,19 +2069,6 @@ local simple_types: {string:Type} = { ["integer"] = INTEGER, } -local memoize_opt_types: {Type:Type} = {} - -local function OPT(t: Type): Type - if memoize_opt_types[t] then - return memoize_opt_types[t] - end - - local ot = shallow_copy_new_type(t) - ot.opt = true - memoize_opt_types[t] = ot - return ot -end - local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer, Type local tk = ps.tokens[i].tk local st = simple_types[tk] @@ -2194,7 +2236,7 @@ local function parse_function_args_rets_body(ps: ParseState, i: integer, node: N if ps.tokens[i].tk == "<" then i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end - i, node.args = parse_argument_list(ps, i) + i, node.args, node.min_arity = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.body = parse_statements(ps, i) end_at(node, ps.tokens[i]) @@ -2639,10 +2681,11 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege return i, node, 0 end -parse_argument_list = function(ps: ParseState, i: integer): integer, Node +parse_argument_list = function(ps: ParseState, i: integer): integer, Node, integer local node = new_node(ps.tokens, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false + local min_arity = 0 for a, fnarg in ipairs(node) do if fnarg.tk == "..." then if a ~= #node then @@ -2653,9 +2696,11 @@ parse_argument_list = function(ps: ParseState, i: integer): integer, Node opts = true elseif opts then return fail(ps, i, "non-optional arguments cannot follow optional arguments") + else + min_arity = min_arity + 1 end end - return i, node + return i, node, min_arity end local record ArgumentInfo @@ -2663,6 +2708,7 @@ local record ArgumentInfo type: Type is_va: boolean is_self: boolean + opt: boolean end local function parse_argument_type(ps: ParseState, i: integer): integer, ArgumentInfo, integer @@ -2700,33 +2746,33 @@ local function parse_argument_type(ps: ParseState, i: integer): integer, Argumen is_va = true end - if opt then - typ = OPT(typ) - end - if argument_name == "self" then is_self = true end end - return i, { i = i, type = typ, is_va = is_va, is_self = is_self }, 0 + return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 end -parse_argument_type_list = function(ps: ParseState, i: integer): integer, Type, boolean +parse_argument_type_list = function(ps: ParseState, i: integer): integer, Type, boolean, integer local ars: {ArgumentInfo} = {} i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) local t, list = new_tuple(ps, i) local n = #ars + local min_arity = 0 for l, ar in ipairs(ars) do list[l] = ar.type if ar.is_va and l < n then fail(ps, ar.i, "'...' can only be last argument") end + if not ar.opt then + min_arity = min_arity + 1 + end end if n > 0 and ars[n].is_va then t.is_va = true end - return i, t, (n > 0 and ars[1].is_self) + return i, t, (n > 0 and ars[1].is_self), min_arity end local function parse_identifier(ps: ParseState, i: integer): integer, Node, integer @@ -2783,6 +2829,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int i = parse_function_args_rets_body(ps, i, fn) if fn.is_method then table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + fn.min_arity = fn.min_arity + 1 end if not fn.name then @@ -3066,7 +3113,7 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): -- end local node = new_node(ps.tokens, istart, "macroexp") local i: integer - i, node.args = parse_argument_list(ps, iargs) + i, node.args, node.min_arity = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) @@ -3085,6 +3132,7 @@ local function parse_where_clause(ps: ParseState, i: integer): integer, Node node.args[1] = new_node(ps.tokens, i, "argument") node.args[1].tk = "self" node.args[1].argtype = selftype + node.min_arity = 1 node.rets = new_tuple(ps, i) node.rets.tuple[1] = BOOLEAN i, node.exp = parse_expression(ps, i) @@ -3167,6 +3215,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no local typ = new_type(ps, wstart, "function") as FunctionType typ.is_method = true + typ.min_arity = 1 typ.args = a_tuple { a_type("nominal", { y = typ.y, @@ -5155,10 +5204,10 @@ local INVALID = a_type("invalid", {} as InvalidType) local UNKNOWN = a_type("unknown", {}) local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_function { args = a_vararg { ANY }, rets = a_vararg { ANY } } +local FUNCTION = a_fn { args = va_args { ANY }, rets = va_args { ANY } } local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) -local XPCALL_MSGH_FUNCTION = a_function { args = a_tuple { ANY }, rets = a_tuple { } } +local XPCALL_MSGH_FUNCTION = a_fn { args = { ANY }, rets = { } } local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type @@ -5518,7 +5567,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str for i, v in ipairs(t.args.tuple) do if not t.is_method or i > 1 then table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " - or v.opt and "? " + or (i > t.min_arity) and "? " or "") .. show(v)) end end @@ -5817,14 +5866,14 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} last_typeid = globals_typeid end - local function a_record(t: RecordType): Type + local function a_record(t: RecordType): RecordType t = a_type("record", t) assert(t.fields) t.field_order = sorted_keys(t.fields) return t end - local function a_gfunction(n: integer, f: function(...: TypeVarType): (FunctionType), typename?: TypeName): FunctionType + local function a_generic(n: integer, f: function(...: TypeVarType): (T)): T local typevars = {} local typeargs = {} local c = string.byte("A") - 1 @@ -5835,12 +5884,18 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} typeargs[i] = a_type("typearg", { typearg = name } as TypeArgType) end local t = f(table.unpack(typevars)) - t.typeargs = typeargs - return a_type(typename or "function", t) + if t is FunctionType or t is RecordType then + t.typeargs = typeargs + end + return t + end + + local function a_gfunction(n: integer, f: function(...: TypeVarType): FuncArgs): FunctionType + return a_generic(n, function(...: TypeVarType): FunctionType return a_fn(f(...)) end) end - local function a_grecord(n: integer, f: function(...: Type): Type): Type - local t = a_gfunction(n, f, "record") as RecordType -- FIXME + local function a_grecord(n: integer, f: function(...: TypeVarType): RecordType): RecordType + local t = a_generic(n, f) t.field_order = sorted_keys(t.fields) return t end @@ -5853,7 +5908,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local type TypeConstructor = function({Type}):TupleType + local type TypeConstructor = function({Type}): Args local record ArgsRets ctor: TypeConstructor @@ -5861,12 +5916,16 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} rets: {Type} end + local function id(x: T): T + return x + end + local file_reader_poly_types: {ArgsRets} = { - { ctor = a_vararg, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, - { ctor = c_tuple, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, - { ctor = a_vararg, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { a_union { STRING, NUMBER } } }, - { ctor = a_vararg, args = { a_union { NUMBER, STRING } }, rets = { STRING } }, - { ctor = a_vararg, args = { }, rets = { STRING } }, + { ctor = va_args, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, + { ctor = id, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, + { ctor = va_args, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { a_union { STRING, NUMBER } } }, + { ctor = va_args, args = { a_union { NUMBER, STRING } }, rets = { STRING } }, + { ctor = va_args, args = { }, rets = { STRING } }, } local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): FunctionType)): Type @@ -5879,7 +5938,7 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return t end - local LOAD_FUNCTION = a_function { args = a_tuple {}, rets = a_tuple { STRING } } + local LOAD_FUNCTION = a_fn { args = {}, rets = { STRING } } local OS_DATE_TABLE = a_record { fields = { @@ -5916,12 +5975,12 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} local DEBUG_HOOK_EVENT = an_enum { "call", "tail call", "return", "line", "count" } - local DEBUG_HOOK_FUNCTION = a_function { - args = a_tuple { DEBUG_HOOK_EVENT, INTEGER }, - rets = a_tuple {}, + local DEBUG_HOOK_FUNCTION = a_fn { + args = { DEBUG_HOOK_EVENT, INTEGER }, + rets = {}, } - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, a }, rets = a_tuple { BOOLEAN } } end) + local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type): FuncArgs return { args = { a, a }, rets = { BOOLEAN } } end) local metatable_nominals: {NominalType} = {} @@ -5935,71 +5994,71 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["..."] = a_vararg { STRING }, ["any"] = a_typedecl(ANY), ["arg"] = an_array(STRING), - ["assert"] = a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a, OPT(b) }, rets = a_tuple { a } } end), + ["assert"] = a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a, OPT(b) }, rets = { a } } end), ["collectgarbage"] = a_poly { - a_function { args = a_tuple { an_enum { "collect", "count", "stop", "restart" } }, rets = a_tuple { NUMBER } }, - a_function { args = a_tuple { an_enum { "step", "setpause", "setstepmul" }, NUMBER }, rets = a_tuple { NUMBER } }, - a_function { args = a_tuple { an_enum { "isrunning" } }, rets = a_tuple { BOOLEAN } }, - a_function { args = a_tuple { STRING, OPT(NUMBER) }, rets = a_tuple { a_union { BOOLEAN, NUMBER } } }, + a_fn { args = { an_enum { "collect", "count", "stop", "restart" } }, rets = { NUMBER } }, + a_fn { args = { an_enum { "step", "setpause", "setstepmul" }, NUMBER }, rets = { NUMBER } }, + a_fn { args = { an_enum { "isrunning" } }, rets = { BOOLEAN } }, + a_fn { args = { STRING, OPT(NUMBER) }, rets = { a_union { BOOLEAN, NUMBER } } }, }, - ["dofile"] = a_function { args = a_tuple { OPT(STRING) }, rets = a_vararg { ANY } }, - ["error"] = a_function { args = a_tuple { ANY, OPT(NUMBER) }, rets = a_tuple {} }, - ["getmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), - ["ipairs"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a) }, rets = a_tuple { - a_function { args = a_tuple {}, rets = a_tuple { INTEGER, a } }, + ["dofile"] = a_fn { args = { OPT(STRING) }, rets = va_args { ANY } }, + ["error"] = a_fn { args = { ANY, OPT(NUMBER) }, rets = {} }, + ["getmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a }, rets = { METATABLE(a) } } end), + ["ipairs"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a) }, rets = { + a_fn { args = {}, rets = { INTEGER, a } }, } } end), - ["load"] = a_function { args = a_tuple { a_union { STRING, LOAD_FUNCTION }, OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = a_tuple { FUNCTION, STRING } }, - ["loadfile"] = a_function { args = a_tuple { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = a_tuple { FUNCTION, STRING } }, + ["load"] = a_fn { args = { a_union { STRING, LOAD_FUNCTION }, OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }, + ["loadfile"] = a_fn { args = { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }, ["next"] = a_poly { - a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a_map(a, b), OPT(a) }, rets = a_tuple { a, b } } end), - a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), OPT(a) }, rets = a_tuple { INTEGER, a } } end), + a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a_map(a, b), OPT(a) }, rets = { a, b } } end), + a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), OPT(a) }, rets = { INTEGER, a } } end), }, - ["pairs"] = a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a_map(a, b) }, rets = a_tuple { - a_function { args = a_tuple {}, rets = a_tuple { a, b } }, + ["pairs"] = a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a_map(a, b) }, rets = { + a_fn { args = {}, rets = { a, b } }, } } end), - ["pcall"] = a_function { args = a_vararg { FUNCTION, ANY }, rets = a_vararg { BOOLEAN, ANY } }, - ["xpcall"] = a_function { args = a_vararg { FUNCTION, XPCALL_MSGH_FUNCTION, ANY }, rets = a_vararg { BOOLEAN, ANY } }, - ["print"] = a_function { args = a_vararg { ANY }, rets = a_tuple {} }, - ["rawequal"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, - ["rawget"] = a_function { args = a_tuple { TABLE, ANY }, rets = a_tuple { ANY } }, - ["rawlen"] = a_function { args = a_tuple { a_union { TABLE, STRING } }, rets = a_tuple { INTEGER } }, + ["pcall"] = a_fn { args = va_args { FUNCTION, ANY }, rets = va_args { BOOLEAN, ANY } }, + ["xpcall"] = a_fn { args = va_args { FUNCTION, XPCALL_MSGH_FUNCTION, ANY }, rets = va_args { BOOLEAN, ANY } }, + ["print"] = a_fn { args = va_args { ANY }, rets = {} }, + ["rawequal"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, + ["rawget"] = a_fn { args = { TABLE, ANY }, rets = { ANY } }, + ["rawlen"] = a_fn { args = { a_union { TABLE, STRING } }, rets = { INTEGER } }, ["rawset"] = a_poly { - a_gfunction(2, function(a: Type, b: Type): FunctionType return { args = a_tuple { a_map(a, b), a, b }, rets = a_tuple {} } end), - a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), - a_function { args = a_tuple { TABLE, ANY, ANY }, rets = a_tuple {} }, + a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a_map(a, b), a, b }, rets = {} } end), + a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, a }, rets = {} } end), + a_fn { args = { TABLE, ANY, ANY }, rets = {} }, }, - ["require"] = a_function { args = a_tuple { STRING }, rets = a_tuple {} }, + ["require"] = a_fn { args = { STRING }, rets = {} }, ["select"] = a_poly { - a_gfunction(1, function(a: Type): FunctionType return { args = a_vararg { NUMBER, a }, rets = a_tuple { a } } end), - a_function { args = a_vararg { NUMBER, ANY }, rets = a_tuple { ANY } }, - a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { INTEGER } }, + a_gfunction(1, function(a: Type): FuncArgs return { args = va_args { NUMBER, a }, rets = { a } } end), + a_fn { args = va_args { NUMBER, ANY }, rets = { ANY } }, + a_fn { args = va_args { STRING, ANY }, rets = { INTEGER } }, }, - ["setmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), + ["setmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a, METATABLE(a) }, rets = { a } } end), ["tonumber"] = a_poly { - a_function { args = a_tuple { ANY }, rets = a_tuple { NUMBER } }, - a_function { args = a_tuple { ANY, NUMBER }, rets = a_tuple { INTEGER } }, + a_fn { args = { ANY }, rets = { NUMBER } }, + a_fn { args = { ANY, NUMBER }, rets = { INTEGER } }, }, - ["tostring"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, - ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, + ["tostring"] = a_fn { args = { ANY }, rets = { STRING } }, + ["type"] = a_fn { args = { ANY }, rets = { STRING } }, ["FILE"] = a_typedecl( a_record { is_userdata = true, fields = { - ["close"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple { BOOLEAN, STRING, INTEGER } }, - ["flush"] = a_function { args = a_tuple { NOMINAL_FILE }, rets = a_tuple {} }, + ["close"] = a_fn { args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING, INTEGER } }, + ["flush"] = a_fn { args = { NOMINAL_FILE }, rets = {} }, ["lines"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type table.insert(args, 1, NOMINAL_FILE) - return a_function { args = ctor(args), rets = a_tuple { - a_function { args = a_tuple {}, rets = ctor(rets) }, + return a_fn { args = ctor(args), rets = { + a_fn { args = {}, rets = ctor(rets) }, } } end), ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type table.insert(args, 1, NOMINAL_FILE) - return a_function { args = ctor(args), rets = ctor(rets) } + return a_fn { args = ctor(args), rets = ctor(rets) } end), - ["seek"] = a_function { args = a_tuple { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = a_tuple { INTEGER, STRING } }, - ["setvbuf"] = a_function { args = a_tuple { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = a_tuple {} }, - ["write"] = a_function { args = a_vararg { NOMINAL_FILE, a_union { STRING, NUMBER } }, rets = a_tuple { NOMINAL_FILE, STRING } }, + ["seek"] = a_fn { args = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = { INTEGER, STRING } }, + ["setvbuf"] = a_fn { args = { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = {} }, + ["write"] = a_fn { args = va_args { NOMINAL_FILE, a_union { STRING, NUMBER } }, rets = { NOMINAL_FILE, STRING } }, -- TODO complete... }, meta_fields = { ["__close"] = FUNCTION }, @@ -6007,54 +6066,54 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} } ), ["metatable"] = a_typedecl( - a_grecord(1, function(a: Type): RecordType return { + a_grecord(1, function(a: Type): RecordType return a_record { fields = { - ["__call"] = a_function { args = a_vararg { a, ANY }, rets = a_vararg { ANY } }, - ["__gc"] = a_function { args = a_tuple { a }, rets = a_tuple {} }, + ["__call"] = a_fn { args = va_args { a, ANY }, rets = va_args { ANY } }, + ["__gc"] = a_fn { args = { a }, rets = {} }, ["__index"] = ANY, -- FIXME: function | table | anything with an __index metamethod - ["__len"] = a_function { args = a_tuple { a }, rets = a_tuple { ANY } }, + ["__len"] = a_fn { args = { a }, rets = { ANY } }, ["__mode"] = an_enum { "k", "v", "kv" }, ["__newindex"] = ANY, -- FIXME: function | table | anything with a __newindex metamethod - ["__pairs"] = a_gfunction(2, function(k: Type, v: Type): FunctionType + ["__pairs"] = a_gfunction(2, function(k: Type, v: Type): FuncArgs return { - args = a_tuple { a }, - rets = a_tuple { a_function { args = a_tuple {}, rets = a_tuple { k, v } } } + args = { a }, + rets = { a_fn { args = {}, rets = { k, v } } } } end), - ["__tostring"] = a_function { args = a_tuple { a }, rets = a_tuple { STRING } }, + ["__tostring"] = a_fn { args = { a }, rets = { STRING } }, ["__name"] = STRING, - ["__add"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__sub"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__mul"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__div"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__idiv"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__mod"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__pow"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__unm"] = a_function { args = a_tuple { ANY }, rets = a_tuple { ANY } }, - ["__band"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__bor"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__bxor"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__bnot"] = a_function { args = a_tuple { ANY }, rets = a_tuple { ANY } }, - ["__shl"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__shr"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__concat"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { ANY } }, - ["__eq"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, - ["__lt"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, - ["__le"] = a_function { args = a_tuple { ANY, ANY }, rets = a_tuple { BOOLEAN } }, - ["__close"] = a_function { args = a_tuple { a }, rets = a_tuple { } }, + ["__add"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__sub"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__mul"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__div"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__idiv"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__mod"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__pow"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__unm"] = a_fn { args = { ANY }, rets = { ANY } }, + ["__band"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__bor"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__bxor"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__bnot"] = a_fn { args = { ANY }, rets = { ANY } }, + ["__shl"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__shr"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__concat"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, + ["__eq"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, + ["__lt"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, + ["__le"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, + ["__close"] = a_fn { args = { a }, rets = { } }, }, } end) ), ["coroutine"] = a_record { fields = { - ["create"] = a_function { args = a_tuple { FUNCTION }, rets = a_tuple { THREAD } }, - ["close"] = a_function { args = a_tuple { THREAD }, rets = a_tuple { BOOLEAN, STRING } }, - ["isyieldable"] = a_function { args = a_tuple {}, rets = a_tuple { BOOLEAN } }, - ["resume"] = a_function { args = a_vararg { THREAD, ANY }, rets = a_vararg { BOOLEAN, ANY } }, - ["running"] = a_function { args = a_tuple {}, rets = a_tuple { THREAD, BOOLEAN } }, - ["status"] = a_function { args = a_tuple { THREAD }, rets = a_tuple { STRING } }, - ["wrap"] = a_function { args = a_tuple { FUNCTION }, rets = a_tuple { FUNCTION } }, - ["yield"] = a_function { args = a_vararg { ANY }, rets = a_vararg { ANY } }, + ["create"] = a_fn { args = { FUNCTION }, rets = { THREAD } }, + ["close"] = a_fn { args = { THREAD }, rets = { BOOLEAN, STRING } }, + ["isyieldable"] = a_fn { args = {}, rets = { BOOLEAN } }, + ["resume"] = a_fn { args = va_args { THREAD, ANY }, rets = va_args { BOOLEAN, ANY } }, + ["running"] = a_fn { args = {}, rets = { THREAD, BOOLEAN } }, + ["status"] = a_fn { args = { THREAD }, rets = { STRING } }, + ["wrap"] = a_fn { args = { FUNCTION }, rets = { FUNCTION } }, + ["yield"] = a_fn { args = va_args { ANY }, rets = va_args { ANY } }, } }, ["debug"] = a_record { @@ -6063,141 +6122,141 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["Hook"] = a_typedecl(DEBUG_HOOK_FUNCTION), ["HookEvent"] = a_typedecl(DEBUG_HOOK_EVENT), - ["debug"] = a_function { args = a_tuple {}, rets = a_tuple {} }, - ["gethook"] = a_function { args = a_tuple { OPT(THREAD) }, rets = a_tuple { DEBUG_HOOK_FUNCTION, INTEGER } }, + ["debug"] = a_fn { args = {}, rets = {} }, + ["gethook"] = a_fn { args = { OPT(THREAD) }, rets = { DEBUG_HOOK_FUNCTION, INTEGER } }, ["getlocal"] = a_poly { - a_function { args = a_tuple { THREAD, FUNCTION, NUMBER }, rets = a_tuple { STRING } }, - a_function { args = a_tuple { THREAD, NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, - a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { STRING } }, - a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { STRING, ANY } }, + a_fn { args = { THREAD, FUNCTION, NUMBER }, rets = { STRING } }, + a_fn { args = { THREAD, NUMBER, NUMBER }, rets = { STRING, ANY } }, + a_fn { args = { FUNCTION, NUMBER }, rets = { STRING } }, + a_fn { args = { NUMBER, NUMBER }, rets = { STRING, ANY } }, }, - ["getmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a }, rets = a_tuple { METATABLE(a) } } end), - ["getregistry"] = a_function { args = a_tuple {}, rets = a_tuple { TABLE } }, - ["getupvalue"] = a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { ANY } }, - ["getuservalue"] = a_function { args = a_tuple { USERDATA, NUMBER }, rets = a_tuple { ANY } }, + ["getmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a }, rets = { METATABLE(a) } } end), + ["getregistry"] = a_fn { args = {}, rets = { TABLE } }, + ["getupvalue"] = a_fn { args = { FUNCTION, NUMBER }, rets = { ANY } }, + ["getuservalue"] = a_fn { args = { USERDATA, NUMBER }, rets = { ANY } }, ["sethook"] = a_poly { - a_function { args = a_tuple { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = a_tuple {} }, - a_function { args = a_tuple { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = a_tuple {} }, + a_fn { args = { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }, + a_fn { args = { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }, }, ["setlocal"] = a_poly { - a_function { args = a_tuple { THREAD, NUMBER, NUMBER, ANY }, rets = a_tuple { STRING } }, - a_function { args = a_tuple { NUMBER, NUMBER, ANY }, rets = a_tuple { STRING } }, + a_fn { args = { THREAD, NUMBER, NUMBER, ANY }, rets = { STRING } }, + a_fn { args = { NUMBER, NUMBER, ANY }, rets = { STRING } }, }, - ["setmetatable"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { a, METATABLE(a) }, rets = a_tuple { a } } end), - ["setupvalue"] = a_function { args = a_tuple { FUNCTION, NUMBER, ANY }, rets = a_tuple { STRING } }, - ["setuservalue"] = a_function { args = a_tuple { USERDATA, ANY, NUMBER }, rets = a_tuple { USERDATA } }, + ["setmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a, METATABLE(a) }, rets = { a } } end), + ["setupvalue"] = a_fn { args = { FUNCTION, NUMBER, ANY }, rets = { STRING } }, + ["setuservalue"] = a_fn { args = { USERDATA, ANY, NUMBER }, rets = { USERDATA } }, ["traceback"] = a_poly { - a_function { args = a_tuple { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING } }, - a_function { args = a_tuple { OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING } }, + a_fn { args = { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = { STRING } }, + a_fn { args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }, }, - ["upvalueid"] = a_function { args = a_tuple { FUNCTION, NUMBER }, rets = a_tuple { USERDATA } }, - ["upvaluejoin"] = a_function { args = a_tuple { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = a_tuple {} }, + ["upvalueid"] = a_fn { args = { FUNCTION, NUMBER }, rets = { USERDATA } }, + ["upvaluejoin"] = a_fn { args = { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = {} }, ["getinfo"] = a_poly { - a_function { args = a_tuple { ANY }, rets = a_tuple { DEBUG_GETINFO_TABLE } }, - a_function { args = a_tuple { ANY, STRING }, rets = a_tuple { DEBUG_GETINFO_TABLE } }, - a_function { args = a_tuple { ANY, ANY, STRING }, rets = a_tuple { DEBUG_GETINFO_TABLE } }, + a_fn { args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }, + a_fn { args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }, + a_fn { args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }, }, }, }, ["io"] = a_record { fields = { - ["close"] = a_function { args = a_tuple { OPT(NOMINAL_FILE) }, rets = a_tuple { BOOLEAN, STRING } }, - ["flush"] = a_function { args = a_tuple {}, rets = a_tuple {} }, - ["input"] = a_function { args = a_tuple { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = a_tuple { NOMINAL_FILE } }, + ["close"] = a_fn { args = { OPT(NOMINAL_FILE) }, rets = { BOOLEAN, STRING } }, + ["flush"] = a_fn { args = {}, rets = {} }, + ["input"] = a_fn { args = { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = { NOMINAL_FILE } }, ["lines"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - return a_function { args = ctor(args), rets = a_tuple { - a_function { args = a_tuple {}, rets = ctor(rets) }, + return a_fn { args = ctor(args), rets = { + a_fn { args = {}, rets = ctor(rets) }, } } end), - ["open"] = a_function { args = a_tuple { STRING, OPT(STRING) }, rets = a_tuple { NOMINAL_FILE, STRING } }, - ["output"] = a_function { args = a_tuple { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = a_tuple { NOMINAL_FILE } }, - ["popen"] = a_function { args = a_tuple { STRING, OPT(STRING) }, rets = a_tuple { NOMINAL_FILE, STRING } }, + ["open"] = a_fn { args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }, + ["output"] = a_fn { args = { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = { NOMINAL_FILE } }, + ["popen"] = a_fn { args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }, ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - return a_function { args = ctor(args), rets = ctor(rets) } + return a_fn { args = ctor(args), rets = ctor(rets) } end), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_function { args = a_tuple {}, rets = a_tuple { NOMINAL_FILE } }, - ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, - ["write"] = a_function { args = a_vararg { a_union { STRING, NUMBER } }, rets = a_tuple { NOMINAL_FILE, STRING } }, + ["tmpfile"] = a_fn { args = {}, rets = { NOMINAL_FILE } }, + ["type"] = a_fn { args = { ANY }, rets = { STRING } }, + ["write"] = a_fn { args = va_args { a_union { STRING, NUMBER } }, rets = { NOMINAL_FILE, STRING } }, }, }, ["math"] = a_record { fields = { ["abs"] = a_poly { - a_function { args = a_tuple { INTEGER }, rets = a_tuple { INTEGER } }, - a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + a_fn { args = { INTEGER }, rets = { INTEGER } }, + a_fn { args = { NUMBER }, rets = { NUMBER } }, }, - ["acos"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["asin"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["atan"] = a_function { args = a_tuple { NUMBER, OPT(NUMBER) }, rets = a_tuple { NUMBER } }, - ["atan2"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, - ["ceil"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { INTEGER } }, - ["cos"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["cosh"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["deg"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["exp"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["floor"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { INTEGER } }, + ["acos"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["asin"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["atan"] = a_fn { args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }, + ["atan2"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, + ["ceil"] = a_fn { args = { NUMBER }, rets = { INTEGER } }, + ["cos"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["cosh"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["deg"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["exp"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["floor"] = a_fn { args = { NUMBER }, rets = { INTEGER } }, ["fmod"] = a_poly { - a_function { args = a_tuple { INTEGER, INTEGER }, rets = a_tuple { INTEGER } }, - a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, + a_fn { args = { INTEGER, INTEGER }, rets = { INTEGER } }, + a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, }, - ["frexp"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER, NUMBER } }, + ["frexp"] = a_fn { args = { NUMBER }, rets = { NUMBER, NUMBER } }, ["huge"] = NUMBER, - ["ldexp"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, - ["log"] = a_function { args = a_tuple { NUMBER, OPT(NUMBER) }, rets = a_tuple { NUMBER } }, - ["log10"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["ldexp"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, + ["log"] = a_fn { args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }, + ["log10"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, ["max"] = a_poly { - a_function { args = a_vararg { INTEGER }, rets = a_tuple { INTEGER } }, - a_gfunction(1, function(a: Type): FunctionType return { args = a_vararg { a }, rets = a_tuple { a } } end), - a_function { args = a_vararg { a_union { NUMBER, INTEGER } }, rets = a_tuple { NUMBER } }, - a_function { args = a_vararg { ANY }, rets = a_tuple { ANY } }, + a_fn { args = va_args { INTEGER }, rets = { INTEGER } }, + a_gfunction(1, function(a: Type): FuncArgs return { args = va_args { a }, rets = { a } } end), + a_fn { args = va_args { a_union { NUMBER, INTEGER } }, rets = { NUMBER } }, + a_fn { args = va_args { ANY }, rets = { ANY } }, }, ["maxinteger"] = a_type("integer", { needs_compat = true }), ["min"] = a_poly { - a_function { args = a_vararg { INTEGER }, rets = a_tuple { INTEGER } }, - a_gfunction(1, function(a: Type): FunctionType return { args = a_vararg { a }, rets = a_tuple { a } } end), - a_function { args = a_vararg { a_union { NUMBER, INTEGER } }, rets = a_tuple { NUMBER } }, - a_function { args = a_vararg { ANY }, rets = a_tuple { ANY } }, + a_fn { args = va_args { INTEGER }, rets = { INTEGER } }, + a_gfunction(1, function(a: Type): FuncArgs return { args = va_args { a }, rets = { a } } end), + a_fn { args = va_args { a_union { NUMBER, INTEGER } }, rets = { NUMBER } }, + a_fn { args = va_args { ANY }, rets = { ANY } }, }, ["mininteger"] = a_type("integer", { needs_compat = true }), - ["modf"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { INTEGER, NUMBER } }, + ["modf"] = a_fn { args = { NUMBER }, rets = { INTEGER, NUMBER } }, ["pi"] = NUMBER, - ["pow"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, - ["rad"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, + ["pow"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, + ["rad"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, ["random"] = a_poly { - a_function { args = a_tuple { NUMBER, OPT(NUMBER) }, rets = a_tuple { INTEGER } }, - a_function { args = a_tuple {}, rets = a_tuple { NUMBER } }, + a_fn { args = { NUMBER, OPT(NUMBER) }, rets = { INTEGER } }, + a_fn { args = {}, rets = { NUMBER } }, }, - ["randomseed"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { INTEGER, INTEGER } }, - ["sin"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["sinh"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["sqrt"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["tan"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["tanh"] = a_function { args = a_tuple { NUMBER }, rets = a_tuple { NUMBER } }, - ["tointeger"] = a_function { args = a_tuple { ANY }, rets = a_tuple { INTEGER } }, - ["type"] = a_function { args = a_tuple { ANY }, rets = a_tuple { STRING } }, - ["ult"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { BOOLEAN } }, + ["randomseed"] = a_fn { args = { NUMBER, NUMBER }, rets = { INTEGER, INTEGER } }, + ["sin"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["sinh"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["sqrt"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["tan"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["tanh"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, + ["tointeger"] = a_fn { args = { ANY }, rets = { INTEGER } }, + ["type"] = a_fn { args = { ANY }, rets = { STRING } }, + ["ult"] = a_fn { args = { NUMBER, NUMBER }, rets = { BOOLEAN } }, }, }, ["os"] = a_record { fields = { - ["clock"] = a_function { args = a_tuple {}, rets = a_tuple { NUMBER } }, + ["clock"] = a_fn { args = {}, rets = { NUMBER } }, ["date"] = a_poly { - a_function { args = a_tuple { }, rets = a_tuple { STRING } }, - a_function { args = a_tuple { an_enum { "!*t", "*t" }, OPT(NUMBER) }, rets = a_tuple { OS_DATE_TABLE } }, - a_function { args = a_tuple { OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING } }, + a_fn { args = { }, rets = { STRING } }, + a_fn { args = { an_enum { "!*t", "*t" }, OPT(NUMBER) }, rets = { OS_DATE_TABLE } }, + a_fn { args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }, }, - ["difftime"] = a_function { args = a_tuple { NUMBER, NUMBER }, rets = a_tuple { NUMBER } }, - ["execute"] = a_function { args = a_tuple { STRING }, rets = a_tuple { BOOLEAN, STRING, INTEGER } }, - ["exit"] = a_function { args = a_tuple { OPT(a_union { NUMBER, BOOLEAN }), OPT(BOOLEAN) }, rets = a_tuple {} }, - ["getenv"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, - ["remove"] = a_function { args = a_tuple { STRING }, rets = a_tuple { BOOLEAN, STRING } }, - ["rename"] = a_function { args = a_tuple { STRING, STRING}, rets = a_tuple { BOOLEAN, STRING } }, - ["setlocale"] = a_function { args = a_tuple { STRING, OPT(STRING) }, rets = a_tuple { STRING } }, - ["time"] = a_function { args = a_tuple { OPT(OS_DATE_TABLE) }, rets = a_tuple { INTEGER } }, - ["tmpname"] = a_function { args = a_tuple {}, rets = a_tuple { STRING } }, + ["difftime"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, + ["execute"] = a_fn { args = { STRING }, rets = { BOOLEAN, STRING, INTEGER } }, + ["exit"] = a_fn { args = { OPT(a_union { NUMBER, BOOLEAN }), OPT(BOOLEAN) }, rets = {} }, + ["getenv"] = a_fn { args = { STRING }, rets = { STRING } }, + ["remove"] = a_fn { args = { STRING }, rets = { BOOLEAN, STRING } }, + ["rename"] = a_fn { args = { STRING, STRING}, rets = { BOOLEAN, STRING } }, + ["setlocale"] = a_fn { args = { STRING, OPT(STRING) }, rets = { STRING } }, + ["time"] = a_fn { args = { OPT(OS_DATE_TABLE) }, rets = { INTEGER } }, + ["tmpname"] = a_fn { args = {}, rets = { STRING } }, }, }, ["package"] = a_record { @@ -6205,75 +6264,75 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} ["config"] = STRING, ["cpath"] = STRING, ["loaded"] = a_map(STRING, ANY), - ["loaders"] = an_array(a_function { args = a_tuple { STRING }, rets = a_tuple { ANY, ANY } }), - ["loadlib"] = a_function { args = a_tuple { STRING, STRING }, rets = a_tuple { FUNCTION } }, + ["loaders"] = an_array(a_fn { args = { STRING }, rets = { ANY, ANY } }), + ["loadlib"] = a_fn { args = { STRING, STRING }, rets = { FUNCTION } }, ["path"] = STRING, ["preload"] = TABLE, - ["searchers"] = an_array(a_function { args = a_tuple { STRING }, rets = a_tuple { ANY, ANY } }), - ["searchpath"] = a_function { args = a_tuple { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = a_tuple { STRING, STRING } }, + ["searchers"] = an_array(a_fn { args = { STRING }, rets = { ANY, ANY } }), + ["searchpath"] = a_fn { args = { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = { STRING, STRING } }, }, }, ["string"] = a_record { fields = { ["byte"] = a_poly { - a_function { args = a_tuple { STRING, OPT(NUMBER) }, rets = a_tuple { INTEGER } }, - a_function { args = a_tuple { STRING, NUMBER, NUMBER }, rets = a_vararg { INTEGER } }, + a_fn { args = { STRING, OPT(NUMBER) }, rets = { INTEGER } }, + a_fn { args = { STRING, NUMBER, NUMBER }, rets = va_args { INTEGER } }, }, - ["char"] = a_function { args = a_vararg { NUMBER }, rets = a_tuple { STRING } }, - ["dump"] = a_function { args = a_tuple { FUNCTION, OPT(BOOLEAN) }, rets = a_tuple { STRING } }, - ["find"] = a_function { args = a_tuple { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = a_vararg { INTEGER, INTEGER, STRING } }, - ["format"] = a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { STRING } }, - ["gmatch"] = a_function { args = a_tuple { STRING, STRING }, rets = a_tuple { - a_function { args = a_tuple {}, rets = a_vararg { STRING } }, + ["char"] = a_fn { args = va_args { NUMBER }, rets = { STRING } }, + ["dump"] = a_fn { args = { FUNCTION, OPT(BOOLEAN) }, rets = { STRING } }, + ["find"] = a_fn { args = { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = va_args { INTEGER, INTEGER, STRING } }, + ["format"] = a_fn { args = va_args { STRING, ANY }, rets = { STRING } }, + ["gmatch"] = a_fn { args = { STRING, STRING }, rets = { + a_fn { args = {}, rets = va_args { STRING } }, } }, ["gsub"] = a_poly { - a_function { args = a_tuple { STRING, STRING, a_map(STRING, STRING), OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, - a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple { STRING } }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, - a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple { NUMBER } }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, - a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple { BOOLEAN } }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, - a_function { args = a_tuple { STRING, STRING, a_function { args = a_vararg { STRING }, rets = a_tuple {} }, OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, - a_function { args = a_tuple { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = a_tuple { STRING, INTEGER } }, + a_fn { args = { STRING, STRING, a_map(STRING, STRING), OPT(NUMBER) }, rets = { STRING, INTEGER } }, + a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = { STRING } }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, + a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = { NUMBER } }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, + a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = { BOOLEAN } }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, + a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = {} }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, + a_fn { args = { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = { STRING, INTEGER } }, -- FIXME any other modes }, - ["len"] = a_function { args = a_tuple { STRING }, rets = a_tuple { INTEGER } }, - ["lower"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, - ["match"] = a_function { args = a_tuple { STRING, OPT(STRING), OPT(NUMBER) }, rets = a_vararg { STRING } }, - ["pack"] = a_function { args = a_vararg { STRING, ANY }, rets = a_tuple { STRING } }, - ["packsize"] = a_function { args = a_tuple { STRING }, rets = a_tuple { INTEGER } }, - ["rep"] = a_function { args = a_tuple { STRING, NUMBER, OPT(STRING) }, rets = a_tuple { STRING } }, - ["reverse"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, - ["sub"] = a_function { args = a_tuple { STRING, NUMBER, OPT(NUMBER) }, rets = a_tuple { STRING } }, - ["unpack"] = a_function { args = a_tuple { STRING, STRING, OPT(NUMBER) }, rets = a_vararg { ANY } }, - ["upper"] = a_function { args = a_tuple { STRING }, rets = a_tuple { STRING } }, + ["len"] = a_fn { args = { STRING }, rets = { INTEGER } }, + ["lower"] = a_fn { args = { STRING }, rets = { STRING } }, + ["match"] = a_fn { args = { STRING, OPT(STRING), OPT(NUMBER) }, rets = va_args { STRING } }, + ["pack"] = a_fn { args = va_args { STRING, ANY }, rets = { STRING } }, + ["packsize"] = a_fn { args = { STRING }, rets = { INTEGER } }, + ["rep"] = a_fn { args = { STRING, NUMBER, OPT(STRING) }, rets = { STRING } }, + ["reverse"] = a_fn { args = { STRING }, rets = { STRING } }, + ["sub"] = a_fn { args = { STRING, NUMBER, OPT(NUMBER) }, rets = { STRING } }, + ["unpack"] = a_fn { args = { STRING, STRING, OPT(NUMBER) }, rets = va_args { ANY } }, + ["upper"] = a_fn { args = { STRING }, rets = { STRING } }, }, }, ["table"] = a_record { fields = { - ["concat"] = a_function { args = a_tuple { an_array(a_union {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = a_tuple { STRING } }, + ["concat"] = a_fn { args = { an_array(a_union {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = { STRING } }, ["insert"] = a_poly { - a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, a }, rets = a_tuple {} } end), - a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), a }, rets = a_tuple {} } end), + a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, a }, rets = {} } end), + a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), a }, rets = {} } end), }, ["move"] = a_poly { - a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER }, rets = a_tuple { an_array(a) } }end ), - a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }, rets = a_tuple { an_array(a) } } end), + a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, NUMBER, NUMBER }, rets = { an_array(a) } }end ), + a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }, rets = { an_array(a) } } end), }, - ["pack"] = a_function { args = a_vararg { ANY }, rets = a_tuple { TABLE } }, - ["remove"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), OPT(NUMBER) }, rets = a_tuple { a } } end), - ["sort"] = a_gfunction(1, function(a: Type): FunctionType return { args = a_tuple { an_array(a), OPT(TABLE_SORT_FUNCTION) }, rets = a_tuple {} } end), - ["unpack"] = a_gfunction(1, function(a: Type): FunctionType return { needs_compat = true, args = a_tuple { an_array(a), OPT(NUMBER), OPT(NUMBER) }, rets = a_vararg { a } } end), + ["pack"] = a_fn { args = va_args { ANY }, rets = { TABLE } }, + ["remove"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), OPT(NUMBER) }, rets = { a } } end), + ["sort"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), OPT(TABLE_SORT_FUNCTION) }, rets = {} } end), + ["unpack"] = a_gfunction(1, function(a: Type): FuncArgs return { needs_compat = true, args = { an_array(a), OPT(NUMBER), OPT(NUMBER) }, rets = va_args { a } } end), }, }, ["utf8"] = a_record { fields = { - ["char"] = a_function { args = a_vararg { NUMBER }, rets = a_tuple { STRING } }, + ["char"] = a_fn { args = va_args { NUMBER }, rets = { STRING } }, ["charpattern"] = STRING, - ["codepoint"] = a_function { args = a_tuple { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = a_vararg { INTEGER } }, - ["codes"] = a_function { args = a_tuple { STRING }, rets = a_tuple { - a_function { args = a_tuple { STRING, OPT(NUMBER) }, rets = a_tuple { NUMBER, NUMBER } }, + ["codepoint"] = a_fn { args = { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = va_args { INTEGER } }, + ["codes"] = a_fn { args = { STRING }, rets = { + a_fn { args = { STRING, OPT(NUMBER) }, rets = { NUMBER, NUMBER } }, }, }, - ["len"] = a_function { args = a_tuple { STRING, NUMBER, NUMBER }, rets = a_tuple { INTEGER } }, - ["offset"] = a_function { args = a_tuple { STRING, NUMBER, NUMBER }, rets = a_tuple { INTEGER } }, + ["len"] = a_fn { args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }, + ["offset"] = a_fn { args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }, }, }, ["_VERSION"] = STRING, @@ -6670,25 +6729,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return u, store_errs and errs end - local function set_min_arity(f: FunctionType) - if f.min_arity then - return - end - local tuple = f.args.tuple - local n = #tuple - if f.args.is_va then - n = n - 1 - end - for i = n, 1, -1 do - if tuple[i].opt then - n = n - 1 - else - break - end - end - f.min_arity = n - end - local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple return f.min_arity < nfargs @@ -6765,7 +6805,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local copy: Type = {} seen[orig_t] = copy - copy.opt = t.opt copy.typename = t.typename copy.filename = t.filename copy.x = t.x @@ -6822,7 +6861,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - set_min_arity(t) copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) as (TupleType, boolean) @@ -8187,8 +8225,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local errs = {} local aa, ba = a.args.tuple, b.args.tuple - set_min_arity(a) - set_min_arity(b) if (not b.args.is_va) and a.min_arity > b.min_arity then table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else @@ -8451,7 +8487,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function if lax and is_unknown(func) then - func = a_function { args = a_vararg { UNKNOWN }, rets = a_vararg { UNKNOWN } } + func = a_fn { args = va_args { UNKNOWN }, rets = va_args { UNKNOWN } } end -- unwrap if tuple, resolve if nominal func = resolve_tuple_and_nominal(func) @@ -8771,7 +8807,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end local wanted = #fargs - set_min_arity(f) -- simple functions: if (passes == 1 and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) @@ -9093,20 +9128,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function add_function_definition_for_recursion(node: Node, fnargs: TupleType) - assert(fnargs.typename == "tuple") - - -- FIXME needs this copy? - local args = a_tuple({}) - args.is_va = fnargs.is_va - for _, fnarg in ipairs(fnargs.tuple) do - table.insert(args.tuple, fnarg) - end - - add_var(nil, node.name.tk, a_function { + add_var(nil, node.name.tk, type_at(node, a_function { + min_arity = node.min_arity, typeargs = node.typeargs, - args = args, + args = fnargs, rets = get_rets(node.rets), - }) + })) end local function fail_unresolved() @@ -11048,14 +11075,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_function_scope(node) - local t = ensure_fresh_typeargs(a_function { - y = node.y, - x = node.x, + local t = type_at(node, ensure_fresh_typeargs(a_function { + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, - }) + })) add_var(node, node.name.tk, t) return t @@ -11079,15 +11104,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string check_macroexp_arg_use(node.macrodef) - local t = ensure_fresh_typeargs(a_function { - y = node.y, - x = node.x, + local t = type_at(node, ensure_fresh_typeargs(a_function { + min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, macroexp = node.macrodef, - }) + })) add_var(node, node.name.tk, t) return t @@ -11128,14 +11151,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return NONE end - add_global(node, node.name.tk, ensure_fresh_typeargs(a_function { - y = node.y, - x = node.x, + add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function { + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, - })) + }))) return NONE end, @@ -11192,15 +11213,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(nil, "self", selftype) end - local fn_type = ensure_fresh_typeargs(a_function { - y = node.y, - x = node.x, + local fn_type = type_at(node, ensure_fresh_typeargs(a_function { + min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, rets = get_rets(rets), - filename = filename, - }) + })) local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk @@ -11267,14 +11286,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(rets is TupleType) end_function_scope(node) - return ensure_fresh_typeargs(a_function { - y = node.y, - x = node.x, + return type_at(node, ensure_fresh_typeargs(a_function { + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - filename = filename, - }) + })) end, }, ["macroexp"] = { @@ -11295,14 +11312,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(rets is TupleType) end_function_scope(node) - return ensure_fresh_typeargs(a_function { - y = node.y, - x = node.x, + return type_at(node, ensure_fresh_typeargs(a_function { + min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - filename = filename, - }) + })) end, }, ["cast"] = { @@ -11743,9 +11758,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.tk == "..." then t = a_vararg { t } end - if node.opt then - t = OPT(t) - end add_var(node, node.tk, t).is_func_arg = true return t end, From 20fa3d870f7d73ac64765a81e1cfa7e091b85534 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 8 Jan 2024 01:28:13 -0300 Subject: [PATCH 091/224] tl.new_env: simpler, more extensible API --- tl.lua | 20 ++++++++++++++++++-- tl.tl | 20 ++++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index 67dee5ec4..985d62deb 100644 --- a/tl.lua +++ b/tl.lua @@ -1,7 +1,15 @@ local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack; local utf8 = _tl_compat and _tl_compat.utf8 or utf8 local VERSION = "0.15.3+dev" -local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, } +local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, EnvOptions = {}, } + + + + + + + + @@ -6359,6 +6367,15 @@ a_grecord(1, function(a) return a_record({ return globals, standard_library end +tl.new_env = function(opts) + local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) + if not env then + return nil, err + end + + return env +end + tl.init_env = function(lax, gen_compat, gen_target, predefined) if gen_compat == true or gen_compat == nil then gen_compat = "optional" @@ -6382,7 +6399,6 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) local globals, standard_library = init_globals(lax) local env = { - ok = true, modules = {}, loaded = {}, loaded_order = {}, diff --git a/tl.tl b/tl.tl index 7d7288457..4a08655ad 100644 --- a/tl.tl +++ b/tl.tl @@ -134,11 +134,19 @@ local record tl tr: TypeReport end + record EnvOptions + lax_mode: boolean + gen_compat: CompatMode + gen_target: TargetMode + predefined_modules: {string} + end + load: function(string, string, LoadMode, {any:any}): LoadFunction, string process: function(string, Env, string, FILE): (Result, string) process_string: function(string, boolean, Env, ? string, ? string): Result gen: function(string, Env, PrettyPrintOptions): string, Result type_check: function(Node, TypeCheckOptions): Result, string + new_env: function(EnvOptions): Env, string init_env: function(? boolean, ? boolean | CompatMode, ? TargetMode, ? {string}): Env, string version: function(): string @@ -6359,6 +6367,15 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return globals, standard_library end +tl.new_env = function(opts: tl.EnvOptions): Env, string + local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) + if not env then + return nil, err + end + + return env +end + tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_target?: TargetMode, predefined?: {string}): Env, string if gen_compat == true or gen_compat == nil then gen_compat = "optional" @@ -6381,8 +6398,7 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar local globals, standard_library = init_globals(lax) - local env = { - ok = true, + local env: Env = { modules = {}, loaded = {}, loaded_order = {}, From 59e5fe1c1aaddabc1496c885cf4d5fb9d6722cc2 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 8 Jan 2024 01:28:48 -0300 Subject: [PATCH 092/224] tl: use new_env --- tl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tl b/tl index 318f93b66..114eb56e1 100755 --- a/tl +++ b/tl @@ -137,15 +137,19 @@ local function setup_env(tlconfig, filename) lax_mode = false end - local gen_compat = tlconfig["gen_compat"] - local gen_target = tlconfig["gen_target"] - tlconfig._init_env_modules = tlconfig._init_env_modules or {} if tlconfig.global_env_def then table.insert(tlconfig._init_env_modules, 1, tlconfig.global_env_def) end - local env, err = tl.init_env(lax_mode, gen_compat, gen_target, tlconfig._init_env_modules) + local opts = { + lax_mode = lax_mode, + gen_compat = tlconfig["gen_compat"], + gen_target = tlconfig["gen_target"], + predefined_modules = tlconfig._init_env_modules, + } + + local env, err = tl.new_env(opts) if not env then die(err) end From 94cf23c2538b42fe4a9a7c7c69c428466afde318 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 8 Jan 2024 01:57:37 -0300 Subject: [PATCH 093/224] language flag: feat-arity=on/off Enable or disable the behavior of '?' argument annotations, providing backwards compatibility with existing codebases. --- spec/cli/feat_spec.lua | 73 ++++++++++++++++++++++++++++++++++++++++++ tl | 7 ++++ tl.lua | 25 +++++++++++++-- tl.tl | 25 +++++++++++++-- 4 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 spec/cli/feat_spec.lua diff --git a/spec/cli/feat_spec.lua b/spec/cli/feat_spec.lua new file mode 100644 index 000000000..286033387 --- /dev/null +++ b/spec/cli/feat_spec.lua @@ -0,0 +1,73 @@ +local assert = require("luassert") +local util = require("spec.util") + +local test_cases = { + ["feat-arity"] = { + { + code = [[ + -- without '?' annotation + local function add(a: number, b: number): number + return a + (b or 0) + end + + print(add()) + print(add(10)) + print(add(10, 20)) + print(add(10, 20, 30)) + + -- with '?' annotation + local function sub(a: number, b?: number): number + return a - (b or 0) + end + + print(sub()) + print(sub(10)) + print(sub(10, 20)) + print(sub(10, 20, 30)) + ]], + values = { + on = { + describe = "allows defining minimum arities for functions based on optional argument annotations", + status = 1, + match = { + "5 errors:", + ":6:22: wrong number of arguments (given 0, expects 2)", + ":7:22: wrong number of arguments (given 1, expects 2)", + ":9:22: wrong number of arguments (given 3, expects 2)", + ":16:22: wrong number of arguments (given 0, expects at least 1 and at most 2)", + ":19:22: wrong number of arguments (given 3, expects at least 1 and at most 2)", + }, + }, + off = { + describe = "ignores missing arguments", + status = 1, + match = { + "2 errors:", + ":9:22: wrong number of arguments (given 3, expects 2)", + ":19:22: wrong number of arguments (given 3, expects at least 1 and at most 2)", + } + } + } + } + } +} + +describe("feat flags", function() + for flag, tests in pairs(test_cases) do + describe(flag, function() + for _, case in ipairs(tests) do + for value, data in pairs(case.values) do + it("--" .. flag .. "=" .. value .. " " .. data.describe, function() + local name = util.write_tmp_file(finally, case.code) + local pd = io.popen(util.tl_cmd("check --" .. flag .. "=" .. value, name) .. "2>&1 1>" .. util.os_null, "r") + local output = pd:read("*a") + util.assert_popen_close(data.status, pd:close()) + for _, s in ipairs(data.match) do + assert.match(s, output, 1, true) + end + end) + end + end + end) + end +end) diff --git a/tl b/tl index 114eb56e1..f3362f686 100755 --- a/tl +++ b/tl @@ -144,6 +144,7 @@ local function setup_env(tlconfig, filename) local opts = { lax_mode = lax_mode, + feat_arity = tlconfig["feat_arity"], gen_compat = tlconfig["gen_compat"], gen_target = tlconfig["gen_target"], predefined_modules = tlconfig._init_env_modules, @@ -292,6 +293,7 @@ local function validate_config(config) global_env_def = "string", quiet = "boolean", skip_compat53 = "boolean", + feat_arity = { ["off"] = true, ["on"] = true }, gen_compat = { ["off"] = true, ["optional"] = true, ["required"] = true }, gen_target = { ["5.1"] = true, ["5.3"] = true, ["5.4"] = true }, disable_warnings = "{string}", @@ -352,6 +354,9 @@ local function get_args_parser() :argname("") :count("*") + parser:option("--feat-arity", "Define minimum arities for functions based on optional argument annotations.") + :choices({ "off", "on" }) + parser:option("--gen-compat", "Generate compatibility code for targeting different Lua VM versions.") :choices({ "off", "optional", "required" }) :default("optional") @@ -492,6 +497,8 @@ local function merge_config_and_args(tlconfig, args) tlconfig["pretend"] = true end + tlconfig["feat_arity"] = args["feat_arity"] or tlconfig["feat_arity"] + tlconfig["gen_target"] = args["gen_target"] or tlconfig["gen_target"] tlconfig["gen_compat"] = args["gen_compat"] or tlconfig["gen_compat"] or (tlconfig["skip_compat53"] and "off") diff --git a/tl.lua b/tl.lua index 985d62deb..aadac8ff0 100644 --- a/tl.lua +++ b/tl.lua @@ -146,6 +146,13 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = { + + + + + + + @@ -6367,12 +6374,22 @@ a_grecord(1, function(a) return a_record({ return globals, standard_library end +local function set_feat(feat, default) + if feat then + return (feat == "on") + else + return default + end +end + tl.new_env = function(opts) local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) if not env then return nil, err end + env.feat_arity = set_feat(opts.feat_arity, true) + return env end @@ -6424,6 +6441,8 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) end end + env.feat_arity = true + return env end @@ -6443,6 +6462,7 @@ tl.type_check = function(ast, opts) end local lax = opts.lax + local feat_arity = env.feat_arity local filename = opts.filename @@ -8823,13 +8843,14 @@ a.types[i], b.types[i]), } end end local wanted = #fargs + local min_arity = feat_arity and f.min_arity or 0 - if (passes == 1 and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or (passes == 3 and ((pass == 1 and given == wanted) or - (pass == 2 and given < wanted and (lax or given >= f.min_arity)) or + (pass == 2 and given < wanted and (lax or given >= min_arity)) or (pass == 3 and f.args.is_va and given > wanted))) then diff --git a/tl.tl b/tl.tl index 4a08655ad..15d14bbd0 100644 --- a/tl.tl +++ b/tl.tl @@ -24,6 +24,11 @@ local record tl "5.4" end + enum Feat + "on" + "off" + end + record PrettyPrintOptions preserve_indent: boolean preserve_newlines: boolean @@ -50,6 +55,7 @@ local record tl gen_target: TargetMode keep_going: boolean report_types: boolean + feat_arity: boolean end record Symbol @@ -138,6 +144,7 @@ local record tl lax_mode: boolean gen_compat: CompatMode gen_target: TargetMode + feat_arity: Feat predefined_modules: {string} end @@ -6367,12 +6374,22 @@ local function init_globals(lax: boolean): {string:Variable}, {string:Type} return globals, standard_library end +local function set_feat(feat: tl.Feat, default: boolean): boolean + if feat then + return (feat == "on") + else + return default + end +end + tl.new_env = function(opts: tl.EnvOptions): Env, string local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) if not env then return nil, err end + env.feat_arity = set_feat(opts.feat_arity, true) + return env end @@ -6424,6 +6441,8 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar end end + env.feat_arity = true + return env end @@ -6443,6 +6462,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local lax = opts.lax + local feat_arity = env.feat_arity local filename = opts.filename local type Scope = {string:Variable} @@ -8823,13 +8843,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end local wanted = #fargs + local min_arity = feat_arity and f.min_arity or 0 -- simple functions: - if (passes == 1 and ((given <= wanted and given >= f.min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) -- poly, pass 1: try exact arity matches first or (passes == 3 and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and (lax or given >= f.min_arity)) + or (pass == 2 and given < wanted and (lax or given >= min_arity)) -- poly, pass 3: then finally try vararg functions or (pass == 3 and f.args.is_va and given > wanted))) then From a5e1415e150fb929201254cad7dfee73020c6495 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 7 Jan 2024 16:07:59 -0300 Subject: [PATCH 094/224] standard library definition written in Teal --- spec/assignment/to_const_spec.lua | 2 +- spec/call/string_method_spec.lua | 2 +- spec/cli/types_spec.lua | 4 +- spec/declaration/global_spec.lua | 2 +- spec/operator/index_spec.lua | 4 +- spec/stdlib/io_spec.lua | 10 +- tl | 5 + tl.lua | 1012 ++++++++++++++-------------- tl.tl | 1024 ++++++++++++++--------------- 9 files changed, 1005 insertions(+), 1060 deletions(-) diff --git a/spec/assignment/to_const_spec.lua b/spec/assignment/to_const_spec.lua index 9783bc1ca..83a450419 100644 --- a/spec/assignment/to_const_spec.lua +++ b/spec/assignment/to_const_spec.lua @@ -35,7 +35,7 @@ describe("assignment to const", function() local b = setmetatable ]], { { y = 2, x = 13, msg = "to-be-closed variable a has a non-closable type function()" }, - { y = 3, x = 13, msg = "to-be-closed variable b has a non-closable type function(A, metatable): A" }, + { y = 3, x = 13, msg = "to-be-closed variable b has a non-closable type function(T, metatable): T" }, }, "5.4")) end) diff --git a/spec/call/string_method_spec.lua b/spec/call/string_method_spec.lua index 1816c443b..56c088833 100644 --- a/spec/call/string_method_spec.lua +++ b/spec/call/string_method_spec.lua @@ -29,7 +29,7 @@ describe("string method call", function() local s = "a" s = s:gsub(function() end) .. "!" ]], { - { msg = "argument 1: got function" }, + { msg = "wrong number of arguments" }, })) end) describe("chained", function() diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index d36aafd9a..2f1180e71 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -327,8 +327,8 @@ describe("tl types works like check", function() assert.same({ ["17"] = 3, ["20"] = 4, - ["25"] = 15, - ["30"] = 14, + ["25"] = 17, + ["30"] = 16, ["31"] = 2, }, by_pos["2"]) end) diff --git a/spec/declaration/global_spec.lua b/spec/declaration/global_spec.lua index a88881fc0..1df30e985 100644 --- a/spec/declaration/global_spec.lua +++ b/spec/declaration/global_spec.lua @@ -101,7 +101,7 @@ describe("global", function() it("fails if removing const", util.check_type_error([[ global string: number ]], { - { msg = "global was previously declared as " }, + { msg = "cannot redeclare global with a different type" }, })) it("fails if removing const", util.check_type_error([[ diff --git a/spec/operator/index_spec.lua b/spec/operator/index_spec.lua index 40d9009df..4bbd3582c 100644 --- a/spec/operator/index_spec.lua +++ b/spec/operator/index_spec.lua @@ -84,7 +84,7 @@ describe("[]", function() s:gsub("hello", "world") s:len() s:lower() - s:match() + s:match("hello") s:pack() s:packsize() s:rep(2) @@ -108,7 +108,7 @@ describe("[]", function() s:gsub("hello", "world") s:len() s:lower() - s:match() + s:match("hello") s:pack() s:packsize() s:rep(2) diff --git a/spec/stdlib/io_spec.lua b/spec/stdlib/io_spec.lua index fffc39d0f..7b785a5b6 100644 --- a/spec/stdlib/io_spec.lua +++ b/spec/stdlib/io_spec.lua @@ -76,17 +76,17 @@ describe("io", function() ]])) it("with a numeric format", util.check([[ - for a in io.lines("n") do + for a in io.lines("filename.txt", "n") do print(a * 2) end - for a in io.lines("*n") do + for a in io.lines("filename.txt", "*n") do print(a * 2) end ]])) it("resolves the type of mixed numeric/string formats as unions for now", util.check([[ - for a, b in io.lines("n", 12) do + for a, b in io.lines("filename.txt", "n", 12) do if a is number then print(a * 2) end @@ -174,7 +174,7 @@ describe("io", function() ]])) it("with a bytes format argument", util.check([[ - for c in io.popen("ls"):lines("filename.txt", 1) do + for c in io.popen("ls"):lines(1) do print(c:upper()) end ]])) @@ -196,7 +196,7 @@ describe("io", function() ]])) it("with multiple formats", util.check([[ - for a, b, c in io.popen("ls"):lines("filename.txt", "l", 12, 13) do + for a, b, c in io.popen("ls"):lines("l", 12, 13) do print(a:upper()) print(b:upper()) print(c:upper()) diff --git a/tl b/tl index f3362f686..89e4990fc 100755 --- a/tl +++ b/tl @@ -186,6 +186,8 @@ do for _, err in ipairs(errors) do printerr(err.filename .. ":" .. err.y .. ":" .. err.x .. ": " .. (err.msg or "")) end + printerr("----------------------------------------") + printerr(n .. " " .. category .. (n ~= 1 and "s" or "")) return true end return false @@ -867,6 +869,9 @@ do local tr, trenv for i, input_file in ipairs(args["file"]) do local pok, result = pcall(tl.process, input_file, env) + if not pok then + die("Internal Compiler Error: " .. result) + end if pok then if result and result.ast then tr, trenv = tl.get_types(result, trenv) diff --git a/tl.lua b/tl.lua index aadac8ff0..3d2062a69 100644 --- a/tl.lua +++ b/tl.lua @@ -1,6 +1,452 @@ local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack; local utf8 = _tl_compat and _tl_compat.utf8 or utf8 local VERSION = "0.15.3+dev" +local stdlib = [=====[ + +do + global type any + global type thread + global type userdata + + local enum FileStringMode + "a" "l" "L" "*a" "*l" "*L" + end + + local enum FileNumberMode + "n" "*n" + end + + local enum FileMode + "a" "l" "L" "*a" "*l" "*L" "n" "*n" + end + + global record FILE + userdata + + enum SeekWhence + "set" "cur" "end" + end + + enum SetVBufMode + "no" "full" "line" + end + + close: function(FILE): boolean, string, number + flush: function(FILE) + + lines: function(FILE): (function(): (string)) + lines: function(FILE, FileNumberMode...): (function(): (number...)) + lines: function(FILE, (number | FileStringMode)...): (function(): (string...)) + lines: function(FILE, (number | FileMode)...): (function(): ((string | number)...)) + lines: function(FILE, (number | string)...): (function(): (string...)) + + read: function(FILE): string + read: function(FILE, FileNumberMode...): number... + read: function(FILE, (number | FileStringMode)...): string... + read: function(FILE, (number | FileMode)...): ((string | number)...) + read: function(FILE, (number | string)...): (string...) + + seek: function(FILE, ? SeekWhence, ? number): integer, string + setvbuf: function(FILE, SetVBufMode, ? number) + + write: function(FILE, (string | number)...): FILE, string + + metamethod __close: function(FILE) + end + + global record coroutine + type Function = function(any...): any... + + create: function(Function): thread + close: function(thread): boolean, string + isyieldable: function(): boolean + resume: function(thread, any...): boolean, any... + running: function(): thread, boolean + status: function(thread): string + wrap: function(F): F + yield: function(any...): any... + end + + global record debug + record GetInfoTable + name: string + namewhat: string + source: string + short_src: string + linedefined: integer + lastlinedefined: integer + what: string + currentline: integer + istailcall: boolean + nups: integer + nparams: integer + isvararg: boolean + func: any + activelines: {integer:boolean} + end + + enum HookEvent + "call" "tail call" "return" "line" "count" + end + + type HookFunction = function(HookEvent, integer) + + type AnyFunction = function(any...):any... + + debug: function() + gethook: function(? thread): HookFunction, integer + + getinfo: function(AnyFunction | number): GetInfoTable + getinfo: function(AnyFunction | number, string): GetInfoTable + getinfo: function(thread, AnyFunction | number, string): GetInfoTable + + getlocal: function(thread, AnyFunction, number): string + getlocal: function(thread, number, number): string, any + getlocal: function(AnyFunction, number): string + getlocal: function(number, number): string, any + + getmetatable: function(T): metatable + getregistry: function(): {any:any} + getupvalue: function(AnyFunction, number): any + getuservalue: function(userdata, number): any + + sethook: function(thread, HookFunction, string, ? number) + sethook: function(HookFunction, string, ? number) + + setlocal: function(thread, number, number, any): string + setlocal: function(number, number, any): string + + setmetatable: function(T, metatable): T + setupvalue: function(AnyFunction, number, any): string + setuservalue: function(U, any, number): U --[[U is userdata]] + + traceback: function(thread, ? string, ? number): string + traceback: function(? string, ? number): string + + upvalueid: function(AnyFunction, number): userdata + upvaluejoin: function(AnyFunction, number, AnyFunction, number) + end + + global record io + enum OpenMode + "r" "w" "a" "r+" "w+" "a+" + "rb" "wb" "ab" "r+b" "w+b" "a+b" + "*r" "*w" "*a" "*r+" "*w+" "*a+" + "*rb" "*wb" "*ab" "*r+b" "*w+b" "*a+b" + end + + close: function(? FILE) + input: function(? FILE): FILE + flush: function() + + lines: function(? string): (function(): (string)) + lines: function(? string, FileNumberMode...): (function(): (number...)) + lines: function(? string, (number | FileStringMode)...): (function(): (string...)) + lines: function(? string, (number | FileMode)...): (function(): ((string | number)...)) + lines: function(? string, (number | string)...): (function(): (string...)) + + open: function(string, ? OpenMode): FILE, string + output: function(? FILE): FILE + popen: function(string, ? OpenMode): FILE, string + + read: function(): string + read: function(FileNumberMode...): number... + read: function((number | FileStringMode)...): string... + read: function((number | FileMode)...): ((string | number)...) + read: function((number | string)...): (string...) + + stderr: FILE + stdin: FILE + stdout: FILE + tmpfile: function(): FILE + type: function(any): string + write: function((string | number)...): FILE, string + end + + global record math + abs: function(integer): integer + abs: function(number): number + + acos: function(number): number + asin: function(number): number + atan: function(number, ? number): number + atan2: function(number, number): number + ceil: function(number): integer + cos: function(number): number + cosh: function(number): number + deg: function(number): number + exp: function(number): number + floor: function(number): integer + + fmod: function(integer, integer): integer + fmod: function(number, number): number + + frexp: function(number): number, number + huge: number + ldexp: function(number, number): number + log: function(number, ? number): number + log10: function(number): number + + max: function(integer...): integer + max: function((number | integer)...): number + max: function(T...): T + max: function(any...): any + + maxinteger: integer --[[needs_compat]] + + min: function(integer...): integer + min: function((number | integer)...): number + min: function(T...): T + min: function(any...): any + + mininteger: integer --[[needs_compat]] + + modf: function(number): integer, number + pi: number + pow: function(number, number): number + rad: function(number): number + + random: function(number, ? number): integer + random: function(): number + + randomseed: function(number, number): integer, integer + sin: function(number): number + sinh: function(number): number + sqrt: function(number): number + tan: function(number): number + tanh: function(number): number + tointeger: function(any): integer + type: function(any): string + ult: function(number, number): boolean + end + + global record metatable + enum Mode + "k" "v" "kv" + end + + __call: function(T, any...): any... + __mode: Mode + __name: string + __tostring: function(T): string + __pairs: function(T): (function(): (K, V)) + + __index: any --[[FIXME: function | table | anything with an __index metamethod]] + __newindex: any --[[FIXME: function | table | anything with an __index metamethod]] + + __gc: function(T) + __close: function(T) + + __add: function(any, any): any + __sub: function(any, any): any + __mul: function(any, any): any + __div: function(any, any): any + __idiv: function(any, any): any + __mod: function(any, any): any + __pow: function(any, any): any + __band: function(any, any): any + __bor: function(any, any): any + __bxor: function(any, any): any + __shl: function(any, any): any + __shr: function(any, any): any + __concat: function(any, any): any + + __len: function(T): any + __unm: function(T): any + __bnot: function(T): any + + __eq: function(any, any): boolean + __lt: function(any, any): boolean + __le: function(any, any): boolean + end + + global record os + record DateTable + year: integer + month: integer + day: integer + hour: integer + min: integer + sec: integer + wday: integer + yday: integer + isdst: boolean + end + + enum DateMode + "!*t" "*t" + end + + clock: function(): number + + date: function(DateMode, ? number): DateTable + date: function(? string, ? number): string + + difftime: function(number, number): number + execute: function(string): boolean, string, integer + exit: function(? (number | boolean), ? boolean) + getenv: function(string): string + remove: function(string): boolean, string + rename: function(string, string): boolean, string + setlocale: function(string, ? string): string + time: function(? DateTable): integer + tmpname: function(): string + end + + global record package + config: string + cpath: string + loaded: {string:any} + loaders: { function(string): any, any } + path: string + preload: {any:any} + searchers: { function(string): any } + end + + global record string + char: function(number...): string + + byte: function(string, ? number): integer + byte: function(string, number, number): integer... + + dump: function(function(any...): (any), ? boolean): string + find: function(string, string, ? number, ? boolean): integer, integer, string + format: function(string, any...): string + gmatch: function(string, string): (function(): string...) + + gsub: function(string, string, string, ? number): string, integer + gsub: function(string, string, {string:string}, ? number): string, integer + gsub: function(string, string, function(string...): (string | number | boolean), ? number): string, integer + + len: function(string): integer + lower: function(string): string + match: function(string, string, ? number): string... + pack: function(string, any...): string + packsize: function(string): integer + rep: function(string, number, ? string): string + reverse: function(string): string + sub: function(string, number, ? number): string + unpack: function(string, string, ? number): any... + upper: function(string): string + end + + global record table + type SortFunction = function(A, A): boolean + + record PackTable + is {A} + + n: integer + end + + concat: function({(string | number)}, ? string, ? number, ? number): string + + insert: function({A}, number, A) + insert: function({A}, A) + + pack: function(T...): PackTable + pack: function(any...): {any:any} + + remove: function({A}, ? number): A + sort: function({A}, ? SortFunction) + + unpack: function({A}, ? number, ? number): A... --[[needs_compat]] + end + + global record utf8 + char: function(number...): string + charpattern: string + codepoint: function(string, ? number, ? number, ? boolean): number... + codes: function(string, ? boolean): (function(string, ? number): (number, number)) + len: function(string, ? number, ? number, ? boolean): number + offset: function(string, number, ? number): number + end + + local record StandardLibrary + enum CollectGarbageCommand + "collect" + "count" + "stop" + "restart" + end + + enum CollectGarbageSetValue + "step" + "setpause" + "setstepmul" + end + + enum CollectGarbageIsRunning + "isrunning" + end + + type LoadFunction = function(): string + + enum LoadMode + "b" "t" "bt" + end + + type XpcallMsghFunction = function(...: any): () + + arg: {string} + assert: function(A, ? B): A + + collectgarbage: function(CollectGarbageCommand): number + collectgarbage: function(CollectGarbageSetValue, number): number + collectgarbage: function(CollectGarbageIsRunning): boolean + collectgarbage: function(string, ? number): (boolean | number) + + error: function(? any, ? number) + ipairs: function({A}): (function():(integer, A)) + + load: function((string | LoadFunction), ? string, ? LoadMode, ? table): (function, string) + load: function((string | LoadFunction), ? string, ? string, ? table): (function, string) + + next: function({K:V}, ? K): (K, V) + next: function({A}, ? integer): (integer, A) + + pairs: function({K:V}): (function():(K, V)) + pcall: function(function(any...):(any...), any...): boolean, any... + print: function(any...) + require: function(string): any + + select: function(number, T...): T + select: function(number, any...): any + select: function(string, any...): integer + + setmetatable: function(T, metatable): T + + tonumber: function(any): number + tonumber: function(any, number): integer + + tostring: function(any): string + type: function(any): string + xpcall: function(function(any...):(any...), XpcallMsghFunction, any...): boolean, any... + _VERSION: string + end + + global arg = StandardLibrary.arg + global assert = StandardLibrary.assert + global collectgarbage = StandardLibrary.collectgarbage + global error = StandardLibrary.error + global load = StandardLibrary.load + global next = StandardLibrary.next + global pairs = StandardLibrary.pairs + global pcall = StandardLibrary.pcall + global print = StandardLibrary.print + global require = StandardLibrary.require + global select = StandardLibrary.select + global setmetatable = StandardLibrary.setmetatable + global tostring = StandardLibrary.tostring + global tonumber = StandardLibrary.tonumber + global ipairs = StandardLibrary.ipairs + global type = StandardLibrary.type + global xpcall = StandardLibrary.xpcall + global _VERSION = StandardLibrary._VERSION +end + +]=====] + local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, EnvOptions = {}, } @@ -1710,9 +2156,9 @@ end -local function OPT(t) - return { opttype = t } -end + + + @@ -5062,7 +5508,6 @@ local skip_types = { ["none"] = true, ["literal_table_item"] = true, ["unresolved"] = true, - ["typedecl"] = true, } local get_typenum @@ -5126,12 +5571,14 @@ get_typenum = function(trenv, t) n = trenv.next_num local rt = t + if rt.typename == "tuple" and #rt.tuple == 1 then + rt = rt.tuple[1] + end + if rt.typename == "typedecl" then rt = rt.def elseif rt.typename == "typealias" then rt = rt.alias_to - elseif rt.typename == "tuple" and #rt.tuple == 1 then - rt = rt.tuple[1] end local ti = { @@ -5221,10 +5668,10 @@ local CIRCULAR_REQUIRE = a_type("circular_require", {}) local FUNCTION = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }) -local NOMINAL_FILE = a_type("nominal", { names = { "FILE" } }) + local XPCALL_MSGH_FUNCTION = a_fn({ args = { ANY }, rets = {} }) -local USERDATA = ANY + local numeric_binop = { ["number"] = { @@ -5864,516 +6311,10 @@ local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2) node.e2[4] = e2 end +local stdlib_globals = nil local globals_typeid local fresh_typevar_ctr = 1 -local function init_globals(lax) - local globals = {} - local stdlib_compat = get_stdlib_compat(lax) - - - local is_first_init = globals_typeid == nil - - local save_typeid = last_typeid - if is_first_init then - globals_typeid = new_typeid() - else - last_typeid = globals_typeid - end - - local function a_record(t) - t = a_type("record", t) - assert(t.fields) - t.field_order = sorted_keys(t.fields) - return t - end - - local function a_generic(n, f) - local typevars = {} - local typeargs = {} - local c = string.byte("A") - 1 - fresh_typevar_ctr = fresh_typevar_ctr + 1 - for i = 1, n do - local name = string.char(c + i) .. "@" .. fresh_typevar_ctr - typevars[i] = a_type("typevar", { typevar = name }) - typeargs[i] = a_type("typearg", { typearg = name }) - end - local t = f(_tl_table_unpack(typevars)) - if t.typename == "function" or t.typename == "record" then - t.typeargs = typeargs - end - return t - end - - local function a_gfunction(n, f) - return a_generic(n, function(...) return a_fn(f(...)) end) - end - - local function a_grecord(n, f) - local t = a_generic(n, f) - t.field_order = sorted_keys(t.fields) - return t - end - - local function an_enum(keys) - local t = a_type("enum", { enumset = {} }) - for _, k in ipairs(keys) do - t.enumset[k] = true - end - return t - end - - - - - - - - - - local function id(x) - return x - end - - local file_reader_poly_types = { - { ctor = va_args, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L" }) } }) }, rets = { STRING } }, - { ctor = id, args = { an_enum({ "*n", "n" }) }, rets = { NUMBER, STRING } }, - { ctor = va_args, args = { a_type("union", { types = { NUMBER, an_enum({ "*a", "a", "*l", "l", "*L", "L", "*n", "n" }) } }) }, rets = { a_type("union", { types = { STRING, NUMBER } }) } }, - { ctor = va_args, args = { a_type("union", { types = { NUMBER, STRING } }) }, rets = { STRING } }, - { ctor = va_args, args = {}, rets = { STRING } }, - } - - local function a_file_reader(fn) - local t = a_type("poly", { types = {} }) - for _, entry in ipairs(file_reader_poly_types) do - local args = shallow_copy_table(entry.args) - local rets = shallow_copy_table(entry.rets) - table.insert(t.types, fn(entry.ctor, args, rets)) - end - return t - end - - local LOAD_FUNCTION = a_fn({ args = {}, rets = { STRING } }) - - local OS_DATE_TABLE = a_record({ - fields = { - ["year"] = INTEGER, - ["month"] = INTEGER, - ["day"] = INTEGER, - ["hour"] = INTEGER, - ["min"] = INTEGER, - ["sec"] = INTEGER, - ["wday"] = INTEGER, - ["yday"] = INTEGER, - ["isdst"] = BOOLEAN, - }, - }) - - local DEBUG_GETINFO_TABLE = a_record({ - fields = { - ["name"] = STRING, - ["namewhat"] = STRING, - ["source"] = STRING, - ["short_src"] = STRING, - ["linedefined"] = INTEGER, - ["lastlinedefined"] = INTEGER, - ["what"] = STRING, - ["currentline"] = INTEGER, - ["istailcall"] = BOOLEAN, - ["nups"] = INTEGER, - ["nparams"] = INTEGER, - ["isvararg"] = BOOLEAN, - ["func"] = ANY, - ["activelines"] = a_type("map", { keys = INTEGER, values = BOOLEAN }), - }, - }) - - local DEBUG_HOOK_EVENT = an_enum({ "call", "tail call", "return", "line", "count" }) - - local DEBUG_HOOK_FUNCTION = a_fn({ - args = { DEBUG_HOOK_EVENT, INTEGER }, - rets = {}, - }) - - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = { a, a }, rets = { BOOLEAN } } end) - - local metatable_nominals = {} - - local function METATABLE(a) - local t = a_type("nominal", { names = { "metatable" }, typevals = { a } }) - table.insert(metatable_nominals, t) - return t - end - - local standard_library = { - ["..."] = a_vararg({ STRING }), - ["any"] = a_type("typedecl", { def = ANY }), - ["arg"] = a_type("array", { elements = STRING }), - ["assert"] = a_gfunction(2, function(a, b) return { args = { a, OPT(b) }, rets = { a } } end), - ["collectgarbage"] = a_type("poly", { types = { - a_fn({ args = { an_enum({ "collect", "count", "stop", "restart" }) }, rets = { NUMBER } }), - a_fn({ args = { an_enum({ "step", "setpause", "setstepmul" }), NUMBER }, rets = { NUMBER } }), - a_fn({ args = { an_enum({ "isrunning" }) }, rets = { BOOLEAN } }), - a_fn({ args = { STRING, OPT(NUMBER) }, rets = { a_type("union", { types = { BOOLEAN, NUMBER } }) } }), - } }), - ["dofile"] = a_fn({ args = { OPT(STRING) }, rets = va_args({ ANY }) }), - ["error"] = a_fn({ args = { ANY, OPT(NUMBER) }, rets = {} }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = { a }, rets = { METATABLE(a) } } end), - ["ipairs"] = a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }) }, rets = { - a_fn({ args = {}, rets = { INTEGER, a } }), -}, } end), - ["load"] = a_fn({ args = { a_type("union", { types = { STRING, LOAD_FUNCTION } }), OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }), - ["loadfile"] = a_fn({ args = { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }), - ["next"] = a_type("poly", { types = { - a_gfunction(2, function(a, b) return { args = { a_type("map", { keys = a, values = b }), OPT(a) }, rets = { a, b } } end), - a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), OPT(a) }, rets = { INTEGER, a } } end), - } }), - ["pairs"] = a_gfunction(2, function(a, b) return { args = { a_type("map", { keys = a, values = b }) }, rets = { - a_fn({ args = {}, rets = { a, b } }), -}, } end), - ["pcall"] = a_fn({ args = va_args({ FUNCTION, ANY }), rets = va_args({ BOOLEAN, ANY }) }), - ["xpcall"] = a_fn({ args = va_args({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = va_args({ BOOLEAN, ANY }) }), - ["print"] = a_fn({ args = va_args({ ANY }), rets = {} }), - ["rawequal"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), - ["rawget"] = a_fn({ args = { TABLE, ANY }, rets = { ANY } }), - ["rawlen"] = a_fn({ args = { a_type("union", { types = { TABLE, STRING } }) }, rets = { INTEGER } }), - ["rawset"] = a_type("poly", { types = { - a_gfunction(2, function(a, b) return { args = { a_type("map", { keys = a, values = b }), a, b }, rets = {} } end), - a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, a }, rets = {} } end), - a_fn({ args = { TABLE, ANY, ANY }, rets = {} }), - } }), - ["require"] = a_fn({ args = { STRING }, rets = {} }), - ["select"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = va_args({ NUMBER, a }), rets = { a } } end), - a_fn({ args = va_args({ NUMBER, ANY }), rets = { ANY } }), - a_fn({ args = va_args({ STRING, ANY }), rets = { INTEGER } }), - } }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = { a, METATABLE(a) }, rets = { a } } end), - ["tonumber"] = a_type("poly", { types = { - a_fn({ args = { ANY }, rets = { NUMBER } }), - a_fn({ args = { ANY, NUMBER }, rets = { INTEGER } }), - } }), - ["tostring"] = a_fn({ args = { ANY }, rets = { STRING } }), - ["type"] = a_fn({ args = { ANY }, rets = { STRING } }), - ["FILE"] = a_type("typedecl", { def = - a_record({ - is_userdata = true, - fields = { - ["close"] = a_fn({ args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING, INTEGER } }), - ["flush"] = a_fn({ args = { NOMINAL_FILE }, rets = {} }), - ["lines"] = a_file_reader(function(ctor, args, rets) - table.insert(args, 1, NOMINAL_FILE) - return a_fn({ args = ctor(args), rets = { - a_fn({ args = {}, rets = ctor(rets) }), - }, }) - end), - ["read"] = a_file_reader(function(ctor, args, rets) - table.insert(args, 1, NOMINAL_FILE) - return a_fn({ args = ctor(args), rets = ctor(rets) }) - end), - ["seek"] = a_fn({ args = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = { INTEGER, STRING } }), - ["setvbuf"] = a_fn({ args = { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = {} }), - ["write"] = a_fn({ args = va_args({ NOMINAL_FILE, a_type("union", { types = { STRING, NUMBER } }) }), rets = { NOMINAL_FILE, STRING } }), - - }, - meta_fields = { ["__close"] = FUNCTION }, - meta_field_order = { "__close" }, - }) }), - - ["metatable"] = a_type("typedecl", { def = -a_grecord(1, function(a) return a_record({ - fields = { - ["__call"] = a_fn({ args = va_args({ a, ANY }), rets = va_args({ ANY }) }), - ["__gc"] = a_fn({ args = { a }, rets = {} }), - ["__index"] = ANY, - ["__len"] = a_fn({ args = { a }, rets = { ANY } }), - ["__mode"] = an_enum({ "k", "v", "kv" }), - ["__newindex"] = ANY, - ["__pairs"] = a_gfunction(2, function(k, v) - return { - args = { a }, - rets = { a_fn({ args = {}, rets = { k, v } }) }, - } - end), - ["__tostring"] = a_fn({ args = { a }, rets = { STRING } }), - ["__name"] = STRING, - ["__add"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__sub"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__mul"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__div"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__idiv"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__mod"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__pow"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__unm"] = a_fn({ args = { ANY }, rets = { ANY } }), - ["__band"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__bor"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__bxor"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__bnot"] = a_fn({ args = { ANY }, rets = { ANY } }), - ["__shl"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__shr"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__concat"] = a_fn({ args = { ANY, ANY }, rets = { ANY } }), - ["__eq"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), - ["__lt"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), - ["__le"] = a_fn({ args = { ANY, ANY }, rets = { BOOLEAN } }), - ["__close"] = a_fn({ args = { a }, rets = {} }), - }, -}) end) }), - - ["coroutine"] = a_record({ - fields = { - ["create"] = a_fn({ args = { FUNCTION }, rets = { THREAD } }), - ["close"] = a_fn({ args = { THREAD }, rets = { BOOLEAN, STRING } }), - ["isyieldable"] = a_fn({ args = {}, rets = { BOOLEAN } }), - ["resume"] = a_fn({ args = va_args({ THREAD, ANY }), rets = va_args({ BOOLEAN, ANY }) }), - ["running"] = a_fn({ args = {}, rets = { THREAD, BOOLEAN } }), - ["status"] = a_fn({ args = { THREAD }, rets = { STRING } }), - ["wrap"] = a_fn({ args = { FUNCTION }, rets = { FUNCTION } }), - ["yield"] = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }), - }, - }), - ["debug"] = a_record({ - fields = { - ["Info"] = a_type("typedecl", { def = DEBUG_GETINFO_TABLE }), - ["Hook"] = a_type("typedecl", { def = DEBUG_HOOK_FUNCTION }), - ["HookEvent"] = a_type("typedecl", { def = DEBUG_HOOK_EVENT }), - - ["debug"] = a_fn({ args = {}, rets = {} }), - ["gethook"] = a_fn({ args = { OPT(THREAD) }, rets = { DEBUG_HOOK_FUNCTION, INTEGER } }), - ["getlocal"] = a_type("poly", { types = { - a_fn({ args = { THREAD, FUNCTION, NUMBER }, rets = { STRING } }), - a_fn({ args = { THREAD, NUMBER, NUMBER }, rets = { STRING, ANY } }), - a_fn({ args = { FUNCTION, NUMBER }, rets = { STRING } }), - a_fn({ args = { NUMBER, NUMBER }, rets = { STRING, ANY } }), - } }), - ["getmetatable"] = a_gfunction(1, function(a) return { args = { a }, rets = { METATABLE(a) } } end), - ["getregistry"] = a_fn({ args = {}, rets = { TABLE } }), - ["getupvalue"] = a_fn({ args = { FUNCTION, NUMBER }, rets = { ANY } }), - ["getuservalue"] = a_fn({ args = { USERDATA, NUMBER }, rets = { ANY } }), - ["sethook"] = a_type("poly", { types = { - a_fn({ args = { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }), - a_fn({ args = { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }), - } }), - ["setlocal"] = a_type("poly", { types = { - a_fn({ args = { THREAD, NUMBER, NUMBER, ANY }, rets = { STRING } }), - a_fn({ args = { NUMBER, NUMBER, ANY }, rets = { STRING } }), - } }), - ["setmetatable"] = a_gfunction(1, function(a) return { args = { a, METATABLE(a) }, rets = { a } } end), - ["setupvalue"] = a_fn({ args = { FUNCTION, NUMBER, ANY }, rets = { STRING } }), - ["setuservalue"] = a_fn({ args = { USERDATA, ANY, NUMBER }, rets = { USERDATA } }), - ["traceback"] = a_type("poly", { types = { - a_fn({ args = { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = { STRING } }), - a_fn({ args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }), - } }), - ["upvalueid"] = a_fn({ args = { FUNCTION, NUMBER }, rets = { USERDATA } }), - ["upvaluejoin"] = a_fn({ args = { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = {} }), - ["getinfo"] = a_type("poly", { types = { - a_fn({ args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }), - a_fn({ args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), - a_fn({ args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), - } }), - }, - }), - ["io"] = a_record({ - fields = { - ["close"] = a_fn({ args = { OPT(NOMINAL_FILE) }, rets = { BOOLEAN, STRING } }), - ["flush"] = a_fn({ args = {}, rets = {} }), - ["input"] = a_fn({ args = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }, rets = { NOMINAL_FILE } }), - ["lines"] = a_file_reader(function(ctor, args, rets) - return a_fn({ args = ctor(args), rets = { - a_fn({ args = {}, rets = ctor(rets) }), - }, }) - end), - ["open"] = a_fn({ args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }), - ["output"] = a_fn({ args = { OPT(a_type("union", { types = { STRING, NOMINAL_FILE } })) }, rets = { NOMINAL_FILE } }), - ["popen"] = a_fn({ args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }), - ["read"] = a_file_reader(function(ctor, args, rets) - return a_fn({ args = ctor(args), rets = ctor(rets) }) - end), - ["stderr"] = NOMINAL_FILE, - ["stdin"] = NOMINAL_FILE, - ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_fn({ args = {}, rets = { NOMINAL_FILE } }), - ["type"] = a_fn({ args = { ANY }, rets = { STRING } }), - ["write"] = a_fn({ args = va_args({ a_type("union", { types = { STRING, NUMBER } }) }), rets = { NOMINAL_FILE, STRING } }), - }, - }), - ["math"] = a_record({ - fields = { - ["abs"] = a_type("poly", { types = { - a_fn({ args = { INTEGER }, rets = { INTEGER } }), - a_fn({ args = { NUMBER }, rets = { NUMBER } }), - } }), - ["acos"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["asin"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["atan"] = a_fn({ args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }), - ["atan2"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), - ["ceil"] = a_fn({ args = { NUMBER }, rets = { INTEGER } }), - ["cos"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["cosh"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["deg"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["exp"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["floor"] = a_fn({ args = { NUMBER }, rets = { INTEGER } }), - ["fmod"] = a_type("poly", { types = { - a_fn({ args = { INTEGER, INTEGER }, rets = { INTEGER } }), - a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), - } }), - ["frexp"] = a_fn({ args = { NUMBER }, rets = { NUMBER, NUMBER } }), - ["huge"] = NUMBER, - ["ldexp"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), - ["log"] = a_fn({ args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }), - ["log10"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["max"] = a_type("poly", { types = { - a_fn({ args = va_args({ INTEGER }), rets = { INTEGER } }), - a_gfunction(1, function(a) return { args = va_args({ a }), rets = { a } } end), - a_fn({ args = va_args({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = { NUMBER } }), - a_fn({ args = va_args({ ANY }), rets = { ANY } }), - } }), - ["maxinteger"] = a_type("integer", { needs_compat = true }), - ["min"] = a_type("poly", { types = { - a_fn({ args = va_args({ INTEGER }), rets = { INTEGER } }), - a_gfunction(1, function(a) return { args = va_args({ a }), rets = { a } } end), - a_fn({ args = va_args({ a_type("union", { types = { NUMBER, INTEGER } }) }), rets = { NUMBER } }), - a_fn({ args = va_args({ ANY }), rets = { ANY } }), - } }), - ["mininteger"] = a_type("integer", { needs_compat = true }), - ["modf"] = a_fn({ args = { NUMBER }, rets = { INTEGER, NUMBER } }), - ["pi"] = NUMBER, - ["pow"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), - ["rad"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["random"] = a_type("poly", { types = { - a_fn({ args = { NUMBER, OPT(NUMBER) }, rets = { INTEGER } }), - a_fn({ args = {}, rets = { NUMBER } }), - } }), - ["randomseed"] = a_fn({ args = { NUMBER, NUMBER }, rets = { INTEGER, INTEGER } }), - ["sin"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["sinh"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["sqrt"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["tan"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["tanh"] = a_fn({ args = { NUMBER }, rets = { NUMBER } }), - ["tointeger"] = a_fn({ args = { ANY }, rets = { INTEGER } }), - ["type"] = a_fn({ args = { ANY }, rets = { STRING } }), - ["ult"] = a_fn({ args = { NUMBER, NUMBER }, rets = { BOOLEAN } }), - }, - }), - ["os"] = a_record({ - fields = { - ["clock"] = a_fn({ args = {}, rets = { NUMBER } }), - ["date"] = a_type("poly", { types = { - a_fn({ args = {}, rets = { STRING } }), - a_fn({ args = { an_enum({ "!*t", "*t" }), OPT(NUMBER) }, rets = { OS_DATE_TABLE } }), - a_fn({ args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }), - } }), - ["difftime"] = a_fn({ args = { NUMBER, NUMBER }, rets = { NUMBER } }), - ["execute"] = a_fn({ args = { STRING }, rets = { BOOLEAN, STRING, INTEGER } }), - ["exit"] = a_fn({ args = { OPT(a_type("union", { types = { NUMBER, BOOLEAN } })), OPT(BOOLEAN) }, rets = {} }), - ["getenv"] = a_fn({ args = { STRING }, rets = { STRING } }), - ["remove"] = a_fn({ args = { STRING }, rets = { BOOLEAN, STRING } }), - ["rename"] = a_fn({ args = { STRING, STRING }, rets = { BOOLEAN, STRING } }), - ["setlocale"] = a_fn({ args = { STRING, OPT(STRING) }, rets = { STRING } }), - ["time"] = a_fn({ args = { OPT(OS_DATE_TABLE) }, rets = { INTEGER } }), - ["tmpname"] = a_fn({ args = {}, rets = { STRING } }), - }, - }), - ["package"] = a_record({ - fields = { - ["config"] = STRING, - ["cpath"] = STRING, - ["loaded"] = a_type("map", { keys = STRING, values = ANY }), - ["loaders"] = a_type("array", { elements = a_fn({ args = { STRING }, rets = { ANY, ANY } }) }), - ["loadlib"] = a_fn({ args = { STRING, STRING }, rets = { FUNCTION } }), - ["path"] = STRING, - ["preload"] = TABLE, - ["searchers"] = a_type("array", { elements = a_fn({ args = { STRING }, rets = { ANY, ANY } }) }), - ["searchpath"] = a_fn({ args = { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = { STRING, STRING } }), - }, - }), - ["string"] = a_record({ - fields = { - ["byte"] = a_type("poly", { types = { - a_fn({ args = { STRING, OPT(NUMBER) }, rets = { INTEGER } }), - a_fn({ args = { STRING, NUMBER, NUMBER }, rets = va_args({ INTEGER }) }), - } }), - ["char"] = a_fn({ args = va_args({ NUMBER }), rets = { STRING } }), - ["dump"] = a_fn({ args = { FUNCTION, OPT(BOOLEAN) }, rets = { STRING } }), - ["find"] = a_fn({ args = { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = va_args({ INTEGER, INTEGER, STRING }) }), - ["format"] = a_fn({ args = va_args({ STRING, ANY }), rets = { STRING } }), - ["gmatch"] = a_fn({ args = { STRING, STRING }, rets = { - a_fn({ args = {}, rets = va_args({ STRING }) }), - }, }), - ["gsub"] = a_type("poly", { types = { - a_fn({ args = { STRING, STRING, a_type("map", { keys = STRING, values = STRING }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), - a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = { STRING } }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), - a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = { NUMBER } }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), - a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = { BOOLEAN } }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), - a_fn({ args = { STRING, STRING, a_fn({ args = va_args({ STRING }), rets = {} }), OPT(NUMBER) }, rets = { STRING, INTEGER } }), - a_fn({ args = { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = { STRING, INTEGER } }), - - } }), - ["len"] = a_fn({ args = { STRING }, rets = { INTEGER } }), - ["lower"] = a_fn({ args = { STRING }, rets = { STRING } }), - ["match"] = a_fn({ args = { STRING, OPT(STRING), OPT(NUMBER) }, rets = va_args({ STRING }) }), - ["pack"] = a_fn({ args = va_args({ STRING, ANY }), rets = { STRING } }), - ["packsize"] = a_fn({ args = { STRING }, rets = { INTEGER } }), - ["rep"] = a_fn({ args = { STRING, NUMBER, OPT(STRING) }, rets = { STRING } }), - ["reverse"] = a_fn({ args = { STRING }, rets = { STRING } }), - ["sub"] = a_fn({ args = { STRING, NUMBER, OPT(NUMBER) }, rets = { STRING } }), - ["unpack"] = a_fn({ args = { STRING, STRING, OPT(NUMBER) }, rets = va_args({ ANY }) }), - ["upper"] = a_fn({ args = { STRING }, rets = { STRING } }), - }, - }), - ["table"] = a_record({ - fields = { - ["concat"] = a_fn({ args = { a_type("array", { elements = a_type("union", { types = { STRING, NUMBER } }) }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = { STRING } }), - ["insert"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, a }, rets = {} } end), - a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), a }, rets = {} } end), - } }), - ["move"] = a_type("poly", { types = { - a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER }, rets = { a_type("array", { elements = a }) } } end), - a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), NUMBER, NUMBER, NUMBER, a_type("array", { elements = a }) }, rets = { a_type("array", { elements = a }) } } end), - } }), - ["pack"] = a_fn({ args = va_args({ ANY }), rets = { TABLE } }), - ["remove"] = a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), OPT(NUMBER) }, rets = { a } } end), - ["sort"] = a_gfunction(1, function(a) return { args = { a_type("array", { elements = a }), OPT(TABLE_SORT_FUNCTION) }, rets = {} } end), - ["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = { a_type("array", { elements = a }), OPT(NUMBER), OPT(NUMBER) }, rets = va_args({ a }) } end), - }, - }), - ["utf8"] = a_record({ - fields = { - ["char"] = a_fn({ args = va_args({ NUMBER }), rets = { STRING } }), - ["charpattern"] = STRING, - ["codepoint"] = a_fn({ args = { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = va_args({ INTEGER }) }), - ["codes"] = a_fn({ args = { STRING }, rets = { - a_fn({ args = { STRING, OPT(NUMBER) }, rets = { NUMBER, NUMBER } }), - }, }), - ["len"] = a_fn({ args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }), - ["offset"] = a_fn({ args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }), - }, - }), - ["_VERSION"] = STRING, - } - - NOMINAL_FILE.found = standard_library["FILE"] - for _, m in ipairs(metatable_nominals) do - m.found = standard_library["metatable"] - end - - for name, typ in pairs(standard_library) do - globals[name] = { t = typ, needs_compat = stdlib_compat[name], attribute = "const" } - end - - - - - globals["@is_va"] = { t = ANY } - - if not is_first_init then - last_typeid = save_typeid - end - - return globals, standard_library -end - local function set_feat(feat, default) if feat then return (feat == "on") @@ -6413,21 +6354,50 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" end - local globals, standard_library = init_globals(lax) - local env = { modules = {}, loaded = {}, loaded_order = {}, - globals = globals, + globals = {}, gen_compat = gen_compat, gen_target = gen_target, } + if not stdlib_globals then + local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") + assert(#syntax_errors == 0) + local result = tl.type_check(program, { + filename = "@stdlib", + env = env, + }) + assert(#result.type_errors == 0) + stdlib_globals = env.globals; + + + local math_t = (stdlib_globals["math"].t).def + local table_t = (stdlib_globals["table"].t).def + local integer_compat = a_type("integer", { needs_compat = true }) + math_t.fields["maxinteger"] = integer_compat + math_t.fields["mininteger"] = integer_compat + table_t.fields["unpack"].needs_compat = true + + + + + stdlib_globals["..."] = { t = a_vararg({ STRING }) } + stdlib_globals["@is_va"] = { t = ANY } + + env.globals = {} + end + + local stdlib_compat = get_stdlib_compat(lax) + for name, var in pairs(stdlib_globals) do + env.globals[name] = var + var.needs_compat = stdlib_compat[name] + local t = var.t + if t.typename == "typedecl" then - for name, var in pairs(standard_library) do - if var.typename == "record" then - env.modules[name] = var + env.modules[name] = t end end diff --git a/tl.tl b/tl.tl index 15d14bbd0..fbaa13909 100644 --- a/tl.tl +++ b/tl.tl @@ -1,6 +1,452 @@ local VERSION = "0.15.3+dev" +local stdlib = [=====[ + +do + global type any + global type thread + global type userdata + + local enum FileStringMode + "a" "l" "L" "*a" "*l" "*L" + end + + local enum FileNumberMode + "n" "*n" + end + + local enum FileMode + "a" "l" "L" "*a" "*l" "*L" "n" "*n" + end + + global record FILE + userdata + + enum SeekWhence + "set" "cur" "end" + end + + enum SetVBufMode + "no" "full" "line" + end + + close: function(FILE): boolean, string, number + flush: function(FILE) + + lines: function(FILE): (function(): (string)) + lines: function(FILE, FileNumberMode...): (function(): (number...)) + lines: function(FILE, (number | FileStringMode)...): (function(): (string...)) + lines: function(FILE, (number | FileMode)...): (function(): ((string | number)...)) + lines: function(FILE, (number | string)...): (function(): (string...)) + + read: function(FILE): string + read: function(FILE, FileNumberMode...): number... + read: function(FILE, (number | FileStringMode)...): string... + read: function(FILE, (number | FileMode)...): ((string | number)...) + read: function(FILE, (number | string)...): (string...) + + seek: function(FILE, ? SeekWhence, ? number): integer, string + setvbuf: function(FILE, SetVBufMode, ? number) + + write: function(FILE, (string | number)...): FILE, string + + metamethod __close: function(FILE) + end + + global record coroutine + type Function = function(any...): any... + + create: function(Function): thread + close: function(thread): boolean, string + isyieldable: function(): boolean + resume: function(thread, any...): boolean, any... + running: function(): thread, boolean + status: function(thread): string + wrap: function(F): F + yield: function(any...): any... + end + + global record debug + record GetInfoTable + name: string + namewhat: string + source: string + short_src: string + linedefined: integer + lastlinedefined: integer + what: string + currentline: integer + istailcall: boolean + nups: integer + nparams: integer + isvararg: boolean + func: any + activelines: {integer:boolean} + end + + enum HookEvent + "call" "tail call" "return" "line" "count" + end + + type HookFunction = function(HookEvent, integer) + + type AnyFunction = function(any...):any... + + debug: function() + gethook: function(? thread): HookFunction, integer + + getinfo: function(AnyFunction | number): GetInfoTable + getinfo: function(AnyFunction | number, string): GetInfoTable + getinfo: function(thread, AnyFunction | number, string): GetInfoTable + + getlocal: function(thread, AnyFunction, number): string + getlocal: function(thread, number, number): string, any + getlocal: function(AnyFunction, number): string + getlocal: function(number, number): string, any + + getmetatable: function(T): metatable + getregistry: function(): {any:any} + getupvalue: function(AnyFunction, number): any + getuservalue: function(userdata, number): any + + sethook: function(thread, HookFunction, string, ? number) + sethook: function(HookFunction, string, ? number) + + setlocal: function(thread, number, number, any): string + setlocal: function(number, number, any): string + + setmetatable: function(T, metatable): T + setupvalue: function(AnyFunction, number, any): string + setuservalue: function(U, any, number): U --[[U is userdata]] + + traceback: function(thread, ? string, ? number): string + traceback: function(? string, ? number): string + + upvalueid: function(AnyFunction, number): userdata + upvaluejoin: function(AnyFunction, number, AnyFunction, number) + end + + global record io + enum OpenMode + "r" "w" "a" "r+" "w+" "a+" + "rb" "wb" "ab" "r+b" "w+b" "a+b" + "*r" "*w" "*a" "*r+" "*w+" "*a+" + "*rb" "*wb" "*ab" "*r+b" "*w+b" "*a+b" + end + + close: function(? FILE) + input: function(? FILE): FILE + flush: function() + + lines: function(? string): (function(): (string)) + lines: function(? string, FileNumberMode...): (function(): (number...)) + lines: function(? string, (number | FileStringMode)...): (function(): (string...)) + lines: function(? string, (number | FileMode)...): (function(): ((string | number)...)) + lines: function(? string, (number | string)...): (function(): (string...)) + + open: function(string, ? OpenMode): FILE, string + output: function(? FILE): FILE + popen: function(string, ? OpenMode): FILE, string + + read: function(): string + read: function(FileNumberMode...): number... + read: function((number | FileStringMode)...): string... + read: function((number | FileMode)...): ((string | number)...) + read: function((number | string)...): (string...) + + stderr: FILE + stdin: FILE + stdout: FILE + tmpfile: function(): FILE + type: function(any): string + write: function((string | number)...): FILE, string + end + + global record math + abs: function(integer): integer + abs: function(number): number + + acos: function(number): number + asin: function(number): number + atan: function(number, ? number): number + atan2: function(number, number): number + ceil: function(number): integer + cos: function(number): number + cosh: function(number): number + deg: function(number): number + exp: function(number): number + floor: function(number): integer + + fmod: function(integer, integer): integer + fmod: function(number, number): number + + frexp: function(number): number, number + huge: number + ldexp: function(number, number): number + log: function(number, ? number): number + log10: function(number): number + + max: function(integer...): integer + max: function((number | integer)...): number + max: function(T...): T + max: function(any...): any + + maxinteger: integer --[[needs_compat]] + + min: function(integer...): integer + min: function((number | integer)...): number + min: function(T...): T + min: function(any...): any + + mininteger: integer --[[needs_compat]] + + modf: function(number): integer, number + pi: number + pow: function(number, number): number + rad: function(number): number + + random: function(number, ? number): integer + random: function(): number + + randomseed: function(number, number): integer, integer + sin: function(number): number + sinh: function(number): number + sqrt: function(number): number + tan: function(number): number + tanh: function(number): number + tointeger: function(any): integer + type: function(any): string + ult: function(number, number): boolean + end + + global record metatable + enum Mode + "k" "v" "kv" + end + + __call: function(T, any...): any... + __mode: Mode + __name: string + __tostring: function(T): string + __pairs: function(T): (function(): (K, V)) + + __index: any --[[FIXME: function | table | anything with an __index metamethod]] + __newindex: any --[[FIXME: function | table | anything with an __index metamethod]] + + __gc: function(T) + __close: function(T) + + __add: function(any, any): any + __sub: function(any, any): any + __mul: function(any, any): any + __div: function(any, any): any + __idiv: function(any, any): any + __mod: function(any, any): any + __pow: function(any, any): any + __band: function(any, any): any + __bor: function(any, any): any + __bxor: function(any, any): any + __shl: function(any, any): any + __shr: function(any, any): any + __concat: function(any, any): any + + __len: function(T): any + __unm: function(T): any + __bnot: function(T): any + + __eq: function(any, any): boolean + __lt: function(any, any): boolean + __le: function(any, any): boolean + end + + global record os + record DateTable + year: integer + month: integer + day: integer + hour: integer + min: integer + sec: integer + wday: integer + yday: integer + isdst: boolean + end + + enum DateMode + "!*t" "*t" + end + + clock: function(): number + + date: function(DateMode, ? number): DateTable + date: function(? string, ? number): string + + difftime: function(number, number): number + execute: function(string): boolean, string, integer + exit: function(? (number | boolean), ? boolean) + getenv: function(string): string + remove: function(string): boolean, string + rename: function(string, string): boolean, string + setlocale: function(string, ? string): string + time: function(? DateTable): integer + tmpname: function(): string + end + + global record package + config: string + cpath: string + loaded: {string:any} + loaders: { function(string): any, any } + path: string + preload: {any:any} + searchers: { function(string): any } + end + + global record string + char: function(number...): string + + byte: function(string, ? number): integer + byte: function(string, number, number): integer... + + dump: function(function(any...): (any), ? boolean): string + find: function(string, string, ? number, ? boolean): integer, integer, string + format: function(string, any...): string + gmatch: function(string, string): (function(): string...) + + gsub: function(string, string, string, ? number): string, integer + gsub: function(string, string, {string:string}, ? number): string, integer + gsub: function(string, string, function(string...): (string | number | boolean), ? number): string, integer + + len: function(string): integer + lower: function(string): string + match: function(string, string, ? number): string... + pack: function(string, any...): string + packsize: function(string): integer + rep: function(string, number, ? string): string + reverse: function(string): string + sub: function(string, number, ? number): string + unpack: function(string, string, ? number): any... + upper: function(string): string + end + + global record table + type SortFunction = function(A, A): boolean + + record PackTable + is {A} + + n: integer + end + + concat: function({(string | number)}, ? string, ? number, ? number): string + + insert: function({A}, number, A) + insert: function({A}, A) + + pack: function(T...): PackTable + pack: function(any...): {any:any} + + remove: function({A}, ? number): A + sort: function({A}, ? SortFunction) + + unpack: function({A}, ? number, ? number): A... --[[needs_compat]] + end + + global record utf8 + char: function(number...): string + charpattern: string + codepoint: function(string, ? number, ? number, ? boolean): number... + codes: function(string, ? boolean): (function(string, ? number): (number, number)) + len: function(string, ? number, ? number, ? boolean): number + offset: function(string, number, ? number): number + end + + local record StandardLibrary + enum CollectGarbageCommand + "collect" + "count" + "stop" + "restart" + end + + enum CollectGarbageSetValue + "step" + "setpause" + "setstepmul" + end + + enum CollectGarbageIsRunning + "isrunning" + end + + type LoadFunction = function(): string + + enum LoadMode + "b" "t" "bt" + end + + type XpcallMsghFunction = function(...: any): () + + arg: {string} + assert: function(A, ? B): A + + collectgarbage: function(CollectGarbageCommand): number + collectgarbage: function(CollectGarbageSetValue, number): number + collectgarbage: function(CollectGarbageIsRunning): boolean + collectgarbage: function(string, ? number): (boolean | number) + + error: function(? any, ? number) + ipairs: function({A}): (function():(integer, A)) + + load: function((string | LoadFunction), ? string, ? LoadMode, ? table): (function, string) + load: function((string | LoadFunction), ? string, ? string, ? table): (function, string) + + next: function({K:V}, ? K): (K, V) + next: function({A}, ? integer): (integer, A) + + pairs: function({K:V}): (function():(K, V)) + pcall: function(function(any...):(any...), any...): boolean, any... + print: function(any...) + require: function(string): any + + select: function(number, T...): T + select: function(number, any...): any + select: function(string, any...): integer + + setmetatable: function(T, metatable): T + + tonumber: function(any): number + tonumber: function(any, number): integer + + tostring: function(any): string + type: function(any): string + xpcall: function(function(any...):(any...), XpcallMsghFunction, any...): boolean, any... + _VERSION: string + end + + global arg = StandardLibrary.arg + global assert = StandardLibrary.assert + global collectgarbage = StandardLibrary.collectgarbage + global error = StandardLibrary.error + global load = StandardLibrary.load + global next = StandardLibrary.next + global pairs = StandardLibrary.pairs + global pcall = StandardLibrary.pcall + global print = StandardLibrary.print + global require = StandardLibrary.require + global select = StandardLibrary.select + global setmetatable = StandardLibrary.setmetatable + global tostring = StandardLibrary.tostring + global tonumber = StandardLibrary.tonumber + global ipairs = StandardLibrary.ipairs + global type = StandardLibrary.type + global xpcall = StandardLibrary.xpcall + global _VERSION = StandardLibrary._VERSION +end + +]=====] + local record tl enum LoadMode "b" @@ -1695,10 +2141,10 @@ local macroexp a_union(t: {Type}): UnionType return a_type("union", { types = t } as UnionType) end -local macroexp a_poly(t: {FunctionType}): PolyType - return a_type("poly", { types = t } as PolyType) -end - +--local macroexp a_poly(t: {FunctionType}): PolyType +-- return a_type("poly", { types = t } as PolyType) +--end +-- local function a_function(t: FunctionType): FunctionType assert(t.min_arity) return a_type("function", t) @@ -1710,10 +2156,10 @@ local record Opt opttype: Type end -local function OPT(t: Type): Opt - return { opttype = t } -end - +--local function OPT(t: Type): Opt +-- return { opttype = t } +--end +-- local record Args is {Type|Opt} @@ -5062,7 +5508,6 @@ local skip_types: {TypeName: boolean} = { ["none"] = true, ["literal_table_item"] = true, ["unresolved"] = true, - ["typedecl"] = true, } local get_typenum: function(trenv: TypeReportEnv, t: Type): integer @@ -5126,12 +5571,14 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer n = trenv.next_num local rt = t + if rt is TupleType and #rt.tuple == 1 then + rt = rt.tuple[1] + end + if rt is TypeDeclType then rt = rt.def elseif rt is TypeAliasType then rt = rt.alias_to - elseif rt is TupleType and #rt.tuple == 1 then - rt = rt.tuple[1] end local ti: TypeInfo = { @@ -5221,10 +5668,10 @@ local CIRCULAR_REQUIRE = a_type("circular_require", {}) local FUNCTION = a_fn { args = va_args { ANY }, rets = va_args { ANY } } -local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) +--local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) local XPCALL_MSGH_FUNCTION = a_fn { args = { ANY }, rets = { } } -local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type +--local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type local numeric_binop = { ["number"] = { @@ -5864,516 +6311,10 @@ local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which node.e2[4] = e2 end +local stdlib_globals: {string:Variable} = nil local globals_typeid: integer local fresh_typevar_ctr = 1 -local function init_globals(lax: boolean): {string:Variable}, {string:Type} - local globals: {string:Variable} = {} - local stdlib_compat = get_stdlib_compat(lax) - - -- ensure globals are always initialized with the same typeids - local is_first_init = globals_typeid == nil - - local save_typeid = last_typeid - if is_first_init then - globals_typeid = new_typeid() - else - last_typeid = globals_typeid - end - - local function a_record(t: RecordType): RecordType - t = a_type("record", t) - assert(t.fields) - t.field_order = sorted_keys(t.fields) - return t - end - - local function a_generic(n: integer, f: function(...: TypeVarType): (T)): T - local typevars = {} - local typeargs = {} - local c = string.byte("A") - 1 - fresh_typevar_ctr = fresh_typevar_ctr + 1 - for i = 1, n do - local name = string.char(c + i) .. "@" .. fresh_typevar_ctr - typevars[i] = a_type("typevar", { typevar = name } as TypeVarType) - typeargs[i] = a_type("typearg", { typearg = name } as TypeArgType) - end - local t = f(table.unpack(typevars)) - if t is FunctionType or t is RecordType then - t.typeargs = typeargs - end - return t - end - - local function a_gfunction(n: integer, f: function(...: TypeVarType): FuncArgs): FunctionType - return a_generic(n, function(...: TypeVarType): FunctionType return a_fn(f(...)) end) - end - - local function a_grecord(n: integer, f: function(...: TypeVarType): RecordType): RecordType - local t = a_generic(n, f) - t.field_order = sorted_keys(t.fields) - return t - end - - local function an_enum(keys: {string}): EnumType - local t = a_type("enum", { enumset = {} } as EnumType) - for _, k in ipairs(keys) do - t.enumset[k] = true - end - return t - end - - local type TypeConstructor = function({Type}): Args - - local record ArgsRets - ctor: TypeConstructor - args: {Type} - rets: {Type} - end - - local function id(x: T): T - return x - end - - local file_reader_poly_types: {ArgsRets} = { - { ctor = va_args, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L" } } }, rets = { STRING } }, - { ctor = id, args = { an_enum { "*n", "n" } }, rets = { NUMBER, STRING } }, - { ctor = va_args, args = { a_union { NUMBER, an_enum { "*a", "a", "*l", "l", "*L", "L", "*n", "n" } } }, rets = { a_union { STRING, NUMBER } } }, - { ctor = va_args, args = { a_union { NUMBER, STRING } }, rets = { STRING } }, - { ctor = va_args, args = { }, rets = { STRING } }, - } - - local function a_file_reader(fn: (function(ctor: TypeConstructor, args: {Type}, rets: {Type}): FunctionType)): Type - local t = a_poly {} - for _, entry in ipairs(file_reader_poly_types) do - local args = shallow_copy_table(entry.args) - local rets = shallow_copy_table(entry.rets) - table.insert(t.types, fn(entry.ctor, args, rets)) - end - return t - end - - local LOAD_FUNCTION = a_fn { args = {}, rets = { STRING } } - - local OS_DATE_TABLE = a_record { - fields = { - ["year"] = INTEGER, - ["month"] = INTEGER, - ["day"] = INTEGER, - ["hour"] = INTEGER, - ["min"] = INTEGER, - ["sec"] = INTEGER, - ["wday"] = INTEGER, - ["yday"] = INTEGER, - ["isdst"] = BOOLEAN, - } - } - - local DEBUG_GETINFO_TABLE = a_record { - fields = { - ["name"] = STRING, - ["namewhat"] = STRING, - ["source"] = STRING, - ["short_src"] = STRING, - ["linedefined"] = INTEGER, - ["lastlinedefined"] = INTEGER, - ["what"] = STRING, - ["currentline"] = INTEGER, - ["istailcall"] = BOOLEAN, - ["nups"] = INTEGER, - ["nparams"] = INTEGER, - ["isvararg"] = BOOLEAN, - ["func"] = ANY, - ["activelines"] = a_map(INTEGER, BOOLEAN), - } - } - - local DEBUG_HOOK_EVENT = an_enum { "call", "tail call", "return", "line", "count" } - - local DEBUG_HOOK_FUNCTION = a_fn { - args = { DEBUG_HOOK_EVENT, INTEGER }, - rets = {}, - } - - local TABLE_SORT_FUNCTION = a_gfunction(1, function(a: Type): FuncArgs return { args = { a, a }, rets = { BOOLEAN } } end) - - local metatable_nominals: {NominalType} = {} - - local function METATABLE(a: Type): Type - local t = a_type("nominal", { names = {"metatable"}, typevals = { a } } as NominalType) - table.insert(metatable_nominals, t) - return t - end - - local standard_library: {string:Type} = { - ["..."] = a_vararg { STRING }, - ["any"] = a_typedecl(ANY), - ["arg"] = an_array(STRING), - ["assert"] = a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a, OPT(b) }, rets = { a } } end), - ["collectgarbage"] = a_poly { - a_fn { args = { an_enum { "collect", "count", "stop", "restart" } }, rets = { NUMBER } }, - a_fn { args = { an_enum { "step", "setpause", "setstepmul" }, NUMBER }, rets = { NUMBER } }, - a_fn { args = { an_enum { "isrunning" } }, rets = { BOOLEAN } }, - a_fn { args = { STRING, OPT(NUMBER) }, rets = { a_union { BOOLEAN, NUMBER } } }, - }, - ["dofile"] = a_fn { args = { OPT(STRING) }, rets = va_args { ANY } }, - ["error"] = a_fn { args = { ANY, OPT(NUMBER) }, rets = {} }, - ["getmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a }, rets = { METATABLE(a) } } end), - ["ipairs"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a) }, rets = { - a_fn { args = {}, rets = { INTEGER, a } }, - } } end), - ["load"] = a_fn { args = { a_union { STRING, LOAD_FUNCTION }, OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }, - ["loadfile"] = a_fn { args = { OPT(STRING), OPT(STRING), OPT(TABLE) }, rets = { FUNCTION, STRING } }, - ["next"] = a_poly { - a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a_map(a, b), OPT(a) }, rets = { a, b } } end), - a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), OPT(a) }, rets = { INTEGER, a } } end), - }, - ["pairs"] = a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a_map(a, b) }, rets = { - a_fn { args = {}, rets = { a, b } }, - } } end), - ["pcall"] = a_fn { args = va_args { FUNCTION, ANY }, rets = va_args { BOOLEAN, ANY } }, - ["xpcall"] = a_fn { args = va_args { FUNCTION, XPCALL_MSGH_FUNCTION, ANY }, rets = va_args { BOOLEAN, ANY } }, - ["print"] = a_fn { args = va_args { ANY }, rets = {} }, - ["rawequal"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, - ["rawget"] = a_fn { args = { TABLE, ANY }, rets = { ANY } }, - ["rawlen"] = a_fn { args = { a_union { TABLE, STRING } }, rets = { INTEGER } }, - ["rawset"] = a_poly { - a_gfunction(2, function(a: Type, b: Type): FuncArgs return { args = { a_map(a, b), a, b }, rets = {} } end), - a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, a }, rets = {} } end), - a_fn { args = { TABLE, ANY, ANY }, rets = {} }, - }, - ["require"] = a_fn { args = { STRING }, rets = {} }, - ["select"] = a_poly { - a_gfunction(1, function(a: Type): FuncArgs return { args = va_args { NUMBER, a }, rets = { a } } end), - a_fn { args = va_args { NUMBER, ANY }, rets = { ANY } }, - a_fn { args = va_args { STRING, ANY }, rets = { INTEGER } }, - }, - ["setmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a, METATABLE(a) }, rets = { a } } end), - ["tonumber"] = a_poly { - a_fn { args = { ANY }, rets = { NUMBER } }, - a_fn { args = { ANY, NUMBER }, rets = { INTEGER } }, - }, - ["tostring"] = a_fn { args = { ANY }, rets = { STRING } }, - ["type"] = a_fn { args = { ANY }, rets = { STRING } }, - ["FILE"] = a_typedecl( - a_record { - is_userdata = true, - fields = { - ["close"] = a_fn { args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING, INTEGER } }, - ["flush"] = a_fn { args = { NOMINAL_FILE }, rets = {} }, - ["lines"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - table.insert(args, 1, NOMINAL_FILE) - return a_fn { args = ctor(args), rets = { - a_fn { args = {}, rets = ctor(rets) }, - } } - end), - ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - table.insert(args, 1, NOMINAL_FILE) - return a_fn { args = ctor(args), rets = ctor(rets) } - end), - ["seek"] = a_fn { args = { NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }, rets = { INTEGER, STRING } }, - ["setvbuf"] = a_fn { args = { NOMINAL_FILE, STRING, OPT(NUMBER) }, rets = {} }, - ["write"] = a_fn { args = va_args { NOMINAL_FILE, a_union { STRING, NUMBER } }, rets = { NOMINAL_FILE, STRING } }, - -- TODO complete... - }, - meta_fields = { ["__close"] = FUNCTION }, - meta_field_order = { "__close" }, - } - ), - ["metatable"] = a_typedecl( - a_grecord(1, function(a: Type): RecordType return a_record { - fields = { - ["__call"] = a_fn { args = va_args { a, ANY }, rets = va_args { ANY } }, - ["__gc"] = a_fn { args = { a }, rets = {} }, - ["__index"] = ANY, -- FIXME: function | table | anything with an __index metamethod - ["__len"] = a_fn { args = { a }, rets = { ANY } }, - ["__mode"] = an_enum { "k", "v", "kv" }, - ["__newindex"] = ANY, -- FIXME: function | table | anything with a __newindex metamethod - ["__pairs"] = a_gfunction(2, function(k: Type, v: Type): FuncArgs - return { - args = { a }, - rets = { a_fn { args = {}, rets = { k, v } } } - } - end), - ["__tostring"] = a_fn { args = { a }, rets = { STRING } }, - ["__name"] = STRING, - ["__add"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__sub"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__mul"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__div"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__idiv"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__mod"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__pow"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__unm"] = a_fn { args = { ANY }, rets = { ANY } }, - ["__band"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__bor"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__bxor"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__bnot"] = a_fn { args = { ANY }, rets = { ANY } }, - ["__shl"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__shr"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__concat"] = a_fn { args = { ANY, ANY }, rets = { ANY } }, - ["__eq"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, - ["__lt"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, - ["__le"] = a_fn { args = { ANY, ANY }, rets = { BOOLEAN } }, - ["__close"] = a_fn { args = { a }, rets = { } }, - }, - } end) - ), - ["coroutine"] = a_record { - fields = { - ["create"] = a_fn { args = { FUNCTION }, rets = { THREAD } }, - ["close"] = a_fn { args = { THREAD }, rets = { BOOLEAN, STRING } }, - ["isyieldable"] = a_fn { args = {}, rets = { BOOLEAN } }, - ["resume"] = a_fn { args = va_args { THREAD, ANY }, rets = va_args { BOOLEAN, ANY } }, - ["running"] = a_fn { args = {}, rets = { THREAD, BOOLEAN } }, - ["status"] = a_fn { args = { THREAD }, rets = { STRING } }, - ["wrap"] = a_fn { args = { FUNCTION }, rets = { FUNCTION } }, - ["yield"] = a_fn { args = va_args { ANY }, rets = va_args { ANY } }, - } - }, - ["debug"] = a_record { - fields = { - ["Info"] = a_typedecl(DEBUG_GETINFO_TABLE), - ["Hook"] = a_typedecl(DEBUG_HOOK_FUNCTION), - ["HookEvent"] = a_typedecl(DEBUG_HOOK_EVENT), - - ["debug"] = a_fn { args = {}, rets = {} }, - ["gethook"] = a_fn { args = { OPT(THREAD) }, rets = { DEBUG_HOOK_FUNCTION, INTEGER } }, - ["getlocal"] = a_poly { - a_fn { args = { THREAD, FUNCTION, NUMBER }, rets = { STRING } }, - a_fn { args = { THREAD, NUMBER, NUMBER }, rets = { STRING, ANY } }, - a_fn { args = { FUNCTION, NUMBER }, rets = { STRING } }, - a_fn { args = { NUMBER, NUMBER }, rets = { STRING, ANY } }, - }, - ["getmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a }, rets = { METATABLE(a) } } end), - ["getregistry"] = a_fn { args = {}, rets = { TABLE } }, - ["getupvalue"] = a_fn { args = { FUNCTION, NUMBER }, rets = { ANY } }, - ["getuservalue"] = a_fn { args = { USERDATA, NUMBER }, rets = { ANY } }, - ["sethook"] = a_poly { - a_fn { args = { THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }, - a_fn { args = { DEBUG_HOOK_FUNCTION, STRING, NUMBER }, rets = {} }, - }, - ["setlocal"] = a_poly { - a_fn { args = { THREAD, NUMBER, NUMBER, ANY }, rets = { STRING } }, - a_fn { args = { NUMBER, NUMBER, ANY }, rets = { STRING } }, - }, - ["setmetatable"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { a, METATABLE(a) }, rets = { a } } end), - ["setupvalue"] = a_fn { args = { FUNCTION, NUMBER, ANY }, rets = { STRING } }, - ["setuservalue"] = a_fn { args = { USERDATA, ANY, NUMBER }, rets = { USERDATA } }, - ["traceback"] = a_poly { - a_fn { args = { OPT(THREAD), OPT(STRING), OPT(NUMBER) }, rets = { STRING } }, - a_fn { args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }, - }, - ["upvalueid"] = a_fn { args = { FUNCTION, NUMBER }, rets = { USERDATA } }, - ["upvaluejoin"] = a_fn { args = { FUNCTION, NUMBER, FUNCTION, NUMBER }, rets = {} }, - ["getinfo"] = a_poly { - a_fn { args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }, - a_fn { args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }, - a_fn { args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }, - }, - }, - }, - ["io"] = a_record { - fields = { - ["close"] = a_fn { args = { OPT(NOMINAL_FILE) }, rets = { BOOLEAN, STRING } }, - ["flush"] = a_fn { args = {}, rets = {} }, - ["input"] = a_fn { args = { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = { NOMINAL_FILE } }, - ["lines"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - return a_fn { args = ctor(args), rets = { - a_fn { args = {}, rets = ctor(rets) }, - } } - end), - ["open"] = a_fn { args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }, - ["output"] = a_fn { args = { OPT(a_union { STRING, NOMINAL_FILE }) }, rets = { NOMINAL_FILE } }, - ["popen"] = a_fn { args = { STRING, OPT(STRING) }, rets = { NOMINAL_FILE, STRING } }, - ["read"] = a_file_reader(function(ctor: TypeConstructor, args: {Type}, rets: {Type}): Type - return a_fn { args = ctor(args), rets = ctor(rets) } - end), - ["stderr"] = NOMINAL_FILE, - ["stdin"] = NOMINAL_FILE, - ["stdout"] = NOMINAL_FILE, - ["tmpfile"] = a_fn { args = {}, rets = { NOMINAL_FILE } }, - ["type"] = a_fn { args = { ANY }, rets = { STRING } }, - ["write"] = a_fn { args = va_args { a_union { STRING, NUMBER } }, rets = { NOMINAL_FILE, STRING } }, - }, - }, - ["math"] = a_record { - fields = { - ["abs"] = a_poly { - a_fn { args = { INTEGER }, rets = { INTEGER } }, - a_fn { args = { NUMBER }, rets = { NUMBER } }, - }, - ["acos"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["asin"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["atan"] = a_fn { args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }, - ["atan2"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, - ["ceil"] = a_fn { args = { NUMBER }, rets = { INTEGER } }, - ["cos"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["cosh"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["deg"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["exp"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["floor"] = a_fn { args = { NUMBER }, rets = { INTEGER } }, - ["fmod"] = a_poly { - a_fn { args = { INTEGER, INTEGER }, rets = { INTEGER } }, - a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, - }, - ["frexp"] = a_fn { args = { NUMBER }, rets = { NUMBER, NUMBER } }, - ["huge"] = NUMBER, - ["ldexp"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, - ["log"] = a_fn { args = { NUMBER, OPT(NUMBER) }, rets = { NUMBER } }, - ["log10"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["max"] = a_poly { - a_fn { args = va_args { INTEGER }, rets = { INTEGER } }, - a_gfunction(1, function(a: Type): FuncArgs return { args = va_args { a }, rets = { a } } end), - a_fn { args = va_args { a_union { NUMBER, INTEGER } }, rets = { NUMBER } }, - a_fn { args = va_args { ANY }, rets = { ANY } }, - }, - ["maxinteger"] = a_type("integer", { needs_compat = true }), - ["min"] = a_poly { - a_fn { args = va_args { INTEGER }, rets = { INTEGER } }, - a_gfunction(1, function(a: Type): FuncArgs return { args = va_args { a }, rets = { a } } end), - a_fn { args = va_args { a_union { NUMBER, INTEGER } }, rets = { NUMBER } }, - a_fn { args = va_args { ANY }, rets = { ANY } }, - }, - ["mininteger"] = a_type("integer", { needs_compat = true }), - ["modf"] = a_fn { args = { NUMBER }, rets = { INTEGER, NUMBER } }, - ["pi"] = NUMBER, - ["pow"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, - ["rad"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["random"] = a_poly { - a_fn { args = { NUMBER, OPT(NUMBER) }, rets = { INTEGER } }, - a_fn { args = {}, rets = { NUMBER } }, - }, - ["randomseed"] = a_fn { args = { NUMBER, NUMBER }, rets = { INTEGER, INTEGER } }, - ["sin"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["sinh"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["sqrt"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["tan"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["tanh"] = a_fn { args = { NUMBER }, rets = { NUMBER } }, - ["tointeger"] = a_fn { args = { ANY }, rets = { INTEGER } }, - ["type"] = a_fn { args = { ANY }, rets = { STRING } }, - ["ult"] = a_fn { args = { NUMBER, NUMBER }, rets = { BOOLEAN } }, - }, - }, - ["os"] = a_record { - fields = { - ["clock"] = a_fn { args = {}, rets = { NUMBER } }, - ["date"] = a_poly { - a_fn { args = { }, rets = { STRING } }, - a_fn { args = { an_enum { "!*t", "*t" }, OPT(NUMBER) }, rets = { OS_DATE_TABLE } }, - a_fn { args = { OPT(STRING), OPT(NUMBER) }, rets = { STRING } }, - }, - ["difftime"] = a_fn { args = { NUMBER, NUMBER }, rets = { NUMBER } }, - ["execute"] = a_fn { args = { STRING }, rets = { BOOLEAN, STRING, INTEGER } }, - ["exit"] = a_fn { args = { OPT(a_union { NUMBER, BOOLEAN }), OPT(BOOLEAN) }, rets = {} }, - ["getenv"] = a_fn { args = { STRING }, rets = { STRING } }, - ["remove"] = a_fn { args = { STRING }, rets = { BOOLEAN, STRING } }, - ["rename"] = a_fn { args = { STRING, STRING}, rets = { BOOLEAN, STRING } }, - ["setlocale"] = a_fn { args = { STRING, OPT(STRING) }, rets = { STRING } }, - ["time"] = a_fn { args = { OPT(OS_DATE_TABLE) }, rets = { INTEGER } }, - ["tmpname"] = a_fn { args = {}, rets = { STRING } }, - }, - }, - ["package"] = a_record { - fields = { - ["config"] = STRING, - ["cpath"] = STRING, - ["loaded"] = a_map(STRING, ANY), - ["loaders"] = an_array(a_fn { args = { STRING }, rets = { ANY, ANY } }), - ["loadlib"] = a_fn { args = { STRING, STRING }, rets = { FUNCTION } }, - ["path"] = STRING, - ["preload"] = TABLE, - ["searchers"] = an_array(a_fn { args = { STRING }, rets = { ANY, ANY } }), - ["searchpath"] = a_fn { args = { STRING, STRING, OPT(STRING), OPT(STRING) }, rets = { STRING, STRING } }, - }, - }, - ["string"] = a_record { - fields = { - ["byte"] = a_poly { - a_fn { args = { STRING, OPT(NUMBER) }, rets = { INTEGER } }, - a_fn { args = { STRING, NUMBER, NUMBER }, rets = va_args { INTEGER } }, - }, - ["char"] = a_fn { args = va_args { NUMBER }, rets = { STRING } }, - ["dump"] = a_fn { args = { FUNCTION, OPT(BOOLEAN) }, rets = { STRING } }, - ["find"] = a_fn { args = { STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }, rets = va_args { INTEGER, INTEGER, STRING } }, - ["format"] = a_fn { args = va_args { STRING, ANY }, rets = { STRING } }, - ["gmatch"] = a_fn { args = { STRING, STRING }, rets = { - a_fn { args = {}, rets = va_args { STRING } }, - } }, - ["gsub"] = a_poly { - a_fn { args = { STRING, STRING, a_map(STRING, STRING), OPT(NUMBER) }, rets = { STRING, INTEGER } }, - a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = { STRING } }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, - a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = { NUMBER } }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, - a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = { BOOLEAN } }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, - a_fn { args = { STRING, STRING, a_fn { args = va_args { STRING }, rets = {} }, OPT(NUMBER) }, rets = { STRING, INTEGER } }, - a_fn { args = { STRING, STRING, OPT(STRING), OPT(NUMBER) }, rets = { STRING, INTEGER } }, - -- FIXME any other modes - }, - ["len"] = a_fn { args = { STRING }, rets = { INTEGER } }, - ["lower"] = a_fn { args = { STRING }, rets = { STRING } }, - ["match"] = a_fn { args = { STRING, OPT(STRING), OPT(NUMBER) }, rets = va_args { STRING } }, - ["pack"] = a_fn { args = va_args { STRING, ANY }, rets = { STRING } }, - ["packsize"] = a_fn { args = { STRING }, rets = { INTEGER } }, - ["rep"] = a_fn { args = { STRING, NUMBER, OPT(STRING) }, rets = { STRING } }, - ["reverse"] = a_fn { args = { STRING }, rets = { STRING } }, - ["sub"] = a_fn { args = { STRING, NUMBER, OPT(NUMBER) }, rets = { STRING } }, - ["unpack"] = a_fn { args = { STRING, STRING, OPT(NUMBER) }, rets = va_args { ANY } }, - ["upper"] = a_fn { args = { STRING }, rets = { STRING } }, - }, - }, - ["table"] = a_record { - fields = { - ["concat"] = a_fn { args = { an_array(a_union {STRING, NUMBER }), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }, rets = { STRING } }, - ["insert"] = a_poly { - a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, a }, rets = {} } end), - a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), a }, rets = {} } end), - }, - ["move"] = a_poly { - a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, NUMBER, NUMBER }, rets = { an_array(a) } }end ), - a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), NUMBER, NUMBER, NUMBER, an_array(a) }, rets = { an_array(a) } } end), - }, - ["pack"] = a_fn { args = va_args { ANY }, rets = { TABLE } }, - ["remove"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), OPT(NUMBER) }, rets = { a } } end), - ["sort"] = a_gfunction(1, function(a: Type): FuncArgs return { args = { an_array(a), OPT(TABLE_SORT_FUNCTION) }, rets = {} } end), - ["unpack"] = a_gfunction(1, function(a: Type): FuncArgs return { needs_compat = true, args = { an_array(a), OPT(NUMBER), OPT(NUMBER) }, rets = va_args { a } } end), - }, - }, - ["utf8"] = a_record { - fields = { - ["char"] = a_fn { args = va_args { NUMBER }, rets = { STRING } }, - ["charpattern"] = STRING, - ["codepoint"] = a_fn { args = { STRING, OPT(NUMBER), OPT(NUMBER) }, rets = va_args { INTEGER } }, - ["codes"] = a_fn { args = { STRING }, rets = { - a_fn { args = { STRING, OPT(NUMBER) }, rets = { NUMBER, NUMBER } }, - }, }, - ["len"] = a_fn { args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }, - ["offset"] = a_fn { args = { STRING, NUMBER, NUMBER }, rets = { INTEGER } }, - }, - }, - ["_VERSION"] = STRING, - } - - NOMINAL_FILE.found = standard_library["FILE"] - for _, m in ipairs(metatable_nominals) do - m.found = standard_library["metatable"] - end - - for name, typ in pairs(standard_library) do - globals[name] = { t = typ, needs_compat = stdlib_compat[name], attribute = "const" } - end - - -- only global scope and vararg functions accept `...`: - -- `@is_va` is an internal sentinel value which is - -- `any` if `...` is accepted in this scope or `nil` if it isn't. - globals["@is_va"] = { t = ANY } - - if not is_first_init then - last_typeid = save_typeid - end - - return globals, standard_library -end - local function set_feat(feat: tl.Feat, default: boolean): boolean if feat then return (feat == "on") @@ -6413,21 +6354,50 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" end - local globals, standard_library = init_globals(lax) - local env: Env = { modules = {}, loaded = {}, loaded_order = {}, - globals = globals, + globals = {}, gen_compat = gen_compat, gen_target = gen_target, } - -- make standard library tables available as modules for require() - for name, var in pairs(standard_library) do - if var is RecordType then - env.modules[name] = var + if not stdlib_globals then + local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") + assert(#syntax_errors == 0) + local result = tl.type_check(program, { + filename = "@stdlib", + env = env + }) + assert(#result.type_errors == 0) + stdlib_globals = env.globals; + + -- special cases for compatibility + local math_t = (stdlib_globals["math"].t as TypeDeclType).def as RecordType + local table_t = (stdlib_globals["table"].t as TypeDeclType).def as RecordType + local integer_compat = a_type("integer", { needs_compat = true }) + math_t.fields["maxinteger"] = integer_compat + math_t.fields["mininteger"] = integer_compat + table_t.fields["unpack"].needs_compat = true + + -- only global scope and vararg functions accept `...`: + -- `@is_va` is an internal sentinel value which is + -- `any` if `...` is accepted in this scope or `nil` if it isn't. + stdlib_globals["..."] = { t = a_vararg { STRING } } + stdlib_globals["@is_va"] = { t = ANY } + + env.globals = {} + end + + local stdlib_compat = get_stdlib_compat(lax) + for name, var in pairs(stdlib_globals) do + env.globals[name] = var + var.needs_compat = stdlib_compat[name] + local t = var.t + if t is TypeDeclType then + -- make standard library tables available as modules for require() + env.modules[name] = t end end From 8f8408eba75d4b57edb5773ff891af57d115882a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Perrad?= Date: Mon, 8 Jan 2024 16:37:29 -0300 Subject: [PATCH 095/224] fix select signature --- tl.lua | 4 ++-- tl.tl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index 3d2062a69..f6036ee0c 100644 --- a/tl.lua +++ b/tl.lua @@ -410,8 +410,8 @@ do print: function(any...) require: function(string): any - select: function(number, T...): T - select: function(number, any...): any + select: function(number, T...): T... + select: function(number, any...): any... select: function(string, any...): integer setmetatable: function(T, metatable): T diff --git a/tl.tl b/tl.tl index fbaa13909..9011df00c 100644 --- a/tl.tl +++ b/tl.tl @@ -410,8 +410,8 @@ do print: function(any...) require: function(string): any - select: function(number, T...): T - select: function(number, any...): any + select: function(number, T...): T... + select: function(number, any...): any... select: function(string, any...): integer setmetatable: function(T, metatable): T From 93126f4dbda0459eecb4e3fb9f4b0c4a0c59cf45 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 9 Jan 2024 02:28:45 -0300 Subject: [PATCH 096/224] Makefile: allow testing with different interpreter values This will allow me to more easily run local tests with my hacked `lua-no-tailcalls` interpreter binary. --- Makefile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index d32150d5a..e30d4f009 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +LUA ?= + ifeq ($(OS), Windows_NT) BUSTED = busted.bat --suppress-pending --exclude-tags=unix else @@ -8,8 +10,8 @@ all: selfbuild suite selfbuild: cp tl.lua tl.lua.bak - ./tl gen --check tl.tl && cp tl.lua tl.lua.1 || { cp tl.lua tl.lua.1; cp tl.lua.bak tl.lua; exit 1; } - ./tl gen --check tl.tl && cp tl.lua tl.lua.2 || { cp tl.lua tl.lua.2; cp tl.lua.bak tl.lua; exit 1; } + $(LUA) ./tl gen --check tl.tl && cp tl.lua tl.lua.1 || { cp tl.lua tl.lua.1; cp tl.lua.bak tl.lua; exit 1; } + $(LUA) ./tl gen --check tl.tl && cp tl.lua tl.lua.2 || { cp tl.lua tl.lua.2; cp tl.lua.bak tl.lua; exit 1; } diff tl.lua.1 tl.lua.2 suite: From aed4035aac80a9f4506c35461243288c17a2f971 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 9 Jan 2024 14:20:26 -0300 Subject: [PATCH 097/224] typecodes: drop mask entries See #733. --- tl.lua | 24 ++++++------------------ tl.tl | 24 ++++++------------------ 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/tl.lua b/tl.lua index f6036ee0c..714faedac 100644 --- a/tl.lua +++ b/tl.lua @@ -643,35 +643,23 @@ tl.typecodes = { USERDATA = 0x00000040, THREAD = 0x00000080, - IS_TABLE = 0x00000008, - IS_NUMBER = 0x00000002, - IS_STRING = 0x00000004, - LUA_MASK = 0x00000fff, - INTEGER = 0x00010002, + ENUM = 0x00010004, + EMPTY_TABLE = 0x00000008, ARRAY = 0x00010008, RECORD = 0x00020008, MAP = 0x00040008, TUPLE = 0x00080008, - EMPTY_TABLE = 0x00000008, - ENUM = 0x00010004, INTERFACE = 0x00100008, - - IS_ARRAY = 0x00010008, - IS_RECORD = 0x00020008, + POLY = 0x20000020, + UNION = 0x40000000, NOMINAL = 0x10000000, TYPE_VARIABLE = 0x08000000, - IS_UNION = 0x40000000, - IS_POLY = 0x20000020, - ANY = 0xffffffff, UNKNOWN = 0x80008000, INVALID = 0x80000000, - - IS_SPECIAL = 0x80000000, - IS_VALID = 0x00000fff, } @@ -5485,12 +5473,12 @@ local typename_to_typecode = { ["thread"] = tl.typecodes.THREAD, ["number"] = tl.typecodes.NUMBER, ["integer"] = tl.typecodes.INTEGER, - ["union"] = tl.typecodes.IS_UNION, + ["union"] = tl.typecodes.UNION, ["nominal"] = tl.typecodes.NOMINAL, ["circular_require"] = tl.typecodes.NOMINAL, ["emptytable"] = tl.typecodes.EMPTY_TABLE, ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, - ["poly"] = tl.typecodes.IS_POLY, + ["poly"] = tl.typecodes.POLY, ["any"] = tl.typecodes.ANY, ["unknown"] = tl.typecodes.UNKNOWN, ["invalid"] = tl.typecodes.INVALID, diff --git a/tl.tl b/tl.tl index 9011df00c..52e85e11b 100644 --- a/tl.tl +++ b/tl.tl @@ -642,36 +642,24 @@ tl.typecodes = { FUNCTION = 0x00000020, USERDATA = 0x00000040, THREAD = 0x00000080, - -- Lua type masks - IS_TABLE = 0x00000008, - IS_NUMBER = 0x00000002, - IS_STRING = 0x00000004, - LUA_MASK = 0x00000fff, -- Teal types INTEGER = 0x00010002, + ENUM = 0x00010004, + EMPTY_TABLE = 0x00000008, ARRAY = 0x00010008, RECORD = 0x00020008, MAP = 0x00040008, TUPLE = 0x00080008, - EMPTY_TABLE = 0x00000008, - ENUM = 0x00010004, INTERFACE = 0x00100008, - -- Teal type masks - IS_ARRAY = 0x00010008, - IS_RECORD = 0x00020008, + POLY = 0x20000020, + UNION = 0x40000000, -- Indirect types NOMINAL = 0x10000000, TYPE_VARIABLE = 0x08000000, - -- Indirect type masks - IS_UNION = 0x40000000, - IS_POLY = 0x20000020, -- Special types ANY = 0xffffffff, UNKNOWN = 0x80008000, INVALID = 0x80000000, - -- Special type masks - IS_SPECIAL = 0x80000000, - IS_VALID = 0x00000fff, } local type Result = tl.Result @@ -5485,12 +5473,12 @@ local typename_to_typecode : {TypeName:integer} = { ["thread"] = tl.typecodes.THREAD, ["number"] = tl.typecodes.NUMBER, ["integer"] = tl.typecodes.INTEGER, - ["union"] = tl.typecodes.IS_UNION, + ["union"] = tl.typecodes.UNION, ["nominal"] = tl.typecodes.NOMINAL, ["circular_require"] = tl.typecodes.NOMINAL, ["emptytable"] = tl.typecodes.EMPTY_TABLE, ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, - ["poly"] = tl.typecodes.IS_POLY, + ["poly"] = tl.typecodes.POLY, ["any"] = tl.typecodes.ANY, ["unknown"] = tl.typecodes.UNKNOWN, ["invalid"] = tl.typecodes.INVALID, From c17be97a9245d1cd52fefbc2badeb5612ebe854c Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 21 Aug 2024 18:32:54 -0300 Subject: [PATCH 098/224] TypeReporter and TypeCollector objects --- spec/util.lua | 2 +- tl | 11 +- tl.lua | 402 +++++++++++++++++++++++++---------------------- tl.tl | 422 +++++++++++++++++++++++++++----------------------- 4 files changed, 448 insertions(+), 389 deletions(-) diff --git a/spec/util.lua b/spec/util.lua index fab01dcbf..1976f4c66 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -567,7 +567,7 @@ function util.check_types(code, types) local result = tl.type_check(ast, { filename = "foo.tl", env = env, lax = false }) batch:add(assert.same, {}, result.type_errors, "Code was not expected to have type errors") - local tr = tl.get_types(result, env.trenv) + local tr = env.reporter:get_report() for i, e in ipairs(types) do assert(e.x, "[" .. i .. "] missing 'x' key in test specification") assert(e.y, "[" .. i .. "] missing 'y' key in test specification") diff --git a/tl b/tl index 89e4990fc..68a329fbc 100755 --- a/tl +++ b/tl @@ -866,22 +866,17 @@ do env.keep_going = true env.report_types = true - local tr, trenv for i, input_file in ipairs(args["file"]) do - local pok, result = pcall(tl.process, input_file, env) + local pok, err = pcall(tl.process, input_file, env) if not pok then - die("Internal Compiler Error: " .. result) - end - if pok then - if result and result.ast then - tr, trenv = tl.get_types(result, trenv) - end + die("Internal Compiler Error: " .. err) end check_collect(i) end local ok, _, _, w = report_all_errors(tlconfig, env) + local tr = env.reporter:get_report() if tr then if w or not ok then printerr("") diff --git a/tl.lua b/tl.lua index 714faedac..d1a4a3f4e 100644 --- a/tl.lua +++ b/tl.lua @@ -447,16 +447,7 @@ end ]=====] -local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, EnvOptions = {}, } - - - - - - - - - +local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } @@ -600,6 +591,7 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = { +local TypeReporter = {} @@ -677,6 +669,24 @@ tl.typecodes = { + + + + + + + + + + + + + + + + + + @@ -5498,9 +5508,6 @@ local skip_types = { ["unresolved"] = true, } -local get_typenum - - local function sorted_keys(m) local keys = {} for k, _ in pairs(m) do @@ -5517,8 +5524,8 @@ local function mark_array(x) return x end -function tl.init_type_report() - return { +function tl.new_type_reporter() + local self = { next_num = 1, typeid_to_num = {}, tr = { @@ -5528,35 +5535,36 @@ function tl.init_type_report() globals = {}, }, } + return setmetatable(self, { __index = TypeReporter }) end -local function store_function(trenv, ti, rt) +function TypeReporter:store_function(ti, rt) local args = {} for _, fnarg in ipairs(rt.args.tuple) do - table.insert(args, mark_array({ get_typenum(trenv, fnarg), nil })) + table.insert(args, mark_array({ self:get_typenum(fnarg), nil })) end ti.args = mark_array(args) local rets = {} for _, fnarg in ipairs(rt.rets.tuple) do - table.insert(rets, mark_array({ get_typenum(trenv, fnarg), nil })) + table.insert(rets, mark_array({ self:get_typenum(fnarg), nil })) end ti.rets = mark_array(rets) ti.vararg = not not rt.args.is_va ti.varret = not not rt.rets.is_va end -get_typenum = function(trenv, t) +function TypeReporter:get_typenum(t) assert(t.typeid) - local n = trenv.typeid_to_num[t.typeid] + local n = self.typeid_to_num[t.typeid] if n then return n end - local tr = trenv.tr + local tr = self.tr - n = trenv.next_num + n = self.next_num local rt = t if rt.typename == "tuple" and #rt.tuple == 1 then @@ -5577,12 +5585,12 @@ get_typenum = function(trenv, t) x = t.x, } tr.types[n] = ti - trenv.typeid_to_num[t.typeid] = n - trenv.next_num = trenv.next_num + 1 + self.typeid_to_num[t.typeid] = n + self.next_num = self.next_num + 1 if t.typename == "nominal" then if t.found then - ti.ref = get_typenum(trenv, t.found) + ti.ref = self:get_typenum(t.found) end if t.resolved then rt = t @@ -5595,26 +5603,26 @@ get_typenum = function(trenv, t) local r = {} for _, k in ipairs(rt.field_order) do local v = rt.fields[k] - r[k] = get_typenum(trenv, v) + r[k] = self:get_typenum(v) end ti.fields = r end if rt.elements then - ti.elements = get_typenum(trenv, rt.elements) + ti.elements = self:get_typenum(rt.elements) end if rt.typename == "map" then - ti.keys = get_typenum(trenv, rt.keys) - ti.values = get_typenum(trenv, rt.values) + ti.keys = self:get_typenum(rt.keys) + ti.values = self:get_typenum(rt.values) elseif rt.typename == "enum" then ti.enums = mark_array(sorted_keys(rt.enumset)) elseif rt.typename == "function" then - store_function(trenv, ti, rt) + self:store_function(ti, rt) elseif rt.types then local tis = {} for _, pt in ipairs(rt.types) do - table.insert(tis, get_typenum(trenv, pt)) + table.insert(tis, self:get_typenum(pt)) end ti.types = mark_array(tis) end @@ -5622,13 +5630,38 @@ get_typenum = function(trenv, t) return n end -local function make_type_reporter(filename, trenv) + + + + + + + + + + + + + + + + + +function TypeReporter:get_collector(filename) + local tc = { + filename = filename, + symbol_list = {}, + } + local ft = {} - trenv.tr.by_pos[filename] = ft + self.tr.by_pos[filename] = ft - local function store_type(y, x, typ) + local symbol_list = tc.symbol_list + local symbol_list_n = 0 + + tc.store_type = function(y, x, typ) if not typ or skip_types[typ.typename] then return end @@ -5639,10 +5672,128 @@ local function make_type_reporter(filename, trenv) ft[y] = yt end - yt[x] = get_typenum(trenv, typ) + yt[x] = self:get_typenum(typ) + end + + tc.reserve_symbol_list_slot = function(node) + symbol_list_n = symbol_list_n + 1 + node.symbol_list_slot = symbol_list_n + end + + tc.add_to_symbol_list = function(node, name, t) + if not node then + return + end + local slot + if node.symbol_list_slot then + slot = node.symbol_list_slot + else + symbol_list_n = symbol_list_n + 1 + slot = symbol_list_n + end + symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } + end + + tc.begin_symbol_list_scope = function(node) + symbol_list_n = symbol_list_n + 1 + symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } + end + + tc.end_symbol_list_scope = function(node) + if symbol_list[symbol_list_n].name == "@{" then + symbol_list[symbol_list_n] = nil + symbol_list_n = symbol_list_n - 1 + else + symbol_list_n = symbol_list_n + 1 + symbol_list[symbol_list_n] = { y = assert(node.yend), x = assert(node.xend), name = "@}" } + end + end + + return tc +end + +function TypeReporter:store_result(tc, globals) + local tr = self.tr + + local filename = tc.filename + local symbol_list = tc.symbol_list + + tr.by_pos[filename][0] = nil + + + do + local n = 0 + local p = 0 + local n_stack, p_stack = {}, {} + local level = 0 + for i, s in ipairs(symbol_list) do + if s.typ then + n = n + 1 + elseif s.name == "@{" then + level = level + 1 + n_stack[level], p_stack[level] = n, p + n, p = 0, i + else + if n == 0 then + symbol_list[p].skip = true + s.skip = true + end + n, p = n_stack[level], p_stack[level] + level = level - 1 + end + end + end + + local symbols = mark_array({}) + tr.symbols_by_file[filename] = symbols + + + do + local stack = {} + local level = 0 + local i = 0 + for _, s in ipairs(symbol_list) do + if not s.skip then + i = i + 1 + local id + if s.typ then + id = self:get_typenum(s.typ) + elseif s.name == "@{" then + level = level + 1 + stack[level] = i + id = -1 + else + local other = stack[level] + level = level - 1 + symbols[other][4] = i + id = other - 1 + end + local sym = mark_array({ s.y, s.x, s.name, id }) + table.insert(symbols, sym) + end + end + end + + local gkeys = sorted_keys(globals) + for _, name in ipairs(gkeys) do + if name:sub(1, 1) ~= "@" then + local var = globals[name] + tr.globals[name] = self:get_typenum(var.t) + end + end + + if not tr.symbols then + tr.symbols = tr.symbols_by_file[filename] end +end + +function TypeReporter:get_report() + return self.tr +end - return store_type + +function tl.get_types(result) + return result.env.reporter:get_report(), result.env.reporter end @@ -6141,26 +6292,6 @@ function tl.search_module(module_name, search_dtl) end return nil, nil, tried end - - - - - - - - - - - - - - - - - - - - local function require_module(module_name, lax, env) local mod = env.modules[module_name] @@ -6405,6 +6536,7 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) end tl.type_check = function(ast, opts) + opts = opts or {} local env = opts.env if not env then @@ -6434,13 +6566,10 @@ tl.type_check = function(ast, opts) local module_type - local symbol_list - local symbol_list_n = 0 - local store_type + local tc if env.report_types then - symbol_list = {} - env.trenv = env.trenv or tl.init_type_report() - store_type = make_type_reporter(filename or "?", env.trenv) + env.reporter = env.reporter or tl.new_type_reporter() + tc = env.reporter:get_collector(filename or "?") end @@ -7108,11 +7237,6 @@ tl.type_check = function(ast, opts) return t end - local function reserve_symbol_list_slot(node) - symbol_list_n = symbol_list_n + 1 - node.symbol_list_slot = symbol_list_n - end - local get_unresolved local find_unresolved @@ -7175,15 +7299,8 @@ tl.type_check = function(ast, opts) return var end - if symbol_list and node then - local slot - if node.symbol_list_slot then - slot = node.symbol_list_slot - else - symbol_list_n = symbol_list_n + 1 - slot = symbol_list_n - end - symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } + if tc and node then + tc.add_to_symbol_list(node, name, t) end return var @@ -7351,9 +7468,8 @@ tl.type_check = function(ast, opts) local function begin_scope(node) table.insert(st, {}) - if symbol_list and node then - symbol_list_n = symbol_list_n + 1 - symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } + if tc and node then + tc.begin_symbol_list_scope(node) end end @@ -7390,14 +7506,8 @@ tl.type_check = function(ast, opts) check_for_unused_vars(scope) table.remove(st) - if symbol_list and node then - if symbol_list[symbol_list_n].name == "@{" then - symbol_list[symbol_list_n] = nil - symbol_list_n = symbol_list_n - 1 - else - symbol_list_n = symbol_list_n + 1 - symbol_list[symbol_list_n] = { y = assert(node.yend), x = assert(node.xend), name = "@}" } - end + if tc and node then + tc.end_symbol_list_scope(node) end end @@ -8865,8 +8975,8 @@ a.types[i], b.types[i]), } ret = resolve_typevars_at(node, ret) end_scope() - if store_type and e1 then - store_type(e1.y, e1.x, f) + if tc and e1 then + tc.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -10488,9 +10598,9 @@ expand_type(node, values, elements) }) }, ["local_declaration"] = { before = function(node) - if symbol_list then + if tc then for _, var in ipairs(node.vars) do - reserve_symbol_list_slot(var) + tc.reserve_symbol_list_slot(var) end end end, @@ -10540,8 +10650,8 @@ expand_type(node, values, elements) }) end end - if store_type then - store_type(var.y, var.x, t) + if tc then + tc.store_type(var.y, var.x, t) end dismiss_unresolved(var.tk) @@ -10603,8 +10713,8 @@ expand_type(node, values, elements) }) add_var(varnode, varname, rval, nil, "narrow") end - if store_type then - store_type(varnode.y, varnode.x, valtype) + if tc then + tc.store_type(varnode.y, varnode.x, valtype) end end end @@ -10739,8 +10849,8 @@ expand_type(node, values, elements) }) end add_var(v, v.tk, r) - if store_type then - store_type(v.y, v.x, r) + if tc then + tc.store_type(v.y, v.x, r) end last = r @@ -11050,8 +11160,8 @@ expand_type(node, values, elements) }) ["local_function"] = { before = function(node) widen_all_unions() - if symbol_list then - reserve_symbol_list_slot(node) + if tc then + tc.reserve_symbol_list_slot(node) end begin_scope(node) end, @@ -11084,8 +11194,8 @@ expand_type(node, values, elements) }) ["local_macroexp"] = { before = function(node) widen_all_unions() - if symbol_list then - reserve_symbol_list_slot(node) + if tc then + tc.reserve_symbol_list_slot(node) end begin_scope(node) end, @@ -12090,7 +12200,7 @@ expand_type(node, values, elements) }) local where = w if where.y then - store_type(where.y, where.x, t) + tc.store_type(where.y, where.x, t) end return t @@ -12110,7 +12220,7 @@ expand_type(node, values, elements) }) visit_type.after = internal_compiler_check(visit_type.after) end - if store_type then + if tc then visit_node.after = store_type_after(visit_node.after) visit_type.after = store_type_after(visit_type.after) end @@ -12167,7 +12277,6 @@ expand_type(node, values, elements) }) filename = filename, warnings = warnings, type_errors = errors, - symbol_list = symbol_list, dependencies = dependencies, } @@ -12178,93 +12287,16 @@ expand_type(node, values, elements) }) env.modules[opts.module_name] = result.type end - return result -end - - - - - -function tl.get_types(result, trenv) - local filename = result.filename or "?" - trenv = trenv or result.env.trenv - - if not trenv then - error("result must have been generated with env.report_types = true", 2) - end - - local tr = trenv.tr - - tr.by_pos[filename][0] = nil - - - do - local n = 0 - local p = 0 - local n_stack, p_stack = {}, {} - local level = 0 - for i, s in ipairs(result.symbol_list) do - if s.typ then - n = n + 1 - elseif s.name == "@{" then - level = level + 1 - n_stack[level], p_stack[level] = n, p - n, p = 0, i - else - if n == 0 then - result.symbol_list[p].skip = true - s.skip = true - end - n, p = n_stack[level], p_stack[level] - level = level - 1 - end - end + if tc then + env.reporter:store_result(tc, env.globals) end - local symbols = mark_array({}) - tr.symbols_by_file[filename] = symbols - + return result +end - do - local stack = {} - local level = 0 - local i = 0 - for _, s in ipairs(result.symbol_list) do - if not s.skip then - i = i + 1 - local id - if s.typ then - id = get_typenum(trenv, s.typ) - elseif s.name == "@{" then - level = level + 1 - stack[level] = i - id = -1 - else - local other = stack[level] - level = level - 1 - symbols[other][4] = i - id = other - 1 - end - local sym = mark_array({ s.y, s.x, s.name, id }) - table.insert(symbols, sym) - end - end - end - local gkeys = sorted_keys(result.env.globals) - for _, name in ipairs(gkeys) do - if name:sub(1, 1) ~= "@" then - local var = result.env.globals[name] - tr.globals[name] = get_typenum(trenv, var.t) - end - end - if not tr.symbols then - tr.symbols = tr.symbols_by_file[filename] - end - return tr, trenv -end function tl.symbols_in_scope(tr, y, x) local function find(symbols, at_y, at_x) diff --git a/tl.tl b/tl.tl index 52e85e11b..ca89311be 100644 --- a/tl.tl +++ b/tl.tl @@ -496,7 +496,7 @@ local record tl modules: {string:Type} loaded: {string:Result} loaded_order: {string} - trenv: TypeReportEnv + reporter: TypeReporter gen_compat: CompatMode gen_target: TargetMode keep_going: boolean @@ -504,15 +504,6 @@ local record tl feat_arity: boolean end - record Symbol - x: integer - y: integer - name: string - typ: Type - other: integer - skip: boolean - end - record Result filename: string ast: Node @@ -521,7 +512,6 @@ local record tl type_errors: {Error} gen_error: string warnings: {Error} - symbol_list: {Symbol} env: Env dependencies: {string:string} -- module name, file found end @@ -580,12 +570,6 @@ local record tl globals: {string: integer} end - record TypeReportEnv - typeid_to_num: {integer: integer} - next_num: integer - tr: TypeReport - end - record EnvOptions lax_mode: boolean gen_compat: CompatMode @@ -607,6 +591,14 @@ local record tl load_envs: { {any:any} : Env } end +local record TypeReporter + typeid_to_num: {integer: integer} + next_num: integer + tr: TypeReport + + get_typenum: function(TypeReporter, Type): integer +end + tl.version = function(): string return VERSION end @@ -673,8 +665,26 @@ local type LoadFunction = tl.LoadFunction local type TargetMode = tl.TargetMode local type TypeInfo = tl.TypeInfo local type TypeReport = tl.TypeReport -local type TypeReportEnv = tl.TypeReportEnv -local type Symbol = tl.Symbol + +local enum Narrow + "narrow" + "narrowed_declaration" + "declaration" +end + +local record Variable + t: Type + attribute: Attribute + needs_compat: boolean + narrowed_from: Type + is_narrowed: Narrow + declared_at: Node + is_func_arg: boolean + used: boolean + used_as_type: boolean + aliasing: Variable + implemented: {string:boolean} +end -------------------------------------------------------------------------------- -- Compiler debugging @@ -5498,9 +5508,6 @@ local skip_types: {TypeName: boolean} = { ["unresolved"] = true, } -local get_typenum: function(trenv: TypeReportEnv, t: Type): integer -local type StoreType = function(y: integer, x: integer, typ: Type) - local function sorted_keys(m: {A:B}):{A} local keys = {} for k, _ in pairs(m) do @@ -5517,8 +5524,8 @@ local function mark_array(x: T): T return x end -function tl.init_type_report(): TypeReportEnv - return { +function tl.new_type_reporter(): TypeReporter + local self: TypeReporter = { next_num = 1, typeid_to_num = {}, tr = { @@ -5528,35 +5535,36 @@ function tl.init_type_report(): TypeReportEnv globals = {}, }, } + return setmetatable(self, { __index = TypeReporter }) end -local function store_function(trenv: TypeReportEnv, ti: TypeInfo, rt: FunctionType) +function TypeReporter:store_function(ti: TypeInfo, rt: FunctionType) local args: {{integer, string}} = {} for _, fnarg in ipairs(rt.args.tuple) do - table.insert(args, mark_array { get_typenum(trenv, fnarg), nil }) + table.insert(args, mark_array { self:get_typenum(fnarg), nil }) end ti.args = mark_array(args) local rets: {{integer, string}} = {} for _, fnarg in ipairs(rt.rets.tuple) do - table.insert(rets, mark_array { get_typenum(trenv, fnarg), nil }) + table.insert(rets, mark_array { self:get_typenum(fnarg), nil }) end ti.rets = mark_array(rets) ti.vararg = not not rt.args.is_va ti.varret = not not rt.rets.is_va end -get_typenum = function(trenv: TypeReportEnv, t: Type): integer +function TypeReporter:get_typenum(t: Type): integer assert(t.typeid) -- try by typeid - local n = trenv.typeid_to_num[t.typeid] + local n = self.typeid_to_num[t.typeid] if n then return n end - local tr = trenv.tr + local tr = self.tr -- it's a new entry: store and increment - n = trenv.next_num + n = self.next_num local rt = t if rt is TupleType and #rt.tuple == 1 then @@ -5577,12 +5585,12 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer x = t.x, } tr.types[n] = ti - trenv.typeid_to_num[t.typeid] = n - trenv.next_num = trenv.next_num + 1 + self.typeid_to_num[t.typeid] = n + self.next_num = self.next_num + 1 if t is NominalType then if t.found then - ti.ref = get_typenum(trenv, t.found) + ti.ref = self:get_typenum(t.found) end if t.resolved then rt = t @@ -5595,26 +5603,26 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer local r = {} for _, k in ipairs(rt.field_order) do local v = rt.fields[k] - r[k] = get_typenum(trenv, v) + r[k] = self:get_typenum(v) end ti.fields = r end if rt is ArrayLikeType then - ti.elements = get_typenum(trenv, rt.elements) + ti.elements = self:get_typenum(rt.elements) end if rt is MapType then - ti.keys = get_typenum(trenv, rt.keys) - ti.values = get_typenum(trenv, rt.values) + ti.keys = self:get_typenum(rt.keys) + ti.values = self:get_typenum(rt.values) elseif rt is EnumType then ti.enums = mark_array(sorted_keys(rt.enumset)) elseif rt is FunctionType then - store_function(trenv, ti, rt) + self:store_function(ti, rt) elseif rt is AggregateType then local tis = {} for _, pt in ipairs(rt.types) do - table.insert(tis, get_typenum(trenv, pt)) + table.insert(tis, self:get_typenum(pt)) end ti.types = mark_array(tis) end @@ -5622,13 +5630,38 @@ get_typenum = function(trenv: TypeReportEnv, t: Type): integer return n end -local function make_type_reporter(filename: string, trenv: TypeReportEnv): StoreType --- local filename = result.filename or "?" +local record TypeCollector + record Symbol + x: integer + y: integer + name: string + typ: Type + skip: boolean + end + + filename: string + symbol_list: {Symbol} + + store_type: function(y: integer, x: integer, typ: Type) + reserve_symbol_list_slot: function(Node) + add_to_symbol_list: function(node: Node, name: string, t: Type) + begin_symbol_list_scope: function(node: Node) + end_symbol_list_scope: function(node: Node) +end + +function TypeReporter:get_collector(filename: string): TypeCollector + local tc: TypeCollector = { + filename = filename, + symbol_list = {}, + } local ft: {integer:{integer:integer}} = {} - trenv.tr.by_pos[filename] = ft + self.tr.by_pos[filename] = ft + + local symbol_list = tc.symbol_list + local symbol_list_n = 0 - local function store_type(y: integer, x: integer, typ: Type) + tc.store_type = function(y: integer, x: integer, typ: Type) if not typ or skip_types[typ.typename] then return end @@ -5639,10 +5672,128 @@ local function make_type_reporter(filename: string, trenv: TypeReportEnv): Store ft[y] = yt end - yt[x] = get_typenum(trenv, typ) + yt[x] = self:get_typenum(typ) + end + + tc.reserve_symbol_list_slot = function(node: Node) + symbol_list_n = symbol_list_n + 1 + node.symbol_list_slot = symbol_list_n + end + + tc.add_to_symbol_list = function(node: Node, name: string, t: Type) + if not node then + return + end + local slot: integer + if node.symbol_list_slot then + slot = node.symbol_list_slot + else + symbol_list_n = symbol_list_n + 1 + slot = symbol_list_n + end + symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } + end + + tc.begin_symbol_list_scope = function(node: Node) + symbol_list_n = symbol_list_n + 1 + symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } + end + + tc.end_symbol_list_scope = function(node: Node) + if symbol_list[symbol_list_n].name == "@{" then + symbol_list[symbol_list_n] = nil + symbol_list_n = symbol_list_n - 1 + else + symbol_list_n = symbol_list_n + 1 + symbol_list[symbol_list_n] = { y = assert(node.yend), x = assert(node.xend), name = "@}" } + end + end + + return tc +end + +function TypeReporter:store_result(tc: TypeCollector, globals: {string:Variable}) + local tr = self.tr + + local filename = tc.filename + local symbol_list = tc.symbol_list + + tr.by_pos[filename][0] = nil + + -- mark unneeded scope blocks to be skipped + do + local n = 0 -- number of symbols in current scope + local p = 0 -- opening position of current scope block + local n_stack, p_stack = {}, {} + local level = 0 + for i, s in ipairs(symbol_list) do + if s.typ then + n = n + 1 + elseif s.name == "@{" then + level = level + 1 + n_stack[level], p_stack[level] = n, p -- push current scope + n, p = 0, i -- begin new scope + else + if n == 0 then -- nothing declared in this scope + symbol_list[p].skip = true -- skip @{ + s.skip = true -- skip @} + end + n, p = n_stack[level], p_stack[level] -- pop previous scope + level = level - 1 + end + end + end + + local symbols: {TypeReport.Symbol} = mark_array {} + tr.symbols_by_file[filename] = symbols + + -- resolve scope cross references, skipping unneeded scope blocks + do + local stack = {} + local level = 0 + local i = 0 + for _, s in ipairs(symbol_list) do + if not s.skip then + i = i + 1 + local id: integer + if s.typ then + id = self:get_typenum(s.typ) + elseif s.name == "@{" then + level = level + 1 + stack[level] = i + id = -1 -- will be overwritten + else + local other = stack[level] + level = level - 1 + symbols[other][4] = i -- overwrite id from @{ + id = other - 1 + end + local sym = mark_array({ s.y, s.x, s.name, id }) + table.insert(symbols, sym) + end + end + end + + local gkeys = sorted_keys(globals) + for _, name in ipairs(gkeys) do + if name:sub(1, 1) ~= "@" then + local var = globals[name] + tr.globals[name] = self:get_typenum(var.t) + end + end + + if not tr.symbols then + tr.symbols = tr.symbols_by_file[filename] end +end + +function TypeReporter:get_report(): TypeReport + return self.tr +end - return store_type +-- backwards compatibility +function tl.get_types(result: Result): TypeReport, TypeReporter + return result.env.reporter:get_report(), result.env.reporter end -------------------------------------------------------------------------------- @@ -6142,26 +6293,6 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local enum Narrow - "narrow" - "narrowed_declaration" - "declaration" -end - -local record Variable - t: Type - attribute: Attribute - needs_compat: boolean - narrowed_from: Type - is_narrowed: Narrow - declared_at: Node - is_func_arg: boolean - used: boolean - used_as_type: boolean - aliasing: Variable - implemented: {string:boolean} -end - local function require_module(module_name: string, lax: boolean, env: Env): Type, boolean local mod = env.modules[module_name] if mod then @@ -6405,6 +6536,7 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar end tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string + opts = opts or {} local env = opts.env if not env then @@ -6434,13 +6566,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local module_type: Type - local symbol_list: {Symbol} - local symbol_list_n = 0 - local store_type: StoreType + local tc: TypeCollector if env.report_types then - symbol_list = {} - env.trenv = env.trenv or tl.init_type_report() - store_type = make_type_reporter(filename or "?", env.trenv) + env.reporter = env.reporter or tl.new_type_reporter() + tc = env.reporter:get_collector(filename or "?") end local enum VarUse @@ -7108,11 +7237,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local function reserve_symbol_list_slot(node: Node) - symbol_list_n = symbol_list_n + 1 - node.symbol_list_slot = symbol_list_n - end - local get_unresolved: function(scope?: Scope): UnresolvedType local find_unresolved: function(level?: integer): UnresolvedType @@ -7175,15 +7299,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return var end - if symbol_list and node then - local slot: integer - if node.symbol_list_slot then - slot = node.symbol_list_slot - else - symbol_list_n = symbol_list_n + 1 - slot = symbol_list_n - end - symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } + if tc and node then + tc.add_to_symbol_list(node, name, t) end return var @@ -7351,9 +7468,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function begin_scope(node?: Node) table.insert(st, {}) - if symbol_list and node then - symbol_list_n = symbol_list_n + 1 - symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } + if tc and node then + tc.begin_symbol_list_scope(node) end end @@ -7390,14 +7506,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string check_for_unused_vars(scope) table.remove(st) - if symbol_list and node then - if symbol_list[symbol_list_n].name == "@{" then - symbol_list[symbol_list_n] = nil - symbol_list_n = symbol_list_n - 1 - else - symbol_list_n = symbol_list_n + 1 - symbol_list[symbol_list_n] = { y = assert(node.yend), x = assert(node.xend), name = "@}" } - end + if tc and node then + tc.end_symbol_list_scope(node) end end @@ -8865,8 +8975,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ret = resolve_typevars_at(node, ret) end_scope() - if store_type and e1 then - store_type(e1.y, e1.x, f) + if tc and e1 then + tc.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -10488,9 +10598,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["local_declaration"] = { before = function(node: Node) - if symbol_list then + if tc then for _, var in ipairs(node.vars) do - reserve_symbol_list_slot(var) + tc.reserve_symbol_list_slot(var) end end end, @@ -10540,8 +10650,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if store_type then - store_type(var.y, var.x, t) + if tc then + tc.store_type(var.y, var.x, t) end dismiss_unresolved(var.tk) @@ -10603,8 +10713,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string add_var(varnode, varname, rval, nil, "narrow") end - if store_type then - store_type(varnode.y, varnode.x, valtype) + if tc then + tc.store_type(varnode.y, varnode.x, valtype) end end end @@ -10739,8 +10849,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end add_var(v, v.tk, r) - if store_type then - store_type(v.y, v.x, r) + if tc then + tc.store_type(v.y, v.x, r) end last = r @@ -11050,8 +11160,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["local_function"] = { before = function(node: Node) widen_all_unions() - if symbol_list then - reserve_symbol_list_slot(node) + if tc then + tc.reserve_symbol_list_slot(node) end begin_scope(node) end, @@ -11084,8 +11194,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["local_macroexp"] = { before = function(node: Node) widen_all_unions() - if symbol_list then - reserve_symbol_list_slot(node) + if tc then + tc.reserve_symbol_list_slot(node) end begin_scope(node) end, @@ -12090,7 +12200,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local where = w as Where if where.y then - store_type(where.y, where.x, t) + tc.store_type(where.y, where.x, t) end return t @@ -12110,7 +12220,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.after = internal_compiler_check(visit_type.after) end - if store_type then + if tc then visit_node.after = store_type_after(visit_node.after) visit_type.after = store_type_after(visit_type.after) end @@ -12167,7 +12277,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string filename = filename, warnings = warnings, type_errors = errors, - symbol_list = symbol_list, dependencies = dependencies, } @@ -12178,6 +12287,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string env.modules[opts.module_name] = result.type end + if tc then + env.reporter:store_result(tc, env.globals) + end + return result end @@ -12185,87 +12298,6 @@ end -- Report types -------------------------------------------------------------------------------- -function tl.get_types(result: Result, trenv: TypeReportEnv): TypeReport, TypeReportEnv - local filename = result.filename or "?" - trenv = trenv or result.env.trenv - - if not trenv then - error("result must have been generated with env.report_types = true", 2) - end - - local tr = trenv.tr - - tr.by_pos[filename][0] = nil - - -- mark unneeded scope blocks to be skipped - do - local n = 0 -- number of symbols in current scope - local p = 0 -- opening position of current scope block - local n_stack, p_stack = {}, {} - local level = 0 - for i, s in ipairs(result.symbol_list) do - if s.typ then - n = n + 1 - elseif s.name == "@{" then - level = level + 1 - n_stack[level], p_stack[level] = n, p -- push current scope - n, p = 0, i -- begin new scope - else - if n == 0 then -- nothing declared in this scope - result.symbol_list[p].skip = true -- skip @{ - s.skip = true -- skip @} - end - n, p = n_stack[level], p_stack[level] -- pop previous scope - level = level - 1 - end - end - end - - local symbols: {TypeReport.Symbol} = mark_array {} - tr.symbols_by_file[filename] = symbols - - -- resolve scope cross references, skipping unneeded scope blocks - do - local stack = {} - local level = 0 - local i = 0 - for _, s in ipairs(result.symbol_list) do - if not s.skip then - i = i + 1 - local id: integer - if s.typ then - id = get_typenum(trenv, s.typ) - elseif s.name == "@{" then - level = level + 1 - stack[level] = i - id = -1 -- will be overwritten - else - local other = stack[level] - level = level - 1 - symbols[other][4] = i -- overwrite id from @{ - id = other - 1 - end - local sym = mark_array({ s.y, s.x, s.name, id }) - table.insert(symbols, sym) - end - end - end - - local gkeys = sorted_keys(result.env.globals) - for _, name in ipairs(gkeys) do - if name:sub(1, 1) ~= "@" then - local var = result.env.globals[name] - tr.globals[name] = get_typenum(trenv, var.t) - end - end - - if not tr.symbols then - tr.symbols = tr.symbols_by_file[filename] - end - - return tr, trenv -end - function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer local function le(a: {integer, integer}, b: {integer, integer}): boolean From e7a1966416af6f46226fa2718ad834ef7873d0d6 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 10 Jan 2024 01:28:00 -0300 Subject: [PATCH 099/224] Where: make it an interface --- tl.lua | 16 ++++++---------- tl.tl | 20 ++++++++------------ 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/tl.lua b/tl.lua index d1a4a3f4e..eab8eee2b 100644 --- a/tl.lua +++ b/tl.lua @@ -447,6 +447,12 @@ end ]=====] + + + + + + local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } @@ -1827,8 +1833,6 @@ local table_types = { - - @@ -2001,14 +2005,6 @@ local Node = {ExpectedContext = {}, } - - - - - - - - diff --git a/tl.tl b/tl.tl index ca89311be..59270631f 100644 --- a/tl.tl +++ b/tl.tl @@ -447,6 +447,12 @@ end ]=====] +local interface Where + y: integer + x: integer + filename: string +end + local record tl enum LoadMode "b" @@ -1528,16 +1534,14 @@ local table_types : {TypeName:boolean} = { } local interface Type + is Where where self.typename typename: TypeName -- discriminator typeid: integer -- unique identifier - y: integer - x: integer yend: integer xend: integer - filename: string inferred_at: Where @@ -1908,7 +1912,7 @@ local attributes : {Attribute: boolean} = { local is_attribute : {string:boolean} = attributes as {string:boolean} local record Node - is {Node} + is {Node}, Where where self.kind ~= nil record ExpectedContext @@ -1916,10 +1920,6 @@ local record Node name: string end - y: integer - x: integer - filename: string - tk: string kind: NodeKind symbol_list_slot: integer @@ -2014,10 +2014,6 @@ local record Node debug_type: Type end -local type Where - = Node - | Type - local function is_number_type(t:Type): boolean return t.typename == "number" or t.typename == "integer" end From ea7d8413c02443f3a7209010b3d6c59244cdb8b2 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 10 Jan 2024 01:37:17 -0300 Subject: [PATCH 100/224] Type: remove xend and yend --- tl.lua | 8 -------- tl.tl | 16 ++++------------ 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/tl.lua b/tl.lua index eab8eee2b..56fde1759 100644 --- a/tl.lua +++ b/tl.lua @@ -1823,12 +1823,6 @@ local table_types = { - - - - - - @@ -6928,8 +6922,6 @@ tl.type_check = function(ast, opts) copy.filename = t.filename copy.x = t.x copy.y = t.y - copy.yend = t.yend - copy.xend = t.xend if t.typename == "array" then assert(copy.typename == "array") diff --git a/tl.tl b/tl.tl index 59270631f..51140248b 100644 --- a/tl.tl +++ b/tl.tl @@ -1537,16 +1537,10 @@ local interface Type is Where where self.typename - typename: TypeName -- discriminator - typeid: integer -- unique identifier - - yend: integer - xend: integer - - inferred_at: Where - - -- Lua compatibilty - needs_compat: boolean + typename: TypeName -- discriminator + typeid: integer -- unique identifier + inferred_at: Where -- for error messages + needs_compat: boolean -- for Lua compatibilty end local record StringType @@ -6928,8 +6922,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.filename = t.filename copy.x = t.x copy.y = t.y - copy.yend = t.yend - copy.xend = t.xend if t is ArrayType then assert(copy is ArrayType) From cbbb8f1c44d153d00d2e7553732c853408eeed59 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 14 Jan 2024 19:58:31 -0300 Subject: [PATCH 101/224] fix: resolves 'or' to the larger type --- spec/operator/is_spec.lua | 2 +- spec/operator/or_spec.lua | 27 +++++++++++++++++++++ tl.lua | 51 ++++++++++++++++++++++++++------------- tl.tl | 51 ++++++++++++++++++++++++++------------- 4 files changed, 96 insertions(+), 35 deletions(-) diff --git a/spec/operator/is_spec.lua b/spec/operator/is_spec.lua index 9e31199d2..aa5c31749 100644 --- a/spec/operator/is_spec.lua +++ b/spec/operator/is_spec.lua @@ -81,7 +81,7 @@ describe("flow analysis with is", function() local a: string | Foo local _b: Foo = a is string and makeFoo(a) or a ]], { - { msg = "cannot use operator 'or' for types Foo and string | Foo" }, + { msg = "got string | Foo, expected Foo" }, })) end) diff --git a/spec/operator/or_spec.lua b/spec/operator/or_spec.lua index eea54753e..ab48cbb70 100644 --- a/spec/operator/or_spec.lua +++ b/spec/operator/or_spec.lua @@ -135,4 +135,31 @@ describe("or", function() print(v * v) end ]])) + + it("resolves 'or' to the larger type", util.check_type_error([[ + local record A + where self.tag == "a" + b: B + tag: string + end + + local record B + where self.tag == "b" + tag: string + end + + local function wants_b(my_b: B) + print(my_b) + end + + local ab: A | B + + local b = ab is A and ab.b or ab -- ab.b may be nil, causing the value of type A to be returned via 'or' + + wants_b(b) + ]], { + { y = 20, x = 15, msg = "got A | B, expected B" }, + })) + + end) diff --git a/tl.lua b/tl.lua index 56fde1759..a8815d6f1 100644 --- a/tl.lua +++ b/tl.lua @@ -7688,6 +7688,13 @@ tl.type_check = function(ast, opts) local is_lua_table_type local resolve_tuple_and_nominal + local function to_structural(t) + if t.typename == "nominal" then + return resolve_nominal(t) + end + return t + end + local function unite(types, flatten_constants) if #types == 1 then return types[1] @@ -8826,6 +8833,14 @@ a.types[i], b.types[i]), } end end + local function resolve_function_type(func, i) + if func.typename == "poly" then + return func.types[i] + else + return func + end + end + local function fail_call(where, func, nargs, errs) if errs then @@ -8851,7 +8866,7 @@ a.types[i], b.types[i]), } error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end - local f = func.typename == "poly" and func.types[1] or func + local f = resolve_function_type(func, 1) mark_invalid_typeargs(f) @@ -8886,7 +8901,7 @@ a.types[i], b.types[i]), } for pass = 1, passes do for i = 1, n do if (not tried) or not tried[i] then - local f = func.typename == "poly" and func.types[i] or func + local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then @@ -11616,22 +11631,24 @@ expand_type(node, values, elements) }) end t = u - elseif is_a(rb, ra) then - node.known = facts_or(node, node.e1.known, node.e2.known) - if expected then - local a_is = is_a(a, node.expected) - local b_is = is_a(b, node.expected) - if a_is and b_is then - t = resolve_typevars_at(node, node.expected) - elseif a_is then - t = resolve_tuple(b) - else - t = resolve_tuple(a) + else + local a_ge_b = is_a(rb, ra) + local b_ge_a = is_a(ra, rb) + if a_ge_b or b_ge_a then + node.known = facts_or(node, node.e1.known, node.e2.known) + if expected then + local a_is = is_a(a, node.expected) + local b_is = is_a(b, node.expected) + if a_is and b_is then + t = resolve_typevars_at(node, node.expected) + end end - else - t = resolve_tuple(a) + if not t then + local larger_type = b_ge_a and b or a + t = resolve_tuple(larger_type) + end + t = drop_constant_value(t) end - t = drop_constant_value(t) end if t then @@ -11889,7 +11906,7 @@ expand_type(node, values, elements) }) local t = after_literal(node) t.literal = node.conststr - local expected = node.expected + local expected = node.expected and to_structural(node.expected) if expected and expected.typename == "enum" and is_a(t, expected) then return node.expected end diff --git a/tl.tl b/tl.tl index 51140248b..f1f44173a 100644 --- a/tl.tl +++ b/tl.tl @@ -7688,6 +7688,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local is_lua_table_type: function(t: Type): boolean local resolve_tuple_and_nominal: function(t: Type): Type + local function to_structural(t: Type): Type + if t is NominalType then + return resolve_nominal(t) + end + return t + end + local function unite(types: {Type}, flatten_constants?: boolean): Type if #types == 1 then return types[1] @@ -8826,6 +8833,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end + local function resolve_function_type(func: FunctionType | PolyType, i: integer): FunctionType + if func is PolyType then + return func.types[i] + else + return func + end + end + local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType if errs then -- report the errors from the first match @@ -8851,7 +8866,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end - local f = func is PolyType and func.types[1] or func + local f = resolve_function_type(func, 1) mark_invalid_typeargs(f) @@ -8886,7 +8901,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for pass = 1, passes do for i = 1, n do if (not tried) or not tried[i] then - local f = func is PolyType and func.types[i] or func + local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then @@ -11616,22 +11631,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end t = u - elseif is_a(rb, ra) then - node.known = facts_or(node, node.e1.known, node.e2.known) - if expected then - local a_is = is_a(a, node.expected) - local b_is = is_a(b, node.expected) - if a_is and b_is then - t = resolve_typevars_at(node, node.expected) - elseif a_is then - t = resolve_tuple(b) - else - t = resolve_tuple(a) + else + local a_ge_b = is_a(rb, ra) + local b_ge_a = is_a(ra, rb) + if a_ge_b or b_ge_a then + node.known = facts_or(node, node.e1.known, node.e2.known) + if expected then + local a_is = is_a(a, node.expected) + local b_is = is_a(b, node.expected) + if a_is and b_is then + t = resolve_typevars_at(node, node.expected) + end end - else - t = resolve_tuple(a) + if not t then + local larger_type = b_ge_a and b or a + t = resolve_tuple(larger_type) + end + t = drop_constant_value(t) end - t = drop_constant_value(t) end if t then @@ -11889,7 +11906,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = after_literal(node) as StringType t.literal = node.conststr - local expected = node.expected + local expected = node.expected and to_structural(node.expected) if expected and expected is EnumType and is_a(t, expected) then return node.expected end From d79e47f388af12dfabd66214eae44dd000a8c8b8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 14 Jan 2024 21:12:12 -0300 Subject: [PATCH 102/224] refactor: use to_structural to resolve nominals ...instead of the overkill of resolve_tuple_and_nominal everywhere. Only resolve tuples where needed (and we're still possibly overdoing it.) --- tl.lua | 280 +++++++++++++++++++++++++++---------------------------- tl.tl | 290 ++++++++++++++++++++++++++++----------------------------- 2 files changed, 281 insertions(+), 289 deletions(-) diff --git a/tl.lua b/tl.lua index a8815d6f1..58e445254 100644 --- a/tl.lua +++ b/tl.lua @@ -7686,9 +7686,9 @@ tl.type_check = function(ast, opts) end local is_lua_table_type - local resolve_tuple_and_nominal local function to_structural(t) + assert(not (t.typename == "tuple")) if t.typename == "nominal" then return resolve_nominal(t) end @@ -8569,7 +8569,7 @@ a.types[i], b.types[i]), } local function same_call_mt_in_all_union_entries(u) return same_in_all_union_entries(u, function(t) - t = resolve_tuple_and_nominal(t) + t = to_structural(t) if t.fields then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt.typename == "function" then @@ -8589,14 +8589,14 @@ a.types[i], b.types[i]), } func = a_fn({ args = va_args({ UNKNOWN }), rets = va_args({ UNKNOWN }) }) end - func = resolve_tuple_and_nominal(func) + func = to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then if func.typename == "union" then local r = same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) - return resolve_tuple_and_nominal(r), true + return to_structural(r), true end end @@ -8610,7 +8610,7 @@ a.types[i], b.types[i]), } if func.fields and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = resolve_tuple_and_nominal(func) + func = to_structural(func) is_method = true end end @@ -9018,7 +9018,7 @@ a.types[i], b.types[i]), } where_args[2] = node.e2 args.tuple[2] = orig_b end - return resolve_tuple_and_nominal((type_check_function_call(node, where_args, metamethod, args, nil, true))), meta_on_operator + return to_structural(resolve_tuple((type_check_function_call(node, where_args, metamethod, args, nil, true)))), meta_on_operator else return nil, nil end @@ -9029,7 +9029,7 @@ a.types[i], b.types[i]), } assert(type(rec) == "table") assert(type(key) == "string") - tbl = resolve_tuple_and_nominal(tbl) + tbl = to_structural(tbl) if tbl.typename == "string" or tbl.typename == "enum" then tbl = find_var_type("string") @@ -9271,15 +9271,6 @@ a.types[i], b.types[i]), } end_scope(node) end - resolve_tuple_and_nominal = function(t) - t = resolve_tuple(t) - if t.typename == "nominal" then - t = resolve_nominal(t) - end - assert(not (t.typename == "nominal")) - return t - end - local function flatten_tuple(vals) local vt = vals.tuple local n_vals = #vt @@ -9354,10 +9345,11 @@ a.types[i], b.types[i]), } end local function type_check_index(anode, bnode, a, b) - local orig_a = a - local orig_b = b - a = resolve_typedecl(resolve_tuple_and_nominal(a)) - b = resolve_tuple_and_nominal(b) + assert(not (a.typename == "tuple")) + assert(not (b.typename == "tuple")) + + local ra = resolve_typedecl(to_structural(a)) + local rb = to_structural(b) if lax and is_unknown(a) then return UNKNOWN @@ -9367,74 +9359,74 @@ a.types[i], b.types[i]), } local erra local errb - if a.typename == "tupletable" and is_a(b, INTEGER) then + if ra.typename == "tupletable" and is_a(rb, INTEGER) then if bnode.constnum then - if bnode.constnum >= 1 and bnode.constnum <= #a.types and bnode.constnum == math.floor(bnode.constnum) then - return a.types[bnode.constnum] + if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then + return ra.types[bnode.constnum] end - errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", a + errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, a) + local array_type = arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif a.elements and is_a(b, INTEGER) then - return a.elements - elseif a.typename == "emptytable" then - if a.keys == nil then - a.keys = infer_at(anode, resolve_tuple(orig_b)) + elseif ra.elements and is_a(rb, INTEGER) then + return ra.elements + elseif ra.typename == "emptytable" then + if ra.keys == nil then + ra.keys = infer_at(anode, b) end - if is_a(orig_b, a.keys) then + if is_a(b, ra.keys) then return type_at(anode, a_type("unresolved_emptytable_value", { - emptytable_type = a, + emptytable_type = ra, })) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. - a.keys.inferred_at.filename .. ":" .. - a.keys.inferred_at.y .. ":" .. - a.keys.inferred_at.x .. ": )", orig_b, a.keys - elseif a.typename == "map" then - if is_a(orig_b, a.keys) then - return a.values + ra.keys.inferred_at.filename .. ":" .. + ra.keys.inferred_at.y .. ":" .. + ra.keys.inferred_at.x .. ": )", b, ra.keys + elseif ra.typename == "map" then + if is_a(b, ra.keys) then + return ra.values end - errm, erra, errb = "wrong index type: got %s, expected %s", orig_b, a.keys - elseif b.typename == "string" and b.literal then - local t, e = match_record_key(orig_a, anode, b.literal) + errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys + elseif rb.typename == "string" and rb.literal then + local t, e = match_record_key(a, anode, rb.literal) if t then return t end - errm, erra = e, orig_a - elseif a.fields then - if b.typename == "enum" then - local field_names = sorted_keys(b.enumset) + errm, erra = e, a + elseif ra.fields then + if rb.typename == "enum" then + local field_names = sorted_keys(rb.enumset) for _, k in ipairs(field_names) do - if not a.fields[k] then - errm, erra = "enum value '" .. k .. "' is not a field in %s", a + if not ra.fields[k] then + errm, erra = "enum value '" .. k .. "' is not a field in %s", ra break end end if not errm then - return match_all_record_field_names(bnode, a, field_names, + return match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(b, STRING) then - errm, erra = "cannot index object of type %s with a string, consider using an enum", orig_a + elseif is_a(rb, STRING) then + errm, erra = "cannot index object of type %s with a string, consider using an enum", a else - errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b + errm, erra, errb = "cannot index object of type %s with %s", a, b end else - errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b + errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", a, orig_b, orig_a, orig_b) + local meta_t = check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end @@ -9729,7 +9721,7 @@ a.types[i], b.types[i]), } end local function resolve_if_union(t) - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if rt.typename == "union" then return rt end @@ -9999,7 +9991,7 @@ a.types[i], b.types[i]), } if not b.tuple[1] then return invalid_at(node, "pairs requires an argument") end - local t = resolve_tuple_and_nominal(b.tuple[1]) + local t = to_structural(b.tuple[1]) if t.elements then add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end @@ -10025,7 +10017,7 @@ a.types[i], b.types[i]), } return invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = resolve_tuple_and_nominal(orig_t) + local t = to_structural(orig_t) if t.typename == "tupletable" then local arr_type = arraytype_from_tuple(node.e2, t) @@ -10366,7 +10358,7 @@ expand_type(node, values, elements) }) local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if resolve_tuple_and_nominal(decltype) == INVALID then + if to_structural(decltype) == INVALID then decltype = INVALID end @@ -10390,7 +10382,7 @@ expand_type(node, values, elements) }) end if var.attribute == "total" then - local rd = decltype and resolve_tuple_and_nominal(decltype) + local rd = decltype and to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then error_at(var, "attribute only applies to maps and records") ok = false @@ -10398,7 +10390,7 @@ expand_type(node, values, elements) }) error_at(var, "variable declared does not declare an initialization value") ok = false elseif not (node.exps[i] and node.exps[i].attribute == "total") then - local ri = resolve_tuple_and_nominal(infertype) + local ri = to_structural(infertype) if not (ri.typename == "map" or ri.typename == "record") then error_at(var, "attribute only applies to maps and records") ok = false @@ -10478,7 +10470,7 @@ expand_type(node, values, elements) }) end local function total_map_check(t, seen_keys) - local k = resolve_tuple_and_nominal(t.keys) + local k = to_structural(t.keys) local is_total = true local missing if k.typename == "enum" then @@ -10514,7 +10506,7 @@ expand_type(node, values, elements) }) return nil end - local var = resolve_tuple_and_nominal(vartype) + local var = to_structural(vartype) if var.typename == "typedecl" or var.typename == "typealias" then error_at(where, "cannot reassign a type") return nil @@ -10527,7 +10519,7 @@ expand_type(node, values, elements) }) assert_is_a(where, valtype, vartype, "in assignment") - local val = resolve_tuple_and_nominal(valtype) + local val = to_structural(valtype) return var, val end @@ -10643,7 +10635,7 @@ expand_type(node, values, elements) }) if ok and infertype then local where = node.exps[i] or node.exps - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if (not (rt.typename == "enum")) and ((not (t.typename == "nominal")) or (rt.typename == "union")) and not same_type(t, infertype) then @@ -10876,9 +10868,9 @@ expand_type(node, values, elements) }) before_statements = function(node, children) widen_all_unions(node) begin_scope(node) - local from_t = resolve_tuple_and_nominal(children[2]) - local to_t = resolve_tuple_and_nominal(children[3]) - local step_t = children[4] and resolve_tuple_and_nominal(children[4]) + local from_t = to_structural(resolve_tuple(children[2])) + local to_t = to_structural(resolve_tuple(children[3])) + local step_t = children[4] and to_structural(children[4]) local t = (from_t.typename == "integer" and to_t.typename == "integer" and (not step_t or step_t.typename == "integer")) and @@ -10908,7 +10900,7 @@ expand_type(node, values, elements) }) if not expected then expected = infer_at(node, got) - module_type = drop_constant_value(resolve_tuple_and_nominal(expected)) + module_type = drop_constant_value(to_structural(resolve_tuple(expected))) st[2]["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10967,10 +10959,10 @@ expand_type(node, values, elements) }) ["literal_table"] = { before = function(node) if node.expected then - local decltype = resolve_tuple_and_nominal(node.expected) + local decltype = to_structural(node.expected) if decltype.typename == "typevar" and decltype.constraint then - decltype = resolve_typedecl(resolve_tuple_and_nominal(decltype.constraint)) + decltype = resolve_typedecl(to_structural(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -11009,12 +11001,12 @@ expand_type(node, values, elements) }) return infer_table_literal(node, children) end - local decltype = resolve_tuple_and_nominal(node.expected) + local decltype = to_structural(node.expected) local constraint if decltype.typename == "typevar" and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = resolve_tuple_and_nominal(constraint) + decltype = to_structural(constraint) end if decltype.typename == "union" then @@ -11022,7 +11014,7 @@ expand_type(node, values, elements) }) local single_table_rt for _, t in ipairs(decltype.types) do - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if is_lua_table_type(rt) then if single_table_type then @@ -11119,12 +11111,12 @@ expand_type(node, values, elements) }) end if decltype.typename == "record" then - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if rt.typename == "record" then rt.is_total, rt.missing = total_record_check(decltype, seen_keys) end elseif decltype.typename == "map" then - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if rt.typename == "map" then rt.is_total, rt.missing = total_map_check(decltype, seen_keys) end @@ -11275,7 +11267,7 @@ expand_type(node, values, elements) }) begin_scope(node) end, before_arguments = function(_node, children) - local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) + local rtype = to_structural(resolve_typedecl(children[1])) if rtype.fields and rtype.typeargs then @@ -11293,7 +11285,7 @@ expand_type(node, values, elements) }) local rets = children[4] assert(rets.typename == "tuple") - local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) + local rtype = to_structural(resolve_typedecl(children[1])) if lax and rtype.typename == "unknown" then return @@ -11333,7 +11325,7 @@ expand_type(node, values, elements) }) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = resolve_tuple_and_nominal(rfieldtype) + rfieldtype = to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then redeclaration_warning(node) @@ -11493,35 +11485,47 @@ expand_type(node, values, elements) }) after = function(node, children) end_scope() - local a = children[1] - local b = children[3] - local orig_a = a - local orig_b = b - local ra = a and resolve_tuple_and_nominal(a) - local rb = b and resolve_tuple_and_nominal(b) + local ga = children[1] + local gb = children[3] + + + local ua = resolve_tuple(ga) + local ub + - local expected = node.expected and resolve_tuple_and_nominal(node.expected) + local ra = to_structural(ua) + local rb if ra.typename == "circular_require" or (ra.typename == "typedecl" and ra.def and ra.def.typename == "circular_require") then return invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(a) then + if lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, a, b) + local t = type_check_funcall(node, ua, gb) return t + + elseif node.op.op == "as" then + return gb end + local expected = node.expected and to_structural(resolve_tuple(node.expected)) + ensure_not_abstract(node.e1, ra) if ra.typename == "typedecl" and ra.def.typename == "record" then ra = ra.def end - if rb then + + + + if gb then + ub = resolve_tuple(gb) + rb = to_structural(ub) ensure_not_abstract(node.e2, rb) if rb.typename == "typedecl" and rb.def.typename == "record" then rb = rb.def @@ -11529,7 +11533,7 @@ expand_type(node, values, elements) }) end if node.op.op == "." then - node.receiver = a + node.receiver = ua assert(node.e2.kind == "identifier") local bnode = { @@ -11539,7 +11543,7 @@ expand_type(node, values, elements) }) kind = "string", } local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk })) - local t = type_check_index(node.e1, bnode, orig_a, btype) + local t = type_check_index(node.e1, bnode, ua, btype) if t.needs_compat and opts.gen_compat ~= "off" then @@ -11555,11 +11559,7 @@ expand_type(node, values, elements) }) end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, a, b) - end - - if node.op.op == "as" then - return b + return type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then @@ -11569,8 +11569,8 @@ expand_type(node, values, elements) }) if ra.typename == "typedecl" then error_at(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), orig_a, orig_b) - node.known = IsFact({ var = node.e1.tk, typ = b, where = node }) + check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact({ var = node.e1.tk, typ = ub, where = node }) else error_at(node, "can only use 'is' on variables") end @@ -11578,20 +11578,20 @@ expand_type(node, values, elements) }) end if node.op.op == ":" then - node.receiver = a + node.receiver = ua - if lax and (is_unknown(a) or a.typename == "typevar") then + if lax and (is_unknown(ua) or ua.typename == "typevar") then if node.e1.kind == "variable" then add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end return UNKNOWN end - local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) + local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, resolve_tuple(orig_a)) + return invalid_at(node.e2, e, ua) end return t @@ -11604,18 +11604,18 @@ expand_type(node, values, elements) }) if node.op.op == "and" then node.known = facts_and(node, node.e1.known, node.e2.known) - return discard_tuple(node, b, b) + return discard_tuple(node, ub, gb) end if node.op.op == "or" then local t - if b.typename == "nil" then + if ub.typename == "nil" then node.known = nil - t = a + t = ua - elseif is_lua_table_type(ra) and b.typename == "emptytable" then + elseif is_lua_table_type(ra) and rb.typename == "emptytable" then node.known = nil - t = a + t = ua elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then @@ -11637,22 +11637,22 @@ expand_type(node, values, elements) }) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(a, node.expected) - local b_is = is_a(b, node.expected) + local a_is = is_a(ua, expected) + local b_is = is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, node.expected) + t = resolve_typevars_at(node, expected) end end if not t then - local larger_type = b_ge_a and b or a - t = resolve_tuple(larger_type) + local larger_type = b_ge_a and ub or ua + t = larger_type end t = drop_constant_value(t) end end if t then - return discard_tuple(node, t, b) + return discard_tuple(node, t, gb) end end @@ -11664,39 +11664,38 @@ expand_type(node, values, elements) }) if ra.typename == "enum" and rb.typename == "string" then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", b, a) + return invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then return invalid_at(node, "tuples are not the same size") - elseif is_a(b, a) or a.typename == "typevar" then + elseif is_a(ub, ua) or ua.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact({ var = node.e1.tk, typ = b, where = node }) + node.known = EqFact({ var = node.e1.tk, typ = ub, where = node }) end - elseif is_a(a, b) or b.typename == "typevar" then + elseif is_a(ua, ub) or ub.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact({ var = node.e2.tk, typ = a, where = node }) + node.known = EqFact({ var = node.e2.tk, typ = ua, where = node }) end - elseif lax and (is_unknown(a) or is_unknown(b)) then + elseif lax and (is_unknown(ua) or is_unknown(ub)) then return UNKNOWN else - return invalid_at(node, "types are not comparable for equality: %s and %s", a, b) + return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end return BOOLEAN end if node.op.arity == 1 and unop_types[node.op.op] then - a = ra - if a.typename == "union" then - a = unite(a.types, true) + if ra.typename == "union" then + ra = unite(ra.types, true) end local types_op = unop_types[node.op.op] - local t = types_op[a.typename] + local t = types_op[ra.typename] if not t then - t = find_in_interface_list(a, function(ty) + t = find_in_interface_list(ra, function(ty) return types_op[ty.typename] end) end @@ -11705,16 +11704,16 @@ expand_type(node, values, elements) }) if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) + t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) t = INVALID end end - if a.typename == "map" then - if a.keys.typename == "number" or a.keys.typename == "integer" then + if ra.typename == "map" then + if ra.keys.typename == "number" or ra.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else error_at(node, "using the '#' operator on this map will always return 0") @@ -11743,31 +11742,28 @@ expand_type(node, values, elements) }) node.known = facts_or(node, node.e1.known, node.e2.known) end - a = ra - b = rb - - if a.typename == "union" then - a = unite(a.types, true) + if ra.typename == "union" then + ra = unite(ra.types, true) end - if b.typename == "union" then - b = unite(b.types, true) + if rb.typename == "union" then + rb = unite(rb.types, true) end local types_op = binop_types[node.op.op] - local t = types_op[a.typename] and types_op[a.typename][b.typename] + local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] local meta_on_operator if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) + t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) t = INVALID if node.op.op == "or" then - local u = unite({ orig_a, orig_b }) + local u = unite({ ua, ub }) if u.typename == "union" and is_valid_union(u) then add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end @@ -11775,11 +11771,11 @@ expand_type(node, values, elements) }) end end - if orig_a.typename == "nominal" and orig_b.typename == "nominal" and not meta_on_operator then - if is_a(orig_a, orig_b) then - t = resolve_tuple(orig_a) + if ua.typename == "nominal" and ub.typename == "nominal" and not meta_on_operator then + if is_a(ua, ub) then + t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end diff --git a/tl.tl b/tl.tl index f1f44173a..06d8626a8 100644 --- a/tl.tl +++ b/tl.tl @@ -1581,7 +1581,7 @@ local record UnresolvedType where self.typename == "unresolved" labels: {string:{Node}} - nominals: {string:{Type}} + nominals: {string:{NominalType}} global_types: {string:boolean} narrows: {string:boolean} end @@ -1607,8 +1607,8 @@ local record NominalType names: {string} typevals: {Type} - found: Type -- type is found but typeargs are not resolved - resolved: Type -- type is found and typeargs are resolved + found: Type -- type is found but typeargs are not resolved + resolved: Type -- type is found and typeargs are resolved end local interface ArrayLikeType @@ -7504,7 +7504,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return NONE end - local resolve_nominal: function(t: Type): Type + local resolve_nominal: function(t: NominalType): Type local resolve_typealias: function(t: Type): Type, Variable do local function match_typevals(t: NominalType, def: RecordLikeType | FunctionType): Type @@ -7630,7 +7630,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false end - local function fail_nominals(t1: Type, t2: Type): boolean, {Error} + local function fail_nominals(t1: NominalType, t2: NominalType): boolean, {Error} local t1name = show_type(t1) local t2name = show_type(t2) if t1name == t2name then @@ -7686,9 +7686,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local is_lua_table_type: function(t: Type): boolean - local resolve_tuple_and_nominal: function(t: Type): Type local function to_structural(t: Type): Type + assert(not t is TupleType) if t is NominalType then return resolve_nominal(t) end @@ -8569,7 +8569,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local function same_call_mt_in_all_union_entries(u: UnionType): Type return same_in_all_union_entries(u, function(t: Type): (Type, Type) - t = resolve_tuple_and_nominal(t) + t = to_structural(t) if t is RecordLikeType then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt is FunctionType then @@ -8589,14 +8589,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string func = a_fn { args = va_args { UNKNOWN }, rets = va_args { UNKNOWN } } end -- unwrap if tuple, resolve if nominal - func = resolve_tuple_and_nominal(func) + func = to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then -- resolve if union if func is UnionType then local r = same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) -- FIXME: is this right? - return resolve_tuple_and_nominal(r), true + return to_structural(r), true end end -- resolve if prototype @@ -8610,7 +8610,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func is RecordLikeType and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = resolve_tuple_and_nominal(func) + func = to_structural(func) is_method = true end end @@ -9018,7 +9018,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string where_args[2] = node.e2 args.tuple[2] = orig_b end - return resolve_tuple_and_nominal((type_check_function_call(node, where_args, metamethod, args, nil, true))), meta_on_operator + return to_structural(resolve_tuple((type_check_function_call(node, where_args, metamethod, args, nil, true)))), meta_on_operator else return nil, nil end @@ -9029,7 +9029,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert(type(rec) == "table") assert(type(key) == "string") - tbl = resolve_tuple_and_nominal(tbl) + tbl = to_structural(tbl) if tbl is StringType or tbl is EnumType then tbl = find_var_type("string") -- simulate string metatable @@ -9271,15 +9271,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end_scope(node) end - resolve_tuple_and_nominal = function(t: Type): Type - t = resolve_tuple(t) - if t is NominalType then - t = resolve_nominal(t) - end - assert(not t is NominalType) - return t - end - local function flatten_tuple(vals: TupleType): TupleType local vt = vals.tuple local n_vals = #vt @@ -9354,10 +9345,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type - local orig_a = a - local orig_b = b - a = resolve_typedecl(resolve_tuple_and_nominal(a)) - b = resolve_tuple_and_nominal(b) + assert(not a is TupleType) + assert(not b is TupleType) + + local ra = resolve_typedecl(to_structural(a)) + local rb = to_structural(b) if lax and is_unknown(a) then return UNKNOWN @@ -9367,74 +9359,74 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local erra: Type local errb: Type - if a is TupleTableType and is_a(b, INTEGER) then + if ra is TupleTableType and is_a(rb, INTEGER) then if bnode.constnum then - if bnode.constnum >= 1 and bnode.constnum <= #a.types and bnode.constnum == math.floor(bnode.constnum) then - return a.types[bnode.constnum as integer] + if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then + return ra.types[bnode.constnum as integer] end - errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", a + errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, a) + local array_type = arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif a is ArrayLikeType and is_a(b, INTEGER) then - return a.elements - elseif a is EmptyTableType then - if a.keys == nil then - a.keys = infer_at(anode, resolve_tuple(orig_b)) + elseif ra is ArrayLikeType and is_a(rb, INTEGER) then + return ra.elements + elseif ra is EmptyTableType then + if ra.keys == nil then + ra.keys = infer_at(anode, b) end - if is_a(orig_b, a.keys) then + if is_a(b, ra.keys) then return type_at(anode, a_type("unresolved_emptytable_value", { - emptytable_type = a + emptytable_type = ra } as UnresolvedEmptyTableValueType)) end errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " - .. a.keys.inferred_at.filename .. ":" - .. a.keys.inferred_at.y .. ":" - .. a.keys.inferred_at.x .. ": )", orig_b, a.keys - elseif a is MapType then - if is_a(orig_b, a.keys) then - return a.values + .. ra.keys.inferred_at.filename .. ":" + .. ra.keys.inferred_at.y .. ":" + .. ra.keys.inferred_at.x .. ": )", b, ra.keys + elseif ra is MapType then + if is_a(b, ra.keys) then + return ra.values end - errm, erra, errb = "wrong index type: got %s, expected %s", orig_b, a.keys - elseif b is StringType and b.literal then - local t, e = match_record_key(orig_a, anode, b.literal) + errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys + elseif rb is StringType and rb.literal then + local t, e = match_record_key(a, anode, rb.literal) if t then return t end - errm, erra = e, orig_a - elseif a is RecordLikeType then - if b is EnumType then - local field_names: {string} = sorted_keys(b.enumset) + errm, erra = e, a + elseif ra is RecordLikeType then + if rb is EnumType then + local field_names: {string} = sorted_keys(rb.enumset) for _, k in ipairs(field_names) do - if not a.fields[k] then - errm, erra = "enum value '" .. k .. "' is not a field in %s", a + if not ra.fields[k] then + errm, erra = "enum value '" .. k .. "' is not a field in %s", ra break end end if not errm then - return match_all_record_field_names(bnode, a, field_names, + return match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(b, STRING) then - errm, erra = "cannot index object of type %s with a string, consider using an enum", orig_a + elseif is_a(rb, STRING) then + errm, erra = "cannot index object of type %s with a string, consider using an enum", a else - errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b + errm, erra, errb = "cannot index object of type %s with %s", a, b end else - errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b + errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", a, orig_b, orig_a, orig_b) + local meta_t = check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end @@ -9729,7 +9721,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function resolve_if_union(t: Type): Type - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if rt is UnionType then return rt end @@ -9999,7 +9991,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not b.tuple[1] then return invalid_at(node, "pairs requires an argument") end - local t = resolve_tuple_and_nominal(b.tuple[1]) + local t = to_structural(b.tuple[1]) if t is ArrayLikeType then add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end @@ -10025,7 +10017,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = resolve_tuple_and_nominal(orig_t) + local t = to_structural(orig_t) if t is TupleTableType then local arr_type = arraytype_from_tuple(node.e2, t) @@ -10366,7 +10358,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if resolve_tuple_and_nominal(decltype) == INVALID then + if to_structural(decltype) == INVALID then decltype = INVALID end @@ -10390,7 +10382,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if var.attribute == "total" then - local rd = decltype and resolve_tuple_and_nominal(decltype) + local rd = decltype and to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then error_at(var, "attribute only applies to maps and records") ok = false @@ -10398,7 +10390,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string error_at(var, "variable declared does not declare an initialization value") ok = false elseif not (node.exps[i] and node.exps[i].attribute == "total") then - local ri = resolve_tuple_and_nominal(infertype) + local ri = to_structural(infertype) if not (ri is MapType or ri is RecordType) then error_at(var, "attribute only applies to maps and records") ok = false @@ -10478,7 +10470,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function total_map_check(t: MapType, seen_keys: {CheckableKey:Where}): boolean, {string} - local k = resolve_tuple_and_nominal(t.keys) + local k = to_structural(t.keys) local is_total = true local missing: {string} if k is EnumType then @@ -10514,7 +10506,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - local var = resolve_tuple_and_nominal(vartype) + local var = to_structural(vartype) if var is TypeDeclType or var is TypeAliasType then error_at(where, "cannot reassign a type") return nil @@ -10527,7 +10519,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string assert_is_a(where, valtype, vartype, "in assignment") - local val = resolve_tuple_and_nominal(valtype) + local val = to_structural(valtype) return var, val end @@ -10643,7 +10635,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ok and infertype then local where = node.exps[i] or node.exps - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if (not rt is EnumType) and ((not t is NominalType) or (rt is UnionType)) and not same_type(t, infertype) @@ -10876,9 +10868,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string before_statements = function(node: Node, children: {Type}) widen_all_unions(node) begin_scope(node) - local from_t = resolve_tuple_and_nominal(children[2]) - local to_t = resolve_tuple_and_nominal(children[3]) - local step_t = children[4] and resolve_tuple_and_nominal(children[4]) + local from_t = to_structural(resolve_tuple(children[2])) + local to_t = to_structural(resolve_tuple(children[3])) + local step_t = children[4] and to_structural(children[4]) local t = (from_t.typename == "integer" and to_t.typename == "integer" and (not step_t or step_t.typename == "integer")) @@ -10908,7 +10900,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not expected then -- if at the toplevel expected = infer_at(node, got) - module_type = drop_constant_value(resolve_tuple_and_nominal(expected)) + module_type = drop_constant_value(to_structural(resolve_tuple(expected))) st[2]["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10967,10 +10959,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["literal_table"] = { before = function(node: Node) if node.expected then - local decltype = resolve_tuple_and_nominal(node.expected) + local decltype = to_structural(node.expected) if decltype is TypeVarType and decltype.constraint then - decltype = resolve_typedecl(resolve_tuple_and_nominal(decltype.constraint)) + decltype = resolve_typedecl(to_structural(decltype.constraint)) end if decltype is TupleTableType then @@ -11009,12 +11001,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return infer_table_literal(node, children) end - local decltype = resolve_tuple_and_nominal(node.expected) + local decltype = to_structural(node.expected) local constraint: Type if decltype is TypeVarType and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = resolve_tuple_and_nominal(constraint) + decltype = to_structural(constraint) end if decltype is UnionType then @@ -11022,7 +11014,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local single_table_rt: Type for _, t in ipairs(decltype.types) do - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if is_lua_table_type(rt) then if single_table_type then -- multiple table types in union, give up @@ -11119,12 +11111,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if decltype is RecordType then - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if rt is RecordType then rt.is_total, rt.missing = total_record_check(decltype, seen_keys) end elseif decltype is MapType then - local rt = resolve_tuple_and_nominal(t) + local rt = to_structural(t) if rt is MapType then rt.is_total, rt.missing = total_map_check(decltype, seen_keys) end @@ -11275,7 +11267,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string begin_scope(node) end, before_arguments = function(_node: Node, children: {Type}) - local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) + local rtype = to_structural(resolve_typedecl(children[1])) -- add type arguments from the record implicitly if rtype is RecordLikeType and rtype.typeargs then @@ -11293,7 +11285,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local rets = children[4] assert(rets is TupleType) - local rtype = resolve_tuple_and_nominal(resolve_typedecl(children[1])) + local rtype = to_structural(resolve_typedecl(children[1])) if lax and rtype.typename == "unknown" then return @@ -11333,7 +11325,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = resolve_tuple_and_nominal(rfieldtype) + rfieldtype = to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then redeclaration_warning(node) @@ -11493,35 +11485,47 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string after = function(node: Node, children: {Type}): Type end_scope() - local a: Type = children[1] - local b: Type = children[3] + -- given a and b: may be TupleType + local ga: Type = children[1] + local gb: Type = children[3] - local orig_a = a - local orig_b = b - local ra = a and resolve_tuple_and_nominal(a) - local rb = b and resolve_tuple_and_nominal(b) + -- unary a and b: not TupleType + local ua = resolve_tuple(ga) + local ub: Type - local expected = node.expected and resolve_tuple_and_nominal(node.expected) + -- resolved a and b: not NominalType + local ra: Type = to_structural(ua) + local rb: Type if ra.typename == "circular_require" or (ra is TypeDeclType and ra.def and ra.def.typename == "circular_require") then return invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(a) then + if lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, a, b) + local t = type_check_funcall(node, ua, gb) return t + + elseif node.op.op == "as" then + return gb end + local expected = node.expected and to_structural(resolve_tuple(node.expected)) + ensure_not_abstract(node.e1, ra) if ra is TypeDeclType and ra.def.typename == "record" then ra = ra.def end - if rb then + + -- "@funcall" and "as" are the only operators that use tuples, and always in the b position; + -- after they are handled above, we can resolve b's tuple and only use that instead. + if gb then + ub = resolve_tuple(gb) + rb = to_structural(ub) ensure_not_abstract(node.e2, rb) if rb is TypeDeclType and rb.def.typename == "record" then rb = rb.def @@ -11529,7 +11533,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node.op.op == "." then - node.receiver = a + node.receiver = ua assert(node.e2.kind == "identifier") local bnode: Node = { @@ -11539,7 +11543,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string kind = "string", } local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk } as StringType)) - local t = type_check_index(node.e1, bnode, orig_a, btype) + local t = type_check_index(node.e1, bnode, ua, btype) if t.needs_compat and opts.gen_compat ~= "off" then -- only apply to a literal use, not a propagated type @@ -11555,11 +11559,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, a, b) - end - - if node.op.op == "as" then - return b + return type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then @@ -11569,8 +11569,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ra is TypeDeclType then error_at(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), orig_a, orig_b) - node.known = IsFact { var = node.e1.tk, typ = b, where = node } + check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact { var = node.e1.tk, typ = ub, where = node } else error_at(node, "can only use 'is' on variables") end @@ -11578,20 +11578,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node.op.op == ":" then - node.receiver = a + node.receiver = ua -- we handle ':' separately from '.' because ':' is specific to records, -- so we produce different error messages - if lax and (is_unknown(a) or a.typename == "typevar") then + if lax and (is_unknown(ua) or ua.typename == "typevar") then if node.e1.kind == "variable" then add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end return UNKNOWN end - local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk) + local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, resolve_tuple(orig_a)) + return invalid_at(node.e2, e, ua) end return t @@ -11604,18 +11604,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "and" then node.known = facts_and(node, node.e1.known, node.e2.known) - return discard_tuple(node, b, b) + return discard_tuple(node, ub, gb) end if node.op.op == "or" then local t: Type - if b.typename == "nil" then + if ub.typename == "nil" then node.known = nil - t = a + t = ua - elseif is_lua_table_type(ra) and b is EmptyTableType then + elseif is_lua_table_type(ra) and rb is EmptyTableType then node.known = nil - t = a + t = ua elseif ((ra is EnumType and rb is StringType and is_a(rb, ra)) or (ra is StringType and rb is EnumType and is_a(ra, rb))) then @@ -11637,22 +11637,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(a, node.expected) - local b_is = is_a(b, node.expected) + local a_is = is_a(ua, expected) + local b_is = is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, node.expected) + t = resolve_typevars_at(node, expected) end end if not t then - local larger_type = b_ge_a and b or a - t = resolve_tuple(larger_type) + local larger_type = b_ge_a and ub or ua + t = larger_type end t = drop_constant_value(t) end end if t then - return discard_tuple(node, t, b) + return discard_tuple(node, t, gb) end -- else fallthrough to general binop handler end @@ -11664,39 +11664,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if ra is EnumType and rb is StringType then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", b, a) + return invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra is TupleTableType and rb is TupleTableType and #ra.types ~= #rb.types then return invalid_at(node, "tuples are not the same size") - elseif is_a(b, a) or a.typename == "typevar" then + elseif is_a(ub, ua) or ua.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact { var = node.e1.tk, typ = b, where = node } + node.known = EqFact { var = node.e1.tk, typ = ub, where = node } end - elseif is_a(a, b) or b.typename == "typevar" then + elseif is_a(ua, ub) or ub.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact { var = node.e2.tk, typ = a, where = node } + node.known = EqFact { var = node.e2.tk, typ = ua, where = node } end - elseif lax and (is_unknown(a) or is_unknown(b)) then + elseif lax and (is_unknown(ua) or is_unknown(ub)) then return UNKNOWN else - return invalid_at(node, "types are not comparable for equality: %s and %s", a, b) + return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end return BOOLEAN end if node.op.arity == 1 and unop_types[node.op.op] then - a = ra - if a is UnionType then - a = unite(a.types, true) -- squash unions of string constants + if ra is UnionType then + ra = unite(ra.types, true) -- squash unions of string constants end local types_op = unop_types[node.op.op] - local t = types_op[a.typename] + local t = types_op[ra.typename] if not t then - t = find_in_interface_list(a, function(ty: Type): Type + t = find_in_interface_list(ra, function(ty: Type): Type return types_op[ty.typename] end) end @@ -11705,16 +11704,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, a, nil, orig_a, nil) + t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) t = INVALID end end - if a is MapType then - if a.keys.typename == "number" or a.keys.typename == "integer" then + if ra is MapType then + if ra.keys.typename == "number" or ra.keys.typename == "integer" then add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else error_at(node, "using the '#' operator on this map will always return 0") @@ -11743,31 +11742,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = facts_or(node, node.e1.known, node.e2.known) end - a = ra - b = rb - - if a is UnionType then - a = unite(a.types, true) -- squash unions of string constants + if ra is UnionType then + ra = unite(ra.types, true) -- squash unions of string constants end - if b is UnionType then - b = unite(b.types, true) -- squash unions of string constants + if rb is UnionType then + rb = unite(rb.types, true) -- squash unions of string constants end local types_op = binop_types[node.op.op] - local t = types_op[a.typename] and types_op[a.typename][b.typename] + local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] local meta_on_operator: integer if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, a, b, orig_a, orig_b) + t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) t = INVALID if node.op.op == "or" then - local u = unite({orig_a, orig_b}) + local u = unite({ua, ub}) if u is UnionType and is_valid_union(u) then add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end @@ -11775,11 +11771,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if orig_a is NominalType and orig_b is NominalType and not meta_on_operator then - if is_a(orig_a, orig_b) then - t = resolve_tuple(orig_a) + if ua is NominalType and ub is NominalType and not meta_on_operator then + if is_a(ua, ub) then + t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) + error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end From 4487f807cea3fcb686e76fb38d633418157c7b23 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 14 Jan 2024 22:19:56 -0300 Subject: [PATCH 103/224] this block is no longer needed --- tl.lua | 6 ------ tl.tl | 6 ------ 2 files changed, 12 deletions(-) diff --git a/tl.lua b/tl.lua index 58e445254..babfa9ff0 100644 --- a/tl.lua +++ b/tl.lua @@ -7556,12 +7556,6 @@ tl.type_check = function(ast, opts) return def end - - if def.typename == "nominal" then - found = def.found - assert(found.typename == "typedecl") - def = found.def - end assert(not (def.typename == "nominal")) resolved = match_typevals(t, def) diff --git a/tl.tl b/tl.tl index 06d8626a8..47d88bd1b 100644 --- a/tl.tl +++ b/tl.tl @@ -7556,12 +7556,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return def end - -- FIXME is this block still needed? - if def is NominalType then - found = def.found - assert(found is TypeDeclType) - def = found.def - end assert(not def is NominalType) resolved = match_typevals(t, def) From 84c5b8d5960d2306148ef84bc3f331f490966c01 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 15 Jan 2024 01:32:56 -0300 Subject: [PATCH 104/224] show typeargs when displaying record names --- tl.lua | 16 ++++++++++++++-- tl.tl | 16 ++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index babfa9ff0..b27e687c5 100644 --- a/tl.lua +++ b/tl.lua @@ -6045,11 +6045,14 @@ local function display_typevar(typevar) end local function show_fields(t, show) - if t.declname then + if t.declname and not t.typeargs then return " " .. t.declname end local out = {} + if t.declname and not t.typeargs then + table.insert(out, " " .. t.declname) + end if t.typeargs then table.insert(out, "<") local typeargs = {} @@ -6059,6 +6062,10 @@ local function show_fields(t, show) table.insert(out, table.concat(typeargs, ", ")) table.insert(out, ">") end + if t.declname then + return table.concat(out) + end + table.insert(out, " (") if t.elements then table.insert(out, "{" .. show(t.elements) .. "}") @@ -6473,6 +6480,9 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) } if not stdlib_globals then + local tl_debug = TL_DEBUG + TL_DEBUG = nil + local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert(#syntax_errors == 0) local result = tl.type_check(program, { @@ -6480,7 +6490,9 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) env = env, }) assert(#result.type_errors == 0) - stdlib_globals = env.globals; + stdlib_globals = env.globals + + TL_DEBUG = tl_debug local math_t = (stdlib_globals["math"].t).def diff --git a/tl.tl b/tl.tl index 47d88bd1b..670273a7b 100644 --- a/tl.tl +++ b/tl.tl @@ -6045,11 +6045,14 @@ local function display_typevar(typevar: string): string end local function show_fields(t: RecordLikeType, show: function(Type):(string)): string - if t.declname then + if t.declname and not t.typeargs then return " " .. t.declname end local out: {string} = {} + if t.declname and not t.typeargs then + table.insert(out, " " .. t.declname) + end if t.typeargs then table.insert(out, "<") local typeargs = {} @@ -6059,6 +6062,10 @@ local function show_fields(t: RecordLikeType, show: function(Type):(string)): st table.insert(out, table.concat(typeargs, ", ")) table.insert(out, ">") end + if t.declname then + return table.concat(out) + end + table.insert(out, " (") if t.elements then table.insert(out, "{" .. show(t.elements) .. "}") @@ -6473,6 +6480,9 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar } if not stdlib_globals then + local tl_debug = TL_DEBUG + TL_DEBUG = nil + local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert(#syntax_errors == 0) local result = tl.type_check(program, { @@ -6480,7 +6490,9 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar env = env }) assert(#result.type_errors == 0) - stdlib_globals = env.globals; + stdlib_globals = env.globals + + TL_DEBUG = tl_debug -- special cases for compatibility local math_t = (stdlib_globals["math"].t as TypeDeclType).def as RecordType From 867a65903cc5d69f6356216f8d085b23a176c3f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Perrad?= Date: Mon, 8 Jan 2024 18:52:40 -0300 Subject: [PATCH 105/224] more integer in stdlib signatures Co-Authored-By: Hisham Muhammad --- spec/stdlib/require_spec.lua | 2 +- spec/stdlib/select_spec.lua | 3 +- tl.lua | 95 ++++++++++++++++++------------------ tl.tl | 93 ++++++++++++++++++----------------- 4 files changed, 97 insertions(+), 96 deletions(-) diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 5fc70d40f..d2196ab27 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -703,7 +703,7 @@ describe("require", function() ["luaunit.d.tl"] = [[ global type luaunit_runner_t = record setOutputType: function(luaunit_runner_t, string) - runSuite: function(luaunit_runner_t, any): number + runSuite: function(luaunit_runner_t, any): integer end global type luaunit_t = record diff --git a/spec/stdlib/select_spec.lua b/spec/stdlib/select_spec.lua index 39354fc91..a229e5e69 100644 --- a/spec/stdlib/select_spec.lua +++ b/spec/stdlib/select_spec.lua @@ -13,7 +13,6 @@ describe("select", function() it("rejects an invalid first argument", util.check_type_error([[ select({}, "hi", "hello") ]], { - -- FIXME not ideal message, but it fails on failure cases - { msg = "got {}, expected number" }, + { msg = "got {}, expected integer" }, })) end) diff --git a/tl.lua b/tl.lua index b27e687c5..a567435e0 100644 --- a/tl.lua +++ b/tl.lua @@ -46,8 +46,8 @@ do read: function(FILE, (number | FileMode)...): ((string | number)...) read: function(FILE, (number | string)...): (string...) - seek: function(FILE, ? SeekWhence, ? number): integer, string - setvbuf: function(FILE, SetVBufMode, ? number) + seek: function(FILE, ? SeekWhence, ? integer): integer, string + setvbuf: function(FILE, SetVBufMode, ? integer) write: function(FILE, (string | number)...): FILE, string @@ -96,35 +96,35 @@ do debug: function() gethook: function(? thread): HookFunction, integer - getinfo: function(AnyFunction | number): GetInfoTable - getinfo: function(AnyFunction | number, string): GetInfoTable - getinfo: function(thread, AnyFunction | number, string): GetInfoTable + getinfo: function(AnyFunction | integer): GetInfoTable + getinfo: function(AnyFunction | integer, string): GetInfoTable + getinfo: function(thread, AnyFunction | integer, string): GetInfoTable - getlocal: function(thread, AnyFunction, number): string - getlocal: function(thread, number, number): string, any - getlocal: function(AnyFunction, number): string - getlocal: function(number, number): string, any + getlocal: function(thread, AnyFunction, integer): string + getlocal: function(thread, integer, integer): string, any + getlocal: function(AnyFunction, integer): string + getlocal: function(integer, integer): string, any getmetatable: function(T): metatable getregistry: function(): {any:any} - getupvalue: function(AnyFunction, number): any - getuservalue: function(userdata, number): any + getupvalue: function(AnyFunction, integer): any + getuservalue: function(userdata, integer): any - sethook: function(thread, HookFunction, string, ? number) - sethook: function(HookFunction, string, ? number) + sethook: function(thread, HookFunction, string, ? integer) + sethook: function(HookFunction, string, ? integer) - setlocal: function(thread, number, number, any): string - setlocal: function(number, number, any): string + setlocal: function(thread, integer, integer, any): string + setlocal: function(integer, integer, any): string setmetatable: function(T, metatable): T - setupvalue: function(AnyFunction, number, any): string - setuservalue: function(U, any, number): U --[[U is userdata]] + setupvalue: function(AnyFunction, integer, any): string + setuservalue: function(U, any, integer): U --[[U is userdata]] - traceback: function(thread, ? string, ? number): string - traceback: function(? string, ? number): string + traceback: function(thread, ? string, ? integer): string + traceback: function(? string, ? integer): string - upvalueid: function(AnyFunction, number): userdata - upvaluejoin: function(AnyFunction, number, AnyFunction, number) + upvalueid: function(AnyFunction, integer): userdata + upvaluejoin: function(AnyFunction, integer, AnyFunction, integer) end global record io @@ -206,10 +206,10 @@ do pow: function(number, number): number rad: function(number): number - random: function(number, ? number): integer + random: function(integer, ? integer): integer random: function(): number - randomseed: function(number, number): integer, integer + randomseed: function(? integer, ? integer): integer, integer sin: function(number): number sinh: function(number): number sqrt: function(number): number @@ -282,9 +282,9 @@ do date: function(DateMode, ? number): DateTable date: function(? string, ? number): string - difftime: function(number, number): number + difftime: function(integer, integer): number execute: function(string): boolean, string, integer - exit: function(? (number | boolean), ? boolean) + exit: function(? (integer | boolean), ? boolean) getenv: function(string): string remove: function(string): boolean, string rename: function(string, string): boolean, string @@ -297,36 +297,37 @@ do config: string cpath: string loaded: {string:any} + loadlib: function(string, string): (function) loaders: { function(string): any, any } path: string preload: {any:any} searchers: { function(string): any } + searchpath: function(string, string, ? string, ? string): string, string end global record string - char: function(number...): string - - byte: function(string, ? number): integer - byte: function(string, number, number): integer... + byte: function(string, ? integer): integer + byte: function(string, integer, ? integer): integer... + char: function(integer...): string dump: function(function(any...): (any), ? boolean): string - find: function(string, string, ? number, ? boolean): integer, integer, string + find: function(string, string, ? integer, ? boolean): integer, integer, string format: function(string, any...): string - gmatch: function(string, string): (function(): string...) + gmatch: function(string, string, ? integer): (function(): string...) - gsub: function(string, string, string, ? number): string, integer - gsub: function(string, string, {string:string}, ? number): string, integer - gsub: function(string, string, function(string...): (string | number | boolean), ? number): string, integer + gsub: function(string, string, string, ? integer): string, integer + gsub: function(string, string, {string:string}, ? integer): string, integer + gsub: function(string, string, function(string...): (string | integer | boolean), ? integer): string, integer len: function(string): integer lower: function(string): string - match: function(string, string, ? number): string... + match: function(string, string, ? integer): string... pack: function(string, any...): string packsize: function(string): integer - rep: function(string, number, ? string): string + rep: function(string, integer, ? string): string reverse: function(string): string - sub: function(string, number, ? number): string - unpack: function(string, string, ? number): any... + sub: function(string, integer, ? integer): string + unpack: function(string, string, ? integer): any... upper: function(string): string end @@ -339,15 +340,15 @@ do n: integer end - concat: function({(string | number)}, ? string, ? number, ? number): string + concat: function({(string | number)}, ? string, ? integer, ? integer): string - insert: function({A}, number, A) + insert: function({A}, integer, A) insert: function({A}, A) pack: function(T...): PackTable pack: function(any...): {any:any} - remove: function({A}, ? number): A + remove: function({A}, ? integer): A sort: function({A}, ? SortFunction) unpack: function({A}, ? number, ? number): A... --[[needs_compat]] @@ -391,12 +392,12 @@ do arg: {string} assert: function(A, ? B): A - collectgarbage: function(CollectGarbageCommand): number - collectgarbage: function(CollectGarbageSetValue, number): number + collectgarbage: function(? CollectGarbageCommand): number + collectgarbage: function(CollectGarbageSetValue, integer): number collectgarbage: function(CollectGarbageIsRunning): boolean collectgarbage: function(string, ? number): (boolean | number) - error: function(? any, ? number) + error: function(? any, ? integer) ipairs: function({A}): (function():(integer, A)) load: function((string | LoadFunction), ? string, ? LoadMode, ? table): (function, string) @@ -410,14 +411,14 @@ do print: function(any...) require: function(string): any - select: function(number, T...): T... - select: function(number, any...): any... + select: function(integer, T...): T... + select: function(integer, any...): any... select: function(string, any...): integer setmetatable: function(T, metatable): T tonumber: function(any): number - tonumber: function(any, number): integer + tonumber: function(any, integer): integer tostring: function(any): string type: function(any): string diff --git a/tl.tl b/tl.tl index 670273a7b..c52ff05a0 100644 --- a/tl.tl +++ b/tl.tl @@ -46,8 +46,8 @@ do read: function(FILE, (number | FileMode)...): ((string | number)...) read: function(FILE, (number | string)...): (string...) - seek: function(FILE, ? SeekWhence, ? number): integer, string - setvbuf: function(FILE, SetVBufMode, ? number) + seek: function(FILE, ? SeekWhence, ? integer): integer, string + setvbuf: function(FILE, SetVBufMode, ? integer) write: function(FILE, (string | number)...): FILE, string @@ -96,35 +96,35 @@ do debug: function() gethook: function(? thread): HookFunction, integer - getinfo: function(AnyFunction | number): GetInfoTable - getinfo: function(AnyFunction | number, string): GetInfoTable - getinfo: function(thread, AnyFunction | number, string): GetInfoTable + getinfo: function(AnyFunction | integer): GetInfoTable + getinfo: function(AnyFunction | integer, string): GetInfoTable + getinfo: function(thread, AnyFunction | integer, string): GetInfoTable - getlocal: function(thread, AnyFunction, number): string - getlocal: function(thread, number, number): string, any - getlocal: function(AnyFunction, number): string - getlocal: function(number, number): string, any + getlocal: function(thread, AnyFunction, integer): string + getlocal: function(thread, integer, integer): string, any + getlocal: function(AnyFunction, integer): string + getlocal: function(integer, integer): string, any getmetatable: function(T): metatable getregistry: function(): {any:any} - getupvalue: function(AnyFunction, number): any - getuservalue: function(userdata, number): any + getupvalue: function(AnyFunction, integer): any + getuservalue: function(userdata, integer): any - sethook: function(thread, HookFunction, string, ? number) - sethook: function(HookFunction, string, ? number) + sethook: function(thread, HookFunction, string, ? integer) + sethook: function(HookFunction, string, ? integer) - setlocal: function(thread, number, number, any): string - setlocal: function(number, number, any): string + setlocal: function(thread, integer, integer, any): string + setlocal: function(integer, integer, any): string setmetatable: function(T, metatable): T - setupvalue: function(AnyFunction, number, any): string - setuservalue: function(U, any, number): U --[[U is userdata]] + setupvalue: function(AnyFunction, integer, any): string + setuservalue: function(U, any, integer): U --[[U is userdata]] - traceback: function(thread, ? string, ? number): string - traceback: function(? string, ? number): string + traceback: function(thread, ? string, ? integer): string + traceback: function(? string, ? integer): string - upvalueid: function(AnyFunction, number): userdata - upvaluejoin: function(AnyFunction, number, AnyFunction, number) + upvalueid: function(AnyFunction, integer): userdata + upvaluejoin: function(AnyFunction, integer, AnyFunction, integer) end global record io @@ -206,10 +206,10 @@ do pow: function(number, number): number rad: function(number): number - random: function(number, ? number): integer + random: function(integer, ? integer): integer random: function(): number - randomseed: function(number, number): integer, integer + randomseed: function(? integer, ? integer): integer, integer sin: function(number): number sinh: function(number): number sqrt: function(number): number @@ -282,9 +282,9 @@ do date: function(DateMode, ? number): DateTable date: function(? string, ? number): string - difftime: function(number, number): number + difftime: function(integer, integer): number execute: function(string): boolean, string, integer - exit: function(? (number | boolean), ? boolean) + exit: function(? (integer | boolean), ? boolean) getenv: function(string): string remove: function(string): boolean, string rename: function(string, string): boolean, string @@ -297,36 +297,37 @@ do config: string cpath: string loaded: {string:any} + loadlib: function(string, string): (function) loaders: { function(string): any, any } path: string preload: {any:any} searchers: { function(string): any } + searchpath: function(string, string, ? string, ? string): string, string end global record string - char: function(number...): string - - byte: function(string, ? number): integer - byte: function(string, number, number): integer... + byte: function(string, ? integer): integer + byte: function(string, integer, ? integer): integer... + char: function(integer...): string dump: function(function(any...): (any), ? boolean): string - find: function(string, string, ? number, ? boolean): integer, integer, string + find: function(string, string, ? integer, ? boolean): integer, integer, string format: function(string, any...): string - gmatch: function(string, string): (function(): string...) + gmatch: function(string, string, ? integer): (function(): string...) - gsub: function(string, string, string, ? number): string, integer - gsub: function(string, string, {string:string}, ? number): string, integer - gsub: function(string, string, function(string...): (string | number | boolean), ? number): string, integer + gsub: function(string, string, string, ? integer): string, integer + gsub: function(string, string, {string:string}, ? integer): string, integer + gsub: function(string, string, function(string...): (string | integer | boolean), ? integer): string, integer len: function(string): integer lower: function(string): string - match: function(string, string, ? number): string... + match: function(string, string, ? integer): string... pack: function(string, any...): string packsize: function(string): integer - rep: function(string, number, ? string): string + rep: function(string, integer, ? string): string reverse: function(string): string - sub: function(string, number, ? number): string - unpack: function(string, string, ? number): any... + sub: function(string, integer, ? integer): string + unpack: function(string, string, ? integer): any... upper: function(string): string end @@ -339,15 +340,15 @@ do n: integer end - concat: function({(string | number)}, ? string, ? number, ? number): string + concat: function({(string | number)}, ? string, ? integer, ? integer): string - insert: function({A}, number, A) + insert: function({A}, integer, A) insert: function({A}, A) pack: function(T...): PackTable pack: function(any...): {any:any} - remove: function({A}, ? number): A + remove: function({A}, ? integer): A sort: function({A}, ? SortFunction) unpack: function({A}, ? number, ? number): A... --[[needs_compat]] @@ -391,8 +392,8 @@ do arg: {string} assert: function(A, ? B): A - collectgarbage: function(CollectGarbageCommand): number - collectgarbage: function(CollectGarbageSetValue, number): number + collectgarbage: function(? CollectGarbageCommand): number + collectgarbage: function(CollectGarbageSetValue, integer): number collectgarbage: function(CollectGarbageIsRunning): boolean collectgarbage: function(string, ? number): (boolean | number) @@ -410,14 +411,14 @@ do print: function(any...) require: function(string): any - select: function(number, T...): T... - select: function(number, any...): any... + select: function(integer, T...): T... + select: function(integer, any...): any... select: function(string, any...): integer setmetatable: function(T, metatable): T tonumber: function(any): number - tonumber: function(any, number): integer + tonumber: function(any, integer): integer tostring: function(any): string type: function(any): string From 123635b9c4c0d80ee9991fee3e73679a997be23b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 8 Jan 2024 18:57:34 -0300 Subject: [PATCH 106/224] standard library: add missing entries --- tl.lua | 48 +++++++++++++++++++++++++++++++++++++++++++----- tl.tl | 50 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/tl.lua b/tl.lua index a567435e0..37c156dbb 100644 --- a/tl.lua +++ b/tl.lua @@ -57,8 +57,8 @@ do global record coroutine type Function = function(any...): any... - create: function(Function): thread close: function(thread): boolean, string + create: function(Function): thread isyieldable: function(): boolean resume: function(thread, any...): boolean, any... running: function(): thread, boolean @@ -164,9 +164,9 @@ do end global record math - abs: function(integer): integer - abs: function(number): number + type Numeric = number | integer + abs: function(N): N acos: function(number): number asin: function(number): number atan: function(number, ? number): number @@ -345,6 +345,8 @@ do insert: function({A}, integer, A) insert: function({A}, A) + move: function({A}, integer, integer, integer, ? {A}): {A} + pack: function(T...): PackTable pack: function(any...): {any:any} @@ -397,18 +399,35 @@ do collectgarbage: function(CollectGarbageIsRunning): boolean collectgarbage: function(string, ? number): (boolean | number) + dofile: function(? string): any... + error: function(? any, ? integer) + getmetatable: function(T): metatable ipairs: function({A}): (function():(integer, A)) load: function((string | LoadFunction), ? string, ? LoadMode, ? table): (function, string) load: function((string | LoadFunction), ? string, ? string, ? table): (function, string) + loadfile: function(? string, ? string, ? {any:any}): (function, string) + next: function({K:V}, ? K): (K, V) next: function({A}, ? integer): (integer, A) pairs: function({K:V}): (function():(K, V)) pcall: function(function(any...):(any...), any...): boolean, any... print: function(any...) + rawequal: function(any, any): boolean + + rawget: function({K:V}, K): V + rawget: function({any:any}, any): any + rawget: function(any, any): any + + rawlen: function({A}): integer + + rawset: function({K:V}, K, V): {K:V} + rawset: function({any:any}, any, any): {any:any} + rawset: function(any, any, any): any + require: function(string): any select: function(integer, T...): T... @@ -422,6 +441,7 @@ do tostring: function(any): string type: function(any): string + warn: function(string, string...) xpcall: function(function(any...):(any...), XpcallMsghFunction, any...): boolean, any... _VERSION: string end @@ -429,12 +449,19 @@ do global arg = StandardLibrary.arg global assert = StandardLibrary.assert global collectgarbage = StandardLibrary.collectgarbage + global dofile = StandardLibrary.dofile global error = StandardLibrary.error + global getmetatable = StandardLibrary.getmetatable global load = StandardLibrary.load + global loadfile = StandardLibrary.loadfile global next = StandardLibrary.next global pairs = StandardLibrary.pairs global pcall = StandardLibrary.pcall global print = StandardLibrary.print + global rawequal = StandardLibrary.rawequal + global rawget = StandardLibrary.rawget + global rawlen = StandardLibrary.rawlen + global rawset = StandardLibrary.rawset global require = StandardLibrary.require global select = StandardLibrary.select global setmetatable = StandardLibrary.setmetatable @@ -6451,6 +6478,16 @@ tl.new_env = function(opts) return env end +local function assert_no_stdlib_errors(errors, name) + if #errors ~= 0 then + local out = {} + for _, err in ipairs(errors) do + table.insert(out, err.y .. ":" .. err.x .. " " .. err.msg .. "\n") + end + error("Internal Compiler Error: standard library contains " .. name .. ":\n" .. table.concat(out), 2) + end +end + tl.init_env = function(lax, gen_compat, gen_target, predefined) if gen_compat == true or gen_compat == nil then gen_compat = "optional" @@ -6485,12 +6522,13 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") - assert(#syntax_errors == 0) + assert_no_stdlib_errors(syntax_errors, "syntax errors") + local result = tl.type_check(program, { filename = "@stdlib", env = env, }) - assert(#result.type_errors == 0) + assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals TL_DEBUG = tl_debug diff --git a/tl.tl b/tl.tl index c52ff05a0..e19c254d4 100644 --- a/tl.tl +++ b/tl.tl @@ -57,8 +57,8 @@ do global record coroutine type Function = function(any...): any... - create: function(Function): thread close: function(thread): boolean, string + create: function(Function): thread isyieldable: function(): boolean resume: function(thread, any...): boolean, any... running: function(): thread, boolean @@ -164,9 +164,9 @@ do end global record math - abs: function(integer): integer - abs: function(number): number + type Numeric = number | integer + abs: function(N): N acos: function(number): number asin: function(number): number atan: function(number, ? number): number @@ -345,6 +345,8 @@ do insert: function({A}, integer, A) insert: function({A}, A) + move: function({A}, integer, integer, integer, ? {A}): {A} + pack: function(T...): PackTable pack: function(any...): {any:any} @@ -397,18 +399,35 @@ do collectgarbage: function(CollectGarbageIsRunning): boolean collectgarbage: function(string, ? number): (boolean | number) - error: function(? any, ? number) + dofile: function(? string): any... + + error: function(? any, ? integer) + getmetatable: function(T): metatable ipairs: function({A}): (function():(integer, A)) load: function((string | LoadFunction), ? string, ? LoadMode, ? table): (function, string) load: function((string | LoadFunction), ? string, ? string, ? table): (function, string) + loadfile: function(? string, ? string, ? {any:any}): (function, string) + next: function({K:V}, ? K): (K, V) next: function({A}, ? integer): (integer, A) pairs: function({K:V}): (function():(K, V)) pcall: function(function(any...):(any...), any...): boolean, any... print: function(any...) + rawequal: function(any, any): boolean + + rawget: function({K:V}, K): V + rawget: function({any:any}, any): any + rawget: function(any, any): any + + rawlen: function({A}): integer + + rawset: function({K:V}, K, V): {K:V} + rawset: function({any:any}, any, any): {any:any} + rawset: function(any, any, any): any + require: function(string): any select: function(integer, T...): T... @@ -422,6 +441,7 @@ do tostring: function(any): string type: function(any): string + warn: function(string, string...) xpcall: function(function(any...):(any...), XpcallMsghFunction, any...): boolean, any... _VERSION: string end @@ -429,12 +449,19 @@ do global arg = StandardLibrary.arg global assert = StandardLibrary.assert global collectgarbage = StandardLibrary.collectgarbage + global dofile = StandardLibrary.dofile global error = StandardLibrary.error + global getmetatable = StandardLibrary.getmetatable global load = StandardLibrary.load + global loadfile = StandardLibrary.loadfile global next = StandardLibrary.next global pairs = StandardLibrary.pairs global pcall = StandardLibrary.pcall global print = StandardLibrary.print + global rawequal = StandardLibrary.rawequal + global rawget = StandardLibrary.rawget + global rawlen = StandardLibrary.rawlen + global rawset = StandardLibrary.rawset global require = StandardLibrary.require global select = StandardLibrary.select global setmetatable = StandardLibrary.setmetatable @@ -6451,6 +6478,16 @@ tl.new_env = function(opts: tl.EnvOptions): Env, string return env end +local function assert_no_stdlib_errors(errors: {Error}, name: string) + if #errors ~= 0 then + local out = {} + for _, err in ipairs(errors) do + table.insert(out, err.y .. ":" .. err.x .. " " .. err.msg .. "\n") + end + error("Internal Compiler Error: standard library contains " .. name .. ":\n" .. table.concat(out), 2) + end +end + tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_target?: TargetMode, predefined?: {string}): Env, string if gen_compat == true or gen_compat == nil then gen_compat = "optional" @@ -6485,12 +6522,13 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") - assert(#syntax_errors == 0) + assert_no_stdlib_errors(syntax_errors, "syntax errors") + local result = tl.type_check(program, { filename = "@stdlib", env = env }) - assert(#result.type_errors == 0) + assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals TL_DEBUG = tl_debug From 25f64b59a409040788af58b06cbc58d276cddf5f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 15 Jan 2024 01:39:01 -0300 Subject: [PATCH 107/224] _G: reserve a typeid --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 37c156dbb..987373fcb 100644 --- a/tl.lua +++ b/tl.lua @@ -6456,7 +6456,7 @@ local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2) end local stdlib_globals = nil -local globals_typeid +local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 local function set_feat(feat, default) diff --git a/tl.tl b/tl.tl index e19c254d4..e4d50bfd5 100644 --- a/tl.tl +++ b/tl.tl @@ -6456,7 +6456,7 @@ local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which end local stdlib_globals: {string:Variable} = nil -local globals_typeid: integer +local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 local function set_feat(feat: tl.Feat, default: boolean): boolean From 7ecb0bcef59a4c88b1133c9441e98fb54d50aff1 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 15 Jan 2024 02:27:48 -0300 Subject: [PATCH 108/224] fix typealias forwarding logic --- spec/stdlib/require_spec.lua | 116 +++++++++++++++++++++++++++++++++++ tl.lua | 94 ++++++++++++++-------------- tl.tl | 100 +++++++++++++++--------------- 3 files changed, 217 insertions(+), 93 deletions(-) diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index d2196ab27..b43c9b937 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -778,6 +778,122 @@ describe("require", function() assert.same({}, result.type_errors) end) + it("does not crash when localizing alias types from other records, a is forwarding b", function () + -- ok + util.mock_io(finally, { + ["main.tl"] = [[ + local a = require("a") + + local type MyA = a.AliasA + local type MyB = a.AliasB + + local w: a.AliasA + local z: a.AliasB + + local ww: MyA + local zz: MyB + + print(w.x) + print(z.y) + + print(ww.x) + print(zz.y) + ]], + ["a.tl"] = [[ + local b = require("b") + local type BA_in_A = b.AliasA + local type BB_in_A = b.AliasB + + return b + ]], + ["b.tl"] = [[ + local interface TypeA + x: number + end + + local interface TypeB + y: string + end + + local record b + type AliasA = TypeA + type AliasB = TypeB + end + + return b + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("does not crash when localizing alias types from other records, a is aliasing b aliases", function () + -- ok + util.mock_io(finally, { + ["main.tl"] = [[ + local a = require("a") + local b = require("b") + + local type MyA = a.AAlias + local type MyB = a.BAlias + + local w: a.AAlias + local z: a.BAlias + + local ww: MyA + local zz: MyB + + print(w.x) + print(z.y) + + print(ww.x) + print(zz.y) + + local www: b.AliasA + local zzz: b.AliasB + + www = ww + ww = www + w = www + www = w + ww = w + w = ww + ]], + ["a.tl"] = [[ + local b = require("b") + + local record a + type AAlias = b.AliasA + type BAlias = b.AliasB + end + + return a + ]], + ["b.tl"] = [[ + local interface TypeA + x: number + end + + local interface TypeB + y: string + end + + local record b + type AliasA = TypeA + type AliasB = TypeB + end + + return b + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + describe("circular requires", function() it("can be made using type-requires in order", function () util.mock_io(finally, { diff --git a/tl.lua b/tl.lua index 987373fcb..0c45e4e95 100644 --- a/tl.lua +++ b/tl.lua @@ -7555,6 +7555,8 @@ tl.type_check = function(ast, opts) return NONE end + + local resolve_nominal local resolve_typealias do @@ -7583,7 +7585,7 @@ tl.type_check = function(ast, opts) end end - resolve_nominal = function(t) + local function find_nominal_type_decl(t) if t.resolved then return t.resolved end @@ -7594,27 +7596,29 @@ tl.type_check = function(ast, opts) return INVALID end - local resolved - if found.typename == "typealias" then found = found.alias_to.found end - if found.typename == "typedecl" then - local def = found.def - if def.typename == "circular_require" then - - return def - end - - assert(not (def.typename == "nominal")) - - resolved = match_typevals(t, def) - else + if not (found.typename == "typedecl") then error_at(t, table.concat(t.names, ".") .. " is not a type") return INVALID end + local def = found.def + if def.typename == "circular_require" then + + return def + end + + assert(not (def.typename == "nominal")) + + t.found = found + return nil, found + end + + local function resolve_decl_into_nominal(t, found) + local resolved = match_typevals(t, found.def) if not resolved then error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") return INVALID @@ -7627,41 +7631,37 @@ tl.type_check = function(ast, opts) t.y = resolved.y end end - t.found = found + t.resolved = resolved return resolved end - resolve_typealias = function(typealias) - local names = typealias.alias_to.names - local aliasing = find_var(names[1], "use_type") - if not aliasing then - return INVALID + resolve_nominal = function(t) + local immediate, found = find_nominal_type_decl(t) + if immediate then + return immediate end + return resolve_decl_into_nominal(t, found) + end + + resolve_typealias = function(typealias) local t = typealias.alias_to - if t.resolved then - return t.resolved, aliasing - end - local found = t.found or find_type(t.names) - if not found then - error_at(t, "unknown type %s", t) - return INVALID + local immediate, found = find_nominal_type_decl(t) + if immediate then + return immediate end - assert(found.typename == "typedecl") - - if t.typevals then - local resolved = match_typevals(t, found.def) - t.resolved = resolved - t.found = found - found = a_type("typedecl", { def = resolved }) - else - t.resolved = t + if not t.typevals then + return found end - return found, aliasing + local resolved = resolve_decl_into_nominal(t, found) + + local typedecl = a_type("typedecl", { def = resolved }) + t.resolved = typedecl + return typedecl end end @@ -10469,22 +10469,24 @@ expand_type(node, values, elements) }) return ok, t, infertype ~= nil end - local function get_type_declaration(value) + local function get_typedecl(value) if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then local t = special_functions["require"](value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) - if not (t.typename == "invalid") then - return t.tuple[1] - end + local ty = t.typename == "tuple" and t.tuple[1] or t + ty = (ty.typename == "typealias") and resolve_typealias(ty) or ty + local td = (ty.typename == "typedecl") and ty or a_type("typedecl", { def = ty }) + return td else local newtype = value.newtype if newtype.typename == "typealias" then - return resolve_typealias(value.newtype) + local aliasing = find_var(newtype.alias_to.names[1], "use_type") + return resolve_typealias(newtype), aliasing else - return value.newtype, nil + return newtype, nil end end end @@ -10599,7 +10601,7 @@ expand_type(node, values, elements) }) ["local_type"] = { before = function(node) local name = node.var.tk - local resolved, aliasing = get_type_declaration(node.value) + local resolved, aliasing = get_typedecl(node.value) local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing @@ -10615,7 +10617,7 @@ expand_type(node, values, elements) }) local name = node.var.tk local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_type_declaration(node.value) + local resolved, aliasing = get_typedecl(node.value) local added = add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then @@ -12132,6 +12134,8 @@ expand_type(node, values, elements) }) end end end + elseif ftype.typename == "typealias" then + resolve_typealias(ftype) end typ.fields[name] = ftype diff --git a/tl.tl b/tl.tl index e4d50bfd5..7e8a25078 100644 --- a/tl.tl +++ b/tl.tl @@ -1999,7 +1999,7 @@ local record Node exps: Node -- newtype - newtype: Type + newtype: TypeAliasType | TypeDeclType elide_type: boolean -- expressions @@ -2139,7 +2139,7 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type }) end -local function new_typedecl(ps: ParseState, i: integer, def: Type): Type +local function new_typedecl(ps: ParseState, i: integer, def: Type): TypeDeclType local t = new_type(ps, i, "typedecl") as TypeDeclType t.def = def return t @@ -7555,8 +7555,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return NONE end + local type InvalidOrTypeDeclType = InvalidType | TypeDeclType + local resolve_nominal: function(t: NominalType): Type - local resolve_typealias: function(t: Type): Type, Variable + local resolve_typealias: function(t: TypeAliasType): InvalidOrTypeDeclType do local function match_typevals(t: NominalType, def: RecordLikeType | FunctionType): Type if t.typevals and def.typeargs then @@ -7583,7 +7585,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - resolve_nominal = function(t: NominalType): Type + local function find_nominal_type_decl(t: NominalType): Type, TypeDeclType if t.resolved then return t.resolved end @@ -7594,27 +7596,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return INVALID end - local resolved: Type - if found is TypeAliasType then found = found.alias_to.found end - if found is TypeDeclType then - local def = found.def - if def.typename == "circular_require" then - -- return, but do not store resolution - return def - end - - assert(not def is NominalType) - - resolved = match_typevals(t, def) - else + if not found is TypeDeclType then error_at(t, table.concat(t.names, ".") .. " is not a type") return INVALID end + local def = found.def + if def.typename == "circular_require" then + -- return, but do not store resolution + return def + end + + assert(not def is NominalType) + + t.found = found + return nil, found + end + + local function resolve_decl_into_nominal(t: NominalType, found: TypeDeclType): Type + local resolved = match_typevals(t, found.def) if not resolved then error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") return INVALID @@ -7627,41 +7631,37 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t.y = resolved.y end end - t.found = found + t.resolved = resolved return resolved end - resolve_typealias = function(typealias: TypeAliasType): Type, Variable - local names = typealias.alias_to.names - local aliasing = find_var(names[1], "use_type") - if not aliasing then - return INVALID + resolve_nominal = function(t: NominalType): Type + local immediate, found = find_nominal_type_decl(t) + if immediate then + return immediate end + return resolve_decl_into_nominal(t, found) + end + + resolve_typealias = function(typealias: TypeAliasType): InvalidOrTypeDeclType local t = typealias.alias_to - if t.resolved then - return t.resolved, aliasing - end - local found = t.found or find_type(t.names) - if not found then - error_at(t, "unknown type %s", t) - return INVALID + local immediate, found = find_nominal_type_decl(t) + if immediate then + return immediate end - assert(found is TypeDeclType) - - if t.typevals then - local resolved = match_typevals(t, found.def) - t.resolved = resolved - t.found = found - found = a_typedecl(resolved) - else - t.resolved = t + if not t.typevals then + return found end - return found, aliasing + local resolved = resolve_decl_into_nominal(t, found) + + local typedecl = a_type("typedecl", { def = resolved } as TypeDeclType) + t.resolved = typedecl + return typedecl end end @@ -10469,22 +10469,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ok, t, infertype ~= nil end - local function get_type_declaration(value: Node): Type, Variable + local function get_typedecl(value: Node): TypeDeclType, Variable if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then local t = special_functions["require"](value, find_var_type("require"), a_tuple { STRING }, 0) - if not t is InvalidType then - return t.tuple[1] - end + local ty = t is TupleType and t.tuple[1] or t + ty = (ty is TypeAliasType) and resolve_typealias(ty) or ty + local td = (ty is TypeDeclType) and ty or a_type("typedecl", { def = ty } as TypeDeclType) + return td else local newtype = value.newtype if newtype is TypeAliasType then - return resolve_typealias(value.newtype) + local aliasing = find_var(newtype.alias_to.names[1], "use_type") + return resolve_typealias(newtype), aliasing else - return value.newtype, nil + return newtype, nil end end end @@ -10599,7 +10601,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["local_type"] = { before = function(node: Node) local name = node.var.tk - local resolved, aliasing = get_type_declaration(node.value) + local resolved, aliasing = get_typedecl(node.value) local var = add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing @@ -10615,7 +10617,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local name = node.var.tk local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_type_declaration(node.value) + local resolved, aliasing = get_typedecl(node.value) local added = add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then @@ -12132,6 +12134,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end + elseif ftype is TypeAliasType then + resolve_typealias(ftype) end typ.fields[name] = ftype From 946f2b793975987c44195763d1cc6df0b0ba9ce2 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 10 Jan 2024 23:56:29 -0300 Subject: [PATCH 109/224] union_type: nominals should be already resolved here --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 0c45e4e95..8517d4b58 100644 --- a/tl.lua +++ b/tl.lua @@ -6797,7 +6797,7 @@ tl.type_check = function(ast, opts) elseif t.typename == "tuple" then return union_type(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then - local typedecl = t.found or find_type(t.names) + local typedecl = t.found if not typedecl then return "invalid" end diff --git a/tl.tl b/tl.tl index 7e8a25078..24393fe3f 100644 --- a/tl.tl +++ b/tl.tl @@ -6797,7 +6797,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif t is TupleType then return union_type(t.tuple[1]), t.tuple[1] elseif t is NominalType then - local typedecl = t.found or find_type(t.names) + local typedecl = t.found if not typedecl then return "invalid" end From 775cc2a13a705b48f695d43f5aac50143efe8759 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 11 Jan 2024 00:01:52 -0300 Subject: [PATCH 110/224] drop validate_union, just use is_valid_union --- tl.lua | 33 ++++++++++++++------------------- tl.tl | 33 ++++++++++++++------------------- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/tl.lua b/tl.lua index 8517d4b58..95ab3d728 100644 --- a/tl.lua +++ b/tl.lua @@ -6877,22 +6877,6 @@ tl.type_check = function(ast, opts) return true end - local function validate_union(where, u, store_errs, errs) - local valid, err = is_valid_union(u) - if err then - if store_errs then - errs = errs or {} - else - errs = errors - end - table.insert(errs, Err(where, err, u)) - end - if not valid then - return INVALID, store_errs and errs - end - return u, store_errs and errs - end - local function show_arity(f) local nfargs = #f.args.tuple return f.min_arity < nfargs and @@ -7071,7 +7055,11 @@ tl.type_check = function(ast, opts) copy.types[i], same = resolve(tf, same) end - copy, errs = validate_union(t, copy, true, errs) + local _, err = is_valid_union(copy) + if err then + errs = errs or {} + table.insert(errs, Err(t, err, copy)) + end elseif t.typename == "poly" then assert(copy.typename == "poly") copy.types = {} @@ -11674,7 +11662,10 @@ expand_type(node, values, elements) }) node.known = facts_or(node, node.e1.known, node.e2.known) local u = unite({ ra, rb }, true) if u.typename == "union" then - u = validate_union(node, u) + local ok, err = is_valid_union(u) + if not ok then + u = err and invalid_at(node, err, u) or INVALID + end end t = u @@ -12222,7 +12213,11 @@ expand_type(node, values, elements) }) }, ["union"] = { after = function(typ, _children) - return (validate_union(typ, typ)) + local ok, err = is_valid_union(typ) + if not ok then + return err and invalid_at(typ, err, typ) or INVALID + end + return typ end, }, }, diff --git a/tl.tl b/tl.tl index 24393fe3f..64352b18d 100644 --- a/tl.tl +++ b/tl.tl @@ -6877,22 +6877,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local function validate_union(where: Where, u: UnionType, store_errs?: boolean, errs?: {Error}): Type, {Error} - local valid, err = is_valid_union(u) - if err then - if store_errs then - errs = errs or {} - else - errs = errors - end - table.insert(errs, Err(where, err, u)) - end - if not valid then - return INVALID, store_errs and errs - end - return u, store_errs and errs - end - local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple return f.min_arity < nfargs @@ -7071,7 +7055,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string copy.types[i], same = resolve(tf, same) end - copy, errs = validate_union(t, copy, true, errs) + local _, err = is_valid_union(copy) + if err then + errs = errs or {} + table.insert(errs, Err(t, err, copy)) + end elseif t is PolyType then assert(copy is PolyType) copy.types = {} @@ -11674,7 +11662,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = facts_or(node, node.e1.known, node.e2.known) local u = unite({ra, rb}, true) if u is UnionType then - u = validate_union(node, u) + local ok, err = is_valid_union(u) + if not ok then + u = err and invalid_at(node, err, u) or INVALID + end end t = u @@ -12222,7 +12213,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, ["union"] = { after = function(typ: UnionType, _children: {Type}): Type - return (validate_union(typ, typ)) + local ok, err = is_valid_union(typ) + if not ok then + return err and invalid_at(typ, err, typ) or INVALID + end + return typ end }, }, From bba12c4bdee6c2902548a094977e5836e7ce9980 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 11 Jan 2024 01:02:12 -0300 Subject: [PATCH 111/224] unite: removing duplicates via t.found should be enough This avoids resolve_nominal which needs state lookup. With this change, unite() should work looking at the given Type objects only. --- tl.lua | 4 ++-- tl.tl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index 95ab3d728..60e3eb5df 100644 --- a/tl.lua +++ b/tl.lua @@ -7764,8 +7764,8 @@ tl.type_check = function(ast, opts) end else local typeid = t.typeid - if t.typename == "nominal" then - typeid = resolve_nominal(t).typeid + if t.typename == "nominal" and t.found then + typeid = t.found.typeid end if not types_seen[typeid] then types_seen[typeid] = true diff --git a/tl.tl b/tl.tl index 64352b18d..85cc49678 100644 --- a/tl.tl +++ b/tl.tl @@ -7764,8 +7764,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else local typeid = t.typeid - if t is NominalType then - typeid = resolve_nominal(t).typeid + if t is NominalType and t.found then + typeid = t.found.typeid end if not types_seen[typeid] then types_seen[typeid] = true From c2dc259c316ccf3dad6548922ad756b3038e6f0d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 12 Jan 2024 20:35:41 -0300 Subject: [PATCH 112/224] refactor: simplify signature of type_check_function_call a bit --- tl.lua | 38 ++++++++++++++++++++++---------------- tl.tl | 40 +++++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/tl.lua b/tl.lua index 60e3eb5df..7a02f29a4 100644 --- a/tl.lua +++ b/tl.lua @@ -8906,19 +8906,22 @@ a.types[i], b.types[i]), } return resolve_typevars_at(where, f.rets) end - local function check_call(where, where_args, func, args, expected_rets, is_typedecl_funcall, is_method, argdelta) + local function check_call(where, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) assert(type(func) == "table") assert(type(args) == "table") + local is_method = (argdelta == -1) + if not (func.typename == "function" or func.typename == "poly") then func, is_method = resolve_for_call(func, args, is_method) + if is_method then + argdelta = -1 + end if not (func.typename == "function" or func.typename == "poly") then return invalid_at(where, "not a function: %s", func) end end - argdelta = is_method and -1 or argdelta or 0 - if is_method and args.tuple[1] then add_var(nil, "@self", type_at(where, a_type("typedecl", { def = args.tuple[1] }))) end @@ -8985,7 +8988,10 @@ a.types[i], b.types[i]), } return fail_call(where, func, given, first_errs) end - type_check_function_call = function(node, where_args, func, args, e1, is_method, argdelta) + type_check_function_call = function(node, func, args, argdelta, e1, e2) + e1 = e1 or node.e1 + e2 = e2 or node.e2 + local expected = node.expected local expected_rets if expected and expected.typename == "tuple" then @@ -9007,7 +9013,7 @@ a.types[i], b.types[i]), } end end - local ret, f = check_call(node, where_args, func, args, expected_rets, is_typedecl_funcall, is_method, argdelta) + local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) ret = resolve_typevars_at(node, ret) end_scope() @@ -9016,7 +9022,7 @@ a.types[i], b.types[i]), } end if f and f.macroexp then - expand_macroexp(node, where_args, f.macroexp) + expand_macroexp(node, e2, f.macroexp) end return ret, f @@ -9045,13 +9051,13 @@ a.types[i], b.types[i]), } end if metamethod then - local where_args = { node.e1 } + local e2 = { node.e1 } local args = a_type("tuple", { tuple = { orig_a } }) if b and method_name ~= "__is" then - where_args[2] = node.e2 + e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, where_args, metamethod, args, nil, true)))), meta_on_operator + return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end @@ -10042,7 +10048,7 @@ a.types[i], b.types[i]), } end end - return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end, ["ipairs"] = function(node, a, b, argdelta) @@ -10063,7 +10069,7 @@ a.types[i], b.types[i]), } end end - return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end, ["rawget"] = function(node, _a, b, _argdelta) @@ -10105,7 +10111,7 @@ a.types[i], b.types[i]), } ["assert"] = function(node, a, b, argdelta) node.known = FACT_TRUTHY - local r = type_check_function_call(node, node.e2, a, b, node, false, argdelta) + local r = type_check_function_call(node, a, b, argdelta) apply_facts(node, node.e2[1].known) return r end, @@ -10118,13 +10124,13 @@ a.types[i], b.types[i]), } if special then return special(node, a, b, argdelta) else - return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, node.e2, a, b, node.e1, true)) + return (type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end end @@ -10861,7 +10867,7 @@ expand_type(node, values, elements) }) if exp1type.typename == "poly" then local _ - _, exp1type = type_check_function_call(exp1, { node.exps[2], node.exps[3] }, exp1type, args, exp1, false, 0) + _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) end if exp1type.typename == "function" then diff --git a/tl.tl b/tl.tl index 85cc49678..72dc63672 100644 --- a/tl.tl +++ b/tl.tl @@ -8743,7 +8743,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local type InvalidOrTupleType = InvalidType | TupleType - local type_check_function_call: function(Node, {Node}, Type, TupleType, Node, boolean, ? integer): InvalidOrTupleType, FunctionType + local type_check_function_call: function(Node, Type, TupleType, ? integer, ? Node, ? {Node}): InvalidOrTupleType, FunctionType do local function mark_invalid_typeargs(f: FunctionType) if f.typeargs then @@ -8906,19 +8906,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_typevars_at(where, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, is_method: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") + local is_method = (argdelta == -1) + if not (func is FunctionType or func is PolyType) then func, is_method = resolve_for_call(func, args, is_method) + if is_method then + argdelta = -1 + end if not (func is FunctionType or func is PolyType) then return invalid_at(where, "not a function: %s", func) end end - argdelta = is_method and -1 or argdelta or 0 - if is_method and args.tuple[1] then add_var(nil, "@self", type_at(where, a_typedecl(args.tuple[1]))) end @@ -8985,7 +8988,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return fail_call(where, func, given, first_errs) end - type_check_function_call = function(node: Node, where_args: {Node}, func: Type, args: TupleType, e1: Node, is_method: boolean, argdelta?: integer): InvalidOrTupleType, FunctionType + type_check_function_call = function(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType + e1 = e1 or node.e1 + e2 = e2 or node.e2 + local expected = node.expected local expected_rets: TupleType if expected and expected is TupleType then @@ -9007,7 +9013,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ret, f = check_call(node, where_args, func, args, expected_rets, is_typedecl_funcall, is_method, argdelta) + local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) ret = resolve_typevars_at(node, ret) end_scope() @@ -9016,7 +9022,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if f and f.macroexp then - expand_macroexp(node, where_args, f.macroexp) + expand_macroexp(node, e2, f.macroexp) end return ret, f @@ -9045,13 +9051,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if metamethod then - local where_args = { node.e1 } + local e2 = { node.e1 } local args = a_tuple { orig_a } if b and method_name ~= "__is" then - where_args[2] = node.e2 + e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, where_args, metamethod, args, nil, true)))), meta_on_operator + return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end @@ -10042,7 +10048,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end, ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType @@ -10063,7 +10069,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - return (type_check_function_call(node, node.e2, a, b, node, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end, ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType @@ -10105,7 +10111,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["assert"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType node.known = FACT_TRUTHY - local r = type_check_function_call(node, node.e2, a, b, node, false, argdelta) + local r = type_check_function_call(node, a, b, argdelta) apply_facts(node, node.e2[1].known) return r end, @@ -10118,13 +10124,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if special then return special(node, a, b, argdelta) else - return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, node.e2, a, b, node.e1, true)) + return (type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)) + return (type_check_function_call(node, a, b, argdelta)) end end @@ -10861,7 +10867,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if exp1type is PolyType then local _: Type - _, exp1type = type_check_function_call(exp1, {node.exps[2], node.exps[3]}, exp1type, args, exp1, false, 0) + _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) end if exp1type is FunctionType then From 5a6fb5f8acc91be01f5d22dcd0f32e906d038328 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 15 Jan 2024 13:00:55 -0300 Subject: [PATCH 113/224] move module_name logic out of type_check it is now handled by the parts that deal with module names: `require` and the package loader. --- spec/cli/check_spec.lua | 8 ++++++ tl | 37 ++++++++++++++++++++++--- tl.lua | 57 +++++++++++--------------------------- tl.tl | 61 ++++++++++++----------------------------- 4 files changed, 75 insertions(+), 88 deletions(-) diff --git a/spec/cli/check_spec.lua b/spec/cli/check_spec.lua index 725e196c9..347c369aa 100644 --- a/spec/cli/check_spec.lua +++ b/spec/cli/check_spec.lua @@ -3,6 +3,14 @@ local util = require("spec.util") describe("tl check", function() describe("on .tl files", function() + it("reports if file does not exist", function() + local pd = io.popen(util.tl_cmd("check", "file_that_does_not_exist.tl") .. " 2>&1", "r") + local output = pd:read("*a") + util.assert_popen_close(1, pd:close()) + assert.match("could not open file_that_does_not_exist.tl", output, 1, true) + end) + + it("works on empty files", function() local name = util.write_tmp_file(finally, [[]]) local pd = io.popen(util.tl_cmd("check", name), "r") diff --git a/tl b/tl index 68a329fbc..70248553a 100755 --- a/tl +++ b/tl @@ -119,6 +119,26 @@ local function find_file_in_parent_dirs(fname) end end +local function filename_to_module_name(filename) + local path = os.getenv("TL_PATH") or package.path + for entry in path:gmatch("[^;]+") do + entry = entry:gsub("%.", "%%.") + local lua_pat = "^" .. entry:gsub("%?", ".+") .. "$" + local d_tl_pat = lua_pat:gsub("%%.lua%$", "%%.d%%.tl$") + local tl_pat = lua_pat:gsub("%%.lua%$", "%%.tl$") + + for _, pat in ipairs({ tl_pat, d_tl_pat, lua_pat }) do + local cap = filename:match(pat) + if cap then + return (cap:gsub("[/\\]", ".")) + end + end + end + + -- fallback: + return (filename:gsub("%.lua$", ""):gsub("%.d%.tl$", ""):gsub("%.tl$", ""):gsub("[/\\]", ".")) +end + -------------------------------------------------------------------------------- -- Common driver backend -------------------------------------------------------------------------------- @@ -212,9 +232,18 @@ do end end +local function process_module(filename, env) + local module_name = filename_to_module_name(filename) + local result, err = tl.process(filename, env) + if result then + env.modules[module_name] = result.type + end + return result, err +end + local function type_check_and_load(tlconfig, filename) local env = setup_env(tlconfig, filename) - local result, err = tl.process(filename, env) + local result, err = process_module(filename, env) if err then die(err) end @@ -705,7 +734,7 @@ commands["check"] = function(tlconfig, args) env = setup_env(tlconfig, input_file) end if not already_loaded(env, input_file) then - local _, err = tl.process(input_file, env) + local _, err = process_module(input_file, env) if err then die(err) end @@ -766,7 +795,7 @@ commands["gen"] = function(tlconfig, args) output_file = get_output_filename(input_file) } - res.tl_result, err = tl.process(input_file, env) + res.tl_result, err = process_module(input_file, env) if err then die(err) end @@ -867,7 +896,7 @@ do env.report_types = true for i, input_file in ipairs(args["file"]) do - local pok, err = pcall(tl.process, input_file, env) + local pok, err = pcall(process_module, input_file, env) if not pok then die("Internal Compiler Error: " .. err) end diff --git a/tl.lua b/tl.lua index 7a02f29a4..ef6c34126 100644 --- a/tl.lua +++ b/tl.lua @@ -622,7 +622,6 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = { - local TypeReporter = {} @@ -6276,26 +6275,6 @@ local function search_for(module_name, suffix, path, tried) return nil, nil, tried end -local function filename_to_module_name(filename) - local path = os.getenv("TL_PATH") or package.path - for entry in path:gmatch("[^;]+") do - entry = entry:gsub("%.", "%%.") - local lua_pat = "^" .. entry:gsub("%?", ".+") .. "$" - local d_tl_pat = lua_pat:gsub("%%.lua%$", "%%.d%%.tl$") - local tl_pat = lua_pat:gsub("%%.lua%$", "%%.tl$") - - for _, pat in ipairs({ tl_pat, d_tl_pat, lua_pat }) do - local cap = filename:match(pat) - if cap then - return (cap:gsub("[/\\]", ".")) - end - end - end - - - return (filename:gsub("%.lua$", ""):gsub("%.d%.tl$", ""):gsub("%.tl$", ""):gsub("[/\\]", ".")) -end - function tl.search_module(module_name, search_dtl) local found local fd @@ -6326,9 +6305,14 @@ local function require_module(module_name, lax, env) local found, fd = tl.search_module(module_name, true) if found and (lax or found:match("tl$")) then - local found_result, err = tl.process(found, env, module_name, fd) + + env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + + local found_result, err = tl.process(found, env, fd) assert(found_result, err) + env.modules[module_name] = found_result.type + return found_result.type, true elseif fd then fd:close() @@ -6588,10 +6572,6 @@ tl.type_check = function(ast, opts) end end - if opts.module_name then - env.modules[opts.module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) - end - local lax = opts.lax local feat_arity = env.feat_arity local filename = opts.filename @@ -12334,10 +12314,6 @@ expand_type(node, values, elements) }) env.loaded[filename] = result table.insert(env.loaded_order, filename) - if opts.module_name then - env.modules[opts.module_name] = result.type - end - if tc then env.reporter:store_result(tc, env.globals) end @@ -12391,7 +12367,9 @@ local function read_full_file(fd) return content, err end -tl.process = function(filename, env, module_name, fd) +tl.process = function(filename, env, fd) + assert((not fd or type(fd) ~= "string"), "fd must be a file") + if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12423,14 +12401,10 @@ tl.process = function(filename, env, module_name, fd) is_lua = input:match("^#![^\n]*lua[^\n]*\n") end - return tl.process_string(input, is_lua, env, filename, module_name) + return tl.process_string(input, is_lua, env, filename) end -function tl.process_string(input, is_lua, env, filename, module_name) - if filename and not module_name then - module_name = filename_to_module_name(filename) - end - +function tl.process_string(input, is_lua, env, filename) env = env or tl.init_env(is_lua) if env.loaded and env.loaded[filename] then return env.loaded[filename] @@ -12443,7 +12417,6 @@ function tl.process_string(input, is_lua, env, filename, module_name) local result = { ok = false, filename = filename, - module_name = module_name, type = BOOLEAN, type_errors = {}, syntax_errors = syntax_errors, @@ -12456,7 +12429,6 @@ function tl.process_string(input, is_lua, env, filename, module_name) local opts = { filename = filename, - module_name = module_name, lax = is_lua, gen_compat = env.gen_compat, gen_target = env.gen_target, @@ -12502,14 +12474,17 @@ local function tl_package_loader(module_name) env = tl.package_loader_env end - tl.type_check(program, { + env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + + local result = tl.type_check(program, { lax = lax, filename = found_filename, - module_name = module_name, env = env, run_internal_compiler_checks = false, }) + env.modules[module_name] = result.type + local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) diff --git a/tl.tl b/tl.tl index 72dc63672..b8ed9743b 100644 --- a/tl.tl +++ b/tl.tl @@ -518,7 +518,6 @@ local record tl record TypeCheckOptions lax: boolean filename: string - module_name: string gen_compat: CompatMode gen_target: TargetMode env: Env @@ -613,8 +612,8 @@ local record tl end load: function(string, string, LoadMode, {any:any}): LoadFunction, string - process: function(string, Env, string, FILE): (Result, string) - process_string: function(string, boolean, Env, ? string, ? string): Result + process: function(string, Env, ? FILE): (Result, string) + process_string: function(string, boolean, Env, ? string): Result gen: function(string, Env, PrettyPrintOptions): string, Result type_check: function(Node, TypeCheckOptions): Result, string new_env: function(EnvOptions): Env, string @@ -6276,26 +6275,6 @@ local function search_for(module_name: string, suffix: string, path: string, tri return nil, nil, tried end -local function filename_to_module_name(filename: string): string - local path = os.getenv("TL_PATH") or package.path - for entry in path:gmatch("[^;]+") do - entry = entry:gsub("%.", "%%.") - local lua_pat = "^" .. entry:gsub("%?", ".+") .. "$" - local d_tl_pat = lua_pat:gsub("%%.lua%$", "%%.d%%.tl$") - local tl_pat = lua_pat:gsub("%%.lua%$", "%%.tl$") - - for _, pat in ipairs({ tl_pat, d_tl_pat, lua_pat }) do - local cap = filename:match(pat) - if cap then - return (cap:gsub("[/\\]", ".")) - end - end - end - - -- fallback: - return (filename:gsub("%.lua$", ""):gsub("%.d%.tl$", ""):gsub("%.tl$", ""):gsub("[/\\]", ".")) -end - function tl.search_module(module_name: string, search_dtl: boolean): string, FILE, {string} local found: string local fd: FILE @@ -6326,9 +6305,14 @@ local function require_module(module_name: string, lax: boolean, env: Env): Type local found, fd = tl.search_module(module_name, true) if found and (lax or found:match("tl$") as boolean) then - local found_result, err: Result, string = tl.process(found, env, module_name, fd) + + env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + + local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) + env.modules[module_name] = found_result.type + return found_result.type, true elseif fd then fd:close() @@ -6588,10 +6572,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if opts.module_name then - env.modules[opts.module_name] = a_typedecl(CIRCULAR_REQUIRE) - end - local lax = opts.lax local feat_arity = env.feat_arity local filename = opts.filename @@ -12334,10 +12314,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string env.loaded[filename] = result table.insert(env.loaded_order, filename) - if opts.module_name then - env.modules[opts.module_name] = result.type - end - if tc then env.reporter:store_result(tc, env.globals) end @@ -12391,7 +12367,9 @@ local function read_full_file(fd: FILE): string, string return content, err end -tl.process = function(filename: string, env: Env, module_name: string, fd: FILE): Result, string +tl.process = function(filename: string, env: Env, fd?: FILE): Result, string + assert((not fd or type(fd) ~= "string"), "fd must be a file") + if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12423,14 +12401,10 @@ tl.process = function(filename: string, env: Env, module_name: string, fd: FILE) is_lua = input:match("^#![^\n]*lua[^\n]*\n") as boolean end - return tl.process_string(input, is_lua, env, filename, module_name) + return tl.process_string(input, is_lua, env, filename) end -function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: string, module_name?: string): Result - if filename and not module_name then - module_name = filename_to_module_name(filename) - end - +function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: string): Result env = env or tl.init_env(is_lua) if env.loaded and env.loaded[filename] then return env.loaded[filename] @@ -12443,7 +12417,6 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: local result = { ok = false, filename = filename, - module_name = module_name, type = BOOLEAN, type_errors = {}, syntax_errors = syntax_errors, @@ -12456,7 +12429,6 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: local opts: TypeCheckOptions = { filename = filename, - module_name = module_name, lax = is_lua, gen_compat = env.gen_compat, gen_target = env.gen_target, @@ -12502,14 +12474,17 @@ local function tl_package_loader(module_name: string): any, any env = tl.package_loader_env end - tl.type_check(program, { + env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + + local result = tl.type_check(program, { lax = lax, filename = found_filename, - module_name = module_name, env = env, run_internal_compiler_checks = false, }) + env.modules[module_name] = result.type + -- TODO: should this be a hard error? this seems analogous to -- finding a lua file with a syntax error in it local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) From 52aa5f66fec65c5980ad1c4931831a590138244b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 16 Jan 2024 13:56:36 -0300 Subject: [PATCH 114/224] all prefixed errors have known locations --- tl.lua | 5 +++-- tl.tl | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index ef6c34126..1fcf92de9 100644 --- a/tl.lua +++ b/tl.lua @@ -7179,14 +7179,15 @@ tl.type_check = function(ast, opts) end local function add_errs_prefixing(where, src, dst, prefix) + assert(where == nil or where.y ~= nil) + if not src then return end for _, err in ipairs(src) do err.msg = prefix .. err.msg - - if where and where.y and ( + if where and ( (err.filename ~= filename) or (not err.y) or (where.y > err.y or (where.y == err.y and where.x > err.x))) then diff --git a/tl.tl b/tl.tl index b8ed9743b..42f18119e 100644 --- a/tl.tl +++ b/tl.tl @@ -7179,14 +7179,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function add_errs_prefixing(where: Where, src: {Error}, dst: {Error}, prefix: string) + assert(where == nil or where.y ~= nil) + if not src then return end for _, err in ipairs(src) do err.msg = prefix .. err.msg - -- where.y may be nil because not all types have .y set - if where and where.y and ( + if where and ( (err.filename ~= filename) or (not err.y) or (where.y > err.y or (where.y == err.y and where.x > err.x)) From 8793fe5e7c1aad2cf24b5cd5c5f0fd81012958d8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 16 Jan 2024 14:24:11 -0300 Subject: [PATCH 115/224] attribute now can only apply to literal tables --- docs/tutorial.md | 3 +- spec/declaration/local_spec.lua | 8 +++-- tl.lua | 52 +++++++++++++----------------- tl.tl | 56 ++++++++++++++------------------- 4 files changed, 53 insertions(+), 66 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index 68863444b..1451d5e6f 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -867,7 +867,8 @@ end The `` annotation is specific to Teal. It declares a const variable assigned to a table value in which all possible keys need to be explicitly -declared. +declared. Note that you can only use `` when assigning to a literal +table value, that is, when you are spelling out a table using a `{}` block. Of course, not all types allow you to enumerate all possible keys: there is an infinite number (well, not infinite because we're talking about computers, but diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index d2936d67d..ea3819081 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -315,7 +315,7 @@ describe("local", function() } ]])) - it("accepts direct declaration from total to total", util.check([[ + it("does not accept direct declaration from total to total", util.check_type_error([[ local record Point x: number end @@ -325,7 +325,9 @@ describe("local", function() } local p2 : Point = p - ]])) + ]], { + { y = 9, msg = "attribute only applies to literal tables" }, + })) it("rejects direct declaration from non-total to total", util.check_type_error([[ local record Point @@ -339,7 +341,7 @@ describe("local", function() local p2 : Point = p ]], { - { msg = "record variable declared does not declare values for all fields" }, + { y = 10, msg = "attribute only applies to literal tables" }, })) it("cannot reassign a total", util.check_type_error([[ diff --git a/tl.lua b/tl.lua index 1fcf92de9..45d61d72e 100644 --- a/tl.lua +++ b/tl.lua @@ -1846,11 +1846,6 @@ local table_types = { - - - - - @@ -2032,6 +2027,8 @@ local Node = {ExpectedContext = {}, } + + @@ -10409,24 +10406,26 @@ expand_type(node, values, elements) }) elseif not infertype then error_at(var, "variable declared does not declare an initialization value") ok = false - elseif not (node.exps[i] and node.exps[i].attribute == "total") then - local ri = to_structural(infertype) - if not (ri.typename == "map" or ri.typename == "record") then - error_at(var, "attribute only applies to maps and records") + else + local valnode = node.exps[i] + if not valnode or valnode.kind ~= "literal_table" then + error_at(var, "attribute only applies to literal tables") ok = false - elseif not ri.is_total then - local missing = "" - if ri.missing then - missing = " (missing: " .. table.concat(ri.missing, ", ") .. ")" - end - if ri.typename == "map" then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) - ok = false - elseif ri.typename == "record" then - error_at(var, "record variable declared does not declare values for all fields" .. missing) - ok = false + else + if not valnode.is_total then + local missing = "" + if valnode.missing then + missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" + end + local ri = to_structural(infertype) + if ri.typename == "map" then + error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + ok = false + elseif ri.typename == "record" then + error_at(var, "record variable declared does not declare values for all fields" .. missing) + ok = false + end end - ri.is_total = nil end end end @@ -11123,24 +11122,17 @@ expand_type(node, values, elements) }) t = infer_at(node, a_type("array", { elements = force_array })) else t = resolve_typevars_at(node, node.expected) - if node.expected == t and t.typename == "nominal" then - t = a_type("nominal", { - names = t.names, - found = t.found, - resolved = t.resolved, - }) - end end if decltype.typename == "record" then local rt = to_structural(t) if rt.typename == "record" then - rt.is_total, rt.missing = total_record_check(decltype, seen_keys) + node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype.typename == "map" then local rt = to_structural(t) if rt.typename == "map" then - rt.is_total, rt.missing = total_map_check(decltype, seen_keys) + node.is_total, node.missing = total_map_check(decltype, seen_keys) end end diff --git a/tl.tl b/tl.tl index 42f18119e..54a140d56 100644 --- a/tl.tl +++ b/tl.tl @@ -1623,11 +1623,6 @@ local interface HasDeclName declname: string end -local interface HasIsTotal - is_total: boolean - missing: {string} -end - local record NominalType is Type where self.typename == "nominal" @@ -1666,7 +1661,7 @@ local record ArrayType end local record RecordType - is RecordLikeType, HasIsTotal + is RecordLikeType where self.typename == "record" end @@ -1720,7 +1715,7 @@ local record TypeVarType end local record MapType - is Type, HasIsTotal + is Type where self.typename == "map" keys: Type @@ -2013,6 +2008,8 @@ local record Node -- table literal array_len: integer + is_total: boolean + missing: {string} -- goto label: string @@ -10409,24 +10406,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif not infertype then error_at(var, "variable declared does not declare an initialization value") ok = false - elseif not (node.exps[i] and node.exps[i].attribute == "total") then - local ri = to_structural(infertype) - if not (ri is MapType or ri is RecordType) then - error_at(var, "attribute only applies to maps and records") + else + local valnode = node.exps[i] + if not valnode or valnode.kind ~= "literal_table" then + error_at(var, "attribute only applies to literal tables") ok = false - elseif not ri.is_total then - local missing = "" - if ri.missing then - missing = " (missing: " .. table.concat(ri.missing, ", ") .. ")" - end - if ri is MapType then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) - ok = false - elseif ri is RecordType then - error_at(var, "record variable declared does not declare values for all fields" .. missing) - ok = false + else + if not valnode.is_total then + local missing = "" + if valnode.missing then + missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" + end + local ri = to_structural(infertype) + if ri is MapType then + error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + ok = false + elseif ri is RecordType then + error_at(var, "record variable declared does not declare values for all fields" .. missing) + ok = false + end end - ri.is_total = nil end end end @@ -11123,24 +11122,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string t = infer_at(node, an_array(force_array)) else t = resolve_typevars_at(node, node.expected) - if node.expected == t and t is NominalType then - t = a_type("nominal", { - names = t.names, - found = t.found, - resolved = t.resolved, - } as NominalType) - end end if decltype is RecordType then local rt = to_structural(t) if rt is RecordType then - rt.is_total, rt.missing = total_record_check(decltype, seen_keys) + node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype is MapType then local rt = to_structural(t) if rt is MapType then - rt.is_total, rt.missing = total_map_check(decltype, seen_keys) + node.is_total, node.missing = total_map_check(decltype, seen_keys) end end From c6b03025f74b9d5fec862cfa6a5124376ebf679e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 18 Jan 2024 02:22:50 -0300 Subject: [PATCH 116/224] interfaces: actually check interfaces! --- tl.lua | 43 +++++++++++++++++++++++++++++++------------ tl.tl | 53 ++++++++++++++++++++++++++++++++++++----------------- 2 files changed, 67 insertions(+), 29 deletions(-) diff --git a/tl.lua b/tl.lua index 45d61d72e..241c4ebd0 100644 --- a/tl.lua +++ b/tl.lua @@ -1848,6 +1848,9 @@ local table_types = { + + + @@ -6922,7 +6925,7 @@ tl.type_check = function(ast, opts) seen[orig_t] = rt return rt, false end - same = false + all_same = false t = rt end end @@ -7058,7 +7061,7 @@ tl.type_check = function(ast, opts) end end - copy.typeid = same and orig_t.typeid or new_typeid() + copy.typeid = same and t.typeid or new_typeid() return copy, same and all_same end @@ -7584,10 +7587,16 @@ tl.type_check = function(ast, opts) end local function resolve_decl_into_nominal(t, found) - local resolved = match_typevals(t, found.def) - if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + local def = found.def + local resolved + if def.typename == "record" or def.typename == "function" then + resolved = match_typevals(t, def) + if not resolved then + error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") + return INVALID + end + else + resolved = def end if not t.filename then @@ -8054,6 +8063,11 @@ tl.type_check = function(ast, opts) ["record"] = { ["record"] = eqtype_record, }, + ["interface"] = { + ["interface"] = function(a, b) + return a.typeid == b.typeid + end, + }, ["function"] = { ["function"] = function(a, b) local argdelta = a.is_method and 1 or 0 @@ -8326,7 +8340,7 @@ a.types[i], b.types[i]), } ["record"] = function(a, b) local def = a.def if def.fields then - return subtype_record(a.def, b) + return subtype_record(def, b) end end, }, @@ -9488,9 +9502,9 @@ a.types[i], b.types[i]), } old.meta_fields = nil edit_type(old, "map") - assert(old.typename == "map") - old.keys = STRING - old.values = values + local map = old + map.keys = STRING + map.values = values elseif old.typename == "union" then edit_type(old, "union") table.insert(old.types, drop_constant_value(new)) @@ -9590,7 +9604,12 @@ a.types[i], b.types[i]), } local def = found.def if def.fields and def.fields[exp.e2.tk] then table.insert(t.names, exp.e2.tk) - t.found = def.fields[exp.e2.tk] + local ft = def.fields[exp.e2.tk] + if type(ft) == "table" then + t.found = ft + else + return nil + end end end end @@ -11711,7 +11730,7 @@ expand_type(node, values, elements) }) local t = types_op[ra.typename] - if not t then + if not t and ra.fields then t = find_in_interface_list(ra, function(ty) return types_op[ty.typename] end) diff --git a/tl.tl b/tl.tl index 54a140d56..ecc5af5d4 100644 --- a/tl.tl +++ b/tl.tl @@ -1577,6 +1577,8 @@ local record StringType literal: string end +local type TypeType = TypeAliasType | TypeDeclType + local record TypeDeclType is Type where self.typename == "typedecl" @@ -1614,6 +1616,7 @@ local record UnresolvedType end local interface HasTypeArgs + is Type where self.typeargs typeargs: {TypeArgType} @@ -1629,7 +1632,7 @@ local record NominalType names: {string} typevals: {Type} - found: Type -- type is found but typeargs are not resolved + found: TypeType -- type is found but typeargs are not resolved resolved: Type -- type is found and typeargs are resolved end @@ -1993,7 +1996,7 @@ local record Node exps: Node -- newtype - newtype: TypeAliasType | TypeDeclType + newtype: TypeType elide_type: boolean -- expressions @@ -6922,7 +6925,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string seen[orig_t] = rt return rt, false end - same = false + all_same = false t = rt end end @@ -7058,7 +7061,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - copy.typeid = same and orig_t.typeid or new_typeid() + copy.typeid = same and t.typeid or new_typeid() return copy, same and all_same end @@ -7584,10 +7587,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function resolve_decl_into_nominal(t: NominalType, found: TypeDeclType): Type - local resolved = match_typevals(t, found.def) - if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + local def = found.def + local resolved: Type + if def is RecordType or def is FunctionType then + resolved = match_typevals(t, def) + if not resolved then + error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") + return INVALID + end + else + resolved = def end if not t.filename then @@ -8054,6 +8063,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = { ["record"] = eqtype_record, }, + ["interface"] = { + ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} + return a.typeid == b.typeid + end, + }, ["function"] = { ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} local argdelta = a.is_method and 1 or 0 @@ -8148,7 +8162,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["poly"] = { - ["*"] = function(a: Type, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b + ["*"] = function(a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b if exists_supertype_in(b, a) then -- ─────────────── return true -- a poly <: b end @@ -8199,7 +8213,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a: Type, b: Type): boolean, {Error} + ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then return true end @@ -8326,7 +8340,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["record"] = function(a: TypeDeclType, b: RecordType): boolean, {Error} local def = a.def if def is RecordLikeType then - return subtype_record(a.def, b) -- record as prototype + return subtype_record(def, b) -- record as prototype end end, }, @@ -9488,9 +9502,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.meta_fields = nil edit_type(old, "map") - assert(old is MapType) - old.keys = STRING - old.values = values + local map = old as MapType + map.keys = STRING + map.values = values elseif old is UnionType then edit_type(old, "union") table.insert(old.types, drop_constant_value(new)) @@ -9590,7 +9604,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local def = found.def if def is RecordLikeType and def.fields[exp.e2.tk] then table.insert(t.names, exp.e2.tk) - t.found = def.fields[exp.e2.tk] + local ft = def.fields[exp.e2.tk] + if ft is TypeType then + t.found = ft + else + return nil + end end end end @@ -11711,7 +11730,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = types_op[ra.typename] - if not t then + if not t and ra is RecordLikeType then t = find_in_interface_list(ra, function(ty: Type): Type return types_op[ty.typename] end) @@ -12034,7 +12053,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["record"] = { - before = function(typ: Type) + before = function(typ: RecordType) begin_scope() add_var(nil, "@self", type_at(typ, a_typedecl(typ))) From 0d3c6823b9eb2b111330b77c28cbd50f582a94cd Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jan 2024 21:41:23 -0300 Subject: [PATCH 117/224] big code reorganization: TypeChecker record Several big changes, that were done in tandem, and which would be too troublesome to break into separate commits. The goal here is to ultimately be able to break tl.tl into multiple files (because its size started hitting limits in both Lua 5.1 (number of upvalues) and Lua 5.4 (number of locals). Here's a high-level summary of the changes: * new Errors record, encapsulating error-reporting concerns; * all Type occurrences have unique objects reporting their locations (no more singletons for base types such as BOOLEAN and INVALID); * some enums renamed for more consistency across Gen and Feat options; * TypeCheckOptions and EnvOptions tables reorganized for easier forwarding of options across them; * simplifications in the various function signatures of the public API; * all Types and Nodes store filename, line and column location (`f`, `y`, `x`); * Scope is now a record containing the variables map and unresolved items -- no more "@unresolved" pseudo-variable and `unresolved` pseudo-type for storing this data in the symbols table; * `type_check` now uses a TypeChecker object for storing all state, instead of relying on closures and function nesting (that's a bit sad is it ended up spreading `self:` and extra function arguments everywhere, but I guess state management will be more explicit for others reading the code now...); * all Fact objects have a Where location as well, and supressions of inference data in error messages for widened-back types is marked explicitly with `no_infer` instead of missing a `w` field; * general simplification of the sourcing of error locations (though I would still like to improve that further); --- spec/api/gen_spec.lua | 4 +- spec/api/get_types_spec.lua | 4 +- spec/api/pretty_print_ast.lua | 2 +- spec/call/generic_function_spec.lua | 2 +- spec/cli/types_spec.lua | 20 +- spec/declaration/record_method_spec.lua | 4 +- spec/parser/parser_error_spec.lua | 4 +- spec/parser/parser_spec.lua | 1 + spec/stdlib/require_spec.lua | 2 +- spec/stdlib/xpcall_spec.lua | 2 +- spec/util.lua | 11 +- tl | 12 +- tl.lua | 7635 ++++++++++++----------- tl.tl | 4969 ++++++++------- 14 files changed, 6488 insertions(+), 6184 deletions(-) diff --git a/spec/api/gen_spec.lua b/spec/api/gen_spec.lua index a53538d46..baee93bcc 100644 --- a/spec/api/gen_spec.lua +++ b/spec/api/gen_spec.lua @@ -69,7 +69,7 @@ describe("tl.gen", function() print(math.floor(2)) ]] - local env = tl.init_env(true, true) + local env = tl.init_env(true, false) local output, result = tl.gen(input, env) assert.equal('print(math.floor(2))', output) @@ -83,7 +83,7 @@ describe("tl.gen", function() print(math.floor(2))]] - local env = tl.init_env(true, true) + local env = tl.init_env(true, false) local output, result = tl.gen(input, env) assert.equal(input, output) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index 26bbf1d05..b6f55ec83 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -8,7 +8,7 @@ describe("tl.get_types", function() local function a() ::continue:: end - ]], false, env)) + ]], env)) local tr, trenv = tl.get_types(result) assert(tr) @@ -25,7 +25,7 @@ describe("tl.get_types", function() end R.f("hello") - ]], false, env)) + ]], env)) local tr, trenv = tl.get_types(result) local y = 6 diff --git a/spec/api/pretty_print_ast.lua b/spec/api/pretty_print_ast.lua index d87d1ea86..d1d149786 100644 --- a/spec/api/pretty_print_ast.lua +++ b/spec/api/pretty_print_ast.lua @@ -4,7 +4,7 @@ local util = require("spec.util") describe("tl.pretty_print_ast", function() it("returns error for attribute on non 5.4 target", function() local input = [[local x = io.open("foobar", "r")]] - local result = tl.process_string(input, false, tl.init_env(false, "off", "5.4"), "foo.tl") + local result = tl.process_string(input, tl.init_env(false, "off", "5.4"), "foo.tl") local output, err = tl.pretty_print_ast(result.ast, "5.3") assert.is_nil(output) diff --git a/spec/call/generic_function_spec.lua b/spec/call/generic_function_spec.lua index 2fb8cf4d6..ec68bb3ff 100644 --- a/spec/call/generic_function_spec.lua +++ b/spec/call/generic_function_spec.lua @@ -370,7 +370,7 @@ describe("generic function", function() recurse_node(ast, visit_node, visit_type) end ]], { - { x = 40, msg = "argument 3: in map value: type parameter : got number, expected string" } + { y = 16, x = 40, msg = "argument 3: in map value: type parameter : got number, expected string" } })) it("inference trickles down to function arguments, pass", util.check([[ diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 2f1180e71..a1334e71f 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -300,7 +300,6 @@ describe("tl types works like check", function() local by_pos = types.by_pos[next(types.by_pos)] assert(by_pos["1"]) assert(by_pos["1"]["13"]) -- require - assert(by_pos["1"]["20"]) -- ( assert(by_pos["1"]["21"]) -- "os" assert(by_pos["1"]["26"]) -- . end) @@ -318,18 +317,17 @@ describe("tl types works like check", function() assert(types.by_pos) local by_pos = types.by_pos[next(types.by_pos)] assert.same({ - ["19"] = 2, - ["20"] = 5, - ["22"] = 2, - ["39"] = 6, - ["41"] = 2, + ["19"] = 8, + ["22"] = 8, + ["23"] = 6, + ["30"] = 2, + ["41"] = 8, }, by_pos["1"]) assert.same({ - ["17"] = 3, - ["20"] = 4, - ["25"] = 17, - ["30"] = 16, - ["31"] = 2, + ["17"] = 6, + ["20"] = 2, + ["25"] = 9, + ["31"] = 8, }, by_pos["2"]) end) end) diff --git a/spec/declaration/record_method_spec.lua b/spec/declaration/record_method_spec.lua index 20cbde3dc..7f8cf2db6 100644 --- a/spec/declaration/record_method_spec.lua +++ b/spec/declaration/record_method_spec.lua @@ -239,8 +239,8 @@ describe("record method", function() return "hello" end ]], { - { msg = "in assignment: incompatible number of returns: got 0 (), expected 1 (string)" }, - { msg = "excess return values, expected 0 (), got 1 (string \"hello\")" }, + { y = 5, msg = "in assignment: incompatible number of returns: got 0 (), expected 1 (string)" }, + { y = 6, msg = "excess return values, expected 0 (), got 1 (string \"hello\")" }, })) it("allows functions declared on method tables (#27)", function() diff --git a/spec/parser/parser_error_spec.lua b/spec/parser/parser_error_spec.lua index ed50e80c9..cfd2e077c 100644 --- a/spec/parser/parser_error_spec.lua +++ b/spec/parser/parser_error_spec.lua @@ -2,7 +2,7 @@ local tl = require("tl") describe("parser errors", function() it("parse errors include filename", function () - local result = tl.process_string("local x 1", false, nil, "foo.tl") + local result = tl.process_string("local x 1", nil, "foo.tl") assert.same("foo.tl", result.syntax_errors[1].filename, "parse errors should contain .filename property") end) @@ -30,7 +30,7 @@ describe("parser errors", function() local code = [[ local bar = require "bar" ]] - local result = tl.process_string(code, true, nil, "foo.tl") + local result = tl.process_string(code, nil, "foo.tl") assert.is_not_nil(string.match(result.env.loaded["./bar.tl"].syntax_errors[1].filename, "bar.tl$"), "errors should contain .filename property") end) end) diff --git a/spec/parser/parser_spec.lua b/spec/parser/parser_spec.lua index d1e66fb38..870260f90 100644 --- a/spec/parser/parser_spec.lua +++ b/spec/parser/parser_spec.lua @@ -19,6 +19,7 @@ describe("parser", function() assert.same({ kind = "statements", tk = "$EOF$", + f = "", x = 1, y = 1, xend = 5, diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index b43c9b937..ec1fb26ba 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -401,7 +401,7 @@ describe("require", function() local result, err = tl.process("foo.tl") assert.same(0, #result.syntax_errors) - assert.same(0, #result.env.loaded["foo.tl"].type_errors) + assert.same({}, result.env.loaded["foo.tl"].type_errors) assert.same(1, #result.env.loaded["./box.tl"].type_errors) assert.match("cannot use operator ..", result.env.loaded["./box.tl"].type_errors[1].msg) end) diff --git a/spec/stdlib/xpcall_spec.lua b/spec/stdlib/xpcall_spec.lua index 87089f162..16e911a7f 100644 --- a/spec/stdlib/xpcall_spec.lua +++ b/spec/stdlib/xpcall_spec.lua @@ -105,7 +105,7 @@ describe("xpcall", function() { msg = "xyz: got boolean, expected number" } })) - it("type checks the message handler", util.check_type_error([[ + it("#only type checks the message handler", util.check_type_error([[ local function f(a: string, b: number) end diff --git a/spec/util.lua b/spec/util.lua index 1976f4c66..25cbdae50 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -435,7 +435,7 @@ local function check(lax, code, unknowns, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, { filename = "foo.lua", lax = lax, gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) batch:add(assert.same, {}, result.type_errors) if unknowns then @@ -456,7 +456,7 @@ local function check_type_error(lax, code, type_errors, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) local result_type_errors = combine_result(result, "type_errors") batch_compare(batch, "type errors", type_errors, result_type_errors) @@ -525,7 +525,7 @@ function util.check_syntax_error(code, syntax_errors) local batch = batch_assertions() batch_compare(batch, "syntax errors", syntax_errors, errors) batch:assert() - tl.type_check(ast, { filename = "foo.tl", lax = false }) + tl.type_check(ast, "foo.tl", { feat_lax = "off" }) end end @@ -564,7 +564,7 @@ function util.check_types(code, types) local batch = batch_assertions() local env = tl.init_env() env.report_types = true - local result = tl.type_check(ast, { filename = "foo.tl", env = env, lax = false }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = "off" }, env) batch:add(assert.same, {}, result.type_errors, "Code was not expected to have type errors") local tr = env.reporter:get_report() @@ -596,7 +596,8 @@ local function gen(lax, code, expected, gen_target) return function() local ast, syntax_errors = tl.parse(code, "foo.tl") assert.same({}, syntax_errors, "Code was not expected to have syntax errors") - local result = assert(tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target, gen_compat = gen_target == "5.4" and "off" or nil })) + local gen_compat = gen_target == "5.4" and "off" or nil + local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) assert.same({}, result.type_errors) local output_code = tl.pretty_print_ast(ast, gen_target) diff --git a/tl b/tl index 70248553a..fc1fa315e 100755 --- a/tl +++ b/tl @@ -163,10 +163,12 @@ local function setup_env(tlconfig, filename) end local opts = { - lax_mode = lax_mode, - feat_arity = tlconfig["feat_arity"], - gen_compat = tlconfig["gen_compat"], - gen_target = tlconfig["gen_target"], + defaults = { + feat_lax = lax_mode and "on" or "off", + feat_arity = tlconfig["feat_arity"], + gen_compat = tlconfig["gen_compat"], + gen_target = tlconfig["gen_target"], + }, predefined_modules = tlconfig._init_env_modules, } @@ -916,7 +918,7 @@ do local y, x = pos:match("^(%d+):?(%d*)") y = tonumber(y) or 1 x = tonumber(x) or 1 - json_out_table(io.stdout, tl.symbols_in_scope(tr, y, x)) + json_out_table(io.stdout, tl.symbols_in_scope(tr, y, x, filename)) else tr.symbols = tr.symbols_by_file[filename] json_out_table(io.stdout, tr) diff --git a/tl.lua b/tl.lua index 241c4ebd0..dba438acf 100644 --- a/tl.lua +++ b/tl.lua @@ -1,4 +1,4 @@ -local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack; local utf8 = _tl_compat and _tl_compat.utf8 or utf8 +local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local utf8 = _tl_compat and _tl_compat.utf8 or utf8 local VERSION = "0.15.3+dev" local stdlib = [=====[ @@ -481,10 +481,16 @@ end -local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } +local Errors = {} + + + +local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } + + @@ -632,6 +638,7 @@ local TypeReporter = {} + tl.version = function() return VERSION end @@ -702,6 +709,12 @@ tl.typecodes = { +local DEFAULT_GEN_COMPAT = "optional" +local DEFAULT_GEN_TARGET = "5.3" + + + + @@ -1520,7 +1533,6 @@ end - local table_types = { @@ -1555,7 +1567,6 @@ local table_types = { ["any"] = false, ["unknown"] = false, ["invalid"] = false, - ["unresolved"] = false, ["none"] = false, ["*"] = false, } @@ -1580,6 +1591,9 @@ local table_types = { +local function is_numeric_type(t) + return t.typename == "number" or t.typename == "integer" +end @@ -1855,14 +1869,12 @@ local table_types = { -local TruthyFact = {} -local NotFact = {} @@ -1871,7 +1883,6 @@ local NotFact = {} -local AndFact = {} @@ -1881,33 +1892,34 @@ local AndFact = {} -local OrFact = {} +local TruthyFact = {} +local NotFact = {} -local EqFact = {} +local AndFact = {} -local IsFact = {} +local OrFact = {} @@ -1917,22 +1929,17 @@ local IsFact = {} +local EqFact = {} -local attributes = { - ["const"] = true, - ["close"] = true, - ["total"] = true, -} -local is_attribute = attributes -local Node = {ExpectedContext = {}, } +local IsFact = {} @@ -1954,6 +1961,15 @@ local Node = {ExpectedContext = {}, } +local attributes = { + ["const"] = true, + ["close"] = true, + ["total"] = true, +} +local is_attribute = attributes + +local Node = {ExpectedContext = {}, } + @@ -2035,9 +2051,6 @@ local Node = {ExpectedContext = {}, } -local function is_number_type(t) - return t.typename == "number" or t.typename == "integer" -end @@ -2054,95 +2067,34 @@ end -local parse_type_list -local parse_expression -local parse_expression_and_tk -local parse_statements -local parse_argument_list -local parse_argument_type_list -local parse_type -local parse_newtype -local parse_interface_name -local parse_enum_body -local parse_record_body -local parse_type_body_fns -local function fail(ps, i, msg) - if not ps.tokens[i] then - local eof = ps.tokens[#ps.tokens] - table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) - return #ps.tokens - end - table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) - return math.min(#ps.tokens, i + 1) -end -local function end_at(node, tk) - node.yend = tk.y - node.xend = tk.x + #tk.tk - 1 -end -local function verify_tk(ps, i, tk) - if ps.tokens[i].tk == tk then - return i + 1 - end - return fail(ps, i, "syntax error, expected '" .. tk .. "'") -end -local function verify_end(ps, i, istart, node) - if ps.tokens[i].tk == "end" then - local endy, endx = ps.tokens[i].y, ps.tokens[i].x - node.yend = endy - node.xend = endx + 2 - if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then - if not ps.end_alignment_hint then - ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" } - end - end - return i + 1 - end - end_at(node, ps.tokens[i]) - if ps.end_alignment_hint then - table.insert(ps.errs, ps.end_alignment_hint) - ps.end_alignment_hint = nil - end - return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") -end -local function new_node(tokens, i, kind) - local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } -end -local function a_type(typename, t) + +local function a_type(w, typename, t) t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y t.typename = typename return t end -local function edit_type(t, typename) +local function edit_type(w, t, typename) t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y t.typename = typename return t end -local function new_type(ps, i, typename) - local token = ps.tokens[i] - return a_type(typename, { - filename = ps.filename, - y = token.y, - x = token.x, - - }) -end -local function new_typedecl(ps, i, def) - local t = new_type(ps, i, "typedecl") - t.def = def - return t -end @@ -2154,20 +2106,28 @@ end +local function a_function(w, t) + assert(t.min_arity) + return a_type(w, "function", t) +end +local function a_vararg(w, t) + local typ = a_type(w, "tuple", { tuple = t }) + typ.is_va = true + return typ +end -local function a_function(t) - assert(t.min_arity) - return a_type("function", t) -end +local function a_nominal(n, names) + return a_type(n, "nominal", { names = names }) +end @@ -2177,16 +2137,63 @@ end +local an_operator +local function shallow_copy_new_type(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + copy.typeid = new_typeid() + return copy +end +local function shallow_copy_table(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + return copy +end -local function va_args(args) - args.is_va = true - return args +local function clear_redundant_errors(errors) + local redundant = {} + local lastx, lasty = 0, 0 + for i, err in ipairs(errors) do + err.i = i + end + table.sort(errors, function(a, b) + local af = assert(a.filename) + local bf = assert(b.filename) + return af < bf or + (af == bf and (a.y < b.y or + (a.y == b.y and (a.x < b.x or + (a.x == b.x and (a.i < b.i)))))) + end) + for i, err in ipairs(errors) do + err.i = nil + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end end +local simple_types = { + ["nil"] = true, + ["any"] = true, + ["number"] = true, + ["string"] = true, + ["thread"] = true, + ["boolean"] = true, + ["integer"] = true, +} +do @@ -2194,194 +2201,232 @@ end -local function a_fn(f) - local args_t = a_type("tuple", { tuple = {} }) - local tup = args_t.tuple - args_t.is_va = f.args.is_va - local min_arity = f.args.is_va and -1 or 0 - for _, a in ipairs(f.args) do - if a.opttype then - table.insert(tup, a.opttype) - else - table.insert(tup, a) - min_arity = min_arity + 1 - end - end - local rets_t = a_type("tuple", { tuple = {} }) - tup = rets_t.tuple - rets_t.is_va = f.rets.is_va - for _, a in ipairs(f.rets) do - assert(a.typename) - table.insert(tup, a) - end - return a_type("function", { - args = args_t, - rets = rets_t, - min_arity = min_arity, - needs_compat = f.needs_compat, - typeargs = f.typeargs, - }) -end -local function a_vararg(t) - local typ = a_type("tuple", { tuple = t }) - typ.is_va = true - return typ -end + local parse_type_list + local parse_expression + local parse_expression_and_tk + local parse_statements + local parse_argument_list + local parse_argument_type_list + local parse_type + local parse_newtype + local parse_interface_name + local parse_enum_body + local parse_record_body + local parse_type_body_fns -local NIL = a_type("nil", {}) -local ANY = a_type("any", {}) -local TABLE = a_type("map", { keys = ANY, values = ANY }) -local NUMBER = a_type("number", {}) -local STRING = a_type("string", {}) -local THREAD = a_type("thread", {}) -local BOOLEAN = a_type("boolean", {}) -local INTEGER = a_type("integer", {}) + local function fail(ps, i, msg) + if not ps.tokens[i] then + local eof = ps.tokens[#ps.tokens] + table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) + return #ps.tokens + end + table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) + return math.min(#ps.tokens, i + 1) + end -local function shallow_copy_new_type(t) - local copy = {} - for k, v in pairs(t) do - copy[k] = v + local function end_at(node, tk) + node.yend = tk.y + node.xend = tk.x + #tk.tk - 1 end - copy.typeid = new_typeid() - return copy -end -local function shallow_copy_table(t) - local copy = {} - for k, v in pairs(t) do - copy[k] = v + local function verify_tk(ps, i, tk) + if ps.tokens[i].tk == tk then + return i + 1 + end + return fail(ps, i, "syntax error, expected '" .. tk .. "'") + end + + local function verify_end(ps, i, istart, node) + if ps.tokens[i].tk == "end" then + local endy, endx = ps.tokens[i].y, ps.tokens[i].x + node.yend = endy + node.xend = endx + 2 + if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then + if not ps.end_alignment_hint then + ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" } + end + end + return i + 1 + end + end_at(node, ps.tokens[i]) + if ps.end_alignment_hint then + table.insert(ps.errs, ps.end_alignment_hint) + ps.end_alignment_hint = nil + end + return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end - return copy -end -local function verify_kind(ps, i, kind, node_kind) - if ps.tokens[i].kind == kind then - return i + 1, new_node(ps.tokens, i, node_kind) + local function new_node(ps, i, kind) + local t = ps.tokens[i] + return { f = ps.filename, y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } end - return fail(ps, i, "syntax error, expected " .. kind) -end + local function new_type(ps, i, typename) + local token = ps.tokens[i] + local t = {} + t.typeid = new_typeid() + t.f = ps.filename + t.x = token.x + t.y = token.y + t.typename = typename + return t + end + local function new_typedecl(ps, i, def) + local t = new_type(ps, i, "typedecl") + t.def = def + return t + end -local function skip(ps, i, skip_fn) - local err_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = {}, - } - return skip_fn(err_ps, i) -end + local function new_tuple(ps, i, types, is_va) + local t = new_type(ps, i, "tuple") + t.is_va = is_va + t.tuple = types or {} + return t, t.tuple + end -local function failskip(ps, i, msg, skip_fn, starti) - local skip_i = skip(ps, starti or i, skip_fn) - fail(ps, i, msg) - return skip_i -end + local function new_typealias(ps, i, alias_to) + local t = new_type(ps, i, "typealias") + t.alias_to = alias_to + return t + end -local function skip_type_body(ps, i) - local tn = ps.tokens[i].tk - i = i + 1 - assert(parse_type_body_fns[tn], tn .. " has no parse body function") - return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) -end + local function new_nominal(ps, i, name) + local t = new_type(ps, i, "nominal") + if name then + t.names = { name } + end + return t + end -local function parse_table_value(ps, i) - local next_word = ps.tokens[i].tk - if next_word == "record" or next_word == "interface" then - local skip_i, e = skip(ps, i, skip_type_body) - if e then - fail(ps, i, next_word == "record" and - "syntax error: this syntax is no longer valid; declare nested record inside a record" or - "syntax error: cannot declare interface inside a table; use a statement") - return skip_i, new_node(ps.tokens, i, "error_node") + local function verify_kind(ps, i, kind, node_kind) + if ps.tokens[i].kind == kind then + return i + 1, new_node(ps, i, node_kind) end - elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then - i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) - return i, new_node(ps.tokens, i - 1, "error_node") + return fail(ps, i, "syntax error, expected " .. kind) end - local e - i, e = parse_expression(ps, i) - if not e then - e = new_node(ps.tokens, i - 1, "error_node") + + + local function skip(ps, i, skip_fn) + local err_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = {}, + } + return skip_fn(err_ps, i) end - return i, e -end -local function parse_table_item(ps, i, n) - local node = new_node(ps.tokens, i, "literal_table_item") - if ps.tokens[i].kind == "$EOF$" then - return fail(ps, i, "unexpected eof") + local function failskip(ps, i, msg, skip_fn, starti) + local skip_i = skip(ps, starti or i, skip_fn) + fail(ps, i, msg) + return skip_i end - if ps.tokens[i].tk == "[" then - node.key_parsed = "long" + local function skip_type_body(ps, i) + local tn = ps.tokens[i].tk i = i + 1 - i, node.key = parse_expression_and_tk(ps, i, "]") - i = verify_tk(ps, i, "=") - i, node.value = parse_table_value(ps, i) - return i, node, n - elseif ps.tokens[i].kind == "identifier" then - if ps.tokens[i + 1].tk == "=" then - node.key_parsed = "short" - i, node.key = verify_kind(ps, i, "identifier", "string") - node.key.conststr = node.key.tk - node.key.tk = '"' .. node.key.tk .. '"' + assert(parse_type_body_fns[tn], tn .. " has no parse body function") + return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) + end + + local function parse_table_value(ps, i) + local next_word = ps.tokens[i].tk + if next_word == "record" or next_word == "interface" then + local skip_i, e = skip(ps, i, skip_type_body) + if e then + fail(ps, i, next_word == "record" and + "syntax error: this syntax is no longer valid; declare nested record inside a record" or + "syntax error: cannot declare interface inside a table; use a statement") + return skip_i, new_node(ps, i, "error_node") + end + elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then + i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) + return i, new_node(ps, i - 1, "error_node") + end + + local e + i, e = parse_expression(ps, i) + if not e then + e = new_node(ps, i - 1, "error_node") + end + return i, e + end + + local function parse_table_item(ps, i, n) + local node = new_node(ps, i, "literal_table_item") + if ps.tokens[i].kind == "$EOF$" then + return fail(ps, i, "unexpected eof") + end + + if ps.tokens[i].tk == "[" then + node.key_parsed = "long" + i = i + 1 + i, node.key = parse_expression_and_tk(ps, i, "]") i = verify_tk(ps, i, "=") i, node.value = parse_table_value(ps, i) return i, node, n - elseif ps.tokens[i + 1].tk == ":" then - node.key_parsed = "short" - local orig_i = i - local try_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = ps.required_modules, - } - i, node.key = verify_kind(try_ps, i, "identifier", "string") - node.key.conststr = node.key.tk - node.key.tk = '"' .. node.key.tk .. '"' - i = verify_tk(try_ps, i, ":") - i, node.itemtype = parse_type(try_ps, i) - if node.itemtype and ps.tokens[i].tk == "=" then - i = verify_tk(try_ps, i, "=") - i, node.value = parse_table_value(try_ps, i) - if node.value then - for _, e in ipairs(try_ps.errs) do - table.insert(ps.errs, e) + elseif ps.tokens[i].kind == "identifier" then + if ps.tokens[i + 1].tk == "=" then + node.key_parsed = "short" + i, node.key = verify_kind(ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(ps, i, "=") + i, node.value = parse_table_value(ps, i) + return i, node, n + elseif ps.tokens[i + 1].tk == ":" then + node.key_parsed = "short" + local orig_i = i + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = ps.required_modules, + } + i, node.key = verify_kind(try_ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(try_ps, i, ":") + i, node.itemtype = parse_type(try_ps, i) + if node.itemtype and ps.tokens[i].tk == "=" then + i = verify_tk(try_ps, i, "=") + i, node.value = parse_table_value(try_ps, i) + if node.value then + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + return i, node, n end - return i, node, n end - end - node.itemtype = nil - i = orig_i + node.itemtype = nil + i = orig_i + end end - end - node.key = new_node(ps.tokens, i, "integer") - node.key_parsed = "implicit" - node.key.constnum = n - node.key.tk = tostring(n) - i, node.value = parse_expression(ps, i) - if not node.value then - return fail(ps, i, "expected an expression") + node.key = new_node(ps, i, "integer") + node.key_parsed = "implicit" + node.key.constnum = n + node.key.tk = tostring(n) + i, node.value = parse_expression(ps, i) + if not node.value then + return fail(ps, i, "expected an expression") + end + return i, node, n + 1 end - return i, node, n + 1 -end @@ -2390,794 +2435,780 @@ end -local function parse_list(ps, i, list, close, sep, parse_item) - local n = 1 - while ps.tokens[i].kind ~= "$EOF$" do - if close[ps.tokens[i].tk] then - end_at(list, ps.tokens[i]) - break - end - local item - local oldn = n - i, item, n = parse_item(ps, i, n) - n = n or oldn - table.insert(list, item) - if ps.tokens[i].tk == "," then - i = i + 1 - if sep == "sep" and close[ps.tokens[i].tk] then - fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") - return i, list - end - elseif sep == "term" and ps.tokens[i].tk == ";" then - i = i + 1 - elseif not close[ps.tokens[i].tk] then - local options = {} - for k, _ in pairs(close) do - table.insert(options, "'" .. k .. "'") - end - table.sort(options) - local first = options[1]:sub(2, -2) - local msg - - if first == ")" and ps.tokens[i].tk == "=" then - msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" - i = failskip(ps, i, msg, parse_expression, i + 1) - else - table.insert(options, "','") - msg = "syntax error, expected one of: " .. table.concat(options, ", ") - fail(ps, i, msg) + local function parse_list(ps, i, list, close, sep, parse_item) + local n = 1 + while ps.tokens[i].kind ~= "$EOF$" do + if close[ps.tokens[i].tk] then + end_at(list, ps.tokens[i]) + break end + local item + local oldn = n + i, item, n = parse_item(ps, i, n) + n = n or oldn + table.insert(list, item) + if ps.tokens[i].tk == "," then + i = i + 1 + if sep == "sep" and close[ps.tokens[i].tk] then + fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") + return i, list + end + elseif sep == "term" and ps.tokens[i].tk == ";" then + i = i + 1 + elseif not close[ps.tokens[i].tk] then + local options = {} + for k, _ in pairs(close) do + table.insert(options, "'" .. k .. "'") + end + table.sort(options) + local first = options[1]:sub(2, -2) + local msg + + if first == ")" and ps.tokens[i].tk == "=" then + msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + i = failskip(ps, i, msg, parse_expression, i + 1) + else + table.insert(options, "','") + msg = "syntax error, expected one of: " .. table.concat(options, ", ") + fail(ps, i, msg) + end - if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then + if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then - table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) - return i, list + table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) + return i, list + end end end + return i, list end - return i, list -end - -local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) - i = verify_tk(ps, i, open) - i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) - i = verify_tk(ps, i, close) - return i, list -end - -local function parse_table_literal(ps, i) - local node = new_node(ps.tokens, i, "literal_table") - return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) -end -local function parse_trying_list(ps, i, list, parse_item, ret_lookahead) - local try_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = ps.required_modules, - } - local tryi, item = parse_item(try_ps, i) - if not item then + local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) + i = verify_tk(ps, i, open) + i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) + i = verify_tk(ps, i, close) return i, list end - for _, e in ipairs(try_ps.errs) do - table.insert(ps.errs, e) + + local function parse_table_literal(ps, i) + local node = new_node(ps, i, "literal_table") + return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end - i = tryi - table.insert(list, item) - while ps.tokens[i].tk == "," and - (not ret_lookahead or - (not (ps.tokens[i + 1].kind == "identifier" and - ps.tokens[i + 2] and ps.tokens[i + 2].tk == ":"))) do - i = i + 1 - i, item = parse_item(ps, i) + local function parse_trying_list(ps, i, list, parse_item, ret_lookahead) + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = ps.required_modules, + } + local tryi, item = parse_item(try_ps, i) + if not item then + return i, list + end + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + i = tryi table.insert(list, item) - end - return i, list -end + while ps.tokens[i].tk == "," and + (not ret_lookahead or + (not (ps.tokens[i + 1].kind == "identifier" and + ps.tokens[i + 2] and ps.tokens[i + 2].tk == ":"))) do -local function parse_anglebracket_list(ps, i, parse_item) - if ps.tokens[i + 1].tk == ">" then - return fail(ps, i + 1, "type argument list cannot be empty") + i = i + 1 + i, item = parse_item(ps, i) + table.insert(list, item) + end + return i, list end - local types = {} - i = verify_tk(ps, i, "<") - i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) - if ps.tokens[i].tk == ">" then - i = i + 1 - elseif ps.tokens[i].tk == ">>" then - ps.tokens[i].tk = ">" - else - return fail(ps, i, "syntax error, expected '>'") - end - return i, types -end + local function parse_anglebracket_list(ps, i, parse_item) + if ps.tokens[i + 1].tk == ">" then + return fail(ps, i + 1, "type argument list cannot be empty") + end + local types = {} + i = verify_tk(ps, i, "<") + i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) + if ps.tokens[i].tk == ">" then + i = i + 1 + elseif ps.tokens[i].tk == ">>" then -local function parse_typearg(ps, i) - local name = ps.tokens[i].tk - local constraint - i = verify_kind(ps, i, "identifier") - if ps.tokens[i].tk == "is" then - i = i + 1 - i, constraint = parse_interface_name(ps, i) + ps.tokens[i].tk = ">" + else + return fail(ps, i, "syntax error, expected '>'") + end + return i, types end - return i, a_type("typearg", { - y = ps.tokens[i - 2].y, - x = ps.tokens[i - 2].x, - typearg = name, - constraint = constraint, - }) -end - -local function parse_return_types(ps, i) - return parse_type_list(ps, i, "rets") -end -local function parse_function_type(ps, i) - local typ = new_type(ps, i, "function") - i = i + 1 - if ps.tokens[i].tk == "<" then - i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) - i, typ.rets = parse_return_types(ps, i) - else - typ.args = a_vararg({ ANY }) - typ.rets = a_vararg({ ANY }) + local function parse_typearg(ps, i) + local name = ps.tokens[i].tk + local constraint + i = verify_kind(ps, i, "identifier") + if ps.tokens[i].tk == "is" then + i = i + 1 + i, constraint = parse_interface_name(ps, i) + end + local t = new_type(ps, i, "typearg") + t.typearg = name + t.constraint = constraint + return i, t end - return i, typ -end - -local simple_types = { - ["nil"] = NIL, - ["any"] = ANY, - ["table"] = TABLE, - ["number"] = NUMBER, - ["string"] = STRING, - ["thread"] = THREAD, - ["boolean"] = BOOLEAN, - ["integer"] = INTEGER, -} -local function parse_simple_type_or_nominal(ps, i) - local tk = ps.tokens[i].tk - local st = simple_types[tk] - if st then - return i + 1, st + local function parse_return_types(ps, i) + local iprev = i - 1 + local t + i, t = parse_type_list(ps, i, "rets") + if #t.tuple == 0 then + t.x = ps.tokens[iprev].x + t.y = ps.tokens[iprev].y + end + return i, t end - local typ = new_type(ps, i, "nominal") - typ.names = { tk } - i = i + 1 - while ps.tokens[i].tk == "." do + + local function parse_function_type(ps, i) + local typ = new_type(ps, i, "function") i = i + 1 - if ps.tokens[i].kind == "identifier" then - table.insert(typ.names, ps.tokens[i].tk) - i = i + 1 + if ps.tokens[i].tk == "<" then + i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + if ps.tokens[i].tk == "(" then + i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) + i, typ.rets = parse_return_types(ps, i) else - return fail(ps, i, "syntax error, expected identifier") + typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) end + return i, typ end - if ps.tokens[i].tk == "<" then - i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) - end - return i, typ -end + local function parse_simple_type_or_nominal(ps, i) + local tk = ps.tokens[i].tk + local st = simple_types[tk] + if st then + return i + 1, new_type(ps, i, tk) + elseif tk == "table" then + local typ = new_type(ps, i, "map") + typ.keys = new_type(ps, i, "any") + typ.values = new_type(ps, i, "any") + return i + 1, typ + end -local function parse_base_type(ps, i) - local tk = ps.tokens[i].tk - if ps.tokens[i].kind == "identifier" then - return parse_simple_type_or_nominal(ps, i) - elseif tk == "{" then - local istart = i + local typ = new_nominal(ps, i, tk) i = i + 1 - local t - i, t = parse_type(ps, i) - if not t then - return i - end - if ps.tokens[i].tk == "}" then - local decl = new_type(ps, istart, "array") - decl.elements = t - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - elseif ps.tokens[i].tk == "," then - local decl = new_type(ps, istart, "tupletable") - decl.types = { t } - local n = 2 - repeat - i = i + 1 - i, decl.types[n] = parse_type(ps, i) - if not decl.types[n] then - break - end - n = n + 1 - until ps.tokens[i].tk ~= "," - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - elseif ps.tokens[i].tk == ":" then - local decl = new_type(ps, istart, "map") + while ps.tokens[i].tk == "." do i = i + 1 - decl.keys = t - i, decl.values = parse_type(ps, i) - if not decl.values then - return i + if ps.tokens[i].kind == "identifier" then + table.insert(typ.names, ps.tokens[i].tk) + i = i + 1 + else + return fail(ps, i, "syntax error, expected identifier") end - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - end - return fail(ps, i, "syntax error; did you forget a '}'?") - elseif tk == "function" then - return parse_function_type(ps, i) - elseif tk == "nil" then - return i + 1, simple_types["nil"] - elseif tk == "table" then - local typ = new_type(ps, i, "map") - typ.keys = ANY - typ.values = ANY - return i + 1, typ - end - return fail(ps, i, "expected a type") -end + end -parse_type = function(ps, i) - if ps.tokens[i].tk == "(" then - i = i + 1 - local t - i, t = parse_type(ps, i) - i = verify_tk(ps, i, ")") - return i, t + if ps.tokens[i].tk == "<" then + i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) + end + return i, typ end - local bt - local istart = i - i, bt = parse_base_type(ps, i) - if not bt then - return i - end - if ps.tokens[i].tk == "|" then - local u = new_type(ps, istart, "union") - u.types = { bt } - while ps.tokens[i].tk == "|" do + local function parse_base_type(ps, i) + local tk = ps.tokens[i].tk + if ps.tokens[i].kind == "identifier" then + return parse_simple_type_or_nominal(ps, i) + elseif tk == "{" then + local istart = i i = i + 1 - i, bt = parse_base_type(ps, i) - if not bt then + local t + i, t = parse_type(ps, i) + if not t then return i end - table.insert(u.types, bt) + if ps.tokens[i].tk == "}" then + local decl = new_type(ps, istart, "array") + decl.elements = t + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + elseif ps.tokens[i].tk == "," then + local decl = new_type(ps, istart, "tupletable") + decl.types = { t } + local n = 2 + repeat + i = i + 1 + i, decl.types[n] = parse_type(ps, i) + if not decl.types[n] then + break + end + n = n + 1 + until ps.tokens[i].tk ~= "," + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + elseif ps.tokens[i].tk == ":" then + local decl = new_type(ps, istart, "map") + i = i + 1 + decl.keys = t + i, decl.values = parse_type(ps, i) + if not decl.values then + return i + end + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + end + return fail(ps, i, "syntax error; did you forget a '}'?") + elseif tk == "function" then + return parse_function_type(ps, i) + elseif tk == "nil" then + return i + 1, new_type(ps, i, "nil") end - bt = u + return fail(ps, i, "expected a type") end - return i, bt -end - -local function new_tuple(ps, i) - local t = new_type(ps, i, "tuple") - t.tuple = {} - return t, t.tuple -end -parse_type_list = function(ps, i, mode) - local t, list = new_tuple(ps, i) - - local first_token = ps.tokens[i].tk - if mode == "rets" or mode == "decltuple" then - if first_token == ":" then + parse_type = function(ps, i) + if ps.tokens[i].tk == "(" then i = i + 1 - else + local t + i, t = parse_type(ps, i) + i = verify_tk(ps, i, ")") return i, t end - end - local optional_paren = false - if ps.tokens[i].tk == "(" then - optional_paren = true - i = i + 1 + local bt + local istart = i + i, bt = parse_base_type(ps, i) + if not bt then + return i + end + if ps.tokens[i].tk == "|" then + local u = new_type(ps, istart, "union") + u.types = { bt } + while ps.tokens[i].tk == "|" do + i = i + 1 + i, bt = parse_base_type(ps, i) + if not bt then + return i + end + table.insert(u.types, bt) + end + bt = u + end + return i, bt end - local prev_i = i - i = parse_trying_list(ps, i, list, parse_type, mode == "rets") - if i == prev_i and ps.tokens[i].tk ~= ")" then - fail(ps, i - 1, "expected a type list") - end + parse_type_list = function(ps, i, mode) + local t, list = new_tuple(ps, i) - if mode == "rets" and ps.tokens[i].tk == "..." then - i = i + 1 - local nrets = #list - if nrets > 0 then - t.is_va = true - else - fail(ps, i, "unexpected '...'") + local first_token = ps.tokens[i].tk + if mode == "rets" or mode == "decltuple" then + if first_token == ":" then + i = i + 1 + else + return i, t + end end - end - if optional_paren then - i = verify_tk(ps, i, ")") - end + local optional_paren = false + if ps.tokens[i].tk == "(" then + optional_paren = true + i = i + 1 + end - return i, t -end + local prev_i = i + i = parse_trying_list(ps, i, list, parse_type, mode == "rets") + if i == prev_i and ps.tokens[i].tk ~= ")" then + fail(ps, i - 1, "expected a type list") + end -local function parse_function_args_rets_body(ps, i, node) - local istart = i - 1 - if ps.tokens[i].tk == "<" then - i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - i, node.args, node.min_arity = parse_argument_list(ps, i) - i, node.rets = parse_return_types(ps, i) - i, node.body = parse_statements(ps, i) - end_at(node, ps.tokens[i]) - i = verify_end(ps, i, istart, node) - return i, node -end + if mode == "rets" and ps.tokens[i].tk == "..." then + i = i + 1 + local nrets = #list + if nrets > 0 then + t.is_va = true + else + fail(ps, i, "unexpected '...'") + end + end -local function parse_function_value(ps, i) - local node = new_node(ps.tokens, i, "function") - i = verify_tk(ps, i, "function") - return parse_function_args_rets_body(ps, i, node) -end + if optional_paren then + i = verify_tk(ps, i, ")") + end -local function unquote(str) - local f = str:sub(1, 1) - if f == '"' or f == "'" then - return str:sub(2, -2), false + return i, t end - f = str:match("^%[=*%[") - local l = #f + 1 - return str:sub(l, -l), true -end -local function parse_literal(ps, i) - local tk = ps.tokens[i].tk - local kind = ps.tokens[i].kind - if kind == "identifier" then - return verify_kind(ps, i, "identifier", "variable") - elseif kind == "string" then - local node = new_node(ps.tokens, i, "string") - node.conststr, node.is_longstring = unquote(tk) - return i + 1, node - elseif kind == "number" or kind == "integer" then - local n = tonumber(tk) - local node - i, node = verify_kind(ps, i, kind) - node.constnum = n + local function parse_function_args_rets_body(ps, i, node) + local istart = i - 1 + if ps.tokens[i].tk == "<" then + i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + i, node.args, node.min_arity = parse_argument_list(ps, i) + i, node.rets = parse_return_types(ps, i) + i, node.body = parse_statements(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) return i, node - elseif tk == "true" then - return verify_kind(ps, i, "keyword", "boolean") - elseif tk == "false" then - return verify_kind(ps, i, "keyword", "boolean") - elseif tk == "nil" then - return verify_kind(ps, i, "keyword", "nil") - elseif tk == "function" then - return parse_function_value(ps, i) - elseif tk == "{" then - return parse_table_literal(ps, i) - elseif kind == "..." then - return verify_kind(ps, i, "...") - elseif kind == "$ERR invalid_string$" then - return fail(ps, i, "malformed string") - elseif kind == "$ERR invalid_number$" then - return fail(ps, i, "malformed number") - end - return fail(ps, i, "syntax error") -end - -local function node_is_require_call(n) - if n.e1 and n.e2 and - n.e1.kind == "variable" and n.e1.tk == "require" and - n.e2.kind == "expression_list" and #n.e2 == 1 and - n.e2[1].kind == "string" then - - return n.e2[1].conststr - elseif n.op and n.op.op == "@funcall" and - n.e1 and n.e1.tk == "pcall" and - n.e2 and #n.e2 == 2 and - n.e2[1].kind == "variable" and n.e2[1].tk == "require" and - n.e2[2].kind == "string" and n.e2[2].conststr then + end - return n.e2[2].conststr - else - return nil + local function parse_function_value(ps, i) + local node = new_node(ps, i, "function") + i = verify_tk(ps, i, "function") + return parse_function_args_rets_body(ps, i, node) end -end -local an_operator + local function unquote(str) + local f = str:sub(1, 1) + if f == '"' or f == "'" then + return str:sub(2, -2), false + end + f = str:match("^%[=*%[") + local l = #f + 1 + return str:sub(l, -l), true + end -do - local precedences = { - [1] = { - ["not"] = 11, - ["#"] = 11, - ["-"] = 11, - ["~"] = 11, - }, - [2] = { - ["or"] = 1, - ["and"] = 2, - ["is"] = 3, - ["<"] = 3, - [">"] = 3, - ["<="] = 3, - [">="] = 3, - ["~="] = 3, - ["=="] = 3, - ["|"] = 4, - ["~"] = 5, - ["&"] = 6, - ["<<"] = 7, - [">>"] = 7, - [".."] = 8, - ["+"] = 9, - ["-"] = 9, - ["*"] = 10, - ["/"] = 10, - ["//"] = 10, - ["%"] = 10, - ["^"] = 12, - ["as"] = 50, - ["@funcall"] = 100, - ["@index"] = 100, - ["."] = 100, - [":"] = 100, - }, - } + local function parse_literal(ps, i) + local tk = ps.tokens[i].tk + local kind = ps.tokens[i].kind + if kind == "identifier" then + return verify_kind(ps, i, "identifier", "variable") + elseif kind == "string" then + local node = new_node(ps, i, "string") + node.conststr, node.is_longstring = unquote(tk) + return i + 1, node + elseif kind == "number" or kind == "integer" then + local n = tonumber(tk) + local node + i, node = verify_kind(ps, i, kind) + node.constnum = n + return i, node + elseif tk == "true" then + return verify_kind(ps, i, "keyword", "boolean") + elseif tk == "false" then + return verify_kind(ps, i, "keyword", "boolean") + elseif tk == "nil" then + return verify_kind(ps, i, "keyword", "nil") + elseif tk == "function" then + return parse_function_value(ps, i) + elseif tk == "{" then + return parse_table_literal(ps, i) + elseif kind == "..." then + return verify_kind(ps, i, "...") + elseif kind == "$ERR invalid_string$" then + return fail(ps, i, "malformed string") + elseif kind == "$ERR invalid_number$" then + return fail(ps, i, "malformed number") + end + return fail(ps, i, "syntax error") + end + + local function node_is_require_call(n) + if n.e1 and n.e2 and + n.e1.kind == "variable" and n.e1.tk == "require" and + n.e2.kind == "expression_list" and #n.e2 == 1 and + n.e2[1].kind == "string" then + + return n.e2[1].conststr + elseif n.op and n.op.op == "@funcall" and + n.e1 and n.e1.tk == "pcall" and + n.e2 and #n.e2 == 2 and + n.e2[1].kind == "variable" and n.e2[1].tk == "require" and + n.e2[2].kind == "string" and n.e2[2].conststr then + + return n.e2[2].conststr + else + return nil + end + end - local is_right_assoc = { - ["^"] = true, - [".."] = true, - } + do + local precedences = { + [1] = { + ["not"] = 11, + ["#"] = 11, + ["-"] = 11, + ["~"] = 11, + }, + [2] = { + ["or"] = 1, + ["and"] = 2, + ["is"] = 3, + ["<"] = 3, + [">"] = 3, + ["<="] = 3, + [">="] = 3, + ["~="] = 3, + ["=="] = 3, + ["|"] = 4, + ["~"] = 5, + ["&"] = 6, + ["<<"] = 7, + [">>"] = 7, + [".."] = 8, + ["+"] = 9, + ["-"] = 9, + ["*"] = 10, + ["/"] = 10, + ["//"] = 10, + ["%"] = 10, + ["^"] = 12, + ["as"] = 50, + ["@funcall"] = 100, + ["@index"] = 100, + ["."] = 100, + [":"] = 100, + }, + } - local function new_operator(tk, arity, op) - return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } - end + local is_right_assoc = { + ["^"] = true, + [".."] = true, + } - an_operator = function(node, arity, op) - return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } - end + local function new_operator(tk, arity, op) + return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } + end - local args_starters = { - ["("] = true, - ["{"] = true, - ["string"] = true, - } + an_operator = function(node, arity, op) + return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } + end - local E + local args_starters = { + ["("] = true, + ["{"] = true, + ["string"] = true, + } - local function after_valid_prefixexp(ps, prevnode, i) - return ps.tokens[i - 1].kind == ")" or - (prevnode.kind == "op" and - (prevnode.op.op == "@funcall" or - prevnode.op.op == "@index" or - prevnode.op.op == "." or - prevnode.op.op == ":")) or + local E - prevnode.kind == "identifier" or - prevnode.kind == "variable" - end + local function after_valid_prefixexp(ps, prevnode, i) + return ps.tokens[i - 1].kind == ")" or + (prevnode.kind == "op" and + (prevnode.op.op == "@funcall" or + prevnode.op.op == "@index" or + prevnode.op.op == "." or + prevnode.op.op == ":")) or + prevnode.kind == "identifier" or + prevnode.kind == "variable" + end - local function failstore(tkop, e1) - return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } - end - local function P(ps, i) - if ps.tokens[i].kind == "$EOF$" then - return i + local function failstore(ps, tkop, e1) + return { f = ps.filename, y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } end - local e1 - local t1 = ps.tokens[i] - if precedences[1][t1.tk] ~= nil then - local op = new_operator(t1, 1, t1.tk) - i = i + 1 - local prev_i = i - i, e1 = P(ps, i) - if not e1 then - fail(ps, prev_i, "expected an expression") + + local function P(ps, i) + if ps.tokens[i].kind == "$EOF$" then return i end - e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } - elseif ps.tokens[i].tk == "(" then - i = i + 1 - local prev_i = i - i, e1 = parse_expression_and_tk(ps, i, ")") + local e1 + local t1 = ps.tokens[i] + if precedences[1][t1.tk] ~= nil then + local op = new_operator(t1, 1, t1.tk) + i = i + 1 + local prev_i = i + i, e1 = P(ps, i) + if not e1 then + fail(ps, prev_i, "expected an expression") + return i + end + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + elseif ps.tokens[i].tk == "(" then + i = i + 1 + local prev_i = i + i, e1 = parse_expression_and_tk(ps, i, ")") + if not e1 then + fail(ps, prev_i, "expected an expression") + return i + end + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + else + i, e1 = parse_literal(ps, i) + end + if not e1 then - fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } - else - i, e1 = parse_literal(ps, i) - end - if not e1 then - return i - end - - while true do - local tkop = ps.tokens[i] - if tkop.kind == "," or tkop.kind == ")" then - break - end - if tkop.tk == "." or tkop.tk == ":" then - local op = new_operator(tkop, 2, tkop.tk) + while true do + local tkop = ps.tokens[i] + if tkop.kind == "," or tkop.kind == ")" then + break + end + if tkop.tk == "." or tkop.tk == ":" then + local op = new_operator(tkop, 2, tkop.tk) - local prev_i = i + local prev_i = i - local key - i = i + 1 - if ps.tokens[i].kind ~= "identifier" then - local skipped = skip(ps, i, parse_type) - if skipped > i + 1 then - fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") - return skipped, failstore(tkop, e1) + local key + i = i + 1 + if ps.tokens[i].kind ~= "identifier" then + local skipped = skip(ps, i, parse_type) + if skipped > i + 1 then + fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") + return skipped, failstore(ps, tkop, e1) + end + end + i, key = verify_kind(ps, i, "identifier") + if not key then + return i, failstore(ps, tkop, e1) end - end - i, key = verify_kind(ps, i, "identifier") - if not key then - return i, failstore(tkop, e1) - end - if op.op == ":" then - if not args_starters[ps.tokens[i].kind] then - if ps.tokens[i].tk == "=" then - fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)") - else - fail(ps, i, "expected a function call for a method") + if op.op == ":" then + if not args_starters[ps.tokens[i].kind] then + if ps.tokens[i].tk == "=" then + fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)") + else + fail(ps, i, "expected a function call for a method") + end + return i, failstore(ps, tkop, e1) end - return i, failstore(tkop, e1) - end - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot call a method on this expression") - return i, failstore(tkop, e1) + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot call a method on this expression") + return i, failstore(ps, tkop, e1) + end end - end - - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } - elseif tkop.tk == "(" then - local prev_tk = ps.tokens[i - 1] - if tkop.y > prev_tk.y then - table.insert(ps.tokens, i, { y = prev_tk.y, x = prev_tk.x + #prev_tk.tk, tk = ";", kind = ";" }) - break - end - local op = new_operator(tkop, 2, "@funcall") + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } + elseif tkop.tk == "(" then + local prev_tk = ps.tokens[i - 1] + if tkop.y > prev_tk.y then + table.insert(ps.tokens, i, { y = prev_tk.y, x = prev_tk.x + #prev_tk.tk, tk = ";", kind = ";" }) + break + end - local prev_i = i + local op = new_operator(tkop, 2, "@funcall") - local args = new_node(ps.tokens, i, "expression_list") - i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot call this expression") - return i, failstore(tkop, e1) - end + local args = new_node(ps, i, "expression_list") + i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot call this expression") + return i, failstore(ps, tkop, e1) + end - table.insert(ps.required_modules, node_is_require_call(e1)) - elseif tkop.tk == "[" then - local op = new_operator(tkop, 2, "@index") + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - local prev_i = i + table.insert(ps.required_modules, node_is_require_call(e1)) + elseif tkop.tk == "[" then + local op = new_operator(tkop, 2, "@index") - local idx - i = i + 1 - i, idx = parse_expression_and_tk(ps, i, "]") + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot index this expression") - return i, failstore(tkop, e1) - end + local idx + i = i + 1 + i, idx = parse_expression_and_tk(ps, i, "]") - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } - elseif tkop.kind == "string" or tkop.kind == "{" then - local op = new_operator(tkop, 2, "@funcall") + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot index this expression") + return i, failstore(ps, tkop, e1) + end - local prev_i = i + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } + elseif tkop.kind == "string" or tkop.kind == "{" then + local op = new_operator(tkop, 2, "@funcall") - local args = new_node(ps.tokens, i, "expression_list") - local argument - if tkop.kind == "string" then - argument = new_node(ps.tokens, i) - argument.conststr = unquote(tkop.tk) - i = i + 1 - else - i, argument = parse_table_literal(ps, i) - end + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then + local args = new_node(ps, i, "expression_list") + local argument if tkop.kind == "string" then - fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") + argument = new_node(ps, i) + argument.conststr = unquote(tkop.tk) + i = i + 1 else - fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") + i, argument = parse_table_literal(ps, i) end - return i, failstore(tkop, e1) - end - table.insert(args, argument) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + if not after_valid_prefixexp(ps, e1, prev_i) then + if tkop.kind == "string" then + fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") + else + fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") + end + return i, failstore(ps, tkop, e1) + end - table.insert(ps.required_modules, node_is_require_call(e1)) - elseif tkop.tk == "as" or tkop.tk == "is" then - local op = new_operator(tkop, 2, tkop.tk) + table.insert(args, argument) + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - i = i + 1 - local cast = new_node(ps.tokens, i, "cast") - if ps.tokens[i].tk == "(" then - i, cast.casttype = parse_type_list(ps, i, "casttype") + table.insert(ps.required_modules, node_is_require_call(e1)) + elseif tkop.tk == "as" or tkop.tk == "is" then + local op = new_operator(tkop, 2, tkop.tk) + + i = i + 1 + local cast = new_node(ps, i, "cast") + if ps.tokens[i].tk == "(" then + i, cast.casttype = parse_type_list(ps, i, "casttype") + else + i, cast.casttype = parse_type(ps, i) + end + if not cast.casttype then + return i, failstore(ps, tkop, e1) + end + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else - i, cast.casttype = parse_type(ps, i) - end - if not cast.casttype then - return i, failstore(tkop, e1) + break end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } - else - break end - end - return i, e1 - end + return i, e1 + end - E = function(ps, i, lhs, min_precedence) - local lookahead = ps.tokens[i].tk - while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do - local t1 = ps.tokens[i] - local op = new_operator(t1, 2, t1.tk) - i = i + 1 - local rhs - i, rhs = P(ps, i) - if not rhs then - fail(ps, i, "expected an expression") - return i - end - lookahead = ps.tokens[i].tk - while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or - (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do - i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + E = function(ps, i, lhs, min_precedence) + local lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do + local t1 = ps.tokens[i] + local op = new_operator(t1, 2, t1.tk) + i = i + 1 + local rhs + i, rhs = P(ps, i) if not rhs then fail(ps, i, "expected an expression") return i end lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or + (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do + i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + if not rhs then + fail(ps, i, "expected an expression") + return i + end + lookahead = ps.tokens[i].tk + end + lhs = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } end - lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } + return i, lhs end - return i, lhs - end - parse_expression = function(ps, i) - local lhs - local istart = i - i, lhs = P(ps, i) - if lhs then - i, lhs = E(ps, i, lhs, 0) - end - if lhs then - return i, lhs, 0 - end + parse_expression = function(ps, i) + local lhs + local istart = i + i, lhs = P(ps, i) + if lhs then + i, lhs = E(ps, i, lhs, 0) + end + if lhs then + return i, lhs, 0 + end - if i == istart then - i = fail(ps, i, "expected an expression") + if i == istart then + i = fail(ps, i, "expected an expression") + end + return i end - return i end -end -parse_expression_and_tk = function(ps, i, tk) - local e - i, e = parse_expression(ps, i) - if not e then - e = new_node(ps.tokens, i - 1, "error_node") - end - if ps.tokens[i].tk == tk then - i = i + 1 - else - local msg = "syntax error, expected '" .. tk .. "'" - if ps.tokens[i].tk == "=" then - msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + parse_expression_and_tk = function(ps, i, tk) + local e + i, e = parse_expression(ps, i) + if not e then + e = new_node(ps, i - 1, "error_node") end + if ps.tokens[i].tk == tk then + i = i + 1 + else + local msg = "syntax error, expected '" .. tk .. "'" + if ps.tokens[i].tk == "=" then + msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + end - for n = 0, 19 do - local t = ps.tokens[i + n] - if t.kind == "$EOF$" then - break - end - if t.tk == tk then - fail(ps, i, msg) - return i + n + 1, e + for n = 0, 19 do + local t = ps.tokens[i + n] + if t.kind == "$EOF$" then + break + end + if t.tk == tk then + fail(ps, i, msg) + return i + n + 1, e + end end + i = fail(ps, i, msg) end - i = fail(ps, i, msg) + return i, e end - return i, e -end -local function parse_variable_name(ps, i) - local node - i, node = verify_kind(ps, i, "identifier") - if not node then - return i - end - if ps.tokens[i].tk == "<" then - i = i + 1 - local annotation - i, annotation = verify_kind(ps, i, "identifier") - if annotation then - if not is_attribute[annotation.tk] then - fail(ps, i, "unknown variable annotation: " .. annotation.tk) + local function parse_variable_name(ps, i) + local node + i, node = verify_kind(ps, i, "identifier") + if not node then + return i + end + if ps.tokens[i].tk == "<" then + i = i + 1 + local annotation + i, annotation = verify_kind(ps, i, "identifier") + if annotation then + if not is_attribute[annotation.tk] then + fail(ps, i, "unknown variable annotation: " .. annotation.tk) + end + node.attribute = annotation.tk + else + fail(ps, i, "expected a variable annotation") end - node.attribute = annotation.tk - else - fail(ps, i, "expected a variable annotation") + i = verify_tk(ps, i, ">") end - i = verify_tk(ps, i, ">") + return i, node end - return i, node -end -local function parse_argument(ps, i) - local node - if ps.tokens[i].tk == "..." then - i, node = verify_kind(ps, i, "...", "argument") - node.opt = true - else - i, node = verify_kind(ps, i, "identifier", "argument") - end - if ps.tokens[i].tk == "..." then - fail(ps, i, "'...' needs to be declared as a typed argument") - end - if ps.tokens[i].tk == "?" then - i = i + 1 - node.opt = true - end - if ps.tokens[i].tk == ":" then - i = i + 1 - local argtype + local function parse_argument(ps, i) + local node + if ps.tokens[i].tk == "..." then + i, node = verify_kind(ps, i, "...", "argument") + node.opt = true + else + i, node = verify_kind(ps, i, "identifier", "argument") + end + if ps.tokens[i].tk == "..." then + fail(ps, i, "'...' needs to be declared as a typed argument") + end + if ps.tokens[i].tk == "?" then + i = i + 1 + node.opt = true + end + if ps.tokens[i].tk == ":" then + i = i + 1 + local argtype - i, argtype = parse_type(ps, i) + i, argtype = parse_type(ps, i) - if node then - node.argtype = argtype + if node then + node.argtype = argtype + end end + return i, node, 0 end - return i, node, 0 -end -parse_argument_list = function(ps, i) - local node = new_node(ps.tokens, i, "argument_list") - i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) - local opts = false - local min_arity = 0 - for a, fnarg in ipairs(node) do - if fnarg.tk == "..." then - if a ~= #node then - fail(ps, i, "'...' can only be last argument") - break + parse_argument_list = function(ps, i) + local node = new_node(ps, i, "argument_list") + i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) + local opts = false + local min_arity = 0 + for a, fnarg in ipairs(node) do + if fnarg.tk == "..." then + if a ~= #node then + fail(ps, i, "'...' can only be last argument") + break + end + elseif fnarg.opt then + opts = true + elseif opts then + return fail(ps, i, "non-optional arguments cannot follow optional arguments") + else + min_arity = min_arity + 1 end - elseif fnarg.opt then - opts = true - elseif opts then - return fail(ps, i, "non-optional arguments cannot follow optional arguments") - else - min_arity = min_arity + 1 end + return i, node, min_arity end - return i, node, min_arity -end @@ -3187,1014 +3218,982 @@ end -local function parse_argument_type(ps, i) - local opt = false - local is_va = false - local is_self = false - local argument_name = nil + local function parse_argument_type(ps, i) + local opt = false + local is_va = false + local is_self = false + local argument_name = nil - if ps.tokens[i].kind == "identifier" then - argument_name = ps.tokens[i].tk - if ps.tokens[i + 1].tk == "?" then + if ps.tokens[i].kind == "identifier" then + argument_name = ps.tokens[i].tk + if ps.tokens[i + 1].tk == "?" then + opt = true + if ps.tokens[i + 2].tk == ":" then + i = i + 3 + end + elseif ps.tokens[i + 1].tk == ":" then + i = i + 2 + end + elseif ps.tokens[i].kind == "?" then opt = true - if ps.tokens[i + 2].tk == ":" then - i = i + 3 + i = i + 1 + elseif ps.tokens[i].tk == "..." then + if ps.tokens[i + 1].tk == ":" then + i = i + 2 + is_va = true + else + return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") end - elseif ps.tokens[i + 1].tk == ":" then - i = i + 2 end - elseif ps.tokens[i].kind == "?" then - opt = true - i = i + 1 - elseif ps.tokens[i].tk == "..." then - if ps.tokens[i + 1].tk == ":" then - i = i + 2 - is_va = true - else - return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") - end - end - local typ; i, typ = parse_type(ps, i) - if typ then - if not is_va and ps.tokens[i].tk == "..." then - i = i + 1 - is_va = true - end + local typ; i, typ = parse_type(ps, i) + if typ then + if not is_va and ps.tokens[i].tk == "..." then + i = i + 1 + is_va = true + end - if argument_name == "self" then - is_self = true + if argument_name == "self" then + is_self = true + end end - end - return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 -end + return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 + end -parse_argument_type_list = function(ps, i) - local ars = {} - i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) - local t, list = new_tuple(ps, i) - local n = #ars - local min_arity = 0 - for l, ar in ipairs(ars) do - list[l] = ar.type - if ar.is_va and l < n then - fail(ps, ar.i, "'...' can only be last argument") + parse_argument_type_list = function(ps, i) + local ars = {} + i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) + local t, list = new_tuple(ps, i) + local n = #ars + local min_arity = 0 + for l, ar in ipairs(ars) do + list[l] = ar.type + if ar.is_va and l < n then + fail(ps, ar.i, "'...' can only be last argument") + end + if not ar.opt then + min_arity = min_arity + 1 + end end - if not ar.opt then - min_arity = min_arity + 1 + if n > 0 and ars[n].is_va then + t.is_va = true end + return i, t, (n > 0 and ars[1].is_self), min_arity end - if n > 0 and ars[n].is_va then - t.is_va = true + + local function parse_identifier(ps, i) + if ps.tokens[i].kind == "identifier" then + return i + 1, new_node(ps, i, "identifier") + end + i = fail(ps, i, "syntax error, expected identifier") + return i, new_node(ps, i, "error_node") end - return i, t, (n > 0 and ars[1].is_self), min_arity -end -local function parse_identifier(ps, i) - if ps.tokens[i].kind == "identifier" then - return i + 1, new_node(ps.tokens, i, "identifier") + local function parse_local_function(ps, i) + i = verify_tk(ps, i, "local") + i = verify_tk(ps, i, "function") + local node = new_node(ps, i - 2, "local_function") + i, node.name = parse_identifier(ps, i) + return parse_function_args_rets_body(ps, i, node) end - i = fail(ps, i, "syntax error, expected identifier") - return i, new_node(ps.tokens, i, "error_node") -end -local function parse_local_function(ps, i) - i = verify_tk(ps, i, "local") - i = verify_tk(ps, i, "function") - local node = new_node(ps.tokens, i - 2, "local_function") - i, node.name = parse_identifier(ps, i) - return parse_function_args_rets_body(ps, i, node) -end + local function parse_function(ps, i, fk) + local orig_i = i + i = verify_tk(ps, i, "function") + local fn = new_node(ps, i - 1, "global_function") + local names = {} + i, names[1] = parse_identifier(ps, i) + while ps.tokens[i].tk == "." do + i = i + 1 + i, names[#names + 1] = parse_identifier(ps, i) + end + if ps.tokens[i].tk == ":" then + i = i + 1 + i, names[#names + 1] = parse_identifier(ps, i) + fn.is_method = true + end -local function parse_function(ps, i, fk) - local orig_i = i - i = verify_tk(ps, i, "function") - local fn = new_node(ps.tokens, i - 1, "global_function") - local names = {} - i, names[1] = parse_identifier(ps, i) - while ps.tokens[i].tk == "." do - i = i + 1 - i, names[#names + 1] = parse_identifier(ps, i) - end - if ps.tokens[i].tk == ":" then - i = i + 1 - i, names[#names + 1] = parse_identifier(ps, i) - fn.is_method = true - end + if #names > 1 then + fn.kind = "record_function" + local owner = names[1] + owner.kind = "type_identifier" + for i2 = 2, #names - 1 do + local dot = an_operator(names[i2], 2, ".") + names[i2].kind = "identifier" + owner = { f = ps.filename, y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + end + fn.fn_owner = owner + end + fn.name = names[#names] - if #names > 1 then - fn.kind = "record_function" - local owner = names[1] - owner.kind = "type_identifier" - for i2 = 2, #names - 1 do - local dot = an_operator(names[i2], 2, ".") - names[i2].kind = "identifier" - owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y + i = parse_function_args_rets_body(ps, i, fn) + if fn.is_method and fn.args then + table.insert(fn.args, 1, { f = ps.filename, x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + fn.min_arity = fn.min_arity + 1 end - fn.fn_owner = owner - end - fn.name = names[#names] - local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y - i = parse_function_args_rets_body(ps, i, fn) - if fn.is_method then - table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) - fn.min_arity = fn.min_arity + 1 - end + if not fn.name then + return orig_i + 1 + end - if not fn.name then - return orig_i + 1 - end + if fn.kind == "record_function" and fk == "global" then + fail(ps, orig_i, "record functions cannot be annotated as 'global'") + elseif fn.kind == "global_function" and fk == "record" then + fn.implicit_global_function = true + end - if fn.kind == "record_function" and fk == "global" then - fail(ps, orig_i, "record functions cannot be annotated as 'global'") - elseif fn.kind == "global_function" and fk == "record" then - fn.implicit_global_function = true + return i, fn end - return i, fn -end - -local function parse_if_block(ps, i, n, node, is_else) - local block = new_node(ps.tokens, i, "if_block") - i = i + 1 - block.if_parent = node - block.if_block_n = n - if not is_else then - i, block.exp = parse_expression_and_tk(ps, i, "then") - if not block.exp then + local function parse_if_block(ps, i, n, node, is_else) + local block = new_node(ps, i, "if_block") + i = i + 1 + block.if_parent = node + block.if_block_n = n + if not is_else then + i, block.exp = parse_expression_and_tk(ps, i, "then") + if not block.exp then + return i + end + end + i, block.body = parse_statements(ps, i) + if not block.body then return i end + end_at(block.body, ps.tokens[i - 1]) + block.yend, block.xend = block.body.yend, block.body.xend + table.insert(node.if_blocks, block) + return i, node end - i, block.body = parse_statements(ps, i) - if not block.body then - return i - end - end_at(block.body, ps.tokens[i - 1]) - block.yend, block.xend = block.body.yend, block.body.xend - table.insert(node.if_blocks, block) - return i, node -end -local function parse_if(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "if") - node.if_blocks = {} - i, node = parse_if_block(ps, i, 1, node) - if not node then - return i - end - local n = 2 - while ps.tokens[i].tk == "elseif" do - i, node = parse_if_block(ps, i, n, node) + local function parse_if(ps, i) + local istart = i + local node = new_node(ps, i, "if") + node.if_blocks = {} + i, node = parse_if_block(ps, i, 1, node) if not node then return i end - n = n + 1 + local n = 2 + while ps.tokens[i].tk == "elseif" do + i, node = parse_if_block(ps, i, n, node) + if not node then + return i + end + n = n + 1 + end + if ps.tokens[i].tk == "else" then + i, node = parse_if_block(ps, i, n, node, true) + if not node then + return i + end + end + i = verify_end(ps, i, istart, node) + return i, node end - if ps.tokens[i].tk == "else" then - i, node = parse_if_block(ps, i, n, node, true) - if not node then - return i + + local function parse_while(ps, i) + local istart = i + local node = new_node(ps, i, "while") + i = verify_tk(ps, i, "while") + i, node.exp = parse_expression_and_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node + end + + local function parse_fornum(ps, i) + local istart = i + local node = new_node(ps, i, "fornum") + i = i + 1 + i, node.var = parse_identifier(ps, i) + i = verify_tk(ps, i, "=") + i, node.from = parse_expression_and_tk(ps, i, ",") + i, node.to = parse_expression(ps, i) + if ps.tokens[i].tk == "," then + i = i + 1 + i, node.step = parse_expression_and_tk(ps, i, "do") + else + i = verify_tk(ps, i, "do") end + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_while(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "while") - i = verify_tk(ps, i, "while") - i, node.exp = parse_expression_and_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_fornum(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "fornum") - i = i + 1 - i, node.var = parse_identifier(ps, i) - i = verify_tk(ps, i, "=") - i, node.from = parse_expression_and_tk(ps, i, ",") - i, node.to = parse_expression(ps, i) - if ps.tokens[i].tk == "," then + local function parse_forin(ps, i) + local istart = i + local node = new_node(ps, i, "forin") i = i + 1 - i, node.step = parse_expression_and_tk(ps, i, "do") - else + node.vars = new_node(ps, i, "variable_list") + i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) + i = verify_tk(ps, i, "in") + node.exps = new_node(ps, i, "expression_list") + i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) + if #node.exps < 1 then + return fail(ps, i, "missing iterator expression in generic for") + elseif #node.exps > 3 then + return fail(ps, i, "too many expressions in generic for") + end i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_forin(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "forin") - i = i + 1 - node.vars = new_node(ps.tokens, i, "variable_list") - i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) - i = verify_tk(ps, i, "in") - node.exps = new_node(ps.tokens, i, "expression_list") - i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) - if #node.exps < 1 then - return fail(ps, i, "missing iterator expression in generic for") - elseif #node.exps > 3 then - return fail(ps, i, "too many expressions in generic for") - end - i = verify_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_for(ps, i) - if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then - return parse_fornum(ps, i) - else - return parse_forin(ps, i) + local function parse_for(ps, i) + if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then + return parse_fornum(ps, i) + else + return parse_forin(ps, i) + end end -end -local function parse_repeat(ps, i) - local node = new_node(ps.tokens, i, "repeat") - i = verify_tk(ps, i, "repeat") - i, node.body = parse_statements(ps, i) - node.body.is_repeat = true - i = verify_tk(ps, i, "until") - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end + local function parse_repeat(ps, i) + local node = new_node(ps, i, "repeat") + i = verify_tk(ps, i, "repeat") + i, node.body = parse_statements(ps, i) + node.body.is_repeat = true + i = verify_tk(ps, i, "until") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node + end -local function parse_do(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "do") - i = verify_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end + local function parse_do(ps, i) + local istart = i + local node = new_node(ps, i, "do") + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node + end -local function parse_break(ps, i) - local node = new_node(ps.tokens, i, "break") - i = verify_tk(ps, i, "break") - return i, node -end + local function parse_break(ps, i) + local node = new_node(ps, i, "break") + i = verify_tk(ps, i, "break") + return i, node + end -local function parse_goto(ps, i) - local node = new_node(ps.tokens, i, "goto") - i = verify_tk(ps, i, "goto") - node.label = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") - return i, node -end + local function parse_goto(ps, i) + local node = new_node(ps, i, "goto") + i = verify_tk(ps, i, "goto") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + return i, node + end -local function parse_label(ps, i) - local node = new_node(ps.tokens, i, "label") - i = verify_tk(ps, i, "::") - node.label = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") - i = verify_tk(ps, i, "::") - return i, node -end + local function parse_label(ps, i) + local node = new_node(ps, i, "label") + i = verify_tk(ps, i, "::") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + i = verify_tk(ps, i, "::") + return i, node + end -local stop_statement_list = { - ["end"] = true, - ["else"] = true, - ["elseif"] = true, - ["until"] = true, -} + local stop_statement_list = { + ["end"] = true, + ["else"] = true, + ["elseif"] = true, + ["until"] = true, + } -local stop_return_list = { - [";"] = true, - ["$EOF$"] = true, -} + local stop_return_list = { + [";"] = true, + ["$EOF$"] = true, + } -for k, v in pairs(stop_statement_list) do - stop_return_list[k] = v -end + for k, v in pairs(stop_statement_list) do + stop_return_list[k] = v + end -local function parse_return(ps, i) - local node = new_node(ps.tokens, i, "return") - i = verify_tk(ps, i, "return") - node.exps = new_node(ps.tokens, i, "expression_list") - i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) - if ps.tokens[i].kind == ";" then - i = i + 1 + local function parse_return(ps, i) + local node = new_node(ps, i, "return") + i = verify_tk(ps, i, "return") + node.exps = new_node(ps, i, "expression_list") + i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) + if ps.tokens[i].kind == ";" then + i = i + 1 + end + return i, node end - return i, node -end -local function store_field_in_record(ps, i, field_name, t, fields, field_order) - if not fields[field_name] then - fields[field_name] = t - table.insert(field_order, field_name) - else - local prev_t = fields[field_name] - if t.typename == "function" and prev_t.typename == "function" then - local p = new_type(ps, i, "poly") - p.types = { prev_t, t } - fields[field_name] = p - elseif t.typename == "function" and prev_t.typename == "poly" then - table.insert(prev_t.types, t) + local function store_field_in_record(ps, i, field_name, t, fields, field_order) + if not fields[field_name] then + fields[field_name] = t + table.insert(field_order, field_name) else - fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") - return false + local prev_t = fields[field_name] + if t.typename == "function" and prev_t.typename == "function" then + local p = new_type(ps, i, "poly") + p.types = { prev_t, t } + fields[field_name] = p + elseif t.typename == "function" and prev_t.typename == "poly" then + table.insert(prev_t.types, t) + else + fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") + return false + end end + return true end - return true -end -local function parse_nested_type(ps, i, def, typename, parse_body) - i = i + 1 - local iv = i + local function parse_nested_type(ps, i, def, typename, parse_body) + i = i + 1 + local iv = i - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") - if not v then - return fail(ps, i, "expected a variable name") - end + local v + i, v = verify_kind(ps, i, "identifier", "type_identifier") + if not v then + return fail(ps, i, "expected a variable name") + end - local nt = new_node(ps.tokens, i - 2, "newtype") - local ndef = new_type(ps, i, typename) - local iok = parse_body(ps, i, ndef, nt) - if iok then - i = iok - nt.newtype = new_typedecl(ps, i, ndef) - end + local nt = new_node(ps, i - 2, "newtype") + local ndef = new_type(ps, i, typename) + local itype = i + local iok = parse_body(ps, i, ndef, nt) + if iok then + i = iok + nt.newtype = new_typedecl(ps, itype, ndef) + end - store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) - return i -end + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) + return i + end -parse_enum_body = function(ps, i, def, node) - local istart = i - 1 - def.enumset = {} - while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do - local item - i, item = verify_kind(ps, i, "string", "enum_item") - if item then - table.insert(node, item) - def.enumset[unquote(item.tk)] = true + parse_enum_body = function(ps, i, def, node) + local istart = i - 1 + def.enumset = {} + while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do + local item + i, item = verify_kind(ps, i, "string", "enum_item") + if item then + table.insert(node, item) + def.enumset[unquote(item.tk)] = true + end end + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -local metamethod_names = { - ["__add"] = true, - ["__sub"] = true, - ["__mul"] = true, - ["__div"] = true, - ["__mod"] = true, - ["__pow"] = true, - ["__unm"] = true, - ["__idiv"] = true, - ["__band"] = true, - ["__bor"] = true, - ["__bxor"] = true, - ["__bnot"] = true, - ["__shl"] = true, - ["__shr"] = true, - ["__concat"] = true, - ["__len"] = true, - ["__eq"] = true, - ["__lt"] = true, - ["__le"] = true, - ["__index"] = true, - ["__newindex"] = true, - ["__call"] = true, - ["__tostring"] = true, - ["__pairs"] = true, - ["__gc"] = true, - ["__close"] = true, - ["__is"] = true, -} - -local function parse_macroexp(ps, istart, iargs) + local metamethod_names = { + ["__add"] = true, + ["__sub"] = true, + ["__mul"] = true, + ["__div"] = true, + ["__mod"] = true, + ["__pow"] = true, + ["__unm"] = true, + ["__idiv"] = true, + ["__band"] = true, + ["__bor"] = true, + ["__bxor"] = true, + ["__bnot"] = true, + ["__shl"] = true, + ["__shr"] = true, + ["__concat"] = true, + ["__len"] = true, + ["__eq"] = true, + ["__lt"] = true, + ["__le"] = true, + ["__index"] = true, + ["__newindex"] = true, + ["__call"] = true, + ["__tostring"] = true, + ["__pairs"] = true, + ["__gc"] = true, + ["__close"] = true, + ["__is"] = true, + } + local function parse_macroexp(ps, istart, iargs) - local node = new_node(ps.tokens, istart, "macroexp") - local i - i, node.args, node.min_arity = parse_argument_list(ps, iargs) - i, node.rets = parse_return_types(ps, i) - i = verify_tk(ps, i, "return") - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i]) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_where_clause(ps, i) - local node = new_node(ps.tokens, i, "macroexp") - - local selftype = new_type(ps, i, "nominal") - selftype.names = { "@self" } - - node.args = new_node(ps.tokens, i, "argument_list") - node.args[1] = new_node(ps.tokens, i, "argument") - node.args[1].tk = "self" - node.args[1].argtype = selftype - node.min_arity = 1 - node.rets = new_tuple(ps, i) - node.rets.tuple[1] = BOOLEAN - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end -parse_interface_name = function(ps, i) - local istart = i - local typ - i, typ = parse_simple_type_or_nominal(ps, i) - if not (typ.typename == "nominal") then - return fail(ps, istart, "expected an interface") + local node = new_node(ps, istart, "macroexp") + local i + i, node.args, node.min_arity = parse_argument_list(ps, iargs) + i, node.rets = parse_return_types(ps, i) + i = verify_tk(ps, i, "return") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + return i, node end - return i, typ -end -local function parse_array_interface_type(ps, i, def) - if def.interface_list then - local first = def.interface_list[1] - if first.typename == "array" then - return failskip(ps, i, "duplicated declaration of array element type", parse_type) - end - end - local t - i, t = parse_base_type(ps, i) - if not t then - return i - end - if not (t.typename == "array") then - fail(ps, i, "expected an array declaration") - return i + local function parse_where_clause(ps, i) + local node = new_node(ps, i, "macroexp") + node.args = new_node(ps, i, "argument_list") + node.args[1] = new_node(ps, i, "argument") + node.args[1].tk = "self" + node.args[1].argtype = new_nominal(ps, i, "@self") + node.min_arity = 1 + node.rets = new_tuple(ps, i) + node.rets.tuple[1] = new_type(ps, i, "boolean") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node end - def.elements = t.elements - return i, t -end -parse_record_body = function(ps, i, def, node) - local istart = i - 1 - def.fields = {} - def.field_order = {} - - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + parse_interface_name = function(ps, i) + local istart = i + local typ + i, typ = parse_simple_type_or_nominal(ps, i) + if not (typ.typename == "nominal") then + return fail(ps, istart, "expected an interface") + end + return i, typ end - if ps.tokens[i].tk == "{" then - local atype - i, atype = parse_array_interface_type(ps, i, def) - if atype then - def.interface_list = { atype } + local function parse_array_interface_type(ps, i, def) + if def.interface_list then + local first = def.interface_list[1] + if first.typename == "array" then + return failskip(ps, i, "duplicated declaration of array element type", parse_type) + end + end + local t + i, t = parse_base_type(ps, i) + if not t then + return i + end + if not (t.typename == "array") then + fail(ps, i, "expected an array declaration") + return i end + def.elements = t.elements + return i, t end - if ps.tokens[i].tk == "is" then - i = i + 1 + parse_record_body = function(ps, i, def, node) + local istart = i - 1 + def.fields = {} + def.field_order = {} + + if ps.tokens[i].tk == "<" then + i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end if ps.tokens[i].tk == "{" then local atype i, atype = parse_array_interface_type(ps, i, def) - if ps.tokens[i].tk == "," then - i = i + 1 - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) - else - def.interface_list = {} - end if atype then - table.insert(def.interface_list, 1, atype) + def.interface_list = { atype } end - else - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end - end - if ps.tokens[i].tk == "where" then - local wstart = i - i = i + 1 - local where_macroexp - i, where_macroexp = parse_where_clause(ps, i) - - local typ = new_type(ps, wstart, "function") - typ.is_method = true - typ.min_arity = 1 - typ.args = a_type("tuple", { tuple = { - a_type("nominal", { - y = typ.y, - x = typ.x, - filename = ps.filename, - names = { "@self" }, - }), - } }) - typ.rets = a_type("tuple", { tuple = { BOOLEAN } }) - typ.macroexp = where_macroexp - - def.meta_fields = {} - def.meta_field_order = {} - store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) - end - - while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do - local tn = ps.tokens[i].tk - if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then - if def.is_userdata then - fail(ps, i, "duplicated 'userdata' declaration") + if ps.tokens[i].tk == "is" then + i = i + 1 + + if ps.tokens[i].tk == "{" then + local atype + i, atype = parse_array_interface_type(ps, i, def) + if ps.tokens[i].tk == "," then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + else + def.interface_list = {} + end + if atype then + table.insert(def.interface_list, 1, atype) + end else - def.is_userdata = true + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + end + + if ps.tokens[i].tk == "where" then + local wstart = i i = i + 1 - elseif ps.tokens[i].tk == "{" then - return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") - elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then - i = i + 1 - local iv = i - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") - if not v then - return fail(ps, i, "expected a variable name") - end - i = verify_tk(ps, i, "=") - local nt - i, nt = parse_newtype(ps, i) - if not nt or not nt.newtype then - return fail(ps, i, "expected a type definition") - end + local where_macroexp + i, where_macroexp = parse_where_clause(ps, i) + + local typ = new_type(ps, wstart, "function") + typ.is_method = true + typ.min_arity = 1 + typ.args = new_tuple(ps, wstart, { + a_nominal(where_macroexp, { "@self" }), + }) + typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) + typ.macroexp = where_macroexp - local ntt = nt.newtype - if ntt.typename == "typealias" then - ntt.is_nested_alias = true - end + def.meta_fields = {} + def.meta_field_order = {} + store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) + end - store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) - elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then - i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) - else - local is_metamethod = false - if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then - is_metamethod = true + while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do + local tn = ps.tokens[i].tk + if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then + if def.is_userdata then + fail(ps, i, "duplicated 'userdata' declaration") + else + def.is_userdata = true + end i = i + 1 - end + elseif ps.tokens[i].tk == "{" then + return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") + elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then + i = i + 1 + local iv = i + local v + i, v = verify_kind(ps, i, "identifier", "type_identifier") + if not v then + return fail(ps, i, "expected a variable name") + end + i = verify_tk(ps, i, "=") + local nt + i, nt = parse_newtype(ps, i) + if not nt or not nt.newtype then + return fail(ps, i, "expected a type definition") + end - local v - if ps.tokens[i].tk == "[" then - i, v = parse_literal(ps, i + 1) - if v and not v.conststr then - return fail(ps, i, "expected a string literal") + local ntt = nt.newtype + if ntt.typename == "typealias" then + ntt.is_nested_alias = true end - i = verify_tk(ps, i, "]") + + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) + elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then + i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) else - i, v = verify_kind(ps, i, "identifier", "variable") - end - local iv = i - if not v then - return fail(ps, i, "expected a variable name") - end + local is_metamethod = false + if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then + is_metamethod = true + i = i + 1 + end - if ps.tokens[i].tk == ":" then - i = i + 1 - local t - i, t = parse_type(ps, i) - if not t then - return fail(ps, i, "expected a type") + local v + if ps.tokens[i].tk == "[" then + i, v = parse_literal(ps, i + 1) + if v and not v.conststr then + return fail(ps, i, "expected a string literal") + end + i = verify_tk(ps, i, "]") + else + i, v = verify_kind(ps, i, "identifier", "variable") end + local iv = i + if not v then + return fail(ps, i, "expected a variable name") + end + + if ps.tokens[i].tk == ":" then + i = i + 1 + local t + i, t = parse_type(ps, i) + if not t then + return fail(ps, i, "expected a type") + end - local field_name = v.conststr or v.tk - local fields = def.fields - local field_order = def.field_order - if is_metamethod then - if not def.meta_fields then - def.meta_fields = {} - def.meta_field_order = {} + local field_name = v.conststr or v.tk + local fields = def.fields + local field_order = def.field_order + if is_metamethod then + if not def.meta_fields then + def.meta_fields = {} + def.meta_field_order = {} + end + fields = def.meta_fields + field_order = def.meta_field_order + if not metamethod_names[field_name] then + fail(ps, i - 1, "not a valid metamethod: " .. field_name) + end end - fields = def.meta_fields - field_order = def.meta_field_order - if not metamethod_names[field_name] then - fail(ps, i - 1, "not a valid metamethod: " .. field_name) + + if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then + if not (t.typename == "function") then + fail(ps, i + 1, "macroexp must have a function type") + else + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + end end - end - if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then - if not (t.typename == "function") then - fail(ps, i + 1, "macroexp must have a function type") + store_field_in_record(ps, iv, field_name, t, fields, field_order) + elseif ps.tokens[i].tk == "=" then + local next_word = ps.tokens[i + 1].tk + if next_word == "record" or next_word == "enum" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") + elseif next_word == "functiontype" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") else - i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") end - end - - store_field_in_record(ps, iv, field_name, t, fields, field_order) - elseif ps.tokens[i].tk == "=" then - local next_word = ps.tokens[i + 1].tk - if next_word == "record" or next_word == "enum" then - return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") - elseif next_word == "functiontype" then - return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") else - return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") + fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") end - else - fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") end end + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -parse_type_body_fns = { - ["interface"] = parse_record_body, - ["record"] = parse_record_body, - ["enum"] = parse_enum_body, -} -parse_newtype = function(ps, i) - local node = new_node(ps.tokens, i, "newtype") - local def - local tn = ps.tokens[i].tk - local itype = i - if parse_type_body_fns[tn] then - def = new_type(ps, i, tn) - i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node) - if not def then - return fail(ps, i, "expected a type") - end + parse_type_body_fns = { + ["interface"] = parse_record_body, + ["record"] = parse_record_body, + ["enum"] = parse_enum_body, + } - node.newtype = new_typedecl(ps, itype, def) - return i, node - else - i, def = parse_type(ps, i) - if not def then - return fail(ps, i, "expected a type") - end + parse_newtype = function(ps, i) + local node = new_node(ps, i, "newtype") + local def + local tn = ps.tokens[i].tk + local itype = i + if parse_type_body_fns[tn] then + def = new_type(ps, i, tn) + i = i + 1 + i = parse_type_body_fns[tn](ps, i, def, node) + if not def then + return fail(ps, i, "expected a type") + end - if def.typename == "nominal" then - local typealias = new_type(ps, itype, "typealias") - typealias.alias_to = def - node.newtype = typealias - else node.newtype = new_typedecl(ps, itype, def) - end - - return i, node - end -end + return i, node + else + i, def = parse_type(ps, i) + if not def then + return fail(ps, i, "expected a type") + end -local function parse_assignment_expression_list(ps, i, asgn) - asgn.exps = new_node(ps.tokens, i, "expression_list") - repeat - i = i + 1 - local val - i, val = parse_expression(ps, i) - if not val then - if #asgn.exps == 0 then - asgn.exps = nil + if def.typename == "nominal" then + node.newtype = new_typealias(ps, itype, def) + else + node.newtype = new_typedecl(ps, itype, def) end - return i - end - table.insert(asgn.exps, val) - until ps.tokens[i].tk ~= "," - return i, asgn -end -local parse_call_or_assignment -do - local function is_lvalue(node) - node.is_lvalue = node.kind == "variable" or - (node.kind == "op" and - (node.op.op == "@index" or node.op.op == ".")) - return node.is_lvalue + return i, node + end end - local function parse_variable(ps, i) - local node - i, node = parse_expression(ps, i) - if not (node and is_lvalue(node)) then - return fail(ps, i, "expected a variable") - end - return i, node + local function parse_assignment_expression_list(ps, i, asgn) + asgn.exps = new_node(ps, i, "expression_list") + repeat + i = i + 1 + local val + i, val = parse_expression(ps, i) + if not val then + if #asgn.exps == 0 then + asgn.exps = nil + end + return i + end + table.insert(asgn.exps, val) + until ps.tokens[i].tk ~= "," + return i, asgn end - parse_call_or_assignment = function(ps, i) - local exp - local istart = i - i, exp = parse_expression(ps, i) - if not exp then - return i + local parse_call_or_assignment + do + local function is_lvalue(node) + node.is_lvalue = node.kind == "variable" or + (node.kind == "op" and + (node.op.op == "@index" or node.op.op == ".")) + return node.is_lvalue end - if (exp.op and exp.op.op == "@funcall") or exp.failstore then - return i, exp + local function parse_variable(ps, i) + local node + i, node = parse_expression(ps, i) + if not (node and is_lvalue(node)) then + return fail(ps, i, "expected a variable") + end + return i, node end - if not is_lvalue(exp) then - return fail(ps, i, "syntax error") - end + parse_call_or_assignment = function(ps, i) + local exp + local istart = i + i, exp = parse_expression(ps, i) + if not exp then + return i + end - local asgn = new_node(ps.tokens, istart, "assignment") - asgn.vars = new_node(ps.tokens, istart, "variable_list") - asgn.vars[1] = exp - if ps.tokens[i].tk == "," then - i = i + 1 - i = parse_trying_list(ps, i, asgn.vars, parse_variable) - if #asgn.vars < 2 then - return fail(ps, i, "syntax error") + if (exp.op and exp.op.op == "@funcall") or exp.failstore then + return i, exp end - end - if ps.tokens[i].tk ~= "=" then - verify_tk(ps, i, "=") - return i - end + if not is_lvalue(exp) then + return fail(ps, i, "syntax error") + end - i, asgn = parse_assignment_expression_list(ps, i, asgn) - return i, asgn - end -end + local asgn = new_node(ps, istart, "assignment") + asgn.vars = new_node(ps, istart, "variable_list") + asgn.vars[1] = exp + if ps.tokens[i].tk == "," then + i = i + 1 + i = parse_trying_list(ps, i, asgn.vars, parse_variable) + if #asgn.vars < 2 then + return fail(ps, i, "syntax error") + end + end -local function parse_variable_declarations(ps, i, node_name) - local asgn = new_node(ps.tokens, i, node_name) + if ps.tokens[i].tk ~= "=" then + verify_tk(ps, i, "=") + return i + end - asgn.vars = new_node(ps.tokens, i, "variable_list") - i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) - if #asgn.vars == 0 then - return fail(ps, i, "expected a local variable definition") + i, asgn = parse_assignment_expression_list(ps, i, asgn) + return i, asgn + end end - i, asgn.decltuple = parse_type_list(ps, i, "decltuple") - - if ps.tokens[i].tk == "=" then + local function parse_variable_declarations(ps, i, node_name) + local asgn = new_node(ps, i, node_name) - local next_word = ps.tokens[i + 1].tk - local tn = next_word - if parse_type_body_fns[tn] then - local scope = node_name == "local_declaration" and "local" or "global" - return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. next_word .. " " .. asgn.vars[1].tk .. "'", skip_type_body) - elseif next_word == "functiontype" then - local scope = node_name == "local_declaration" and "local" or "global" - return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) + asgn.vars = new_node(ps, i, "variable_list") + i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) + if #asgn.vars == 0 then + return fail(ps, i, "expected a local variable definition") end - i, asgn = parse_assignment_expression_list(ps, i, asgn) - end - return i, asgn -end + i, asgn.decltuple = parse_type_list(ps, i, "decltuple") -local function parse_type_declaration(ps, i, node_name) - i = i + 2 + if ps.tokens[i].tk == "=" then - local asgn = new_node(ps.tokens, i, node_name) - i, asgn.var = parse_variable_name(ps, i) - if not asgn.var then - return fail(ps, i, "expected a type name") - end + local next_word = ps.tokens[i + 1].tk + local tn = next_word + if parse_type_body_fns[tn] then + local scope = node_name == "local_declaration" and "local" or "global" + return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. next_word .. " " .. asgn.vars[1].tk .. "'", skip_type_body) + elseif next_word == "functiontype" then + local scope = node_name == "local_declaration" and "local" or "global" + return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) + end - if node_name == "global_type" and ps.tokens[i].tk ~= "=" then + i, asgn = parse_assignment_expression_list(ps, i, asgn) + end return i, asgn end - i = verify_tk(ps, i, "=") + local function parse_type_declaration(ps, i, node_name) + i = i + 2 - if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + local asgn = new_node(ps, i, node_name) + i, asgn.var = parse_variable_name(ps, i) + if not asgn.var then + return fail(ps, i, "expected a type name") end - return i, asgn - end - i, asgn.value = parse_newtype(ps, i) - if not asgn.value then - return i - end + if node_name == "global_type" and ps.tokens[i].tk ~= "=" then + return i, asgn + end + + i = verify_tk(ps, i, "=") - local nt = asgn.value.newtype - if nt.typename == "typedecl" then - local def = nt.def - if def.fields or def.typename == "enum" then - if not def.declname then - def.declname = asgn.var.tk + if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then + local istart = i + i, asgn.value = parse_call_or_assignment(ps, i) + if asgn.value and not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") end + return i, asgn end - end - - return i, asgn -end -local function parse_type_constructor(ps, i, node_name, type_name, parse_body) - local asgn = new_node(ps.tokens, i, node_name) - local nt = new_node(ps.tokens, i, "newtype") - asgn.value = nt - local itype = i - local def = new_type(ps, i, type_name) + i, asgn.value = parse_newtype(ps, i) + if not asgn.value then + return i + end - i = i + 2 + local nt = asgn.value.newtype + if nt.typename == "typedecl" then + local def = nt.def + if def.fields or def.typename == "enum" then + if not def.declname then + def.declname = asgn.var.tk + end + end + end - i, asgn.var = verify_kind(ps, i, "identifier") - if not asgn.var then - return fail(ps, i, "expected a type name") + return i, asgn end - assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") - def.declname = asgn.var.tk + local function parse_type_constructor(ps, i, node_name, type_name, parse_body) + local asgn = new_node(ps, i, node_name) + local nt = new_node(ps, i, "newtype") + asgn.value = nt + local itype = i + local def = new_type(ps, i, type_name) - i = parse_body(ps, i, def, nt) + i = i + 2 - nt.newtype = new_typedecl(ps, itype, def) + i, asgn.var = verify_kind(ps, i, "identifier") + if not asgn.var then + return fail(ps, i, "expected a type name") + end - return i, asgn -end + assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") + def.declname = asgn.var.tk -local function skip_type_declaration(ps, i) - return parse_type_declaration(ps, i - 1, "local_type") -end + i = parse_body(ps, i, def, nt) -local function parse_local_macroexp(ps, i) - local istart = i - i = i + 2 - local node = new_node(ps.tokens, i, "local_macroexp") - i, node.name = parse_identifier(ps, i) - i, node.macrodef = parse_macroexp(ps, istart, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end + nt.newtype = new_typedecl(ps, itype, def) -local function parse_local(ps, i) - local ntk = ps.tokens[i + 1].tk - local tn = ntk - if ntk == "function" then - return parse_local_function(ps, i) - elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") - elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then - return parse_local_macroexp(ps, i) - elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) - end - return parse_variable_declarations(ps, i + 1, "local_declaration") -end + return i, asgn + end -local function parse_global(ps, i) - local ntk = ps.tokens[i + 1].tk - local tn = ntk - if ntk == "function" then - return parse_function(ps, i + 1, "global") - elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") - elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) - elseif ps.tokens[i + 1].kind == "identifier" then - return parse_variable_declarations(ps, i + 1, "global_declaration") - end - return parse_call_or_assignment(ps, i) -end + local function skip_type_declaration(ps, i) + return parse_type_declaration(ps, i - 1, "local_type") + end -local function parse_record_function(ps, i) - return parse_function(ps, i, "record") -end + local function parse_local_macroexp(ps, i) + local istart = i + i = i + 2 + local node = new_node(ps, i, "local_macroexp") + i, node.name = parse_identifier(ps, i) + i, node.macrodef = parse_macroexp(ps, istart, i) + end_at(node, ps.tokens[i - 1]) + return i, node + end -local parse_statement_fns = { - ["::"] = parse_label, - ["do"] = parse_do, - ["if"] = parse_if, - ["for"] = parse_for, - ["goto"] = parse_goto, - ["local"] = parse_local, - ["while"] = parse_while, - ["break"] = parse_break, - ["global"] = parse_global, - ["repeat"] = parse_repeat, - ["return"] = parse_return, - ["function"] = parse_record_function, -} + local function parse_local(ps, i) + local ntk = ps.tokens[i + 1].tk + local tn = ntk + if ntk == "function" then + return parse_local_function(ps, i) + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "local_type") + elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then + return parse_local_macroexp(ps, i) + elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) + end + return parse_variable_declarations(ps, i + 1, "local_declaration") + end + + local function parse_global(ps, i) + local ntk = ps.tokens[i + 1].tk + local tn = ntk + if ntk == "function" then + return parse_function(ps, i + 1, "global") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "global_type") + elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) + elseif ps.tokens[i + 1].kind == "identifier" then + return parse_variable_declarations(ps, i + 1, "global_declaration") + end + return parse_call_or_assignment(ps, i) + end + + local function parse_record_function(ps, i) + return parse_function(ps, i, "record") + end + + local parse_statement_fns = { + ["::"] = parse_label, + ["do"] = parse_do, + ["if"] = parse_if, + ["for"] = parse_for, + ["goto"] = parse_goto, + ["local"] = parse_local, + ["while"] = parse_while, + ["break"] = parse_break, + ["global"] = parse_global, + ["repeat"] = parse_repeat, + ["return"] = parse_return, + ["function"] = parse_record_function, + } -local function type_needs_local_or_global(ps, i) - local tk = ps.tokens[i].tk - return failskip(ps, i, ("%s needs to be declared with 'local %s' or 'global %s'"):format(tk, tk, tk), skip_type_body) -end + local function type_needs_local_or_global(ps, i) + local tk = ps.tokens[i].tk + return failskip(ps, i, ("%s needs to be declared with 'local %s' or 'global %s'"):format(tk, tk, tk), skip_type_body) + end -local needs_local_or_global = { - ["type"] = function(ps, i) - return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) - end, - ["record"] = type_needs_local_or_global, - ["enum"] = type_needs_local_or_global, -} + local needs_local_or_global = { + ["type"] = function(ps, i) + return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) + end, + ["record"] = type_needs_local_or_global, + ["enum"] = type_needs_local_or_global, + } -parse_statements = function(ps, i, toplevel) - local node = new_node(ps.tokens, i, "statements") - local item - while true do - while ps.tokens[i].kind == ";" do - i = i + 1 - if item then - item.semicolon = true + parse_statements = function(ps, i, toplevel) + local node = new_node(ps, i, "statements") + local item + while true do + while ps.tokens[i].kind == ";" do + i = i + 1 + if item then + item.semicolon = true + end end - end - if ps.tokens[i].kind == "$EOF$" then - break - end - local tk = ps.tokens[i].tk - if (not toplevel) and stop_statement_list[tk] then - break - end + if ps.tokens[i].kind == "$EOF$" then + break + end + local tk = ps.tokens[i].tk + if (not toplevel) and stop_statement_list[tk] then + break + end - local fn = parse_statement_fns[tk] - if not fn then - local skip_fn = needs_local_or_global[tk] - if skip_fn and ps.tokens[i + 1].kind == "identifier" then - fn = skip_fn - else - fn = parse_call_or_assignment + local fn = parse_statement_fns[tk] + if not fn then + local skip_fn = needs_local_or_global[tk] + if skip_fn and ps.tokens[i + 1].kind == "identifier" then + fn = skip_fn + else + fn = parse_call_or_assignment + end end - end - i, item = fn(ps, i) + i, item = fn(ps, i) - if item then - table.insert(node, item) - elseif i > 1 then + if item then + table.insert(node, item) + elseif i > 1 then - local lasty = ps.tokens[i - 1].y - while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do - i = i + 1 + local lasty = ps.tokens[i - 1].y + while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do + i = i + 1 + end end end - end - end_at(node, ps.tokens[i]) - return i, node -end - -local function clear_redundant_errors(errors) - local redundant = {} - local lastx, lasty = 0, 0 - for i, err in ipairs(errors) do - err.i = i + end_at(node, ps.tokens[i]) + return i, node end - table.sort(errors, function(a, b) - local af = a.filename or "" - local bf = b.filename or "" - return af < bf or - (af == bf and (a.y < b.y or - (a.y == b.y and (a.x < b.x or - (a.x == b.x and (a.i < b.i)))))) - end) - for i, err in ipairs(errors) do - err.i = nil - if err.x == lastx and err.y == lasty then - table.insert(redundant, i) + + function tl.parse_program(tokens, errs, filename) + errs = errs or {} + local ps = { + tokens = tokens, + errs = errs, + filename = filename or "", + required_modules = {}, + } + local i = 1 + local hashbang + if ps.tokens[i].kind == "hashbang" then + hashbang = ps.tokens[i].tk + i = i + 1 + end + local _, node = parse_statements(ps, i, true) + if hashbang then + node.hashbang = hashbang end - lastx, lasty = err.x, err.y - end - for i = #redundant, 1, -1 do - table.remove(errors, redundant[i]) - end -end -function tl.parse_program(tokens, errs, filename) - errs = errs or {} - local ps = { - tokens = tokens, - errs = errs, - filename = filename or "", - required_modules = {}, - } - local i = 1 - local hashbang - if ps.tokens[i].kind == "hashbang" then - hashbang = ps.tokens[i].tk - i = i + 1 + clear_redundant_errors(errs) + return node, ps.required_modules end - local _, node = parse_statements(ps, i, true) - if hashbang then - node.hashbang = hashbang + + function tl.parse(input, filename) + local tokens, errs = tl.lex(input, filename) + local node, required_modules = tl.parse_program(tokens, errs, filename) + return node, errs, required_modules end - clear_redundant_errors(errs) - return node, ps.required_modules end -function tl.parse(input, filename) - local tokens, errs = tl.lex(input, filename) - local node, required_modules = tl.parse_program(tokens, errs, filename) - return node, errs, required_modules -end + @@ -4307,7 +4306,7 @@ local function tl_debug_indent_pop(mark, single, y, x, fmt, ...) end end -local function recurse_type(ast, visit) +local function recurse_type(s, ast, visit) local kind = ast.typename if TL_DEBUG then @@ -4319,7 +4318,7 @@ local function recurse_type(ast, visit) if cbkind then local cbkind_before = cbkind.before if cbkind_before then - cbkind_before(ast) + cbkind_before(s, ast) end end @@ -4327,90 +4326,90 @@ local function recurse_type(ast, visit) if ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do - xs[i] = recurse_type(child, visit) + xs[i] = recurse_type(s, child, visit) end elseif ast.types then for _, child in ipairs(ast.types) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end elseif ast.typename == "map" then - table.insert(xs, recurse_type(ast.keys, visit)) - table.insert(xs, recurse_type(ast.values, visit)) + table.insert(xs, recurse_type(s, ast.keys, visit)) + table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast.fields then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end if ast.fields then for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.meta_fields then for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "function" then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.args then for _, child in ipairs(ast.args.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.rets then for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "nominal" then if ast.typevals then for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "typearg" then if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + table.insert(xs, recurse_type(s, ast.constraint, visit)) end elseif ast.typename == "array" then if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end elseif ast.typename == "literal_table_item" then if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) + table.insert(xs, recurse_type(s, ast.ktype, visit)) end if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) + table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast.typename == "typealias" then - table.insert(xs, recurse_type(ast.alias_to, visit)) + table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then - table.insert(xs, recurse_type(ast.def, visit)) + table.insert(xs, recurse_type(s, ast.def, visit)) end local ret local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end local visit_after = visit.after if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4420,15 +4419,16 @@ local function recurse_type(ast, visit) return ret end -local function recurse_typeargs(ast, visit_type) +local function recurse_typeargs(s, ast, visit_type) if ast.typeargs then for _, typearg in ipairs(ast.typeargs) do - recurse_type(typearg, visit_type) + recurse_type(s, typearg, visit_type) end end end local function extra_callback(name, + s, ast, xs, visit_node) @@ -4438,7 +4438,7 @@ local function extra_callback(name, if not nbs then return end local bs = nbs[name] if not bs then return end - bs(ast, xs) + bs(s, ast, xs) end local no_recurse_node = { @@ -4458,7 +4458,7 @@ local no_recurse_node = { ["type_identifier"] = true, } -local function recurse_node(root, +local function recurse_node(s, root, visit_node, visit_type) if not root then @@ -4477,9 +4477,9 @@ local function recurse_node(root, local function walk_vars_exps(ast, xs) xs[1] = recurse(ast.vars) if ast.decltuple then - xs[2] = recurse_type(ast.decltuple, visit_type) + xs[2] = recurse_type(s, ast.decltuple, visit_type) end - extra_callback("before_exp", ast, xs, visit_node) + extra_callback("before_exp", s, ast, xs, visit_node) if ast.exps then xs[3] = recurse(ast.exps) end @@ -4491,11 +4491,11 @@ local function recurse_node(root, end local function walk_named_function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) xs[2] = recurse(ast.args) - xs[3] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[4] = recurse(ast.body) end @@ -4508,9 +4508,9 @@ local function recurse_node(root, end xs[2] = p1 if ast.op.arity == 2 then - extra_callback("before_e2", ast, xs, visit_node) + extra_callback("before_e2", s, ast, xs, visit_node) if ast.op.op == "is" or ast.op.op == "as" then - xs[3] = recurse_type(ast.e2.casttype, visit_type) + xs[3] = recurse_type(s, ast.e2.casttype, visit_type) else xs[3] = recurse(ast.e2) end @@ -4528,7 +4528,7 @@ local function recurse_node(root, xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then - xs[3] = recurse_type(ast.itemtype, visit_type) + xs[3] = recurse_type(s, ast.itemtype, visit_type) end end, @@ -4554,13 +4554,13 @@ local function recurse_node(root, if ast.exp then xs[1] = recurse(ast.exp) end - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, ["while"] = function(ast, xs) xs[1] = recurse(ast.exp) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, @@ -4570,45 +4570,45 @@ local function recurse_node(root, end, ["macroexp"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[3] = recurse(ast.exp) end, ["function"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, ["local_function"] = walk_named_function, ["global_function"] = walk_named_function, ["record_function"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) - extra_callback("before_arguments", ast, xs, visit_node) + extra_callback("before_arguments", s, ast, xs, visit_node) xs[3] = recurse(ast.args) - xs[4] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[4] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, ["local_macroexp"] = function(ast, xs) xs[1] = recurse(ast.name) xs[2] = recurse(ast.macrodef.args) - xs[3] = recurse_type(ast.macrodef.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.macrodef.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[4] = recurse(ast.macrodef.exp) end, ["forin"] = function(ast, xs) xs[1] = recurse(ast.vars) xs[2] = recurse(ast.exps) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, @@ -4617,7 +4617,7 @@ local function recurse_node(root, xs[2] = recurse(ast.from) xs[3] = recurse(ast.to) xs[4] = ast.step and recurse(ast.step) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, @@ -4634,12 +4634,12 @@ local function recurse_node(root, end, ["newtype"] = function(ast, xs) - xs[1] = recurse_type(ast.newtype, visit_type) + xs[1] = recurse_type(s, ast.newtype, visit_type) end, ["argument"] = function(ast, xs) if ast.argtype then - xs[1] = recurse_type(ast.argtype, visit_type) + xs[1] = recurse_type(s, ast.argtype, visit_type) end end, } @@ -4658,7 +4658,7 @@ local function recurse_node(root, local cbkind = cbs and cbs[kind] if cbkind then if cbkind.before then - cbkind.before(ast) + cbkind.before(s, ast) end end @@ -4682,10 +4682,10 @@ local function recurse_node(root, local ret local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4789,7 +4789,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local save_indent = {} - local function increment_indent(node) + local function increment_indent(_, node) local child = node.body or node[1] if not child then return @@ -4890,7 +4890,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["total"] = " ", } - local function emit_exactly(node, _children) + local function emit_exactly(_, node, _children) local out = { y = node.y, h = 0 } add_string(out, node.tk) return out @@ -4900,7 +4900,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_node.cbs = { ["statements"] = { - after = function(node, children) + after = function(_, node, children) local out if opts.preserve_hashbang and node.hashbang then out = { y = 1, h = 0 } @@ -4922,7 +4922,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["local_declaration"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "local ") for i, var in ipairs(node.vars) do @@ -4948,7 +4948,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["local_type"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if not node.var.elide_type then table.insert(out, "local") @@ -4960,7 +4960,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["global_type"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) @@ -4971,7 +4971,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["global_declaration"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if children[3] then add_child(out, children[1]) @@ -4982,7 +4982,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["assignment"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") @@ -4991,7 +4991,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["if"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } for i, child in ipairs(children) do add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) @@ -5002,7 +5002,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["if_block"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.if_block_n == 1 then table.insert(out, "if") @@ -5022,7 +5022,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["while"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") @@ -5035,7 +5035,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["repeat"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") @@ -5047,7 +5047,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["do"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") @@ -5058,7 +5058,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["forin"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5073,7 +5073,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["fornum"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5093,7 +5093,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["return"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then @@ -5103,14 +5103,14 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["break"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["variable_list"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } local space for i, child in ipairs(children) do @@ -5125,7 +5125,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["literal_table"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if #children == 0 then table.insert(out, "{}") @@ -5145,7 +5145,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["literal_table_item"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then @@ -5168,13 +5168,13 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["local_macroexp"] = { before = increment_indent, - after = function(node, _children) + after = function(_, node, _children) return { y = node.y, h = 0 } end, }, ["local_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") @@ -5189,7 +5189,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["global_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5204,7 +5204,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["record_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5229,7 +5229,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) @@ -5243,7 +5243,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["cast"] = {}, ["paren"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) @@ -5252,7 +5252,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["op"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) @@ -5313,7 +5313,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["newtype"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } local nt = node.newtype if nt.typename == "typealias" then @@ -5330,7 +5330,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["goto"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) @@ -5338,7 +5338,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["label"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) @@ -5347,7 +5347,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["string"] = { - after = function(node, children) + after = function(_, node, children) @@ -5355,7 +5355,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) if node.tk:sub(1, 1) == "[" or gen_target ~= "5.1" or not node.tk:find("\\", 1, true) then - return emit_exactly(node, children) + return emit_exactly(nil, node, children) end local out = { y = node.y, h = 0 } @@ -5419,7 +5419,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local visit_type = {} visit_type.cbs = {} local default_type_visitor = { - after = function(typ, _children) + after = function(_, typ, _children) local out = { y = typ.y or -1, h = 0 } local r = typ.typename == "nominal" and typ.resolved or typ local lua_type = primitive[r.typename] or "table" @@ -5457,13 +5457,12 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] - local out = recurse_node(ast, visit_node, visit_type) + local out = recurse_node(nil, ast, visit_node, visit_type) if err then return nil, err end @@ -5513,7 +5512,6 @@ local typename_to_typecode = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, @@ -5521,8 +5519,8 @@ local typename_to_typecode = { local skip_types = { ["none"] = true, + ["tuple"] = true, ["literal_table_item"] = true, - ["unresolved"] = true, } local function sorted_keys(m) @@ -5545,6 +5543,7 @@ function tl.new_type_reporter() local self = { next_num = 1, typeid_to_num = {}, + typename_to_num = {}, tr = { by_pos = {}, types = {}, @@ -5552,6 +5551,24 @@ function tl.new_type_reporter() globals = {}, }, } + + local names = {} + for name, _ in pairs(simple_types) do + table.insert(names, name) + end + table.sort(names) + + for _, name in ipairs(names) do + local ti = { + t = assert(typename_to_typecode[name]), + str = name, + } + local n = self.next_num + self.typename_to_num[name] = n + self.tr.types[n] = ti + self.next_num = self.next_num + 1 + end + return setmetatable(self, { __index = TypeReporter }) end @@ -5571,9 +5588,15 @@ function TypeReporter:store_function(ti, rt) end function TypeReporter:get_typenum(t) + + local n = self.typename_to_num[t.typename] + if n then + return n + end + assert(t.typeid) - local n = self.typeid_to_num[t.typeid] + n = self.typeid_to_num[t.typeid] if n then return n end @@ -5597,7 +5620,7 @@ function TypeReporter:get_typenum(t) local ti = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), - file = t.filename, + file = t.f, y = t.y, x = t.x, } @@ -5667,7 +5690,7 @@ end function TypeReporter:get_collector(filename) - local tc = { + local collector = { filename = filename, symbol_list = {}, } @@ -5675,10 +5698,10 @@ function TypeReporter:get_collector(filename) local ft = {} self.tr.by_pos[filename] = ft - local symbol_list = tc.symbol_list + local symbol_list = collector.symbol_list local symbol_list_n = 0 - tc.store_type = function(y, x, typ) + collector.store_type = function(y, x, typ) if not typ or skip_types[typ.typename] then return end @@ -5692,12 +5715,12 @@ function TypeReporter:get_collector(filename) yt[x] = self:get_typenum(typ) end - tc.reserve_symbol_list_slot = function(node) + collector.reserve_symbol_list_slot = function(node) symbol_list_n = symbol_list_n + 1 node.symbol_list_slot = symbol_list_n end - tc.add_to_symbol_list = function(node, name, t) + collector.add_to_symbol_list = function(node, name, t) if not node then return end @@ -5711,12 +5734,12 @@ function TypeReporter:get_collector(filename) symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } end - tc.begin_symbol_list_scope = function(node) + collector.begin_symbol_list_scope = function(node) symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end - tc.end_symbol_list_scope = function(node) + collector.end_symbol_list_scope = function(node) if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -5726,14 +5749,14 @@ function TypeReporter:get_collector(filename) end end - return tc + return collector end -function TypeReporter:store_result(tc, globals) +function TypeReporter:store_result(collector, globals) local tr = self.tr - local filename = tc.filename - local symbol_list = tc.symbol_list + local filename = collector.filename + local symbol_list = collector.symbol_list tr.by_pos[filename][0] = nil @@ -5809,143 +5832,449 @@ function TypeReporter:get_report() end -function tl.get_types(result) - return result.env.reporter:get_report(), result.env.reporter + + + + +function tl.symbols_in_scope(tr, y, x, filename) + local function find(symbols, at_y, at_x) + local function le(a, b) + return a[1] < b[1] or + (a[1] == b[1] and a[2] <= b[2]) + end + return binary_search(symbols, { at_y, at_x }, le) or 0 + end + + local ret = {} + + local symbols = tr.symbols_by_file[filename] + if not symbols then + return ret + end + + local n = find(symbols, y, x) + + while n >= 1 do + local s = symbols[n] + if s[3] == "@{" then + n = n - 1 + elseif s[3] == "@}" then + n = s[4] + else + ret[s[3]] = s[4] + n = n - 1 + end + end + + return ret +end + + + + + +function Errors.new(filename) + local self = { + errors = {}, + warnings = {}, + unknown_dots = {}, + filename = filename, + } + return setmetatable(self, { __index = Errors }) +end + +local function Err(msg, t1, t2, t3) + if t1 then + local s1, s2, s3 + if t1.typename == "invalid" then + return nil + end + s1 = show_type(t1) + if t2 then + if t2.typename == "invalid" then + return nil + end + s2 = show_type(t2) + end + if t3 then + if t3.typename == "invalid" then + return nil + end + s3 = show_type(t3) + end + msg = msg:format(s1, s2, s3) + return { + msg = msg, + x = t1.x, + y = t1.y, + filename = t1.f, + } + end + + return { + msg = msg, + } +end + +local function insert_error(self, y, x, err) + err.y = assert(y) + err.x = assert(x) + err.filename = self.filename + + if TL_DEBUG then + io.stderr:write("ERROR:" .. err.y .. ":" .. err.x .. ": " .. err.msg .. "\n") + end + + table.insert(self.errors, err) +end + +function Errors:add(w, msg, ...) + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + +local context_name = { + ["local_declaration"] = "in local declaration", + ["global_declaration"] = "in global declaration", + ["assignment"] = "in assignment", + ["literal_table_item"] = "in table item", +} + +function Errors:get_context(ctx, name) + if not ctx then + return "" + end + local ec = (ctx.kind ~= nil) and ctx.expected_context + local cn = (type(ctx) == "string") and ctx or + (ctx.kind ~= nil) and context_name[ec and ec.kind or ctx.kind] + return (cn and cn .. ": " or "") .. (ec and ec.name and ec.name .. ": " or "") .. (name and name .. ": " or "") +end + +function Errors:add_in_context(w, ctx, msg, ...) + local prefix = self:get_context(ctx) + msg = prefix .. msg + + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + + +function Errors:collect(errs) + for _, e in ipairs(errs) do + insert_error(self, e.y, e.x, e) + end +end + +function Errors:add_warning(tag, w, fmt, ...) + assert(w.y) + table.insert(self.warnings, { + y = w.y, + x = w.x, + msg = fmt:format(...), + filename = self.filename, + tag = tag, + }) +end + +function Errors:invalid_at(w, msg, ...) + self:add(w, msg, ...) + return a_type(w, "invalid", {}) +end + +function Errors:add_unknown(node, name) + self:add_warning("unknown", node, "unknown variable: %s", name) +end + +function Errors:redeclaration_warning(node, old_var) + if node.tk:sub(1, 1) == "_" then return end + + local var_kind = "variable" + local var_name = node.tk + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + + local short_error = "redeclaration of " .. var_kind .. " '%s'" + if old_var and old_var.declared_at then + self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + else + self:add_warning("redeclaration", node, short_error, var_name) + end +end + +function Errors:unused_warning(name, var) + local prefix = name:sub(1, 1) + if var.declared_at and + var.is_narrowed ~= "narrow" and + prefix ~= "_" and + prefix ~= "@" then + + local t = var.t + self:add_warning( + "unused", + var.declared_at, + "unused %s %s: %s", + var.is_func_arg and "argument" or + t.typename == "function" and "function" or + t.typename == "typedecl" and "type" or + t.typename == "typealias" and "type" or + "variable", + name, + show_type(var.t)) + + end +end + +function Errors:add_prefixing(w, src, prefix, dst) + if not src then + return + end + + for _, err in ipairs(src) do + err.msg = prefix .. err.msg + if w and ( + (err.filename ~= w.f) or + (not err.y) or + (w.y > err.y or (w.y == err.y and w.x > err.x))) then + + err.y = w.y + err.x = w.x + err.filename = w.f + end + + if dst then + table.insert(dst, err) + else + insert_error(self, err.y, err.x, err) + end + end +end + + + + + + + + +local function check_for_unused_vars(scope, is_global) + local vars = scope.vars + if not next(vars) then + return + end + local list + for name, var in pairs(vars) do + local t = var.t + if var.declared_at and not var.used then + if var.used_as_type then + var.declared_at.elide_type = true + else + if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then + var.declared_at.elide_type = true + end + list = list or {} + table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) + end + elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then + var.aliasing.used = true + var.aliasing.declared_at.elide_type = false + end + end + if list then + table.sort(list, function(a, b) + return a.y < b.y or (a.y == b.y and a.x < b.x) + end) + end + return list +end + +function Errors:warn_unused_vars(scope, is_global) + local unused = check_for_unused_vars(scope, is_global) + if unused then + for _, u in ipairs(unused) do + self:unused_warning(u.name, u.var) + end + end + + if scope.labels then + for name, node in pairs(scope.labels) do + if not node.used_label then + self:add_warning("unused", node, "unused label ::%s::", name) + end + end + end end +function Errors:add_unknown_dot(node, name) + if not self.unknown_dots[name] then + self.unknown_dots[name] = true + self:add_unknown(node, name) + end +end +function Errors:fail_unresolved_labels(scope) + if scope.pending_labels then + for name, nodes in pairs(scope.pending_labels) do + for _, node in ipairs(nodes) do + self:add(node, "no visible label '" .. name .. "' for goto") + end + end + end +end +function Errors:fail_unresolved_nominals(scope, global_scope) + if global_scope and scope.pending_nominals then + for name, types in pairs(scope.pending_nominals) do + if not global_scope.pending_global_types[name] then + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + self:add(typ, "unknown type %s", typ) + end + end + end + end +end -local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {}) -local UNKNOWN = a_type("unknown", {}) -local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }) +function Errors:check_redeclared_key(w, ctx, seen_keys, key) + if key ~= nil then + local s = seen_keys[key] + if s then + self:add_in_context(w, ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. self.filename .. ":" .. s.y .. ":" .. s.x .. ")") + else + seen_keys[key] = w + end + end +end -local XPCALL_MSGH_FUNCTION = a_fn({ args = { ANY }, rets = {} }) local numeric_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, + ["integer"] = "integer", + ["number"] = "number", }, } local float_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, + ["integer"] = "number", + ["number"] = "number", }, } local integer_binop = { ["number"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = INTEGER, + ["integer"] = "integer", + ["number"] = "integer", }, } local relational_binop = { ["number"] = { - ["integer"] = BOOLEAN, - ["number"] = BOOLEAN, + ["integer"] = "boolean", + ["number"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, + ["string"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, } local equality_binop = { ["number"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["string"] = "boolean", + ["nil"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["boolean"] = "boolean", + ["nil"] = "boolean", }, ["record"] = { - ["emptytable"] = BOOLEAN, - ["record"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["record"] = "boolean", + ["nil"] = "boolean", }, ["array"] = { - ["emptytable"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["array"] = "boolean", + ["nil"] = "boolean", }, ["map"] = { - ["emptytable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["map"] = "boolean", + ["nil"] = "boolean", }, ["thread"] = { - ["thread"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["thread"] = "boolean", + ["nil"] = "boolean", }, } local unop_types = { ["#"] = { - ["string"] = INTEGER, - ["array"] = INTEGER, - ["tupletable"] = INTEGER, - ["map"] = INTEGER, - ["emptytable"] = INTEGER, + ["string"] = "integer", + ["array"] = "integer", + ["tupletable"] = "integer", + ["map"] = "integer", + ["emptytable"] = "integer", }, ["-"] = { - ["number"] = NUMBER, - ["integer"] = INTEGER, + ["number"] = "number", + ["integer"] = "integer", }, ["~"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["not"] = { - ["string"] = BOOLEAN, - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["boolean"] = BOOLEAN, - ["record"] = BOOLEAN, - ["array"] = BOOLEAN, - ["tupletable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["emptytable"] = BOOLEAN, - ["thread"] = BOOLEAN, + ["string"] = "boolean", + ["number"] = "boolean", + ["integer"] = "boolean", + ["boolean"] = "boolean", + ["record"] = "boolean", + ["array"] = "boolean", + ["tupletable"] = "boolean", + ["map"] = "boolean", + ["emptytable"] = "boolean", + ["thread"] = "boolean", }, } @@ -5976,67 +6305,66 @@ local binop_types = { [">"] = relational_binop, ["or"] = { ["boolean"] = { - ["boolean"] = BOOLEAN, - ["function"] = FUNCTION, + ["boolean"] = "boolean", }, ["number"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "number", + ["number"] = "number", + ["boolean"] = "boolean", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "integer", + ["number"] = "number", + ["boolean"] = "boolean", }, ["string"] = { - ["string"] = STRING, - ["boolean"] = BOOLEAN, - ["enum"] = STRING, + ["string"] = "string", + ["boolean"] = "boolean", + ["enum"] = "string", }, ["function"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["array"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["record"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["map"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["enum"] = { - ["string"] = STRING, + ["string"] = "string", }, ["thread"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, }, [".."] = { ["string"] = { - ["string"] = STRING, - ["enum"] = STRING, - ["number"] = STRING, - ["integer"] = STRING, + ["string"] = "string", + ["enum"] = "string", + ["number"] = "string", + ["integer"] = "string", }, ["number"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["integer"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["enum"] = { - ["number"] = STRING, - ["integer"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["number"] = "string", + ["integer"] = "string", + ["string"] = "string", + ["enum"] = "string", }, }, } @@ -6244,8 +6572,8 @@ local function show_type_base(t, short, seen) end end -local function inferred_msg(t) - return " (inferred at " .. t.inferred_at.filename .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" +local function inferred_msg(t, prefix) + return " (" .. (prefix or "") .. "inferred at " .. t.inferred_at.f .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" end show_type = function(t, short, seen) @@ -6297,28 +6625,29 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(module_name, lax, env) +local function require_module(w, module_name, feat_lax, env) local mod = env.modules[module_name] if mod then - return mod, true + return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (lax or found:match("tl$")) then + if found and (feat_lax or found:match("tl$")) then - env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + env.module_filenames[module_name] = found + env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) local found_result, err = tl.process(found, env, fd) assert(found_result, err) env.modules[module_name] = found_result.type - return found_result.type, true + return found_result.type, found elseif fd then fd:close() end - return INVALID, found ~= nil + return a_type(w, "invalid", {}), found end local compat_code_cache = {} @@ -6340,7 +6669,7 @@ local function add_compat_entries(program, used_set, gen_compat) local code = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, { filename = "", lax = false, gen_compat = "off" }) + tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -6379,32 +6708,26 @@ local function add_compat_entries(program, used_set, gen_compat) TL_DEBUG = tl_debug end -local function get_stdlib_compat(lax) - if lax then - return { - ["utf8"] = true, - } - else - return { - ["io"] = true, - ["math"] = true, - ["string"] = true, - ["table"] = true, - ["utf8"] = true, - ["coroutine"] = true, - ["os"] = true, - ["package"] = true, - ["debug"] = true, - ["load"] = true, - ["loadfile"] = true, - ["assert"] = true, - ["pairs"] = true, - ["ipairs"] = true, - ["pcall"] = true, - ["xpcall"] = true, - ["rawlen"] = true, - } - end +local function get_stdlib_compat() + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } end local bit_operators = { @@ -6415,14 +6738,21 @@ local bit_operators = { ["<<"] = "lshift", } +local function node_at(w, n) + n.f = assert(w.f) + n.x = w.x + n.y = w.y + return n +end + local function convert_node_to_compat_call(node, mod_name, fn_name, e1, e2) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } - node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } - node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } + node.e1 = node_at(node, { kind = "op", op = an_operator(node, 2, ".") }) + node.e1.e1 = node_at(node, { kind = "identifier", tk = mod_name }) + node.e1.e2 = node_at(node, { kind = "identifier", tk = fn_name }) + node.e2 = node_at(node, { kind = "expression_list" }) node.e2[1] = e1 node.e2[2] = e2 end @@ -6431,36 +6761,17 @@ local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } - node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } - node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } + node.e1 = node_at(node, { kind = "identifier", tk = "_tl_mt" }) + node.e2 = node_at(node, { kind = "expression_list" }) + node.e2[1] = node_at(node, { kind = "string", tk = "\"" .. mt_name .. "\"" }) + node.e2[2] = node_at(node, { kind = "integer", tk = tostring(which_self) }) node.e2[3] = e1 node.e2[4] = e2 end -local stdlib_globals = nil -local globals_typeid = new_typeid() -local fresh_typevar_ctr = 1 - -local function set_feat(feat, default) - if feat then - return (feat == "on") - else - return default - end -end - -tl.new_env = function(opts) - local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) - if not env then - return nil, err - end - - env.feat_arity = set_feat(opts.feat_arity, true) - - return env -end +local stdlib_globals = nil +local globals_typeid = new_typeid() +local fresh_typevar_ctr = 1 local function assert_no_stdlib_errors(errors, name) if #errors ~= 0 then @@ -6472,46 +6783,31 @@ local function assert_no_stdlib_errors(errors, name) end end -tl.init_env = function(lax, gen_compat, gen_target, predefined) - if gen_compat == true or gen_compat == nil then - gen_compat = "optional" - elseif gen_compat == false then - gen_compat = "off" - end - gen_compat = gen_compat - - if not gen_target then - if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then - gen_target = "5.1" - else - gen_target = "5.3" - end - end - - if gen_target == "5.4" and gen_compat ~= "off" then - return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" - end +tl.new_env = function(opts) + opts = opts or {} local env = { modules = {}, + module_filenames = {}, loaded = {}, loaded_order = {}, globals = {}, - gen_compat = gen_compat, - gen_target = gen_target, + defaults = opts.defaults or {}, } + if env.defaults.gen_target == "5.4" and env.defaults.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + local w = { f = "@stdlib", x = 1, y = 1 } + if not stdlib_globals then local tl_debug = TL_DEBUG TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - - local result = tl.type_check(program, { - filename = "@stdlib", - env = env, - }) + local result = tl.type_check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -6520,21 +6816,20 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) local math_t = (stdlib_globals["math"].t).def local table_t = (stdlib_globals["table"].t).def - local integer_compat = a_type("integer", { needs_compat = true }) - math_t.fields["maxinteger"] = integer_compat - math_t.fields["mininteger"] = integer_compat + math_t.fields["maxinteger"].needs_compat = true + math_t.fields["mininteger"].needs_compat = true table_t.fields["unpack"].needs_compat = true - stdlib_globals["..."] = { t = a_vararg({ STRING }) } - stdlib_globals["@is_va"] = { t = ANY } + stdlib_globals["..."] = { t = a_vararg(w, { a_type(w, "string", {}) }) } + stdlib_globals["@is_va"] = { t = a_type(w, "any", {}) } env.globals = {} end - local stdlib_compat = get_stdlib_compat(lax) + local stdlib_compat = get_stdlib_compat() for name, var in pairs(stdlib_globals) do env.globals[name] = var var.needs_compat = stdlib_compat[name] @@ -6545,53 +6840,43 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) end end - if predefined then - for _, name in ipairs(predefined) do - local module_type = require_module(name, lax, env) + if opts.predefined_modules then + for _, name in ipairs(opts.predefined_modules) do + local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) - if module_type == INVALID then + if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) end end end - env.feat_arity = true - return env end -tl.type_check = function(ast, opts) - opts = opts or {} - local env = opts.env - if not env then - local err - env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) - if err then - return nil, err - end - end +do + + + + local TypeChecker = {} + + + + + + + + + + + - local lax = opts.lax - local feat_arity = env.feat_arity - local filename = opts.filename - local st = { env.globals } - local all_needs_compat = {} - local dependencies = {} - local warnings = {} - local errors = {} - local module_type - local tc - if env.report_types then - env.reporter = env.reporter or tl.new_type_reporter() - tc = env.reporter:get_collector(filename or "?") - end @@ -6600,10 +6885,21 @@ tl.type_check = function(ast, opts) - local function find_var(name, use) - for i = #st, 1, -1 do - local scope = st[i] - local var = scope[name] + + + + + + + + + + + + function TypeChecker:find_var(name, use) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local var = scope.vars[name] if var then if use == "lvalue" and var.is_narrowed then if var.narrowed_from then @@ -6612,7 +6908,7 @@ tl.type_check = function(ast, opts) end else if i == 1 and var.needs_compat then - all_needs_compat[name] = true + self.all_needs_compat[name] = true end if use == "use_type" then var.used_as_type = true @@ -6625,10 +6921,10 @@ tl.type_check = function(ast, opts) end end - local function simulate_g() + function TypeChecker:simulate_g() local globals = {} - for k, v in pairs(st[1]) do + for k, v in pairs(self.st[1].vars) do if k:sub(1, 1) ~= "@" then globals[k] = v.t end @@ -6642,100 +6938,60 @@ tl.type_check = function(ast, opts) end - local resolve_typevars + local typevar_resolver - local function fresh_typevar(t) - return a_type("typevar", { + local function fresh_typevar(_, t) + return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }) end - local function fresh_typearg(t) - return a_type("typearg", { + local function fresh_typearg(_, t) + return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }) end - local function ensure_fresh_typeargs(t) + function TypeChecker:ensure_fresh_typeargs(t) if not t.typeargs then return t end fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok - ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg) + ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) assert(ok, "Internal Compiler Error: error creating fresh type variables") return t end - local function find_var_type(name, use) - local var = find_var(name, use) + function TypeChecker:find_var_type(name, use) + local var = self:find_var(name, use) if var then local t = var.t if t.typename == "unresolved_typearg" then return nil, nil, t.constraint end - t = ensure_fresh_typeargs(t) + t = self:ensure_fresh_typeargs(t) return t, var.attribute end end - local function Err(where, msg, ...) - local n = select("#", ...) - if n > 0 then - local showt = {} - for i = 1, n do - local t = select(i, ...) - if t then - if t.typename == "invalid" then - return nil - end - showt[i] = show_type(t) - end - end - msg = msg:format(_tl_table_unpack(showt)) - end - local name = where.filename or filename - - if TL_DEBUG then - io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") - end - - return { - y = where.y, - x = where.x, - msg = msg, - filename = name, - } - end - - local function error_at(w, msg, ...) - assert(w.y) - - local e = Err(w, msg, ...) - if e then - table.insert(errors, e) - return true - else - return false - end - end - - local function ensure_not_abstract(where, t) + local function ensure_not_abstract(t) if t.typename == "function" and t.macroexp then - error_at(where, "macroexps are abstract; consider using a concrete function") + return nil, "macroexps are abstract; consider using a concrete function" elseif t.typename == "typedecl" then local def = t.def if def.typename == "interface" then - error_at(where, "interfaces are abstract; consider using a concrete record") + return nil, "interfaces are abstract; consider using a concrete record" end end + return true end - local function find_type(names, accept_typearg) - local typ = find_var_type(names[1], "use_type") + function TypeChecker:find_type(names, accept_typearg) + local typ = self:find_var_type(names[1], "use_type") if not typ then return nil end @@ -6757,7 +7013,7 @@ tl.type_check = function(ast, opts) return nil end - typ = ensure_fresh_typeargs(typ) + typ = self:ensure_fresh_typeargs(typ) if typ.typename == "nominal" and typ.found then typ = typ.found end @@ -6769,19 +7025,19 @@ tl.type_check = function(ast, opts) end end - local function union_type(t) + local function type_for_union(t) if t.typename == "typedecl" then - return union_type(t.def), t.def + return type_for_union(t.def), t.def elseif t.typename == "typealias" then - return union_type(t.alias_to), t.alias_to + return type_for_union(t.alias_to), t.alias_to elseif t.typename == "tuple" then - return union_type(t.tuple[1]), t.tuple[1] + return type_for_union(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then local typedecl = t.found if not typedecl then return "invalid" end - return union_type(typedecl) + return type_for_union(typedecl) elseif t.fields then if t.is_userdata then return "userdata", t @@ -6805,7 +7061,7 @@ tl.type_check = function(ast, opts) local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut, rt = union_type(t) + local ut, rt = type_for_union(t) if ut == "userdata" then assert(rt.fields) if rt.meta_fields and rt.meta_fields["__is"] then @@ -6886,24 +7142,11 @@ tl.type_check = function(ast, opts) ["unknown"] = true, } - local function default_resolve_typevars_callback(t) - local rt = find_var_type(t.typevar) - if not rt then - return nil - elseif rt.typename == "string" then - - return STRING - end - return rt - end - - resolve_typevars = function(typ, fn_var, fn_arg) + typevar_resolver = function(self, typ, fn_var, fn_arg) local errs local seen = {} local resolved = {} - fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t, all_same) local same = true @@ -6918,7 +7161,7 @@ tl.type_check = function(ast, opts) local orig_t = t if t.typename == "typevar" then - local rt = fn_var(t) + local rt = fn_var(self, t) if rt then resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then @@ -6934,7 +7177,7 @@ tl.type_check = function(ast, opts) seen[orig_t] = copy copy.typename = t.typename - copy.filename = t.filename + copy.f = t.f copy.x = t.x copy.y = t.y @@ -6945,7 +7188,7 @@ tl.type_check = function(ast, opts) elseif t.typename == "typearg" then if fn_arg then - copy = fn_arg(t) + copy = fn_arg(self, t) else assert(copy.typename == "typearg") copy.typearg = t.typearg @@ -7038,7 +7281,7 @@ tl.type_check = function(ast, opts) local _, err = is_valid_union(copy) if err then errs = errs or {} - table.insert(errs, Err(t, err, copy)) + table.insert(errs, Err(err, copy)) end elseif t.typename == "poly" then assert(copy.typename == "poly") @@ -7048,6 +7291,7 @@ tl.type_check = function(ast, opts) end elseif t.typename == "tupletable" then assert(copy.typename == "tupletable") + copy.inferred_at = t.inferred_at copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -7067,7 +7311,7 @@ tl.type_check = function(ast, opts) local copy, same = resolve(typ, true) if errs then - return false, INVALID, errs + return false, a_type(typ, "invalid", {}), errs end if (not same) and @@ -7086,144 +7330,72 @@ tl.type_check = function(ast, opts) return true, copy end - local function infer_emptytable(emptytable, fresh_t) - local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") - local nst = is_global and 1 or #st - for i = nst, 1, -1 do - local scope = st[i] - if scope[emptytable.assigned_to] then - scope[emptytable.assigned_to] = { t = fresh_t } - end - end - end + local function resolve_typevar(tc, t) + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + elseif rt.typename == "string" then - local function resolve_tuple(t) - if t.typename == "tuple" then - t = t.tuple[1] - end - if t == nil then - return NIL + return a_type(rt, "string", {}) end - return t - end - - local function add_warning(tag, where, fmt, ...) - table.insert(warnings, { - y = where.y, - x = where.x, - msg = fmt:format(...), - filename = where.filename or filename, - tag = tag, - }) - end - - local function invalid_at(where, msg, ...) - error_at(where, msg, ...) - return INVALID - end - - local function add_unknown(node, name) - add_warning("unknown", node, "unknown variable: %s", name) + return rt end - local function redeclaration_warning(node, old_var) - if node.tk:sub(1, 1) == "_" then return end - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end - local short_error = "redeclaration of " .. var_kind .. " '%s'" - if old_var and old_var.declared_at then - add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) - else - add_warning("redeclaration", node, short_error, var_name) + function TypeChecker:infer_emptytable(emptytable, fresh_t) + local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") + local nst = is_global and 1 or #self.st + for i = nst, 1, -1 do + local scope = self.st[i] + if scope.vars[emptytable.assigned_to] then + scope.vars[emptytable.assigned_to] = { t = fresh_t } + end end end - local function check_if_redeclaration(new_name, at) - local old = find_var(new_name, "check_only") - if old then - redeclaration_warning(at, old) + local function resolve_tuple(t) + local rt = t + if rt.typename == "tuple" then + rt = rt.tuple[1] end - end - - local function unused_warning(name, var) - local prefix = name:sub(1, 1) - if var.declared_at and - var.is_narrowed ~= "narrow" and - prefix ~= "_" and - prefix ~= "@" then - - if name:sub(1, 2) == "::" then - add_warning("unused", var.declared_at, "unused label %s", name) - else - local t = var.t - add_warning( - "unused", - var.declared_at, - "unused %s %s: %s", - var.is_func_arg and "argument" or - t.typename == "function" and "function" or - t.typename == "typedecl" and "type" or - t.typename == "typealias" and "type" or - "variable", - name, - show_type(var.t)) - - end + if rt == nil then + return a_type(t, "nil", {}) end + return rt end - local function add_errs_prefixing(where, src, dst, prefix) - assert(where == nil or where.y ~= nil) - - if not src then - return - end - for _, err in ipairs(src) do - err.msg = prefix .. err.msg - - if where and ( - (err.filename ~= filename) or - (not err.y) or - (where.y > err.y or (where.y == err.y and where.x > err.x))) then - - err.y = where.y - err.x = where.x - err.filename = filename - end - table.insert(dst, err) + function TypeChecker:check_if_redeclaration(new_name, at) + local old = self:find_var(new_name, "check_only") + if old then + self.errs:redeclaration_warning(at, old) end end + local function type_at(w, t) t.x = w.x t.y = w.y - t.filename = filename return t end - local function resolve_typevars_at(where, t) - assert(where) - local ok, ret, errs = resolve_typevars(t) + function TypeChecker:resolve_typevars_at(w, t) + assert(w) + local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) if not ok then - assert(where.y) - add_errs_prefixing(where, errs, errors, "") + assert(w.y) + self.errs:add_prefixing(w, errs, "") end if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end - return type_at(where, ret) + return type_at(w, ret) end - local function infer_at(where, t) - local ret = resolve_typevars_at(where, t) + function TypeChecker:infer_at(w, t) + local ret = self:resolve_typevars_at(w, t) if ret.typename == "invalid" then ret = t end @@ -7231,8 +7403,8 @@ tl.type_check = function(ast, opts) if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end - ret.inferred_at = where - ret.inferred_at.filename = filename + assert(w.f) + ret.inferred_at = w return ret end @@ -7245,12 +7417,9 @@ tl.type_check = function(ast, opts) return t end - local get_unresolved - local find_unresolved - - local function add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - local scope = st[#st] - local var = scope[name] + function TypeChecker:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) + local scope = self.st[#self.st] + local var = scope.vars[name] if narrow then if var then if var.is_narrowed then @@ -7263,11 +7432,11 @@ tl.type_check = function(ast, opts) var.t = t else var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node } - scope[name] = var + scope.vars[name] = var end - local unresolved = get_unresolved(scope) - unresolved.narrows[name] = true + scope.narrows = scope.narrows or {} + scope.narrows[name] = true return var end @@ -7278,37 +7447,33 @@ tl.type_check = function(ast, opts) name ~= "..." and name:sub(1, 1) ~= "@" then - check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node) end if var and not var.used then - unused_warning(name, var) + self.errs:unused_warning(name, var) end var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope[name] = var + scope.vars[name] = var return var end - local function add_var(node, name, t, attribute, narrow, dont_check_redeclaration) - if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - add_unknown(node, name) + function TypeChecker:add_var(node, name, t, attribute, narrow, dont_check_redeclaration) + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) end if not attribute then t = drop_constant_value(t) end - local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - - if t.typename == "unresolved" or t.typename == "none" then - return var - end + local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if tc and node then - tc.add_to_symbol_list(node, name, t) + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) end return var @@ -7316,9 +7481,6 @@ tl.type_check = function(ast, opts) - local same_type - local is_a - @@ -7332,38 +7494,38 @@ tl.type_check = function(ast, opts) - local function arg_check(where, all_errs, a, b, v, mode, n) + function TypeChecker:arg_check(w, all_errs, a, b, v, mode, n) local ok, errs if v == "covariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) elseif v == "contravariant" then - ok, errs = is_a(b, a) + ok, errs = self:is_a(b, a) elseif v == "bivariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) if ok then return true end - ok = is_a(b, a) + ok = self:is_a(b, a) if ok then return true end elseif v == "invariant" then - ok, errs = same_type(a, b) + ok, errs = self:same_type(a, b) end if not ok then - add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") + self.errs:add_prefixing(w, errs, mode .. (n and " " .. n or "") .. ": ", all_errs) return false end return true end - local function has_all_types_of(t1s, t2s) + function TypeChecker:has_all_types_of(t1s, t2s) for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if same_type(t2, t1) then + if self:same_type(t2, t1) then found = true break end @@ -7395,8 +7557,8 @@ tl.type_check = function(ast, opts) end end - local function close_types(vars) - for _, var in pairs(vars) do + local function close_types(scope) + for _, var in pairs(scope.vars) do local t = var.t if t.typename == "typedecl" then t.closed = true @@ -7408,161 +7570,96 @@ tl.type_check = function(ast, opts) end end + function TypeChecker:begin_scope(node) + table.insert(self.st, { vars = {} }) - - - - - - - local function check_for_unused_vars(vars, is_global) - if not next(vars) then - return - end - local list = {} - for name, var in pairs(vars) do - local t = var.t - if var.declared_at and not var.used then - if var.used_as_type then - var.declared_at.elide_type = true - else - if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then - var.declared_at.elide_type = true - end - table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) - end - elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then - var.aliasing.used = true - var.aliasing.declared_at.elide_type = false - end - end - if list[1] then - table.sort(list, function(a, b) - return a.y < b.y or (a.y == b.y and a.x < b.x) - end) - for _, u in ipairs(list) do - unused_warning(u.name, u.var) - end - end - end - - get_unresolved = function(scope) - local unresolved - if scope then - local unr = scope["@unresolved"] - unresolved = unr and unr.t - else - unresolved = find_var_type("@unresolved") - end - if not unresolved then - unresolved = a_type("unresolved", { - labels = {}, - nominals = {}, - global_types = {}, - narrows = {}, - }) - add_var(nil, "@unresolved", unresolved) - end - return unresolved - end - - find_unresolved = function(level) - local u = st[level or #st]["@unresolved"] - if u then - return u.t - end - end - - local function begin_scope(node) - table.insert(st, {}) - - if tc and node then - tc.begin_symbol_list_scope(node) + if self.collector and node then + self.collector.begin_symbol_list_scope(node) end end - local function end_scope(node) + function TypeChecker:end_scope(node) + local st = self.st local scope = st[#st] - local unresolved = scope["@unresolved"] - if unresolved then - local unrt = unresolved.t - local next_scope = st[#st - 1] - local upper = next_scope["@unresolved"] - if upper then - local uppert = upper.t - for name, nodes in pairs(unrt.labels) do + local next_scope = st[#st - 1] + + if next_scope then + if scope.pending_labels then + next_scope.pending_labels = next_scope.pending_labels or {} + for name, nodes in pairs(scope.pending_labels) do for _, n in ipairs(nodes) do - uppert.labels[name] = uppert.labels[name] or {} - table.insert(uppert.labels[name], n) + next_scope.pending_labels[name] = next_scope.pending_labels[name] or {} + table.insert(next_scope.pending_labels[name], n) end end - for name, types in pairs(unrt.nominals) do + scope.pending_labels = nil + end + if scope.pending_nominals then + next_scope.pending_nominals = next_scope.pending_nominals or {} + for name, types in pairs(scope.pending_nominals) do for _, typ in ipairs(types) do - uppert.nominals[name] = uppert.nominals[name] or {} - table.insert(uppert.nominals[name], typ) + next_scope.pending_nominals[name] = next_scope.pending_nominals[name] or {} + table.insert(next_scope.pending_nominals[name], typ) end end - for name, _ in pairs(unrt.global_types) do - uppert.global_types[name] = true - end - else - next_scope["@unresolved"] = unresolved - unrt.narrows = {} + scope.pending_nominals = nil end end + close_types(scope) - check_for_unused_vars(scope) + self.errs:warn_unused_vars(scope) + table.remove(st) - if tc and node then - tc.end_symbol_list_scope(node) + if self.collector and node then + self.collector.end_symbol_list_scope(node) end end - local end_scope_and_none_type = function(node, _children) - end_scope(node) + + local NONE = a_type({ f = "@none", x = -1, y = -1 }, "none", {}) + + local function end_scope_and_none_type(self, node, _children) + self:end_scope(node) return NONE end - local resolve_nominal - local resolve_typealias do - local function match_typevals(t, def) + local function match_typevals(self, t, def) if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - error_at(t, "mismatch in number of type arguments") + self.errs:add(t, "mismatch in number of type arguments") return nil end - begin_scope() + self:begin_scope() for i, tt in ipairs(t.typevals) do - add_var(nil, def.typeargs[i].typearg, tt) + self:add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t, def) - end_scope() + local ret = self:resolve_typevars_at(t, def) + self:end_scope() return ret elseif t.typevals then - error_at(t, "spurious type arguments") + self.errs:add(t, "spurious type arguments") return nil elseif def.typeargs then - error_at(t, "missing type arguments in %s", def) + self.errs:add(t, "missing type arguments in %s", def) return nil else return def end end - local function find_nominal_type_decl(t) + local function find_nominal_type_decl(self, t) if t.resolved then return t.resolved end - local found = t.found or find_type(t.names) + local found = t.found or self:find_type(t.names) if not found then - error_at(t, "unknown type %s", t) - return INVALID + return self.errs:invalid_at(t, "unknown type %s", t) end if found.typename == "typealias" then @@ -7570,8 +7667,7 @@ tl.type_check = function(ast, opts) end if not (found.typename == "typedecl") then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end local def = found.def @@ -7586,44 +7682,35 @@ tl.type_check = function(ast, opts) return nil, found end - local function resolve_decl_into_nominal(t, found) + local function resolve_decl_into_nominal(self, t, found) local def = found.def local resolved if def.typename == "record" or def.typename == "function" then - resolved = match_typevals(t, def) + resolved = match_typevals(self, t, def) if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") end else resolved = def end - if not t.filename then - t.filename = resolved.filename - if t.x == nil and t.y == nil then - t.x = resolved.x - t.y = resolved.y - end - end - t.resolved = resolved return resolved end - resolve_nominal = function(t) - local immediate, found = find_nominal_type_decl(t) + function TypeChecker:resolve_nominal(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end - return resolve_decl_into_nominal(t, found) + return resolve_decl_into_nominal(self, t, found) end - resolve_typealias = function(typealias) + function TypeChecker:resolve_typealias(typealias) local t = typealias.alias_to - local immediate, found = find_nominal_type_decl(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end @@ -7632,90 +7719,92 @@ tl.type_check = function(ast, opts) return found end - local resolved = resolve_decl_into_nominal(t, found) + local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type("typedecl", { def = resolved }) + local typedecl = a_type(typealias, "typedecl", { def = resolved }) t.resolved = typedecl return typedecl end end - local function are_same_unresolved_global_type(t1, t2) - if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() - if unresolved.global_types[t1.names[1]] then - return true + do + local function are_same_unresolved_global_type(self, t1, t2) + if t1.names[1] == t2.names[1] then + local global_scope = self.st[1] + if global_scope.pending_global_types[t1.names[1]] then + return true + end end + return false end - return false - end - local function fail_nominals(t1, t2) - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + local function fail_nominals(self, t1, t2) + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + self:resolve_nominal(t1) + if t1.found then + t1name = t1name .. " (defined in " .. t1.found.f .. ":" .. t1.found.y .. ")" + end + self:resolve_nominal(t2) + if t2.found then + t2name = t2name .. " (defined in " .. t2.found.f .. ":" .. t2.found.y .. ")" + end end + return false, { Err(t1name .. " is not a " .. t2name) } end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - local function are_same_nominals(t1, t2) - local same_names - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or find_type(t1.names) - local ft2 = t2.found or find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid + function TypeChecker:are_same_nominals(t1, t2) + local same_names + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid else - if are_same_unresolved_global_type(t1, t2) then - return true - end + local ft1 = t1.found or self:find_type(t1.names) + local ft2 = t2.found or self:find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - error_at(t1, "unknown type %s", t1) - end - if not ft2 then - error_at(t2, "unknown type %s", t2) + if not ft1 then + self.errs:add(t1, "unknown type %s", t1) + end + if not ft2 then + self.errs:add(t2, "unknown type %s", t2) + end + return false, {} end - return false, {} end - end - if not same_names then - return fail_nominals(t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local errs = {} - for i = 1, #t1.typevals do - local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + if not same_names then + return fail_nominals(self, t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) + self.errs:add_prefixing(nil, typeval_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", errs) + end + return any_errors(errs) end - return any_errors(errs) + return true end - return true end local is_lua_table_type - local function to_structural(t) + function TypeChecker:to_structural(t) assert(not (t.typename == "tuple")) if t.typename == "nominal" then - return resolve_nominal(t) + return self:resolve_nominal(t) end return t end - local function unite(types, flatten_constants) + local function unite(w, types, flatten_constants) if #types == 1 then return types[1] end @@ -7726,7 +7815,6 @@ tl.type_check = function(ast, opts) local types_seen = {} - types_seen[NIL.typeid] = true types_seen["nil"] = true local i = 1 @@ -7762,14 +7850,14 @@ tl.type_check = function(ast, opts) end end - if types_seen[INVALID.typeid] then - return INVALID + if types_seen["invalid"] then + return a_type(w, "invalid", {}) end if #ts == 1 then return ts[1] else - return a_type("union", { types = ts }) + return a_type(w, "union", { types = ts }) end end @@ -7789,21 +7877,20 @@ tl.type_check = function(ast, opts) end end - local expand_type - local function arraytype_from_tuple(where, tupletype) + function TypeChecker:arraytype_from_tuple(w, tupletype) - local element_type = unite(tupletype.types, true) + local element_type = unite(w, tupletype.types, true) local valid = (not (element_type.typename == "union")) and true or is_valid_union(element_type) if valid then - return a_type("array", { elements = element_type }) + return a_type(w, "array", { elements = element_type }) end - local arr_type = a_type("array", { elements = tupletype.types[1] }) + local arr_type = a_type(w, "array", { elements = tupletype.types[1] }) for i = 2, #tupletype.types do - local expanded = expand_type(where, arr_type, a_type("array", { elements = tupletype.types[i] })) + local expanded = self:expand_type(w, arr_type, a_type(w, "array", { elements = tupletype.types[i] })) if not (expanded.typename == "array") then - return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } + return nil, { Err("unable to convert tuple %s to array", tupletype) } end arr_type = expanded end @@ -7814,33 +7901,33 @@ tl.type_check = function(ast, opts) return t.typename == "nominal" and t.names[1] == "@self" end - local function compare_true(_, _) + local function compare_true(_, _, _) return true end - local function subtype_nominal(a, b) + function TypeChecker:subtype_nominal(a, b) if is_self(a) and is_self(b) then return true end - local ra = a.typename == "nominal" and resolve_nominal(a) or a - local rb = b.typename == "nominal" and resolve_nominal(b) or b - local ok, errs = is_a(ra, rb) + local ra = a.typename == "nominal" and self:resolve_nominal(a) or a + local rb = b.typename == "nominal" and self:resolve_nominal(b) or b + local ok, errs = self:is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false end return ok, errs end - local function subtype_array(a, b) - if (not a.elements) or (not is_a(a.elements, b.elements)) then + function TypeChecker:subtype_array(a, b) + if (not a.elements) or (not self:is_a(a.elements, b.elements)) then return false end if a.consttypes and #a.consttypes > 1 then for _, e in ipairs(a.consttypes) do - if not is_a(e, b.elements) then - return false, { Err(a, "%s is not a member of %s", e, b.elements) } + if not self:is_a(e, b.elements) then + return false, { Err("%s is not a member of %s", e, b.elements) } end end end @@ -7862,16 +7949,16 @@ tl.type_check = function(ast, opts) return nil end - local function subtype_record(a, b) + function TypeChecker:subtype_record(a, b) if a.elements and b.elements then - if not is_a(a.elements, b.elements) then - return false, { Err(a, "array parts have incompatible element types") } + if not self:is_a(a.elements, b.elements) then + return false, { Err("array parts have incompatible element types") } end end if a.is_userdata ~= b.is_userdata then - return false, { Err(a, a.is_userdata and "userdata is not a record" or + return false, { Err(a.is_userdata and "userdata is not a record" or "record is not a userdata"), } end @@ -7880,9 +7967,9 @@ tl.type_check = function(ast, opts) local ak = a.fields[k] local bk = b.fields[k] if bk then - local ok, fielderrs = is_a(ak, bk) + local ok, fielderrs = self:is_a(ak, bk) if not ok then - add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + self.errs:add_prefixing(nil, fielderrs, "record field doesn't match: " .. k .. ": ", errs) end end end @@ -7896,32 +7983,32 @@ tl.type_check = function(ast, opts) return true end - local eqtype_record = function(a, b) + function TypeChecker:eqtype_record(a, b) if (a.elements ~= nil) ~= (b.elements ~= nil) then - return false, { Err(a, "types do not have the same array interface") } + return false, { Err("types do not have the same array interface") } end if a.elements then - local ok, errs = same_type(a.elements, b.elements) + local ok, errs = self:same_type(a.elements, b.elements) if not ok then return ok, errs end end - local ok, errs = subtype_record(a, b) + local ok, errs = self:subtype_record(a, b) if not ok then return ok, errs end - ok, errs = subtype_record(b, a) + ok, errs = self:subtype_record(b, a) if not ok then return ok, errs end return true end - local function compare_map(ak, bk, av, bv, no_hack) - local ok1, errs_k = same_type(ak, bk) - local ok2, errs_v = same_type(av, bv) + local function compare_map(self, ak, bk, av, bv, no_hack) + local ok1, errs_k = self:same_type(ak, bk) + local ok2, errs_v = self:same_type(av, bv) if bk.typename == "any" and not no_hack then @@ -7951,25 +8038,25 @@ tl.type_check = function(ast, opts) return false, errs_k or errs_v end - local function compare_or_infer_typevar(typevar, a, b, cmp) + function TypeChecker:compare_or_infer_typevar(typevar, a, b, cmp) - local vt, _, constraint = find_var_type(typevar) + local vt, _, constraint = self:find_var_type(typevar) if vt then - return cmp(a or vt, b or vt) + return cmp(self, a or vt, b or vt) else local other = a or b if constraint then - if not is_a(other, constraint) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + if not self:is_a(other, constraint) then + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end - if same_type(other, constraint) then + if self:same_type(other, constraint) then @@ -7977,22 +8064,22 @@ tl.type_check = function(ast, opts) end end - local ok, r, errs = resolve_typevars(other) + local ok, r, errs = typevar_resolver(self, other, resolve_typevar) if not ok then return false, errs end if r.typename == "typevar" and r.typevar == typevar then return true end - add_var(nil, typevar, r) + self:add_var(nil, typevar, r) return true end end - local function exists_supertype_in(t, xs) + function TypeChecker:exists_supertype_in(t, xs) for _, x in ipairs(xs.types) do - if is_a(t, x) then + if self:is_a(t, x) then return x end end @@ -8003,143 +8090,139 @@ tl.type_check = function(ast, opts) ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a, b) + ["interface"] = function(_self, _a, b) return not b.is_userdata end, - ["record"] = function(_a, b) + ["record"] = function(_self, _a, b) return not b.is_userdata end, } - - - local eqtype_relations - eqtype_relations = { + TypeChecker.eqtype_relations = { ["typevar"] = { - ["typevar"] = function(a, b) + ["typevar"] = function(self, a, b) if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, - ["*"] = function(a, b) - return compare_or_infer_typevar(a.typevar, nil, b, same_type) + ["*"] = function(self, a, b) + return self:compare_or_infer_typevar(a.typevar, nil, b, self.same_type) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) for i = 1, math.min(#a.types, #b.types) do - if not same_type(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + if not self:same_type(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types ~= #b.types then - return false, { Err(a, "tuples have different size", a, b) } + return false, { Err("tuples have different size", a, b) } end return true end, }, ["array"] = { - ["array"] = function(a, b) - return same_type(a.elements, b.elements) + ["array"] = function(self, a, b) + return self:same_type(a.elements, b.elements) end, }, ["map"] = { - ["map"] = function(a, b) - return compare_map(a.keys, b.keys, a.values, b.values, true) + ["map"] = function(self, a, b) + return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, }, ["union"] = { - ["union"] = function(a, b) - return (has_all_types_of(a.types, b.types) and - has_all_types_of(b.types, a.types)) + ["union"] = function(self, a, b) + return (self:has_all_types_of(a.types, b.types) and + self:has_all_types_of(b.types, a.types)) end, }, ["nominal"] = { - ["nominal"] = are_same_nominals, + ["nominal"] = TypeChecker.are_same_nominals, }, ["record"] = { - ["record"] = eqtype_record, + ["record"] = TypeChecker.eqtype_record, }, ["interface"] = { - ["interface"] = function(a, b) + ["interface"] = function(_self, a, b) return a.typeid == b.typeid end, }, ["function"] = { - ["function"] = function(a, b) + ["function"] = function(self, a, b) local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then if (not not a.is_method) ~= (not not b.is_method) then - return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + return false, { Err("different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } + return false, { Err("different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end local narets, nbrets = #a.rets.tuple, #b.rets.tuple if narets ~= nbrets then - return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } + return false, { Err("different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} for i = 1, naargs do - arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) + self:arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) + self:arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, }, ["*"] = { - ["typevar"] = function(a, b) - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + ["typevar"] = function(self, a, b) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, }, } - local subtype_relations - subtype_relations = { + TypeChecker.subtype_relations = { ["tuple"] = { - ["tuple"] = function(a, b) + ["tuple"] = function(self, a, b) local at, bt = a.tuple, b.tuple if #at ~= #bt then return false end for i = 1, #at do - if not is_a(at[i], bt[i]) then + if not self:is_a(at[i], bt[i]) then return false end end return true end, - ["*"] = function(a, b) - return is_a(resolve_tuple(a), b) + ["*"] = function(self, a, b) + return self:is_a(resolve_tuple(a), b) end, }, ["typevar"] = { - ["typevar"] = function(a, b) + ["typevar"] = function(self, a, b) if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["*"] = function(a, b) - return compare_or_infer_typevar(a.typevar, nil, b, is_a) + ["*"] = function(self, a, b) + return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, ["nil"] = { ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a, b) + ["union"] = function(self, a, b) local used = {} for _, t in ipairs(a.types) do - begin_scope() - local u = exists_supertype_in(t, b) - end_scope() + self:begin_scope() + local u = self:exists_supertype_in(t, b) + self:end_scope() if not u then return false end @@ -8148,13 +8231,13 @@ tl.type_check = function(ast, opts) end end for u, t in pairs(used) do - is_a(t, u) + self:is_a(t, u) end return true end, - ["*"] = function(a, b) + ["*"] = function(self, a, b) for _, t in ipairs(a.types) do - if not is_a(t, b) then + if not self:is_a(t, b) then return false end end @@ -8162,212 +8245,212 @@ tl.type_check = function(ast, opts) end, }, ["poly"] = { - ["*"] = function(a, b) - if exists_supertype_in(b, a) then + ["*"] = function(self, a, b) + if self:exists_supertype_in(b, a) then return true end - return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + return false, { Err("cannot match against any alternatives of the polymorphic type") } end, }, ["nominal"] = { - ["nominal"] = function(a, b) - local ok, errs = are_same_nominals(a, b) + ["nominal"] = function(self, a, b) + local ok, errs = self:are_same_nominals(a, b) if ok then return true end - local rb = resolve_nominal(b) + local rb = self:resolve_nominal(b) if rb.typename == "interface" then - return is_a(a, rb) + return self:is_a(a, rb) end - local ra = resolve_nominal(a) + local ra = self:resolve_nominal(a) if ra.typename == "union" or rb.typename == "union" then - return is_a(ra, rb) + return self:is_a(ra, rb) end return ok, errs end, - ["*"] = subtype_nominal, + ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a, b) + ["enum"] = function(_self, a, b) if not a.literal then - return false, { Err(a, "string is not a %s", b) } + return false, { Err("%s is not a %s", a, b) } end if b.enumset[a.literal] then return true end - return false, { Err(a, "%s is not a member of %s", a, b) } + return false, { Err("%s is not a member of %s", a, b) } end, }, ["integer"] = { ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + ["interface"] = function(self, a, b) + if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then return true end - return same_type(a, b) + return self:same_type(a, b) end, - ["array"] = subtype_array, - ["record"] = subtype_record, - ["tupletable"] = function(a, b) - return subtype_relations["record"]["tupletable"](a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = TypeChecker.subtype_record, + ["tupletable"] = function(self, a, b) + return self.subtype_relations["record"]["tupletable"](self, a, b) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) for i = 1, math.min(#a.types, #b.types) do - if not is_a(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. + if not self:is_a(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]), } end end if #a.types > #b.types then - return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } + return false, { Err("tuple %s is too big for tuple %s", a, b) } end return true end, - ["record"] = function(a, b) + ["record"] = function(self, a, b) if b.elements then - return subtype_relations["tupletable"]["array"](a, b) + return self.subtype_relations["tupletable"]["array"](self, a, b) end end, - ["array"] = function(a, b) + ["array"] = function(self, a, b) if b.inferred_len and b.inferred_len > #a.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local aa, err = arraytype_from_tuple(a.inferred_at, a) + local aa, err = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then return false, err end - if not is_a(aa, b) then - return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } + if not self:is_a(aa, b) then + return false, { Err("got %s (from %s), expected %s", aa, a, b) } end return true end, - ["map"] = function(a, b) - local aa = arraytype_from_tuple(a.inferred_at, a) + ["map"] = function(self, a, b) + local aa = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then - return false, { Err(a, "Unable to convert tuple %s to map", a) } + return false, { Err("Unable to convert tuple %s to map", a) } end - return compare_map(INTEGER, b.keys, aa.elements, b.values) + return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, }, ["record"] = { - ["record"] = subtype_record, - ["interface"] = function(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + ["record"] = TypeChecker.subtype_record, + ["interface"] = function(self, a, b) + if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then return true end if not a.declname then - return subtype_record(a, b) + return self:subtype_record(a, b) end end, - ["array"] = subtype_array, - ["map"] = function(a, b) - if not is_a(b.keys, STRING) then - return false, { Err(a, "can't match a record to a map with non-string keys") } + ["array"] = TypeChecker.subtype_array, + ["map"] = function(self, a, b) + if not self:is_a(b.keys, a_type(b, "string", {})) then + return false, { Err("can't match a record to a map with non-string keys") } end for _, k in ipairs(a.field_order) do local bk = b.keys if bk.typename == "enum" and not bk.enumset[k] then - return false, { Err(a, "key is not an enum value: " .. k) } + return false, { Err("key is not an enum value: " .. k) } end - if not is_a(a.fields[k], b.values) then - return false, { Err(a, "record is not a valid map; not all fields have the same type") } + if not self:is_a(a.fields[k], b.values) then + return false, { Err("record is not a valid map; not all fields have the same type") } end end return true end, - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) if a.elements then - return subtype_relations["array"]["tupletable"](a, b) + return self.subtype_relations["array"]["tupletable"](self, a, b) end end, }, ["array"] = { - ["array"] = subtype_array, - ["record"] = function(a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = function(self, a, b) if b.elements then - return subtype_array(a, b) + return self:subtype_array(a, b) end end, - ["map"] = function(a, b) - return compare_map(INTEGER, b.keys, a.elements, b.values) + ["map"] = function(self, a, b) + return compare_map(self, a_type(a, "integer", {}), b.keys, a.elements, b.values) end, - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) local alen = a.inferred_len or 0 if alen > #b.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end for i = 1, (alen > 0) and alen or #b.types do - if not is_a(a.elements, b.types[i]) then - return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } + if not self:is_a(a.elements, b.types[i]) then + return false, { Err("tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true end, }, ["map"] = { - ["map"] = function(a, b) - return compare_map(a.keys, b.keys, a.values, b.values) + ["map"] = function(self, a, b) + return compare_map(self, a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a, b) - return compare_map(a.keys, INTEGER, a.values, b.elements) + ["array"] = function(self, a, b) + return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, }, ["typedecl"] = { - ["record"] = function(a, b) + ["record"] = function(self, a, b) local def = a.def if def.fields then - return subtype_record(def, b) + return self:subtype_record(def, b) end end, }, ["function"] = { - ["function"] = function(a, b) + ["function"] = function(self, a, b) local errs = {} local aa, ba = a.args.tuple, b.args.tuple if (not b.args.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) + table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end local ar, br = a.rets.tuple, b.rets.tuple local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) + table.insert(errs, Err("incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) + self:arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8375,36 +8458,36 @@ a.types[i], b.types[i]), } end, }, ["typearg"] = { - ["typearg"] = function(a, b) + ["typearg"] = function(_self, a, b) return a.typearg == b.typearg end, - ["*"] = function(a, b) + ["*"] = function(self, a, b) if a.constraint then - return is_a(a.constraint, b) + return self:is_a(a.constraint, b) end end, }, ["*"] = { ["any"] = compare_true, - ["tuple"] = function(a, b) - return is_a(a_type("tuple", { tuple = { a } }), b) + ["tuple"] = function(self, a, b) + return self:is_a(a_type(a, "tuple", { tuple = { a } }), b) end, - ["typevar"] = function(a, b) - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + ["typevar"] = function(self, a, b) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["typearg"] = function(a, b) + ["typearg"] = function(self, a, b) if b.constraint then - return is_a(a, b.constraint) + return self:is_a(a, b.constraint) end end, - ["union"] = exists_supertype_in, + ["union"] = TypeChecker.exists_supertype_in, - ["nominal"] = subtype_nominal, - ["poly"] = function(a, b) + ["nominal"] = TypeChecker.subtype_nominal, + ["poly"] = function(self, a, b) for _, t in ipairs(b.types) do - if not is_a(a, t) then - return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + if not self:is_a(a, t) then + return false, { Err("cannot match against all alternatives of the polymorphic type") } end end return true @@ -8413,7 +8496,7 @@ a.types[i], b.types[i]), } } - local type_priorities = { + TypeChecker.type_priorities = { ["tuple"] = 2, ["typevar"] = 3, @@ -8442,19 +8525,7 @@ a.types[i], b.types[i]), } ["function"] = 14, } - if lax then - type_priorities["unknown"] = 0 - - subtype_relations["unknown"] = {} - subtype_relations["unknown"]["*"] = compare_true - subtype_relations["*"]["unknown"] = compare_true - - subtype_relations["boolean"] = {} - subtype_relations["boolean"]["boolean"] = compare_true - subtype_relations["*"]["boolean"] = compare_true - end - - local function compare_types(relations, t1, t2) + local function compare_types(self, relations, t1, t2) if t1.typeid == t2.typeid then return true end @@ -8462,8 +8533,8 @@ a.types[i], b.types[i]), } local s1 = relations[t1.typename] local fn = s1 and s1[t2.typename] if not fn then - local p1 = type_priorities[t1.typename] or 999 - local p2 = type_priorities[t2.typename] or 999 + local p1 = self.type_priorities[t1.typename] or 999 + local p2 = self.type_priorities[t2.typename] or 999 fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) end @@ -8472,32 +8543,32 @@ a.types[i], b.types[i]), } if fn == compare_true then return true end - ok, err = fn(t1, t2) + ok, err = fn(self, t1, t2) else ok = t1.typename == t2.typename end if (not ok) and not err then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err("got %s, expected %s", t1, t2) } end return ok, err end - is_a = function(t1, t2) - return compare_types(subtype_relations, t1, t2) + function TypeChecker:is_a(t1, t2) + return compare_types(self, self.subtype_relations, t1, t2) end - same_type = function(t1, t2) + function TypeChecker:same_type(t1, t2) - return compare_types(eqtype_relations, t1, t2) + return compare_types(self, self.eqtype_relations, t1, t2) end if TL_DEBUG then - local orig_is_a = is_a - is_a = function(t1, t2) + local orig_is_a = TypeChecker.is_a + TypeChecker.is_a = function(self, t1, t2) assert(type(t1) == "table") assert(type(t2) == "table") @@ -8507,14 +8578,14 @@ a.types[i], b.types[i]), } return true end - return orig_is_a(t1, t2) + return orig_is_a(self, t1, t2) end end - local function assert_is_a(where, t1, t2, context, name) + function TypeChecker:assert_is_a(w, t1, t2, ctx, name) t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then return true end @@ -8522,24 +8593,27 @@ a.types[i], b.types[i]), } if t1.typename == "nil" then return true elseif t2.typename == "unresolved_emptytable_value" then - if is_number_type(t2.emptytable_type.keys) then - infer_emptytable(t2.emptytable_type, infer_at(where, a_type("array", { elements = t1 }))) + local t2keys = t2.emptytable_type.keys + if is_numeric_type(t2keys) then + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "array", { elements = t1 }))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_type("map", { keys = t2.emptytable_type.keys, values = t1 }))) + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "map", { keys = t2keys, values = t1 }))) end return true elseif t2.typename == "emptytable" then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(where, t1)) + self:infer_emptytable(t2, self:infer_at(w, t1)) elseif not (t1.typename == "emptytable") then - error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) return false end return true end - local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(where, match_errs, errors, context .. ": " .. (name and (name .. ": ") or "")) + local ok, match_errs = self:is_a(t1, t2) + if not ok then + self.errs:add_prefixing(w, match_errs, self.errs:get_context(ctx, name)) + end return ok end @@ -8547,11 +8621,11 @@ a.types[i], b.types[i]), } if t.typename == "invalid" then return false end - if same_type(t, NIL) then + if t.typename == "nil" then return true end if t.typename == "nominal" then - t = resolve_nominal(t) + t = assert(t.resolved) end if t.fields then return t.meta_fields and t.meta_fields["__close"] ~= nil @@ -8569,36 +8643,27 @@ a.types[i], b.types[i]), } return definitely_not_closable_exprs[e.kind] end - local unknown_dots = {} - - local function add_unknown_dot(node, name) - if not unknown_dots[name] then - unknown_dots[name] = true - add_unknown(node, name) - end - end - - local function same_in_all_union_entries(u, check) + function TypeChecker:same_in_all_union_entries(u, check) local t1, f = check(u.types[1]) if not t1 then return nil end for i = 2, #u.types do local t2 = check(u.types[i]) - if not t2 or not same_type(t1, t2) then + if not t2 or not self:same_type(t1, t2) then return nil end end return f or t1 end - local function same_call_mt_in_all_union_entries(u) - return same_in_all_union_entries(u, function(t) - t = to_structural(t) + function TypeChecker:same_call_mt_in_all_union_entries(u) + return self:same_in_all_union_entries(u, function(t) + t = self:to_structural(t) if t.fields then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt.typename == "function" then - local args_tuple = a_type("tuple", { tuple = {} }) + local args_tuple = a_type(u, "tuple", { tuple = {} }) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end @@ -8608,20 +8673,21 @@ a.types[i], b.types[i]), } end) end - local function resolve_for_call(func, args, is_method) + function TypeChecker:resolve_for_call(func, args, is_method) - if lax and is_unknown(func) then - func = a_fn({ args = va_args({ UNKNOWN }), rets = va_args({ UNKNOWN }) }) + if self.feat_lax and is_unknown(func) then + local unk = func + func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) end - func = to_structural(func) + func = self:to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then if func.typename == "union" then - local r = same_call_mt_in_all_union_entries(func) + local r = self:same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) - return to_structural(r), true + return self:to_structural(r), true end end @@ -8635,7 +8701,7 @@ a.types[i], b.types[i]), } if func.fields and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = to_structural(func) + func = self:to_structural(func) is_method = true end end @@ -8655,7 +8721,7 @@ a.types[i], b.types[i]), } local visit_node = { cbs = { ["variable"] = { - after = function(node, _children) + after = function(_, node, _children) local i = argnames[node.tk] if not i then return nil @@ -8668,7 +8734,7 @@ a.types[i], b.types[i]), } after = on_node, } - return recurse_node(root, visit_node, {}) + return recurse_node(nil, root, visit_node, {}) end local function expand_macroexp(orignode, args, macroexp) @@ -8676,7 +8742,7 @@ a.types[i], b.types[i]), } return { Node, args[i] } end - local on_node = function(node, children, ret) + local on_node = function(_, node, children, ret) local orig = ret and ret[2] or node local out = shallow_copy_table(orig) @@ -8705,12 +8771,12 @@ a.types[i], b.types[i]), } orignode.expanded = p[2] end - local function check_macroexp_arg_use(macroexp) + function TypeChecker:check_macroexp_arg_use(macroexp) local used = {} local on_arg_id = function(node, _i) if used[node.tk] then - error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + self.errs:add(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -8733,18 +8799,15 @@ a.types[i], b.types[i]), } orignode.known = saveknown end - - - local type_check_function_call do - local function mark_invalid_typeargs(f) + local function mark_invalid_typeargs(self, f) if f.typeargs then for _, a in ipairs(f.typeargs) do - if not find_var_type(a.typearg) then + if not self:find_var_type(a.typearg) then if a.constraint then - add_var(nil, a.typearg, a.constraint) + self:add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + self:add_var(nil, a.typearg, self.feat_lax and a_type(a, "unknown", {}) or a_type(a, "unresolvable_typearg", { typearg = a.typearg, })) end @@ -8753,7 +8816,7 @@ a.types[i], b.types[i]), } end end - local function infer_emptytables(where, wheres, xs, ys, delta) + local function infer_emptytables(self, w, wheres, xs, ys, delta) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8763,9 +8826,9 @@ a.types[i], b.types[i]), } if x.typename == "emptytable" then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[i + delta] or where - local inferred_y = infer_at(w, y) - infer_emptytable(x, inferred_y) + local iw = wheres and wheres[i + delta] or w + local inferred_y = self:infer_at(iw, y) + self:infer_emptytable(x, inferred_y) xt[i] = inferred_y end end @@ -8775,7 +8838,7 @@ a.types[i], b.types[i]), } local check_args_rets do - local function check_func_type_list(where, wheres, xs, ys, from, delta, v, mode) + local function check_func_type_list(self, w, wheres, xs, ys, from, delta, v, mode) assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8786,11 +8849,11 @@ a.types[i], b.types[i]), } for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[pos] or where - if not arg_check(w, errs, x, y, v, mode, pos) then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8799,7 +8862,7 @@ a.types[i], b.types[i]), } return true end - check_args_rets = function(where, where_args, f, args, expected_rets, argdelta) + check_args_rets = function(self, w, where_args, f, args, expected_rets, argdelta) local rets_ok = true local rets_errs local args_ok @@ -8810,19 +8873,19 @@ a.types[i], b.types[i]), } if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then + if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end if expected_rets then - expected_rets = infer_at(where, expected_rets) - infer_emptytables(where, nil, expected_rets, f.rets, 0) + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") + args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -8830,29 +8893,29 @@ a.types[i], b.types[i]), } - infer_emptytables(where, where_args, args, f.args, argdelta) + infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end end - local function push_typeargs(func) + local function push_typeargs(self, func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { constraint = fnarg.constraint, })) end end end - local function pop_typeargs(func) + local function pop_typeargs(self, func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - if st[#st][fnarg.typearg] then - st[#st][fnarg.typearg] = nil + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end end @@ -8866,12 +8929,9 @@ a.types[i], b.types[i]), } end end - local function fail_call(where, func, nargs, errs) + local function fail_call(self, w, func, nargs, errs) if errs then - - for _, err in ipairs(errs) do - table.insert(errors, err) - end + self.errs:collect(errs) else local expects = {} @@ -8888,34 +8948,34 @@ a.types[i], b.types[i]), } else table.insert(expects, show_arity(func)) end - error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = resolve_function_type(func, 1) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end - local function check_call(where, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) + local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) assert(type(func) == "table") assert(type(args) == "table") local is_method = (argdelta == -1) if not (func.typename == "function" or func.typename == "poly") then - func, is_method = resolve_for_call(func, args, is_method) + func, is_method = self:resolve_for_call(func, args, is_method) if is_method then argdelta = -1 end if not (func.typename == "function" or func.typename == "poly") then - return invalid_at(where, "not a function: %s", func) + return self.errs:invalid_at(w, "not a function: %s", func) end end if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_type("typedecl", { def = args.tuple[1] }))) + self:add_var(nil, "@self", a_type(w, "typedecl", { def = args.tuple[1] })) end local passes, n = 1, 1 @@ -8932,30 +8992,30 @@ a.types[i], b.types[i]), } local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then - if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then + if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then if not is_typedecl_funcall then - add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end else - return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") + return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") end end local wanted = #fargs - local min_arity = feat_arity and f.min_arity or 0 + local min_arity = self.feat_arity and f.min_arity or 0 - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) or (passes == 3 and ((pass == 1 and given == wanted) or - (pass == 2 and given < wanted and (lax or given >= min_arity)) or + (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) or (pass == 3 and f.args.is_va and given > wanted))) then - push_typeargs(f) + push_typeargs(self, f) - local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) + local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) if matched then return matched, f @@ -8964,23 +9024,23 @@ a.types[i], b.types[i]), } if expected_rets then - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) end if passes == 3 then tried = tried or {} tried[i] = true - pop_typeargs(f) + pop_typeargs(self, f) end end end end end - return fail_call(where, func, given, first_errs) + return fail_call(self, w, func, given, first_errs) end - type_check_function_call = function(node, func, args, argdelta, e1, e2) + function TypeChecker:type_check_function_call(node, func, args, argdelta, e1, e2) e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -8989,14 +9049,14 @@ a.types[i], b.types[i]), } if expected and expected.typename == "tuple" then expected_rets = expected else - expected_rets = a_type("tuple", { tuple = { node.expected } }) + expected_rets = a_type(node, "tuple", { tuple = { node.expected } }) end - begin_scope() + self:begin_scope() local is_typedecl_funcall - if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then - local receiver = node.e1.receiver + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver if receiver.typename == "nominal" then local resolved = receiver.resolved if resolved and resolved.typename == "typedecl" then @@ -9005,12 +9065,12 @@ a.types[i], b.types[i]), } end end - local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) - ret = resolve_typevars_at(node, ret) - end_scope() + local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) + self:end_scope() - if tc and e1 then - tc.store_type(e1.y, e1.x, f) + if self.collector then + self.collector.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -9021,9 +9081,9 @@ a.types[i], b.types[i]), } end end - local function check_metamethod(node, method_name, a, b, orig_a, orig_b) - if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then - return UNKNOWN, nil + function TypeChecker:check_metamethod(node, method_name, a, b, orig_a, orig_b) + if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then + return a_type(node, "unknown", {}), nil end local ameta = a.fields and a.meta_fields local bmeta = b and b.fields and b.meta_fields @@ -9044,26 +9104,26 @@ a.types[i], b.types[i]), } if metamethod then local e2 = { node.e1 } - local args = a_type("tuple", { tuple = { orig_a } }) + local args = a_type(node, "tuple", { tuple = { orig_a } }) if b and method_name ~= "__is" then e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end end - local function match_record_key(tbl, rec, key) + function TypeChecker:match_record_key(tbl, rec, key) assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") - tbl = to_structural(tbl) + tbl = self:to_structural(tbl) if tbl.typename == "string" or tbl.typename == "enum" then - tbl = find_var_type("string") + tbl = self:find_var_type("string") end if tbl.typename == "typedecl" then @@ -9072,13 +9132,13 @@ a.types[i], b.types[i]), } if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else - tbl = resolve_nominal(tbl.alias_to) + tbl = self:resolve_nominal(tbl.alias_to) end end if tbl.typename == "union" then - local t = same_in_all_union_entries(tbl, function(t) - return (match_record_key(t, rec, key)) + local t = self:same_in_all_union_entries(tbl, function(t) + return (self:match_record_key(t, rec, key)) end) if t then @@ -9087,7 +9147,7 @@ a.types[i], b.types[i]), } end if (tbl.typename == "typevar" or tbl.typename == "typearg") and tbl.constraint then - local t = match_record_key(tbl.constraint, rec, key) + local t = self:match_record_key(tbl.constraint, rec, key) if t then return t @@ -9101,7 +9161,8 @@ a.types[i], b.types[i]), } return tbl.fields[key] end - local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) + local str = a_type(rec, "string", {}) + local meta_t = self:check_metamethod(rec, "__index", tbl, str, tbl, str) if meta_t then return meta_t end @@ -9116,8 +9177,8 @@ a.types[i], b.types[i]), } return nil, "invalid key '" .. key .. "' in type %s" end elseif tbl.typename == "emptytable" or is_unknown(tbl) then - if lax then - return INVALID + if self.feat_lax then + return a_type(rec, "unknown", {}) end return nil, "cannot index a value of unknown type" end @@ -9129,30 +9190,35 @@ a.types[i], b.types[i]), } end end - local function widen_in_scope(scope, var) - assert(scope[var], "no " .. var .. " in scope") - local narrow_mode = scope[var].is_narrowed - if narrow_mode and narrow_mode ~= "declaration" then - if scope[var].narrowed_from then - scope[var].t = scope[var].narrowed_from - scope[var].narrowed_from = nil - scope[var].is_narrowed = nil - else - scope[var] = nil - end + function TypeChecker:widen_in_scope(scope, var) + local v = scope.vars[var] + assert(v, "no " .. var .. " in scope") + local narrow_mode = scope.vars[var].is_narrowed + if (not narrow_mode) or narrow_mode == "declaration" then + return false + end - local unresolved = get_unresolved(scope) - unresolved.narrows[var] = nil - return true + if v.narrowed_from then + v.t = v.narrowed_from + v.narrowed_from = nil + v.is_narrowed = nil + else + scope.vars[var] = nil + end + + if scope.narrows then + scope.narrows[var] = nil end - return false + + return true end - local function widen_back_var(name) + function TypeChecker:widen_back_var(name) local widened = false - for i = #st, 1, -1 do - if st[i][name] then - if widen_in_scope(st[i], name) then + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.vars[name] then + if self:widen_in_scope(scope, name) then widened = true else break @@ -9166,7 +9232,7 @@ a.types[i], b.types[i]), } local visit_node = { cbs = { ["assignment"] = { - after = function(node, _children) + after = function(_, node, _children) for _, v in ipairs(node.vars) do if v.kind == "variable" and v.tk == name then return true @@ -9176,7 +9242,7 @@ a.types[i], b.types[i]), } end, }, }, - after = function(_node, children, ret) + after = function(_, _node, children, ret) ret = ret or false for _, c in ipairs(children) do local ca = c @@ -9194,118 +9260,82 @@ a.types[i], b.types[i]), } end, } - return recurse_node(root, visit_node, visit_type) + return recurse_node(nil, root, visit_node, visit_type) end - local function widen_all_unions(node) - for i = #st, 1, -1 do - local scope = st[i] - local unresolved = find_unresolved(i) - if unresolved and unresolved.narrows then - for name, _ in pairs(unresolved.narrows) do + function TypeChecker:widen_all_unions(node) + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.narrows then + for name, _ in pairs(scope.narrows) do if not node or assigned_anywhere(name, node) then - widen_in_scope(scope, name) + self:widen_in_scope(scope, name) end end end end end - local function add_global(node, var, valtype, is_assigning) - if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then - add_unknown(node, var) + function TypeChecker:add_global(node, varname, valtype, is_assigning) + if self.feat_lax and is_unknown(valtype) and (varname ~= "self" and varname ~= "...") then + self.errs:add_unknown(node, varname) end local is_const = node.attribute ~= nil - local existing, scope, existing_attr = find_var(var) + local existing, scope, existing_attr = self:find_var(varname) if existing then if scope > 1 then - error_at(node, "cannot define a global when a local with the same name is in scope") + self.errs:add(node, "cannot define a global when a local with the same name is in scope") elseif is_assigning and existing_attr then - error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + self.errs:add(node, "cannot reassign to <" .. existing_attr .. "> global: " .. varname) elseif existing_attr and not is_const then - error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + self.errs:add(node, "global was previously declared as <" .. existing_attr .. ">: " .. varname) elseif (not existing_attr) and is_const then - error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - elseif valtype and not same_type(existing.t, valtype) then - error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + self.errs:add(node, "global was previously declared as not <" .. node.attribute .. ">: " .. varname) + elseif valtype and not self:same_type(existing.t, valtype) then + self.errs:add(node, "cannot redeclare global with a different type: previous type of " .. varname .. " is %s", existing.t) end return nil end - st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - - return st[1][var] - end - - local get_rets - if lax then - get_rets = function(rets) - if #rets.tuple == 0 then - return a_vararg({ UNKNOWN }) - end - return rets - end - else - get_rets = function(rets) - return rets - end - end - - local function add_internal_function_variables(node, args) - add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_type("tuple", { tuple = {} })) + local var = { t = valtype, attribute = is_const and "const" or nil } + self.st[1].vars[varname] = var - if node.typeargs then - for _, t in ipairs(node.typeargs) do - local v = find_var(t.typearg, "check_only") - if not v or not v.used_as_type then - error_at(t, "type argument '%s' is not used in function signature", t) - end - end - end + return var end - local function add_function_definition_for_recursion(node, fnargs) - add_var(nil, node.name.tk, type_at(node, a_function({ - min_arity = node.min_arity, - typeargs = node.typeargs, - args = fnargs, - rets = get_rets(node.rets), - }))) - end + function TypeChecker:add_internal_function_variables(node, args) + self:add_var(nil, "@is_va", a_type(node, args.is_va and "any" or "nil", {})) + self:add_var(nil, "@return", node.rets or a_type(node, "tuple", { tuple = {} })) - local function fail_unresolved() - local unresolved = st[#st]["@unresolved"] - if unresolved then - st[#st]["@unresolved"] = nil - local unrt = unresolved.t - for name, nodes in pairs(unrt.labels) do - for _, node in ipairs(nodes) do - error_at(node, "no visible label '" .. name .. "' for goto") - end - end - for name, types in pairs(unrt.nominals) do - if not unrt.global_types[name] then - for _, typ in ipairs(types) do - assert(typ.x) - assert(typ.y) - error_at(typ, "unknown type %s", typ) - end + if node.typeargs then + for _, t in ipairs(node.typeargs) do + local v = self:find_var(t.typearg, "check_only") + if not v or not v.used_as_type then + self.errs:add(t, "type argument '%s' is not used in function signature", t) end end end end - local function end_function_scope(node) - fail_unresolved() - end_scope(node) + function TypeChecker:add_function_definition_for_recursion(node, fnargs) + self:add_var(nil, node.name.tk, a_function(node, { + min_arity = node.min_arity, + typeargs = node.typeargs, + args = fnargs, + rets = self.get_rets(node.rets), + })) + end + + function TypeChecker:end_function_scope(node) + self.errs:fail_unresolved_labels(self.st[#self.st]) + self:end_scope(node) end local function flatten_tuple(vals) local vt = vals.tuple local n_vals = #vt - local ret = a_type("tuple", { tuple = {} }) + local ret = a_type(vals, "tuple", { tuple = {} }) local rt = ret.tuple if n_vals == 0 then @@ -9333,9 +9363,9 @@ a.types[i], b.types[i]), } return ret end - local function get_assignment_values(vals, wanted) + local function get_assignment_values(w, vals, wanted) if vals == nil then - return a_type("tuple", { tuple = {} }) + return a_type(w, "tuple", { tuple = {} }) end local ret = flatten_tuple(vals) @@ -9354,14 +9384,14 @@ a.types[i], b.types[i]), } return ret end - local function match_all_record_field_names(node, a, field_names, errmsg) + function TypeChecker:match_all_record_field_names(node, a, field_names, errmsg) local t for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else - if not same_type(f, t) then + if not self:same_type(f, t) then errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k) t = nil break @@ -9371,26 +9401,26 @@ a.types[i], b.types[i]), } if t then return t else - return invalid_at(node, errmsg) + return self.errs:invalid_at(node, errmsg) end end - local function type_check_index(anode, bnode, a, b) + function TypeChecker:type_check_index(anode, bnode, a, b) assert(not (a.typename == "tuple")) assert(not (b.typename == "tuple")) - local ra = resolve_typedecl(to_structural(a)) - local rb = to_structural(b) + local ra = resolve_typedecl(self:to_structural(a)) + local rb = self:to_structural(b) - if lax and is_unknown(a) then - return UNKNOWN + if self.feat_lax and is_unknown(a) then + return a end local errm local erra local errb - if ra.typename == "tupletable" and is_a(rb, INTEGER) then + if ra.typename == "tupletable" and rb.typename == "integer" then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then return ra.types[bnode.constnum] @@ -9398,38 +9428,35 @@ a.types[i], b.types[i]), } errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, ra) + local array_type = self:arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif ra.elements and is_a(rb, INTEGER) then + elseif ra.elements and rb.typename == "integer" then return ra.elements elseif ra.typename == "emptytable" then if ra.keys == nil then - ra.keys = infer_at(anode, b) + ra.keys = self:infer_at(bnode, b) end - if is_a(b, ra.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { + if self:is_a(b, ra.keys) then + return a_type(anode, "unresolved_emptytable_value", { emptytable_type = ra, - })) + }) end - errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. - ra.keys.inferred_at.filename .. ":" .. - ra.keys.inferred_at.y .. ":" .. - ra.keys.inferred_at.x .. ": )", b, ra.keys + errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys elseif ra.typename == "map" then - if is_a(b, ra.keys) then + if self:is_a(b, ra.keys) then return ra.values end errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys elseif rb.typename == "string" and rb.literal then - local t, e = match_record_key(a, anode, rb.literal) + local t, e = self:match_record_key(a, anode, rb.literal) if t then return t end @@ -9445,10 +9472,10 @@ a.types[i], b.types[i]), } end end if not errm then - return match_all_record_field_names(bnode, ra, field_names, + return self:match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(rb, STRING) then + elseif rb.typename == "string" then errm, erra = "cannot index object of type %s with a string, consider using an enum", a else errm, erra, errb = "cannot index object of type %s with %s", a, b @@ -9457,28 +9484,28 @@ a.types[i], b.types[i]), } errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", ra, b, a, b) + local meta_t = self:check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end - return invalid_at(bnode, errm, erra, errb) + return self.errs:invalid_at(bnode, errm, erra, errb) end - expand_type = function(where, old, new) + function TypeChecker:expand_type(w, old, new) if not old or old.typename == "nil" then return new else - if not is_a(new, old) then + if not self:is_a(new, old) then if old.typename == "map" and new.fields then local old_keys = old.keys if old_keys.typename == "string" then for _, ftype in fields_of(new) do - old.values = expand_type(where, old.values, ftype) + old.values = self:expand_type(w, old.values, ftype) end - edit_type(old, "map") + edit_type(w, old, "map") else - error_at(where, "cannot determine table literal type") + self.errs:add(w, "cannot determine table literal type") end elseif old.fields and new.fields then local values @@ -9486,14 +9513,14 @@ a.types[i], b.types[i]), } if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end for _, ftype in fields_of(new) do if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end old.fields = nil @@ -9501,25 +9528,25 @@ a.types[i], b.types[i]), } old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") + edit_type(w, old, "map") local map = old - map.keys = STRING + map.keys = a_type(w, "string", {}) map.values = values elseif old.typename == "union" then - edit_type(old, "union") + edit_type(w, old, "union") table.insert(old.types, drop_constant_value(new)) else - return unite({ old, new }, true) + return unite(w, { old, new }, true) end end end return old end - local function find_record_to_extend(exp) + function TypeChecker:find_record_to_extend(exp) if exp.kind == "type_identifier" then - local v = find_var(exp.tk) + local v = self:find_var(exp.tk) if not v then return nil, nil, exp.tk end @@ -9536,7 +9563,7 @@ a.types[i], b.types[i]), } return t, v, exp.tk elseif exp.kind == "op" then - local t, v, rname = find_record_to_extend(exp.e1) + local t, v, rname = self:find_record_to_extend(exp.e1) local fname = exp.e2.tk local dname = rname .. "." .. fname if not t then @@ -9557,30 +9584,29 @@ a.types[i], b.types[i]), } end end - local function typedecl_to_nominal(where, name, t, resolved) + local function typedecl_to_nominal(node, name, t, resolved) local typevals local def = t.def if def.typeargs then typevals = {} for _, a in ipairs(def.typeargs) do - table.insert(typevals, a_type("typevar", { + table.insert(typevals, a_type(a, "typevar", { typevar = a.typearg, constraint = a.constraint, })) end end - return type_at(where, a_type("nominal", { - typevals = typevals, - names = { name }, - found = t, - resolved = resolved, - })) + local nom = a_nominal(node, { name }) + nom.typevals = typevals + nom.found = t + nom.resolved = resolved + return nom end - local function get_self_type(exp) + function TypeChecker:get_self_type(exp) if exp.kind == "type_identifier" then - local t = find_var_type(exp.tk) + local t = self:find_var_type(exp.tk) if not t then return nil end @@ -9592,7 +9618,7 @@ a.types[i], b.types[i]), } end elseif exp.kind == "op" then - local t = get_self_type(exp.e1) + local t = self:get_self_type(exp.e1) if not t then return nil end @@ -9624,7 +9650,6 @@ a.types[i], b.types[i]), } local facts_and local facts_or local facts_not - local apply_facts local FACT_TRUTHY do local IsFact_mt = { @@ -9636,6 +9661,7 @@ a.types[i], b.types[i]), } setmetatable(IsFact, { __call = function(_, fact) fact.fact = "is" + assert(fact.w) return setmetatable(fact, IsFact_mt) end, }) @@ -9649,6 +9675,7 @@ a.types[i], b.types[i]), } setmetatable(EqFact, { __call = function(_, fact) fact.fact = "==" + assert(fact.w) return setmetatable(fact, EqFact_mt) end, }) @@ -9707,57 +9734,57 @@ a.types[i], b.types[i]), } FACT_TRUTHY = TruthyFact({}) - facts_and = function(where, f1, f2) - return AndFact({ f1 = f1, f2 = f2, where = where }) + facts_and = function(w, f1, f2) + return AndFact({ f1 = f1, f2 = f2, w = w }) end - facts_or = function(where, f1, f2) + facts_or = function(w, f1, f2) if f1 and f2 then - return OrFact({ f1 = f1, f2 = f2, where = where }) + return OrFact({ f1 = f1, f2 = f2, w = w }) else return nil end end - facts_not = function(where, f1) + facts_not = function(w, f1) if f1 then - return NotFact({ f1 = f1, where = where }) + return NotFact({ f1 = f1, w = w }) else return nil end end - local function unite_types(t1, t2) - return unite({ t2, t1 }) + local function unite_types(w, t1, t2) + return unite(w, { t2, t1 }) end - local function intersect_types(t1, t2) + local function intersect_types(self, w, t1, t2) if t2.typename == "union" then t1, t2 = t2, t1 end if t1.typename == "union" then local out = {} for _, t in ipairs(t1.types) do - if is_a(t, t2) then + if self:is_a(t, t2) then table.insert(out, t) end end - return unite(out) + return unite(w, out) else - if is_a(t1, t2) then + if self:is_a(t1, t2) then return t1 - elseif is_a(t2, t1) then + elseif self:is_a(t2, t1) then return t2 else - return NIL + return a_type(w, "nil", {}) end end end - local function resolve_if_union(t) - local rt = to_structural(t) + function TypeChecker:resolve_if_union(t) + local rt = self:to_structural(t) if rt.typename == "union" then return rt end @@ -9765,23 +9792,23 @@ a.types[i], b.types[i]), } end - local function subtract_types(t1, t2) + local function subtract_types(self, w, t1, t2) local types = {} - t1 = resolve_if_union(t1) + t1 = self:resolve_if_union(t1) if not (t1.typename == "union") then return t1 end - t2 = resolve_if_union(t2) + t2 = self:resolve_if_union(t2) local t2types = t2.typename == "union" and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true for _, bt in ipairs(t2types) do - if same_type(at, bt) then + if self:same_type(at, bt) then not_present = false break end @@ -9792,10 +9819,10 @@ a.types[i], b.types[i]), } end if #types == 0 then - return NIL + return a_type(w, "nil", {}) end - return unite(types) + return unite(w, types) end local eval_not @@ -9805,65 +9832,65 @@ a.types[i], b.types[i]), } local eval_fact local function invalid_from(f) - return IsFact({ fact = "is", var = f.var, typ = INVALID, where = f.where }) + return IsFact({ fact = "is", var = f.var, typ = a_type(f.w, "invalid", {}), w = f.w }) end - not_facts = function(fs) + not_facts = function(self, fs) local ret = {} for var, f in pairs(fs) do - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then - ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + ret[var] = EqFact({ var = var, typ = a_type(f.w, "invalid", {}), w = f.w, no_infer = f.no_infer }) elseif f.fact == "==" then - ret[var] = EqFact({ var = var, typ = typ }) + ret[var] = EqFact({ var = var, typ = typ, w = f.w, no_infer = true }) elseif typ.typename == "typevar" then assert(f.fact == "is") - ret[var] = EqFact({ var = var, typ = typ }) - elseif not is_a(f.typ, typ) then + ret[var] = EqFact({ var = var, typ = typ, w = f.w, no_infer = true }) + elseif not self:is_a(f.typ, typ) then assert(f.fact == "is") - add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + self.errs:add_warning("branch", f.w, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact({ var = var, typ = a_type(f.w, "invalid", {}), w = f.w, no_infer = f.no_infer }) else assert(f.fact == "is") - ret[var] = IsFact({ var = var, typ = subtract_types(typ, f.typ), where = f.where }) + ret[var] = IsFact({ var = var, typ = subtract_types(self, f.w, typ, f.typ), w = f.w, no_infer = f.no_infer }) end end return ret end - eval_not = function(f) + eval_not = function(self, f) if not f then return {} elseif f.fact == "is" then - return not_facts({ [f.var] = f }) + return not_facts(self, { [f.var] = f }) elseif f.fact == "not" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "and" then - return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return or_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) elseif f.fact == "or" then - return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return and_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) else - return not_facts(eval_fact(f)) + return not_facts(self, eval_fact(self, f)) end end - or_facts = function(fs1, fs2) + or_facts = function(_self, fs1, fs2) local ret = {} for var, f in pairs(fs2) do if fs1[var] then - local united = unite_types(f.typ, fs1[var].typ) + local united = unite_types(f.w, f.typ, fs1[var].typ) if fs1[var].fact == "is" and f.fact == "is" then - ret[var] = IsFact({ var = var, typ = united, where = f.where }) + ret[var] = IsFact({ var = var, typ = united, w = f.w }) else - ret[var] = EqFact({ var = var, typ = united, where = f.where }) + ret[var] = EqFact({ var = var, typ = united, w = f.w }) end end end @@ -9871,7 +9898,7 @@ a.types[i], b.types[i]), } return ret end - and_facts = function(fs1, fs2) + and_facts = function(self, fs1, fs2) local ret = {} local has = {} @@ -9882,18 +9909,18 @@ a.types[i], b.types[i]), } if fs2[var].fact == "is" and f.fact == "is" then ctor = IsFact end - rt = intersect_types(f.typ, fs2[var].typ) + rt = intersect_types(self, f.w, f.typ, fs2[var].typ) else rt = f.typ end - local ff = ctor({ var = var, typ = rt, where = f.where }) + local ff = ctor({ var = var, typ = rt, w = f.w, no_infer = f.no_infer }) ret[var] = ff has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = EqFact({ var = var, typ = f.typ, where = f.where }) + ret[var] = EqFact({ var = var, typ = f.typ, w = f.w, no_infer = f.no_infer }) has["=="] = true end end @@ -9907,21 +9934,21 @@ a.types[i], b.types[i]), } return ret end - eval_fact = function(f) + eval_fact = function(self, f) if not f then return {} elseif f.fact == "is" then - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } end if typ.typename ~= "typevar" then - if is_a(typ, f.typ) then + if self:is_a(typ, f.typ) then return { [f.var] = f } - elseif not is_a(f.typ, typ) then - error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + elseif not self:is_a(f.typ, typ) then + self.errs:add(f.w, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -9929,63 +9956,60 @@ a.types[i], b.types[i]), } elseif f.fact == "==" then return { [f.var] = f } elseif f.fact == "not" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "truthy" then return {} elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "and" then - return and_facts(eval_fact(f.f1), eval_fact(f.f2)) + return and_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) elseif f.fact == "or" then - return or_facts(eval_fact(f.f1), eval_fact(f.f2)) + return or_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) end end - apply_facts = function(where, known) + function TypeChecker:apply_facts(w, known) if not known then return end - local facts = eval_fact(known) + local facts = eval_fact(self, known) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - error_at(where, "cannot resolve a type for " .. v .. " here") + self.errs:add(w, "cannot resolve a type for " .. v .. " here") end - local t = infer_at(where, f.typ) - if not f.where then + local t = f.no_infer and f.typ or self:infer_at(w, f.typ) + if f.no_infer then t.inferred_at = nil end - add_var(nil, v, t, "const", "narrow") + self:add_var(nil, v, t, "const", "narrow") end end end - local function dismiss_unresolved(name) - for i = #st, 1, -1 do - local unresolved = find_unresolved(i) - if unresolved then - local uses = unresolved.nominals[name] - if uses then - for _, t in ipairs(uses) do - resolve_nominal(t) - end - unresolved.nominals[name] = nil - return + function TypeChecker:dismiss_unresolved(name) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local uses = scope.pending_nominals and scope.pending_nominals[name] + if uses then + for _, t in ipairs(uses) do + self:resolve_nominal(t) end + scope.pending_nominals[name] = nil + return end end end - local type_check_funcall - - local function special_pcall_xpcall(node, _a, b, argdelta) + local function special_pcall_xpcall(self, node, _a, b, argdelta) local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 + local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then - error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_type("tuple", { tuple = { BOOLEAN } }) + self.errs:add(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + return a_type(node, "tuple", { tuple = { bool } }) end @@ -9997,137 +10021,142 @@ a.types[i], b.types[i]), } ftype.is_method = false end - local fe2 = {} + local fe2 = node_at(node.e2, {}) if node.e1.tk == "xpcall" then base_nargs = 2 + local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) - assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") + local msgh_type = a_function(arg2, { + min_arity = 1, + args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), + rets = a_type(arg2, "tuple", { tuple = {} }), + }) + self:assert_is_a(arg2, msgh, msgh_type, "in message handler") end for i = base_nargs + 1, #node.e2 do table.insert(fe2, node.e2[i]) end - local fnode = { - y = node.y, - x = node.x, + local fnode = node_at(node, { kind = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, - } - local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) + }) + local rets = self:type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets.typename == "invalid" then return rets end - table.insert(rets.tuple, 1, BOOLEAN) + table.insert(rets.tuple, 1, bool) return rets end local special_functions = { - ["pairs"] = function(node, a, b, argdelta) + ["pairs"] = function(self, node, a, b, argdelta) if not b.tuple[1] then - return invalid_at(node, "pairs requires an argument") + return self.errs:invalid_at(node, "pairs requires an argument") end - local t = to_structural(b.tuple[1]) + local t = self:to_structural(b.tuple[1]) if t.elements then - add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + self.errs:add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then - if not (lax and is_unknown(t)) then + if not (self.feat_lax and is_unknown(t)) then if t.fields then - match_all_record_field_names(node.e2, t, t.field_order, + self:match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + self.errs:add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else - error_at(node.e2, "cannot apply pairs on values of type: %s", t) + self.errs:add(node.e2, "cannot apply pairs on values of type: %s", t) end end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(node, a, b, argdelta) + ["ipairs"] = function(self, node, a, b, argdelta) if not b.tuple[1] then - return invalid_at(node, "ipairs requires an argument") + return self.errs:invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = to_structural(orig_t) + local t = self:to_structural(orig_t) if t.typename == "tupletable" then - local arr_type = arraytype_from_tuple(node.e2, t) + local arr_type = self:arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) + return self.errs:invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not t.elements then - if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) + if not (self.feat_lax and (is_unknown(t) or t.typename == "emptytable")) then + return self.errs:invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(node, _a, b, _argdelta) + ["rawget"] = function(self, node, _a, b, _argdelta) if #b.tuple == 2 then - return a_type("tuple", { tuple = { type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) + return a_type(node, "tuple", { tuple = { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) else - return invalid_at(node, "rawget expects two arguments") + return self.errs:invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node, _a, b, _argdelta) + ["require"] = function(self, node, _a, b, _argdelta) if #b.tuple ~= 1 then - return invalid_at(node, "require expects one literal argument") + return self.errs:invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return a_type("tuple", { tuple = { a_type("any", {}) } }) + return a_type(node, "tuple", { tuple = { a_type(node, "any", {}) } }) end local module_name = assert(node.e2[1].conststr) - local t, found = require_module(module_name, lax, env) - if not found then - return invalid_at(node, "module not found: '" .. module_name .. "'") - end + local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) if t.typename == "invalid" then - if lax then - return a_type("tuple", { tuple = { UNKNOWN } }) + if not module_filename then + return self.errs:invalid_at(node, "module not found: '" .. module_name .. "'") end - return invalid_at(node, "no type information for required module: '" .. module_name .. "'") + + if self.feat_lax then + return a_type(node, "tuple", { tuple = { a_type(node, "unknown", {}) } }) + end + return self.errs:invalid_at(node, "no type information for required module: '" .. module_name .. "'") end - dependencies[module_name] = t.filename - return type_at(node, a_type("tuple", { tuple = { t } })) + self.dependencies[module_name] = module_filename + return a_type(node, "tuple", { tuple = { t } }) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node, a, b, argdelta) + ["assert"] = function(self, node, a, b, argdelta) node.known = FACT_TRUTHY - local r = type_check_function_call(node, a, b, argdelta) - apply_facts(node, node.e2[1].known) + local r = self:type_check_function_call(node, a, b, argdelta) + self:apply_facts(node, node.e2[1].known) return r end, } - type_check_funcall = function(node, a, b, argdelta) + function TypeChecker:type_check_funcall(node, a, b, argdelta) argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] if special then - return special(node, a, b, argdelta) + return special(self, node, a, b, argdelta) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, a, b, -1)) + return (self:type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end end @@ -10139,19 +10168,19 @@ a.types[i], b.types[i]), } node.exps[i].tk == node.vars[i].tk end - local function missing_initializer(node, i, name) - if lax then - return UNKNOWN + function TypeChecker:missing_initializer(node, i, name) + if self.feat_lax then + return a_type(node, "unknown", {}) else if node.exps then - return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return self.errs:invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return self.errs:invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end - local function set_expected_types_to_decltuple(node, children) + local function set_expected_types_to_decltuple(_, node, children) local decltuple = node.kind == "assignment" and children[1] or node.decltuple assert(decltuple.typename == "tuple") local decls = decltuple.tuple @@ -10163,7 +10192,7 @@ a.types[i], b.types[i]), } typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = type_at(node, a_type("tuple", { tuple = {} })) + typ = a_type(node, "tuple", { tuple = {} }) for a = i, ndecl do table.insert(typ.tuple, decls[a]) end @@ -10179,38 +10208,7 @@ a.types[i], b.types[i]), } return n and n >= 1 and math.floor(n) == n end - local context_name = { - ["local_declaration"] = "in local declaration", - ["global_declaration"] = "in global declaration", - ["assignment"] = "in assignment", - } - - local function in_context(ctx, msg) - if not ctx then - return msg - end - local where = context_name[ctx.kind] - if where then - return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg - else - return msg - end - end - - - - local function check_redeclared_key(where, ctx, seen_keys, key) - if key ~= nil then - local s = seen_keys[key] - if s then - error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) - else - seen_keys[key] = where - end - end - end - - local function infer_table_literal(node, children) + local function infer_table_literal(self, node, children) local is_record = false local is_array = false local is_map = false @@ -10235,14 +10233,15 @@ a.types[i], b.types[i]), } for i, child in ipairs(children) do local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b = nil - if child.ktype.typename == "boolean" then + if cktype.typename == "boolean" then b = (node[i].key.tk == "true") end local key = ck or n or b - check_redeclared_key(node[i], nil, seen_keys, key) + self.errs:check_redeclared_key(node[i], nil, seen_keys, key) local uvtype = resolve_tuple(child.vtype) if ck then @@ -10253,7 +10252,7 @@ a.types[i], b.types[i]), } end fields[ck] = uvtype table.insert(field_order, ck) - elseif is_number_type(child.ktype) then + elseif is_numeric_type(cktype) then is_array = true if not is_not_tuple then is_tuple = true @@ -10267,25 +10266,25 @@ a.types[i], b.types[i]), } if i == #children and cv.typename == "tuple" then for _, c in ipairs(cv.tuple) do - elements = expand_type(node, elements, c) + elements = self:expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end else if not is_positive_int(n) then - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n] = uvtype if n > largest_array_idx then largest_array_idx = n end - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end end @@ -10297,37 +10296,37 @@ a.types[i], b.types[i]), } end else is_map = true - keys = expand_type(node, keys, drop_constant_value(child.ktype)) - values = expand_type(node, values, uvtype) + keys = self:expand_type(node, keys, drop_constant_value(cktype)) + values = self:expand_type(node, values, uvtype) end end local t if is_array and is_map then - error_at(node, "cannot determine type of table literal") - t = a_type("map", { keys = -expand_type(node, keys, INTEGER), values = + self.errs:add(node, "cannot determine type of table literal") + t = a_type(node, "map", { keys = +self:expand_type(node, keys, a_type(node, "integer", {})), values = -expand_type(node, values, elements) }) +self:expand_type(node, values, elements) }) elseif is_record and is_array then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, elements = elements, interface_list = { - type_at(node, a_type("array", { elements = elements })), + a_type(node, "array", { elements = elements }), }, }) elseif is_record and is_map then if keys.typename == "string" then for _, fname in ipairs(field_order) do - values = expand_type(node, values, fields[fname]) + values = self:expand_type(node, values, fields[fname]) end - t = a_type("map", { keys = keys, values = values }) + t = a_type(node, "map", { keys = keys, values = values }) else - error_at(node, "cannot determine type of table literal") + self.errs:add(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -10335,7 +10334,7 @@ expand_type(node, values, elements) }) local last_t for _, current_t in pairs(types) do if last_t then - if not same_type(last_t, current_t) then + if not self:same_type(last_t, current_t) then pure_array = false break end @@ -10344,69 +10343,70 @@ expand_type(node, values, elements) }) end end if pure_array then - t = a_type("array", { elements = elements }) + t = a_type(node, "array", { elements = elements }) t.consttypes = types t.inferred_len = largest_array_idx - 1 else - t = a_type("tupletable", {}) + t = a_type(node, "tupletable", { inferred_at = node }) t.types = types end elseif is_record then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, }) elseif is_map then - t = a_type("map", { keys = keys, values = values }) + t = a_type(node, "map", { keys = keys, values = values }) elseif is_tuple then - t = a_type("tupletable", {}) + t = a_type(node, "tupletable", { inferred_at = node }) t.types = types if not types or #types == 0 then - error_at(node, "cannot determine type of tuple elements") + self.errs:add(node, "cannot determine type of tuple elements") end end if not t then - t = a_type("emptytable", {}) + t = a_type(node, "emptytable", {}) end return type_at(node, t) end - local function infer_negation_of_if_blocks(where, ifnode, n) - local f = facts_not(where, ifnode.if_blocks[1].exp.known) + function TypeChecker:infer_negation_of_if_blocks(w, ifnode, n) + local f = facts_not(w, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] if b.exp then - f = facts_and(where, f, facts_not(where, b.exp.known)) + f = facts_and(w, f, facts_not(w, b.exp.known)) end end - apply_facts(where, f) + self:apply_facts(w, f) end - local function determine_declaration_type(var, node, infertypes, i) + function TypeChecker:determine_declaration_type(var, node, infertypes, i) local ok = true local name = var.tk local infertype = infertypes and infertypes.tuple[i] - if lax and infertype and infertype.typename == "nil" then + if self.feat_lax and infertype and infertype.typename == "nil" then infertype = nil end local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if to_structural(decltype) == INVALID then - decltype = INVALID + local rdecltype = self:to_structural(decltype) + if rdecltype.typename == "invalid" then + decltype = rdecltype end if infertype then - ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) + local w = node.exps and node.exps[i] or node.vars[i] + ok = self:assert_is_a(w, infertype, decltype, context_name[node.kind], name) end else if infertype then if infertype.typename == "unresolvable_typearg" then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false - infertype = INVALID + infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") elseif infertype.typename == "function" and infertype.is_method then @@ -10418,17 +10418,17 @@ expand_type(node, values, elements) }) end if var.attribute == "total" then - local rd = decltype and to_structural(decltype) + local rd = decltype and self:to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - error_at(var, "attribute only applies to maps and records") + self.errs:add(var, "attribute only applies to maps and records") ok = false elseif not infertype then - error_at(var, "variable declared does not declare an initialization value") + self.errs:add(var, "variable declared does not declare an initialization value") ok = false else local valnode = node.exps[i] if not valnode or valnode.kind ~= "literal_table" then - error_at(var, "attribute only applies to literal tables") + self.errs:add(var, "attribute only applies to literal tables") ok = false else if not valnode.is_total then @@ -10436,12 +10436,12 @@ expand_type(node, values, elements) }) if valnode.missing then missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" end - local ri = to_structural(infertype) + local ri = self:to_structural(infertype) if ri.typename == "map" then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + self.errs:add(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri.typename == "record" then - error_at(var, "record variable declared does not declare values for all fields" .. missing) + self.errs:add(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -10451,34 +10451,36 @@ expand_type(node, values, elements) }) local t = decltype or infertype if t == nil then - t = missing_initializer(node, i, name) + t = self:missing_initializer(node, i, name) elseif t.typename == "emptytable" then t.declared_at = node t.assigned_to = name elseif t.elements then t.inferred_len = nil + elseif t.typename == "nominal" then + self:resolve_nominal(t) end return ok, t, infertype ~= nil end - local function get_typedecl(value) + function TypeChecker:get_typedecl(value) if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) + local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0) local ty = t.typename == "tuple" and t.tuple[1] or t - ty = (ty.typename == "typealias") and resolve_typealias(ty) or ty - local td = (ty.typename == "typedecl") and ty or a_type("typedecl", { def = ty }) + ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty + local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) return td else local newtype = value.newtype if newtype.typename == "typealias" then - local aliasing = find_var(newtype.alias_to.names[1], "use_type") - return resolve_typealias(newtype), aliasing - else + local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + return self:resolve_typealias(newtype), aliasing + elseif newtype.typename == "typedecl" then return newtype, nil end end @@ -10509,15 +10511,14 @@ expand_type(node, values, elements) }) return is_total, missing end - local function total_map_check(t, seen_keys) - local k = to_structural(t.keys) + local function total_map_check(keys, seen_keys) local is_total = true local missing - if k.typename == "enum" then - for _, key in ipairs(sorted_keys(k.enumset)) do + if keys.typename == "enum" then + for _, key in ipairs(sorted_keys(keys.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end - elseif k.typename == "boolean" then + elseif keys.typename == "boolean" then for _, key in ipairs({ true, false }) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -10531,35 +10532,38 @@ expand_type(node, values, elements) }) - local function check_assignment(where, vartype, valtype, varname, attr) + function TypeChecker:check_assignment(varnode, vartype, valtype) + local varname = varnode.tk + local attr = varnode.attribute + if varname then - if widen_back_var(varname) then - vartype, attr = find_var_type(varname) + if self:widen_back_var(varname) then + vartype, attr = self:find_var_type(varname) if not vartype then - error_at(where, "unknown variable") + self.errs:add(varnode, "unknown variable") return nil end end end if attr == "close" or attr == "const" or attr == "total" then - error_at(where, "cannot assign to <" .. attr .. "> variable") + self.errs:add(varnode, "cannot assign to <" .. attr .. "> variable") return nil end - local var = to_structural(vartype) + local var = self:to_structural(vartype) if var.typename == "typedecl" or var.typename == "typealias" then - error_at(where, "cannot reassign a type") + self.errs:add(varnode, "cannot reassign a type") return nil end if not valtype then - error_at(where, "variable is not being assigned a value") + self.errs:add(varnode, "variable is not being assigned a value") return nil, nil, "missing" end - assert_is_a(where, valtype, vartype, "in assignment") + self:assert_is_a(varnode, valtype, vartype, "in assignment") - local val = to_structural(valtype) + local val = self:to_structural(valtype) return var, val end @@ -10575,181 +10579,182 @@ expand_type(node, values, elements) }) visit_node.cbs = { ["statements"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) end, - after = function(node, _children) + after = function(self, node, _children) - if #st == 2 then - fail_unresolved() + if #self.st == 2 then + self.errs:fail_unresolved_labels(self.st[2]) + self.errs:fail_unresolved_nominals(self.st[2], self.st[1]) end if not node.is_repeat then - end_scope(node) + self:end_scope(node) end return NONE end, }, ["local_type"] = { - before = function(node) + before = function(self, node) local name = node.var.tk - local resolved, aliasing = get_typedecl(node.value) - local var = add_var(node.var, name, resolved, node.var.attribute) + local resolved, aliasing = self:get_typedecl(node.value) + local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing end end, - after = function(node, _children) - dismiss_unresolved(node.var.tk) + after = function(self, node, _children) + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["global_type"] = { - before = function(node) + before = function(self, node) + local global_scope = self.st[1] local name = node.var.tk - local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_typedecl(node.value) - local added = add_global(node.var, name, resolved) + local resolved, aliasing = self:get_typedecl(node.value) + local added = self:add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then added.aliasing = aliasing end - if added and unresolved.global_types[name] then - unresolved.global_types[name] = nil + if global_scope.pending_global_types[name] then + global_scope.pending_global_types[name] = nil end else - if not st[1][name] then - unresolved.global_types[name] = true + if not self.st[1].vars[name] then + global_scope.pending_global_types[name] = true end end end, - after = function(node, _children) - dismiss_unresolved(node.var.tk) + after = function(self, node, _children) + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["local_declaration"] = { - before = function(node) - if tc then + before = function(self, node) + if self.collector then for _, var in ipairs(node.vars) do - tc.reserve_symbol_list_slot(var) + self.collector.reserve_symbol_list_slot(var) end end end, before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local valtuple = children[3] local encountered_close = false - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then - if opts.gen_target == "5.4" then + if self.gen_target == "5.4" then if encountered_close then - error_at(var, "only one per declaration is allowed") + self.errs:add(var, "only one per declaration is allowed") else encountered_close = true end else - error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + self.errs:add(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(self.gen_target) .. ")") end end - local ok, t = determine_declaration_type(var, node, infertypes, i) + local ok, t = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then if not type_is_closable(t) then - error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + self.errs:add(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + self.errs:add(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end assert(var) - add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") local infertype = infertypes.tuple[i] if ok and infertype then - local where = node.exps[i] or node.exps + local w = node.exps[i] or node.exps - local rt = to_structural(t) + local rt = self:to_structural(t) if (not (rt.typename == "enum")) and ((not (t.typename == "nominal")) or (rt.typename == "union")) and - not same_type(t, infertype) then + not self:same_type(t, infertype) then - t = infer_at(where, infertype) - add_var(where, var.tk, t, "const", "narrowed_declaration") + t = self:infer_at(w, infertype) + self:add_var(w, var.tk, t, "const", "narrowed_declaration") end end - if tc then - tc.store_type(var.y, var.x, t) + if self.collector then + self.collector.store_type(var.y, var.x, t) end - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local valtuple = children[3] - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do - local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) + local _, t, is_inferred = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - error_at(var, "globals may not be ") + self.errs:add(var, "globals may not be ") end - add_global(var, var.tk, t, is_inferred) + self:add_global(var, var.tk, t, is_inferred) - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local vartuple = children[1] assert(vartuple.typename == "tuple") local vartypes = vartuple.tuple local valtuple = children[3] assert(valtuple.typename == "tuple") - local valtypes = get_assignment_values(valtuple, #vartypes) + local valtypes = get_assignment_values(node, valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk local valtype = valtypes.tuple[i] - local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) + local rvar, rval, err = self:check_assignment(varnode, vartype, valtype) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local msg = #valtuple.tuple == 1 and "only 1 value is returned by the function" or ("only " .. #valtuple.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) + self.errs:add_warning("hint", varnode, msg) end end if rval and rvar then if rval.typename == "function" then - widen_all_unions() + self:widen_all_unions() end if varname and (rvar.typename == "union" or rvar.typename == "interface") then - add_var(varnode, varname, rval, nil, "narrow") + self:add_var(varnode, varname, rval, nil, "narrow") end - if tc then - tc.store_type(varnode.y, varnode.x, valtype) + if self.collector then + self.collector.store_type(varnode.y, varnode.x, valtype) end end end @@ -10758,7 +10763,7 @@ expand_type(node, values, elements) }) end, }, ["if"] = { - after = function(node, _children) + after = function(self, node, _children) local all_return = true for _, b in ipairs(node.if_blocks) do if not b.block_returns then @@ -10768,26 +10773,26 @@ expand_type(node, values, elements) }) end if all_return then node.block_returns = true - infer_negation_of_if_blocks(node, node, #node.if_blocks) + self:infer_negation_of_if_blocks(node, node, #node.if_blocks) end return NONE end, }, ["if_block"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) if node.if_block_n > 1 then - infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) + self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end end, - before_statements = function(node) + before_statements = function(self, node) if node.exp then - apply_facts(node.exp, node.exp.known) + self:apply_facts(node.exp, node.exp.known) end end, - after = function(node, _children) - end_scope(node) + after = function(self, node, _children) + self:end_scope(node) if #node.body > 0 and node.body[#node.body].block_returns then node.block_returns = true @@ -10797,76 +10802,96 @@ expand_type(node, values, elements) }) end, }, ["while"] = { - before = function(node) + before = function(self, node) - widen_all_unions(node) + self:widen_all_unions(node) end, - before_statements = function(node) - begin_scope(node) - apply_facts(node.exp, node.exp.known) + before_statements = function(self, node) + self:begin_scope(node) + self:apply_facts(node.exp, node.exp.known) end, after = end_scope_and_none_type, }, ["label"] = { - before = function(node) - - widen_all_unions() - local label_id = "::" .. node.label .. "::" - if st[#st][label_id] then - error_at(node, "label '" .. node.label .. "' already defined at " .. filename) - end - local unresolved = find_unresolved() - local var = add_var(node, label_id, type_at(node, a_type("none", {}))) - if unresolved then - if unresolved.labels[node.label] then - var.used = true + before = function(self, node) + + self:widen_all_unions() + local label_id = node.label + do + local scope = self.st[#self.st] + scope.labels = scope.labels or {} + if scope.labels[label_id] then + self.errs:add(node, "label '" .. node.label .. "' already defined") + else + scope.labels[label_id] = node end - unresolved.labels[node.label] = nil end + + + local scope = self.st[#self.st] + if scope.pending_labels and scope.pending_labels[label_id] then + node.used_label = true + scope.pending_labels[label_id] = nil + + end + end, after = function() return NONE end, }, ["goto"] = { - after = function(node, _children) - if not find_var_type("::" .. node.label .. "::") then - local unresolved = get_unresolved(st[#st]) - unresolved.labels[node.label] = unresolved.labels[node.label] or {} - table.insert(unresolved.labels[node.label], node) + after = function(self, node, _children) + local label_id = node.label + local found_label + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.labels and scope.labels[label_id] then + found_label = scope.labels[label_id] + break + end + end + + if found_label then + found_label.used_label = true + else + local scope = self.st[#self.st] + scope.pending_labels = scope.pending_labels or {} + scope.pending_labels[label_id] = scope.pending_labels[label_id] or {} + table.insert(scope.pending_labels[label_id], node) end return NONE end, }, ["repeat"] = { - before = function(node) + before = function(self, node) - widen_all_unions(node) + self:widen_all_unions(node) end, after = end_scope_and_none_type, }, ["forin"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local exptuple = children[2] assert(exptuple.typename == "tuple") local exptypes = exptuple.tuple - widen_all_unions(node) + self:widen_all_unions(node) local exp1 = node.exps[1] - local args = a_type("tuple", { tuple = { + local args = a_type(node.exps, "tuple", { tuple = { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], } }) - local exp1type = resolve_for_call(exptypes[1], args, false) + local exp1type = self:resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then local _ - _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) + _, exp1type = self:type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) end if exp1type.typename == "function" then @@ -10879,69 +10904,69 @@ expand_type(node, values, elements) }) if rets.is_va then r = last else - r = lax and UNKNOWN or INVALID + r = self.feat_lax and a_type(v, "unknown", {}) or a_type(v, "invalid", {}) end end - add_var(v, v.tk, r) + self:add_var(v, v.tk, r) - if tc then - tc.store_type(v.y, v.x, r) + if self.collector then + self.collector.store_type(v.y, v.x, r) end last = r end local nrets = #rets.tuple - if (not lax) and (not rets.is_va and #node.vars > nrets) then + if (not self.feat_lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - error_at(at, "too many variables for this iterator; it produces " .. n_values) + self.errs:add(at, "too many variables for this iterator; it produces " .. n_values) end else - if not (lax and is_unknown(exp1type)) then - error_at(exp1, "expression in for loop does not return an iterator") + if not (self.feat_lax and is_unknown(exp1type)) then + self.errs:add(exp1, "expression in for loop does not return an iterator") end end end, after = end_scope_and_none_type, }, ["fornum"] = { - before_statements = function(node, children) - widen_all_unions(node) - begin_scope(node) - local from_t = to_structural(resolve_tuple(children[2])) - local to_t = to_structural(resolve_tuple(children[3])) - local step_t = children[4] and to_structural(children[4]) - local t = (from_t.typename == "integer" and + before_statements = function(self, node, children) + self:widen_all_unions(node) + self:begin_scope(node) + local from_t = self:to_structural(resolve_tuple(children[2])) + local to_t = self:to_structural(resolve_tuple(children[3])) + local step_t = children[4] and self:to_structural(children[4]) + local typename = (from_t.typename == "integer" and to_t.typename == "integer" and (not step_t or step_t.typename == "integer")) and - INTEGER or - NUMBER - add_var(node.var, node.var.tk, t) + "integer" or + "number" + self:add_var(node.var, node.var.tk, a_type(node.var, typename, {})) end, after = end_scope_and_none_type, }, ["return"] = { - before = function(node) - local rets = find_var_type("@return") + before = function(self, node) + local rets = self:find_var_type("@return") if rets and rets.typename == "tuple" then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end end end, - after = function(node, children) + after = function(self, node, children) local got = children[1] assert(got.typename == "tuple") local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") + local expected = self:find_var_type("@return") if not expected then - expected = infer_at(node, got) - module_type = drop_constant_value(to_structural(resolve_tuple(expected))) - st[2]["@return"] = { t = expected } + expected = self:infer_at(node, got) + self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10956,8 +10981,8 @@ expand_type(node, values, elements) }) vatype = expected.is_va and expected.tuple[n_expected] end - if n_got > n_expected and (not lax) and not vatype then - error_at(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) + if n_got > n_expected and (not self.feat_lax) and not vatype then + self.errs:add(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end if n_expected > 1 and @@ -10965,18 +10990,18 @@ expand_type(node, values, elements) }) node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and node.exps[1].discarded_tuple then - add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + self.errs:add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, n_got do local e = expected_t[i] or vatype if e then e = resolve_tuple(e) - local where = (node.exps[i] and node.exps[i].x) and + local w = (node.exps[i] and node.exps[i].x) and node.exps[i] or node.exps - assert(where and where.x) - assert_is_a(where, got_t[i], e, what) + assert(w and w.x) + self:assert_is_a(w, got_t[i], e, what) end end @@ -10984,25 +11009,28 @@ expand_type(node, values, elements) }) end, }, ["variable_list"] = { - after = function(node, children) - local tuple = a_type("tuple", { tuple = children }) + after = function(self, node, children) + local tuple = a_type(node, "tuple", { tuple = children }) tuple = flatten_tuple(tuple) for i, t in ipairs(tuple.tuple) do - ensure_not_abstract(node[i], t) + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node[i], err) + end end return tuple end, }, ["literal_table"] = { - before = function(node) + before = function(self, node) if node.expected then - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) if decltype.typename == "typevar" and decltype.constraint then - decltype = resolve_typedecl(to_structural(decltype.constraint)) + decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -11034,19 +11062,19 @@ expand_type(node, values, elements) }) end end end, - after = function(node, children) + after = function(self, node, children) node.known = FACT_TRUTHY if not node.expected then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) local constraint if decltype.typename == "typevar" and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = to_structural(constraint) + decltype = self:to_structural(constraint) end if decltype.typename == "union" then @@ -11054,7 +11082,7 @@ expand_type(node, values, elements) }) local single_table_rt for _, t in ipairs(decltype.types) do - local rt = to_structural(t) + local rt = self:to_structural(t) if is_lua_table_type(rt) then if single_table_type then @@ -11075,7 +11103,7 @@ expand_type(node, values, elements) }) end if not is_lua_table_type(decltype) then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end local force_array = nil @@ -11085,73 +11113,75 @@ expand_type(node, values, elements) }) for i, child in ipairs(children) do local cvtype = resolve_tuple(child.vtype) local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b = nil - if child.ktype.typename == "boolean" then + if cktype.typename == "boolean" then b = (node[i].key.tk == "true") end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + self.errs:check_redeclared_key(node[i], node, seen_keys, ck or n or b) if decltype.fields and ck then local df = decltype.fields[ck] if not df then - error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + self.errs:add_in_context(node[i], node, "unknown field " .. ck) else if df.typename == "typedecl" or df.typename == "typealias" then - error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) + self.errs:add_in_context(node[i], node, "cannot reassign a type") else - assert_is_a(node[i], cvtype, df, "in record field", ck) + self:assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif decltype.typename == "tupletable" and is_number_type(child.ktype) then + elseif decltype.typename == "tupletable" and is_numeric_type(cktype) then local dt = decltype.types[n] if not n then - error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unknown index in tuple %s", decltype) elseif not dt then - error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unexpected index " .. n .. " in tuple %s", decltype) else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, dt, node, "in tuple: at index " .. tostring(n)) end - elseif decltype.elements and is_number_type(child.ktype) then + elseif decltype.elements and is_numeric_type(cktype) then local cv = child.vtype if cv.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then for ti, tt in ipairs(cv.tuple) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) + self:assert_is_a(node[i], tt, decltype.elements, node, "expected an array: at index " .. tostring(i + ti - 1)) end else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, decltype.elements, node, "expected an array: at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then if decltype.typename == "map" then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, a_type(node[i].key, "integer", {}), decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") end - force_array = expand_type(node[i], force_array, child.vtype) + force_array = self:expand_type(node[i], force_array, child.vtype) elseif decltype.typename == "map" then force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, cktype, decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") else - error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + self.errs:add_in_context(node[i], node, "unexpected key of type %s in table of type %s", cktype, decltype) end end local t if force_array then - t = infer_at(node, a_type("array", { elements = force_array })) + t = self:infer_at(node, a_type(node, "array", { elements = force_array })) else - t = resolve_typevars_at(node, node.expected) + t = self:resolve_typevars_at(node, node.expected) end if decltype.typename == "record" then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt.typename == "record" then node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype.typename == "map" then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt.typename == "map" then - node.is_total, node.missing = total_map_check(decltype, seen_keys) + local rk = self:to_structural(rt.keys) + node.is_total, node.missing = total_map_check(rk, seen_keys) end end @@ -11163,13 +11193,13 @@ expand_type(node, values, elements) }) end, }, ["literal_table_item"] = { - after = function(node, children) + after = function(self, node, children) local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.itemtype then vtype = node.itemtype - assert_is_a(node.value, children[2], node.itemtype, "in table item") + self:assert_is_a(node.value, children[2], node.itemtype, node) end if vtype.typename == "function" and vtype.is_method then @@ -11178,210 +11208,210 @@ expand_type(node, values, elements) }) vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("literal_table_item", { + return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - })) + }) end, }, ["local_function"] = { - before = function(node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self, node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[2] assert(args.typename == "tuple") - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) - local t = type_at(node, ensure_fresh_typeargs(a_function({ + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - }))) + rets = self.get_rets(rets), + })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["local_macroexp"] = { - before = function(node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self, node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) - check_macroexp_arg_use(node.macrodef) + self:check_macroexp_arg_use(node.macrodef) - local t = type_at(node, ensure_fresh_typeargs(a_function({ + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), macroexp = node.macrodef, - }))) + })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["global_function"] = { - before = function(node) - widen_all_unions() - begin_scope(node) + before = function(self, node) + self:widen_all_unions() + self:begin_scope(node) if node.implicit_global_function then - local typ = find_var_type(node.name.tk) + local typ = self:find_var_type(node.name.tk) if typ then if typ.typename == "function" then node.is_predeclared_local_function = true - elseif not lax then - error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + elseif not self.feat_lax then + self.errs:add(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end - elseif not lax then - error_at(node, "functions need an explicit 'local' or 'global' annotation") + elseif not self.feat_lax then + self.errs:add(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[2] assert(args.typename == "tuple") - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) if node.is_predeclared_local_function then return NONE end - add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function({ + self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - })))) + rets = self.get_rets(rets), + }))) return NONE end, }, ["record_function"] = { - before = function(node) - widen_all_unions() - begin_scope(node) + before = function(self, node) + self:widen_all_unions() + self:begin_scope(node) end, - before_arguments = function(_node, children) - local rtype = to_structural(resolve_typedecl(children[1])) + before_arguments = function(self, _node, children) + local rtype = self:to_structural(resolve_typedecl(children[1])) if rtype.fields and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + })) end end end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[3] assert(args.typename == "tuple") local rets = children[4] assert(rets.typename == "tuple") - local rtype = to_structural(resolve_typedecl(children[1])) + local rtype = self:to_structural(resolve_typedecl(children[1])) - if lax and rtype.typename == "unknown" then + if self.feat_lax and rtype.typename == "unknown" then return end if rtype.typename == "emptytable" then - edit_type(rtype, "record") + edit_type(rtype, rtype, "record") local r = rtype r.fields = {} r.field_order = {} end if not rtype.fields then - error_at(node, "not a record: %s", rtype) + self.errs:add(node, "not a record: %s", rtype) return end - local selftype = get_self_type(node.fn_owner) + local selftype = self:get_self_type(node.fn_owner) if node.is_method then if not selftype then - error_at(node, "could not resolve type of self") + self.errs:add(node, "could not resolve type of self") return end args.tuple[1] = selftype - add_var(nil, "self", selftype) + self:add_var(nil, "self", selftype) end - local fn_type = type_at(node, ensure_fresh_typeargs(a_function({ + local fn_type = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - }))) + rets = self.get_rets(rets), + })) - local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) + local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = to_structural(rfieldtype) + rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - redeclaration_warning(node) + self.errs:redeclaration_warning(node) end - local ok, err = same_type(fn_type, rfieldtype) + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype.typename == "poly" then - add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") + self.errs:add_prefixing(node, err, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability): ") return end local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " - add_errs_prefixing(node, err, errors, msg) + self.errs:add_prefixing(node, err, msg) return end else - if lax or rtype == open_t then + if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -11394,82 +11424,82 @@ expand_type(node, values, elements) }) open_v.implemented[open_k] = true end - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, _children) - end_function_scope(node) + after = function(self, node, _children) + self:end_function_scope(node) return NONE end, }, ["function"] = { - before = function(node) - widen_all_unions(node) - begin_scope(node) + before = function(self, node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[1] assert(args.typename == "tuple") - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[1] assert(args.typename == "tuple") local rets = children[2] assert(rets.typename == "tuple") - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function({ + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - }))) + })) end, }, ["macroexp"] = { - before = function(node) - widen_all_unions(node) - begin_scope(node) + before = function(self, node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_exp = function(node, children) + before_exp = function(self, node, children) local args = children[1] assert(args.typename == "tuple") - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[1] assert(args.typename == "tuple") local rets = children[2] assert(rets.typename == "tuple") - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function({ + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - }))) + })) end, }, ["cast"] = { - after = function(node, _children) + after = function(_self, node, _children) return node.casttype end, }, ["paren"] = { - before = function(node) + before = function(_self, node) node.e1.expected = node.expected end, - after = function(node, children) + after = function(_self, node, children) node.known = node.e1 and node.e1.known return resolve_tuple(children[1]) end, }, ["op"] = { - before = function(node) - begin_scope() + before = function(self, node) + self:begin_scope() if node.expected then if node.op.op == "and" then node.e2.expected = node.expected @@ -11481,18 +11511,19 @@ expand_type(node, values, elements) }) end end end, - before_e2 = function(node, children) + before_e2 = function(self, node, children) local e1type = children[1] if node.op.op == "and" then - apply_facts(node, node.e1.known) + self:apply_facts(node, node.e1.known) elseif node.op.op == "or" then - apply_facts(node, facts_not(node, node.e1.known)) + self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then if e1type.typename == "function" then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(e1type.rets, node.expected) + + self:is_a(e1type.rets, node.expected) end local e1args = e1type.args.tuple local at = argdelta @@ -11515,8 +11546,8 @@ expand_type(node, values, elements) }) end end end, - after = function(node, children) - end_scope() + after = function(self, node, children) + self:end_scope() local ga = children[1] @@ -11527,29 +11558,33 @@ expand_type(node, values, elements) }) local ub - local ra = to_structural(ua) + local ra = self:to_structural(ua) local rb if ra.typename == "circular_require" or (ra.typename == "typedecl" and ra.def and ra.def.typename == "circular_require") then - return invalid_at(node, "cannot dereference a type from a circular require") + return self.errs:invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(ua) then + if self.feat_lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then - add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + self.errs:add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, ua, gb) + assert(gb.typename == "tuple") + local t = self:type_check_funcall(node, ua, gb) return t elseif node.op.op == "as" then return gb end - local expected = node.expected and to_structural(resolve_tuple(node.expected)) + local expected = node.expected and self:to_structural(resolve_tuple(node.expected)) - ensure_not_abstract(node.e1, ra) + local ok, err = ensure_not_abstract(ra) + if not ok then + self.errs:add(node.e1, err) + end if ra.typename == "typedecl" and ra.def.typename == "record" then ra = ra.def end @@ -11558,8 +11593,11 @@ expand_type(node, values, elements) }) if gb then ub = resolve_tuple(gb) - rb = to_structural(ub) - ensure_not_abstract(node.e2, rb) + rb = self:to_structural(ub) + ok, err = ensure_not_abstract(rb) + if not ok then + self.errs:add(node.e2, err) + end if rb.typename == "typedecl" and rb.def.typename == "record" then rb = rb.def end @@ -11569,22 +11607,20 @@ expand_type(node, values, elements) }) node.receiver = ua assert(node.e2.kind == "identifier") - local bnode = { - y = node.e2.y, - x = node.e2.x, + local bnode = node_at(node.e2, { tk = node.e2.tk, kind = "string", - } - local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk })) - local t = type_check_index(node.e1, bnode, ua, btype) + }) + local btype = a_type(node.e2, "string", { literal = node.e2.tk }) + local t = self:type_check_index(node.e1, bnode, ua, btype) - if t.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and self.gen_compat ~= "off" then if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk - all_needs_compat[key] = true + self.all_needs_compat[key] = true end end @@ -11592,22 +11628,22 @@ expand_type(node, values, elements) }) end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, ua, ub) + return self:type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then if rb.typename == "integer" then - all_needs_compat["math"] = true + self.all_needs_compat["math"] = true end if ra.typename == "typedecl" then - error_at(node, "can only use 'is' on variables, not types") + self.errs:add(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact({ var = node.e1.tk, typ = ub, where = node }) + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact({ var = node.e1.tk, typ = ub, w = node }) else - error_at(node, "can only use 'is' on variables") + self.errs:add(node, "can only use 'is' on variables") end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == ":" then @@ -11615,16 +11651,16 @@ expand_type(node, values, elements) }) - if lax and (is_unknown(ua) or ua.typename == "typevar") then + if self.feat_lax and (is_unknown(ua) or ua.typename == "typevar") then if node.e1.kind == "variable" then - add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) + self.errs:add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - return UNKNOWN + return a_type(node, "unknown", {}) end - local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) + local t, e = self:match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, ua) + return self.errs:invalid_at(node.e2, e, ua) end return t @@ -11632,7 +11668,7 @@ expand_type(node, values, elements) }) if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == "and" then @@ -11650,33 +11686,33 @@ expand_type(node, values, elements) }) node.known = nil t = ua - elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or - (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then + elseif ((ra.typename == "enum" and rb.typename == "string" and self:is_a(rb, ra)) or + (ra.typename == "string" and rb.typename == "enum" and self:is_a(ra, rb))) then node.known = nil t = (ra.typename == "enum" and ra or rb) elseif expected and expected.typename == "union" then node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ ra, rb }, true) + local u = unite(node, { ra, rb }, true) if u.typename == "union" then - local ok, err = is_valid_union(u) + ok, err = is_valid_union(u) if not ok then - u = err and invalid_at(node, err, u) or INVALID + u = err and self.errs:invalid_at(node, err, u) or a_type(node, "invalid", {}) end end t = u else - local a_ge_b = is_a(rb, ra) - local b_ge_a = is_a(ra, rb) + local a_ge_b = self:is_a(rb, ra) + local b_ge_a = self:is_a(ra, rb) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(ua, expected) - local b_is = is_a(ub, expected) + local a_is = self:is_a(ua, expected) + local b_is = self:is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, expected) + t = self:resolve_typevars_at(node, expected) end end if not t then @@ -11700,39 +11736,41 @@ expand_type(node, values, elements) }) if ra.typename == "enum" and rb.typename == "string" then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", ub, ua) + return self.errs:invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then - return invalid_at(node, "tuples are not the same size") - elseif is_a(ub, ua) or ua.typename == "typevar" then + return self.errs:invalid_at(node, "tuples are not the same size") + elseif self:is_a(ub, ua) or ua.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact({ var = node.e1.tk, typ = ub, where = node }) + node.known = EqFact({ var = node.e1.tk, typ = ub, w = node }) end - elseif is_a(ua, ub) or ub.typename == "typevar" then + elseif self:is_a(ua, ub) or ub.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact({ var = node.e2.tk, typ = ua, where = node }) + node.known = EqFact({ var = node.e2.tk, typ = ua, w = node }) end - elseif lax and (is_unknown(ua) or is_unknown(ub)) then - return UNKNOWN + elseif self.feat_lax and (is_unknown(ua) or is_unknown(ub)) then + return a_type(node, "unknown", {}) else - return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) + return self.errs:invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.arity == 1 and unop_types[node.op.op] then if ra.typename == "union" then - ra = unite(ra.types, true) + ra = unite(node, ra.types, true) end local types_op = unop_types[node.op.op] - local t = types_op[ra.typename] + local tn = types_op[ra.typename] + local t = tn and a_type(node, tn, {}) if not t and ra.fields then t = find_in_interface_list(ra, function(ty) - return types_op[ty.typename] + local tname = types_op[ty.typename] + return tname and a_type(node, tname, {}) end) end @@ -11740,19 +11778,18 @@ expand_type(node, values, elements) }) if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) end end if ra.typename == "map" then if ra.keys.typename == "number" or ra.keys.typename == "integer" then - add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + self.errs:add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - error_at(node, "using the '#' operator on this map will always return 0") + self.errs:add(node, "using the '#' operator on this map will always return 0") end end @@ -11760,12 +11797,12 @@ expand_type(node, values, elements) }) node.known = FACT_TRUTHY end - if node.op.op == "~" and env.gen_target == "5.1" then + if node.op.op == "~" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", "bnot", node.e1) end end @@ -11779,39 +11816,39 @@ expand_type(node, values, elements) }) end if ra.typename == "union" then - ra = unite(ra.types, true) + ra = unite(ra, ra.types, true) end if rb.typename == "union" then - rb = unite(rb.types, true) + rb = unite(rb, rb.types, true) end local types_op = binop_types[node.op.op] - local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local tn = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local t = tn and a_type(node, tn, {}) local meta_on_operator if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) if node.op.op == "or" then - local u = unite({ ua, ub }) + local u = unite(node, { ua, ub }) if u.typename == "union" and is_valid_union(u) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + self.errs:add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end end if ua.typename == "nominal" and ub.typename == "nominal" and not meta_on_operator then - if is_a(ua, ub) then + if self:is_a(ua, ub) then t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) + self.errs:add(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end @@ -11819,20 +11856,20 @@ expand_type(node, values, elements) }) node.known = FACT_TRUTHY end - if node.op.op == "//" and env.gen_target == "5.1" then + if node.op.op == "//" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2) else - local div = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } + local div = node_at(node, { kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }) convert_node_to_compat_call(node, "math", "floor", div) end - elseif bit_operators[node.op.op] and env.gen_target == "5.1" then + elseif bit_operators[node.op.op] and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end @@ -11844,28 +11881,28 @@ expand_type(node, values, elements) }) end, }, ["variable"] = { - after = function(node, _children) + after = function(self, node, _children) if node.tk == "..." then - local va_sentinel = find_var_type("@is_va") + local va_sentinel = self:find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return invalid_at(node, "cannot use '...' outside a vararg function") + return self.errs:invalid_at(node, "cannot use '...' outside a vararg function") end end local t if node.tk == "_G" then - t, node.attribute = simulate_g() + t, node.attribute = self:simulate_g() else local use = node.is_lvalue and "lvalue" or "use" - t, node.attribute = find_var_type(node.tk, use) + t, node.attribute = self:find_var_type(node.tk, use) end if not t then - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return a_type(node, "unknown", {}) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end if t.typename == "typedecl" then @@ -11876,70 +11913,70 @@ expand_type(node, values, elements) }) end, }, ["type_identifier"] = { - after = function(node, _children) - local typ, attr = find_var_type(node.tk) + after = function(self, node, _children) + local typ, attr = self:find_var_type(node.tk) node.attribute = attr if typ then return typ end - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return a_type(node, "unknown", {}) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { - after = function(node, children) + after = function(self, node, children) local t = children[1] if not t then - t = UNKNOWN + t = a_type(node, "unknown", {}) end if node.tk == "..." then - t = a_vararg({ t }) + t = a_vararg(node, { t }) end - add_var(node, node.tk, t).is_func_arg = true + self:add_var(node, node.tk, t).is_func_arg = true return t end, }, ["identifier"] = { - after = function(_node, _children) + after = function(_self, _node, _children) return NONE end, }, ["newtype"] = { - after = function(node, _children) + after = function(_self, node, _children) return node.newtype end, }, ["error_node"] = { - after = function(_node, _children) - return INVALID + after = function(_self, node, _children) + return a_type(node, "invalid", {}) end, }, } visit_node.cbs["break"] = { - after = function(_node, _children) + after = function(_self, _node, _children) return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] - local function after_literal(node) + local function after_literal(_self, node) node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind, {})) + return a_type(node, node.kind, {}) end visit_node.cbs["string"] = { - after = function(node, _children) - local t = after_literal(node) + after = function(self, node, _children) + local t = after_literal(self, node) t.literal = node.conststr - local expected = node.expected and to_structural(node.expected) - if expected and expected.typename == "enum" and is_a(t, expected) then + local expected = node.expected and self:to_structural(node.expected) + if expected and expected.typename == "enum" and self:is_a(t, expected) then return node.expected end @@ -11950,8 +11987,8 @@ expand_type(node, values, elements) }) visit_node.cbs["integer"] = { after = after_literal } visit_node.cbs["boolean"] = { - after = function(node, _children) - local t = after_literal(node) + after = function(self, node, _children) + local t = after_literal(self, node) node.known = (node.tk == "true") and FACT_TRUTHY or nil return t end, @@ -11962,7 +11999,7 @@ expand_type(node, values, elements) }) visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node, _children, t) + visit_node.after = function(_self, node, _children, t) if node.expanded then apply_macroexp(node) end @@ -11970,13 +12007,12 @@ expand_type(node, values, elements) }) return t end - local expand_interfaces do - local function add_interface_fields(what, fields, field_order, resolved, named, list) + local function add_interface_fields(self, what, fields, field_order, resolved, named, list) for fname, ftype in fields_of(resolved, list) do if fields[fname] then - if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) + if not self:is_a(fields[fname], ftype) then + self.errs:add(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11985,18 +12021,21 @@ expand_type(node, values, elements) }) end end - local function collect_interfaces(list, t, seen) + local function collect_interfaces(self, list, t, seen) if t.interface_list then for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if not (ri.typename == "invalid") then - assert(ri.typename == "interface", "nominal resolved to " .. ri.typename) - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(list, ri, seen) + if ri.typename == "interface" then + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) + end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end - table.insert(list, iface) end else if not seen[iface] then @@ -12009,30 +12048,30 @@ expand_type(node, values, elements) }) return list end - expand_interfaces = function(t) + function TypeChecker:expand_interfaces(t) if t.interfaces_expanded then return end t.interfaces_expanded = true - t.interface_list = collect_interfaces({}, t, {}) + t.interface_list = collect_interfaces(self, {}, t, {}) for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) assert(ri.typename == "interface") - add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields(self, "field", t.fields, t.field_order, ri, iface) if ri.meta_fields then t.meta_fields = t.meta_fields or {} t.meta_field_order = t.meta_field_order or {} - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + add_interface_fields(self, "metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") end else if not t.elements then t.elements = iface else - if not same_type(iface.elements, t.elements) then - error_at(t, "incompatible array interfaces") + if not self:same_type(iface.elements, t.elements) then + self.errs:add(t, "incompatible array interfaces") end end end @@ -12044,29 +12083,29 @@ expand_type(node, values, elements) }) visit_type = { cbs = { ["function"] = { - before = function(_typ) - begin_scope() + before = function(self, _typ) + self:begin_scope() end, - after = function(typ, _children) - end_scope() - return ensure_fresh_typeargs(typ) + after = function(self, typ, _children) + self:end_scope() + return self:ensure_fresh_typeargs(typ) end, }, ["record"] = { - before = function(typ) - begin_scope() - add_var(nil, "@self", type_at(typ, a_type("typedecl", { def = typ }))) + before = function(self, typ) + self:begin_scope() + self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) for fname, ftype in fields_of(typ) do if ftype.typename == "typealias" then - resolve_nominal(ftype.alias_to) - add_var(nil, fname, ftype) + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) elseif ftype.typename == "typedecl" then - add_var(nil, fname, ftype) + self:add_var(nil, fname, ftype) end end end, - after = function(typ, children) + after = function(self, typ, children) local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -12080,11 +12119,11 @@ expand_type(node, values, elements) }) if iface.typename == "array" then typ.interface_list[j] = iface elseif iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if ri.typename == "interface" then typ.interface_list[j] = iface else - error_at(children[i], "%s is not an interface", children[i]) + self.errs:add(children[i], "%s is not an interface", children[i]) end end i = i + 1 @@ -12124,7 +12163,7 @@ expand_type(node, values, elements) }) end end elseif ftype.typename == "typealias" then - resolve_typealias(ftype) + self:resolve_typealias(ftype) end typ.fields[name] = ftype @@ -12143,55 +12182,55 @@ expand_type(node, values, elements) }) end if typ.interface_list then - expand_interfaces(typ) + self:expand_interfaces(typ) end if fmacros then for _, t in ipairs(fmacros) do - local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(self, t.macroexp, visit_node, visit_type) - check_macroexp_arg_use(t.macroexp) + self:check_macroexp_arg_use(t.macroexp) - if not is_a(macroexp_type, t) then - error_at(macroexp_type, "macroexp type does not match declaration") + if not self:is_a(macroexp_type, t) then + self.errs:add(macroexp_type, "macroexp type does not match declaration") end end end - end_scope() + self:end_scope() return typ end, }, ["typearg"] = { - after = function(typ, _children) - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + after = function(self, typ, _children) + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + })) return typ end, }, ["typevar"] = { - after = function(typ, _children) - if not find_var_type(typ.typevar) then - error_at(typ, "undefined type variable " .. typ.typevar) + after = function(self, typ, _children) + if not self:find_var_type(typ.typevar) then + self.errs:add(typ, "undefined type variable " .. typ.typevar) end return typ end, }, ["nominal"] = { - after = function(typ, _children) + after = function(self, typ, _children) if typ.found then return typ end - local t = find_type(typ.names, true) + local t = self:find_type(typ.names, true) if t then if t.typename == "typearg" then typ.names = nil - edit_type(typ, "typevar") + edit_type(typ, typ, "typevar") local tv = typ tv.typevar = t.typearg tv.constraint = t.constraint @@ -12202,18 +12241,19 @@ expand_type(node, values, elements) }) end else local name = typ.names[1] - local unresolved = get_unresolved() - unresolved.nominals[name] = unresolved.nominals[name] or {} - table.insert(unresolved.nominals[name], typ) + local scope = self.st[#self.st] + scope.pending_nominals = scope.pending_nominals or {} + scope.pending_nominals[name] = scope.pending_nominals[name] or {} + table.insert(scope.pending_nominals[name], typ) end return typ end, }, ["union"] = { - after = function(typ, _children) + after = function(self, typ, _children) local ok, err = is_valid_union(typ) if not ok then - return err and invalid_at(typ, err, typ) or INVALID + return err and self.errs:invalid_at(typ, err, typ) or a_type(typ, "invalid", {}) end return typ end, @@ -12221,15 +12261,47 @@ expand_type(node, values, elements) }) }, } + local default_type_visitor = { + after = function(_self, typ, _children) + return typ + end, + } + + visit_type.cbs["interface"] = visit_type.cbs["record"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor + + + local function internal_compiler_check(fn) - return function(w, children, t) - t = fn and fn(w, children, t) or t + return function(s, n, children, t) + t = fn and fn(s, n, children, t) or t if type(t) ~= "table" then - error(((w).kind or (w).typename) .. " did not produce a type") + error(((n).kind or (n).typename) .. " did not produce a type") end if type(t.typename) ~= "string" then - error(((w).kind or (w).typename) .. " type does not have a typename") + error(((n).kind or (n).typename) .. " type does not have a typename") end return t @@ -12237,13 +12309,13 @@ expand_type(node, values, elements) }) end local function store_type_after(fn) - return function(w, children, t) - t = fn and fn(w, children, t) or t + return function(self, n, children, t) + t = fn and fn(self, n, children, t) or t - local where = w + local w = n - if where.y then - tc.store_type(where.y, where.x, t) + if w.y then + self.collector.store_type(w.y, w.x, t) end return t @@ -12251,119 +12323,167 @@ expand_type(node, values, elements) }) end local function debug_type_after(fn) - return function(node, children, t) - t = fn and fn(node, children, t) or t + return function(s, node, children, t) + t = fn and fn(s, node, children, t) or t + node.debug_type = t return t end end - if opts.run_internal_compiler_checks then - visit_node.after = internal_compiler_check(visit_node.after) - visit_type.after = internal_compiler_check(visit_type.after) - end + local function patch_visitors(my_visit_node, + after_node, + my_visit_type, + after_type) + - if tc then - visit_node.after = store_type_after(visit_node.after) - visit_type.after = store_type_after(visit_type.after) + if my_visit_node == visit_node then + my_visit_node = shallow_copy_table(my_visit_node) + end + my_visit_node.after = after_node(my_visit_node.after) + if my_visit_type then + if my_visit_type == visit_type then + my_visit_type = shallow_copy_table(my_visit_type) + end + my_visit_type.after = after_type(my_visit_type.after) + else + my_visit_type = visit_type + end + return my_visit_node, my_visit_type end - if TL_DEBUG then - visit_node.after = debug_type_after(visit_node.after) + local function set_feat(feat, default) + if feat then + return (feat == "on") + else + return default + end end - local default_type_visitor = { - after = function(typ, _children) - return typ - end, - } + tl.type_check = function(ast, filename, opts, env) + assert(type(filename) == "string", "tl.type_check signature has changed, pass filename separately") + assert((not opts) or (not (opts).env), "tl.type_check signature has changed, pass env separately") - visit_type.cbs["interface"] = visit_type.cbs["record"] + filename = filename or "?" - visit_type.cbs["string"] = default_type_visitor - visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typedecl"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor - visit_type.cbs["array"] = default_type_visitor - visit_type.cbs["map"] = default_type_visitor - visit_type.cbs["enum"] = default_type_visitor - visit_type.cbs["boolean"] = default_type_visitor - visit_type.cbs["nil"] = default_type_visitor - visit_type.cbs["number"] = default_type_visitor - visit_type.cbs["integer"] = default_type_visitor - visit_type.cbs["thread"] = default_type_visitor - visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["literal_table_item"] = default_type_visitor - visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor - visit_type.cbs["tuple"] = default_type_visitor - visit_type.cbs["poly"] = default_type_visitor - visit_type.cbs["any"] = default_type_visitor - visit_type.cbs["unknown"] = default_type_visitor - visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor - visit_type.cbs["none"] = default_type_visitor + opts = opts or {} + + if not env then + local err + env, err = tl.new_env({ defaults = opts }) + if err then + return nil, err + end + end + + local self = { + filename = filename, + env = env, + st = { + { + vars = env.globals, + pending_global_types = {}, + }, + }, + errs = Errors.new(filename), + all_needs_compat = {}, + dependencies = {}, + subtype_relations = TypeChecker.subtype_relations, + eqtype_relations = TypeChecker.eqtype_relations, + type_priorities = TypeChecker.type_priorities, + } - assert(ast.kind == "statements") - recurse_node(ast, visit_node, visit_type) + setmetatable(self, { __index = TypeChecker }) - close_types(st[1]) - check_for_unused_vars(st[1], true) + self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) + self.feat_arity = set_feat(opts.feat_arity or env.defaults.feat_arity, true) + self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT + self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET - clear_redundant_errors(errors) + if self.gen_target == "5.4" and self.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end - add_compat_entries(ast, all_needs_compat, env.gen_compat) + if self.feat_lax then + self.type_priorities = shallow_copy_table(self.type_priorities) + self.type_priorities["unknown"] = 0 - local result = { - ast = ast, - env = env, - type = module_type or BOOLEAN, - filename = filename, - warnings = warnings, - type_errors = errors, - dependencies = dependencies, - } + self.subtype_relations = shallow_copy_table(self.subtype_relations) - env.loaded[filename] = result - table.insert(env.loaded_order, filename) + self.subtype_relations["unknown"] = {} + self.subtype_relations["unknown"]["*"] = compare_true - if tc then - env.reporter:store_result(tc, env.globals) - end + self.subtype_relations["*"] = shallow_copy_table(self.subtype_relations["*"]) + self.subtype_relations["*"]["unknown"] = compare_true - return result -end + self.subtype_relations["*"]["boolean"] = compare_true + + self.get_rets = function(rets) + if #rets.tuple == 0 then + return a_vararg(rets, { a_type(rets, "unknown", {}) }) + end + return rets + end + else + self.get_rets = function(rets) + return rets + end + end + if env.report_types then + env.reporter = env.reporter or tl.new_type_reporter() + self.collector = env.reporter:get_collector(filename) + end + local visit_node, visit_type = visit_node, visit_type + if opts.run_internal_compiler_checks then + visit_node, visit_type = patch_visitors( + visit_node, internal_compiler_check, + visit_type, internal_compiler_check) + end + if self.collector then + visit_node, visit_type = patch_visitors( + visit_node, store_type_after, + visit_type, store_type_after) + end + if TL_DEBUG then + visit_node, visit_type = patch_visitors( + visit_node, debug_type_after) -function tl.symbols_in_scope(tr, y, x) - local function find(symbols, at_y, at_x) - local function le(a, b) - return a[1] < b[1] or - (a[1] == b[1] and a[2] <= b[2]) end - return binary_search(symbols, { at_y, at_x }, le) or 0 - end - local ret = {} + assert(ast.kind == "statements") + recurse_node(self, ast, visit_node, visit_type) - local n = find(tr.symbols, y, x) + local global_scope = self.st[1] + close_types(global_scope) + self.errs:warn_unused_vars(global_scope, true) - local symbols = tr.symbols - while n >= 1 do - local s = symbols[n] - if s[3] == "@{" then - n = n - 1 - elseif s[3] == "@}" then - n = s[4] - else - ret[s[3]] = s[4] - n = n - 1 + clear_redundant_errors(self.errs.errors) + + add_compat_entries(ast, self.all_needs_compat, self.gen_compat) + + local result = { + ast = ast, + env = env, + type = self.module_type or a_type(ast, "boolean", {}), + filename = filename, + warnings = self.errs.warnings, + type_errors = self.errs.errors, + dependencies = self.dependencies, + } + + env.loaded[filename] = result + table.insert(env.loaded_order, filename or "") + + if self.collector then + env.reporter:store_result(self.collector, env.globals) end - end - return ret + return result + end end @@ -12379,9 +12499,24 @@ local function read_full_file(fd) return content, err end -tl.process = function(filename, env, fd) - assert((not fd or type(fd) ~= "string"), "fd must be a file") +local function feat_lax_heuristic(filename, input) + if filename then + local _, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + if extension == "tl" then + return "off" + elseif extension == "lua" then + return "on" + end + end + if input then + return (input:match("^#![^\n]*lua[^\n]*\n")) and "on" or "off" + end + return "off" +end +tl.process = function(filename, env, fd) if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12401,23 +12536,38 @@ tl.process = function(filename, env, fd) return nil, "could not read " .. filename .. ": " .. err end - local _, extension = filename:match("(.*)%.([a-z]+)$") - extension = extension and extension:lower() + return tl.process_string(input, env, filename) +end - local is_lua - if extension == "tl" then - is_lua = false - elseif extension == "lua" then - is_lua = true - else - is_lua = input:match("^#![^\n]*lua[^\n]*\n") +function tl.target_from_lua_version(str) + if str == "Lua 5.1" or + str == "Lua 5.2" then + return "5.1" + elseif str == "Lua 5.3" then + return "5.3" + elseif str == "Lua 5.4" then + return "5.4" end +end - return tl.process_string(input, is_lua, env, filename) +local function default_env_opts(runtime, filename, input) + local gen_target = runtime and tl.target_from_lua_version(_VERSION) or DEFAULT_GEN_TARGET + local gen_compat = (gen_target == "5.4") and "off" or DEFAULT_GEN_COMPAT + return { + defaults = { + feat_lax = feat_lax_heuristic(filename, input), + gen_target = gen_target, + gen_compat = gen_compat, + run_internal_compiler_checks = false, + }, + } end -function tl.process_string(input, is_lua, env, filename) - env = env or tl.init_env(is_lua) +function tl.process_string(input, env, filename) + assert(type(env) ~= "boolean", "tl.process_string signature has changed") + + env = env or tl.new_env(default_env_opts(false, filename, input)) + if env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12429,7 +12579,7 @@ function tl.process_string(input, is_lua, env, filename) local result = { ok = false, filename = filename, - type = BOOLEAN, + type = a_type({ f = filename, y = 1, x = 1 }, "boolean", {}), type_errors = {}, syntax_errors = syntax_errors, env = env, @@ -12439,14 +12589,7 @@ function tl.process_string(input, is_lua, env, filename) return result end - local opts = { - filename = filename, - lax = is_lua, - gen_compat = env.gen_compat, - gen_target = env.gen_target, - env = env, - } - local result = tl.type_check(program, opts) + local result = tl.type_check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors @@ -12454,15 +12597,15 @@ function tl.process_string(input, is_lua, env, filename) end tl.gen = function(input, env, pp) - env = env or assert(tl.init_env(), "Default environment initialization failed") - local result = tl.process_string(input, false, env) + env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") + local result = tl.process_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) return code, result end @@ -12478,28 +12621,25 @@ local function tl_package_loader(module_name) if #errs > 0 then error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end - local lax = not not found_filename:match("lua$") local env = tl.package_loader_env if not env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = assert(tl.new_env(), "Default environment initialization failed") env = tl.package_loader_env end - env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + local opts = default_env_opts(true, found_filename) - local result = tl.type_check(program, { - lax = lax, - filename = found_filename, - env = env, - run_internal_compiler_checks = false, - }) + local w = { f = found_filename, x = 1, y = 1 } + env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + + local result = tl.type_check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type - local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) + local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname, loader_data) @@ -12525,21 +12665,10 @@ function tl.loader() end end -function tl.target_from_lua_version(str) - if str == "Lua 5.1" or - str == "Lua 5.2" then - return "5.1" - elseif str == "Lua 5.3" then - return "5.3" - elseif str == "Lua 5.4" then - return "5.4" - end -end - -local function env_for(lax, env_tbl) +local function env_for(opts, env_tbl) if not env_tbl then if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end return tl.package_loader_env end @@ -12548,7 +12677,7 @@ local function env_for(lax, env_tbl) tl.load_envs = setmetatable({}, { __mode = "k" }) end - tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax) + tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.new_env(opts) return tl.load_envs[env_tbl] end @@ -12558,17 +12687,14 @@ tl.load = function(input, chunkname, mode, ...) return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end - local lax = chunkname and not not chunkname:match("lua$") + local opts = default_env_opts(true, chunkname) + if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end - local result = tl.type_check(program, { - lax = lax, - filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""), - env = env_for(lax, ...), - run_internal_compiler_checks = false, - }) + local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") + local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -12582,7 +12708,7 @@ tl.load = function(input, chunkname, mode, ...) mode = mode:gsub("c", "") end - local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true) + local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) if not code then return nil, err end @@ -12590,4 +12716,29 @@ tl.load = function(input, chunkname, mode, ...) return load(code, chunkname, mode, ...) end + + + + +function tl.get_types(result) + return result.env.reporter:get_report(), result.env.reporter +end + +tl.init_env = function(lax, gen_compat, gen_target, predefined) + local opts = { + defaults = { + feat_lax = (lax and "on" or "off"), + gen_compat = ((type(gen_compat) == "string") and gen_compat) or + (gen_compat == false and "off") or + (gen_compat == true or gen_compat == nil) and "optional", + gen_target = gen_target or + ((_VERSION == "Lua 5.1" or _VERSION == "Lua 5.2") and "5.1") or + "5.3", + }, + predefined_modules = predefined, + } + + return tl.new_env(opts) +end + return tl diff --git a/tl.tl b/tl.tl index ecc5af5d4..fea76409c 100644 --- a/tl.tl +++ b/tl.tl @@ -476,9 +476,16 @@ end ]=====] local interface Where + f: string y: integer x: integer +end + +local record Errors filename: string + errors: {Error} + warnings: {Error} + unknown_dots: {string:boolean} end local record tl @@ -492,13 +499,13 @@ local record tl end type LoadFunction = function(...:any): any... - enum CompatMode + enum GenCompat "off" "optional" "required" end - enum TargetMode + enum GenTarget "5.1" "5.3" "5.4" @@ -516,25 +523,23 @@ local record tl end record TypeCheckOptions - lax: boolean - filename: string - gen_compat: CompatMode - gen_target: TargetMode - env: Env + feat_lax: Feat + feat_arity: Feat + gen_compat: GenCompat + gen_target: GenTarget run_internal_compiler_checks: boolean end record Env globals: {string:Variable} modules: {string:Type} + module_filenames: {string:string} loaded: {string:Result} loaded_order: {string} reporter: TypeReporter - gen_compat: CompatMode - gen_target: TargetMode keep_going: boolean report_types: boolean - feat_arity: boolean + defaults: TypeCheckOptions end record Result @@ -571,6 +576,8 @@ local record tl i: integer end + type errors = Errors + typecodes: {string:integer} record TypeInfo @@ -604,28 +611,28 @@ local record tl end record EnvOptions - lax_mode: boolean - gen_compat: CompatMode - gen_target: TargetMode - feat_arity: Feat + defaults: TypeCheckOptions predefined_modules: {string} end load: function(string, string, LoadMode, {any:any}): LoadFunction, string process: function(string, Env, ? FILE): (Result, string) - process_string: function(string, boolean, Env, ? string): Result + process_string: function(string, Env, ? string): Result gen: function(string, Env, PrettyPrintOptions): string, Result - type_check: function(Node, TypeCheckOptions): Result, string - new_env: function(EnvOptions): Env, string - init_env: function(? boolean, ? boolean | CompatMode, ? TargetMode, ? {string}): Env, string + type_check: function(Node, string, TypeCheckOptions, ? Env): Result, string + new_env: function(? EnvOptions): Env, string version: function(): string + -- Backwards compatibility + init_env: function(? boolean, ? boolean | GenCompat, ? GenTarget, ? {string}): Env, string + package_loader_env: Env load_envs: { {any:any} : Env } end local record TypeReporter typeid_to_num: {integer: integer} + typename_to_num: {TypeName: integer} next_num: integer tr: TypeReport @@ -687,17 +694,23 @@ tl.typecodes = { INVALID = 0x80000000, } -local type Result = tl.Result local type Env = tl.Env +local type EnvOptions = tl.EnvOptions local type Error = tl.Error -local type CompatMode = tl.CompatMode +local type Feat = tl.Feat +local type GenCompat = tl.GenCompat +local type GenTarget = tl.GenTarget +local type LoadFunction = tl.LoadFunction +local type LoadMode = tl.LoadMode local type PrettyPrintOptions = tl.PrettyPrintOptions +local type Result = tl.Result local type TypeCheckOptions = tl.TypeCheckOptions -local type LoadMode = tl.LoadMode -local type LoadFunction = tl.LoadFunction -local type TargetMode = tl.TargetMode local type TypeInfo = tl.TypeInfo local type TypeReport = tl.TypeReport +local type WarningKind = tl.WarningKind + +local DEFAULT_GEN_COMPAT : GenCompat = "optional" +local DEFAULT_GEN_TARGET : GenTarget = "5.3" local enum Narrow "narrow" @@ -1518,7 +1531,6 @@ local enum TypeName "any" "unknown" -- to be used in lax mode only "invalid" -- producing a new value of this type (not propagating) must always produce a type error - "unresolved" "none" "*" end @@ -1555,7 +1567,6 @@ local table_types : {TypeName:boolean} = { ["any"] = false, ["unknown"] = false, ["invalid"] = false, - ["unresolved"] = false, ["none"] = false, ["*"] = false, } @@ -1564,6 +1575,9 @@ local interface Type is Where where self.typename + y: integer + x: integer + typename: TypeName -- discriminator typeid: integer -- unique identifier inferred_at: Where -- for error messages @@ -1577,7 +1591,24 @@ local record StringType literal: string end -local type TypeType = TypeAliasType | TypeDeclType +local function is_numeric_type(t:Type): boolean + return t.typename == "number" or t.typename == "integer" +end + +local interface NumericType + is Type + where is_numeric_type(self) +end + +local record IntegerType + is NumericType + where self.typename == "integer" +end + +local record BooleanType + is Type + where self.typename == "boolean" +end local record TypeDeclType is Type @@ -1595,6 +1626,8 @@ local record TypeAliasType is_nested_alias: boolean end +local type TypeType = TypeDeclType | TypeAliasType + local record LiteralTableItemType is Type where self.typename == "literal_table_item" @@ -1605,13 +1638,12 @@ local record LiteralTableItemType vtype: Type end -local record UnresolvedType - is Type - where self.typename == "unresolved" - - labels: {string:{Node}} - nominals: {string:{NominalType}} - global_types: {string:boolean} +local record Scope + vars: {string:Variable} + labels: {string:Node} + pending_labels: {string:{Node}} + pending_nominals: {string:{NominalType}} + pending_global_types: {string:boolean} narrows: {string:boolean} end @@ -1678,6 +1710,11 @@ local record InvalidType where self.typename == "invalid" end +local record UnknownType + is Type + where self.typename == "unknown" +end + local record TupleType is Type where self.typename == "tuple" @@ -1852,7 +1889,8 @@ local interface Fact where self.fact fact: FactType - where: Where + w: Where + no_infer: boolean end local record TruthyFact @@ -2017,6 +2055,9 @@ local record Node -- goto label: string + -- label + used_label: boolean + casttype: Type -- variable @@ -2035,10 +2076,125 @@ local record Node debug_type: Type end -local function is_number_type(t:Type): boolean - return t.typename == "number" or t.typename == "integer" +local function a_type(w: Where, typename: TypeName, t: T): T + t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y + t.typename = typename + return t +end + +local function edit_type(w: Where, t: Type, typename: TypeName): Type + t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y + t.typename = typename + return t +end + +local macroexp a_typedecl(w: Where, def: Type): TypeDeclType + return a_type(w, "typedecl", { def = def } as TypeDeclType) +end + +local macroexp a_tuple(w: Where, t: {Type}): TupleType + return a_type(w, "tuple", { tuple = t } as TupleType) +end + +local macroexp a_union(w: Where, t: {Type}): UnionType + return a_type(w, "union", { types = t } as UnionType) +end + +local function a_function(w: Where, t: FunctionType): FunctionType + assert(t.min_arity) + return a_type(w, "function", t) +end + +local function a_vararg(w: Where, t: {Type}): TupleType + local typ = a_tuple(w, t) + typ.is_va = true + return typ +end + +local macroexp an_array(w: Where, t: Type): ArrayType + return a_type(w, "array", { elements = t } as ArrayType) +end + +local macroexp a_map(w: Where, k: Type, v: Type): MapType + return a_type(w, "map", { keys = k, values = v } as MapType) +end + +local function a_nominal(n: Node, names: {string}): NominalType + return a_type(n, "nominal", { names = names } as NominalType) end +local macroexp an_invalid(w: Where): InvalidType + return a_type(w, "invalid", {} as InvalidType) +end + +local macroexp an_unknown(w: Where): UnknownType + return a_type(w, "unknown", {} as UnknownType) +end + +local an_operator: function(Node, integer, string): Operator + +local function shallow_copy_new_type(t: T): T + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + copy.typeid = new_typeid() + return copy as T +end + +local function shallow_copy_table(t: T): T + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + return copy as T +end + +-- TODO move to Errors module +local function clear_redundant_errors(errors: {Error}) + local redundant: {integer} = {} + local lastx, lasty = 0, 0 + for i, err in ipairs(errors) do + err.i = i + end + table.sort(errors, function(a: Error, b: Error): boolean + local af = assert(a.filename) + local bf = assert(b.filename) + return af < bf + or (af == bf and (a.y < b.y + or (a.y == b.y and (a.x < b.x + or (a.x == b.x and (a.i < b.i)))))) + end) + for i, err in ipairs(errors) do + err.i = nil + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end +end + +local simple_types: {TypeName:boolean} = { + ["nil"] = true, + ["any"] = true, + ["number"] = true, + ["string"] = true, + ["thread"] = true, + ["boolean"] = true, + ["integer"] = true, +} + +do ----------------------------------------------------------------------------- + local record ParseState tokens: {Token} errs: {Error} @@ -2111,163 +2267,52 @@ local function verify_end(ps: ParseState, i: integer, istart: integer, node: Nod return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end -local function new_node(tokens: {Token}, i: integer, kind?: NodeKind): Node - local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } -end - -local function a_type(typename: TypeName, t: T): T - t.typeid = new_typeid() - t.typename = typename - return t +local function new_node(ps: ParseState, i: integer, kind?: NodeKind): Node + local t = ps.tokens[i] + return { f = ps.filename, y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end -local function edit_type(t: Type, typename: TypeName): Type +local function new_type(ps: ParseState, i: integer, typename: TypeName): Type + local token = ps.tokens[i] + local t: Type = {} t.typeid = new_typeid() + t.f = ps.filename + t.x = token.x + t.y = token.y t.typename = typename return t end -local function new_type(ps: ParseState, i: integer, typename: TypeName): Type - local token = ps.tokens[i] - return a_type(typename, { - filename = ps.filename, - y = token.y, - x = token.x, - --tk = token.tk - }) -end - local function new_typedecl(ps: ParseState, i: integer, def: Type): TypeDeclType local t = new_type(ps, i, "typedecl") as TypeDeclType t.def = def return t end -local macroexp a_typedecl(def: Type): TypeDeclType - return a_type("typedecl", { def = def } as TypeDeclType) -end - -local macroexp a_tuple(t: {Type}): TupleType - return a_type("tuple", { tuple = t } as TupleType) -end - -local macroexp a_union(t: {Type}): UnionType - return a_type("union", { types = t } as UnionType) -end - ---local macroexp a_poly(t: {FunctionType}): PolyType --- return a_type("poly", { types = t } as PolyType) ---end --- -local function a_function(t: FunctionType): FunctionType - assert(t.min_arity) - return a_type("function", t) -end - -local record Opt - where self.opttype - - opttype: Type -end - ---local function OPT(t: Type): Opt --- return { opttype = t } ---end --- -local record Args - is {Type|Opt} - - is_va: boolean -end - -local function va_args(args: Args): Args - args.is_va = true - return args -end - -local record FuncArgs - is HasTypeArgs - - args: Args - rets: Args - needs_compat: boolean -end - -local function a_fn(f: FuncArgs): FunctionType - local args_t = a_tuple {} - local tup = args_t.tuple - args_t.is_va = f.args.is_va - local min_arity = f.args.is_va and -1 or 0 - for _, a in ipairs(f.args) do - if a is Opt then - table.insert(tup, a.opttype) - else - table.insert(tup, a) - min_arity = min_arity + 1 - end - end - - local rets_t = a_tuple {} - tup = rets_t.tuple - rets_t.is_va = f.rets.is_va - for _, a in ipairs(f.rets) do - assert(a is Type) - table.insert(tup, a) - end - - return a_type("function", { - args = args_t, - rets = rets_t, - min_arity = min_arity, - needs_compat = f.needs_compat, - typeargs = f.typeargs, - } as FunctionType) -end - -local function a_vararg(t: {Type}): TupleType - local typ = a_tuple(t) - typ.is_va = true - return typ -end - -local macroexp an_array(t: Type): ArrayType - return a_type("array", { elements = t } as ArrayType) -end - -local macroexp a_map(k: Type, v: Type): MapType - return a_type("map", { keys = k, values = v } as MapType) +local function new_tuple(ps: ParseState, i: integer, types?: {Type}, is_va?: boolean): TupleType, {Type} + local t = new_type(ps, i, "tuple") as TupleType + t.is_va = is_va + t.tuple = types or {} + return t, t.tuple end -local NIL = a_type("nil", {}) -local ANY = a_type("any", {}) -local TABLE = a_map(ANY, ANY) -local NUMBER = a_type("number", {}) -local STRING = a_type("string", {}) -local THREAD = a_type("thread", {}) -local BOOLEAN = a_type("boolean", {}) -local INTEGER = a_type("integer", {}) - -local function shallow_copy_new_type(t: T): T - local copy: {any:any} = {} - for k, v in pairs(t as {any:any}) do - copy[k] = v - end - copy.typeid = new_typeid() - return copy as T +local function new_typealias(ps: ParseState, i: integer, alias_to: NominalType): TypeAliasType + local t = new_type(ps, i, "typealias") as TypeAliasType + t.alias_to = alias_to + return t end -local function shallow_copy_table(t: T): T - local copy: {any:any} = {} - for k, v in pairs(t as {any:any}) do - copy[k] = v +local function new_nominal(ps: ParseState, i: integer, name?: string): NominalType + local t = new_type(ps, i, "nominal") as NominalType + if name then + t.names = { name } end - return copy as T + return t end local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind?: NodeKind): integer, Node if ps.tokens[i].kind == kind then - return i + 1, new_node(ps.tokens, i, node_kind) + return i + 1, new_node(ps, i, node_kind) end return fail(ps, i, "syntax error, expected " .. kind) end @@ -2305,23 +2350,23 @@ local function parse_table_value(ps: ParseState, i: integer): integer, Node, int fail(ps, i, next_word == "record" and "syntax error: this syntax is no longer valid; declare nested record inside a record" or "syntax error: cannot declare interface inside a table; use a statement") - return skip_i, new_node(ps.tokens, i, "error_node") + return skip_i, new_node(ps, i, "error_node") end elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) - return i, new_node(ps.tokens, i - 1, "error_node") + return i, new_node(ps, i - 1, "error_node") end local e: Node i, e = parse_expression(ps, i) if not e then - e = new_node(ps.tokens, i - 1, "error_node") + e = new_node(ps, i - 1, "error_node") end return i, e end local function parse_table_item(ps: ParseState, i: integer, n?: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "literal_table_item") + local node = new_node(ps, i, "literal_table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i, "unexpected eof") end @@ -2372,7 +2417,7 @@ local function parse_table_item(ps: ParseState, i: integer, n?: integer): intege end end - node.key = new_node(ps.tokens, i, "integer") + node.key = new_node(ps, i, "integer") node.key_parsed = "implicit" node.key.constnum = n node.key.tk = tostring(n) @@ -2448,7 +2493,7 @@ local function parse_bracket_list(ps: ParseState, i: integer, list: {T}, open end local function parse_table_literal(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "literal_table") + local node = new_node(ps, i, "literal_table") return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end @@ -2506,16 +2551,21 @@ local function parse_typearg(ps: ParseState, i: integer): integer, TypeArgType, i = i + 1 i, constraint = parse_interface_name(ps, i) -- FIXME what about generic interfaces end - return i, a_type("typearg", { - y = ps.tokens[i - 2].y, - x = ps.tokens[i - 2].x, - typearg = name, - constraint = constraint, - } as TypeArgType) + local t = new_type(ps, i, "typearg") as TypeArgType + t.typearg = name + t.constraint = constraint + return i, t end local function parse_return_types(ps: ParseState, i: integer): integer, TupleType - return parse_type_list(ps, i, "rets") + local iprev = i - 1 + local t: TupleType + i, t = parse_type_list(ps, i, "rets") + if #t.tuple == 0 then + t.x = ps.tokens[iprev].x + t.y = ps.tokens[iprev].y + end + return i, t end local function parse_function_type(ps: ParseState, i: integer): integer, FunctionType @@ -2528,31 +2578,25 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else - typ.args = a_vararg { ANY } - typ.rets = a_vararg { ANY } + typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) end return i, typ end -local simple_types: {string:Type} = { - ["nil"] = NIL, - ["any"] = ANY, - ["table"] = TABLE, - ["number"] = NUMBER, - ["string"] = STRING, - ["thread"] = THREAD, - ["boolean"] = BOOLEAN, - ["integer"] = INTEGER, -} - local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer, Type local tk = ps.tokens[i].tk - local st = simple_types[tk] + local st = simple_types[tk as TypeName] if st then - return i + 1, st + return i + 1, new_type(ps, i, tk as TypeName) + elseif tk == "table" then + local typ = new_type(ps, i, "map") as MapType + typ.keys = new_type(ps, i, "any") + typ.values = new_type(ps, i, "any") + return i + 1, typ end - local typ = new_type(ps, i, "nominal") as NominalType - typ.names = { tk } + + local typ = new_nominal(ps, i, tk) i = i + 1 while ps.tokens[i].tk == "." do i = i + 1 @@ -2619,12 +2663,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ elseif tk == "function" then return parse_function_type(ps, i) elseif tk == "nil" then - return i + 1, simple_types["nil"] - elseif tk == "table" then - local typ = new_type(ps, i, "map") as MapType - typ.keys = ANY - typ.values = ANY - return i + 1, typ + return i + 1, new_type(ps, i, "nil") end return fail(ps, i, "expected a type") end @@ -2660,12 +2699,6 @@ parse_type = function(ps: ParseState, i: integer): integer, Type, integer return i, bt end -local function new_tuple(ps: ParseState, i: integer): TupleType, {Type} - local t = new_type(ps, i, "tuple") as TupleType - t.tuple = {} - return t, t.tuple -end - parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): integer, TupleType local t, list = new_tuple(ps, i) @@ -2721,7 +2754,7 @@ local function parse_function_args_rets_body(ps: ParseState, i: integer, node: N end local function parse_function_value(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "function") + local node = new_node(ps, i, "function") i = verify_tk(ps, i, "function") return parse_function_args_rets_body(ps, i, node) end @@ -2742,7 +2775,7 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node if kind == "identifier" then return verify_kind(ps, i, "identifier", "variable") elseif kind == "string" then - local node = new_node(ps.tokens, i, "string") + local node = new_node(ps, i, "string") node.conststr, node.is_longstring = unquote(tk) return i + 1, node elseif kind == "number" or kind == "integer" then @@ -2790,8 +2823,6 @@ local function node_is_require_call(n: Node): string end end -local an_operator: function(Node, integer, string): Operator - do local precedences: {integer:{string:integer}} = { [1] = { @@ -2866,8 +2897,8 @@ do -- small hack: for the sake of `tl types`, parse an invalid binary exp -- as a paren to produce a unary indirection on e1 and save its location. - local function failstore(tkop: Token, e1: Node): Node - return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } + local function failstore(ps: ParseState, tkop: Token, e1: Node): Node + return { f = ps.filename, y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } end local function P(ps: ParseState, i: integer): integer, Node @@ -2885,7 +2916,7 @@ do fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } elseif ps.tokens[i].tk == "(" then i = i + 1 local prev_i = i @@ -2894,7 +2925,7 @@ do fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "paren", e1 = e1 } else i, e1 = parse_literal(ps, i) end @@ -2919,12 +2950,12 @@ do local skipped = skip(ps, i, parse_type as SkipFunction) if skipped > i + 1 then fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") - return skipped, failstore(tkop, e1) + return skipped, failstore(ps, tkop, e1) end end i, key = verify_kind(ps, i, "identifier") if not key then - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end if op.op == ":" then @@ -2934,16 +2965,16 @@ do else fail(ps, i, "expected a function call for a method") end - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot call a method on this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } elseif tkop.tk == "(" then local prev_tk = ps.tokens[i - 1] if tkop.y > prev_tk.y then @@ -2955,15 +2986,15 @@ do local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") + local args = new_node(ps, i, "expression_list") i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot call this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } table.insert(ps.required_modules, node_is_require_call(e1)) elseif tkop.tk == "[" then @@ -2977,19 +3008,19 @@ do if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot index this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } elseif tkop.kind == "string" or tkop.kind == "{" then local op: Operator = new_operator(tkop, 2, "@funcall") local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") + local args = new_node(ps, i, "expression_list") local argument: Node if tkop.kind == "string" then - argument = new_node(ps.tokens, i) + argument = new_node(ps, i) argument.conststr = unquote(tkop.tk) i = i + 1 else @@ -3002,27 +3033,27 @@ do else fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") end - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end table.insert(args, argument) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } table.insert(ps.required_modules, node_is_require_call(e1)) elseif tkop.tk == "as" or tkop.tk == "is" then local op: Operator = new_operator(tkop, 2, tkop.tk) i = i + 1 - local cast = new_node(ps.tokens, i, "cast") + local cast = new_node(ps, i, "cast") if ps.tokens[i].tk == "(" then i, cast.casttype = parse_type_list(ps, i, "casttype") else i, cast.casttype = parse_type(ps, i) end if not cast.casttype then - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else break end @@ -3053,7 +3084,7 @@ do end lookahead = ps.tokens[i].tk end - lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs, } + lhs = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs, } end return i, lhs end @@ -3080,7 +3111,7 @@ parse_expression_and_tk = function(ps: ParseState, i: integer, tk: string): inte local e: Node i, e = parse_expression(ps, i) if not e then - e = new_node(ps.tokens, i - 1, "error_node") + e = new_node(ps, i - 1, "error_node") end if ps.tokens[i].tk == tk then i = i + 1 @@ -3158,7 +3189,7 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege end parse_argument_list = function(ps: ParseState, i: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "argument_list") + local node = new_node(ps, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false local min_arity = 0 @@ -3253,16 +3284,16 @@ end local function parse_identifier(ps: ParseState, i: integer): integer, Node, integer if ps.tokens[i].kind == "identifier" then - return i + 1, new_node(ps.tokens, i, "identifier") + return i + 1, new_node(ps, i, "identifier") end i = fail(ps, i, "syntax error, expected identifier") - return i, new_node(ps.tokens, i, "error_node") + return i, new_node(ps, i, "error_node") end local function parse_local_function(ps: ParseState, i: integer): integer, Node i = verify_tk(ps, i, "local") i = verify_tk(ps, i, "function") - local node = new_node(ps.tokens, i - 2, "local_function") + local node = new_node(ps, i - 2, "local_function") i, node.name = parse_identifier(ps, i) return parse_function_args_rets_body(ps, i, node) end @@ -3275,7 +3306,7 @@ end local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): integer, Node local orig_i = i i = verify_tk(ps, i, "function") - local fn = new_node(ps.tokens, i - 1, "global_function") + local fn = new_node(ps, i - 1, "global_function") local names: {Node} = {} i, names[1] = parse_identifier(ps, i) while ps.tokens[i].tk == "." do @@ -3295,7 +3326,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int for i2 = 2, #names - 1 do local dot = an_operator(names[i2], 2, ".") names[i2].kind = "identifier" - owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + owner = { f = ps.filename, y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } end fn.fn_owner = owner end @@ -3303,8 +3334,8 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y i = parse_function_args_rets_body(ps, i, fn) - if fn.is_method then - table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + if fn.is_method and fn.args then + table.insert(fn.args, 1, { f = ps.filename, x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) fn.min_arity = fn.min_arity + 1 end @@ -3322,7 +3353,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int end local function parse_if_block(ps: ParseState, i: integer, n: integer, node: Node, is_else?: boolean): integer, Node - local block = new_node(ps.tokens, i, "if_block") + local block = new_node(ps, i, "if_block") i = i + 1 block.if_parent = node block.if_block_n = n @@ -3344,7 +3375,7 @@ end local function parse_if(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "if") + local node = new_node(ps, i, "if") node.if_blocks = {} i, node = parse_if_block(ps, i, 1, node) if not node then @@ -3370,7 +3401,7 @@ end local function parse_while(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "while") + local node = new_node(ps, i, "while") i = verify_tk(ps, i, "while") i, node.exp = parse_expression_and_tk(ps, i, "do") i, node.body = parse_statements(ps, i) @@ -3380,7 +3411,7 @@ end local function parse_fornum(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "fornum") + local node = new_node(ps, i, "fornum") i = i + 1 i, node.var = parse_identifier(ps, i) i = verify_tk(ps, i, "=") @@ -3399,12 +3430,12 @@ end local function parse_forin(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "forin") + local node = new_node(ps, i, "forin") i = i + 1 - node.vars = new_node(ps.tokens, i, "variable_list") + node.vars = new_node(ps, i, "variable_list") i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) i = verify_tk(ps, i, "in") - node.exps = new_node(ps.tokens, i, "expression_list") + node.exps = new_node(ps, i, "expression_list") i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) if #node.exps < 1 then return fail(ps, i, "missing iterator expression in generic for") @@ -3426,7 +3457,7 @@ local function parse_for(ps: ParseState, i: integer): integer, Node end local function parse_repeat(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "repeat") + local node = new_node(ps, i, "repeat") i = verify_tk(ps, i, "repeat") i, node.body = parse_statements(ps, i) node.body.is_repeat = true @@ -3438,7 +3469,7 @@ end local function parse_do(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "do") + local node = new_node(ps, i, "do") i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) i = verify_end(ps, i, istart, node) @@ -3446,13 +3477,13 @@ local function parse_do(ps: ParseState, i: integer): integer, Node end local function parse_break(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "break") + local node = new_node(ps, i, "break") i = verify_tk(ps, i, "break") return i, node end local function parse_goto(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "goto") + local node = new_node(ps, i, "goto") i = verify_tk(ps, i, "goto") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") @@ -3460,7 +3491,7 @@ local function parse_goto(ps: ParseState, i: integer): integer, Node end local function parse_label(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "label") + local node = new_node(ps, i, "label") i = verify_tk(ps, i, "::") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") @@ -3485,9 +3516,9 @@ for k, v in pairs(stop_statement_list) do end local function parse_return(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "return") + local node = new_node(ps, i, "return") i = verify_tk(ps, i, "return") - node.exps = new_node(ps.tokens, i, "expression_list") + node.exps = new_node(ps, i, "expression_list") i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) if ps.tokens[i].kind == ";" then i = i + 1 @@ -3525,12 +3556,13 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType return fail(ps, i, "expected a variable name") end - local nt: Node = new_node(ps.tokens, i - 2, "newtype") + local nt: Node = new_node(ps, i - 2, "newtype") local ndef = new_type(ps, i, typename) + local itype = i local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype = new_typedecl(ps, i, ndef) + nt.newtype = new_typedecl(ps, itype, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3587,7 +3619,7 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): -- if ps.tokens[i].tk == "<" then -- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) -- end - local node = new_node(ps.tokens, istart, "macroexp") + local node = new_node(ps, istart, "macroexp") local i: integer i, node.args, node.min_arity = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) @@ -3599,18 +3631,14 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): end local function parse_where_clause(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "macroexp") - - local selftype = new_type(ps, i, "nominal") as NominalType - selftype.names = { "@self" } - - node.args = new_node(ps.tokens, i, "argument_list") - node.args[1] = new_node(ps.tokens, i, "argument") + local node = new_node(ps, i, "macroexp") + node.args = new_node(ps, i, "argument_list") + node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = selftype + node.args[1].argtype = new_nominal(ps, i, "@self") node.min_arity = 1 node.rets = new_tuple(ps, i) - node.rets.tuple[1] = BOOLEAN + node.rets.tuple[1] = new_type(ps, i, "boolean") i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i - 1]) return i, node @@ -3692,15 +3720,10 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no local typ = new_type(ps, wstart, "function") as FunctionType typ.is_method = true typ.min_arity = 1 - typ.args = a_tuple { - a_type("nominal", { - y = typ.y, - x = typ.x, - filename = ps.filename, - names = { "@self" } - } as NominalType) - } - typ.rets = a_tuple { BOOLEAN } + typ.args = new_tuple(ps, wstart, { + a_nominal(where_macroexp, { "@self" }) + }) + typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) typ.macroexp = where_macroexp def.meta_fields = {} @@ -3821,7 +3844,7 @@ parse_type_body_fns = { } parse_newtype = function(ps: ParseState, i: integer): integer, Node - local node: Node = new_node(ps.tokens, i, "newtype") + local node: Node = new_node(ps, i, "newtype") local def: Type local tn = ps.tokens[i].tk as TypeName local itype = i @@ -3842,9 +3865,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end if def is NominalType then - local typealias = new_type(ps, itype, "typealias") as TypeAliasType - typealias.alias_to = def - node.newtype = typealias + node.newtype = new_typealias(ps, itype, def) else node.newtype = new_typedecl(ps, itype, def) end @@ -3854,7 +3875,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end local function parse_assignment_expression_list(ps: ParseState, i: integer, asgn: Node): integer, Node - asgn.exps = new_node(ps.tokens, i, "expression_list") + asgn.exps = new_node(ps, i, "expression_list") repeat i = i + 1 local val: Node @@ -3904,8 +3925,8 @@ do return fail(ps, i, "syntax error") end - local asgn: Node = new_node(ps.tokens, istart, "assignment") - asgn.vars = new_node(ps.tokens, istart, "variable_list") + local asgn: Node = new_node(ps, istart, "assignment") + asgn.vars = new_node(ps, istart, "variable_list") asgn.vars[1] = exp if ps.tokens[i].tk == "," then i = i + 1 @@ -3926,9 +3947,9 @@ do end local function parse_variable_declarations(ps: ParseState, i: integer, node_name: NodeKind): integer, Node - local asgn: Node = new_node(ps.tokens, i, node_name) + local asgn: Node = new_node(ps, i, node_name) - asgn.vars = new_node(ps.tokens, i, "variable_list") + asgn.vars = new_node(ps, i, "variable_list") i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) if #asgn.vars == 0 then return fail(ps, i, "expected a local variable definition") @@ -3956,7 +3977,7 @@ end local function parse_type_declaration(ps: ParseState, i: integer, node_name: NodeKind): integer, Node i = i + 2 -- skip `local` or `global`, and `type` - local asgn: Node = new_node(ps.tokens, i, node_name) + local asgn: Node = new_node(ps, i, node_name) i, asgn.var = parse_variable_name(ps, i) if not asgn.var then return fail(ps, i, "expected a type name") @@ -3996,8 +4017,8 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod end local function parse_type_constructor(ps: ParseState, i: integer, node_name: NodeKind, type_name: TypeName, parse_body: ParseBody): integer, Node - local asgn: Node = new_node(ps.tokens, i, node_name) - local nt: Node = new_node(ps.tokens, i, "newtype") + local asgn: Node = new_node(ps, i, node_name) + local nt: Node = new_node(ps, i, "newtype") asgn.value = nt local itype = i local def = new_type(ps, i, type_name) @@ -4026,7 +4047,7 @@ end local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node local istart = i i = i + 2 -- skip `local` - local node = new_node(ps.tokens, i, "local_macroexp") + local node = new_node(ps, i, "local_macroexp") i, node.name = parse_identifier(ps, i) i, node.macrodef = parse_macroexp(ps, istart, i) end_at(node, ps.tokens[i - 1]) @@ -4096,7 +4117,7 @@ local needs_local_or_global: {string : function(ParseState, integer):(integer, N } parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): integer, Node - local node = new_node(ps.tokens, i, "statements") + local node = new_node(ps, i, "statements") local item: Node while true do while ps.tokens[i].kind == ";" do @@ -4141,32 +4162,6 @@ parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): int return i, node end -local function clear_redundant_errors(errors: {Error}) - local redundant: {integer} = {} - local lastx, lasty = 0, 0 - for i, err in ipairs(errors) do - err.i = i - end - table.sort(errors, function(a: Error, b: Error): boolean - local af = a.filename or "" - local bf = b.filename or "" - return af < bf - or (af == bf and (a.y < b.y - or (a.y == b.y and (a.x < b.x - or (a.x == b.x and (a.i < b.i)))))) - end) - for i, err in ipairs(errors) do - err.i = nil - if err.x == lastx and err.y == lasty then - table.insert(redundant, i) - end - lastx, lasty = err.x, err.y - end - for i = #redundant, 1, -1 do - table.remove(errors, redundant[i]) - end -end - function tl.parse_program(tokens: {Token}, errs: {Error}, filename: string): Node, {string} errs = errs or {} local ps: ParseState = { @@ -4196,17 +4191,19 @@ function tl.parse(input: string, filename: string): Node, {Error}, {string} return node, errs, required_modules end +end ---------------------------------------------------------------------------- + -------------------------------------------------------------------------------- -- AST traversal -------------------------------------------------------------------------------- -local record VisitorCallbacks - before: function(N) - before_exp: function({N}, {T}) - before_arguments: function({N}, {T}) - before_statements: function({N}, {T}) - before_e2: function({N}, {T}) - after: function(N, {T}): T +local record VisitorCallbacks + before: function(S, N) + before_exp: function(S, {N}, {T}) + before_arguments: function(S, {N}, {T}) + before_statements: function(S, {N}, {T}) + before_e2: function(S, {N}, {T}) + after: function(S, N, {T}): T end local enum VisitorExtraCallback @@ -4216,9 +4213,11 @@ local enum VisitorExtraCallback "before_e2" end -local record Visitor - cbs: {K:VisitorCallbacks} - after: function(N, {T}, T): T +local type VisitorAfter = function(S, N, {T}, T): T + +local record Visitor + cbs: {K:VisitorCallbacks} + after: VisitorAfter allow_missing_cbs: boolean end @@ -4307,7 +4306,7 @@ local function tl_debug_indent_pop(mark: string, single: string, y: integer, x: end end -local function recurse_type(ast: Type, visit: Visitor): T +local function recurse_type(s: S, ast: Type, visit: Visitor): T local kind = ast.typename if TL_DEBUG then @@ -4319,7 +4318,7 @@ local function recurse_type(ast: Type, visit: Visitor): T if cbkind then local cbkind_before = cbkind.before if cbkind_before then - cbkind_before(ast) + cbkind_before(s, ast) end end @@ -4327,90 +4326,90 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast is TupleType then for i, child in ipairs(ast.tuple) do - xs[i] = recurse_type(child, visit) + xs[i] = recurse_type(s, child, visit) end elseif ast is AggregateType then for _, child in ipairs(ast.types) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end elseif ast is MapType then - table.insert(xs, recurse_type(ast.keys, visit)) - table.insert(xs, recurse_type(ast.values, visit)) + table.insert(xs, recurse_type(s, ast.keys, visit)) + table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast is RecordLikeType then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end if ast.fields then for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.meta_fields then for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is FunctionType then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.args then for _, child in ipairs(ast.args.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.rets then for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is NominalType then if ast.typevals then for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is TypeArgType then if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + table.insert(xs, recurse_type(s, ast.constraint, visit)) end elseif ast is ArrayType then if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end elseif ast is LiteralTableItemType then if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) + table.insert(xs, recurse_type(s, ast.ktype, visit)) end if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) + table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast is TypeAliasType then - table.insert(xs, recurse_type(ast.alias_to, visit)) + table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast is TypeDeclType then - table.insert(xs, recurse_type(ast.def, visit)) + table.insert(xs, recurse_type(s, ast.def, visit)) end local ret: T local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end local visit_after = visit.after if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4420,25 +4419,26 @@ local function recurse_type(ast: Type, visit: Visitor): T return ret end -local function recurse_typeargs(ast: Node, visit_type: Visitor) +local function recurse_typeargs(s: S, ast: Node, visit_type: Visitor) if ast.typeargs then for _, typearg in ipairs(ast.typeargs) do - recurse_type(typearg, visit_type) + recurse_type(s, typearg, visit_type) end end end -local function extra_callback(name: VisitorExtraCallback, - ast: Node, - xs: {T}, - visit_node: Visitor) +local function extra_callback(name: VisitorExtraCallback, + s: S, + ast: Node, + xs: {T}, + visit_node: Visitor) local cbs = visit_node.cbs if not cbs then return end local nbs = cbs[ast.kind] if not nbs then return end local bs = nbs[name] if not bs then return end - bs(ast, xs) + bs(s, ast, xs) end local no_recurse_node: {NodeKind : boolean} = { @@ -4458,9 +4458,9 @@ local no_recurse_node: {NodeKind : boolean} = { ["type_identifier"] = true, } -local function recurse_node(root: Node, - visit_node: Visitor, - visit_type: Visitor): T +local function recurse_node(s: S, root: Node, + visit_node: Visitor, + visit_type: Visitor): T if not root then -- parse error return @@ -4477,9 +4477,9 @@ local function recurse_node(root: Node, local function walk_vars_exps(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) if ast.decltuple then - xs[2] = recurse_type(ast.decltuple, visit_type) + xs[2] = recurse_type(s, ast.decltuple, visit_type) end - extra_callback("before_exp", ast, xs, visit_node) + extra_callback("before_exp", s, ast, xs, visit_node) if ast.exps then xs[3] = recurse(ast.exps) end @@ -4491,11 +4491,11 @@ local function recurse_node(root: Node, end local function walk_named_function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) xs[2] = recurse(ast.args) - xs[3] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[4] = recurse(ast.body) end @@ -4508,9 +4508,9 @@ local function recurse_node(root: Node, end xs[2] = p1 as T if ast.op.arity == 2 then - extra_callback("before_e2", ast, xs, visit_node) + extra_callback("before_e2", s, ast, xs, visit_node) if ast.op.op == "is" or ast.op.op == "as" then - xs[3] = recurse_type(ast.e2.casttype, visit_type) + xs[3] = recurse_type(s, ast.e2.casttype, visit_type) else xs[3] = recurse(ast.e2) end @@ -4528,7 +4528,7 @@ local function recurse_node(root: Node, xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then - xs[3] = recurse_type(ast.itemtype, visit_type) + xs[3] = recurse_type(s, ast.itemtype, visit_type) end end, @@ -4554,13 +4554,13 @@ local function recurse_node(root: Node, if ast.exp then xs[1] = recurse(ast.exp) end - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, ["while"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.exp) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, @@ -4570,45 +4570,45 @@ local function recurse_node(root: Node, end, ["macroexp"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[3] = recurse(ast.exp) end, ["function"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, ["local_function"] = walk_named_function, ["global_function"] = walk_named_function, ["record_function"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) - extra_callback("before_arguments", ast, xs, visit_node) + extra_callback("before_arguments", s, ast, xs, visit_node) xs[3] = recurse(ast.args) - xs[4] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[4] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, ["local_macroexp"] = function(ast: Node, xs: {T}) -- TODO: generic macroexp xs[1] = recurse(ast.name) xs[2] = recurse(ast.macrodef.args) - xs[3] = recurse_type(ast.macrodef.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.macrodef.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[4] = recurse(ast.macrodef.exp) end, ["forin"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) xs[2] = recurse(ast.exps) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, @@ -4617,7 +4617,7 @@ local function recurse_node(root: Node, xs[2] = recurse(ast.from) xs[3] = recurse(ast.to) xs[4] = ast.step and recurse(ast.step) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, @@ -4634,12 +4634,12 @@ local function recurse_node(root: Node, end, ["newtype"] = function(ast: Node, xs:{T}) - xs[1] = recurse_type(ast.newtype, visit_type) + xs[1] = recurse_type(s, ast.newtype, visit_type) end, ["argument"] = function(ast: Node, xs:{T}) if ast.argtype then - xs[1] = recurse_type(ast.argtype, visit_type) + xs[1] = recurse_type(s, ast.argtype, visit_type) end end, } @@ -4658,7 +4658,7 @@ local function recurse_node(root: Node, local cbkind = cbs and cbs[kind] if cbkind then if cbkind.before then - cbkind.before(ast) + cbkind.before(s, ast) end end @@ -4682,10 +4682,10 @@ local function recurse_node(root: Node, local ret: T local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4768,7 +4768,7 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOptions): string, string +function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode: boolean | PrettyPrintOptions): string, string local err: string local indent = 0 @@ -4789,7 +4789,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local save_indent: {integer} = {} - local function increment_indent(node: Node) + local function increment_indent(_: nil, node: Node) local child = node.body or node[1] if not child then return @@ -4882,7 +4882,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return table.concat(out) end - local visit_node: Visitor = {} + local visit_node: Visitor = {} local lua_54_attribute : {Attribute:string} = { ["const"] = " ", @@ -4890,17 +4890,17 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["total"] = " ", } - local function emit_exactly(node: Node, _children: {Output}): Output + local function emit_exactly(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } add_string(out, node.tk) return out end - local emit_exactly_visitor_cbs : VisitorCallbacks = { after = emit_exactly } + local emit_exactly_visitor_cbs : VisitorCallbacks = { after = emit_exactly } visit_node.cbs = { ["statements"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output if opts.preserve_hashbang and node.hashbang then out = { y = 1, h = 0 } @@ -4922,7 +4922,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end }, ["local_declaration"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "local ") for i, var in ipairs(node.vars) do @@ -4948,7 +4948,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["local_type"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if not node.var.elide_type then table.insert(out, "local") @@ -4960,7 +4960,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["global_type"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) @@ -4971,7 +4971,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["global_declaration"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if children[3] then add_child(out, children[1]) @@ -4982,7 +4982,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["assignment"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") @@ -4991,7 +4991,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["if"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } for i, child in ipairs(children) do add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) @@ -5002,7 +5002,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["if_block"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.if_block_n == 1 then table.insert(out, "if") @@ -5022,7 +5022,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["while"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") @@ -5035,7 +5035,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["repeat"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") @@ -5047,7 +5047,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["do"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") @@ -5058,7 +5058,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["forin"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5073,7 +5073,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["fornum"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5093,7 +5093,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["return"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then @@ -5103,14 +5103,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["break"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["variable_list"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } local space: string for i, child in ipairs(children) do @@ -5125,7 +5125,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["literal_table"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if #children == 0 then table.insert(out, "{}") @@ -5145,7 +5145,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["literal_table_item"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then @@ -5168,13 +5168,13 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["local_macroexp"] = { before = increment_indent, - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output return { y = node.y, h = 0 } end, }, ["local_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") @@ -5189,7 +5189,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["global_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5204,7 +5204,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["record_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5229,7 +5229,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) @@ -5243,7 +5243,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["cast"] = { }, ["paren"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) @@ -5252,7 +5252,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["op"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) @@ -5313,7 +5313,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["newtype"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } local nt = node.newtype if nt is TypeAliasType then @@ -5330,7 +5330,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["goto"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) @@ -5338,7 +5338,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["label"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) @@ -5347,7 +5347,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["string"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output -- translate escape sequences not supported by Lua 5.1 -- in particular: -- - \z : removes trailing whitespace @@ -5355,7 +5355,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | -- - \u{} : unicode if node.tk:sub(1, 1) == "[" or gen_target ~= "5.1" or not node.tk:find("\\", 1, true) then - return emit_exactly(node, children) + return emit_exactly(nil, node, children) end local out : Output = { y = node.y, h = 0 } @@ -5416,10 +5416,10 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["type_identifier"] = emit_exactly_visitor_cbs, } - local visit_type: Visitor = {} + local visit_type: Visitor = {} visit_type.cbs = {} local default_type_visitor = { - after = function(typ: Type, _children: {Output}): Output + after = function(_: nil, typ: Type, _children: {Output}): Output local out: Output = { y = typ.y or -1, h = 0 } local r = typ is NominalType and typ.resolved or typ local lua_type = primitive[r.typename] or "table" @@ -5457,13 +5457,12 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] - local out = recurse_node(ast, visit_node, visit_type) + local out = recurse_node(nil, ast, visit_node, visit_type) if err then return nil, err end @@ -5513,7 +5512,6 @@ local typename_to_typecode : {TypeName:integer} = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, @@ -5521,8 +5519,8 @@ local typename_to_typecode : {TypeName:integer} = { local skip_types: {TypeName: boolean} = { ["none"] = true, + ["tuple"] = true, ["literal_table_item"] = true, - ["unresolved"] = true, } local function sorted_keys(m: {A:B}):{A} @@ -5545,6 +5543,7 @@ function tl.new_type_reporter(): TypeReporter local self: TypeReporter = { next_num = 1, typeid_to_num = {}, + typename_to_num = {}, tr = { by_pos = {}, types = {}, @@ -5552,6 +5551,24 @@ function tl.new_type_reporter(): TypeReporter globals = {}, }, } + + local names = {} + for name, _ in pairs(simple_types) do + table.insert(names, name) + end + table.sort(names) + + for _, name in ipairs(names) do + local ti: TypeInfo = { + t = assert(typename_to_typecode[name]), + str = name, + } + local n = self.next_num + self.typename_to_num[name] = n + self.tr.types[n] = ti + self.next_num = self.next_num + 1 + end + return setmetatable(self, { __index = TypeReporter }) end @@ -5571,9 +5588,15 @@ function TypeReporter:store_function(ti: TypeInfo, rt: FunctionType) end function TypeReporter:get_typenum(t: Type): integer + -- try simple types first + local n = self.typename_to_num[t.typename] + if n then + return n + end + assert(t.typeid) -- try by typeid - local n = self.typeid_to_num[t.typeid] + n = self.typeid_to_num[t.typeid] if n then return n end @@ -5597,7 +5620,7 @@ function TypeReporter:get_typenum(t: Type): integer local ti: TypeInfo = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), - file = t.filename, + file = t.f, y = t.y, x = t.x, } @@ -5667,7 +5690,7 @@ local record TypeCollector end function TypeReporter:get_collector(filename: string): TypeCollector - local tc: TypeCollector = { + local collector: TypeCollector = { filename = filename, symbol_list = {}, } @@ -5675,10 +5698,10 @@ function TypeReporter:get_collector(filename: string): TypeCollector local ft: {integer:{integer:integer}} = {} self.tr.by_pos[filename] = ft - local symbol_list = tc.symbol_list + local symbol_list = collector.symbol_list local symbol_list_n = 0 - tc.store_type = function(y: integer, x: integer, typ: Type) + collector.store_type = function(y: integer, x: integer, typ: Type) if not typ or skip_types[typ.typename] then return end @@ -5692,12 +5715,12 @@ function TypeReporter:get_collector(filename: string): TypeCollector yt[x] = self:get_typenum(typ) end - tc.reserve_symbol_list_slot = function(node: Node) + collector.reserve_symbol_list_slot = function(node: Node) symbol_list_n = symbol_list_n + 1 node.symbol_list_slot = symbol_list_n end - tc.add_to_symbol_list = function(node: Node, name: string, t: Type) + collector.add_to_symbol_list = function(node: Node, name: string, t: Type) if not node then return end @@ -5711,12 +5734,12 @@ function TypeReporter:get_collector(filename: string): TypeCollector symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } end - tc.begin_symbol_list_scope = function(node: Node) + collector.begin_symbol_list_scope = function(node: Node) symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end - tc.end_symbol_list_scope = function(node: Node) + collector.end_symbol_list_scope = function(node: Node) if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -5726,14 +5749,14 @@ function TypeReporter:get_collector(filename: string): TypeCollector end end - return tc + return collector end -function TypeReporter:store_result(tc: TypeCollector, globals: {string:Variable}) +function TypeReporter:store_result(collector: TypeCollector, globals: {string:Variable}) local tr = self.tr - local filename = tc.filename - local symbol_list = tc.symbol_list + local filename = collector.filename + local symbol_list = collector.symbol_list tr.by_pos[filename][0] = nil @@ -5808,144 +5831,450 @@ function TypeReporter:get_report(): TypeReport return self.tr end --- backwards compatibility -function tl.get_types(result: Result): TypeReport, TypeReporter - return result.env.reporter:get_report(), result.env.reporter -end -------------------------------------------------------------------------------- --- Type check +-- Report types -------------------------------------------------------------------------------- -local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {} as InvalidType) -local UNKNOWN = a_type("unknown", {}) -local CIRCULAR_REQUIRE = a_type("circular_require", {}) - -local FUNCTION = a_fn { args = va_args { ANY }, rets = va_args { ANY } } - ---local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) -local XPCALL_MSGH_FUNCTION = a_fn { args = { ANY }, rets = { } } - ---local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type - -local numeric_binop = { +function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer, filename: string): {string:integer} + local function find(symbols: {TypeReport.Symbol}, at_y: integer, at_x: integer): integer + local function le(a: {integer, integer}, b: {integer, integer}): boolean + return a[1] < b[1] + or (a[1] == b[1] and a[2] <= b[2]) + end + return binary_search(symbols, {at_y, at_x}, le) or 0 + end + + local ret: {string:integer} = {} + + local symbols = tr.symbols_by_file[filename] + if not symbols then + return ret + end + + local n = find(symbols, y, x) + + while n >= 1 do + local s = symbols[n] + if s[3] == "@{" then + n = n - 1 + elseif s[3] == "@}" then + n = s[4] + else + ret[s[3]] = s[4] + n = n - 1 + end + end + + return ret +end + +-------------------------------------------------------------------------------- +-- Errors +-------------------------------------------------------------------------------- + +function Errors.new(filename: string): Errors + local self = { + errors = {}, + warnings = {}, + unknown_dots = {}, + filename = filename, + } + return setmetatable(self, { __index = Errors }) +end + +local function Err(msg: string, t1?: Type, t2?: Type, t3?: Type): Error + if t1 then + local s1, s2, s3: string, string, string + if t1 is InvalidType then + return nil + end + s1 = show_type(t1) + if t2 then + if t2 is InvalidType then + return nil + end + s2 = show_type(t2) + end + if t3 then + if t3 is InvalidType then + return nil + end + s3 = show_type(t3) + end + msg = msg:format(s1, s2, s3) + return { + msg = msg, + x = t1.x, + y = t1.y, + filename = t1.f, + } + end + + return { + msg = msg, + } +end + +local function insert_error(self: Errors, y: integer, x: integer, err: Error) + err.y = assert(y) + err.x = assert(x) + err.filename = self.filename + + if TL_DEBUG then + io.stderr:write("ERROR:" .. err.y .. ":" .. err.x .. ": " .. err.msg .. "\n") + end + + table.insert(self.errors, err) +end + +function Errors:add(w: Where, msg: string, ...:Type) + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + +local context_name: {NodeKind: string} = { + ["local_declaration"] = "in local declaration", + ["global_declaration"] = "in global declaration", + ["assignment"] = "in assignment", + ["literal_table_item"] = "in table item", +} + +function Errors:get_context(ctx: Node|string, name?: string): string + if not ctx then + return "" + end + local ec = (ctx is Node) and ctx.expected_context + local cn = (ctx is string) and ctx or + (ctx is Node) and context_name[ec and ec.kind or ctx.kind] + return (cn and cn .. ": " or "") .. (ec and ec.name and ec.name .. ": " or "") .. (name and name .. ": " or "") +end + +function Errors:add_in_context(w: Where, ctx: Node, msg: string, ...:Type) + local prefix = self:get_context(ctx) + msg = prefix .. msg + + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + + +function Errors:collect(errs: {Error}) + for _, e in ipairs(errs) do + insert_error(self, e.y, e.x, e) + end +end + +function Errors:add_warning(tag: WarningKind, w: Where, fmt: string, ...: any) + assert(w.y) + table.insert(self.warnings, { + y = w.y, + x = w.x, + msg = fmt:format(...), + filename = self.filename, + tag = tag, + }) +end + +function Errors:invalid_at(w: Where, msg: string, ...:Type): InvalidType + self:add(w, msg, ...) + return an_invalid(w) +end + +function Errors:add_unknown(node: Node, name: string) + self:add_warning("unknown", node, "unknown variable: %s", name) +end + +function Errors:redeclaration_warning(node: Node, old_var?: Variable) + if node.tk:sub(1, 1) == "_" then return end + + local var_kind = "variable" + local var_name = node.tk + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + + local short_error = "redeclaration of " .. var_kind .. " '%s'" + if old_var and old_var.declared_at then + self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + else + self:add_warning("redeclaration", node, short_error, var_name) + end +end + +function Errors:unused_warning(name: string, var: Variable) + local prefix = name:sub(1,1) + if var.declared_at + and var.is_narrowed ~= "narrow" + and prefix ~= "_" + and prefix ~= "@" + then + local t = var.t + self:add_warning( + "unused", + var.declared_at, + "unused %s %s: %s", + var.is_func_arg and "argument" + or t is FunctionType and "function" + or t is TypeDeclType and "type" + or t is TypeAliasType and "type" + or "variable", + name, + show_type(var.t) + ) + end +end + +function Errors:add_prefixing(w: Where, src: {Error}, prefix: string, dst?: {Error}) + if not src then + return + end + + for _, err in ipairs(src) do + err.msg = prefix .. err.msg + if w and ( + (err.filename ~= w.f) + or (not err.y) + or (w.y > err.y or (w.y == err.y and w.x > err.x)) + ) then + err.y = w.y + err.x = w.x + err.filename = w.f + end + + if dst then + table.insert(dst, err) + else + insert_error(self, err.y, err.x, err) + end + end +end + +local record Unused + y: integer + x: integer + name: string + var: Variable +end + +local function check_for_unused_vars(scope: Scope, is_global?: boolean): {Unused} + local vars = scope.vars + if not next(vars) then + return + end + local list: {Unused} + for name, var in pairs(vars) do + local t = var.t + if var.declared_at and not var.used then + if var.used_as_type then + var.declared_at.elide_type = true + else + if (t is TypeDeclType or t is TypeAliasType) and not is_global then + var.declared_at.elide_type = true + end + list = list or {} + table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) + end + elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then + var.aliasing.used = true + var.aliasing.declared_at.elide_type = false + end + end + if list then + table.sort(list, function(a: Unused, b: Unused): boolean + return a.y < b.y or (a.y == b.y and a.x < b.x) + end) + end + return list +end + +function Errors:warn_unused_vars(scope: Scope, is_global?: boolean) + local unused = check_for_unused_vars(scope, is_global) + if unused then + for _, u in ipairs(unused) do + self:unused_warning(u.name, u.var) + end + end + + if scope.labels then + for name, node in pairs(scope.labels) do + if not node.used_label then + self:add_warning("unused", node, "unused label ::%s::", name) + end + end + end +end + +function Errors:add_unknown_dot(node: Node, name: string) + if not self.unknown_dots[name] then + self.unknown_dots[name] = true + self:add_unknown(node, name) + end +end + +function Errors:fail_unresolved_labels(scope: Scope) + if scope.pending_labels then + for name, nodes in pairs(scope.pending_labels) do + for _, node in ipairs(nodes) do + self:add(node, "no visible label '" .. name .. "' for goto") + end + end + end +end + +function Errors:fail_unresolved_nominals(scope: Scope, global_scope: Scope) + if global_scope and scope.pending_nominals then + for name, types in pairs(scope.pending_nominals) do + if not global_scope.pending_global_types[name] then + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + self:add(typ, "unknown type %s", typ) + end + end + end + end +end + +local type CheckableKey = string | number | boolean + +function Errors:check_redeclared_key(w: Where, ctx: Node, seen_keys: {CheckableKey:Where}, key: CheckableKey) + if key ~= nil then + local s = seen_keys[key] + if s then + self:add_in_context(w, ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. self.filename .. ":" .. s.y .. ":" .. s.x .. ")") + else + seen_keys[key] = w + end + end +end + +-------------------------------------------------------------------------------- +-- Type check +-------------------------------------------------------------------------------- + +local numeric_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, + ["integer"] = "integer", + ["number"] = "number", }, } local float_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, + ["integer"] = "number", + ["number"] = "number", }, } local integer_binop = { ["number"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = INTEGER, + ["integer"] = "integer", + ["number"] = "integer", }, } local relational_binop = { ["number"] = { - ["integer"] = BOOLEAN, - ["number"] = BOOLEAN, + ["integer"] = "boolean", + ["number"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, + ["string"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, } local equality_binop = { ["number"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["string"] = "boolean", + ["nil"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["boolean"] = "boolean", + ["nil"] = "boolean", }, ["record"] = { - ["emptytable"] = BOOLEAN, - ["record"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["record"] = "boolean", + ["nil"] = "boolean", }, ["array"] = { - ["emptytable"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["array"] = "boolean", + ["nil"] = "boolean", }, ["map"] = { - ["emptytable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["map"] = "boolean", + ["nil"] = "boolean", }, ["thread"] = { - ["thread"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["thread"] = "boolean", + ["nil"] = "boolean", } } -local unop_types: {string:{string:Type}} = { +local unop_types: {string:{TypeName:TypeName}} = { ["#"] = { - ["string"] = INTEGER, - ["array"] = INTEGER, - ["tupletable"] = INTEGER, - ["map"] = INTEGER, - ["emptytable"] = INTEGER, + ["string"] = "integer", + ["array"] = "integer", + ["tupletable"] = "integer", + ["map"] = "integer", + ["emptytable"] = "integer", }, ["-"] = { - ["number"] = NUMBER, - ["integer"] = INTEGER, + ["number"] = "number", + ["integer"] = "integer", }, ["~"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["not"] = { - ["string"] = BOOLEAN, - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["boolean"] = BOOLEAN, - ["record"] = BOOLEAN, - ["array"] = BOOLEAN, - ["tupletable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["emptytable"] = BOOLEAN, - ["thread"] = BOOLEAN, + ["string"] = "boolean", + ["number"] = "boolean", + ["integer"] = "boolean", + ["boolean"] = "boolean", + ["record"] = "boolean", + ["array"] = "boolean", + ["tupletable"] = "boolean", + ["map"] = "boolean", + ["emptytable"] = "boolean", + ["thread"] = "boolean", }, } @@ -5955,7 +6284,7 @@ local unop_to_metamethod: {string:string} = { ["~"] = "__bnot", } -local binop_types: {string:{TypeName:{TypeName:Type}}} = { +local binop_types: {string:{TypeName:{TypeName:TypeName}}} = { ["+"] = numeric_binop, ["-"] = numeric_binop, ["*"] = numeric_binop, @@ -5976,67 +6305,66 @@ local binop_types: {string:{TypeName:{TypeName:Type}}} = { [">"] = relational_binop, ["or"] = { ["boolean"] = { - ["boolean"] = BOOLEAN, - ["function"] = FUNCTION, -- HACK + ["boolean"] = "boolean", }, ["number"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "number", + ["number"] = "number", + ["boolean"] = "boolean", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "integer", + ["number"] = "number", + ["boolean"] = "boolean", }, ["string"] = { - ["string"] = STRING, - ["boolean"] = BOOLEAN, - ["enum"] = STRING, + ["string"] = "string", + ["boolean"] = "boolean", + ["enum"] = "string", }, ["function"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["array"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["record"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["map"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["enum"] = { - ["string"] = STRING, + ["string"] = "string", }, ["thread"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", } }, [".."] = { ["string"] = { - ["string"] = STRING, - ["enum"] = STRING, - ["number"] = STRING, - ["integer"] = STRING, + ["string"] = "string", + ["enum"] = "string", + ["number"] = "string", + ["integer"] = "string", }, ["number"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["integer"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["enum"] = { - ["number"] = STRING, - ["integer"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["number"] = "string", + ["integer"] = "string", + ["string"] = "string", + ["enum"] = "string", } }, } @@ -6244,8 +6572,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end end -local function inferred_msg(t: Type): string - return " (inferred at "..t.inferred_at.filename..":"..t.inferred_at.y..":"..t.inferred_at.x..")" +local function inferred_msg(t: Type, prefix?: string): string + return " (" .. (prefix or "") .. "inferred at "..t.inferred_at.f..":"..t.inferred_at.y..":"..t.inferred_at.x..")" end show_type = function(t: Type, short?: boolean, seen?: {Type:string}): string @@ -6297,33 +6625,34 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(module_name: string, lax: boolean, env: Env): Type, boolean +local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string local mod = env.modules[module_name] if mod then - return mod, true + return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (lax or found:match("tl$") as boolean) then + if found and (feat_lax or found:match("tl$") as boolean) then - env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + env.module_filenames[module_name] = found + env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) env.modules[module_name] = found_result.type - return found_result.type, true + return found_result.type, found elseif fd then fd:close() end - return INVALID, found ~= nil + return an_invalid(w), found end local compat_code_cache: {string:Node} = {} -local function add_compat_entries(program: Node, used_set: {string: boolean}, gen_compat: CompatMode) +local function add_compat_entries(program: Node, used_set: {string: boolean}, gen_compat: GenCompat) if gen_compat == "off" or not next(used_set) then return end @@ -6340,7 +6669,7 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge local code: Node = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, { filename = "", lax = false, gen_compat = "off" }) + tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -6379,32 +6708,26 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge TL_DEBUG = tl_debug end -local function get_stdlib_compat(lax: boolean): {string:boolean} - if lax then - return { - ["utf8"] = true, - } - else - return { - ["io"] = true, - ["math"] = true, - ["string"] = true, - ["table"] = true, - ["utf8"] = true, - ["coroutine"] = true, - ["os"] = true, - ["package"] = true, - ["debug"] = true, - ["load"] = true, - ["loadfile"] = true, - ["assert"] = true, - ["pairs"] = true, - ["ipairs"] = true, - ["pcall"] = true, - ["xpcall"] = true, - ["rawlen"] = true, - } - end +local function get_stdlib_compat(): {string:boolean} + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } end local bit_operators: {string:string} = { @@ -6415,14 +6738,21 @@ local bit_operators: {string:string} = { ["<<"] = "lshift", } +local function node_at(w: Where, n: Node): Node + n.f = assert(w.f) + n.x = w.x + n.y = w.y + return n +end + local function convert_node_to_compat_call(node: Node, mod_name: string, fn_name: string, e1: Node, e2?: Node) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } - node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } - node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } + node.e1 = node_at(node, { kind = "op", op = an_operator(node, 2, ".") }) + node.e1.e1 = node_at(node, { kind = "identifier", tk = mod_name }) + node.e1.e2 = node_at(node, { kind = "identifier", tk = fn_name }) + node.e2 = node_at(node, { kind = "expression_list" }) node.e2[1] = e1 node.e2[2] = e2 end @@ -6431,10 +6761,10 @@ local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } - node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } - node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } + node.e1 = node_at(node, { kind = "identifier", tk = "_tl_mt" }) + node.e2 = node_at(node, { kind = "expression_list" }) + node.e2[1] = node_at(node, { kind = "string", tk = "\"" .. mt_name .. "\"" }) + node.e2[2] = node_at(node, { kind = "integer", tk = tostring(which_self) }) node.e2[3] = e1 node.e2[4] = e2 end @@ -6443,25 +6773,6 @@ local stdlib_globals: {string:Variable} = nil local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 -local function set_feat(feat: tl.Feat, default: boolean): boolean - if feat then - return (feat == "on") - else - return default - end -end - -tl.new_env = function(opts: tl.EnvOptions): Env, string - local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) - if not env then - return nil, err - end - - env.feat_arity = set_feat(opts.feat_arity, true) - - return env -end - local function assert_no_stdlib_errors(errors: {Error}, name: string) if #errors ~= 0 then local out = {} @@ -6472,46 +6783,31 @@ local function assert_no_stdlib_errors(errors: {Error}, name: string) end end -tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_target?: TargetMode, predefined?: {string}): Env, string - if gen_compat == true or gen_compat == nil then - gen_compat = "optional" - elseif gen_compat == false then - gen_compat = "off" - end - gen_compat = gen_compat as CompatMode - - if not gen_target then - if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then - gen_target = "5.1" - else - gen_target = "5.3" - end - end - - if gen_target == "5.4" and gen_compat ~= "off" then - return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" - end +tl.new_env = function(opts?: EnvOptions): Env, string + opts = opts or {} local env: Env = { modules = {}, + module_filenames = {}, loaded = {}, loaded_order = {}, globals = {}, - gen_compat = gen_compat, - gen_target = gen_target, + defaults = opts.defaults or {}, } + if env.defaults.gen_target == "5.4" and env.defaults.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + local w: Where = { f = "@stdlib", x = 1, y = 1 } + if not stdlib_globals then local tl_debug = TL_DEBUG TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - - local result = tl.type_check(program, { - filename = "@stdlib", - env = env - }) + local result = tl.type_check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -6520,21 +6816,20 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar -- special cases for compatibility local math_t = (stdlib_globals["math"].t as TypeDeclType).def as RecordType local table_t = (stdlib_globals["table"].t as TypeDeclType).def as RecordType - local integer_compat = a_type("integer", { needs_compat = true }) - math_t.fields["maxinteger"] = integer_compat - math_t.fields["mininteger"] = integer_compat + math_t.fields["maxinteger"].needs_compat = true + math_t.fields["mininteger"].needs_compat = true table_t.fields["unpack"].needs_compat = true -- only global scope and vararg functions accept `...`: -- `@is_va` is an internal sentinel value which is -- `any` if `...` is accepted in this scope or `nil` if it isn't. - stdlib_globals["..."] = { t = a_vararg { STRING } } - stdlib_globals["@is_va"] = { t = ANY } + stdlib_globals["..."] = { t = a_vararg(w, { a_type(w, "string", {}) }) } + stdlib_globals["@is_va"] = { t = a_type(w, "any", {}) } env.globals = {} end - local stdlib_compat = get_stdlib_compat(lax) + local stdlib_compat = get_stdlib_compat() for name, var in pairs(stdlib_globals) do env.globals[name] = var var.needs_compat = stdlib_compat[name] @@ -6545,52 +6840,53 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar end end - if predefined then - for _, name in ipairs(predefined) do - local module_type = require_module(name, lax, env) + if opts.predefined_modules then + for _, name in ipairs(opts.predefined_modules) do + local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) - if module_type == INVALID then + if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) end end end - env.feat_arity = true - return env end -tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string - opts = opts or {} - local env = opts.env - if not env then - local err: string - env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) - if err then - return nil, err - end - end +do + local type TypeRelations = {TypeName:{TypeName:CompareTypes}} + local type InvalidOrTupleType = InvalidType | TupleType - local lax = opts.lax - local feat_arity = env.feat_arity - local filename = opts.filename + local record TypeChecker + env: Env + st: {Scope} + + filename: string + errs: Errors + module_type: Type - local type Scope = {string:Variable} - local st: {Scope} = { env.globals } + subtype_relations: TypeRelations + eqtype_relations: TypeRelations + type_priorities: {TypeName:integer} + + all_needs_compat: {string:boolean} + dependencies: {string:string} + collector: TypeCollector + + gen_compat: GenCompat + gen_target: GenTarget + feat_arity: boolean + feat_lax: boolean - local all_needs_compat = {} + same_type: function(TypeChecker, Type, Type): boolean, {Error} + is_a: function(TypeChecker, Type, Type): boolean, {Error} - local dependencies: {string:string} = {} - local warnings: {Error} = {} - local errors: {Error} = {} + type_check_funcall: function(TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType - local module_type: Type + expand_type: function(TypeChecker, w: Where, old: Type, new: Type): Type - local tc: TypeCollector - if env.report_types then - env.reporter = env.reporter or tl.new_type_reporter() - tc = env.reporter:get_collector(filename or "?") + get_rets: function(TupleType): TupleType end local enum VarUse @@ -6600,10 +6896,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "check_only" end - local function find_var(name: string, use?: VarUse): Variable, integer, Attribute - for i = #st, 1, -1 do - local scope = st[i] - local var = scope[name] + function TypeChecker:find_var(name: string, use?: VarUse): Variable, integer, Attribute + for i = #self.st, 1, -1 do + local scope = self.st[i] + local var = scope.vars[name] if var then if use == "lvalue" and var.is_narrowed then if var.narrowed_from then @@ -6612,7 +6908,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else if i == 1 and var.needs_compat then - all_needs_compat[name] = true + self.all_needs_compat[name] = true end if use == "use_type" then var.used_as_type = true @@ -6625,10 +6921,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function simulate_g(): RecordType, Attribute + function TypeChecker:simulate_g(): RecordType, Attribute -- this is a static approximation of _G local globals: {string:Type} = {} - for k, v in pairs(st[1]) do + for k, v in pairs(self.st[1].vars) do if k:sub(1,1) ~= "@" then globals[k] = v.t end @@ -6641,101 +6937,61 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, nil end - local type ResolveType = function(Type): Type - local resolve_typevars: function (typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + local type ResolveType = function(S, Type): Type + local typevar_resolver: function(s: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} - local function fresh_typevar(t: TypeVarType): Type, Type, boolean - return a_type("typevar", { + local function fresh_typevar(_: nil, t: TypeVarType): Type, Type, boolean + return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeVarType) end - local function fresh_typearg(t: TypeArgType): Type - return a_type("typearg", { + local function fresh_typearg(_: nil, t: TypeArgType): Type + return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeArgType) end - local function ensure_fresh_typeargs(t: T): T + function TypeChecker:ensure_fresh_typeargs(t: T): T if not t is HasTypeArgs then return t end fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok: boolean - ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg) + ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) assert(ok, "Internal Compiler Error: error creating fresh type variables") return t end - local function find_var_type(name: string, use?: VarUse): Type, Attribute, Type - local var = find_var(name, use) + function TypeChecker:find_var_type(name: string, use?: VarUse): Type, Attribute, Type + local var = self:find_var(name, use) if var then local t = var.t if t is UnresolvedTypeArgType then return nil, nil, t.constraint end - t = ensure_fresh_typeargs(t) + t = self:ensure_fresh_typeargs(t) return t, var.attribute end end - local function Err(where: Where, msg: string, ...: Type): Error - local n = select("#", ...) - if n > 0 then - local showt = {} - for i = 1, n do - local t = select(i, ...) - if t then - if t.typename == "invalid" then - return nil - end - showt[i] = show_type(t) - end - end - msg = msg:format(table.unpack(showt)) - end - local name = where.filename or filename - - if TL_DEBUG then - io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") - end - - return { - y = where.y, - x = where.x, - msg = msg, - filename = name, - } - end - - local function error_at(w: Where, msg: string, ...:Type): boolean - assert(w.y) - - local e = Err(w, msg, ...) - if e then - table.insert(errors, e) - return true - else - return false - end - end - - local function ensure_not_abstract(where: Where, t: Type) + local function ensure_not_abstract(t: Type): boolean, string if t is FunctionType and t.macroexp then - error_at(where, "macroexps are abstract; consider using a concrete function") + return nil, "macroexps are abstract; consider using a concrete function" elseif t is TypeDeclType then local def = t.def if def is InterfaceType then - error_at(where, "interfaces are abstract; consider using a concrete record") + return nil, "interfaces are abstract; consider using a concrete record" end end + return true end - local function find_type(names: {string}, accept_typearg?: boolean): Type - local typ = find_var_type(names[1], "use_type") + function TypeChecker:find_type(names: {string}, accept_typearg?: boolean): Type + local typ = self:find_var_type(names[1], "use_type") if not typ then return nil end @@ -6757,7 +7013,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - typ = ensure_fresh_typeargs(typ) + typ = self:ensure_fresh_typeargs(typ) if typ is NominalType and typ.found then typ = typ.found end @@ -6769,19 +7025,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function union_type(t: Type): string, Type + local function type_for_union(t: Type): string, Type if t is TypeDeclType then - return union_type(t.def), t.def + return type_for_union(t.def), t.def elseif t is TypeAliasType then - return union_type(t.alias_to), t.alias_to + return type_for_union(t.alias_to), t.alias_to elseif t is TupleType then - return union_type(t.tuple[1]), t.tuple[1] + return type_for_union(t.tuple[1]), t.tuple[1] elseif t is NominalType then local typedecl = t.found if not typedecl then return "invalid" end - return union_type(typedecl) + return type_for_union(typedecl) elseif t is RecordLikeType then if t.is_userdata then return "userdata", t @@ -6805,7 +7061,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut, rt = union_type(t) + local ut, rt = type_for_union(t) if ut == "userdata" then -- must be tested before table_types assert(rt is RecordLikeType) if rt.meta_fields and rt.meta_fields["__is"] then @@ -6886,24 +7142,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["unknown"] = true, } - local function default_resolve_typevars_callback(t: TypeVarType): Type - local rt = find_var_type(t.typevar) - if not rt then - return nil - elseif rt is StringType then - -- tk is not propagated - return STRING - end - return rt - end - - resolve_typevars = function(typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + typevar_resolver = function(self: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:boolean} = {} - fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t: T, all_same: boolean): T, boolean local same = true @@ -6918,7 +7161,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local orig_t = t if t is TypeVarType then - local rt = fn_var(t) + local rt = fn_var(self, t) if rt then resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then @@ -6934,7 +7177,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string seen[orig_t] = copy copy.typename = t.typename - copy.filename = t.filename + copy.f = t.f copy.x = t.x copy.y = t.y @@ -6945,7 +7188,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- inferred_len is not propagated elseif t is TypeArgType then if fn_arg then - copy = fn_arg(t) + copy = fn_arg(self, t) else assert(copy is TypeArgType) copy.typearg = t.typearg @@ -7038,7 +7281,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local _, err = is_valid_union(copy) if err then errs = errs or {} - table.insert(errs, Err(t, err, copy)) + table.insert(errs, Err(err, copy)) end elseif t is PolyType then assert(copy is PolyType) @@ -7048,6 +7291,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t is TupleTableType then assert(copy is TupleTableType) + copy.inferred_at = t.inferred_at copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -7067,7 +7311,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local copy, same = resolve(typ, true) if errs then - return false, INVALID, errs + return false, an_invalid(typ), errs end if (not same) and @@ -7086,153 +7330,81 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true, copy end - local function infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) + local function resolve_typevar(tc: TypeChecker, t: TypeVarType): Type + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + elseif rt is StringType then + -- tk is not propagated + return a_type(rt, "string", {}) + end + return rt + end + + + + function TypeChecker:infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") - local nst = is_global and 1 or #st + local nst = is_global and 1 or #self.st for i = nst, 1, -1 do - local scope = st[i] - if scope[emptytable.assigned_to] then - scope[emptytable.assigned_to] = { t = fresh_t } + local scope = self.st[i] + if scope.vars[emptytable.assigned_to] then + scope.vars[emptytable.assigned_to] = { t = fresh_t } end end end local function resolve_tuple(t: Type): Type - if t is TupleType then - t = t.tuple[1] + local rt = t + if rt is TupleType then + rt = rt.tuple[1] end - if t == nil then - return NIL + if rt == nil then + return a_type(t, "nil", {}) end - return t - end - - local function add_warning(tag: tl.WarningKind, where: Where, fmt: string, ...: any) - table.insert(warnings, { - y = where.y, - x = where.x, - msg = fmt:format(...), - filename = where.filename or filename, - tag = tag, - }) - end - - local function invalid_at(where: Where, msg: string, ...:Type): InvalidType - error_at(where, msg, ...) - return INVALID - end - - local function add_unknown(node: Node, name: string) - add_warning("unknown", node, "unknown variable: %s", name) + return rt end - local function redeclaration_warning(node: Node, old_var?: Variable) - if node.tk:sub(1, 1) == "_" then return end - - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end - - local short_error = "redeclaration of " .. var_kind .. " '%s'" - if old_var and old_var.declared_at then - add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) - else - add_warning("redeclaration", node, short_error, var_name) - end - end - local function check_if_redeclaration(new_name: string, at: Node) - local old = find_var(new_name, "check_only") + function TypeChecker:check_if_redeclaration(new_name: string, at: Node) + local old = self:find_var(new_name, "check_only") if old then - redeclaration_warning(at, old) + self.errs:redeclaration_warning(at, old) end end - local function unused_warning(name: string, var: Variable) - local prefix = name:sub(1,1) - if var.declared_at - and var.is_narrowed ~= "narrow" - and prefix ~= "_" - and prefix ~= "@" - then - if name:sub(1, 2) == "::" then - add_warning("unused", var.declared_at, "unused label %s", name) - else - local t = var.t - add_warning( - "unused", - var.declared_at, - "unused %s %s: %s", - var.is_func_arg and "argument" - or t is FunctionType and "function" - or t is TypeDeclType and "type" - or t is TypeAliasType and "type" - or "variable", - name, - show_type(var.t) - ) - end - end - end - - local function add_errs_prefixing(where: Where, src: {Error}, dst: {Error}, prefix: string) - assert(where == nil or where.y ~= nil) - - if not src then - return - end - for _, err in ipairs(src) do - err.msg = prefix .. err.msg - - if where and ( - (err.filename ~= filename) - or (not err.y) - or (where.y > err.y or (where.y == err.y and where.x > err.x)) - ) then - err.y = where.y - err.x = where.x - err.filename = filename - end - - table.insert(dst, err) - end - end local function type_at(w: Where, t: T): T t.x = w.x t.y = w.y - t.filename = filename return t end - local function resolve_typevars_at(where: Where, t: T): T - assert(where) - local ok, ret, errs = resolve_typevars(t) + function TypeChecker:resolve_typevars_at(w: Where, t: T): T + assert(w) + local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) if not ok then - assert(where.y) - add_errs_prefixing(where, errs, errors, "") + assert(w.y) + self.errs:add_prefixing(w, errs, "") end - if ret == t or t.typename == "typevar" then + if ret == t or t is TypeVarType then ret = shallow_copy_table(ret) end - return type_at(where, ret) + return type_at(w, ret) end - local function infer_at(where: Where, t: T): T - local ret = resolve_typevars_at(where, t) - if ret.typename == "invalid" then + function TypeChecker:infer_at(w: Where, t: T): T + local ret = self:resolve_typevars_at(w, t) + if ret is InvalidType then ret = t -- errors are produced by resolve_typevars_at end - if ret == t or t.typename == "typevar" then + if ret == t or t is TypeVarType then ret = shallow_copy_table(ret) end - ret.inferred_at = where - ret.inferred_at.filename = filename + assert(w.f) + ret.inferred_at = w return ret end @@ -7245,12 +7417,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local get_unresolved: function(scope?: Scope): UnresolvedType - local find_unresolved: function(level?: integer): UnresolvedType - - local function add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable - local scope = st[#st] - local var = scope[name] + function TypeChecker:add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable + local scope = self.st[#self.st] + local var = scope.vars[name] if narrow then if var then if var.is_narrowed then @@ -7263,11 +7432,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string var.t = t else var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node } - scope[name] = var + scope.vars[name] = var end - local unresolved = get_unresolved(scope) - unresolved.narrows[name] = true + scope.narrows = scope.narrows or {} + scope.narrows[name] = true return var end @@ -7278,46 +7447,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and name ~= "..." and name:sub(1, 1) ~= "@" then - check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node) end if var and not var.used then -- the old var is removed from the scope and won't be checked when it closes, -- so check it here - unused_warning(name, var) + self.errs:unused_warning(name, var) end var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope[name] = var + scope.vars[name] = var return var end - local function add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable - if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - add_unknown(node, name) + function TypeChecker:add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) end if not attribute then t = drop_constant_value(t) end - local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - - if t is UnresolvedType or t.typename == "none" then - return var - end + local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if tc and node then - tc.add_to_symbol_list(node, name, t) + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) end return var end - local type CompareTypes = function(Type, Type): boolean, {Error} - - local same_type: function(t1: Type, t2: Type): boolean, {Error} - local is_a: function(Type, Type): boolean, {Error} + local type CompareTypes = function(TypeChecker, Type, Type): boolean, {Error} local enum ArgCheckMode "argument" @@ -7332,38 +7494,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "invariant" end - local function arg_check(where: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean + function TypeChecker:arg_check(w: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean local ok, errs: boolean, {Error} if v == "covariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) elseif v == "contravariant" then - ok, errs = is_a(b, a) + ok, errs = self:is_a(b, a) elseif v == "bivariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) if ok then return true end - ok = is_a(b, a) + ok = self:is_a(b, a) if ok then return true end elseif v == "invariant" then - ok, errs = same_type(a, b) + ok, errs = self:same_type(a, b) end if not ok then - add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") + self.errs:add_prefixing(w, errs, mode .. (n and " " .. n or "") .. ": ", all_errs) return false end return true end - local function has_all_types_of(t1s: {Type}, t2s: {Type}): boolean + function TypeChecker:has_all_types_of(t1s: {Type}, t2s: {Type}): boolean for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if same_type(t2, t1) then + if self:same_type(t2, t1) then found = true break end @@ -7395,8 +7557,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function close_types(vars: {string:Variable}) - for _, var in pairs(vars) do + local function close_types(scope: Scope) + for _, var in pairs(scope.vars) do local t = var.t if t is TypeDeclType then t.closed = true @@ -7408,161 +7570,96 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local record Unused - y: integer - x: integer - name: string - var: Variable - end - - local function check_for_unused_vars(vars: {string:Variable}, is_global?: boolean) - if not next(vars) then - return - end - local list: {Unused} = {} - for name, var in pairs(vars) do - local t = var.t - if var.declared_at and not var.used then - if var.used_as_type then - var.declared_at.elide_type = true - else - if (t is TypeDeclType or t is TypeAliasType) and not is_global then - var.declared_at.elide_type = true - end - table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) - end - elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then - var.aliasing.used = true - var.aliasing.declared_at.elide_type = false - end - end - if list[1] then - table.sort(list, function(a: Unused, b: Unused): boolean - return a.y < b.y or (a.y == b.y and a.x < b.x) - end) - for _, u in ipairs(list) do - unused_warning(u.name, u.var) - end - end - end - - get_unresolved = function(scope?: Scope): UnresolvedType - local unresolved: UnresolvedType - if scope then - local unr = scope["@unresolved"] - unresolved = unr and unr.t as UnresolvedType - else - unresolved = find_var_type("@unresolved") as UnresolvedType - end - if not unresolved then - unresolved = a_type("unresolved", { - labels = {}, - nominals = {}, - global_types = {}, - narrows = {}, - } as UnresolvedType) - add_var(nil, "@unresolved", unresolved) - end - return unresolved - end - - find_unresolved = function(level?: integer): UnresolvedType - local u = st[level or #st]["@unresolved"] - if u then - return u.t as UnresolvedType - end - end - - local function begin_scope(node?: Node) - table.insert(st, {}) + function TypeChecker:begin_scope(node?: Node) + table.insert(self.st, { vars = {} }) - if tc and node then - tc.begin_symbol_list_scope(node) + if self.collector and node then + self.collector.begin_symbol_list_scope(node) end end - local function end_scope(node?: Node) + function TypeChecker:end_scope(node?: Node) + local st = self.st local scope = st[#st] - local unresolved = scope["@unresolved"] - if unresolved then - local unrt = unresolved.t as UnresolvedType - local next_scope = st[#st - 1] - local upper = next_scope["@unresolved"] - if upper then - local uppert = upper.t as UnresolvedType - for name, nodes in pairs(unrt.labels) do + local next_scope = st[#st - 1] + + if next_scope then + if scope.pending_labels then + next_scope.pending_labels = next_scope.pending_labels or {} + for name, nodes in pairs(scope.pending_labels) do for _, n in ipairs(nodes) do - uppert.labels[name] = uppert.labels[name] or {} - table.insert(uppert.labels[name], n) + next_scope.pending_labels[name] = next_scope.pending_labels[name] or {} + table.insert(next_scope.pending_labels[name], n) end end - for name, types in pairs(unrt.nominals) do + scope.pending_labels = nil + end + if scope.pending_nominals then + next_scope.pending_nominals = next_scope.pending_nominals or {} + for name, types in pairs(scope.pending_nominals) do for _, typ in ipairs(types) do - uppert.nominals[name] = uppert.nominals[name] or {} - table.insert(uppert.nominals[name], typ) + next_scope.pending_nominals[name] = next_scope.pending_nominals[name] or {} + table.insert(next_scope.pending_nominals[name], typ) end end - for name, _ in pairs(unrt.global_types) do - uppert.global_types[name] = true - end - else - next_scope["@unresolved"] = unresolved - unrt.narrows = {} + scope.pending_nominals = nil end end + close_types(scope) - check_for_unused_vars(scope) + self.errs:warn_unused_vars(scope) + table.remove(st) - if tc and node then - tc.end_symbol_list_scope(node) + if self.collector and node then + self.collector.end_symbol_list_scope(node) end end - local end_scope_and_none_type = function(node: Node, _children: {Type}): Type - end_scope(node) + -- This type must never be used for any values + local NONE = a_type({ f = "@none", x = -1, y = -1 }, "none", {}) + + local function end_scope_and_none_type(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_scope(node) return NONE end local type InvalidOrTypeDeclType = InvalidType | TypeDeclType - local resolve_nominal: function(t: NominalType): Type - local resolve_typealias: function(t: TypeAliasType): InvalidOrTypeDeclType do - local function match_typevals(t: NominalType, def: RecordLikeType | FunctionType): Type + local function match_typevals(self: TypeChecker, t: NominalType, def: RecordLikeType | FunctionType): Type if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - error_at(t, "mismatch in number of type arguments") + self.errs:add(t, "mismatch in number of type arguments") return nil end - begin_scope() + self:begin_scope() for i, tt in ipairs(t.typevals) do - add_var(nil, def.typeargs[i].typearg, tt) + self:add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t, def) - end_scope() + local ret = self:resolve_typevars_at(t, def) + self:end_scope() return ret elseif t.typevals then - error_at(t, "spurious type arguments") + self.errs:add(t, "spurious type arguments") return nil elseif def.typeargs then - error_at(t, "missing type arguments in %s", def) + self.errs:add(t, "missing type arguments in %s", def) return nil else return def end end - local function find_nominal_type_decl(t: NominalType): Type, TypeDeclType + local function find_nominal_type_decl(self: TypeChecker, t: NominalType): Type, TypeDeclType if t.resolved then return t.resolved end - local found = t.found or find_type(t.names) + local found = t.found or self:find_type(t.names) if not found then - error_at(t, "unknown type %s", t) - return INVALID + return self.errs:invalid_at(t, "unknown type %s", t) end if found is TypeAliasType then @@ -7570,8 +7667,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not found is TypeDeclType then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end local def = found.def @@ -7586,44 +7682,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil, found end - local function resolve_decl_into_nominal(t: NominalType, found: TypeDeclType): Type + local function resolve_decl_into_nominal(self: TypeChecker, t: NominalType, found: TypeDeclType): Type local def = found.def local resolved: Type if def is RecordType or def is FunctionType then - resolved = match_typevals(t, def) + resolved = match_typevals(self, t, def) if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") end else resolved = def end - if not t.filename then - t.filename = resolved.filename - if t.x == nil and t.y == nil then - t.x = resolved.x - t.y = resolved.y - end - end - t.resolved = resolved return resolved end - resolve_nominal = function(t: NominalType): Type - local immediate, found = find_nominal_type_decl(t) + function TypeChecker:resolve_nominal(t: NominalType): Type + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end - return resolve_decl_into_nominal(t, found) + return resolve_decl_into_nominal(self, t, found) end - resolve_typealias = function(typealias: TypeAliasType): InvalidOrTypeDeclType + function TypeChecker:resolve_typealias(typealias: TypeAliasType): InvalidOrTypeDeclType local t = typealias.alias_to - local immediate, found = find_nominal_type_decl(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end @@ -7632,90 +7719,92 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return found end - local resolved = resolve_decl_into_nominal(t, found) + local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type("typedecl", { def = resolved } as TypeDeclType) + local typedecl = a_type(typealias, "typedecl", { def = resolved } as TypeDeclType) t.resolved = typedecl return typedecl end end - local function are_same_unresolved_global_type(t1: NominalType, t2: NominalType): boolean - if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() - if unresolved.global_types[t1.names[1]] then - return true + do + local function are_same_unresolved_global_type(self: TypeChecker, t1: NominalType, t2: NominalType): boolean + if t1.names[1] == t2.names[1] then + local global_scope = self.st[1] + if global_scope.pending_global_types[t1.names[1]] then + return true + end end + return false end - return false - end - local function fail_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + local function fail_nominals(self: TypeChecker, t1: NominalType, t2: NominalType): boolean, {Error} + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + self:resolve_nominal(t1) + if t1.found then + t1name = t1name .. " (defined in " .. t1.found.f .. ":" .. t1.found.y .. ")" + end + self:resolve_nominal(t2) + if t2.found then + t2name = t2name .. " (defined in " .. t2.found.f .. ":" .. t2.found.y .. ")" + end end + return false, { Err(t1name .. " is not a " .. t2name) } end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - local function are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local same_names: boolean - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or find_type(t1.names) - local ft2 = t2.found or find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid + function TypeChecker:are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} + local same_names: boolean + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid else - if are_same_unresolved_global_type(t1, t2) then - return true - end + local ft1 = t1.found or self:find_type(t1.names) + local ft2 = t2.found or self:find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - error_at(t1, "unknown type %s", t1) - end - if not ft2 then - error_at(t2, "unknown type %s", t2) + if not ft1 then + self.errs:add(t1, "unknown type %s", t1) + end + if not ft2 then + self.errs:add(t2, "unknown type %s", t2) + end + return false, {} -- errors were already produced end - return false, {} -- errors were already produced end - end - if not same_names then - return fail_nominals(t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local errs = {} - for i = 1, #t1.typevals do - local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + if not same_names then + return fail_nominals(self, t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) + self.errs:add_prefixing(nil, typeval_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", errs) + end + return any_errors(errs) end - return any_errors(errs) + return true end - return true end local is_lua_table_type: function(t: Type): boolean - local function to_structural(t: Type): Type + function TypeChecker:to_structural(t: Type): Type assert(not t is TupleType) if t is NominalType then - return resolve_nominal(t) + return self:resolve_nominal(t) end return t end - local function unite(types: {Type}, flatten_constants?: boolean): Type + local function unite(w: Where, types: {Type}, flatten_constants?: boolean): Type if #types == 1 then return types[1] end @@ -7726,7 +7815,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- Make things like number | number resolve to number local types_seen: {(integer|string):boolean} = {} -- but never add nil as a type in the union - types_seen[NIL.typeid] = true types_seen["nil"] = true local i = 1 @@ -7762,14 +7850,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if types_seen[INVALID.typeid] then - return INVALID + if types_seen["invalid"] then + return a_type(w, "invalid", {}) end if #ts == 1 then return ts[1] else - return a_union(ts) + return a_union(w, ts) end end @@ -7789,21 +7877,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local expand_type: function(where: Where, old: Type, new: Type): Type - local function arraytype_from_tuple(where: Where, tupletype: TupleTableType): ArrayType, {Error} + function TypeChecker:arraytype_from_tuple(w: Where, tupletype: TupleTableType): ArrayType, {Error} -- first just try a basic union - local element_type = unite(tupletype.types, true) + local element_type = unite(w, tupletype.types, true) local valid = (not element_type is UnionType) and true or is_valid_union(element_type) if valid then - return an_array(element_type) + return an_array(w, element_type) end -- failing a basic union, expand the types - local arr_type = an_array(tupletype.types[1]) + local arr_type = an_array(w, tupletype.types[1]) for i = 2, #tupletype.types do - local expanded = expand_type(where, arr_type, an_array(tupletype.types[i])) + local expanded = self:expand_type(w, arr_type, an_array(w, tupletype.types[i])) if not expanded is ArrayType then - return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } + return nil, { Err("unable to convert tuple %s to array", tupletype) } end arr_type = expanded end @@ -7814,33 +7901,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t is NominalType and t.names[1] == "@self" end - local function compare_true(_: Type, _: Type): boolean, {Error} + local function compare_true(_: TypeChecker, _: Type, _: Type): boolean, {Error} return true end - local function subtype_nominal(a: Type, b: Type): boolean, {Error} + function TypeChecker:subtype_nominal(a: Type, b: Type): boolean, {Error} if is_self(a) and is_self(b) then return true end - local ra = a is NominalType and resolve_nominal(a) or a - local rb = b is NominalType and resolve_nominal(b) or b - local ok, errs = is_a(ra, rb) + local ra = a is NominalType and self:resolve_nominal(a) or a + local rb = b is NominalType and self:resolve_nominal(b) or b + local ok, errs = self:is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false -- translate to got-expected error with unresolved types end return ok, errs end - local function subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} - if (not a.elements) or (not is_a(a.elements, b.elements)) then + function TypeChecker:subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} + if (not a.elements) or (not self:is_a(a.elements, b.elements)) then return false end if a.consttypes and #a.consttypes > 1 then -- constant array, check elements (useful for array of enums) for _, e in ipairs(a.consttypes) do - if not is_a(e, b.elements) then - return false, { Err(a, "%s is not a member of %s", e, b.elements) } + if not self:is_a(e, b.elements) then + return false, { Err("%s is not a member of %s", e, b.elements) } end end end @@ -7862,16 +7949,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - local function subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} + function TypeChecker:subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} -- assert(b.typename == "record") if a.elements and b.elements then - if not is_a(a.elements, b.elements) then - return false, { Err(a, "array parts have incompatible element types") } + if not self:is_a(a.elements, b.elements) then + return false, { Err("array parts have incompatible element types") } end end if a.is_userdata ~= b.is_userdata then - return false, { Err(a, a.is_userdata and "userdata is not a record" + return false, { Err(a.is_userdata and "userdata is not a record" or "record is not a userdata") } end @@ -7880,9 +7967,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ak = a.fields[k] local bk = b.fields[k] if bk then - local ok, fielderrs = is_a(ak, bk) + local ok, fielderrs = self:is_a(ak, bk) if not ok then - add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + self.errs:add_prefixing(nil, fielderrs, "record field doesn't match: " .. k .. ": ", errs) end end end @@ -7896,32 +7983,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local eqtype_record = function(a: RecordType, b: RecordType): boolean, {Error} + function TypeChecker:eqtype_record(a: RecordType, b: RecordType): boolean, {Error} -- checking array interface if (a.elements ~= nil) ~= (b.elements ~= nil) then - return false, { Err(a, "types do not have the same array interface") } + return false, { Err("types do not have the same array interface") } end if a.elements then - local ok, errs = same_type(a.elements, b.elements) + local ok, errs = self:same_type(a.elements, b.elements) if not ok then return ok, errs end end - local ok, errs = subtype_record(a, b) + local ok, errs = self:subtype_record(a, b) if not ok then return ok, errs end - ok, errs = subtype_record(b, a) + ok, errs = self:subtype_record(b, a) if not ok then return ok, errs end return true end - local function compare_map(ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} - local ok1, errs_k = same_type(ak, bk) - local ok2, errs_v = same_type(av, bv) + local function compare_map(self: TypeChecker, ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} + local ok1, errs_k = self:same_type(ak, bk) + local ok2, errs_v = self:same_type(av, bv) -- FIXME hack for {any:any} if bk.typename == "any" and not no_hack then @@ -7951,25 +8038,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, errs_k or errs_v end - local function compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} + function TypeChecker:compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} -- assert((a == nil and b ~= nil) or (a ~= nil and b == nil)) -- does the typevar currently match to a type? - local vt, _, constraint = find_var_type(typevar) + local vt, _, constraint = self:find_var_type(typevar) if vt then -- If so, compare it to the other type - return cmp(a or vt, b or vt) + return cmp(self, a or vt, b or vt) else -- otherwise, infer it to the other type local other = a or b -- but check interface constraint first if present if constraint then - if not is_a(other, constraint) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + if not self:is_a(other, constraint) then + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end - if same_type(other, constraint) then + if self:same_type(other, constraint) then -- do not infer to some type as constraint right away, -- to give a chance to more specific inferences -- in other arguments/returns @@ -7977,22 +8064,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ok, r, errs = resolve_typevars(other) + local ok, r, errs = typevar_resolver(self, other, resolve_typevar) if not ok then return false, errs end if r is TypeVarType and r.typevar == typevar then return true end - add_var(nil, typevar, r) + self:add_var(nil, typevar, r) return true end end -- ∃ x ∈ xs. t <: x - local function exists_supertype_in(t: Type, xs: AggregateType): Type + function TypeChecker:exists_supertype_in(t: Type, xs: AggregateType): Type for _, x in ipairs(xs.types) do - if is_a(t, x) then + if self:is_a(t, x) then return x end end @@ -8003,143 +8090,139 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a: Type, b: InterfaceType): boolean, {Error} + ["interface"] = function(_self: TypeChecker, _a: Type, b: InterfaceType): boolean, {Error} return not b.is_userdata end, - ["record"] = function(_a: Type, b: RecordType): boolean, {Error} + ["record"] = function(_self: TypeChecker, _a: Type, b: RecordType): boolean, {Error} return not b.is_userdata end, } - local type TypeRelations = {TypeName:{TypeName:CompareTypes}} - - local eqtype_relations: TypeRelations - eqtype_relations = { + TypeChecker.eqtype_relations = { ["typevar"] = { - ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} + ["typevar"] = function(self: TypeChecker, a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, - ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} - return compare_or_infer_typevar(a.typevar, nil, b, same_type) + ["*"] = function(self: TypeChecker, a: TypeVarType, b: Type): boolean, {Error} + return self:compare_or_infer_typevar(a.typevar, nil, b, self.same_type) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do - if not same_type(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + if not self:same_type(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types ~= #b.types then - return false, { Err(a, "tuples have different size", a, b) } + return false, { Err("tuples have different size", a, b) } end return true end, }, ["array"] = { - ["array"] = function(a: ArrayType, b: ArrayType): boolean, {Error} - return same_type(a.elements, b.elements) + ["array"] = function(self: TypeChecker, a: ArrayType, b: ArrayType): boolean, {Error} + return self:same_type(a.elements, b.elements) end, }, ["map"] = { - ["map"] = function(a: MapType, b: MapType): boolean, {Error} - return compare_map(a.keys, b.keys, a.values, b.values, true) + ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} + return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, }, ["union"] = { - ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} - return (has_all_types_of(a.types, b.types) - and has_all_types_of(b.types, a.types)) + ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} + return (self:has_all_types_of(a.types, b.types) + and self:has_all_types_of(b.types, a.types)) end, }, ["nominal"] = { - ["nominal"] = are_same_nominals, + ["nominal"] = TypeChecker.are_same_nominals, }, ["record"] = { - ["record"] = eqtype_record, + ["record"] = TypeChecker.eqtype_record, }, ["interface"] = { - ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} + ["interface"] = function(_self:TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} return a.typeid == b.typeid end, }, ["function"] = { - ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} + ["function"] = function(self:TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then if (not not a.is_method) ~= (not not b.is_method) then - return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + return false, { Err("different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } + return false, { Err("different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end local narets, nbrets = #a.rets.tuple, #b.rets.tuple if narets ~= nbrets then - return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } + return false, { Err("different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} for i = 1, naargs do - arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) + self:arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) + self:arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, }, ["*"] = { - ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, }, } - local subtype_relations: TypeRelations - subtype_relations = { + TypeChecker.subtype_relations = { ["tuple"] = { - ["tuple"] = function(a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] + ["tuple"] = function(self: TypeChecker, a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] local at, bt = a.tuple, b.tuple -- ────────────────────────────────── if #at ~= #bt then -- a tuple <: b tuple return false end for i = 1, #at do - if not is_a(at[i], bt[i]) then + if not self:is_a(at[i], bt[i]) then return false end end return true end, - ["*"] = function(a: Type, b: Type): boolean, {Error} - return is_a(resolve_tuple(a), b) + ["*"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self:is_a(resolve_tuple(a), b) end, }, ["typevar"] = { - ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} + ["typevar"] = function(self: TypeChecker, a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} - return compare_or_infer_typevar(a.typevar, nil, b, is_a) + ["*"] = function(self: TypeChecker, a: TypeVarType, b: Type): boolean, {Error} + return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, ["nil"] = { ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u + ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u local used = {} -- ──────────────────────── for _, t in ipairs(a.types) do -- a union <: b union - begin_scope() - local u = exists_supertype_in(t, b) - end_scope() -- don't preserve failed inferences + self:begin_scope() + local u = self:exists_supertype_in(t, b) + self:end_scope() -- don't preserve failed inferences if not u then return false end @@ -8148,13 +8231,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end for u, t in pairs(used) do - is_a(t, u) -- preserve valid inferences + self:is_a(t, u) -- preserve valid inferences end return true end, - ["*"] = function(a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b - for _, t in ipairs(a.types) do -- ──────────────── - if not is_a(t, b) then -- a union <: b + ["*"] = function(self: TypeChecker, a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b + for _, t in ipairs(a.types) do -- ──────────────── + if not self:is_a(t, b) then -- a union <: b return false end end @@ -8162,212 +8245,212 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["poly"] = { - ["*"] = function(a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b - if exists_supertype_in(b, a) then -- ─────────────── - return true -- a poly <: b + ["*"] = function(self: TypeChecker, a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b + if self:exists_supertype_in(b, a) then -- ─────────────── + return true -- a poly <: b end - return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + return false, { Err("cannot match against any alternatives of the polymorphic type") } end, }, ["nominal"] = { - ["nominal"] = function(a: NominalType, b: NominalType): boolean, {Error} - local ok, errs = are_same_nominals(a, b) + ["nominal"] = function(self: TypeChecker, a: NominalType, b: NominalType): boolean, {Error} + local ok, errs = self:are_same_nominals(a, b) if ok then return true end - local rb = resolve_nominal(b) + local rb = self:resolve_nominal(b) if rb is InterfaceType then -- match interface subtyping - return is_a(a, rb) + return self:is_a(a, rb) end - local ra = resolve_nominal(a) + local ra = self:resolve_nominal(a) if ra is UnionType or rb is UnionType then -- match unions structurally - return is_a(ra, rb) + return self:is_a(ra, rb) end -- all other types nominally return ok, errs end, - ["*"] = subtype_nominal, + ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a: StringType, b: EnumType): boolean, {Error} + ["enum"] = function(_self: TypeChecker, a: StringType, b: EnumType): boolean, {Error} if not a.literal then - return false, { Err(a, "string is not a %s", b) } + return false, { Err("%s is not a %s", a, b) } end if b.enumset[a.literal] then return true end - return false, { Err(a, "%s is not a member of %s", a, b) } + return false, { Err("%s is not a member of %s", a, b) } end, }, ["integer"] = { ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + ["interface"] = function(self: TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then return true end - return same_type(a, b) + return self:same_type(a, b) end, - ["array"] = subtype_array, - ["record"] = subtype_record, - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} - return subtype_relations["record"]["tupletable"](a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = TypeChecker.subtype_record, + ["tupletable"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self.subtype_relations["record"]["tupletable"](self, a, b) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do - if not is_a(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " + if not self:is_a(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types > #b.types then - return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } + return false, { Err("tuple %s is too big for tuple %s", a, b) } end return true end, - ["record"] = function(a: Type, b: RecordType): boolean, {Error} + ["record"] = function(self: TypeChecker, a: Type, b: RecordType): boolean, {Error} if b.elements then - return subtype_relations["tupletable"]["array"](a, b) + return self.subtype_relations["tupletable"]["array"](self, a, b) end end, - ["array"] = function(a: TupleTableType, b: ArrayType): boolean, {Error} + ["array"] = function(self: TypeChecker, a: TupleTableType, b: ArrayType): boolean, {Error} if b.inferred_len and b.inferred_len > #a.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local aa, err = arraytype_from_tuple(a.inferred_at, a) + local aa, err = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then return false, err end - if not is_a(aa, b) then - return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } + if not self:is_a(aa, b) then + return false, { Err("got %s (from %s), expected %s", aa, a, b) } end return true end, - ["map"] = function(a: TupleTableType, b: MapType): boolean, {Error} - local aa = arraytype_from_tuple(a.inferred_at, a) + ["map"] = function(self: TypeChecker, a: TupleTableType, b: MapType): boolean, {Error} + local aa = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then - return false, { Err(a, "Unable to convert tuple %s to map", a) } + return false, { Err("Unable to convert tuple %s to map", a) } end - return compare_map(INTEGER, b.keys, aa.elements, b.values) + return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, }, ["record"] = { - ["record"] = subtype_record, - ["interface"] = function(a: RecordType, b: InterfaceType): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + ["record"] = TypeChecker.subtype_record, + ["interface"] = function(self: TypeChecker, a: RecordType, b: InterfaceType): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then return true end if not a.declname then -- match inferred table (anonymous record) structurally to interface - return subtype_record(a, b) + return self:subtype_record(a, b) end end, - ["array"] = subtype_array, - ["map"] = function(a: RecordType, b: MapType): boolean, {Error} - if not is_a(b.keys, STRING) then - return false, { Err(a, "can't match a record to a map with non-string keys") } + ["array"] = TypeChecker.subtype_array, + ["map"] = function(self: TypeChecker, a: RecordType, b: MapType): boolean, {Error} + if not self:is_a(b.keys, a_type(b, "string", {})) then + return false, { Err("can't match a record to a map with non-string keys") } end for _, k in ipairs(a.field_order) do local bk = b.keys if bk is EnumType and not bk.enumset[k] then - return false, { Err(a, "key is not an enum value: " .. k) } + return false, { Err("key is not an enum value: " .. k) } end - if not is_a(a.fields[k], b.values) then - return false, { Err(a, "record is not a valid map; not all fields have the same type") } + if not self:is_a(a.fields[k], b.values) then + return false, { Err("record is not a valid map; not all fields have the same type") } end end return true end, - ["tupletable"] = function(a: RecordType, b: Type): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: RecordType, b: Type): boolean, {Error} if a.elements then - return subtype_relations["array"]["tupletable"](a, b) + return self.subtype_relations["array"]["tupletable"](self, a, b) end end, }, ["array"] = { - ["array"] = subtype_array, - ["record"] = function(a: ArrayType, b: RecordType): boolean, {Error} + ["array"] = TypeChecker.subtype_array, + ["record"] = function(self: TypeChecker, a: ArrayType, b: RecordType): boolean, {Error} if b.elements then - return subtype_array(a, b) + return self:subtype_array(a, b) end end, - ["map"] = function(a: ArrayType, b: MapType): boolean, {Error} - return compare_map(INTEGER, b.keys, a.elements, b.values) + ["map"] = function(self: TypeChecker, a: ArrayType, b: MapType): boolean, {Error} + return compare_map(self, a_type(a, "integer", {}), b.keys, a.elements, b.values) end, - ["tupletable"] = function(a: ArrayType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: ArrayType, b: TupleTableType): boolean, {Error} local alen = a.inferred_len or 0 if alen > #b.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end -- for array literals (which is the only case where inferred_len is defined), -- only check the entries that are present for i = 1, (alen > 0) and alen or #b.types do - if not is_a(a.elements, b.types[i]) then - return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } + if not self:is_a(a.elements, b.types[i]) then + return false, { Err("tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true end, }, ["map"] = { - ["map"] = function(a: MapType, b: MapType): boolean, {Error} - return compare_map(a.keys, b.keys, a.values, b.values) + ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} + return compare_map(self, a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a: MapType, b: ArrayType): boolean, {Error} - return compare_map(a.keys, INTEGER, a.values, b.elements) + ["array"] = function(self: TypeChecker, a: MapType, b: ArrayType): boolean, {Error} + return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, }, ["typedecl"] = { - ["record"] = function(a: TypeDeclType, b: RecordType): boolean, {Error} + ["record"] = function(self: TypeChecker, a: TypeDeclType, b: RecordType): boolean, {Error} local def = a.def if def is RecordLikeType then - return subtype_record(def, b) -- record as prototype + return self:subtype_record(def, b) -- record as prototype end end, }, ["function"] = { - ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} + ["function"] = function(self: TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} local errs = {} local aa, ba = a.args.tuple, b.args.tuple if (not b.args.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) + table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end local ar, br = a.rets.tuple, b.rets.tuple local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) + table.insert(errs, Err("incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) + self:arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8375,36 +8458,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["typearg"] = { - ["typearg"] = function(a: TypeArgType, b: TypeArgType): boolean, {Error} + ["typearg"] = function(_self: TypeChecker, a: TypeArgType, b: TypeArgType): boolean, {Error} return a.typearg == b.typearg end, - ["*"] = function(a: TypeArgType, b: Type): boolean, {Error} + ["*"] = function(self: TypeChecker, a: TypeArgType, b: Type): boolean, {Error} if a.constraint then - return is_a(a.constraint, b) + return self:is_a(a.constraint, b) end end, }, ["*"] = { ["any"] = compare_true, - ["tuple"] = function(a: Type, b: Type): boolean, {Error} - return is_a(a_tuple({a}), b) + ["tuple"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self:is_a(a_tuple(a, {a}), b) end, - ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["typearg"] = function(a: Type, b: TypeArgType): boolean, {Error} + ["typearg"] = function(self: TypeChecker, a: Type, b: TypeArgType): boolean, {Error} if b.constraint then - return is_a(a, b.constraint) + return self:is_a(a, b.constraint) end end, - ["union"] = exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t - -- ─────────────── - -- a <: b union - ["nominal"] = subtype_nominal, - ["poly"] = function(a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t - for _, t in ipairs(b.types) do -- ─────────────── - if not is_a(a, t) then -- a <: b poly - return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + ["union"] = TypeChecker.exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t + -- ─────────────── + -- a <: b union + ["nominal"] = TypeChecker.subtype_nominal, + ["poly"] = function(self: TypeChecker, a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t + for _, t in ipairs(b.types) do -- ─────────────── + if not self:is_a(a, t) then -- a <: b poly + return false, { Err("cannot match against all alternatives of the polymorphic type") } end end return true @@ -8413,7 +8496,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } -- evaluation strategy - local type_priorities: {TypeName:integer} = { + TypeChecker.type_priorities = { -- types that have catch-all rules evaluate first ["tuple"] = 2, ["typevar"] = 3, @@ -8442,19 +8525,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["function"] = 14, } - if lax then - type_priorities["unknown"] = 0 - - subtype_relations["unknown"] = {} - subtype_relations["unknown"]["*"] = compare_true - subtype_relations["*"]["unknown"] = compare_true - -- in .lua files, all values can be used in a boolean context - subtype_relations["boolean"] = {} - subtype_relations["boolean"]["boolean"] = compare_true - subtype_relations["*"]["boolean"] = compare_true - end - - local function compare_types(relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} + local function compare_types(self: TypeChecker, relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} if t1.typeid == t2.typeid then return true end @@ -8462,8 +8533,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local s1 = relations[t1.typename] local fn = s1 and s1[t2.typename] if not fn then - local p1 = type_priorities[t1.typename] or 999 - local p2 = type_priorities[t2.typename] or 999 + local p1 = self.type_priorities[t1.typename] or 999 + local p2 = self.type_priorities[t2.typename] or 999 fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) end @@ -8472,32 +8543,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fn == compare_true then return true end - ok, err = fn(t1, t2) + ok, err = fn(self, t1, t2) else ok = t1.typename == t2.typename end if (not ok) and not err then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err("got %s, expected %s", t1, t2) } end return ok, err end -- subtyping comparison - is_a = function(t1: Type, t2: Type): boolean, {Error} - return compare_types(subtype_relations, t1, t2) + function TypeChecker:is_a(t1: Type, t2: Type): boolean, {Error} + return compare_types(self, self.subtype_relations, t1, t2) end -- invariant type comparison - same_type = function(t1: Type, t2: Type): boolean, {Error} + function TypeChecker:same_type(t1: Type, t2: Type): boolean, {Error} -- except for error messages, behavior is the same as - -- `return (is_a(t1, t2) and is_a(t2, t1))` - return compare_types(eqtype_relations, t1, t2) + -- `return (is_a(t1, t2) and self:is_a(t2, t1))` + return compare_types(self, self.eqtype_relations, t1, t2) end if TL_DEBUG then - local orig_is_a = is_a - is_a = function(t1: Type, t2: Type): boolean, {Error} + local orig_is_a = TypeChecker.is_a + TypeChecker.is_a = function(self: TypeChecker, t1: Type, t2: Type): boolean, {Error} assert(type(t1) == "table") assert(type(t2) == "table") @@ -8507,14 +8578,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - return orig_is_a(t1, t2) + return orig_is_a(self, t1, t2) end end - local function assert_is_a(where: Where, t1: Type, t2: Type, context: string, name?: string): boolean + function TypeChecker:assert_is_a(w: Where, t1: Type, t2: Type, ctx?: string | Node, name?: string): boolean t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then return true end @@ -8522,24 +8593,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t1.typename == "nil" then return true elseif t2 is UnresolvedEmptyTableValueType then - if is_number_type(t2.emptytable_type.keys) then -- ideally integer only - infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) + local t2keys = t2.emptytable_type.keys + if t2keys is NumericType then -- ideally integer only + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, an_array(w, t1))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_map(w, t2keys, t1))) end return true elseif t2 is EmptyTableType then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(where, t1)) + self:infer_emptytable(t2, self:infer_at(w, t1)) elseif not t1 is EmptyTableType then - error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) return false end return true end - local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(where, match_errs, errors, context .. ": ".. (name and (name .. ": ") or "")) + local ok, match_errs = self:is_a(t1, t2) + if not ok then + self.errs:add_prefixing(w, match_errs, self.errs:get_context(ctx, name)) + end return ok end @@ -8547,11 +8621,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t is InvalidType then return false end - if same_type(t, NIL) then + if t.typename == "nil" then return true end if t is NominalType then - t = resolve_nominal(t) + t = assert(t.resolved) end if t is RecordLikeType then return t.meta_fields and t.meta_fields["__close"] ~= nil @@ -8569,36 +8643,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return definitely_not_closable_exprs[e.kind] end - local unknown_dots: {string:boolean} = {} - - local function add_unknown_dot(node: Node, name: string) - if not unknown_dots[name] then - unknown_dots[name] = true - add_unknown(node, name) - end - end - - local function same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type + function TypeChecker:same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type local t1, f = check(u.types[1]) if not t1 then return nil end for i = 2, #u.types do local t2 = check(u.types[i]) - if not t2 or not same_type(t1, t2) then + if not t2 or not self:same_type(t1, t2) then return nil end end return f or t1 end - local function same_call_mt_in_all_union_entries(u: UnionType): Type - return same_in_all_union_entries(u, function(t: Type): (Type, Type) - t = to_structural(t) + function TypeChecker:same_call_mt_in_all_union_entries(u: UnionType): Type + return self:same_in_all_union_entries(u, function(t: Type): (Type, Type) + t = self:to_structural(t) if t is RecordLikeType then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt is FunctionType then - local args_tuple = a_tuple({}) + local args_tuple = a_tuple(u, {}) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end @@ -8608,20 +8673,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end) end - local function resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean + function TypeChecker:resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function - if lax and is_unknown(func) then - func = a_fn { args = va_args { UNKNOWN }, rets = va_args { UNKNOWN } } + if self.feat_lax and is_unknown(func) then + local unk = func + func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) end -- unwrap if tuple, resolve if nominal - func = to_structural(func) + func = self:to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then -- resolve if union if func is UnionType then - local r = same_call_mt_in_all_union_entries(func) + local r = self:same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) -- FIXME: is this right? - return to_structural(r), true + return self:to_structural(r), true end end -- resolve if prototype @@ -8635,7 +8701,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func is RecordLikeType and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = to_structural(func) + func = self:to_structural(func) is_method = true end end @@ -8643,19 +8709,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local type OnArgId = function(node: Node, i: integer): T - local type OnNode = function(node: Node, children: {T}, ret: T): T + local type OnNode = function(s: S, node: Node, children: {T}, ret: T): T - local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T + local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T local root = macroexp.exp local argnames = {} for i, a in ipairs(macroexp.args) do argnames[a.tk] = i end - local visit_node: Visitor = { + local visit_node: Visitor = { cbs = { ["variable"] = { - after = function(node: Node, _children: {T}): T + after = function(_: nil, node: Node, _children: {T}): T local i = argnames[node.tk] if not i then return nil @@ -8665,10 +8731,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } }, - after = on_node, + after = on_node as VisitorAfter, } - return recurse_node(root, visit_node, {}) + return recurse_node(nil, root, visit_node, {}) end local function expand_macroexp(orignode: Node, args: {Node}, macroexp: Node) @@ -8676,7 +8742,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return { Node, args[i] } end - local on_node = function(node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} + local on_node = function(_: nil, node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} local orig = ret and ret[2] or node local out = shallow_copy_table(orig) @@ -8705,12 +8771,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.expanded = p[2] end - local function check_macroexp_arg_use(macroexp: Node) + function TypeChecker:check_macroexp_arg_use(macroexp: Node) local used: {string:boolean} = {} local on_arg_id = function(node: Node, _i: integer): {Node, Node} if used[node.tk] then - error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + self.errs:add(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -8733,18 +8799,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type InvalidOrTupleType = InvalidType | TupleType - - local type_check_function_call: function(Node, Type, TupleType, ? integer, ? Node, ? {Node}): InvalidOrTupleType, FunctionType do - local function mark_invalid_typeargs(f: FunctionType) + local function mark_invalid_typeargs(self: TypeChecker, f: FunctionType) if f.typeargs then for _, a in ipairs(f.typeargs) do - if not find_var_type(a.typearg) then + if not self:find_var_type(a.typearg) then if a.constraint then - add_var(nil, a.typearg, a.constraint) + self:add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + self:add_var(nil, a.typearg, self.feat_lax and an_unknown(a) or a_type(a, "unresolvable_typearg", { typearg = a.typearg } as UnresolvableTypeArgType)) end @@ -8753,7 +8816,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function infer_emptytables(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) + local function infer_emptytables(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8763,19 +8826,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if x is EmptyTableType then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then -- y may not be present when inferring returns - local w = wheres and wheres[i + delta] or where -- for self, a + argdelta is 0 - local inferred_y = infer_at(w, y) - infer_emptytable(x, inferred_y) + local iw = wheres and wheres[i + delta] or w -- for self, a + argdelta is 0 + local inferred_y = self:infer_at(iw, y) + self:infer_emptytable(x, inferred_y) xt[i] = inferred_y end end end end - local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} + local check_args_rets: function(TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} do -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} + local function check_func_type_list(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8786,11 +8849,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[pos] or where - if not arg_check(w, errs, x, y, v, mode, pos) then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8799,7 +8862,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} + check_args_rets = function(self: TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean @@ -8810,19 +8873,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then + if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end if expected_rets then - expected_rets = infer_at(where, expected_rets) - infer_emptytables(where, nil, expected_rets, f.rets, 0) + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") + args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -8830,29 +8893,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- if we got to this point without returning, -- we got a valid function match - infer_emptytables(where, where_args, args, f.args, argdelta) + infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end end - local function push_typeargs(func: FunctionType) + local function push_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { constraint = fnarg.constraint, } as UnresolvedTypeArgType)) end end end - local function pop_typeargs(func: FunctionType) + local function pop_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - if st[#st][fnarg.typearg] then - st[#st][fnarg.typearg] = nil + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end end @@ -8866,12 +8929,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType + local function fail_call(self: TypeChecker, w: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType if errs then - -- report the errors from the first match - for _, err in ipairs(errs) do - table.insert(errors, err) - end + self.errs:collect(errs) else -- found no arity match to try local expects: {string} = {} @@ -8888,34 +8948,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else table.insert(expects, show_arity(func)) end - error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = resolve_function_type(func, 1) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") local is_method = (argdelta == -1) if not (func is FunctionType or func is PolyType) then - func, is_method = resolve_for_call(func, args, is_method) + func, is_method = self:resolve_for_call(func, args, is_method) if is_method then argdelta = -1 end if not (func is FunctionType or func is PolyType) then - return invalid_at(where, "not a function: %s", func) + return self.errs:invalid_at(w, "not a function: %s", func) end end if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_typedecl(args.tuple[1]))) + self:add_var(nil, "@self", a_typedecl(w, args.tuple[1])) end local passes, n = 1, 1 @@ -8932,30 +8992,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then - if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then + if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then -- a non-"@funcall" means a synthesized call, e.g. from a metamethod if not is_typedecl_funcall then - add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end else - return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") + return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") end end local wanted = #fargs - local min_arity = feat_arity and f.min_arity or 0 + local min_arity = self.feat_arity and f.min_arity or 0 -- simple functions: - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) -- poly, pass 1: try exact arity matches first or (passes == 3 and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and (lax or given >= min_arity)) + or (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) -- poly, pass 3: then finally try vararg functions or (pass == 3 and f.args.is_va and given > wanted))) then - push_typeargs(f) + push_typeargs(self, f) - local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) + local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) if matched then -- success! return matched, f @@ -8964,23 +9024,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if expected_rets then -- revert inferred returns - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) end if passes == 3 then tried = tried or {} tried[i] = true - pop_typeargs(f) + pop_typeargs(self, f) end end end end end - return fail_call(where, func, given, first_errs) + return fail_call(self, w, func, given, first_errs) end - type_check_function_call = function(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType + function TypeChecker:type_check_function_call(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -8989,14 +9049,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if expected and expected is TupleType then expected_rets = expected else - expected_rets = a_tuple { node.expected } + expected_rets = a_tuple(node, { node.expected }) end - begin_scope() + self:begin_scope() local is_typedecl_funcall: boolean - if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then - local receiver = node.e1.receiver + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver if receiver is NominalType then local resolved = receiver.resolved if resolved and resolved is TypeDeclType then @@ -9005,12 +9065,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) - ret = resolve_typevars_at(node, ret) - end_scope() + local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) + self:end_scope() - if tc and e1 then - tc.store_type(e1.y, e1.x, f) + if self.collector then + self.collector.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -9021,9 +9081,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer - if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then - return UNKNOWN, nil + function TypeChecker:check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer + if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then + return an_unknown(node), nil end local ameta = a is RecordLikeType and a.meta_fields local bmeta = b and b is RecordLikeType and b.meta_fields @@ -9044,26 +9104,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if metamethod then local e2 = { node.e1 } - local args = a_tuple { orig_a } + local args = a_tuple(node, { orig_a }) if b and method_name ~= "__is" then e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end end - local function match_record_key(tbl: Type, rec: Node, key: string): Type, string + function TypeChecker:match_record_key(tbl: Type, rec: Node, key: string): Type, string assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") - tbl = to_structural(tbl) + tbl = self:to_structural(tbl) if tbl is StringType or tbl is EnumType then - tbl = find_var_type("string") -- simulate string metatable + tbl = self:find_var_type("string") -- simulate string metatable end if tbl is TypeDeclType then @@ -9072,13 +9132,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else - tbl = resolve_nominal(tbl.alias_to) + tbl = self:resolve_nominal(tbl.alias_to) end end if tbl is UnionType then - local t = same_in_all_union_entries(tbl, function(t: Type): (Type, Type) - return (match_record_key(t, rec, key)) + local t = self:same_in_all_union_entries(tbl, function(t: Type): (Type, Type) + return (self:match_record_key(t, rec, key)) end) if t then @@ -9087,7 +9147,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if (tbl is TypeVarType or tbl is TypeArgType) and tbl.constraint then - local t = match_record_key(tbl.constraint, rec, key) + local t = self:match_record_key(tbl.constraint, rec, key) if t then return t @@ -9101,7 +9161,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return tbl.fields[key] end - local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) + local str = a_type(rec, "string", {}) + local meta_t = self:check_metamethod(rec, "__index", tbl, str, tbl, str) if meta_t then return meta_t end @@ -9116,8 +9177,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil, "invalid key '" .. key .. "' in type %s" end elseif tbl is EmptyTableType or is_unknown(tbl) then - if lax then - return INVALID + if self.feat_lax then + return an_unknown(rec) end return nil, "cannot index a value of unknown type" end @@ -9129,30 +9190,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function widen_in_scope(scope: Scope, var: string): boolean - assert(scope[var], "no " .. var .. " in scope") - local narrow_mode = scope[var].is_narrowed - if narrow_mode and narrow_mode ~= "declaration" then - if scope[var].narrowed_from then - scope[var].t = scope[var].narrowed_from - scope[var].narrowed_from = nil - scope[var].is_narrowed = nil - else - scope[var] = nil - end + function TypeChecker:widen_in_scope(scope: Scope, var: string): boolean + local v = scope.vars[var] + assert(v, "no " .. var .. " in scope") + local narrow_mode = scope.vars[var].is_narrowed + if (not narrow_mode) or narrow_mode == "declaration" then + return false + end - local unresolved = get_unresolved(scope) - unresolved.narrows[var] = nil - return true + if v.narrowed_from then + v.t = v.narrowed_from + v.narrowed_from = nil + v.is_narrowed = nil + else + scope.vars[var] = nil + end + + if scope.narrows then + scope.narrows[var] = nil end - return false + + return true end - local function widen_back_var(name: string): boolean + function TypeChecker:widen_back_var(name: string): boolean local widened = false - for i = #st, 1, -1 do - if st[i][name] then - if widen_in_scope(st[i], name) then + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.vars[name] then + if self:widen_in_scope(scope, name) then widened = true else break @@ -9163,10 +9229,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function assigned_anywhere(name: string, root: Node): boolean - local visit_node: Visitor = { + local visit_node: Visitor = { cbs = { ["assignment"] = { - after = function(node: Node, _children: {boolean}): boolean + after = function(_: nil, node: Node, _children: {boolean}): boolean for _, v in ipairs(node.vars) do if v.kind == "variable" and v.tk == name then return true @@ -9176,7 +9242,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } }, - after = function(_node: Node, children: {boolean}, ret: boolean): boolean + after = function(_: nil, _node: Node, children: {boolean}, ret: boolean): boolean ret = ret or false for _, c in ipairs(children) do local ca = c as any @@ -9188,124 +9254,88 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } - local visit_type: Visitor = { + local visit_type: Visitor = { after = function(): boolean return false end } - return recurse_node(root, visit_node, visit_type) + return recurse_node(nil, root, visit_node, visit_type) end - local function widen_all_unions(node?: Node) - for i = #st, 1, -1 do - local scope = st[i] - local unresolved = find_unresolved(i) - if unresolved and unresolved.narrows then - for name, _ in pairs(unresolved.narrows) do + function TypeChecker:widen_all_unions(node?: Node) + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.narrows then + for name, _ in pairs(scope.narrows) do if not node or assigned_anywhere(name, node) then - widen_in_scope(scope, name) + self:widen_in_scope(scope, name) end end end end end - local function add_global(node: Node, var: string, valtype: Type, is_assigning?: boolean): Variable - if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then - add_unknown(node, var) + function TypeChecker:add_global(node: Node, varname: string, valtype: Type, is_assigning?: boolean): Variable + if self.feat_lax and is_unknown(valtype) and (varname ~= "self" and varname ~= "...") then + self.errs:add_unknown(node, varname) end local is_const = node.attribute ~= nil - local existing, scope, existing_attr = find_var(var) + local existing, scope, existing_attr = self:find_var(varname) if existing then if scope > 1 then - error_at(node, "cannot define a global when a local with the same name is in scope") + self.errs:add(node, "cannot define a global when a local with the same name is in scope") elseif is_assigning and existing_attr then - error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + self.errs:add(node, "cannot reassign to <" .. existing_attr .. "> global: " .. varname) elseif existing_attr and not is_const then - error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + self.errs:add(node, "global was previously declared as <" .. existing_attr .. ">: " .. varname) elseif (not existing_attr) and is_const then - error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - elseif valtype and not same_type(existing.t, valtype) then - error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + self.errs:add(node, "global was previously declared as not <" .. node.attribute .. ">: " .. varname) + elseif valtype and not self:same_type(existing.t, valtype) then + self.errs:add(node, "cannot redeclare global with a different type: previous type of " .. varname .. " is %s", existing.t) end return nil end - st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - - return st[1][var] - end + local var = { t = valtype, attribute = is_const and "const" or nil } + self.st[1].vars[varname] = var - local get_rets: function(TupleType): TupleType - if lax then - get_rets = function(rets: TupleType): TupleType - if #rets.tuple == 0 then - return a_vararg { UNKNOWN } - end - return rets - end - else - get_rets = function(rets: TupleType): TupleType - return rets - end + return var end - local function add_internal_function_variables(node: Node, args: TupleType) - add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_tuple({})) + function TypeChecker:add_internal_function_variables(node: Node, args: TupleType) + self:add_var(nil, "@is_va", a_type(node, args.is_va and "any" or "nil", {})) + self:add_var(nil, "@return", node.rets or a_tuple(node, {})) if node.typeargs then for _, t in ipairs(node.typeargs) do - local v = find_var(t.typearg, "check_only") + local v = self:find_var(t.typearg, "check_only") if not v or not v.used_as_type then - error_at(t, "type argument '%s' is not used in function signature", t) + self.errs:add(t, "type argument '%s' is not used in function signature", t) end end end end - local function add_function_definition_for_recursion(node: Node, fnargs: TupleType) - add_var(nil, node.name.tk, type_at(node, a_function { + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + self:add_var(nil, node.name.tk, a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = fnargs, - rets = get_rets(node.rets), + rets = self.get_rets(node.rets), })) end - local function fail_unresolved() - local unresolved = st[#st]["@unresolved"] - if unresolved then - st[#st]["@unresolved"] = nil - local unrt = unresolved.t as UnresolvedType - for name, nodes in pairs(unrt.labels) do - for _, node in ipairs(nodes) do - error_at(node, "no visible label '" .. name .. "' for goto") - end - end - for name, types in pairs(unrt.nominals) do - if not unrt.global_types[name] then - for _, typ in ipairs(types) do - assert(typ.x) - assert(typ.y) - error_at(typ, "unknown type %s", typ) - end - end - end - end - end - - local function end_function_scope(node: Node) - fail_unresolved() - end_scope(node) + function TypeChecker:end_function_scope(node: Node) + self.errs:fail_unresolved_labels(self.st[#self.st]) + self:end_scope(node) end local function flatten_tuple(vals: TupleType): TupleType local vt = vals.tuple local n_vals = #vt - local ret = a_tuple {} + local ret = a_tuple(vals, {}) local rt = ret.tuple if n_vals == 0 then @@ -9333,9 +9363,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function get_assignment_values(vals: TupleType, wanted: integer): TupleType + local function get_assignment_values(w: Where, vals: TupleType, wanted: integer): TupleType if vals == nil then - return a_tuple {} + return a_tuple(w, {}) end local ret = flatten_tuple(vals) @@ -9354,14 +9384,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type + function TypeChecker:match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type local t: Type for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else - if not same_type(f, t) then + if not self:same_type(f, t) then errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k) t = nil break @@ -9371,26 +9401,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t then return t else - return invalid_at(node, errmsg) + return self.errs:invalid_at(node, errmsg) end end - local function type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type + function TypeChecker:type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type assert(not a is TupleType) assert(not b is TupleType) - local ra = resolve_typedecl(to_structural(a)) - local rb = to_structural(b) + local ra = resolve_typedecl(self:to_structural(a)) + local rb = self:to_structural(b) - if lax and is_unknown(a) then - return UNKNOWN + if self.feat_lax and is_unknown(a) then + return a end local errm: string local erra: Type local errb: Type - if ra is TupleTableType and is_a(rb, INTEGER) then + if ra is TupleTableType and rb is IntegerType then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then return ra.types[bnode.constnum as integer] @@ -9398,38 +9428,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, ra) + local array_type = self:arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif ra is ArrayLikeType and is_a(rb, INTEGER) then + elseif ra is ArrayLikeType and rb is IntegerType then return ra.elements elseif ra is EmptyTableType then if ra.keys == nil then - ra.keys = infer_at(anode, b) + ra.keys = self:infer_at(bnode, b) end - if is_a(b, ra.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { + if self:is_a(b, ra.keys) then + return a_type(anode, "unresolved_emptytable_value", { emptytable_type = ra - } as UnresolvedEmptyTableValueType)) + } as UnresolvedEmptyTableValueType) end - errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " - .. ra.keys.inferred_at.filename .. ":" - .. ra.keys.inferred_at.y .. ":" - .. ra.keys.inferred_at.x .. ": )", b, ra.keys + errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys elseif ra is MapType then - if is_a(b, ra.keys) then + if self:is_a(b, ra.keys) then return ra.values end errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys elseif rb is StringType and rb.literal then - local t, e = match_record_key(a, anode, rb.literal) + local t, e = self:match_record_key(a, anode, rb.literal) if t then return t end @@ -9445,10 +9472,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if not errm then - return match_all_record_field_names(bnode, ra, field_names, + return self:match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(rb, STRING) then + elseif rb is StringType then errm, erra = "cannot index object of type %s with a string, consider using an enum", a else errm, erra, errb = "cannot index object of type %s with %s", a, b @@ -9457,28 +9484,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", ra, b, a, b) + local meta_t = self:check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end - return invalid_at(bnode, errm, erra, errb) + return self.errs:invalid_at(bnode, errm, erra, errb) end - expand_type = function(where: Where, old: Type, new: Type): Type + function TypeChecker:expand_type(w: Where, old: Type, new: Type): Type if not old or old.typename == "nil" then return new else - if not is_a(new, old) then + if not self:is_a(new, old) then if old is MapType and new is RecordLikeType then local old_keys = old.keys if old_keys is StringType then for _, ftype in fields_of(new) do - old.values = expand_type(where, old.values, ftype) + old.values = self:expand_type(w, old.values, ftype) end - edit_type(old, "map") -- map changed, refresh typeid + edit_type(w, old, "map") -- map changed, refresh typeid else - error_at(where, "cannot determine table literal type") + self.errs:add(w, "cannot determine table literal type") end elseif old is RecordLikeType and new is RecordLikeType then local values: Type @@ -9486,14 +9513,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end for _, ftype in fields_of(new) do if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end old.fields = nil @@ -9501,25 +9528,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") + edit_type(w, old, "map") local map = old as MapType - map.keys = STRING + map.keys = a_type(w, "string", {}) map.values = values elseif old is UnionType then - edit_type(old, "union") + edit_type(w, old, "union") table.insert(old.types, drop_constant_value(new)) else - return unite({ old, new }, true) + return unite(w, { old, new }, true) end end end return old end - local function find_record_to_extend(exp: Node): Type, Variable, string + function TypeChecker:find_record_to_extend(exp: Node): Type, Variable, string -- base if exp.kind == "type_identifier" then - local v = find_var(exp.tk) + local v = self:find_var(exp.tk) if not v then return nil, nil, exp.tk end @@ -9536,7 +9563,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t, v, exp.tk -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") - local t, v, rname = find_record_to_extend(exp.e1) + local t, v, rname = self:find_record_to_extend(exp.e1) local fname = exp.e2.tk local dname = rname .. "." .. fname if not t then @@ -9557,30 +9584,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function typedecl_to_nominal(where: Where, name: string, t: TypeDeclType, resolved?: Type): Type + local function typedecl_to_nominal(node: Node, name: string, t: TypeDeclType, resolved?: Type): Type local typevals: {Type} local def = t.def if def is HasTypeArgs then typevals = {} for _, a in ipairs(def.typeargs) do - table.insert(typevals, a_type("typevar", { + table.insert(typevals, a_type(a, "typevar", { typevar = a.typearg, constraint = a.constraint, } as TypeVarType)) end end - return type_at(where, a_type("nominal", { - typevals = typevals, - names = { name }, - found = t, - resolved = resolved, - } as NominalType)) + local nom = a_nominal(node, { name }) + nom.typevals = typevals + nom.found = t + nom.resolved = resolved + return nom end - local function get_self_type(exp: Node): Type + function TypeChecker:get_self_type(exp: Node): Type -- base if exp.kind == "type_identifier" then - local t = find_var_type(exp.tk) + local t = self:find_var_type(exp.tk) if not t then return nil end @@ -9592,7 +9618,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") - local t = get_self_type(exp.e1) + local t = self:get_self_type(exp.e1) if not t then return nil end @@ -9621,10 +9647,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- Inference engine for 'is' operator - local facts_and: function(where: Where, f1: Fact, f2: Fact): Fact - local facts_or: function(where: Where, f1: Fact, f2: Fact): Fact - local facts_not: function(where: Where, f1: Fact): Fact - local apply_facts: function(where: Where, known: Fact) + local facts_and: function(w: Where, f1: Fact, f2: Fact): Fact + local facts_or: function(w: Where, f1: Fact, f2: Fact): Fact + local facts_not: function(w: Where, f1: Fact): Fact local FACT_TRUTHY: Fact do local IsFact_mt: metatable = { @@ -9636,6 +9661,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string setmetatable(IsFact, { __call = function(_: IsFact, fact: Fact): IsFact fact.fact = "is" + assert(fact.w) return setmetatable(fact as IsFact, IsFact_mt) end, }) @@ -9649,6 +9675,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string setmetatable(EqFact, { __call = function(_: EqFact, fact: Fact): EqFact fact.fact = "==" + assert(fact.w) return setmetatable(fact as EqFact, EqFact_mt) end, }) @@ -9707,57 +9734,57 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string FACT_TRUTHY = TruthyFact {} - facts_and = function(where: Where, f1: Fact, f2: Fact): Fact - return AndFact { f1 = f1, f2 = f2, where = where } + facts_and = function(w: Where, f1: Fact, f2: Fact): Fact + return AndFact { f1 = f1, f2 = f2, w = w } end - facts_or = function(where: Where, f1: Fact, f2: Fact): Fact + facts_or = function(w: Where, f1: Fact, f2: Fact): Fact if f1 and f2 then - return OrFact { f1 = f1, f2 = f2, where = where } + return OrFact { f1 = f1, f2 = f2, w = w } else return nil end end - facts_not = function(where: Where, f1: Fact): Fact + facts_not = function(w: Where, f1: Fact): Fact if f1 then - return NotFact { f1 = f1, where = where } + return NotFact { f1 = f1, w = w } else return nil end end -- t1 ∪ t2 - local function unite_types(t1: Type, t2: Type): Type, string - return unite({t2, t1}) + local function unite_types(w: Where, t1: Type, t2: Type): Type, string + return unite(w, {t2, t1}) end -- t1 ∩ t2 - local function intersect_types(t1: Type, t2: Type): Type, string + local function intersect_types(self: TypeChecker, w: Where, t1: Type, t2: Type): Type, string if t2 is UnionType then t1, t2 = t2, t1 end if t1 is UnionType then local out = {} for _, t in ipairs(t1.types) do - if is_a(t, t2) then + if self:is_a(t, t2) then table.insert(out, t) end end - return unite(out) + return unite(w, out) else - if is_a(t1, t2) then + if self:is_a(t1, t2) then return t1 - elseif is_a(t2, t1) then + elseif self:is_a(t2, t1) then return t2 else - return NIL -- because of implicit nil in all unions + return a_type(w, "nil", {}) -- because of implicit nil in all unions end end end - local function resolve_if_union(t: Type): Type - local rt = to_structural(t) + function TypeChecker:resolve_if_union(t: Type): Type + local rt = self:to_structural(t) if rt is UnionType then return rt end @@ -9765,23 +9792,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- t1 - t2 - local function subtract_types(t1: Type, t2: Type): Type + local function subtract_types(self: TypeChecker, w: Where, t1: Type, t2: Type): Type local types: {Type} = {} - t1 = resolve_if_union(t1) + t1 = self:resolve_if_union(t1) -- poly are not first-class, so we don't handle them here if not t1 is UnionType then return t1 end - t2 = resolve_if_union(t2) + t2 = self:resolve_if_union(t2) local t2types = t2 is UnionType and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true for _, bt in ipairs(t2types) do - if same_type(at, bt) then + if self:same_type(at, bt) then not_present = false break end @@ -9792,78 +9819,78 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if #types == 0 then - return NIL -- because of implicit nil in all unions + return a_type(w, "nil", {}) -- because of implicit nil in all unions end - return unite(types) + return unite(w, types) end - local eval_not: function(f: Fact): {string:IsFact|EqFact} - local not_facts: function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local or_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local and_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local eval_fact: function(f: Fact): {string:IsFact|EqFact} + local eval_not: function(TypeChecker, f: Fact): {string:IsFact|EqFact} + local not_facts: function(TypeChecker, fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local or_facts: function(TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local and_facts: function(TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local eval_fact: function(TypeChecker, f: Fact): {string:IsFact|EqFact} local function invalid_from(f: IsFact): IsFact - return IsFact { fact = "is", var = f.var, typ = INVALID, where = f.where } + return IsFact { fact = "is", var = f.var, typ = a_type(f.w, "invalid", {}), w = f.w } end - not_facts = function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + not_facts = function(self: TypeChecker, fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs) do - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then - ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + ret[var] = EqFact { var = var, typ = an_invalid(f.w), w = f.w, no_infer = f.no_infer } elseif f is EqFact then -- nothing is known from negation of equality; widen back - ret[var] = EqFact { var = var, typ = typ } - elseif typ.typename == "typevar" then + ret[var] = EqFact { var = var, typ = typ, w = f.w, no_infer = true } + elseif typ is TypeVarType then assert(f.fact == "is") - -- nothing is known from negation on typeargs; widen back (no 'where') - ret[var] = EqFact { var = var, typ = typ } - elseif not is_a(f.typ, typ) then + -- nothing is known from negation on typeargs; widen back + ret[var] = EqFact { var = var, typ = typ, w = f.w, no_infer = true } + elseif not self:is_a(f.typ, typ) then assert(f.fact == "is") - add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + self.errs:add_warning("branch", f.w, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact { var = var, typ = an_invalid(f.w), w = f.w, no_infer = f.no_infer } else assert(f.fact == "is") - ret[var] = IsFact { var = var, typ = subtract_types(typ, f.typ), where = f.where } + ret[var] = IsFact { var = var, typ = subtract_types(self, f.w, typ, f.typ), w = f.w, no_infer = f.no_infer } end end return ret end - eval_not = function(f: Fact): {string:IsFact|EqFact} + eval_not = function(self: TypeChecker, f: Fact): {string:IsFact|EqFact} if not f then return {} elseif f is IsFact then - return not_facts({[f.var] = f}) + return not_facts(self, {[f.var] = f}) elseif f is NotFact then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is AndFact then - return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return or_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) elseif f is OrFact then - return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return and_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) else - return not_facts(eval_fact(f)) + return not_facts(self, eval_fact(self, f)) end end - or_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + or_facts = function(_self: TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs2) do if fs1[var] then - local united = unite_types(f.typ, fs1[var].typ) + local united = unite_types(f.w, f.typ, fs1[var].typ) if fs1[var].fact == "is" and f.fact == "is" then - ret[var] = IsFact { var = var, typ = united, where = f.where } + ret[var] = IsFact { var = var, typ = united, w = f.w } else - ret[var] = EqFact { var = var, typ = united, where = f.where } + ret[var] = EqFact { var = var, typ = united, w = f.w } end end end @@ -9871,7 +9898,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - and_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + and_facts = function(self: TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} local has: {FactType:boolean} = {} @@ -9882,18 +9909,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fs2[var].fact == "is" and f.fact == "is" then ctor = IsFact end - rt = intersect_types(f.typ, fs2[var].typ) + rt = intersect_types(self, f.w, f.typ, fs2[var].typ) else rt = f.typ end - local ff = ctor { var = var, typ = rt, where = f.where } + local ff = ctor { var = var, typ = rt, w = f.w, no_infer = f.no_infer } ret[var] = ff has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = EqFact { var = var, typ = f.typ, where = f.where } + ret[var] = EqFact { var = var, typ = f.typ, w = f.w, no_infer = f.no_infer } has["=="] = true end end @@ -9907,21 +9934,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - eval_fact = function(f: Fact): {string:IsFact|EqFact} + eval_fact = function(self: TypeChecker, f: Fact): {string:IsFact|EqFact} if not f then return {} elseif f is IsFact then - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } end if typ.typename ~= "typevar" then - if is_a(typ, f.typ) then + if self:is_a(typ, f.typ) then -- drop this warning because of implicit nil in all unions - -- add_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) + -- self.errs:add_warning("branch", f.w, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) return { [f.var] = f } - elseif not is_a(f.typ, typ) then - error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + elseif not self:is_a(f.typ, typ) then + self.errs:add(f.w, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -9929,63 +9956,60 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif f is EqFact then return { [f.var] = f } elseif f is NotFact then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is TruthyFact then return {} elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is AndFact then - return and_facts(eval_fact(f.f1), eval_fact(f.f2)) + return and_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) elseif f is OrFact then - return or_facts(eval_fact(f.f1), eval_fact(f.f2)) + return or_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) end end - apply_facts = function(where: Where, known: Fact) + function TypeChecker:apply_facts(w: Where, known: Fact) if not known then return end - local facts = eval_fact(known) + local facts = eval_fact(self, known) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - error_at(where, "cannot resolve a type for " .. v .. " here") + self.errs:add(w, "cannot resolve a type for " .. v .. " here") end - local t = infer_at(where, f.typ) - if not f.where then + local t = f.no_infer and f.typ or self:infer_at(w, f.typ) + if f.no_infer then t.inferred_at = nil end - add_var(nil, v, t, "const", "narrow") + self:add_var(nil, v, t, "const", "narrow") end end end - local function dismiss_unresolved(name: string) - for i = #st, 1, -1 do - local unresolved = find_unresolved(i) - if unresolved then - local uses = unresolved.nominals[name] - if uses then - for _, t in ipairs(uses) do - resolve_nominal(t) - end - unresolved.nominals[name] = nil - return + function TypeChecker:dismiss_unresolved(name: string) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local uses = scope.pending_nominals and scope.pending_nominals[name] + if uses then + for _, t in ipairs(uses) do + self:resolve_nominal(t) end + scope.pending_nominals[name] = nil + return end end end - local type_check_funcall: function(node: Node, a: Type, b: Type, argdelta?: integer): InvalidOrTupleType - - local function special_pcall_xpcall(node: Node, _a: Type, b: TupleType, argdelta: integer): Type + local function special_pcall_xpcall(self: TypeChecker, node: Node, _a: Type, b: TupleType, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 + local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then - error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_tuple { BOOLEAN } + self.errs:add(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + return a_tuple(node, { bool }) end -- The function called by pcall/xpcall is invoked as a regular function, @@ -9997,137 +10021,142 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ftype.is_method = false end - local fe2: Node = {} + local fe2: Node = node_at(node.e2, {}) if node.e1.tk == "xpcall" then base_nargs = 2 + local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) - assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") + local msgh_type = a_function(arg2, { + min_arity = 1, + args = a_tuple(arg2, { a_type(arg2, "any", {}) }), + rets = a_tuple(arg2, {}) + }) + self:assert_is_a(arg2, msgh, msgh_type, "in message handler") end for i = base_nargs + 1, #node.e2 do table.insert(fe2, node.e2[i]) end - local fnode: Node = { - y = node.y, - x = node.x, + local fnode: Node = node_at(node, { kind = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, - } - local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) + }) + local rets = self:type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets is InvalidType then return rets end - table.insert(rets.tuple, 1, BOOLEAN) + table.insert(rets.tuple, 1, bool) return rets end - local special_functions: {string : function(Node,Type,TupleType,integer):InvalidOrTupleType } = { - ["pairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + local special_functions: {string : function(TypeChecker, Node,Type,TupleType,integer):InvalidOrTupleType } = { + ["pairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then - return invalid_at(node, "pairs requires an argument") + return self.errs:invalid_at(node, "pairs requires an argument") end - local t = to_structural(b.tuple[1]) + local t = self:to_structural(b.tuple[1]) if t is ArrayLikeType then - add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + self.errs:add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then - if not (lax and is_unknown(t)) then + if not (self.feat_lax and is_unknown(t)) then if t is RecordLikeType then - match_all_record_field_names(node.e2, t, t.field_order, + self:match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + self.errs:add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else - error_at(node.e2, "cannot apply pairs on values of type: %s", t) + self.errs:add(node.e2, "cannot apply pairs on values of type: %s", t) end end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["ipairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then - return invalid_at(node, "ipairs requires an argument") + return self.errs:invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = to_structural(orig_t) + local t = self:to_structural(orig_t) if t is TupleTableType then - local arr_type = arraytype_from_tuple(node.e2, t) + local arr_type = self:arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) + return self.errs:invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not t is ArrayLikeType then - if not (lax and (is_unknown(t) or t is EmptyTableType)) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) + if not (self.feat_lax and (is_unknown(t) or t is EmptyTableType)) then + return self.errs:invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["rawget"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType -- TODO should those offsets be fixed by _argdelta? if #b.tuple == 2 then - return a_tuple({ type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) + return a_tuple(node, { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) else - return invalid_at(node, "rawget expects two arguments") + return self.errs:invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["require"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType if #b.tuple ~= 1 then - return invalid_at(node, "require expects one literal argument") + return self.errs:invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return a_tuple({ a_type("any", {}) }) + return a_tuple(node, { a_type(node, "any", {}) }) end local module_name = assert(node.e2[1].conststr) - local t, found = require_module(module_name, lax, env) - if not found then - return invalid_at(node, "module not found: '" .. module_name .. "'") - end + local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) if t.typename == "invalid" then - if lax then - return a_tuple({ UNKNOWN }) + if not module_filename then + return self.errs:invalid_at(node, "module not found: '" .. module_name .. "'") + end + + if self.feat_lax then + return a_tuple(node, { an_unknown(node) }) end - return invalid_at(node, "no type information for required module: '" .. module_name .. "'") + return self.errs:invalid_at(node, "no type information for required module: '" .. module_name .. "'") end - dependencies[module_name] = t.filename - return type_at(node, a_tuple({ t })) + self.dependencies[module_name] = module_filename + return a_tuple(node, { t }) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["assert"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType node.known = FACT_TRUTHY - local r = type_check_function_call(node, a, b, argdelta) - apply_facts(node, node.e2[1].known) + local r = self:type_check_function_call(node, a, b, argdelta) + self:apply_facts(node, node.e2[1].known) return r end, } - type_check_funcall = function(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType + function TypeChecker:type_check_funcall(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] if special then - return special(node, a, b, argdelta) + return special(self, node, a, b, argdelta) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, a, b, -1)) + return (self:type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end end @@ -10139,19 +10168,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[i].tk == node.vars[i].tk end - local function missing_initializer(node: Node, i: integer, name: string): Type - if lax then - return UNKNOWN + function TypeChecker:missing_initializer(node: Node, i: integer, name: string): (InvalidType | UnknownType) + if self.feat_lax then + return an_unknown(node) else if node.exps then - return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return self.errs:invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return self.errs:invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end - local function set_expected_types_to_decltuple(node: Node, children: {Type}) + local function set_expected_types_to_decltuple(_: TypeChecker, node: Node, children: {Type}) local decltuple = node.kind == "assignment" and children[1] or node.decltuple assert(decltuple is TupleType) local decls = decltuple.tuple @@ -10163,7 +10192,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = type_at(node, a_tuple {}) + typ = a_tuple(node, {}) for a = i, ndecl do table.insert(typ.tuple, decls[a]) end @@ -10179,38 +10208,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return n and n >= 1 and math.floor(n) == n end - local context_name: {NodeKind: string} = { - ["local_declaration"] = "in local declaration", - ["global_declaration"] = "in global declaration", - ["assignment"] = "in assignment", - } - - local function in_context(ctx: Node.ExpectedContext, msg: string): string - if not ctx then - return msg - end - local where = context_name[ctx.kind] - if where then - return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg - else - return msg - end - end - - local type CheckableKey = string | number | boolean - - local function check_redeclared_key(where: Where, ctx: Node.ExpectedContext, seen_keys: {CheckableKey:Where}, key: CheckableKey) - if key ~= nil then - local s = seen_keys[key] - if s then - error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) - else - seen_keys[key] = where - end - end - end - - local function infer_table_literal(node: Node, children: {LiteralTableItemType}): Type + local function infer_table_literal(self: TypeChecker, node: Node, children: {LiteralTableItemType}): Type local is_record = false local is_array = false local is_map = false @@ -10235,14 +10233,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, child in ipairs(children) do local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b: boolean = nil - if child.ktype.typename == "boolean" then + if cktype is BooleanType then b = (node[i].key.tk == "true") end local key: CheckableKey = ck or n or b - check_redeclared_key(node[i], nil, seen_keys, key) + self.errs:check_redeclared_key(node[i], nil, seen_keys, key) local uvtype = resolve_tuple(child.vtype) if ck then @@ -10253,7 +10252,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end fields[ck] = uvtype table.insert(field_order, ck) - elseif is_number_type(child.ktype) then + elseif cktype is NumericType then is_array = true if not is_not_tuple then is_tuple = true @@ -10267,25 +10266,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if i == #children and cv is TupleType then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for _, c in ipairs(cv.tuple) do - elements = expand_type(node, elements, c) + elements = self:expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end else -- explicit if not is_positive_int(n) then - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n as integer] = uvtype if n > largest_array_idx then largest_array_idx = n as integer end - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end end @@ -10297,37 +10296,37 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else is_map = true - keys = expand_type(node, keys, drop_constant_value(child.ktype)) - values = expand_type(node, values, uvtype) + keys = self:expand_type(node, keys, drop_constant_value(cktype)) + values = self:expand_type(node, values, uvtype) end end local t: Type if is_array and is_map then - error_at(node, "cannot determine type of table literal") - t = a_map( - expand_type(node, keys, INTEGER), - expand_type(node, values, elements) + self.errs:add(node, "cannot determine type of table literal") + t = a_map(node, + self:expand_type(node, keys, a_type(node, "integer", {})), + self:expand_type(node, values, elements) ) elseif is_record and is_array then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, elements = elements, interface_list = { - type_at(node, an_array(elements)) + an_array(node, elements) } } as RecordType) - -- TODO adopt logic from is_array below when we accept tupletable as an interface + -- TODO adopt logic from self:is_array below when we accept tupletable as an interface elseif is_record and is_map then if keys is StringType then for _, fname in ipairs(field_order) do - values = expand_type(node, values, fields[fname]) + values = self:expand_type(node, values, fields[fname]) end - t = a_map(keys, values) + t = a_map(node, keys, values) else - error_at(node, "cannot determine type of table literal") + self.errs:add(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -10335,7 +10334,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local last_t: Type for _, current_t in pairs(types as {integer:Type}) do if last_t then - if not same_type(last_t, current_t) then + if not self:same_type(last_t, current_t) then pure_array = false break end @@ -10344,69 +10343,70 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if pure_array then - t = an_array(elements) + t = an_array(node, elements) t.consttypes = types t.inferred_len = largest_array_idx - 1 else - t = a_type("tupletable", {}) as TupleTableType + t = a_type(node, "tupletable", { inferred_at = node }) as TupleTableType t.types = types end elseif is_record then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, } as RecordType) elseif is_map then - t = a_map(keys, values) + t = a_map(node, keys, values) elseif is_tuple then - t = a_type("tupletable", {}) as TupleTableType + t = a_type(node, "tupletable", { inferred_at = node }) as TupleTableType t.types = types if not types or #types == 0 then - error_at(node, "cannot determine type of tuple elements") + self.errs:add(node, "cannot determine type of tuple elements") end end if not t then - t = a_type("emptytable", {}) + t = a_type(node, "emptytable", {}) end return type_at(node, t) end - local function infer_negation_of_if_blocks(where: Where, ifnode: Node, n: integer) - local f = facts_not(where, ifnode.if_blocks[1].exp.known) + function TypeChecker:infer_negation_of_if_blocks(w: Where, ifnode: Node, n: integer) + local f = facts_not(w, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] if b.exp then - f = facts_and(where, f, facts_not(where, b.exp.known)) + f = facts_and(w, f, facts_not(w, b.exp.known)) end end - apply_facts(where, f) + self:apply_facts(w, f) end - local function determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean + function TypeChecker:determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean local ok = true local name = var.tk local infertype = infertypes and infertypes.tuple[i] - if lax and infertype and infertype.typename == "nil" then + if self.feat_lax and infertype and infertype.typename == "nil" then infertype = nil end local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if to_structural(decltype) == INVALID then - decltype = INVALID + local rdecltype = self:to_structural(decltype) + if rdecltype is InvalidType then + decltype = rdecltype end if infertype then - ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) + local w = node.exps and node.exps[i] or node.vars[i] + ok = self:assert_is_a(w, infertype, decltype, context_name[node.kind], name) end else if infertype then if infertype is UnresolvableTypeArgType then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false - infertype = INVALID + infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") elseif infertype is FunctionType and infertype.is_method then -- If we assign a method to a variable, e.g: -- `local myfunc = myobj.dothing`, @@ -10418,17 +10418,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if var.attribute == "total" then - local rd = decltype and to_structural(decltype) + local rd = decltype and self:to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - error_at(var, "attribute only applies to maps and records") + self.errs:add(var, "attribute only applies to maps and records") ok = false elseif not infertype then - error_at(var, "variable declared does not declare an initialization value") + self.errs:add(var, "variable declared does not declare an initialization value") ok = false else local valnode = node.exps[i] if not valnode or valnode.kind ~= "literal_table" then - error_at(var, "attribute only applies to literal tables") + self.errs:add(var, "attribute only applies to literal tables") ok = false else if not valnode.is_total then @@ -10436,12 +10436,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if valnode.missing then missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" end - local ri = to_structural(infertype) + local ri = self:to_structural(infertype) if ri is MapType then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + self.errs:add(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri is RecordType then - error_at(var, "record variable declared does not declare values for all fields" .. missing) + self.errs:add(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -10451,34 +10451,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = decltype or infertype if t == nil then - t = missing_initializer(node, i, name) + t = self:missing_initializer(node, i, name) elseif t is EmptyTableType then t.declared_at = node t.assigned_to = name elseif t is ArrayLikeType then t.inferred_len = nil + elseif t is NominalType then + self:resolve_nominal(t) end return ok, t, infertype ~= nil end - local function get_typedecl(value: Node): TypeDeclType, Variable + function TypeChecker:get_typedecl(value: Node): TypeDeclType, Variable if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](value, find_var_type("require"), a_tuple { STRING }, 0) + local t = special_functions["require"](self, value, self:find_var_type("require"), a_tuple(value.e2, { a_type(value.e2[1], "string", {}) }), 0) local ty = t is TupleType and t.tuple[1] or t - ty = (ty is TypeAliasType) and resolve_typealias(ty) or ty - local td = (ty is TypeDeclType) and ty or a_type("typedecl", { def = ty } as TypeDeclType) + ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty + local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType) return td else local newtype = value.newtype if newtype is TypeAliasType then - local aliasing = find_var(newtype.alias_to.names[1], "use_type") - return resolve_typealias(newtype), aliasing - else + local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + return self:resolve_typealias(newtype), aliasing + elseif newtype is TypeDeclType then return newtype, nil end end @@ -10509,15 +10511,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_map_check(t: MapType, seen_keys: {CheckableKey:Where}): boolean, {string} - local k = to_structural(t.keys) + local function total_map_check(keys: Type, seen_keys: {CheckableKey:Where}): boolean, {string} local is_total = true local missing: {string} - if k is EnumType then - for _, key in ipairs(sorted_keys(k.enumset)) do + if keys is EnumType then + for _, key in ipairs(sorted_keys(keys.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end - elseif k.typename == "boolean" then + elseif keys.typename == "boolean" then for _, key in ipairs({ true, false }) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -10531,35 +10532,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "missing" end - local function check_assignment(where: Where, vartype: Type, valtype: Type, varname: string, attr: Attribute): Type, Type, MissingError + function TypeChecker:check_assignment(varnode: Node, vartype: Type, valtype: Type): Type, Type, MissingError + local varname = varnode.tk + local attr = varnode.attribute + if varname then - if widen_back_var(varname) then - vartype, attr = find_var_type(varname) + if self:widen_back_var(varname) then + vartype, attr = self:find_var_type(varname) if not vartype then - error_at(where, "unknown variable") + self.errs:add(varnode, "unknown variable") return nil end end end if attr == "close" or attr == "const" or attr == "total" then - error_at(where, "cannot assign to <" .. attr .. "> variable") + self.errs:add(varnode, "cannot assign to <" .. attr .. "> variable") return nil end - local var = to_structural(vartype) + local var = self:to_structural(vartype) if var is TypeDeclType or var is TypeAliasType then - error_at(where, "cannot reassign a type") + self.errs:add(varnode, "cannot reassign a type") return nil end if not valtype then - error_at(where, "variable is not being assigned a value") + self.errs:add(varnode, "variable is not being assigned a value") return nil, nil, "missing" end - assert_is_a(where, valtype, vartype, "in assignment") + self:assert_is_a(varnode, valtype, vartype, "in assignment") - local val = to_structural(valtype) + local val = self:to_structural(valtype) return var, val end @@ -10571,185 +10575,186 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_tuple(t) end - local visit_node: Visitor = {} + local visit_node: Visitor = {} visit_node.cbs = { ["statements"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) end, - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type -- if at the top level - if #st == 2 then - fail_unresolved() + if #self.st == 2 then + self.errs:fail_unresolved_labels(self.st[2]) + self.errs:fail_unresolved_nominals(self.st[2], self.st[1]) end if not node.is_repeat then - end_scope(node) + self:end_scope(node) end - -- TODO extract node type from `return` + return NONE end }, ["local_type"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) local name = node.var.tk - local resolved, aliasing = get_typedecl(node.value) - local var = add_var(node.var, name, resolved, node.var.attribute) + local resolved, aliasing = self:get_typedecl(node.value) + local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing end end, - after = function(node: Node, _children: {Type}): Type - dismiss_unresolved(node.var.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["global_type"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) + local global_scope = self.st[1] local name = node.var.tk - local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_typedecl(node.value) - local added = add_global(node.var, name, resolved) + local resolved, aliasing = self:get_typedecl(node.value) + local added = self:add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then added.aliasing = aliasing end - if added and unresolved.global_types[name] then - unresolved.global_types[name] = nil + if global_scope.pending_global_types[name] then + global_scope.pending_global_types[name] = nil end else - if not st[1][name] then - unresolved.global_types[name] = true + if not self.st[1].vars[name] then + global_scope.pending_global_types[name] = true end end end, - after = function(node: Node, _children: {Type}): Type - dismiss_unresolved(node.var.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["local_declaration"] = { - before = function(node: Node) - if tc then + before = function(self: TypeChecker, node: Node) + if self.collector then for _, var in ipairs(node.vars) do - tc.reserve_symbol_list_slot(var) + self.collector.reserve_symbol_list_slot(var) end end end, before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local valtuple = children[3] as TupleType -- may be nil local encountered_close = false - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then - if opts.gen_target == "5.4" then + if self.gen_target == "5.4" then if encountered_close then - error_at(var, "only one per declaration is allowed") + self.errs:add(var, "only one per declaration is allowed") else encountered_close = true end else - error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + self.errs:add(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(self.gen_target) .. ")") end end - local ok, t = determine_declaration_type(var, node, infertypes, i) + local ok, t = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then if not type_is_closable(t) then - error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + self.errs:add(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + self.errs:add(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end assert(var) - add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") local infertype = infertypes.tuple[i] if ok and infertype then - local where = node.exps[i] or node.exps + local w = node.exps[i] or node.exps - local rt = to_structural(t) + local rt = self:to_structural(t) if (not rt is EnumType) and ((not t is NominalType) or (rt is UnionType)) - and not same_type(t, infertype) + and not self:same_type(t, infertype) then - t = infer_at(where, infertype) - add_var(where, var.tk, t, "const", "narrowed_declaration") + t = self:infer_at(w, infertype) + self:add_var(w, var.tk, t, "const", "narrowed_declaration") end end - if tc then - tc.store_type(var.y, var.x, t) + if self.collector then + self.collector.store_type(var.y, var.x, t) end - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local valtuple = children[3] as TupleType -- may be nil - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do - local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) + local _, t, is_inferred = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - error_at(var, "globals may not be ") + self.errs:add(var, "globals may not be ") end - add_global(var, var.tk, t, is_inferred) + self:add_global(var, var.tk, t, is_inferred) - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local vartuple = children[1] assert(vartuple is TupleType) local vartypes = vartuple.tuple local valtuple = children[3] assert(valtuple is TupleType) - local valtypes = get_assignment_values(valtuple, #vartypes) + local valtypes = get_assignment_values(node, valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk local valtype = valtypes.tuple[i] - local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) + local rvar, rval, err = self:check_assignment(varnode, vartype, valtype) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local msg = #valtuple.tuple == 1 and "only 1 value is returned by the function" or ("only " .. #valtuple.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) + self.errs:add_warning("hint", varnode, msg) end end if rval and rvar then -- assigning a function if rval is FunctionType then - widen_all_unions() + self:widen_all_unions() end if varname and (rvar is UnionType or rvar is InterfaceType) then -- narrow unions and interfaces - add_var(varnode, varname, rval, nil, "narrow") + self:add_var(varnode, varname, rval, nil, "narrow") end - if tc then - tc.store_type(varnode.y, varnode.x, valtype) + if self.collector then + self.collector.store_type(varnode.y, varnode.x, valtype) end end end @@ -10758,7 +10763,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["if"] = { - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type local all_return = true for _, b in ipairs(node.if_blocks) do if not b.block_returns then @@ -10768,26 +10773,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if all_return then node.block_returns = true - infer_negation_of_if_blocks(node, node, #node.if_blocks) + self:infer_negation_of_if_blocks(node, node, #node.if_blocks) end return NONE end, }, ["if_block"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) if node.if_block_n > 1 then - infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) + self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end end, - before_statements = function(node: Node) + before_statements = function(self: TypeChecker, node: Node) if node.exp then - apply_facts(node.exp, node.exp.known) + self:apply_facts(node.exp, node.exp.known) end end, - after = function(node: Node, _children: {Type}): Type - end_scope(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_scope(node) if #node.body > 0 and node.body[#node.body].block_returns then node.block_returns = true @@ -10797,76 +10802,96 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end }, ["while"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions(node) + self:widen_all_unions(node) end, - before_statements = function(node: Node) - begin_scope(node) - apply_facts(node.exp, node.exp.known) + before_statements = function(self: TypeChecker, node: Node) + self:begin_scope(node) + self:apply_facts(node.exp, node.exp.known) end, after = end_scope_and_none_type, }, ["label"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions() - local label_id = "::" .. node.label .. "::" - if st[#st][label_id] then - error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) - end - local unresolved = find_unresolved() - local var = add_var(node, label_id, type_at(node, a_type("none", {}))) - if unresolved then - if unresolved.labels[node.label] then - var.used = true + self:widen_all_unions() + local label_id = node.label + do + local scope = self.st[#self.st] + scope.labels = scope.labels or {} + if scope.labels[label_id] then + self.errs:add(node, "label '" .. node.label .. "' already defined") + else + scope.labels[label_id] = node end - unresolved.labels[node.label] = nil end + + --for i = #self.st, 1, -1 do + local scope = self.st[#self.st] + if scope.pending_labels and scope.pending_labels[label_id] then + node.used_label = true + scope.pending_labels[label_id] = nil + --break + end + --end end, after = function(): Type return NONE end }, ["goto"] = { - after = function(node: Node, _children: {Type}): Type - if not find_var_type("::" .. node.label .. "::") then - local unresolved = get_unresolved(st[#st]) - unresolved.labels[node.label] = unresolved.labels[node.label] or {} - table.insert(unresolved.labels[node.label], node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local label_id = node.label + local found_label: Node + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.labels and scope.labels[label_id] then + found_label = scope.labels[label_id] + break + end + end + + if found_label then + found_label.used_label = true + else + local scope = self.st[#self.st] + scope.pending_labels = scope.pending_labels or {} + scope.pending_labels[label_id] = scope.pending_labels[label_id] or {} + table.insert(scope.pending_labels[label_id], node) end return NONE end, }, ["repeat"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions(node) + self:widen_all_unions(node) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, }, ["forin"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local exptuple = children[2] assert(exptuple is TupleType) local exptypes = exptuple.tuple - widen_all_unions(node) + self:widen_all_unions(node) local exp1 = node.exps[1] - local args = a_tuple { + local args = a_tuple(node.exps, { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3] - } - local exp1type = resolve_for_call(exptypes[1], args, false) + }) + local exp1type = self:resolve_for_call(exptypes[1], args, false) if exp1type is PolyType then local _: Type - _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) + _, exp1type = self:type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) end if exp1type is FunctionType then @@ -10879,69 +10904,69 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if rets.is_va then r = last else - r = lax and UNKNOWN or INVALID + r = self.feat_lax and an_unknown(v) or an_invalid(v) end end - add_var(v, v.tk, r) + self:add_var(v, v.tk, r) - if tc then - tc.store_type(v.y, v.x, r) + if self.collector then + self.collector.store_type(v.y, v.x, r) end last = r end local nrets = #rets.tuple - if (not lax) and (not rets.is_va and #node.vars > nrets) then + if (not self.feat_lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - error_at(at, "too many variables for this iterator; it produces " .. n_values) + self.errs:add(at, "too many variables for this iterator; it produces " .. n_values) end else - if not (lax and is_unknown(exp1type)) then - error_at(exp1, "expression in for loop does not return an iterator") + if not (self.feat_lax and is_unknown(exp1type)) then + self.errs:add(exp1, "expression in for loop does not return an iterator") end end end, after = end_scope_and_none_type, }, ["fornum"] = { - before_statements = function(node: Node, children: {Type}) - widen_all_unions(node) - begin_scope(node) - local from_t = to_structural(resolve_tuple(children[2])) - local to_t = to_structural(resolve_tuple(children[3])) - local step_t = children[4] and to_structural(children[4]) - local t = (from_t.typename == "integer" and - to_t.typename == "integer" and - (not step_t or step_t.typename == "integer")) - and INTEGER - or NUMBER - add_var(node.var, node.var.tk, t) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) + self:widen_all_unions(node) + self:begin_scope(node) + local from_t = self:to_structural(resolve_tuple(children[2])) + local to_t = self:to_structural(resolve_tuple(children[3])) + local step_t = children[4] and self:to_structural(children[4]) + local typename: TypeName = (from_t.typename == "integer" and + to_t.typename == "integer" and + (not step_t or step_t.typename == "integer")) + and "integer" + or "number" + self:add_var(node.var, node.var.tk, a_type(node.var, typename, {})) end, after = end_scope_and_none_type, }, ["return"] = { - before = function(node: Node) - local rets = find_var_type("@return") + before = function(self: TypeChecker, node: Node) + local rets = self:find_var_type("@return") if rets and rets is TupleType then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end end end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local got = children[1] assert(got is TupleType) local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") as TupleType + local expected = self:find_var_type("@return") as TupleType if not expected then -- if at the toplevel - expected = infer_at(node, got) - module_type = drop_constant_value(to_structural(resolve_tuple(expected))) - st[2]["@return"] = { t = expected } + expected = self:infer_at(node, got) + self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10956,8 +10981,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vatype = expected.is_va and expected.tuple[n_expected] end - if n_got > n_expected and (not lax) and not vatype then - error_at(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) + if n_got > n_expected and (not self.feat_lax) and not vatype then + self.errs:add(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end if n_expected > 1 @@ -10965,18 +10990,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and node.exps[1].discarded_tuple then - add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + self.errs:add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, n_got do local e = expected_t[i] or vatype if e then e = resolve_tuple(e) - local where = (node.exps[i] and node.exps[i].x) - and node.exps[i] - or node.exps - assert(where and where.x) - assert_is_a(where, got_t[i], e, what) + local w = (node.exps[i] and node.exps[i].x) + and node.exps[i] + or node.exps + assert(w and w.x) + self:assert_is_a(w, got_t[i], e, what) end end @@ -10984,25 +11009,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable_list"] = { - after = function(node: Node, children: {Type}): Type - local tuple = a_tuple(children) + after = function(self: TypeChecker, node: Node, children: {Type}): Type + local tuple = a_tuple(node, children) tuple = flatten_tuple(tuple) for i, t in ipairs(tuple.tuple) do - ensure_not_abstract(node[i], t) + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node[i], err) + end end return tuple end, }, ["literal_table"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) if node.expected then - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) if decltype is TypeVarType and decltype.constraint then - decltype = resolve_typedecl(to_structural(decltype.constraint)) + decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end if decltype is TupleTableType then @@ -11034,19 +11062,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {LiteralTableItemType}): Type + after = function(self: TypeChecker, node: Node, children: {LiteralTableItemType}): Type node.known = FACT_TRUTHY if not node.expected then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) local constraint: Type if decltype is TypeVarType and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = to_structural(constraint) + decltype = self:to_structural(constraint) end if decltype is UnionType then @@ -11054,7 +11082,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local single_table_rt: Type for _, t in ipairs(decltype.types) do - local rt = to_structural(t) + local rt = self:to_structural(t) if is_lua_table_type(rt) then if single_table_type then -- multiple table types in union, give up @@ -11075,7 +11103,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not is_lua_table_type(decltype) then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end local force_array: Type = nil @@ -11085,73 +11113,75 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, child in ipairs(children) do local cvtype = resolve_tuple(child.vtype) local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b: boolean = nil - if child.ktype.typename == "boolean" then + if cktype is BooleanType then b = (node[i].key.tk == "true") end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + self.errs:check_redeclared_key(node[i], node, seen_keys, ck or n or b) if decltype is RecordLikeType and ck then local df = decltype.fields[ck] if not df then - error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + self.errs:add_in_context(node[i], node, "unknown field " .. ck) else if df is TypeDeclType or df is TypeAliasType then - error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) + self.errs:add_in_context(node[i], node, "cannot reassign a type") else - assert_is_a(node[i], cvtype, df, "in record field", ck) + self:assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif decltype is TupleTableType and is_number_type(child.ktype) then + elseif decltype is TupleTableType and cktype is NumericType then local dt = decltype.types[n as integer] if not n then - error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unknown index in tuple %s", decltype) elseif not dt then - error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unexpected index " .. n .. " in tuple %s", decltype) else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, dt, node, "in tuple: at index " .. tostring(n)) end - elseif decltype is ArrayLikeType and is_number_type(child.ktype) then + elseif decltype is ArrayLikeType and cktype is NumericType then local cv = child.vtype if cv is TupleType and i == #children and node[i].key_parsed == "implicit" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for ti, tt in ipairs(cv.tuple) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) + self:assert_is_a(node[i], tt, decltype.elements, node, "expected an array: at index " .. tostring(i + ti - 1)) end else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, decltype.elements, node, "expected an array: at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then if decltype is MapType then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, a_type(node[i].key, "integer", {}), decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") end - force_array = expand_type(node[i], force_array, child.vtype) + force_array = self:expand_type(node[i], force_array, child.vtype) elseif decltype is MapType then force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, cktype, decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") else - error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + self.errs:add_in_context(node[i], node, "unexpected key of type %s in table of type %s", cktype, decltype) end end local t: Type if force_array then - t = infer_at(node, an_array(force_array)) + t = self:infer_at(node, an_array(node, force_array)) else - t = resolve_typevars_at(node, node.expected) + t = self:resolve_typevars_at(node, node.expected) end if decltype is RecordType then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt is RecordType then node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype is MapType then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt is MapType then - node.is_total, node.missing = total_map_check(decltype, seen_keys) + local rk = self:to_structural(rt.keys) + node.is_total, node.missing = total_map_check(rk, seen_keys) end end @@ -11163,13 +11193,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["literal_table_item"] = { - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.itemtype then vtype = node.itemtype - assert_is_a(node.value, children[2], node.itemtype, "in table item") + self:assert_is_a(node.value, children[2], node.itemtype, node) end if vtype is FunctionType and vtype.is_method then -- If we assign a method to a table item, e.g. @@ -11178,210 +11208,210 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("literal_table_item", { + return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - } as LiteralTableItemType)) + } as LiteralTableItemType) end, }, ["local_function"] = { - before = function(node: Node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[2] assert(args is TupleType) - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) - local t = type_at(node, ensure_fresh_typeargs(a_function { + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["local_macroexp"] = { - before = function(node: Node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) - check_macroexp_arg_use(node.macrodef) + self:check_macroexp_arg_use(node.macrodef) - local t = type_at(node, ensure_fresh_typeargs(a_function { + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), macroexp = node.macrodef, })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["global_function"] = { - before = function(node: Node) - widen_all_unions() - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + self:begin_scope(node) if node.implicit_global_function then - local typ = find_var_type(node.name.tk) + local typ = self:find_var_type(node.name.tk) if typ then if typ is FunctionType then node.is_predeclared_local_function = true - elseif not lax then - error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + elseif not self.feat_lax then + self.errs:add(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end - elseif not lax then - error_at(node, "functions need an explicit 'local' or 'global' annotation") + elseif not self.feat_lax then + self.errs:add(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[2] assert(args is TupleType) - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) if node.is_predeclared_local_function then return NONE end - add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function { + self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), }))) return NONE end, }, ["record_function"] = { - before = function(node: Node) - widen_all_unions() - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + self:begin_scope(node) end, - before_arguments = function(_node: Node, children: {Type}) - local rtype = to_structural(resolve_typedecl(children[1])) + before_arguments = function(self: TypeChecker, _node: Node, children: {Type}) + local rtype = self:to_structural(resolve_typedecl(children[1])) -- add type arguments from the record implicitly if rtype is RecordLikeType and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - } as TypeArgType))) + } as TypeArgType)) end end end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[3] assert(args is TupleType) local rets = children[4] assert(rets is TupleType) - local rtype = to_structural(resolve_typedecl(children[1])) + local rtype = self:to_structural(resolve_typedecl(children[1])) - if lax and rtype.typename == "unknown" then + if self.feat_lax and rtype is UnknownType then return end if rtype is EmptyTableType then - edit_type(rtype, "record") + edit_type(rtype, rtype, "record") local r = rtype as RecordType r.fields = {} r.field_order = {} end if not rtype is RecordLikeType then - error_at(node, "not a record: %s", rtype) + self.errs:add(node, "not a record: %s", rtype) return end - local selftype = get_self_type(node.fn_owner) + local selftype = self:get_self_type(node.fn_owner) if node.is_method then if not selftype then - error_at(node, "could not resolve type of self") + self.errs:add(node, "could not resolve type of self") return end args.tuple[1] = selftype - add_var(nil, "self", selftype) + self:add_var(nil, "self", selftype) end - local fn_type = type_at(node, ensure_fresh_typeargs(a_function { + local fn_type = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), })) - local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) + local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = to_structural(rfieldtype) + rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - redeclaration_warning(node) + self.errs:redeclaration_warning(node) end - local ok, err = same_type(fn_type, rfieldtype) + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype is PolyType then - add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") + self.errs:add_prefixing(node, err, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability): ") return end local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " - add_errs_prefixing(node, err, errors, msg) + self.errs:add_prefixing(node, err, msg) return end else - if lax or rtype == open_t then + if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -11394,32 +11424,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string open_v.implemented[open_k] = true end - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, _children: {Type}): Type - end_function_scope(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_function_scope(node) return NONE end, }, ["function"] = { - before = function(node: Node) - widen_all_unions(node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[1] assert(args is TupleType) - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[1] assert(args is TupleType) local rets = children[2] assert(rets is TupleType) - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function { + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, @@ -11428,24 +11458,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["macroexp"] = { - before = function(node: Node) - widen_all_unions(node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_exp = function(node: Node, children: {Type}) + before_exp = function(self: TypeChecker, node: Node, children: {Type}) local args = children[1] assert(args is TupleType) - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[1] assert(args is TupleType) local rets = children[2] assert(rets is TupleType) - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function { + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, @@ -11454,22 +11484,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["cast"] = { - after = function(node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return node.casttype end }, ["paren"] = { - before = function(node: Node) + before = function(_self: TypeChecker, node: Node) node.e1.expected = node.expected end, - after = function(node: Node, children: {Type}): Type + after = function(_self: TypeChecker, node: Node, children: {Type}): Type node.known = node.e1 and node.e1.known return resolve_tuple(children[1]) end, }, ["op"] = { - before = function(node: Node) - begin_scope() + before = function(self: TypeChecker, node: Node) + self:begin_scope() if node.expected then if node.op.op == "and" then node.e2.expected = node.expected @@ -11481,18 +11511,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - before_e2 = function(node: Node, children: {Type}) + before_e2 = function(self: TypeChecker, node: Node, children: {Type}) local e1type = children[1] if node.op.op == "and" then - apply_facts(node, node.e1.known) + self:apply_facts(node, node.e1.known) elseif node.op.op == "or" then - apply_facts(node, facts_not(node, node.e1.known)) + self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then if e1type is FunctionType then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(e1type.rets, node.expected) + -- this forces typevars in function return types + self:is_a(e1type.rets, node.expected) end local e1args = e1type.args.tuple local at = argdelta @@ -11515,8 +11546,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {Type}): Type - end_scope() + after = function(self: TypeChecker, node: Node, children: {Type}): Type + self:end_scope() -- given a and b: may be TupleType local ga: Type = children[1] @@ -11527,29 +11558,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ub: Type -- resolved a and b: not NominalType - local ra: Type = to_structural(ua) + local ra: Type = self:to_structural(ua) local rb: Type if ra.typename == "circular_require" or (ra is TypeDeclType and ra.def and ra.def.typename == "circular_require") then - return invalid_at(node, "cannot dereference a type from a circular require") + return self.errs:invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(ua) then + if self.feat_lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then - add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + self.errs:add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, ua, gb) + assert(gb is TupleType) + local t = self:type_check_funcall(node, ua, gb) return t elseif node.op.op == "as" then return gb end - local expected = node.expected and to_structural(resolve_tuple(node.expected)) + local expected = node.expected and self:to_structural(resolve_tuple(node.expected)) - ensure_not_abstract(node.e1, ra) + local ok, err = ensure_not_abstract(ra) + if not ok then + self.errs:add(node.e1, err) + end if ra is TypeDeclType and ra.def.typename == "record" then ra = ra.def end @@ -11558,8 +11593,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- after they are handled above, we can resolve b's tuple and only use that instead. if gb then ub = resolve_tuple(gb) - rb = to_structural(ub) - ensure_not_abstract(node.e2, rb) + rb = self:to_structural(ub) + ok, err = ensure_not_abstract(rb) + if not ok then + self.errs:add(node.e2, err) + end if rb is TypeDeclType and rb.def.typename == "record" then rb = rb.def end @@ -11569,22 +11607,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.receiver = ua assert(node.e2.kind == "identifier") - local bnode: Node = { - y = node.e2.y, - x = node.e2.x, + local bnode: Node = node_at(node.e2, { tk = node.e2.tk, kind = "string", - } - local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk } as StringType)) - local t = type_check_index(node.e1, bnode, ua, btype) + }) + local btype = a_type(node.e2, "string", { literal = node.e2.tk } as StringType) + local t = self:type_check_index(node.e1, bnode, ua, btype) - if t.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and self.gen_compat ~= "off" then -- only apply to a literal use, not a propagated type if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk - all_needs_compat[key] = true + self.all_needs_compat[key] = true end end @@ -11592,22 +11628,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, ua, ub) + return self:type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then if rb.typename == "integer" then - all_needs_compat["math"] = true + self.all_needs_compat["math"] = true end if ra is TypeDeclType then - error_at(node, "can only use 'is' on variables, not types") + self.errs:add(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact { var = node.e1.tk, typ = ub, where = node } + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact { var = node.e1.tk, typ = ub, w = node } else - error_at(node, "can only use 'is' on variables") + self.errs:add(node, "can only use 'is' on variables") end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == ":" then @@ -11615,16 +11651,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- we handle ':' separately from '.' because ':' is specific to records, -- so we produce different error messages - if lax and (is_unknown(ua) or ua.typename == "typevar") then + if self.feat_lax and (is_unknown(ua) or ua is TypeVarType) then if node.e1.kind == "variable" then - add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) + self.errs:add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - return UNKNOWN + return an_unknown(node) end - local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) + local t, e = self:match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, ua) + return self.errs:invalid_at(node.e2, e, ua) end return t @@ -11632,7 +11668,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == "and" then @@ -11650,33 +11686,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = ua - elseif ((ra is EnumType and rb is StringType and is_a(rb, ra)) - or (ra is StringType and rb is EnumType and is_a(ra, rb))) then + elseif ((ra is EnumType and rb is StringType and self:is_a(rb, ra)) + or (ra is StringType and rb is EnumType and self:is_a(ra, rb))) then node.known = nil t = (ra is EnumType and ra or rb) elseif expected and expected is UnionType then -- must be checked after string/enum above node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ra, rb}, true) + local u = unite(node, {ra, rb}, true) if u is UnionType then - local ok, err = is_valid_union(u) + ok, err = is_valid_union(u) if not ok then - u = err and invalid_at(node, err, u) or INVALID + u = err and self.errs:invalid_at(node, err, u) or an_invalid(node) end end t = u else - local a_ge_b = is_a(rb, ra) - local b_ge_a = is_a(ra, rb) + local a_ge_b = self:is_a(rb, ra) + local b_ge_a = self:is_a(ra, rb) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(ua, expected) - local b_is = is_a(ub, expected) + local a_is = self:is_a(ua, expected) + local b_is = self:is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, expected) + t = self:resolve_typevars_at(node, expected) end end if not t then @@ -11695,44 +11731,46 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "==" or node.op.op == "~=" then -- if is_lua_table_type(ra) and is_lua_table_type(rb) then --- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) +-- self:check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end if ra is EnumType and rb is StringType then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", ub, ua) + return self.errs:invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra is TupleTableType and rb is TupleTableType and #ra.types ~= #rb.types then - return invalid_at(node, "tuples are not the same size") - elseif is_a(ub, ua) or ua.typename == "typevar" then + return self.errs:invalid_at(node, "tuples are not the same size") + elseif self:is_a(ub, ua) or ua is TypeVarType then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact { var = node.e1.tk, typ = ub, where = node } + node.known = EqFact { var = node.e1.tk, typ = ub, w = node } end - elseif is_a(ua, ub) or ub.typename == "typevar" then + elseif self:is_a(ua, ub) or ub is TypeVarType then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact { var = node.e2.tk, typ = ua, where = node } + node.known = EqFact { var = node.e2.tk, typ = ua, w = node } end - elseif lax and (is_unknown(ua) or is_unknown(ub)) then - return UNKNOWN + elseif self.feat_lax and (is_unknown(ua) or is_unknown(ub)) then + return an_unknown(node) else - return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) + return self.errs:invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.arity == 1 and unop_types[node.op.op] then if ra is UnionType then - ra = unite(ra.types, true) -- squash unions of string constants + ra = unite(node, ra.types, true) -- squash unions of string constants end local types_op = unop_types[node.op.op] - local t = types_op[ra.typename] + local tn = types_op[ra.typename] + local t = tn and a_type(node, tn, {}) if not t and ra is RecordLikeType then t = find_in_interface_list(ra, function(ty: Type): Type - return types_op[ty.typename] + local tname = types_op[ty.typename] + return tname and a_type(node, tname, {}) end) end @@ -11740,19 +11778,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) end end if ra is MapType then if ra.keys.typename == "number" or ra.keys.typename == "integer" then - add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + self.errs:add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - error_at(node, "using the '#' operator on this map will always return 0") + self.errs:add(node, "using the '#' operator on this map will always return 0") end end @@ -11760,12 +11797,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = FACT_TRUTHY end - if node.op.op == "~" and env.gen_target == "5.1" then + if node.op.op == "~" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", "bnot", node.e1) end end @@ -11779,39 +11816,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if ra is UnionType then - ra = unite(ra.types, true) -- squash unions of string constants + ra = unite(ra, ra.types, true) -- squash unions of string constants end if rb is UnionType then - rb = unite(rb.types, true) -- squash unions of string constants + rb = unite(rb, rb.types, true) -- squash unions of string constants end local types_op = binop_types[node.op.op] - local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local tn = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local t = tn and a_type(node, tn, {}) local meta_on_operator: integer if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) if node.op.op == "or" then - local u = unite({ua, ub}) + local u = unite(node, {ua, ub}) if u is UnionType and is_valid_union(u) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + self.errs:add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end end if ua is NominalType and ub is NominalType and not meta_on_operator then - if is_a(ua, ub) then + if self:is_a(ua, ub) then t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) + self.errs:add(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end @@ -11819,20 +11856,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = FACT_TRUTHY end - if node.op.op == "//" and env.gen_target == "5.1" then + if node.op.op == "//" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2) else - local div: Node = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } + local div: Node = node_at(node, { kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }) convert_node_to_compat_call(node, "math", "floor", div) end - elseif bit_operators[node.op.op] and env.gen_target == "5.1" then + elseif bit_operators[node.op.op] and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end @@ -11844,28 +11881,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable"] = { - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type if node.tk == "..." then - local va_sentinel = find_var_type("@is_va") + local va_sentinel = self:find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return invalid_at(node, "cannot use '...' outside a vararg function") + return self.errs:invalid_at(node, "cannot use '...' outside a vararg function") end end local t: Type if node.tk == "_G" then - t, node.attribute = simulate_g() + t, node.attribute = self:simulate_g() else local use: VarUse = node.is_lvalue and "lvalue" or "use" - t, node.attribute = find_var_type(node.tk, use) + t, node.attribute = self:find_var_type(node.tk, use) end if not t then - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return an_unknown(node) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end if t is TypeDeclType then @@ -11876,70 +11913,70 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["type_identifier"] = { - after = function(node: Node, _children: {Type}): Type - local typ, attr = find_var_type(node.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local typ, attr = self:find_var_type(node.tk) node.attribute = attr if typ then return typ end - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return an_unknown(node) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local t = children[1] if not t then - t = UNKNOWN + t = an_unknown(node) end if node.tk == "..." then - t = a_vararg { t } + t = a_vararg(node, { t }) end - add_var(node, node.tk, t).is_func_arg = true + self:add_var(node, node.tk, t).is_func_arg = true return t end, }, ["identifier"] = { - after = function(_node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type return NONE -- type is resolved elsewhere end, }, ["newtype"] = { - after = function(node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return node.newtype end, }, ["error_node"] = { - after = function(_node: Node, _children: {Type}): Type - return INVALID + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type + return an_invalid(node) end, } } visit_node.cbs["break"] = { - after = function(_node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] - local function after_literal(node: Node): Type + local function after_literal(_self: TypeChecker, node: Node): Type node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind as TypeName, {})) + return a_type(node, node.kind as TypeName, {}) end visit_node.cbs["string"] = { - after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) as StringType + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local t = after_literal(self, node) as StringType t.literal = node.conststr - local expected = node.expected and to_structural(node.expected) - if expected and expected is EnumType and is_a(t, expected) then + local expected = node.expected and self:to_structural(node.expected) + if expected and expected is EnumType and self:is_a(t, expected) then return node.expected end @@ -11950,8 +11987,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["integer"] = { after = after_literal } visit_node.cbs["boolean"] = { - after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local t = after_literal(self, node) node.known = (node.tk == "true") and FACT_TRUTHY or nil return t end, @@ -11962,7 +11999,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node: Node, _children: {Type}, t: Type): Type + visit_node.after = function(_self: TypeChecker, node: Node, _children: {Type}, t: Type): Type if node.expanded then apply_macroexp(node) end @@ -11970,13 +12007,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local expand_interfaces: function(Type) do - local function add_interface_fields(what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) + local function add_interface_fields(self: TypeChecker, what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) for fname, ftype in fields_of(resolved, list) do if fields[fname] then - if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) + if not self:is_a(fields[fname], ftype) then + self.errs:add(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11985,18 +12021,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function collect_interfaces(list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} + local function collect_interfaces(self: TypeChecker, list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} if t.interface_list then for _, iface in ipairs(t.interface_list) do if iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if not (ri.typename == "invalid") then - assert(ri is InterfaceType, "nominal resolved to " .. ri.typename) - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(list, ri, seen) + if ri is InterfaceType then + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) + end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end - table.insert(list, iface) end else if not seen[iface] then @@ -12009,30 +12048,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return list end - expand_interfaces = function(t: RecordLikeType) + function TypeChecker:expand_interfaces(t: RecordLikeType) if t.interfaces_expanded then return end t.interfaces_expanded = true - t.interface_list = collect_interfaces({}, t, {}) + t.interface_list = collect_interfaces(self, {}, t, {}) for _, iface in ipairs(t.interface_list) do if iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) assert(ri is InterfaceType) - add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields(self, "field", t.fields, t.field_order, ri, iface) if ri.meta_fields then t.meta_fields = t.meta_fields or {} t.meta_field_order = t.meta_field_order or {} - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + add_interface_fields(self, "metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") end else if not t.elements then t.elements = iface else - if not same_type(iface.elements, t.elements) then - error_at(t, "incompatible array interfaces") + if not self:same_type(iface.elements, t.elements) then + self.errs:add(t, "incompatible array interfaces") end end end @@ -12040,33 +12079,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local visit_type: Visitor + local visit_type: Visitor visit_type = { cbs = { ["function"] = { - before = function(_typ: Type) - begin_scope() + before = function(self: TypeChecker, _typ: Type) + self:begin_scope() end, - after = function(typ: Type, _children: {Type}): Type - end_scope() - return ensure_fresh_typeargs(typ) + after = function(self: TypeChecker, typ: Type, _children: {Type}): Type + self:end_scope() + return self:ensure_fresh_typeargs(typ) end, }, ["record"] = { - before = function(typ: RecordType) - begin_scope() - add_var(nil, "@self", type_at(typ, a_typedecl(typ))) + before = function(self: TypeChecker, typ: RecordType) + self:begin_scope() + self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) for fname, ftype in fields_of(typ) do if ftype is TypeAliasType then - resolve_nominal(ftype.alias_to) - add_var(nil, fname, ftype) + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) elseif ftype is TypeDeclType then - add_var(nil, fname, ftype) + self:add_var(nil, fname, ftype) end end end, - after = function(typ: RecordType, children: {Type}): Type + after = function(self: TypeChecker, typ: RecordType, children: {Type}): Type local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -12080,11 +12119,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if iface is ArrayType then typ.interface_list[j] = iface elseif iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if ri is InterfaceType then typ.interface_list[j] = iface else - error_at(children[i], "%s is not an interface", children[i]) + self.errs:add(children[i], "%s is not an interface", children[i]) end end i = i + 1 @@ -12124,7 +12163,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif ftype is TypeAliasType then - resolve_typealias(ftype) + self:resolve_typealias(ftype) end typ.fields[name] = ftype @@ -12143,55 +12182,55 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if typ.interface_list then - expand_interfaces(typ) + self:expand_interfaces(typ) end if fmacros then for _, t in ipairs(fmacros) do - local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(self, t.macroexp, visit_node, visit_type) - check_macroexp_arg_use(t.macroexp) + self:check_macroexp_arg_use(t.macroexp) - if not is_a(macroexp_type, t) then - error_at(macroexp_type, "macroexp type does not match declaration") + if not self:is_a(macroexp_type, t) then + self.errs:add(macroexp_type, "macroexp type does not match declaration") end end end - end_scope() + self:end_scope() return typ end, }, ["typearg"] = { - after = function(typ: TypeArgType, _children: {Type}): Type - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + after = function(self: TypeChecker, typ: TypeArgType, _children: {Type}): Type + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - } as TypeArgType))) + } as TypeArgType)) return typ end, }, ["typevar"] = { - after = function(typ: TypeVarType, _children: {Type}): Type - if not find_var_type(typ.typevar) then - error_at(typ, "undefined type variable " .. typ.typevar) + after = function(self: TypeChecker, typ: TypeVarType, _children: {Type}): Type + if not self:find_var_type(typ.typevar) then + self.errs:add(typ, "undefined type variable " .. typ.typevar) end return typ end, }, ["nominal"] = { - after = function(typ: NominalType, _children: {Type}): Type + after = function(self: TypeChecker, typ: NominalType, _children: {Type}): Type if typ.found then return typ end - local t = find_type(typ.names, true) + local t = self:find_type(typ.names, true) if t then if t is TypeArgType then -- convert nominal into a typevar typ.names = nil - edit_type(typ, "typevar") + edit_type(typ, typ, "typevar") local tv = typ as TypeVarType tv.typevar = t.typearg tv.constraint = t.constraint @@ -12202,18 +12241,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else local name = typ.names[1] - local unresolved = get_unresolved() - unresolved.nominals[name] = unresolved.nominals[name] or {} - table.insert(unresolved.nominals[name], typ) + local scope = self.st[#self.st] + scope.pending_nominals = scope.pending_nominals or {} + scope.pending_nominals[name] = scope.pending_nominals[name] or {} + table.insert(scope.pending_nominals[name], typ) end return typ end, }, ["union"] = { - after = function(typ: UnionType, _children: {Type}): Type + after = function(self: TypeChecker, typ: UnionType, _children: {Type}): Type local ok, err = is_valid_union(typ) if not ok then - return err and invalid_at(typ, err, typ) or INVALID + return err and self.errs:invalid_at(typ, err, typ) or an_invalid(typ) end return typ end @@ -12221,59 +12261,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, } - local function internal_compiler_check(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) - return function(w: W, children: {Type}, t: Type): Type - t = fn and fn(w, children, t) or t - - if type(t) ~= "table" then - error(((w as Node).kind or (w as Type).typename) .. " did not produce a type") - end - if type(t.typename) ~= "string" then - error(((w as Node).kind or (w as Type).typename) .. " type does not have a typename") - end - - return t - end - end - - local function store_type_after(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) - return function(w: W, children: {Type}, t: Type): Type - t = fn and fn(w, children, t) or t - - local where = w as Where - - if where.y then - tc.store_type(where.y, where.x, t) - end - - return t - end - end - - local function debug_type_after(fn: function(Node, {Type}, Type): (Type)): (function(Node, {Type}, Type): (Type)) - return function(node: Node, children: {Type}, t: Type): Type - t = fn and fn(node, children, t) or t - node.debug_type = t - return t - end - end - - if opts.run_internal_compiler_checks then - visit_node.after = internal_compiler_check(visit_node.after) - visit_type.after = internal_compiler_check(visit_type.after) - end - - if tc then - visit_node.after = store_type_after(visit_node.after) - visit_type.after = store_type_after(visit_type.after) - end - - if TL_DEBUG then - visit_node.after = debug_type_after(visit_node.after) - end - local default_type_visitor = { - after = function(typ: Type, _children: {Type}): Type + after = function(_self: TypeChecker, typ: Type, _children: {Type}): Type return typ end, } @@ -12300,70 +12289,201 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor - assert(ast.kind == "statements") - recurse_node(ast, visit_node, visit_type) + local type VisitorAfterPatcher = function(VisitorAfter): VisitorAfter - close_types(st[1]) - check_for_unused_vars(st[1], true) + local function internal_compiler_check(fn: VisitorAfter): VisitorAfter + return function(s: S, n: N, children: {Type}, t: Type): Type + t = fn and fn(s, n, children, t) or t + + if type(t) ~= "table" then + error(((n as Node).kind or (n as Type).typename) .. " did not produce a type") + end + if type(t.typename) ~= "string" then + error(((n as Node).kind or (n as Type).typename) .. " type does not have a typename") + end - clear_redundant_errors(errors) + return t + end + end - add_compat_entries(ast, all_needs_compat, env.gen_compat) + local function store_type_after(fn: VisitorAfter): VisitorAfter + return function(self: TypeChecker, n: N, children: {Type}, t: Type): Type + t = fn and fn(self, n, children, t) or t - local result = { - ast = ast, - env = env, - type = module_type or BOOLEAN, - filename = filename, - warnings = warnings, - type_errors = errors, - dependencies = dependencies, - } + local w = n as Where - env.loaded[filename] = result - table.insert(env.loaded_order, filename) + if w.y then + self.collector.store_type(w.y, w.x, t) + end - if tc then - env.reporter:store_result(tc, env.globals) + return t + end end - return result -end + local function debug_type_after(fn: VisitorAfter): VisitorAfter + return function(s: S, node: Node, children: {Type}, t: Type): Type + t = fn and fn(s, node, children, t) or t --------------------------------------------------------------------------------- --- Report types --------------------------------------------------------------------------------- + node.debug_type = t + return t + end + end -function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} - local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer - local function le(a: {integer, integer}, b: {integer, integer}): boolean - return a[1] < b[1] - or (a[1] == b[1] and a[2] <= b[2]) + local function patch_visitors(my_visit_node: Visitor, + after_node: VisitorAfterPatcher, + my_visit_type?: Visitor, + after_type?: VisitorAfterPatcher): + Visitor, + Visitor + if my_visit_node == visit_node then + my_visit_node = shallow_copy_table(my_visit_node) end - return binary_search(symbols, {at_y, at_x}, le) or 0 + my_visit_node.after = after_node(my_visit_node.after) + if my_visit_type then + if my_visit_type == visit_type then + my_visit_type = shallow_copy_table(my_visit_type) + end + my_visit_type.after = after_type(my_visit_type.after) + else + my_visit_type = visit_type + end + return my_visit_node, my_visit_type end - local ret: {string:integer} = {} + local function set_feat(feat: Feat, default: boolean): boolean + if feat then + return (feat == "on") + else + return default + end + end - local n = find(tr.symbols, y, x) + tl.type_check = function(ast: Node, filename: string, opts: TypeCheckOptions, env?: Env): Result, string + assert(filename is string, "tl.type_check signature has changed, pass filename separately") + assert((not opts) or (not (opts as {any:any}).env), "tl.type_check signature has changed, pass env separately") - local symbols = tr.symbols - while n >= 1 do - local s = symbols[n] - if s[3] == "@{" then - n = n - 1 - elseif s[3] == "@}" then - n = s[4] + filename = filename or "?" + + opts = opts or {} + + if not env then + local err: string + env, err = tl.new_env({ defaults = opts }) + if err then + return nil, err + end + end + + local self: TypeChecker = { + filename = filename, + env = env, + st = { + { + vars = env.globals, + pending_global_types = {}, + }, + }, + errs = Errors.new(filename), + all_needs_compat = {}, + dependencies = {}, + subtype_relations = TypeChecker.subtype_relations, + eqtype_relations = TypeChecker.eqtype_relations, + type_priorities = TypeChecker.type_priorities, + } + + setmetatable(self, { __index = TypeChecker }) + + self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) + self.feat_arity = set_feat(opts.feat_arity or env.defaults.feat_arity, true) + self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT + self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET + + if self.gen_target == "5.4" and self.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + if self.feat_lax then + self.type_priorities = shallow_copy_table(self.type_priorities) + self.type_priorities["unknown"] = 0 + + self.subtype_relations = shallow_copy_table(self.subtype_relations) + + self.subtype_relations["unknown"] = {} + self.subtype_relations["unknown"]["*"] = compare_true + + self.subtype_relations["*"] = shallow_copy_table(self.subtype_relations["*"]) + self.subtype_relations["*"]["unknown"] = compare_true + -- in .lua files, all values can be used in a boolean context + self.subtype_relations["*"]["boolean"] = compare_true + + self.get_rets = function(rets: TupleType): TupleType + if #rets.tuple == 0 then + return a_vararg(rets, { an_unknown(rets) }) + end + return rets + end else - ret[s[3]] = s[4] - n = n - 1 + self.get_rets = function(rets: TupleType): TupleType + return rets + end end - end - return ret + if env.report_types then + env.reporter = env.reporter or tl.new_type_reporter() + self.collector = env.reporter:get_collector(filename) + end + + local visit_node, visit_type = visit_node, visit_type + if opts.run_internal_compiler_checks then + visit_node, visit_type = patch_visitors( + visit_node, internal_compiler_check, + visit_type, internal_compiler_check + ) + end + if self.collector then + visit_node, visit_type = patch_visitors( + visit_node, store_type_after, + visit_type, store_type_after + ) + end + if TL_DEBUG then + visit_node, visit_type = patch_visitors( + visit_node, debug_type_after + ) + end + + assert(ast.kind == "statements") + recurse_node(self, ast, visit_node, visit_type) + + local global_scope = self.st[1] + close_types(global_scope) + self.errs:warn_unused_vars(global_scope, true) + + clear_redundant_errors(self.errs.errors) + + add_compat_entries(ast, self.all_needs_compat, self.gen_compat) + + local result = { + ast = ast, + env = env, + type = self.module_type or a_type(ast, "boolean", {}), + filename = filename, + warnings = self.errs.warnings, + type_errors = self.errs.errors, + dependencies = self.dependencies, + } + + env.loaded[filename] = result + table.insert(env.loaded_order, filename or "") + + if self.collector then + env.reporter:store_result(self.collector, env.globals) + end + + return result + end end -------------------------------------------------------------------------------- @@ -12379,9 +12499,24 @@ local function read_full_file(fd: FILE): string, string return content, err end -tl.process = function(filename: string, env: Env, fd?: FILE): Result, string - assert((not fd or type(fd) ~= "string"), "fd must be a file") +local function feat_lax_heuristic(filename?: string, input?: string): Feat + if filename then + local _, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + if extension == "tl" then + return "off" + elseif extension == "lua" then + return "on" + end + end + if input then + return (input:match("^#![^\n]*lua[^\n]*\n")) and "on" or "off" + end + return "off" +end +tl.process = function(filename: string, env: Env, fd?: FILE): Result, string if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12401,23 +12536,38 @@ tl.process = function(filename: string, env: Env, fd?: FILE): Result, string return nil, "could not read " .. filename .. ": " .. err end - local _, extension = filename:match("(.*)%.([a-z]+)$") - extension = extension and extension:lower() + return tl.process_string(input, env, filename) +end - local is_lua: boolean - if extension == "tl" then - is_lua = false - elseif extension == "lua" then - is_lua = true - else - is_lua = input:match("^#![^\n]*lua[^\n]*\n") as boolean +function tl.target_from_lua_version(str: string): GenTarget + if str == "Lua 5.1" + or str == "Lua 5.2" then + return "5.1" + elseif str == "Lua 5.3" then + return "5.3" + elseif str == "Lua 5.4" then + return "5.4" end +end - return tl.process_string(input, is_lua, env, filename) +local function default_env_opts(runtime: boolean, filename?: string, input?: string): EnvOptions + local gen_target = runtime and tl.target_from_lua_version(_VERSION) or DEFAULT_GEN_TARGET + local gen_compat: GenCompat = (gen_target == "5.4") and "off" or DEFAULT_GEN_COMPAT + return { + defaults = { + feat_lax = feat_lax_heuristic(filename, input), + gen_target = gen_target, + gen_compat = gen_compat, + run_internal_compiler_checks = false, + } + } end -function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: string): Result - env = env or tl.init_env(is_lua) +function tl.process_string(input: string, env?: Env, filename?: string): Result + assert(type(env) ~= "boolean", "tl.process_string signature has changed") + + env = env or tl.new_env(default_env_opts(false, filename, input)) + if env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12429,7 +12579,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: local result = { ok = false, filename = filename, - type = BOOLEAN, + type = a_type({ f = filename, y = 1, x = 1 }, "boolean", {}), type_errors = {}, syntax_errors = syntax_errors, env = env, @@ -12439,14 +12589,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: return result end - local opts: TypeCheckOptions = { - filename = filename, - lax = is_lua, - gen_compat = env.gen_compat, - gen_target = env.gen_target, - env = env, - } - local result = tl.type_check(program, opts) + local result = tl.type_check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors @@ -12454,15 +12597,15 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: end tl.gen = function(input: string, env: Env, pp: PrettyPrintOptions): string, Result - env = env or assert(tl.init_env(), "Default environment initialization failed") - local result = tl.process_string(input, false, env) + env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") + local result = tl.process_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code: string - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) return code, result end @@ -12478,28 +12621,25 @@ local function tl_package_loader(module_name: string): any, any if #errs > 0 then error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end - local lax = not not found_filename:match("lua$") local env = tl.package_loader_env if not env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = assert(tl.new_env(), "Default environment initialization failed") env = tl.package_loader_env end - env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + local opts = default_env_opts(true, found_filename) - local result = tl.type_check(program, { - lax = lax, - filename = found_filename, - env = env, - run_internal_compiler_checks = false, - }) + local w = { f = found_filename, x = 1, y = 1 } + env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + + local result = tl.type_check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type -- TODO: should this be a hard error? this seems analogous to -- finding a lua file with a syntax error in it - local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) + local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname: string, loader_data: string): any @@ -12525,21 +12665,10 @@ function tl.loader() end end -function tl.target_from_lua_version(str: string): TargetMode - if str == "Lua 5.1" - or str == "Lua 5.2" then - return "5.1" - elseif str == "Lua 5.3" then - return "5.3" - elseif str == "Lua 5.4" then - return "5.4" - end -end - -local function env_for(lax: boolean, env_tbl: {any:any}): Env +local function env_for(opts: EnvOptions, env_tbl: {any:any}): Env if not env_tbl then if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end return tl.package_loader_env end @@ -12548,7 +12677,7 @@ local function env_for(lax: boolean, env_tbl: {any:any}): Env tl.load_envs = setmetatable({}, { __mode = "k" }) end - tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax) + tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.new_env(opts) return tl.load_envs[env_tbl] end @@ -12558,17 +12687,14 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end - local lax = chunkname and not not chunkname:match("lua$") + local opts = default_env_opts(true, chunkname) + if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end - local result = tl.type_check(program, { - lax = lax, - filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""), - env = env_for(lax, ...), - run_internal_compiler_checks = false, - }) + local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") + local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -12582,7 +12708,7 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a mode = mode:gsub("c", "") as LoadMode end - local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true) + local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) if not code then return nil, err end @@ -12590,4 +12716,29 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return load(code, chunkname, mode, ...) end +-------------------------------------------------------------------------------- +-- Backwards compatibility +-------------------------------------------------------------------------------- + +function tl.get_types(result: Result): TypeReport, TypeReporter + return result.env.reporter:get_report(), result.env.reporter +end + +tl.init_env = function(lax?: boolean, gen_compat?: boolean | GenCompat, gen_target?: GenTarget, predefined?: {string}): Env, string + local opts = { + defaults = { + feat_lax = (lax and "on" or "off") as Feat, + gen_compat = ((gen_compat is GenCompat) and gen_compat) or + (gen_compat == false and "off") or + (gen_compat == true or gen_compat == nil) and "optional", + gen_target = gen_target or + ((_VERSION == "Lua 5.1" or _VERSION == "Lua 5.2") and "5.1") or + "5.3", + }, + predefined_modules = predefined, + } + + return tl.new_env(opts) +end + return tl From 4ead134e78a072b188cdab3de5563aa19cec0148 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 26 Jan 2024 17:50:18 -0300 Subject: [PATCH 118/224] fix: check if for iterator does not return values Fixes #736. --- spec/statement/forin_spec.lua | 10 ++++++++++ tl.lua | 8 +++++++- tl.tl | 8 +++++++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/spec/statement/forin_spec.lua b/spec/statement/forin_spec.lua index 9f93f131b..c473b1c55 100644 --- a/spec/statement/forin_spec.lua +++ b/spec/statement/forin_spec.lua @@ -188,6 +188,16 @@ describe("forin", function() })) describe("regression tests", function() + it("catches if iterator function does not return values (#736)", util.check_type_error([[ + local function f() + end + + for k, v in f() do + end + ]], { + { y = 4, msg = "expression in 'for' statement does not return any values" }, + })) + it("does not accept annotations (#701)", util.check_syntax_error([[ for k , v in pairs(table as {string:any}) do k = "hello" diff --git a/tl.lua b/tl.lua index dba438acf..d0b82df11 100644 --- a/tl.lua +++ b/tl.lua @@ -10881,8 +10881,14 @@ self:expand_type(node, values, elements) }) assert(exptuple.typename == "tuple") local exptypes = exptuple.tuple - self:widen_all_unions(node) local exp1 = node.exps[1] + if #exptypes < 1 then + self.errs:invalid_at(exp1, "expression in 'for' statement does not return any values") + return + end + + self:widen_all_unions(node) + local args = a_type(node.exps, "tuple", { tuple = { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], diff --git a/tl.tl b/tl.tl index fea76409c..2ff5fa7b3 100644 --- a/tl.tl +++ b/tl.tl @@ -10881,8 +10881,14 @@ do assert(exptuple is TupleType) local exptypes = exptuple.tuple - self:widen_all_unions(node) local exp1 = node.exps[1] + if #exptypes < 1 then + self.errs:invalid_at(exp1, "expression in 'for' statement does not return any values") + return + end + + self:widen_all_unions(node) + local args = a_tuple(node.exps, { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3] From 30515ef60893ebe52194edb356300959e9f146da Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 26 Jan 2024 17:59:14 -0300 Subject: [PATCH 119/224] fix: lax arity of returns in function literals The unwanted difference between the types for `f` and `g` in the regression test from this commit was observed when fixing the issue #736. --- spec/lax/lax_spec.lua | 23 +++++++++++++++++++++++ tl.lua | 2 +- tl.tl | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/spec/lax/lax_spec.lua b/spec/lax/lax_spec.lua index 92a9d6061..f3047374b 100644 --- a/spec/lax/lax_spec.lua +++ b/spec/lax/lax_spec.lua @@ -20,4 +20,27 @@ describe("lax mode", function() { msg = "three" }, { msg = "data" }, })) + + it("defines lax arity of returns in function literals (#736)", util.lax_check([[ + -- f: function(unknown...):unknown... + local f = function() + return function() end + end + + for a, b in f() do + end + + -- g: function(unknown...):unknown... + local function g() + return function() end + end + + for x, y in g() do + end + ]], { + { msg = "a" }, + { msg = "b" }, + { msg = "x" }, + { msg = "y" }, + })) end) diff --git a/tl.lua b/tl.lua index d0b82df11..60c9f57f5 100644 --- a/tl.lua +++ b/tl.lua @@ -11459,7 +11459,7 @@ self:expand_type(node, values, elements) }) min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = rets, + rets = self.get_rets(rets), })) end, }, diff --git a/tl.tl b/tl.tl index 2ff5fa7b3..eac3f9590 100644 --- a/tl.tl +++ b/tl.tl @@ -11459,7 +11459,7 @@ do min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = rets, + rets = self.get_rets(rets), })) end, }, From 9701800d59ff2b140e13081e3c0212ca11d7ad25 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 26 Jan 2024 18:21:08 -0300 Subject: [PATCH 120/224] fix: let lax mode perform emptytable key-value inference as normal Do not over-constrain types with `unknown`, let `assert_is_a` perform the check when t2 is `unresolved_emptytable_value`. --- spec/lax/lax_spec.lua | 13 +++++++++++++ tl.lua | 2 +- tl.tl | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/spec/lax/lax_spec.lua b/spec/lax/lax_spec.lua index f3047374b..e36b8d5b1 100644 --- a/spec/lax/lax_spec.lua +++ b/spec/lax/lax_spec.lua @@ -43,4 +43,17 @@ describe("lax mode", function() { msg = "x" }, { msg = "y" }, })) + + it("performs emptytable key-value inference as normal", util.lax_check([[ + local t = {} + + local s = "str" + + t[s] = 9 + + for k, v in pairs(t) do + print(k, v) + end + ]], {})) + end) diff --git a/tl.lua b/tl.lua index 60c9f57f5..d1eb87961 100644 --- a/tl.lua +++ b/tl.lua @@ -8585,7 +8585,7 @@ a.types[i], b.types[i]), } function TypeChecker:assert_is_a(w, t1, t2, ctx, name) t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or t2.typename == "unknown") then return true end diff --git a/tl.tl b/tl.tl index eac3f9590..33d0c70c9 100644 --- a/tl.tl +++ b/tl.tl @@ -8585,7 +8585,7 @@ do function TypeChecker:assert_is_a(w: Where, t1: Type, t2: Type, ctx?: string | Node, name?: string): boolean t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or t2 is UnknownType) then return true end From 705b61e9d3333c7a9da53bae4369313f26efd012 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 13 May 2024 14:05:46 -0300 Subject: [PATCH 121/224] fix 'where'/macroexp behavior with type arguments --- spec/code_gen/macroexp_spec.lua | 84 ++++++++++++++++++++++++++++ spec/declaration/record_spec.lua | 41 ++++++++++++++ tl.lua | 95 +++++++++++++++++++++++++------- tl.tl | 95 +++++++++++++++++++++++++------- 4 files changed, 275 insertions(+), 40 deletions(-) create mode 100644 spec/code_gen/macroexp_spec.lua diff --git a/spec/code_gen/macroexp_spec.lua b/spec/code_gen/macroexp_spec.lua new file mode 100644 index 000000000..83351d30d --- /dev/null +++ b/spec/code_gen/macroexp_spec.lua @@ -0,0 +1,84 @@ +local util = require("spec.util") + +describe("macroexp code generation", function() + it("can use where with generic types", util.gen([[ + local type Success = record + where self.error == false + + error: boolean + value: T + end + + local type Failure = record + where self.error == true + + error: boolean + value: T + end + + local function ok(value: T): Success + return { + error = false, + value = value, + } + end + + local function fail(value: T): Failure + return { + error = true, + value = value, + } + end + + local type Maybe = Success | Failure + + local function call_me(maybe: Maybe) + if maybe is Success then + print("hello, " .. tostring(maybe.value)) + end + end + + call_me(ok(8675309)) + call_me(fail(911)) + ]], [[ + + + + + + + + + + + + + + + local function ok(value) + return { + error = false, + value = value, + } + end + + local function fail(value) + return { + error = true, + value = value, + } + end + + + + local function call_me(maybe) + if maybe.error == false then + print("hello, " .. tostring(maybe.value)) + end + end + + call_me(ok(8675309)) + call_me(fail(911)) + ]])) +end) + diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 7dc5b018d..a188a482c 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -438,6 +438,47 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces end ]])) + it("can use where with generic types", util.check([[ + local type Success = ]]..statement..[[ ]]..array(i, "is {integer}")..[[ + where self.error == false + + error: boolean + value: T + end + + local type Failure = ]]..statement..[[ ]]..array(i, "is {integer}")..[[ + where self.error == true + + error: boolean + value: T + end + + local function ok(value: T): Success + return { + error = false, + value = value, + } + end + + local function fail(value: T): Failure + return { + error = true, + value = value, + } + end + + local type Maybe = Success | Failure + + local function call_me(maybe: Maybe) + if maybe is Success then + print("hello, " .. tostring(maybe.value)) + end + end + + call_me(ok(8675309)) + call_me(fail(911)) + ]])) + if statement == "record" then it("does not produce an esoteric type error (#167)", util.check_type_error([[ local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ diff --git a/tl.lua b/tl.lua index d1eb87961..142f39176 100644 --- a/tl.lua +++ b/tl.lua @@ -1891,6 +1891,7 @@ end + local TruthyFact = {} @@ -3615,13 +3616,16 @@ do } local function parse_macroexp(ps, istart, iargs) - - - - local node = new_node(ps, istart, "macroexp") + local i - i, node.args, node.min_arity = parse_argument_list(ps, iargs) + if ps.tokens[istart + 1].tk == "<" then + i, node.typeargs = parse_anglebracket_list(ps, istart + 1, parse_typearg) + else + i = iargs + end + + i, node.args, node.min_arity = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) @@ -3630,12 +3634,22 @@ do return i, node end - local function parse_where_clause(ps, i) + local function parse_where_clause(ps, i, typeargs) local node = new_node(ps, i, "macroexp") + + local selftype = new_nominal(ps, i, "@self") + if typeargs then + selftype.typevals = {} + for a, t in ipairs(typeargs) do + selftype.typevals[a] = a_nominal(node, { t.typearg }) + end + end + + node.is_method = true node.args = new_node(ps, i, "argument_list") node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = new_nominal(ps, i, "@self") + node.args[1].argtype = selftype node.min_arity = 1 node.rets = new_tuple(ps, i) node.rets.tuple[1] = new_type(ps, i, "boolean") @@ -3674,6 +3688,17 @@ do return i, t end + local function clone_typeargs(ps, i, typeargs) + local copy = {} + for a, ta in ipairs(typeargs) do + local cta = new_type(ps, i, "typearg") + cta.typearg = ta.typearg + copy[a] = cta + end + return copy + end + + parse_record_body = function(ps, i, def, node) local istart = i - 1 def.fields = {} @@ -3715,9 +3740,12 @@ do local wstart = i i = i + 1 local where_macroexp - i, where_macroexp = parse_where_clause(ps, i) + i, where_macroexp = parse_where_clause(ps, i, def.typeargs) local typ = new_type(ps, wstart, "function") + if def.typeargs then + typ.typeargs = clone_typeargs(ps, i, def.typeargs) + end typ.is_method = true typ.min_arity = 1 typ.args = new_tuple(ps, wstart, { @@ -3978,10 +4006,17 @@ do i = i + 2 local asgn = new_node(ps, i, node_name) - i, asgn.var = parse_variable_name(ps, i) - if not asgn.var then + local var + + i, var = verify_kind(ps, i, "identifier") + if not var then return fail(ps, i, "expected a type name") end + local typeargs + if ps.tokens[i].tk == "<" then + i, typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + asgn.var = var if node_name == "global_type" and ps.tokens[i].tk ~= "=" then return i, asgn @@ -4005,6 +4040,10 @@ do local nt = asgn.value.newtype if nt.typename == "typedecl" then + if typeargs then + nt.typeargs = typeargs + end + local def = nt.def if def.fields or def.typename == "enum" then if not def.declname then @@ -4399,6 +4438,11 @@ local function recurse_type(s, ast, visit) elseif ast.typename == "typealias" then table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(s, child, visit)) + end + end table.insert(xs, recurse_type(s, ast.def, visit)) end @@ -7146,8 +7190,17 @@ do local errs local seen = {} local resolved = {} + local resolve - local function resolve(t, all_same) + local function copy_typeargs(t, same) + local copy = {} + for i, tf in ipairs(t) do + copy[i], same = resolve(tf, same) + end + return copy, same + end + + resolve = function(t, all_same) local same = true @@ -7207,6 +7260,11 @@ do end elseif t.typename == "typedecl" then assert(copy.typename == "typedecl") + + if t.typeargs then + copy.typeargs, same = copy_typeargs(t.typeargs, same) + end + copy.def, same = resolve(t.def, same) elseif t.typename == "typealias" then assert(copy.typename == "typealias") @@ -7224,12 +7282,10 @@ do assert(copy.typename == "function") if t.typeargs then - copy.typeargs = {} - for i, tf in ipairs(t.typeargs) do - copy.typeargs[i], same = resolve(tf, same) - end + copy.typeargs, same = copy_typeargs(t.typeargs, same) end + copy.macroexp = t.macroexp copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) @@ -7239,10 +7295,7 @@ do copy.declname = t.declname if t.typeargs then - copy.typeargs = {} - for i, tf in ipairs(t.typeargs) do - copy.typeargs[i], same = resolve(tf, same) - end + copy.typeargs, same = copy_typeargs(t.typeargs, same) end @@ -7679,6 +7732,7 @@ do assert(not (def.typename == "nominal")) t.found = found + return nil, found end @@ -9066,6 +9120,7 @@ a.types[i], b.types[i]), } end local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) self:end_scope() @@ -12274,10 +12329,10 @@ self:expand_type(node, values, elements) }) } visit_type.cbs["interface"] = visit_type.cbs["record"] + visit_type.cbs["typedecl"] = visit_type.cbs["function"] visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 33d0c70c9..394667945 100644 --- a/tl.tl +++ b/tl.tl @@ -1614,6 +1614,7 @@ local record TypeDeclType is Type where self.typename == "typedecl" + typeargs: {TypeArgType} def: Type closed: boolean end @@ -3615,13 +3616,16 @@ local metamethod_names: {string:boolean} = { } local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): integer, Node --- TODO: generic macroexp --- if ps.tokens[i].tk == "<" then --- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) --- end local node = new_node(ps, istart, "macroexp") + local i: integer - i, node.args, node.min_arity = parse_argument_list(ps, iargs) + if ps.tokens[istart + 1].tk == "<" then + i, node.typeargs = parse_anglebracket_list(ps, istart + 1, parse_typearg) + else + i = iargs + end + + i, node.args, node.min_arity = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i = verify_tk(ps, i, "return") i, node.exp = parse_expression(ps, i) @@ -3630,12 +3634,22 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): return i, node end -local function parse_where_clause(ps: ParseState, i: integer): integer, Node +local function parse_where_clause(ps: ParseState, i: integer, typeargs: {TypeArgType}): integer, Node local node = new_node(ps, i, "macroexp") + + local selftype = new_nominal(ps, i, "@self") + if typeargs then + selftype.typevals = {} + for a, t in ipairs(typeargs) do + selftype.typevals[a] = a_nominal(node, { t.typearg }) + end + end + + node.is_method = true node.args = new_node(ps, i, "argument_list") node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = new_nominal(ps, i, "@self") + node.args[1].argtype = selftype node.min_arity = 1 node.rets = new_tuple(ps, i) node.rets.tuple[1] = new_type(ps, i, "boolean") @@ -3674,6 +3688,17 @@ local function parse_array_interface_type(ps: ParseState, i: integer, def: Recor return i, t end +local function clone_typeargs(ps: ParseState, i: integer, typeargs: {TypeArgType}): {TypeArgType} + local copy = {} + for a, ta in ipairs(typeargs) do + local cta = new_type(ps, i, "typearg") as TypeArgType + cta.typearg = ta.typearg + copy[a] = cta + end + return copy +end + + parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, node: Node): integer, Node local istart = i - 1 def.fields = {} @@ -3715,9 +3740,12 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no local wstart = i i = i + 1 local where_macroexp: Node - i, where_macroexp = parse_where_clause(ps, i) + i, where_macroexp = parse_where_clause(ps, i, def.typeargs) local typ = new_type(ps, wstart, "function") as FunctionType + if def.typeargs then + typ.typeargs = clone_typeargs(ps, i, def.typeargs) + end typ.is_method = true typ.min_arity = 1 typ.args = new_tuple(ps, wstart, { @@ -3978,10 +4006,17 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod i = i + 2 -- skip `local` or `global`, and `type` local asgn: Node = new_node(ps, i, node_name) - i, asgn.var = parse_variable_name(ps, i) - if not asgn.var then + local var: Node + + i, var = verify_kind(ps, i, "identifier") + if not var then return fail(ps, i, "expected a type name") end + local typeargs: {TypeArgType} + if ps.tokens[i].tk == "<" then + i, typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + asgn.var = var if node_name == "global_type" and ps.tokens[i].tk ~= "=" then return i, asgn @@ -4005,6 +4040,10 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod local nt = asgn.value.newtype if nt is TypeDeclType then + if typeargs then + nt.typeargs = typeargs + end + local def = nt.def if def is RecordLikeType or def is EnumType then if not def.declname then @@ -4399,6 +4438,11 @@ local function recurse_type(s: S, ast: Type, visit: Visitor(t: T, all_same: boolean): T, boolean + + local function copy_typeargs(t: {TypeArgType}, same: boolean): {TypeArgType}, boolean + local copy = {} + for i, tf in ipairs(t) do + copy[i], same = resolve(tf, same) as (TypeArgType, boolean) + end + return copy, same + end - local function resolve(t: T, all_same: boolean): T, boolean + resolve = function(t: T, all_same: boolean): T, boolean local same = true -- avoid copies of types that do not contain type variables @@ -7207,6 +7260,11 @@ do end elseif t is TypeDeclType then assert(copy is TypeDeclType) + + if t.typeargs then + copy.typeargs, same = copy_typeargs(t.typeargs, same) + end + copy.def, same = resolve(t.def, same) elseif t is TypeAliasType then assert(copy is TypeAliasType) @@ -7224,12 +7282,10 @@ do assert(copy is FunctionType) if t.typeargs then - copy.typeargs = {} - for i, tf in ipairs(t.typeargs) do - copy.typeargs[i], same = resolve(tf, same) as (TypeArgType, boolean) - end + copy.typeargs, same = copy_typeargs(t.typeargs, same) end + copy.macroexp = t.macroexp copy.min_arity = t.min_arity copy.is_method = t.is_method copy.args, same = resolve(t.args, same) as (TupleType, boolean) @@ -7239,10 +7295,7 @@ do copy.declname = t.declname if t.typeargs then - copy.typeargs = {} - for i, tf in ipairs(t.typeargs) do - copy.typeargs[i], same = resolve(tf, same) as (TypeArgType, boolean) - end + copy.typeargs, same = copy_typeargs(t.typeargs, same) end -- checking array interface @@ -7679,6 +7732,7 @@ do assert(not def is NominalType) t.found = found + return nil, found end @@ -9066,6 +9120,7 @@ do end local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) self:end_scope() @@ -12274,10 +12329,10 @@ do } visit_type.cbs["interface"] = visit_type.cbs["record"] + visit_type.cbs["typedecl"] = visit_type.cbs["function"] visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor From 3fb26dc0dffd01d2f6c892048ca1c6fed963edfd Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 1 Jun 2024 20:11:40 -0300 Subject: [PATCH 122/224] interfaces: fix collect_interfaces --- spec/subtyping/interface_spec.lua | 18 ++++++++++++++++++ tl.lua | 16 +++++++--------- tl.tl | 16 +++++++--------- 3 files changed, 32 insertions(+), 18 deletions(-) create mode 100644 spec/subtyping/interface_spec.lua diff --git a/spec/subtyping/interface_spec.lua b/spec/subtyping/interface_spec.lua new file mode 100644 index 000000000..727a17884 --- /dev/null +++ b/spec/subtyping/interface_spec.lua @@ -0,0 +1,18 @@ +local util = require("spec.util") + +describe("subtyping of interfaces:", function() + it("record inherits interface array definition", util.check([[ + local interface MyInterface + is {MyInterface} + x: integer + end + + local record MyRecord + is MyInterface + end + + local r: MyRecord = {} + print(#r) + ]])) +end) + diff --git a/tl.lua b/tl.lua index 142f39176..68d89d010 100644 --- a/tl.lua +++ b/tl.lua @@ -12087,16 +12087,14 @@ self:expand_type(node, values, elements) }) for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then local ri = self:resolve_nominal(iface) - if not (ri.typename == "invalid") then - if ri.typename == "interface" then - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(self, list, ri, seen) - end - table.insert(list, iface) - else - self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) + if ri.typename == "interface" then + if ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end else if not seen[iface] then diff --git a/tl.tl b/tl.tl index 394667945..64364a3ac 100644 --- a/tl.tl +++ b/tl.tl @@ -12087,16 +12087,14 @@ do for _, iface in ipairs(t.interface_list) do if iface is NominalType then local ri = self:resolve_nominal(iface) - if not (ri.typename == "invalid") then - if ri is InterfaceType then - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(self, list, ri, seen) - end - table.insert(list, iface) - else - self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) + if ri is InterfaceType then + if ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end else if not seen[iface] then From 5bb371604786555b563e05eaf96ef7b7d7a86d31 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 1 Jun 2024 20:32:02 -0300 Subject: [PATCH 123/224] total: do not consider a record function to be a missing field Fixes #747. --- spec/declaration/local_spec.lua | 12 ++++++++++++ tl.lua | 4 +++- tl.tl | 4 +++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index ea3819081..a1ae1b45a 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -412,6 +412,18 @@ describe("local", function() likes = {name='orange'} } ]])) + + it("does not consider a record function to be a missing field", util.check([[ + local record A + v: number + end + + function A:echo() + print('A:', self.v) + end + + local b : A = { v = 10 } + ]])) end) describe("", function() diff --git a/tl.lua b/tl.lua index 68d89d010..78ccf4f65 100644 --- a/tl.lua +++ b/tl.lua @@ -1892,6 +1892,7 @@ end + local TruthyFact = {} @@ -10559,7 +10560,7 @@ self:expand_type(node, values, elements) }) local missing for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype.typename == "typedecl" or ftype.typename == "typealias") then + if not (ftype.typename == "typedecl" or ftype.typename == "typealias" or (ftype.typename == "function" and ftype.is_record_function)) then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -11443,6 +11444,7 @@ self:expand_type(node, values, elements) }) typeargs = node.typeargs, args = args, rets = self.get_rets(rets), + is_record_function = true, })) local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) diff --git a/tl.tl b/tl.tl index 64364a3ac..068d2a2ca 100644 --- a/tl.tl +++ b/tl.tl @@ -1784,6 +1784,7 @@ local record FunctionType where self.typename == "function" is_method: boolean + is_record_function: boolean min_arity: integer args: TupleType rets: TupleType @@ -10559,7 +10560,7 @@ do local missing: {string} for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype is TypeDeclType or ftype is TypeAliasType) then + if not (ftype is TypeDeclType or ftype is TypeAliasType or (ftype is FunctionType and ftype.is_record_function)) then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -11443,6 +11444,7 @@ do typeargs = node.typeargs, args = args, rets = self.get_rets(rets), + is_record_function = true, })) local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) From 823a893bbc7aee706d82723ec2a5001a7f29fb9d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 12 Jun 2024 09:49:22 -0300 Subject: [PATCH 124/224] total: ignore metamethods when checking for record totality Fixes #749. --- tl.lua | 4 ---- tl.tl | 4 ---- 2 files changed, 8 deletions(-) diff --git a/tl.lua b/tl.lua index 78ccf4f65..28aa5e77e 100644 --- a/tl.lua +++ b/tl.lua @@ -10552,10 +10552,6 @@ self:expand_type(node, values, elements) }) end local function total_record_check(t, seen_keys) - if t.meta_field_order then - return false - end - local is_total = true local missing for _, key in ipairs(t.field_order) do diff --git a/tl.tl b/tl.tl index 068d2a2ca..ab975353e 100644 --- a/tl.tl +++ b/tl.tl @@ -10552,10 +10552,6 @@ do end local function total_record_check(t: RecordLikeType, seen_keys: {CheckableKey:Where}): boolean, {string} - if t.meta_field_order then - return false - end - local is_total = true local missing: {string} for _, key in ipairs(t.field_order) do From da75aa2eddf48553134f87e97ae731a6ea5493ef Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 14 Jun 2024 15:44:15 -0300 Subject: [PATCH 125/224] Accept __lt and __le for > and >= operators. As in Lua, "A comparison a > b is translated to b < a and a >= b is translated to b <= a." --- spec/metamethods/le_spec.lua | 205 +++++++++++++++++++++++++++++++++++ spec/metamethods/lt_spec.lua | 205 +++++++++++++++++++++++++++++++++++ tl.lua | 12 ++ tl.tl | 12 ++ 4 files changed, 434 insertions(+) create mode 100644 spec/metamethods/le_spec.lua create mode 100644 spec/metamethods/lt_spec.lua diff --git a/spec/metamethods/le_spec.lua b/spec/metamethods/le_spec.lua new file mode 100644 index 000000000..c0eaaa36c --- /dev/null +++ b/spec/metamethods/le_spec.lua @@ -0,0 +1,205 @@ +local util = require("spec.util") + +describe("binary metamethod __le using <=", function() + it("can be set on a record", util.check([[ + local type Rec = record + x: number + metamethod __call: function(Rec, string, number): string + metamethod __le: function(Rec, Rec): boolean + end + + local rec_mt: metatable + rec_mt = { + __call = function(self: Rec, s: string, n: number): string + return tostring(self.x + n) .. s + end, + __le = function(a: Rec, b: Rec): boolean + return a.x <= b.x + end + } + + local r = setmetatable({ x = 10 } as Rec, rec_mt) + local s = setmetatable({ x = 20 } as Rec, rec_mt) + + if r <= s then + print("yes") + end + ]])) + + it("can be used on a record prototype", util.check([[ + local record A + value: number + metamethod __call: function(A, number): A + metamethod __le: function(A, A): boolean + end + local A_mt: metatable + A_mt = { + __call = function(a: A, v: number): A + return setmetatable({value = v} as A, A_mt) + end, + __le = function(a: A, b: A): boolean + return a.value <= b.value + end, + } + + A.value = 10 + if A <= A then + print("wat!?") + end + ]])) + + it("can be used via the second argument", util.check([[ + local type Rec = record + x: number + metamethod __le: function(number, Rec): Rec + end + + local rec_mt: metatable + rec_mt = { + __le = function(a: number, b: Rec): boolean + return a <= b.x + end + } + + local s = setmetatable({ y = 20 } as Rec, rec_mt) + + if 10 <= s then + print("yes") + end + ]])) + + it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[ + local type Temperature = record + n: number + metamethod __le: function(t1: Temperature, t2: Temperature): boolean + end + + local type Date = record + n: number + metamethod __le: function(t1: Date, t2: Date): boolean + end + + local temp2: Temperature = { n = 45 } + local birthday2 : Date = { n = 34 } + + setmetatable(temp2, { + __le = function(t1: Temperature, t2: Temperature): boolean + return t1.n <= t2.n + end, + }) + + setmetatable(birthday2, { + __le = function(t1: Date, t2: Date): boolean + return t1.n <= t2.n + end, + }) + + if temp2 <= birthday2 then + print("wat") + end + ]], { + { y = 26, msg = "Date is not a Temperature" }, + })) +end) + +describe("binary metamethod __le using >=", function() + it("can be set on a record", util.check([[ + local type Rec = record + x: number + metamethod __call: function(Rec, string, number): string + metamethod __le: function(Rec, Rec): boolean + end + + local rec_mt: metatable + rec_mt = { + __call = function(self: Rec, s: string, n: number): string + return tostring(self.x + n) .. s + end, + __le = function(a: Rec, b: Rec): boolean + return a.x <= b.x + end + } + + local r = setmetatable({ x = 10 } as Rec, rec_mt) + local s = setmetatable({ x = 20 } as Rec, rec_mt) + + if s >= r then + print("yes") + end + ]])) + + it("can be used on a record prototype", util.check([[ + local record A + value: number + metamethod __call: function(A, number): A + metamethod __le: function(A, A): boolean + end + local A_mt: metatable + A_mt = { + __call = function(a: A, v: number): A + return setmetatable({value = v} as A, A_mt) + end, + __le = function(a: A, b: A): boolean + return a.value <= b.value + end, + } + + A.value = 10 + if A >= A then + print("wat!?") + end + ]])) + + it("can be used via the second argument", util.check([[ + local type Rec = record + x: number + metamethod __le: function(number, Rec): Rec + end + + local rec_mt: metatable + rec_mt = { + __le = function(a: number, b: Rec): boolean + return a <= b.x + end + } + + local s = setmetatable({ y = 20 } as Rec, rec_mt) + + if s >= 10 then + print("yes") + end + ]])) + + it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[ + local type Temperature = record + n: number + metamethod __le: function(t1: Temperature, t2: Temperature): boolean + end + + local type Date = record + n: number + metamethod __le: function(t1: Date, t2: Date): boolean + end + + local temp2: Temperature = { n = 45 } + local birthday2 : Date = { n = 34 } + + setmetatable(temp2, { + __le = function(t1: Temperature, t2: Temperature): boolean + return t1.n <= t2.n + end, + }) + + setmetatable(birthday2, { + __le = function(t1: Date, t2: Date): boolean + return t1.n <= t2.n + end, + }) + + if birthday2 >= temp2 then + print("wat") + end + ]], { + { y = 26, msg = "Date is not a Temperature" }, + })) +end) diff --git a/spec/metamethods/lt_spec.lua b/spec/metamethods/lt_spec.lua new file mode 100644 index 000000000..56c87bbfb --- /dev/null +++ b/spec/metamethods/lt_spec.lua @@ -0,0 +1,205 @@ +local util = require("spec.util") + +describe("binary metamethod __lt using <", function() + it("can be set on a record", util.check([[ + local type Rec = record + x: number + metamethod __call: function(Rec, string, number): string + metamethod __lt: function(Rec, Rec): boolean + end + + local rec_mt: metatable + rec_mt = { + __call = function(self: Rec, s: string, n: number): string + return tostring(self.x + n) .. s + end, + __lt = function(a: Rec, b: Rec): boolean + return a.x < b.x + end + } + + local r = setmetatable({ x = 10 } as Rec, rec_mt) + local s = setmetatable({ x = 20 } as Rec, rec_mt) + + if r < s then + print("yes") + end + ]])) + + it("can be used on a record prototype", util.check([[ + local record A + value: number + metamethod __call: function(A, number): A + metamethod __lt: function(A, A): boolean + end + local A_mt: metatable + A_mt = { + __call = function(a: A, v: number): A + return setmetatable({value = v} as A, A_mt) + end, + __lt = function(a: A, b: A): boolean + return a.value < b.value + end, + } + + A.value = 10 + if A < A then + print("wat!?") + end + ]])) + + it("can be used via the second argument", util.check([[ + local type Rec = record + x: number + metamethod __lt: function(number, Rec): Rec + end + + local rec_mt: metatable + rec_mt = { + __lt = function(a: number, b: Rec): boolean + return a < b.x + end + } + + local s = setmetatable({ y = 20 } as Rec, rec_mt) + + if 10 < s then + print("yes") + end + ]])) + + it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[ + local type Temperature = record + n: number + metamethod __lt: function(t1: Temperature, t2: Temperature): boolean + end + + local type Date = record + n: number + metamethod __lt: function(t1: Date, t2: Date): boolean + end + + local temp2: Temperature = { n = 45 } + local birthday2 : Date = { n = 34 } + + setmetatable(temp2, { + __lt = function(t1: Temperature, t2: Temperature): boolean + return t1.n < t2.n + end, + }) + + setmetatable(birthday2, { + __lt = function(t1: Date, t2: Date): boolean + return t1.n < t2.n + end, + }) + + if temp2 < birthday2 then + print("wat") + end + ]], { + { y = 26, msg = "Date is not a Temperature" }, + })) +end) + +describe("binary metamethod __lt using >", function() + it("can be set on a record", util.check([[ + local type Rec = record + x: number + metamethod __call: function(Rec, string, number): string + metamethod __lt: function(Rec, Rec): boolean + end + + local rec_mt: metatable + rec_mt = { + __call = function(self: Rec, s: string, n: number): string + return tostring(self.x + n) .. s + end, + __lt = function(a: Rec, b: Rec): boolean + return a.x < b.x + end + } + + local r = setmetatable({ x = 10 } as Rec, rec_mt) + local s = setmetatable({ x = 20 } as Rec, rec_mt) + + if s > r then + print("yes") + end + ]])) + + it("can be used on a record prototype", util.check([[ + local record A + value: number + metamethod __call: function(A, number): A + metamethod __lt: function(A, A): boolean + end + local A_mt: metatable + A_mt = { + __call = function(a: A, v: number): A + return setmetatable({value = v} as A, A_mt) + end, + __lt = function(a: A, b: A): boolean + return a.value < b.value + end, + } + + A.value = 10 + if A > A then + print("wat!?") + end + ]])) + + it("can be used via the second argument", util.check([[ + local type Rec = record + x: number + metamethod __lt: function(number, Rec): Rec + end + + local rec_mt: metatable + rec_mt = { + __lt = function(a: number, b: Rec): boolean + return a < b.x + end + } + + local s = setmetatable({ y = 20 } as Rec, rec_mt) + + if s > 10 then + print("yes") + end + ]])) + + it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[ + local type Temperature = record + n: number + metamethod __lt: function(t1: Temperature, t2: Temperature): boolean + end + + local type Date = record + n: number + metamethod __lt: function(t1: Date, t2: Date): boolean + end + + local temp2: Temperature = { n = 45 } + local birthday2 : Date = { n = 34 } + + setmetatable(temp2, { + __lt = function(t1: Temperature, t2: Temperature): boolean + return t1.n < t2.n + end, + }) + + setmetatable(birthday2, { + __lt = function(t1: Date, t2: Date): boolean + return t1.n < t2.n + end, + }) + + if birthday2 > temp2 then + print("wat") + end + ]], { + { y = 26, msg = "Date is not a Temperature" }, + })) +end) diff --git a/tl.lua b/tl.lua index 28aa5e77e..105168f94 100644 --- a/tl.lua +++ b/tl.lua @@ -6435,6 +6435,11 @@ local binop_to_metamethod = { ["is"] = "__is", } +local flip_binop_to_metamethod = { + [">"] = "__lt", + [">="] = "__le", +} + local function is_unknown(t) return t.typename == "unknown" or t.typename == "unresolved_emptytable_value" @@ -11889,6 +11894,13 @@ self:expand_type(node, values, elements) }) local meta_on_operator if not t then local mt_name = binop_to_metamethod[node.op.op] + if not mt_name then + mt_name = flip_binop_to_metamethod[node.op.op] + if mt_name then + ra, rb = rb, ra + ua, ub = ub, ua + end + end if mt_name then t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end diff --git a/tl.tl b/tl.tl index ab975353e..ac35a49f2 100644 --- a/tl.tl +++ b/tl.tl @@ -6435,6 +6435,11 @@ local binop_to_metamethod: {string:string} = { ["is"] = "__is", } +local flip_binop_to_metamethod: {string:string} = { + [">"] = "__lt", + [">="] = "__le", +} + local function is_unknown(t: Type): boolean return t.typename == "unknown" or t.typename == "unresolved_emptytable_value" @@ -11889,6 +11894,13 @@ do local meta_on_operator: integer if not t then local mt_name = binop_to_metamethod[node.op.op] + if not mt_name then + mt_name = flip_binop_to_metamethod[node.op.op] + if mt_name then + ra, rb = rb, ra + ua, ub = ub, ua + end + end if mt_name then t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end From 73f104fd2c7f0f3db6e969a476e584d179daee8d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 19 Jul 2024 13:46:54 -0300 Subject: [PATCH 126/224] fix: can use nested records in global 'table' record --- spec/declaration/local_spec.lua | 5 +++++ tl.lua | 2 +- tl.tl | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index a1ae1b45a..6b93b8b1b 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -98,6 +98,11 @@ describe("local", function() { msg = "b" }, })) + it("local type can declare a type alias for table", util.check([[ + local type PackTable = table.PackTable + local args: table.PackTable = table.pack(1, 2, 3) + ]])) + it("local type can declare a nominal type alias (regression test for #238)", function () util.mock_io(finally, { ["module.tl"] = [[ diff --git a/tl.lua b/tl.lua index 105168f94..cc321c809 100644 --- a/tl.lua +++ b/tl.lua @@ -2591,7 +2591,7 @@ do local st = simple_types[tk] if st then return i + 1, new_type(ps, i, tk) - elseif tk == "table" then + elseif tk == "table" and ps.tokens[i + 1].tk ~= "." then local typ = new_type(ps, i, "map") typ.keys = new_type(ps, i, "any") typ.values = new_type(ps, i, "any") diff --git a/tl.tl b/tl.tl index ac35a49f2..5b941e29a 100644 --- a/tl.tl +++ b/tl.tl @@ -2591,7 +2591,7 @@ local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer local st = simple_types[tk as TypeName] if st then return i + 1, new_type(ps, i, tk as TypeName) - elseif tk == "table" then + elseif tk == "table" and ps.tokens[i + 1].tk ~= "." then local typ = new_type(ps, i, "map") as MapType typ.keys = new_type(ps, i, "any") typ.values = new_type(ps, i, "any") From 90f2a1ed2dfe55f0089086f555538672370878ce Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 19 Jul 2024 14:44:27 -0300 Subject: [PATCH 127/224] fix: do not crash when comparing type defined with 'function' Using the plain 'function' type needs to define 'min_arity'. --- spec/assignment/to_function_spec.lua | 12 ++++++++++++ tl.lua | 2 ++ tl.tl | 2 ++ 3 files changed, 16 insertions(+) create mode 100644 spec/assignment/to_function_spec.lua diff --git a/spec/assignment/to_function_spec.lua b/spec/assignment/to_function_spec.lua new file mode 100644 index 000000000..57c35d5af --- /dev/null +++ b/spec/assignment/to_function_spec.lua @@ -0,0 +1,12 @@ +local util = require("spec.util") + +describe("assignment to function", function() + it("does not crash when using plain function definitions", util.check([[ + local my_load: function(string, ? string, ? string, ? table): (function, string) + + local function run_file() + local chunk: function(any):(any) + chunk = my_load("") + end + ]])) +end) diff --git a/tl.lua b/tl.lua index cc321c809..609f69666 100644 --- a/tl.lua +++ b/tl.lua @@ -2582,6 +2582,8 @@ do else typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.is_method = false + typ.min_arity = 0 end return i, typ end diff --git a/tl.tl b/tl.tl index 5b941e29a..5f20ada41 100644 --- a/tl.tl +++ b/tl.tl @@ -2582,6 +2582,8 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio else typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.is_method = false + typ.min_arity = 0 end return i, typ end From 8d6b58920af25883eab35d30baec6b17fbb701e3 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 19 Jul 2024 15:33:19 -0300 Subject: [PATCH 128/224] accept 'or' as a boolean in if/while/repeat expressions --- spec/statement/if_spec.lua | 9 +++++++++ spec/statement/repeat_spec.lua | 23 +++++++++++++++++++++++ spec/statement/while_spec.lua | 26 ++++++++++++++++++++++++++ tl.lua | 9 +++++++++ tl.tl | 9 +++++++++ 5 files changed, 76 insertions(+) create mode 100644 spec/statement/while_spec.lua diff --git a/spec/statement/if_spec.lua b/spec/statement/if_spec.lua index 99f009e12..7930942d1 100644 --- a/spec/statement/if_spec.lua +++ b/spec/statement/if_spec.lua @@ -16,6 +16,15 @@ describe("if", function() end ]])) + it("if expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + if n or s then + local ns: number | string = n or s + print(ns) + end + ]])) + it("accepts boolean expressions", util.check([[ local s = "Hallo, Welt" if string.match(s, "world") or s == "Hallo, Welt" then diff --git a/spec/statement/repeat_spec.lua b/spec/statement/repeat_spec.lua index fcfa43d6a..9bb2160d7 100644 --- a/spec/statement/repeat_spec.lua +++ b/spec/statement/repeat_spec.lua @@ -1,6 +1,29 @@ local util = require("spec.util") describe("repeat", function() + it("accepts a boolean", util.check([[ + local b = true + repeat + print(b) + until b + ]])) + + it("accepts a non-boolean", util.check([[ + local n = 123 + repeat + print(n) + until n + ]])) + + it("until expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + repeat + local ns: number | string = n or s + print(ns) + until n or s + ]])) + it("only closes scope after until", util.check([[ repeat local type R = record diff --git a/spec/statement/while_spec.lua b/spec/statement/while_spec.lua new file mode 100644 index 000000000..07023f603 --- /dev/null +++ b/spec/statement/while_spec.lua @@ -0,0 +1,26 @@ +local util = require("spec.util") + +describe("while", function() + it("accepts a boolean", util.check([[ + local b = true + while b do + print(b) + end + ]])) + + it("accepts a non-boolean", util.check([[ + local n = 123 + while n do + print(n) + end + ]])) + + it("while expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + while n or s do + local ns: number | string = n or s + print(ns) + end + ]])) +end) diff --git a/tl.lua b/tl.lua index 609f69666..92e2bf388 100644 --- a/tl.lua +++ b/tl.lua @@ -10844,6 +10844,9 @@ self:expand_type(node, values, elements) }) if node.if_block_n > 1 then self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end + if node.exp then + node.exp.expected = a_type(node, "boolean", {}) + end end, before_statements = function(self, node) if node.exp then @@ -10864,6 +10867,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, before_statements = function(self, node) self:begin_scope(node) @@ -10927,6 +10931,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, after = end_scope_and_none_type, @@ -11787,6 +11792,10 @@ self:expand_type(node, values, elements) }) end t = drop_constant_value(t) end + + if expected and expected.typename == "boolean" then + t = a_type(node, "boolean", {}) + end end if t then diff --git a/tl.tl b/tl.tl index 5f20ada41..1bc14f12e 100644 --- a/tl.tl +++ b/tl.tl @@ -10844,6 +10844,9 @@ do if node.if_block_n > 1 then self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end + if node.exp then + node.exp.expected = a_type(node, "boolean", {}) + end end, before_statements = function(self: TypeChecker, node: Node) if node.exp then @@ -10864,6 +10867,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, before_statements = function(self: TypeChecker, node: Node) self:begin_scope(node) @@ -10927,6 +10931,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, @@ -11787,6 +11792,10 @@ do end t = drop_constant_value(t) end + + if expected and expected is BooleanType then + t = a_type(node, "boolean", {}) + end end if t then From e85ac753532a383d7bd64403cec5127ef14b3e1e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jul 2024 11:31:11 -0300 Subject: [PATCH 129/224] we can now reexport nested types Fixes #765. --- spec/declaration/record_spec.lua | 34 +++++++ spec/util.lua | 6 ++ tl.lua | 130 +++++++++++++++++-------- tl.tl | 158 ++++++++++++++++++++----------- 4 files changed, 236 insertions(+), 92 deletions(-) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index a188a482c..2af252bcc 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -555,6 +555,40 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces ]])() end) + it("can reexport types as nested " .. name, function() + util.mock_io(finally, { + ["inner.tl"] = [[ + local record inner + ]]..statement..[[ SubType ]]..array(i, "{integer}")..[[ + item: K + end + end + + return inner + ]], + ["outer.tl"] = [[ + local core = require("inner") + + local record mod + f: string + type SubType = core.SubType + end + + return mod + ]], + }) + util.run_check_type_error([[ + local mod = require("outer") + + print(mod.f) + local v: mod.SubType = { + item = "hello" + } + ]], { + { msg = 'in record field: item: got string "hello", expected integer' } + }) + end) + it("resolves aliasing of nested " .. name .. " (see #400)", util.check([[ local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ ]]..statement..[[ Bar ]]..array(i, "{Bar}")..[[ diff --git a/spec/util.lua b/spec/util.lua index 25cbdae50..bb6b9af4b 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -436,6 +436,12 @@ local function check(lax, code, unknowns, gen_target) gen_compat = "off" end local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) + + for _, mname in pairs(result.env.loaded_order) do + local mresult = result.env.loaded[mname] + batch:add(assert.same, {}, mresult.syntax_errors or {}, "Code was not expected to have syntax errors") + end + batch:add(assert.same, {}, result.type_errors) if unknowns then diff --git a/tl.lua b/tl.lua index 92e2bf388..6a7203f7e 100644 --- a/tl.lua +++ b/tl.lua @@ -1892,7 +1892,6 @@ end - local TruthyFact = {} @@ -2219,6 +2218,7 @@ do local parse_argument_list local parse_argument_type_list local parse_type + local parse_type_declaration local parse_newtype local parse_interface_name @@ -3776,14 +3776,19 @@ do elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") + + local lt + i, lt = parse_type_declaration(ps, i, "local_type") + if not lt then + return fail(ps, i, "expected a type definition") + end + + local v = lt.var if not v then return fail(ps, i, "expected a variable name") end - i = verify_tk(ps, i, "=") - local nt - i, nt = parse_newtype(ps, i) + + local nt = lt.value if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end @@ -4005,9 +4010,7 @@ do return i, asgn end - local function parse_type_declaration(ps, i, node_name) - i = i + 2 - + parse_type_declaration = function(ps, i, node_name) local asgn = new_node(ps, i, node_name) local var @@ -4053,6 +4056,10 @@ do def.declname = asgn.var.tk end end + elseif nt.typename == "typealias" then + if typeargs then + nt.typeargs = typeargs + end end return i, asgn @@ -4083,7 +4090,7 @@ do end local function skip_type_declaration(ps, i) - return parse_type_declaration(ps, i - 1, "local_type") + return parse_type_declaration(ps, i + 1, "local_type") end local function parse_local_macroexp(ps, i) @@ -4102,7 +4109,7 @@ do if ntk == "function" then return parse_local_function(ps, i) elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") + return parse_type_declaration(ps, i + 2, "local_type") elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then @@ -4117,7 +4124,7 @@ do if ntk == "function" then return parse_function(ps, i + 1, "global") elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") + return parse_type_declaration(ps, i + 2, "global_type") elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) elseif ps.tokens[i + 1].kind == "identifier" then @@ -4439,6 +4446,11 @@ local function recurse_type(s, ast, visit) table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast.typename == "typealias" then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(s, child, visit)) + end + end table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then if ast.typeargs then @@ -7014,7 +7026,7 @@ do fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) - assert(ok, "Internal Compiler Error: error creating fresh type variables") + assert(ok and t, "Internal Compiler Error: error creating fresh type variables") return t end @@ -7727,6 +7739,10 @@ do found = found.alias_to.found end + if not found then + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a resolved type") + end + if not (found.typename == "typedecl") then return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end @@ -7747,7 +7763,7 @@ do local function resolve_decl_into_nominal(self, t, found) local def = found.def local resolved - if def.typename == "record" or def.typename == "function" then + if def.fields or def.typename == "function" then resolved = match_typevals(self, t, def) if not resolved then return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") @@ -7773,7 +7789,7 @@ do local t = typealias.alias_to local immediate, found = find_nominal_type_decl(self, t) - if immediate then + if type(immediate) == "table" then return immediate end @@ -8148,6 +8164,16 @@ do end + function TypeChecker:forall_are_subtype_of(xs, t) + for _, x in ipairs(xs.types) do + if not self:is_a(x, t) then + return false + end + end + return true + end + + local emptytable_relations = { ["array"] = compare_true, ["map"] = compare_true, @@ -8279,6 +8305,15 @@ do ["*"] = compare_true, }, ["union"] = { + ["nominal"] = function(self, a, b) + + local rb = self:resolve_nominal(b) + if rb.typename == "union" then + return self:is_a(a, rb) + end + + return self:forall_are_subtype_of(a, b) + end, ["union"] = function(self, a, b) local used = {} for _, t in ipairs(a.types) do @@ -8297,14 +8332,7 @@ do end return true end, - ["*"] = function(self, a, b) - for _, t in ipairs(a.types) do - if not self:is_a(t, b) then - return false - end - end - return true - end, + ["*"] = TypeChecker.forall_are_subtype_of, }, ["poly"] = { ["*"] = function(self, a, b) @@ -8321,21 +8349,36 @@ do return true end - local rb = self:resolve_nominal(b) - if rb.typename == "interface" then + local ra = self:resolve_nominal(a) + local rb = self:resolve_nominal(b) + if ra.typename == "union" and rb.typename == "union" then + return self:is_a(ra, rb) + end + if ra.typename == "union" then + return self:is_a(ra, b) + end + if rb.typename == "union" then return self:is_a(a, rb) end - local ra = self:resolve_nominal(a) - if ra.typename == "union" or rb.typename == "union" then - return self:is_a(ra, rb) + if rb.typename == "interface" then + return self:is_a(a, rb) end return ok, errs end, + ["union"] = function(self, a, b) + + local ra = self:resolve_nominal(a) + if ra.typename == "union" then + return self:is_a(ra, b) + end + + return not not self:exists_supertype_in(a, b) + end, ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { @@ -10534,7 +10577,10 @@ self:expand_type(node, values, elements) }) value.e1.tk == "require" then local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0) + + local ty = t.typename == "tuple" and t.tuple[1] or t + ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) return td @@ -10676,6 +10722,9 @@ self:expand_type(node, values, elements) }) if node.value then local resolved, aliasing = self:get_typedecl(node.value) local added = self:add_global(node.var, name, resolved) + if resolved.typename == "invalid" then + return + end node.value.newtype = resolved if aliasing then added.aliasing = aliasing @@ -12159,18 +12208,19 @@ self:expand_type(node, values, elements) }) end end + local visit_type_with_typeargs = { + before = function(self, _typ) + self:begin_scope() + end, + after = function(self, typ, _children) + self:end_scope() + return self:ensure_fresh_typeargs(typ) + end, + } + local visit_type visit_type = { cbs = { - ["function"] = { - before = function(self, _typ) - self:begin_scope() - end, - after = function(self, typ, _children) - self:end_scope() - return self:ensure_fresh_typeargs(typ) - end, - }, ["record"] = { before = function(self, typ) self:begin_scope() @@ -12348,11 +12398,13 @@ self:expand_type(node, values, elements) }) } visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["typedecl"] = visit_type.cbs["function"] + + visit_type.cbs["function"] = visit_type_with_typeargs + visit_type.cbs["typedecl"] = visit_type_with_typeargs + visit_type.cbs["typealias"] = visit_type_with_typeargs visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor visit_type.cbs["enum"] = default_type_visitor diff --git a/tl.tl b/tl.tl index 1bc14f12e..c9f03a003 100644 --- a/tl.tl +++ b/tl.tl @@ -1610,17 +1610,23 @@ local record BooleanType where self.typename == "boolean" end -local record TypeDeclType +local interface HasTypeArgs is Type - where self.typename == "typedecl" + where self.typeargs typeargs: {TypeArgType} +end + +local record TypeDeclType + is Type, HasTypeArgs + where self.typename == "typedecl" + def: Type closed: boolean end local record TypeAliasType - is Type + is Type, HasTypeArgs where self.typename == "typealias" alias_to: NominalType @@ -1648,13 +1654,6 @@ local record Scope narrows: {string:boolean} end -local interface HasTypeArgs - is Type - where self.typeargs - - typeargs: {TypeArgType} -end - local interface HasDeclName declname: string end @@ -2219,6 +2218,7 @@ local parse_statements: function(ParseState, integer, ? boolean): integer, Node local parse_argument_list: function(ParseState, integer): integer, Node, integer local parse_argument_type_list: function(ParseState, integer): integer, TupleType, boolean, integer local parse_type: function(ParseState, integer): integer, Type, integer +local parse_type_declaration: function(ps: ParseState, i: integer, node_name: NodeKind): integer, Node local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer @@ -3776,14 +3776,19 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local iv = i - local v: Node - i, v = verify_kind(ps, i, "identifier", "type_identifier") + + local lt: Node + i, lt = parse_type_declaration(ps, i, "local_type") -- local_type Node will be discarded + if not lt then + return fail(ps, i, "expected a type definition") + end + + local v = lt.var if not v then return fail(ps, i, "expected a variable name") end - i = verify_tk(ps, i, "=") - local nt: Node - i, nt = parse_newtype(ps, i) + + local nt = lt.value if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end @@ -4005,9 +4010,7 @@ local function parse_variable_declarations(ps: ParseState, i: integer, node_name return i, asgn end -local function parse_type_declaration(ps: ParseState, i: integer, node_name: NodeKind): integer, Node - i = i + 2 -- skip `local` or `global`, and `type` - +parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKind): integer, Node local asgn: Node = new_node(ps, i, node_name) local var: Node @@ -4053,6 +4056,10 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod def.declname = asgn.var.tk end end + elseif nt is TypeAliasType then + if typeargs then + nt.typeargs = typeargs + end end return i, asgn @@ -4083,7 +4090,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod end local function skip_type_declaration(ps: ParseState, i: integer): integer, Node - return parse_type_declaration(ps, i - 1, "local_type") + return parse_type_declaration(ps, i + 1, "local_type") end local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node @@ -4101,8 +4108,8 @@ local function parse_local(ps: ParseState, i: integer): integer, Node local tn = ntk as TypeName if ntk == "function" then return parse_local_function(ps, i) - elseif ntk == "type" and ps.tokens[i+2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i + 2, "local_type") elseif ntk == "macroexp" and ps.tokens[i+2].kind == "identifier" then return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then @@ -4116,8 +4123,8 @@ local function parse_global(ps: ParseState, i: integer): integer, Node local tn = ntk as TypeName if ntk == "function" then return parse_function(ps, i + 1, "global") - elseif ntk == "type" and ps.tokens[i+2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i + 2, "global_type") elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) elseif ps.tokens[i+1].kind == "identifier" then @@ -4439,6 +4446,11 @@ local function recurse_type(s: S, ast: Type, visit: Visitor visit_type = { cbs = { - ["function"] = { - before = function(self: TypeChecker, _typ: Type) - self:begin_scope() - end, - after = function(self: TypeChecker, typ: Type, _children: {Type}): Type - self:end_scope() - return self:ensure_fresh_typeargs(typ) - end, - }, ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() @@ -12348,11 +12398,13 @@ do } visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["typedecl"] = visit_type.cbs["function"] + + visit_type.cbs["function"] = visit_type_with_typeargs + visit_type.cbs["typedecl"] = visit_type_with_typeargs + visit_type.cbs["typealias"] = visit_type_with_typeargs visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor visit_type.cbs["map"] = default_type_visitor visit_type.cbs["enum"] = default_type_visitor From 15c1bff49a20521272a1d0dd92544ddf1afb1e49 Mon Sep 17 00:00:00 2001 From: Victor Ilchev <46074073+V1K1NGbg@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:54:23 +0300 Subject: [PATCH 130/224] add the code output to the io.open function (#769) Updated definition for io.open function to expose the 3rd parameter containing the returning code as integer * Add the code output to the io.open function * generate the tl.lua --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 6a7203f7e..4c74e437f 100644 --- a/tl.lua +++ b/tl.lua @@ -145,7 +145,7 @@ do lines: function(? string, (number | FileMode)...): (function(): ((string | number)...)) lines: function(? string, (number | string)...): (function(): (string...)) - open: function(string, ? OpenMode): FILE, string + open: function(string, ? OpenMode): FILE, string, integer output: function(? FILE): FILE popen: function(string, ? OpenMode): FILE, string diff --git a/tl.tl b/tl.tl index c9f03a003..6766c2e62 100644 --- a/tl.tl +++ b/tl.tl @@ -145,7 +145,7 @@ do lines: function(? string, (number | FileMode)...): (function(): ((string | number)...)) lines: function(? string, (number | string)...): (function(): (string...)) - open: function(string, ? OpenMode): FILE, string + open: function(string, ? OpenMode): FILE, string, integer output: function(? FILE): FILE popen: function(string, ? OpenMode): FILE, string From 9326efdd77e635fcd5bf5f423daaeb2b000f1112 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jul 2024 19:42:24 -0300 Subject: [PATCH 131/224] fix regression in bivariant arg check --- spec/call/function_spec.lua | 7 +++++++ tl.lua | 6 +++++- tl.tl | 6 +++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/spec/call/function_spec.lua b/spec/call/function_spec.lua index 1aaab0d0f..3860c4047 100644 --- a/spec/call/function_spec.lua +++ b/spec/call/function_spec.lua @@ -97,5 +97,12 @@ describe("function calls", function() ]], { { y = 5, msg = "wrong number of arguments (given 3, expects at least 1 and at most 2)" }, })) + + it("with insufficient arguments (regression test)", util.check_type_error([[ + local chunk: function() + chunk = load() + ]], { + { y = 2, msg = "wrong number of arguments (given 0, expects at least 1 and at most 4)" }, + })) end) end) diff --git a/tl.lua b/tl.lua index 4c74e437f..b17b9bc54 100644 --- a/tl.lua +++ b/tl.lua @@ -8541,7 +8541,11 @@ a.types[i], b.types[i]), } table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + local ai = aa[i] + local bi = ba[i] or ba[#ba] + if bi then + self:arg_check(nil, errs, ai, bi, "bivariant", "argument", i) + end end end diff --git a/tl.tl b/tl.tl index 6766c2e62..e7c5cb055 100644 --- a/tl.tl +++ b/tl.tl @@ -8541,7 +8541,11 @@ do table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + local ai = aa[i] + local bi = ba[i] or ba[#ba] + if bi then + self:arg_check(nil, errs, ai, bi, "bivariant", "argument", i) + end end end From d121f2e5c3b408c0a66552ef6b8831a1450ea647 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jul 2024 20:18:48 -0300 Subject: [PATCH 132/224] can now apply '#' on enum values --- spec/operator/len_spec.lua | 14 ++++++++++++++ tl.lua | 1 + tl.tl | 1 + 3 files changed, 16 insertions(+) diff --git a/spec/operator/len_spec.lua b/spec/operator/len_spec.lua index f4779bf5c..8ae53e2d0 100644 --- a/spec/operator/len_spec.lua +++ b/spec/operator/len_spec.lua @@ -5,6 +5,20 @@ describe("#", function() local x: integer = #({1, 2, 3}) ]])) + it("can run on a string", util.check([[ + local s = "hello" + local len = #s + ]])) + + it("can run on an enum value", util.check([[ + local enum Enum + "hello" + end + + local s: Enum = "hello" + local len = #s + ]])) + it("returns an integer when used on tuple", util.check([[ local x: integer = #({1, "hi"}) ]])) diff --git a/tl.lua b/tl.lua index b17b9bc54..853cd1235 100644 --- a/tl.lua +++ b/tl.lua @@ -6309,6 +6309,7 @@ local equality_binop = { local unop_types = { ["#"] = { + ["enum"] = "integer", ["string"] = "integer", ["array"] = "integer", ["tupletable"] = "integer", diff --git a/tl.tl b/tl.tl index e7c5cb055..f67fa51d2 100644 --- a/tl.tl +++ b/tl.tl @@ -6309,6 +6309,7 @@ local equality_binop = { local unop_types: {string:{TypeName:TypeName}} = { ["#"] = { + ["enum"] = "integer", ["string"] = "integer", ["array"] = "integer", ["tupletable"] = "integer", From 222b8c16c00e51e13d40a1fa307871d7accbe1a5 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 23 Jul 2024 20:19:10 -0300 Subject: [PATCH 133/224] API: tl.pretty_print_ast third argument is optional --- tl.tl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tl.tl b/tl.tl index f67fa51d2..a6b6f61bf 100644 --- a/tl.tl +++ b/tl.tl @@ -4827,7 +4827,7 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode: boolean | PrettyPrintOptions): string, string +function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode?: boolean | PrettyPrintOptions): string, string local err: string local indent = 0 From a1dcf7900e5631f3d2c82113ba3097a17baf9148 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 26 Jul 2024 15:03:23 -0300 Subject: [PATCH 134/224] tests: add regression test to function call check crash Add another regression test to a bug fixed in f97625d3703ce71d67281c3cfc6b5ed7a1289fe8 Co-Authored-By: Victor Ilchev <46074073+V1K1NGbg@users.noreply.github.com> --- spec/call/function_spec.lua | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/spec/call/function_spec.lua b/spec/call/function_spec.lua index 3860c4047..5788ff5b6 100644 --- a/spec/call/function_spec.lua +++ b/spec/call/function_spec.lua @@ -104,5 +104,15 @@ describe("function calls", function() ]], { { y = 2, msg = "wrong number of arguments (given 0, expects at least 1 and at most 4)" }, })) + + it("with insufficient arguments (regression test)", util.check_type_error([[ + local function a(f: (function(): any), ...: any): (function():any) + return function(...): (function():any) + return f(a, ...) + end + end + ]], { + { y = 3, msg = "wrong number of arguments (given 2, expects 0)" }, + })) end) end) From 3995c24e82240eb4a500177dc96679bb8c583fb9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 1 Aug 2024 13:05:51 -0300 Subject: [PATCH 135/224] standard library: special-case tuple support for table.unpack Add special cases for tuples of sizes up to 5. --- spec/call/generic_function_spec.lua | 4 ++-- spec/stdlib/table_spec.lua | 12 ++++++++++++ tl.lua | 4 ++++ tl.tl | 4 ++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/spec/call/generic_function_spec.lua b/spec/call/generic_function_spec.lua index ec68bb3ff..d512c307c 100644 --- a/spec/call/generic_function_spec.lua +++ b/spec/call/generic_function_spec.lua @@ -494,9 +494,9 @@ describe("generic function", function() ]])) it("generic function definitions do not leak type variables (#322)", util.check([[ - local function my_unpack(_list: {T}, _x: number, _y: number): T... + local function my_move(_list: {T}, _a: integer, _b: integer, _c: integer, _t?: {T}): {T} end - local _tbl_unpack = my_unpack or table.unpack + local _tbl_move = my_move or table.move local _map: {string:number} = setmetatable(assert({}), { __mode = "k" }) ]])) diff --git a/spec/stdlib/table_spec.lua b/spec/stdlib/table_spec.lua index f5f1d0ade..a360d2efb 100644 --- a/spec/stdlib/table_spec.lua +++ b/spec/stdlib/table_spec.lua @@ -10,6 +10,18 @@ describe("table", function() local b = b as string local c = c as number ]])) + + -- standard library definition has special cases + -- for tuples of sizes up to 5 + it("can unpack some tuples", util.check([[ + local s = { 1234, "5678", 4566, "foo", 123 } + local a, b, c, d, e = table.unpack(s) + a = a + 1 -- number + b = b .. "!" -- string + c = c + 2 -- number + d = d .. "!" -- string + e = e + 3 -- number + ]])) end) describe("concat", function() diff --git a/tl.lua b/tl.lua index 853cd1235..40657d31e 100644 --- a/tl.lua +++ b/tl.lua @@ -353,6 +353,10 @@ do remove: function({A}, ? integer): A sort: function({A}, ? SortFunction) + unpack: function({A1, A2, A3, A4, A5}): A1, A2, A3, A4, A5 --[[needs_compat]] + unpack: function({A1, A2, A3, A4}): A1, A2, A3, A4 --[[needs_compat]] + unpack: function({A1, A2, A3}): A1, A2, A3 --[[needs_compat]] + unpack: function({A1, A2}): A1, A2 --[[needs_compat]] unpack: function({A}, ? number, ? number): A... --[[needs_compat]] end diff --git a/tl.tl b/tl.tl index a6b6f61bf..f71951ae5 100644 --- a/tl.tl +++ b/tl.tl @@ -353,6 +353,10 @@ do remove: function({A}, ? integer): A sort: function({A}, ? SortFunction) + unpack: function({A1, A2, A3, A4, A5}): A1, A2, A3, A4, A5 --[[needs_compat]] + unpack: function({A1, A2, A3, A4}): A1, A2, A3, A4 --[[needs_compat]] + unpack: function({A1, A2, A3}): A1, A2, A3 --[[needs_compat]] + unpack: function({A1, A2}): A1, A2 --[[needs_compat]] unpack: function({A}, ? number, ? number): A... --[[needs_compat]] end From 91afcc8b2fc9f04defa25fa88696d0617c4b4256 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 5 Aug 2024 17:44:33 -0300 Subject: [PATCH 136/224] docs/tutorial.md: fix typo Co-authored-by: Darren Jennings --- docs/tutorial.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index 1451d5e6f..18572c1ac 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -179,7 +179,7 @@ local record File close: function(File): boolean, string end --- a record can doubles as a record and an array, by declaring an array interface +-- a record can double as a record and an array, by declaring an array interface local record TreeNode is {TreeNode} item: T end From d0856b55de93e322cbd4b7ce3fefaedd2fdbf695 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 5 Aug 2024 17:29:51 -0300 Subject: [PATCH 137/224] tl types: do not crash if given input file doesn't exist --- spec/cli/types_spec.lua | 7 +++++++ tl | 12 ++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index a1334e71f..4abef1740 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -4,6 +4,13 @@ local util = require("spec.util") describe("tl types works like check", function() describe("on .tl files", function() + it("reports missing files", function() + local pd = io.popen(util.tl_cmd("types", "nonexistent_file") .. "2>&1 1>" .. util.os_null, "r") + local output = pd:read("*a") + util.assert_popen_close(1, pd:close()) + assert.match("could not open nonexistent_file", output, 1, true) + end) + it("works on empty files", function() local name = util.write_tmp_file(finally, [[]]) local pd = io.popen(util.tl_cmd("types", name) .. " 2>" .. util.os_null, "r") diff --git a/tl b/tl index fc1fa315e..e78f40322 100755 --- a/tl +++ b/tl @@ -898,15 +898,23 @@ do env.report_types = true for i, input_file in ipairs(args["file"]) do - local pok, err = pcall(process_module, input_file, env) + local pok, perr, err = pcall(process_module, input_file, env) if not pok then - die("Internal Compiler Error: " .. err) + die("Internal Compiler Error: " .. perr) end + if err then + printerr(err) + end + check_collect(i) end local ok, _, _, w = report_all_errors(tlconfig, env) + if not env.reporter then + os.exit(1) + end + local tr = env.reporter:get_report() if tr then if w or not ok then From 5f853333785323ab0db2dd2eb9130ff4f05dffdf Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 5 Aug 2024 17:30:23 -0300 Subject: [PATCH 138/224] local type require() accepts dot notation for nested record `local type MyType = require("module").MyType` is now valid. Closes #778. --- docs/grammar.md | 5 +- spec/cli/types_spec.lua | 11 ++++ spec/stdlib/pcall_spec.lua | 6 +++ spec/stdlib/require_spec.lua | 97 ++++++++++++++++++++++++++++++++++++ tl.lua | 76 ++++++++++++++++++++++------ tl.tl | 76 ++++++++++++++++++++++------ 6 files changed, 239 insertions(+), 32 deletions(-) diff --git a/docs/grammar.md b/docs/grammar.md index 49b0682e9..ac63d3a8c 100644 --- a/docs/grammar.md +++ b/docs/grammar.md @@ -108,12 +108,13 @@ precedence, see below. * typeargs ::= ‘<’ Name {‘,’ Name } ‘>’ * newtype ::= ‘record’ recordbody | ‘enum’ enumbody | type +* | ‘require’ ‘(’ LiteralString ‘)’ {‘.’ Name } * interfacelist ::= nominal {‘,’ nominal} | -* ‘{’ type ‘}’ {‘,’ nominal} +* ‘{’ type ‘}’ {‘,’ nominal} * recordbody ::= [typeargs] [‘is’ interfacelist] -* [‘where’ exp] {recordentry} ‘end’ +* [‘where’ exp] {recordentry} ‘end’ * recordentry ::= ‘userdata’ | * ‘type’ Name ‘=’ newtype | [‘metamethod’] recordkey ‘:’ type | diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 4abef1740..24f36ff05 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -211,6 +211,17 @@ describe("tl types works like check", function() util.assert_popen_close(0, pd:close()) -- TODO check json output end) + + it("does not crash when a require() expression does not resolve (#778)", function() + local name = util.write_tmp_file(finally, [[ + local type Foo = require("missingmodule").baz + ]]) + local pd = io.popen(util.tl_cmd("types", name, "--gen-target=5.1") .. "2>&1 1>" .. util.os_null, "r") + local output = pd:read("*a") + util.assert_popen_close(1, pd:close()) + assert.match("1 error:", output, 1, true) + -- TODO check json output + end) end) describe("on .lua files", function() diff --git a/spec/stdlib/pcall_spec.lua b/spec/stdlib/pcall_spec.lua index 7599a29e2..7e99c8376 100644 --- a/spec/stdlib/pcall_spec.lua +++ b/spec/stdlib/pcall_spec.lua @@ -2,6 +2,12 @@ local tl = require("tl") local util = require("spec.util") describe("pcall", function() + it("can't use pcall in local type", util.check_syntax_error([[ + local type bla = pcall("require", "something") + ]], { + { msg = "pcall() cannot be used in type declarations" } + })) + it("can't pcall nothing", util.check_type_error([[ local pok = pcall() ]], { diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index ec1fb26ba..72ba21859 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -1172,4 +1172,101 @@ describe("require", function() { y = 14, msg = "got , expected Mod" }, })) end) + + it("in 'local type' accepts dots for extracting nested types", function () + -- ok + util.mock_io(finally, { + ["mod.tl"] = [[ + local record mod + record Foo + something: K + fn: function(): Foo + end + end + + return mod + ]], + ["main.tl"] = [[ + local type Foo = require("mod").Foo + local function f(v: Foo) + print(v.something) + -- check that aliasing works: + local x = v.fn() + x = v + end + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("in 'local type' with dots fails if not a record", function () + -- ok + util.mock_io(finally, { + ["mod.tl"] = [[ + return 123 + ]], + ["main.tl"] = [[ + local type Foo = require("mod").Foo + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { filename = "main.tl", x = 37, y = 1, msg = "type is not a record" }, + { filename = "main.tl", x = 45, y = 1, msg = "cannot index key 'Foo' in type integer" }, + }, result.type_errors) + end) + + it("in 'local type' with dots fails if not a record", function () + -- ok + util.mock_io(finally, { + ["mod.tl"] = [[ + local record mod + record Bar + end + end + + return mod + ]], + ["main.tl"] = [[ + local type Foo = require("mod").Foo + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { filename = "main.tl", x = 45, y = 1, msg = "nested type 'Foo' not found in record" }, + }, result.type_errors) + end) + + it("in 'local type' does not accept arbitrary expressions", function () + -- ok + util.mock_io(finally, { + ["mod.tl"] = [[ + local record mod + record Foo + something: K + end + end + + return mod + ]], + ["main.tl"] = [[ + local type Foo = require("mod") + "hello" + local function f(v: Foo) + print(v.something) + end + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({ + { filename = "main.tl", x = 30, y = 1, msg = "require() in type declarations cannot be part of larger expressions" } + }, result.syntax_errors) + end) end) diff --git a/tl.lua b/tl.lua index 40657d31e..b8c9890a8 100644 --- a/tl.lua +++ b/tl.lua @@ -2813,22 +2813,37 @@ do end local function node_is_require_call(n) - if n.e1 and n.e2 and - n.e1.kind == "variable" and n.e1.tk == "require" and + if not (n.e1 and n.e2) then + return nil + end + if n.op and n.op.op == "." then + + return node_is_require_call(n.e1) + elseif n.e1.kind == "variable" and n.e1.tk == "require" and n.e2.kind == "expression_list" and #n.e2 == 1 and n.e2[1].kind == "string" then + return n.e2[1].conststr - elseif n.op and n.op.op == "@funcall" and + end + return nil + end + + local function node_is_require_call_or_pcall(n) + local r = node_is_require_call(n) + if r then + return r + end + if n.op and n.op.op == "@funcall" and n.e1 and n.e1.tk == "pcall" and n.e2 and #n.e2 == 2 and n.e2[1].kind == "variable" and n.e2[1].tk == "require" and n.e2[2].kind == "string" and n.e2[2].conststr then + return n.e2[2].conststr - else - return nil end + return nil end do @@ -3004,7 +3019,7 @@ do e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - table.insert(ps.required_modules, node_is_require_call(e1)) + table.insert(ps.required_modules, node_is_require_call_or_pcall(e1)) elseif tkop.tk == "[" then local op = new_operator(tkop, 2, "@index") @@ -3047,7 +3062,7 @@ do table.insert(args, argument) e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - table.insert(ps.required_modules, node_is_require_call(e1)) + table.insert(ps.required_modules, node_is_require_call_or_pcall(e1)) elseif tkop.tk == "as" or tkop.tk == "is" then local op = new_operator(tkop, 2, tkop.tk) @@ -4034,13 +4049,27 @@ do i = verify_tk(ps, i, "=") - if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + if ps.tokens[i].kind == "identifier" then + if ps.tokens[i].tk == "require" then + local istart = i + i, asgn.value = parse_expression(ps, i) + if asgn.value then + if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then + fail(ps, istart, "require() in type declarations cannot be part of larger expressions") + return i + end + if not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") + return i + end + return i, asgn + else + return i + end + elseif ps.tokens[i].tk == "pcall" then + fail(ps, i, "pcall() cannot be used in type declarations") + return i end - return i, asgn end i, asgn.value = parse_newtype(ps, i) @@ -10591,8 +10620,25 @@ self:expand_type(node, values, elements) }) local ty = t.typename == "tuple" and t.tuple[1] or t ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty - local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) - return td + return (ty.typename == "typedecl") and ty or (a_type(value, "typedecl", { def = ty })) + elseif value.kind == "op" and + value.op.op == "." then + + local ty = self:get_typedecl(value.e1) + if ty.typename == "typedecl" then + local def = ty.def + if def.typename == "record" then + local t = def.fields[value.e2.tk] + if t and t.typename == "typedecl" then + return t + else + return self.errs:invalid_at(value.e2, "nested type '" .. value.e2.tk .. "' not found in record") + end + else + return self.errs:invalid_at(value.e1, "type is not a record") + end + end + return ty else local newtype = value.newtype if newtype.typename == "typealias" then diff --git a/tl.tl b/tl.tl index f71951ae5..93abf21e2 100644 --- a/tl.tl +++ b/tl.tl @@ -2813,22 +2813,37 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node end local function node_is_require_call(n: Node): string - if n.e1 and n.e2 -- literal require call - and n.e1.kind == "variable" and n.e1.tk == "require" + if not (n.e1 and n.e2) then + return nil + end + if n.op and n.op.op == "." then + -- `require("str").something` + return node_is_require_call(n.e1) + elseif n.e1.kind == "variable" and n.e1.tk == "require" and n.e2.kind == "expression_list" and #n.e2 == 1 and n.e2[1].kind == "string" then + -- `require("str")` return n.e2[1].conststr - elseif n.op and n.op.op == "@funcall" -- pcall(require, "str") + end + return nil -- table.insert cares about arity +end + +local function node_is_require_call_or_pcall(n: Node): string + local r = node_is_require_call(n) + if r then + return r + end + if n.op and n.op.op == "@funcall" and n.e1 and n.e1.tk == "pcall" and n.e2 and #n.e2 == 2 and n.e2[1].kind == "variable" and n.e2[1].tk == "require" and n.e2[2].kind == "string" and n.e2[2].conststr then + -- `pcall(require, "str")` return n.e2[2].conststr - else - return nil -- table.insert cares about arity end + return nil -- table.insert cares about arity end do @@ -3004,7 +3019,7 @@ do e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - table.insert(ps.required_modules, node_is_require_call(e1)) + table.insert(ps.required_modules, node_is_require_call_or_pcall(e1)) elseif tkop.tk == "[" then local op: Operator = new_operator(tkop, 2, "@index") @@ -3047,7 +3062,7 @@ do table.insert(args, argument) e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - table.insert(ps.required_modules, node_is_require_call(e1)) + table.insert(ps.required_modules, node_is_require_call_or_pcall(e1)) elseif tkop.tk == "as" or tkop.tk == "is" then local op: Operator = new_operator(tkop, 2, tkop.tk) @@ -4034,13 +4049,27 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin i = verify_tk(ps, i, "=") - if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + if ps.tokens[i].kind == "identifier" then + if ps.tokens[i].tk == "require" then + local istart = i + i, asgn.value = parse_expression(ps, i) + if asgn.value then + if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then + fail(ps, istart, "require() in type declarations cannot be part of larger expressions") + return i + end + if not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") + return i + end + return i, asgn + else + return i + end + elseif ps.tokens[i].tk == "pcall" then + fail(ps, i, "pcall() cannot be used in type declarations") + return i end - return i, asgn end i, asgn.value = parse_newtype(ps, i) @@ -10591,8 +10620,25 @@ do local ty: Type = t is TupleType and t.tuple[1] or t ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty - local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType) - return td + return (ty is TypeDeclType) and ty or (a_type(value, "typedecl", { def = ty } as TypeDeclType)) + elseif value.kind == "op" + and value.op.op == "." + then + local ty = self:get_typedecl(value.e1) + if ty is TypeDeclType then + local def = ty.def + if def is RecordType then + local t = def.fields[value.e2.tk] + if t and t is TypeDeclType then + return t + else + return self.errs:invalid_at(value.e2, "nested type '" .. value.e2.tk .. "' not found in record") + end + else + return self.errs:invalid_at(value.e1, "type is not a record") + end + end + return ty else local newtype = value.newtype if newtype is TypeAliasType then From 448fe8471e753e1737d644c031210b32a674bf55 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 7 Aug 2024 12:15:37 -0300 Subject: [PATCH 139/224] fix: delay resolution of type arguments Ideally we need to recurse type arguments and propagate those. But since they are arriving in typevar_resolver as nominal, we're just renaming the existing typeargs (see `HACK` in source), and test cases. This is most likely wrong, hope it doesn't break anything else! Closes #777. --- spec/declaration/local_spec.lua | 31 +++++++++++++++++++++++++++++++ spec/declaration/record_spec.lua | 19 +++++++++++++++++++ tl.lua | 29 +++++++++++++++++++---------- tl.tl | 31 ++++++++++++++++++++----------- 4 files changed, 89 insertions(+), 21 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index 6b93b8b1b..932b11167 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -129,6 +129,37 @@ describe("local", function() }, result.type_errors) end) + it("local type can resolve a nominal with generics (regression test for #777)", function () + util.mock_io(finally, { + ["module.tl"] = [[ + local record module + record Foo + something: K + end + end + return module + ]], + ["main.tl"] = [[ + local module = require "module" + + local record Boo + field: MyFoo + end + + local type MyFoo = module.Foo + + local b: Boo = { field = { something = "hi" } } + local c: Boo = { field = { something = 123 } } + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { y = 10, x = 55, filename = "main.tl", msg = "in record field: something: got integer, expected string" }, + }, result.type_errors) + end) + it("catches unknown types", util.check_type_error([[ local type MyType = UnknownType ]], { diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index 2af252bcc..b93b5802d 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -894,6 +894,25 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces { y = 1, msg = "unknown type Bongo" }, { y = 1, msg = "unknown type Bingo" }, })) + + + it("reports error on unexpected generics", util.check_type_error([[ + local ]]..statement..[[ Foo ]]..array(i, "is {number}")..[[ + end + + local x: Foo + ]], { + { y = 4, x = 19, msg = "unexpected type argument" }, + })) + + it("reports error on unexpected generics", util.check_type_error([[ + local ]]..statement..[[ Foo ]]..array(i, "is {number}")..[[ + end + + local x: Foo + ]], { + { y = 4, x = 19, msg = "mismatch in number of type arguments" }, + })) end) end diff --git a/tl.lua b/tl.lua index b8c9890a8..43a715bbf 100644 --- a/tl.lua +++ b/tl.lua @@ -4577,11 +4577,6 @@ local function recurse_node(s, root, end end - local function walk_var_value(ast, xs) - xs[1] = recurse(ast.var) - xs[2] = recurse(ast.value) - end - local function walk_named_function(ast, xs) recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) @@ -4628,8 +4623,16 @@ local function recurse_node(s, root, ["local_declaration"] = walk_vars_exps, ["global_declaration"] = walk_vars_exps, - ["local_type"] = walk_var_value, + ["local_type"] = function(ast, xs) + + + xs[1] = recurse(ast.var) + xs[2] = recurse(ast.value) + end, + ["global_type"] = function(ast, xs) + + xs[1] = recurse(ast.var) if ast.value then xs[2] = recurse(ast.value) @@ -7270,7 +7273,7 @@ do if t.typename == "typevar" then local rt = fn_var(self, t) if rt then - resolved[t.typevar] = true + resolved[t.typevar] = rt if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then seen[orig_t] = rt return rt, false @@ -7426,8 +7429,14 @@ do copy.typeargs then for i = #copy.typeargs, 1, -1 do - if resolved[copy.typeargs[i].typearg] then - table.remove(copy.typeargs, i) + local r = resolved[copy.typeargs[i].typearg] + if r then + + if r.typename == "nominal" and #r.names == 1 then + copy.typeargs[i].typearg = r.names[1] + else + table.remove(copy.typeargs, i) + end end end if not copy.typeargs[1] then @@ -7749,7 +7758,7 @@ do self:end_scope() return ret elseif t.typevals then - self.errs:add(t, "spurious type arguments") + self.errs:add(t, "unexpected type argument") return nil elseif def.typeargs then self.errs:add(t, "missing type arguments in %s", def) diff --git a/tl.tl b/tl.tl index 93abf21e2..668380be3 100644 --- a/tl.tl +++ b/tl.tl @@ -4577,11 +4577,6 @@ local function recurse_node(s: S, root: Node, end end - local function walk_var_value(ast: Node, xs: {T}) - xs[1] = recurse(ast.var) - xs[2] = recurse(ast.value) - end - local function walk_named_function(ast: Node, xs: {T}) recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) @@ -4628,8 +4623,16 @@ local function recurse_node(s: S, root: Node, ["local_declaration"] = walk_vars_exps, ["global_declaration"] = walk_vars_exps, - ["local_type"] = walk_var_value, + ["local_type"] = function(ast: Node, xs: {T}) + -- TODO need to recurse typeargs + -- recurse_typeargs(s, ast, visit_type) + xs[1] = recurse(ast.var) + xs[2] = recurse(ast.value) + end, + ["global_type"] = function(ast: Node, xs: {T}) + -- TODO need to recurse typeargs + -- recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.var) if ast.value then xs[2] = recurse(ast.value) @@ -7243,7 +7246,7 @@ do typevar_resolver = function(self: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local errs: {Error} local seen: {Type:Type} = {} - local resolved: {string:boolean} = {} + local resolved: {string:Type} = {} local resolve: function(t: T, all_same: boolean): T, boolean local function copy_typeargs(t: {TypeArgType}, same: boolean): {TypeArgType}, boolean @@ -7270,7 +7273,7 @@ do if t is TypeVarType then local rt = fn_var(self, t) if rt then - resolved[t.typevar] = true + resolved[t.typevar] = rt if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then seen[orig_t] = rt return rt, false @@ -7426,8 +7429,14 @@ do copy.typeargs then for i = #copy.typeargs, 1, -1 do - if resolved[copy.typeargs[i].typearg] then - table.remove(copy.typeargs, i) + local r = resolved[copy.typeargs[i].typearg] + if r then + -- FIXME HACK!!! + if r is NominalType and #r.names == 1 then + copy.typeargs[i].typearg = r.names[1] + else + table.remove(copy.typeargs, i) + end end end if not copy.typeargs[1] then @@ -7749,7 +7758,7 @@ do self:end_scope() return ret elseif t.typevals then - self.errs:add(t, "spurious type arguments") + self.errs:add(t, "unexpected type argument") return nil elseif def.typeargs then self.errs:add(t, "missing type arguments in %s", def) From 9736418f4fd3b8220c1ed337aaefe2a717d46df8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 7 Aug 2024 15:45:14 -0300 Subject: [PATCH 140/224] stricter checks for shadowing of type arguments Produces an error when using a base type as a type argument, and a warning when using another variable name. We can be stricter with type arguments because those exist only in Teal's type world. We can't be as strict with redeclarations in general because record types also exist as concrete Lua objects (example: the "string" table in the standard library). Fixes #764. --- spec/declaration/functiontype_spec.lua | 9 +++++ spec/declaration/record_function_spec.lua | 2 +- spec/declaration/record_method_spec.lua | 2 +- spec/declaration/record_spec.lua | 28 ++++++++++++++ spec/error_reporting/warning_spec.lua | 22 +++++------ tl.lua | 46 +++++++++++++---------- tl.tl | 46 +++++++++++++---------- 7 files changed, 104 insertions(+), 51 deletions(-) diff --git a/spec/declaration/functiontype_spec.lua b/spec/declaration/functiontype_spec.lua index 47d4b83a8..1aea61251 100644 --- a/spec/declaration/functiontype_spec.lua +++ b/spec/declaration/functiontype_spec.lua @@ -9,6 +9,15 @@ describe("functiontype declaration", function() end ]])) + it("cannot use base type as a type variable", util.check_type_error([[ + local record Foo + end + + local f: function(self:Foo):(string) + ]], { + { y = 4, x = 25, msg = "cannot use base type name 'integer' as a type variable" }, + })) + it("produces a nice error when declared with the old syntax", util.check_syntax_error([[ local t = functiontype(number, number): string diff --git a/spec/declaration/record_function_spec.lua b/spec/declaration/record_function_spec.lua index 26ea78719..290f5e617 100644 --- a/spec/declaration/record_function_spec.lua +++ b/spec/declaration/record_function_spec.lua @@ -69,7 +69,7 @@ describe("record function", function() return a - b end ]], { - { y = 8, msg = "redeclaration of function 'do_x'" }, + { y = 8, msg = "function shadows previous declaration of 'do_x'" }, })) it("a type signature does not count as a redeclaration", util.check_warnings([[ diff --git a/spec/declaration/record_method_spec.lua b/spec/declaration/record_method_spec.lua index 7f8cf2db6..6bd8766ce 100644 --- a/spec/declaration/record_method_spec.lua +++ b/spec/declaration/record_method_spec.lua @@ -543,7 +543,7 @@ describe("record method", function() return a - b end ]], { - { y = 8, msg = "redeclaration of function 'do_x'" }, + { y = 8, msg = "function shadows previous declaration of 'do_x'" }, })) it("a type signature does not count as a redeclaration", util.check_warnings([[ diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index b93b5802d..be46f1c62 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -340,6 +340,13 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces f.example = { x = "hello" } ]])) + it("cannot use base type as a type variable", util.check_type_error([[ + local ]]..statement..[[ Foo + end + ]], { + { y = 1, msg = "cannot use base type name 'integer' as a type variable" }, + })) + it("can have nested enums", util.check([[ local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ enum Direction @@ -913,6 +920,27 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces ]], { { y = 4, x = 19, msg = "mismatch in number of type arguments" }, })) + + it("reports on type variables that shadow other variables (#764)", util.check_warnings([[ + local record Foos + end + + local record test + record Foo + userdata + bar : function(self:Foo):(string) + bar2: function(self:Foo):(string) + end + end + + local x: test.Foo + x:bar() + x:bar2() + ]], { + { y = 5, x = 25, type = "redeclaration", msg = "type argument shadows previous declaration of 'Foos'" }, + { y = 8, x = 33, type = "redeclaration", msg = "type argument shadows previous declaration of 'Foos'" }, + { y = 1, x = 23, type = "unused", msg = "unused type Foos" }, + })) end) end diff --git a/spec/error_reporting/warning_spec.lua b/spec/error_reporting/warning_spec.lua index fa0bd2bdb..6be50782a 100644 --- a/spec/error_reporting/warning_spec.lua +++ b/spec/error_reporting/warning_spec.lua @@ -8,7 +8,7 @@ describe("warnings", function() local a = 2 print(a) ]], { - { y = 3, msg = "redeclaration of variable 'a' (originally declared at 1:16)" }, + { y = 3, msg = "variable shadows previous declaration of 'a' (originally declared at 1:16)" }, })) it("reports redefined variables in for loops", util.check_warnings([[ @@ -25,9 +25,9 @@ describe("warnings", function() print(k, v) end ]], { - { y = 3, msg = "redeclaration of variable 'i' (originally declared at 1:14)" }, - { y = 9, msg = "redeclaration of variable 'k' (originally declared at 7:14)" }, - { y = 10, msg = "redeclaration of variable 'v' (originally declared at 7:17)" }, + { y = 3, msg = "variable shadows previous declaration of 'i' (originally declared at 1:14)" }, + { y = 9, msg = "variable shadows previous declaration of 'k' (originally declared at 7:14)" }, + { y = 10, msg = "variable shadows previous declaration of 'v' (originally declared at 7:17)" }, })) it("reports use of pairs on arrays", util.check_warnings([[ @@ -56,7 +56,7 @@ describe("warnings", function() { y = 1, msg = [[unused variable foo: string]] } })) - it("does not report redeclaration of variables prefixed with '_'", util.check_warnings([[ + it("does not report variable shadows previous declaration ofs prefixed with '_'", util.check_warnings([[ local _ = 1 print(_) -- ensure usage local _ = 2 @@ -78,7 +78,7 @@ describe("warnings", function() print(a) end ]], { - { y = 3, msg = "redeclaration of variable 'a' (originally declared at 1:16)" }, + { y = 3, msg = "variable shadows previous declaration of 'a' (originally declared at 1:16)" }, { y = 1, msg = "unused variable a: integer" }, })) @@ -122,7 +122,7 @@ describe("warnings", function() print(i) end ]], { - { y = 2, msg = "redeclaration of variable 'i' (originally declared at 1:16)" }, + { y = 2, msg = "variable shadows previous declaration of 'i' (originally declared at 1:16)" }, { y = 1, msg = "unused variable i: integer" }, })) @@ -132,7 +132,7 @@ describe("warnings", function() print(i) end ]], { - { y = 2, msg = "redeclaration of variable 'i' (originally declared at 1:16)" }, + { y = 2, msg = "variable shadows previous declaration of 'i' (originally declared at 1:16)" }, { y = 1, msg = "unused variable i: integer" }, })) @@ -143,7 +143,7 @@ describe("warnings", function() local function a() end a() ]], { - { y = 3, msg = "redeclaration of function 'a' (originally declared at 1:10)" }, + { y = 3, msg = "function shadows previous declaration of 'a' (originally declared at 1:10)" }, })) it("reports local functions redefined as variables", util.check_warnings([[ @@ -152,7 +152,7 @@ describe("warnings", function() local a = 3 print(a) ]], { - { y = 3, msg = "redeclaration of variable 'a' (originally declared at 1:10)" }, + { y = 3, msg = "variable shadows previous declaration of 'a' (originally declared at 1:10)" }, })) it("reports local variables redefined as functions", util.check_warnings([[ @@ -161,7 +161,7 @@ describe("warnings", function() local function a() end a() ]], { - { y = 3, msg = "redeclaration of function 'a' (originally declared at 1:16)" }, + { y = 3, msg = "function shadows previous declaration of 'a' (originally declared at 1:16)" }, })) end) diff --git a/tl.lua b/tl.lua index 43a715bbf..5eae2beb1 100644 --- a/tl.lua +++ b/tl.lua @@ -2552,12 +2552,12 @@ do local function parse_typearg(ps, i) local name = ps.tokens[i].tk local constraint + local t = new_type(ps, i, "typearg") i = verify_kind(ps, i, "identifier") if ps.tokens[i].tk == "is" then i = i + 1 i, constraint = parse_interface_name(ps, i) end - local t = new_type(ps, i, "typearg") t.typearg = name t.constraint = constraint return i, t @@ -6084,21 +6084,14 @@ function Errors:add_unknown(node, name) self:add_warning("unknown", node, "unknown variable: %s", name) end -function Errors:redeclaration_warning(node, old_var) - if node.tk:sub(1, 1) == "_" then return end - - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end +function Errors:redeclaration_warning(at, var_name, var_kind, old_var) + if var_name:sub(1, 1) == "_" then return end - local short_error = "redeclaration of " .. var_kind .. " '%s'" + local short_error = var_kind .. " shadows previous declaration of '%s'" if old_var and old_var.declared_at then - self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + self:add_warning("redeclaration", at, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) else - self:add_warning("redeclaration", node, short_error, var_name) + self:add_warning("redeclaration", at, short_error, var_name) end end @@ -7482,10 +7475,16 @@ do end - function TypeChecker:check_if_redeclaration(new_name, at) + function TypeChecker:check_if_redeclaration(new_name, node) local old = self:find_var(new_name, "check_only") - if old then - self.errs:redeclaration_warning(at, old) + if old or simple_types[new_name] then + local var_name = node.tk + local var_kind = "variable" + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + self.errs:redeclaration_warning(node, var_name, var_kind, old) end end @@ -11575,7 +11574,7 @@ self:expand_type(node, values, elements) }) rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - self.errs:redeclaration_warning(node) + self.errs:redeclaration_warning(node, node.name.tk, "function") end local ok, err = self:same_type(fn_type, rfieldtype) @@ -12398,8 +12397,17 @@ self:expand_type(node, values, elements) }) }, ["typearg"] = { after = function(self, typ, _children) - self:add_var(nil, typ.typearg, a_type(typ, "typearg", { - typearg = typ.typearg, + local name = typ.typearg + local old = self:find_var(name, "check_only") + if old then + self.errs:redeclaration_warning(typ, name, "type argument", old) + end + if simple_types[name] then + self.errs:add(typ, "cannot use base type name '" .. name .. "' as a type variable") + end + + self:add_var(nil, name, a_type(typ, "typearg", { + typearg = name, constraint = typ.constraint, })) return typ diff --git a/tl.tl b/tl.tl index 668380be3..092c2e7dd 100644 --- a/tl.tl +++ b/tl.tl @@ -2552,12 +2552,12 @@ end local function parse_typearg(ps: ParseState, i: integer): integer, TypeArgType, integer local name = ps.tokens[i].tk local constraint: Type + local t = new_type(ps, i, "typearg") as TypeArgType i = verify_kind(ps, i, "identifier") if ps.tokens[i].tk == "is" then i = i + 1 i, constraint = parse_interface_name(ps, i) -- FIXME what about generic interfaces end - local t = new_type(ps, i, "typearg") as TypeArgType t.typearg = name t.constraint = constraint return i, t @@ -6084,21 +6084,14 @@ function Errors:add_unknown(node: Node, name: string) self:add_warning("unknown", node, "unknown variable: %s", name) end -function Errors:redeclaration_warning(node: Node, old_var?: Variable) - if node.tk:sub(1, 1) == "_" then return end - - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end +function Errors:redeclaration_warning(at: Where, var_name: string, var_kind: string, old_var?: Variable) + if var_name:sub(1, 1) == "_" then return end - local short_error = "redeclaration of " .. var_kind .. " '%s'" + local short_error = var_kind .. " shadows previous declaration of '%s'" if old_var and old_var.declared_at then - self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + self:add_warning("redeclaration", at, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) else - self:add_warning("redeclaration", node, short_error, var_name) + self:add_warning("redeclaration", at, short_error, var_name) end end @@ -7482,10 +7475,16 @@ do end - function TypeChecker:check_if_redeclaration(new_name: string, at: Node) + function TypeChecker:check_if_redeclaration(new_name: string, node: Node) local old = self:find_var(new_name, "check_only") - if old then - self.errs:redeclaration_warning(at, old) + if old or simple_types[new_name as TypeName] then + local var_name = node.tk + local var_kind = "variable" + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + self.errs:redeclaration_warning(node, var_name, var_kind, old) end end @@ -11575,7 +11574,7 @@ do rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - self.errs:redeclaration_warning(node) + self.errs:redeclaration_warning(node, node.name.tk, "function") end local ok, err = self:same_type(fn_type, rfieldtype) @@ -12398,8 +12397,17 @@ do }, ["typearg"] = { after = function(self: TypeChecker, typ: TypeArgType, _children: {Type}): Type - self:add_var(nil, typ.typearg, a_type(typ, "typearg", { - typearg = typ.typearg, + local name = typ.typearg + local old = self:find_var(name, "check_only") + if old then + self.errs:redeclaration_warning(typ, name, "type argument", old) + end + if simple_types[name as TypeName] then + self.errs:add(typ, "cannot use base type name '" .. name .. "' as a type variable") + end + + self:add_var(nil, name, a_type(typ, "typearg", { + typearg = name, constraint = typ.constraint, } as TypeArgType)) return typ From 49c2f4ef658bc749d259f0ed79bad1ea81129d76 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 7 Aug 2024 17:00:40 -0300 Subject: [PATCH 141/224] refactor: rename "declaration" narrowing to "localizing" --- tl.lua | 4 ++-- tl.tl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tl.lua b/tl.lua index 5eae2beb1..0d18da510 100644 --- a/tl.lua +++ b/tl.lua @@ -9346,7 +9346,7 @@ a.types[i], b.types[i]), } local v = scope.vars[var] assert(v, "no " .. var .. " in scope") local narrow_mode = scope.vars[var].is_narrowed - if (not narrow_mode) or narrow_mode == "declaration" then + if (not narrow_mode) or narrow_mode == "localizing" then return false end @@ -10845,7 +10845,7 @@ self:expand_type(node, values, elements) }) end assert(var) - self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "localizing") local infertype = infertypes.tuple[i] if ok and infertype then diff --git a/tl.tl b/tl.tl index 092c2e7dd..1447b2ac6 100644 --- a/tl.tl +++ b/tl.tl @@ -719,7 +719,7 @@ local DEFAULT_GEN_TARGET : GenTarget = "5.3" local enum Narrow "narrow" "narrowed_declaration" - "declaration" + "localizing" end local record Variable @@ -9346,7 +9346,7 @@ do local v = scope.vars[var] assert(v, "no " .. var .. " in scope") local narrow_mode = scope.vars[var].is_narrowed - if (not narrow_mode) or narrow_mode == "declaration" then + if (not narrow_mode) or narrow_mode == "localizing" then return false end @@ -10845,7 +10845,7 @@ do end assert(var) - self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "localizing") local infertype = infertypes.tuple[i] if ok and infertype then From ccb508115be531067d8eaaab4f33673699c372d3 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 7 Aug 2024 17:09:14 -0300 Subject: [PATCH 142/224] fix: localizing a record does not make the new local a type This allows for a pattern where we can localize a record and load and alternative implementation on top of it. Fixes #759. --- spec/call/record_method_spec.lua | 5 +++-- spec/declaration/local_spec.lua | 13 +++++++++++++ spec/stdlib/require_spec.lua | 27 +++++++++++++++++++++++++++ tl.lua | 6 +++++- tl.tl | 6 +++++- 5 files changed, 53 insertions(+), 4 deletions(-) diff --git a/spec/call/record_method_spec.lua b/spec/call/record_method_spec.lua index 0e91056c1..5e000dc01 100644 --- a/spec/call/record_method_spec.lua +++ b/spec/call/record_method_spec.lua @@ -204,8 +204,9 @@ describe("record method call", function() end m.a.add(first) ]], { - -- FIXME this warning needs to go away when we detect that "m.a" and "first" are not the same - { y = 14, msg = "invoked method as a regular function: consider using ':' instead of '.'" } + -- FIXME these warnings need to go away when we detect that the arities are correct + { y = 10, msg = "invoked method as a regular function: consider using ':' instead of '.'" }, + { y = 14, msg = "invoked method as a regular function: consider using ':' instead of '.'" }, }, {})) it("for function declared in record body with self as different type from receiver", util.check_warnings([[ diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index 932b11167..044d7a47a 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -460,6 +460,19 @@ describe("local", function() local b : A = { v = 10 } ]])) + + it("localizing a record does not make the new local a type (#759)", util.check([[ + local record k + end + + local kk: k = {} + + local k = k + + k = {} + + kk = {} + ]])) end) describe("", function() diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 72ba21859..5ecfb84a5 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -1269,4 +1269,31 @@ describe("require", function() { filename = "main.tl", x = 30, y = 1, msg = "require() in type declarations cannot be part of larger expressions" } }, result.syntax_errors) end) + + it("can localize a record and load and alternative implementation (#759)", function () + util.mock_io(finally, { + ["my-lua-utf8.tl"] = [[ + local record my_lua_utf8 + end + + return my_lua_utf8 + ]], + ["main.tl"] = [[ + local type Utf8 = utf8 + local utf8: Utf8 = utf8 + if not utf8 then + local ok, lutf8 = pcall(require, 'my-lua-utf8') as (boolean, Utf8) + if ok then + utf8 = lutf8 + end + end + print(utf8.charpattern) + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + end) diff --git a/tl.lua b/tl.lua index 0d18da510..c22c1038e 100644 --- a/tl.lua +++ b/tl.lua @@ -6995,7 +6995,7 @@ do local scope = self.st[i] local var = scope.vars[name] if var then - if use == "lvalue" and var.is_narrowed then + if use == "lvalue" and var.is_narrowed and var.is_narrowed ~= "localizing" then if var.narrowed_from then var.used = true return { t = var.narrowed_from, attribute = var.attribute }, i, var.attribute @@ -10611,6 +10611,10 @@ self:expand_type(node, values, elements) }) t.inferred_len = nil elseif t.typename == "nominal" then self:resolve_nominal(t) + local rt = t.resolved + if rt.typename == "typedecl" then + t.resolved = rt.def + end end return ok, t, infertype ~= nil diff --git a/tl.tl b/tl.tl index 1447b2ac6..0ddd0e564 100644 --- a/tl.tl +++ b/tl.tl @@ -6995,7 +6995,7 @@ do local scope = self.st[i] local var = scope.vars[name] if var then - if use == "lvalue" and var.is_narrowed then + if use == "lvalue" and var.is_narrowed and var.is_narrowed ~= "localizing" then if var.narrowed_from then var.used = true return { t = var.narrowed_from, attribute = var.attribute }, i, var.attribute @@ -10611,6 +10611,10 @@ do t.inferred_len = nil elseif t is NominalType then self:resolve_nominal(t) + local rt = t.resolved + if rt is TypeDeclType then + t.resolved = rt.def + end end return ok, t, infertype ~= nil From 7ae8974cd3f2b68bf841473bc11d01191190fa46 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 7 Aug 2024 17:37:11 -0300 Subject: [PATCH 143/224] tests: add regression test for #749 --- spec/declaration/local_spec.lua | 37 +++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index 044d7a47a..b56b7556e 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -461,17 +461,29 @@ describe("local", function() local b : A = { v = 10 } ]])) - it("localizing a record does not make the new local a type (#759)", util.check([[ - local record k + it("does not consider a metamethod to be a missing field (regression test for #749", util.check([[ + local interface Op + op: string end - local kk: k = {} + local interface Binary + is Op + left: number + right: number + end - local k = k + local record Add + is Binary + where self.op == '+' -- removing the comment triggers an error + end - k = {} + local sum : Add = { + op = '+', + left = 10, + right = 20 + } - kk = {} + print(sum.op, sum.left, sum.right) ]])) end) @@ -515,4 +527,17 @@ describe("local", function() ]])) end) end) + + it("localizing a record does not make the new local a type (#759)", util.check([[ + local record k + end + + local kk: k = {} + + local k = k + + k = {} + + kk = {} + ]])) end) From 877e78586c04d1d37a48ee20efa357a1b627bf50 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 10 Aug 2024 23:35:16 -0300 Subject: [PATCH 144/224] only check against shadowing base types when declaring types --- spec/declaration/local_spec.lua | 11 +++++++++++ tl.lua | 6 +++--- tl.tl | 6 +++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index b56b7556e..a141056f3 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -540,4 +540,15 @@ describe("local", function() kk = {} ]])) + + it("using a base type name in a regular variable produces no warnings", util.check_warnings([[ + local any = true + print(any) + + local record integer + end + ]], { + { tag = "redeclaration", msg = "variable shadows previous declaration of 'integer'" }, + { tag = "unused", msg = "unused type integer" }, + })) end) diff --git a/tl.lua b/tl.lua index c22c1038e..862feedb8 100644 --- a/tl.lua +++ b/tl.lua @@ -7475,9 +7475,9 @@ do end - function TypeChecker:check_if_redeclaration(new_name, node) + function TypeChecker:check_if_redeclaration(new_name, node, t) local old = self:find_var(new_name, "check_only") - if old or simple_types[new_name] then + if old or (t.typename == "typedecl" and simple_types[new_name]) then local var_name = node.tk local var_kind = "variable" if node.kind == "local_function" or node.kind == "record_function" then @@ -7562,7 +7562,7 @@ do name ~= "..." and name:sub(1, 1) ~= "@" then - self:check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node, t) end if var and not var.used then diff --git a/tl.tl b/tl.tl index 0ddd0e564..3efe2d069 100644 --- a/tl.tl +++ b/tl.tl @@ -7475,9 +7475,9 @@ do end - function TypeChecker:check_if_redeclaration(new_name: string, node: Node) + function TypeChecker:check_if_redeclaration(new_name: string, node: Node, t: Type) local old = self:find_var(new_name, "check_only") - if old or simple_types[new_name as TypeName] then + if old or (t is TypeDeclType and simple_types[new_name as TypeName]) then local var_name = node.tk local var_kind = "variable" if node.kind == "local_function" or node.kind == "record_function" then @@ -7562,7 +7562,7 @@ do and name ~= "..." and name:sub(1, 1) ~= "@" then - self:check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node, t) end if var and not var.used then From 71cc2443115803d5783cf89fcda49c74fec4ca88 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 7 Aug 2024 17:53:28 -0300 Subject: [PATCH 145/224] interfaces: enforce them to be abstract Fixes #757. --- spec/declaration/record_spec.lua | 71 ++++++++++++++++++++++++++++++-- tl.lua | 10 ++++- tl.tl | 10 ++++- 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index be46f1c62..b9ea201ad 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -434,16 +434,79 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces })[i] })) - it("can extend generic functions", util.check([[ - local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + it("only records can have record methods", util.check_type_error([[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ + end + + function Foo:example(data: string) + print(data) + end + ]], { + ({ + nil, + nil, + { msg = "interfaces are abstract" }, + { msg = "interfaces are abstract" }, + })[i] + })) + + it("only records can have record functions", util.check_type_error([[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ + end + + function Foo.example(data: string) + print(data) + end + ]], { + ({ + nil, + nil, + { msg = "interfaces are abstract" }, + { msg = "interfaces are abstract" }, + })[i] + })) + + it("functions can be implemented in instances", util.check([[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ + example: function(string) + end + + local my_f: Foo = {} + + function my_f.example(data: string) + print(data) + end + ]])) + + it("methods can be implemented in instances", util.check([[ + local ]]..statement..[[ Foo ]]..array(i, "{Foo}")..[[ + example: function(Foo, string) + end + + local my_f: Foo = {} + + function my_f:example(data: string) + print(data) + end + ]])) + + it("with implement generic functions", util.check_type_error([[ + local type Foo = ]]..statement..[[ ]]..array(i, "{Foo}")..[[ type bar = function(T) example: bar end - function foo.example(data: string) + function Foo.example(data: string) print(data) end - ]])) + ]], { + ({ + nil, + nil, + { msg = "interfaces are abstract" }, + { msg = "interfaces are abstract" }, + })[i] + })) it("can use where with generic types", util.check([[ local type Success = ]]..statement..[[ ]]..array(i, "is {integer}")..[[ diff --git a/tl.lua b/tl.lua index 862feedb8..968f57607 100644 --- a/tl.lua +++ b/tl.lua @@ -11534,7 +11534,15 @@ self:expand_type(node, values, elements) }) local rets = children[4] assert(rets.typename == "tuple") - local rtype = self:to_structural(resolve_typedecl(children[1])) + local t = children[1] + local rtype = self:to_structural(resolve_typedecl(t)) + + do + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node, err) + end + end if self.feat_lax and rtype.typename == "unknown" then return diff --git a/tl.tl b/tl.tl index 3efe2d069..dcb7ceda1 100644 --- a/tl.tl +++ b/tl.tl @@ -11534,7 +11534,15 @@ do local rets = children[4] assert(rets is TupleType) - local rtype = self:to_structural(resolve_typedecl(children[1])) + local t = children[1] + local rtype = self:to_structural(resolve_typedecl(t)) + + do + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node, err) + end + end if self.feat_lax and rtype is UnknownType then return From 93c83ba306659d4a3238e5371b30c7b7cfa8fa9c Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 13 Aug 2024 10:53:48 -0300 Subject: [PATCH 146/224] fix: do not close nested types too early Closes #775. --- spec/declaration/record_function_spec.lua | 20 ++++++++++++ tl.lua | 37 +++++++++++++++++------ tl.tl | 37 +++++++++++++++++------ 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/spec/declaration/record_function_spec.lua b/spec/declaration/record_function_spec.lua index 290f5e617..75bd616f5 100644 --- a/spec/declaration/record_function_spec.lua +++ b/spec/declaration/record_function_spec.lua @@ -107,5 +107,25 @@ describe("record function", function() ]], {}, { { y = 7, msg = "different number of return values: got 1, expected 0" }, })) + + it("does not close nested types too early (regression test for #775)", util.check([[ + -- declare a nested record + local record mul + record Fil + mime: function(Fil) + end + end + + -- declare an alias + local type Fil = mul.Fil + + -- this works + function mul.Fil:new_method1(self: Fil) + end + + -- should work as well for alias + function Fil:new_method2(self: Fil) + end + ]])) end) end) diff --git a/tl.lua b/tl.lua index 968f57607..48c6fb4e0 100644 --- a/tl.lua +++ b/tl.lua @@ -12293,22 +12293,38 @@ self:expand_type(node, values, elements) }) end, } + function TypeChecker:begin_temporary_record_types(typ) + self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) + + for fname, ftype in fields_of(typ) do + if ftype.typename == "typealias" then + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) + elseif ftype.typename == "typedecl" then + self:add_var(nil, fname, ftype) + end + end + end + + function TypeChecker:end_temporary_record_types(typ) + + + local scope = self.st[#self.st] + scope.vars["@self"] = nil + for fname, ftype in fields_of(typ) do + if ftype.typename == "typealias" or ftype.typename == "typedecl" then + scope.vars[fname] = nil + end + end + end + local visit_type visit_type = { cbs = { ["record"] = { before = function(self, typ) self:begin_scope() - self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) - - for fname, ftype in fields_of(typ) do - if ftype.typename == "typealias" then - self:resolve_nominal(ftype.alias_to) - self:add_var(nil, fname, ftype) - elseif ftype.typename == "typedecl" then - self:add_var(nil, fname, ftype) - end - end + self:begin_temporary_record_types(typ) end, after = function(self, typ, children) local i = 1 @@ -12402,6 +12418,7 @@ self:expand_type(node, values, elements) }) end end + self:end_temporary_record_types(typ) self:end_scope() return typ diff --git a/tl.tl b/tl.tl index dcb7ceda1..c614555cb 100644 --- a/tl.tl +++ b/tl.tl @@ -12293,22 +12293,38 @@ do end, } + function TypeChecker:begin_temporary_record_types(typ: RecordType) + self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) + + for fname, ftype in fields_of(typ) do + if ftype is TypeAliasType then + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) + elseif ftype is TypeDeclType then + self:add_var(nil, fname, ftype) + end + end + end + + function TypeChecker:end_temporary_record_types(typ: RecordType) + -- drop @self and nested records from scope + -- to avoid closing them prematurely in end_scope() + local scope = self.st[#self.st] + scope.vars["@self"] = nil + for fname, ftype in fields_of(typ) do + if ftype is TypeAliasType or ftype is TypeDeclType then + scope.vars[fname] = nil + end + end + end + local visit_type: Visitor visit_type = { cbs = { ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() - self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) - - for fname, ftype in fields_of(typ) do - if ftype is TypeAliasType then - self:resolve_nominal(ftype.alias_to) - self:add_var(nil, fname, ftype) - elseif ftype is TypeDeclType then - self:add_var(nil, fname, ftype) - end - end + self:begin_temporary_record_types(typ) end, after = function(self: TypeChecker, typ: RecordType, children: {Type}): Type local i = 1 @@ -12402,6 +12418,7 @@ do end end + self:end_temporary_record_types(typ) self:end_scope() return typ From 1d2dffda5479568289dfc18b7d19c3aa0ed4eaa5 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 15 Aug 2024 10:20:26 -0300 Subject: [PATCH 147/224] fix: avoid crash in `is` check This makes me think if this should happen implicitly... --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 48c6fb4e0..ed0be29ca 100644 --- a/tl.lua +++ b/tl.lua @@ -10612,7 +10612,7 @@ self:expand_type(node, values, elements) }) elseif t.typename == "nominal" then self:resolve_nominal(t) local rt = t.resolved - if rt.typename == "typedecl" then + if rt and rt.typename == "typedecl" then t.resolved = rt.def end end diff --git a/tl.tl b/tl.tl index c614555cb..eaa807d10 100644 --- a/tl.tl +++ b/tl.tl @@ -10612,7 +10612,7 @@ do elseif t is NominalType then self:resolve_nominal(t) local rt = t.resolved - if rt is TypeDeclType then + if rt and rt is TypeDeclType then t.resolved = rt.def end end From 06d67782413e8fa57f32ab16b6ba4e19b41a43e4 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 16 Aug 2024 10:41:38 -0300 Subject: [PATCH 148/224] fix: string equality check This triggered during debugging because show_type for two string with the same typeid didn't match. --- tl.lua | 26 +++++++++++--------------- tl.tl | 26 +++++++++++--------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/tl.lua b/tl.lua index ed0be29ca..48ecff69f 100644 --- a/tl.lua +++ b/tl.lua @@ -7214,6 +7214,15 @@ do tostring(nfargs or 0) end + local function drop_constant_value(t) + if t.typename == "string" and t.literal then + local ret = shallow_copy_new_type(t) + ret.literal = nil + return ret + end + return t + end + local function resolve_typedecl(t) if t.typename == "typedecl" then return t.def @@ -7443,14 +7452,10 @@ do local rt = tc:find_var_type(t.typevar) if not rt then return nil - elseif rt.typename == "string" then - - return a_type(rt, "string", {}) end - return rt - end - + return drop_constant_value(rt) + end function TypeChecker:infer_emptytable(emptytable, fresh_t) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") @@ -7523,15 +7528,6 @@ do return ret end - local function drop_constant_value(t) - if t.typename == "string" and t.literal then - local ret = shallow_copy_table(t) - ret.literal = nil - return ret - end - return t - end - function TypeChecker:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) local scope = self.st[#self.st] local var = scope.vars[name] diff --git a/tl.tl b/tl.tl index eaa807d10..1e1babe34 100644 --- a/tl.tl +++ b/tl.tl @@ -7214,6 +7214,15 @@ do or tostring(nfargs or 0) end + local function drop_constant_value(t: Type): Type + if t is StringType and t.literal then + local ret = shallow_copy_new_type(t) + ret.literal = nil + return ret + end + return t + end + local function resolve_typedecl(t: Type): Type if t is TypeDeclType then return t.def @@ -7443,14 +7452,10 @@ do local rt = tc:find_var_type(t.typevar) if not rt then return nil - elseif rt is StringType then - -- tk is not propagated - return a_type(rt, "string", {}) end - return rt - end - + return drop_constant_value(rt) + end function TypeChecker:infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") @@ -7523,15 +7528,6 @@ do return ret end - local function drop_constant_value(t: Type): Type - if t is StringType and t.literal then - local ret = shallow_copy_table(t) - ret.literal = nil - return ret - end - return t - end - function TypeChecker:add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable local scope = self.st[#self.st] local var = scope.vars[name] From 8b0fe09795935b34adb163e379873fe343234808 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 11:28:20 -0300 Subject: [PATCH 149/224] use unknown only in lax mode --- spec/metamethods/call_spec.lua | 2 +- spec/metamethods/index_spec.lua | 2 +- spec/statement/forin_spec.lua | 2 +- tl.lua | 4 +++- tl.tl | 4 +++- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/spec/metamethods/call_spec.lua b/spec/metamethods/call_spec.lua index f6c58a4d2..00a276593 100644 --- a/spec/metamethods/call_spec.lua +++ b/spec/metamethods/call_spec.lua @@ -109,7 +109,7 @@ describe("metamethod __call", function() } local Rec_class_mt = { - __call = function(_: Rec, x, y): Rec.Type + __call = function(_: Rec, x: number, y: number): Rec.Type return setmetatable({ x = x, y = y } as Rec.Type, Rec_instance_mt) end } diff --git a/spec/metamethods/index_spec.lua b/spec/metamethods/index_spec.lua index 6e58ca63e..811839272 100644 --- a/spec/metamethods/index_spec.lua +++ b/spec/metamethods/index_spec.lua @@ -113,7 +113,7 @@ describe("metamethod __index", function() } local Rec_class_mt = { - __call = function(_: Rec, x, y): Rec.Type + __call = function(_: Rec, x: number, y: number): Rec.Type return setmetatable({ x = x, y = y } as Rec.Type, Rec_instance_mt) end } diff --git a/spec/statement/forin_spec.lua b/spec/statement/forin_spec.lua index c473b1c55..3a5916404 100644 --- a/spec/statement/forin_spec.lua +++ b/spec/statement/forin_spec.lua @@ -76,7 +76,7 @@ describe("forin", function() end) it("with an explicit iterator", util.check([[ - local function iter(t): number + local function iter(t: T): number end local t = { 1, 2, 3 } for i in iter, t do diff --git a/tl.lua b/tl.lua index 48ecff69f..c9cb33a5d 100644 --- a/tl.lua +++ b/tl.lua @@ -12134,7 +12134,9 @@ self:expand_type(node, values, elements) }) after = function(self, node, children) local t = children[1] if not t then - t = a_type(node, "unknown", {}) + t = self.feat_lax and + a_type(node, "unknown", {}) or + a_type(node, "any", {}) end if node.tk == "..." then t = a_vararg(node, { t }) diff --git a/tl.tl b/tl.tl index 1e1babe34..7db35880f 100644 --- a/tl.tl +++ b/tl.tl @@ -12134,7 +12134,9 @@ do after = function(self: TypeChecker, node: Node, children: {Type}): Type local t = children[1] if not t then - t = an_unknown(node) + t = self.feat_lax + and an_unknown(node) + or a_type(node, "any", {}) end if node.tk == "..." then t = a_vararg(node, { t }) From 7cbedd2a0ef25bf3228f6755128d5775097dfec0 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 11:09:21 -0300 Subject: [PATCH 150/224] minor tweak: no need to copy non-method functions --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index c9cb33a5d..1a546b702 100644 --- a/tl.lua +++ b/tl.lua @@ -10164,8 +10164,8 @@ a.types[i], b.types[i]), } local ftype = table.remove(b.tuple, 1) - ftype = shallow_copy_new_type(ftype) if ftype.typename == "function" then + ftype = shallow_copy_new_type(ftype) ftype.is_method = false end diff --git a/tl.tl b/tl.tl index 7db35880f..2e3a7182a 100644 --- a/tl.tl +++ b/tl.tl @@ -10164,8 +10164,8 @@ do -- so we wish to avoid incorrect error messages / unnecessary warning messages -- associated with calling methods as functions local ftype = table.remove(b.tuple, 1) - ftype = shallow_copy_new_type(ftype) if ftype is FunctionType then + ftype = shallow_copy_new_type(ftype) ftype.is_method = false end From 0478bd42f4e8fc7a9b1b5e2c4f0495251fd8faa2 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 09:07:54 -0300 Subject: [PATCH 151/224] Type: add __tostring metamethod --- tl.lua | 27 ++++++++++++++++++++++----- tl.tl | 27 ++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/tl.lua b/tl.lua index 1a546b702..8eb68c993 100644 --- a/tl.lua +++ b/tl.lua @@ -2080,6 +2080,14 @@ local Node = {ExpectedContext = {}, } + +local show_type + +local type_mt = { + __tostring = function(t) + return show_type(t) + end, +} local function a_type(w, typename, t) t.typeid = new_typeid() @@ -2087,6 +2095,10 @@ local function a_type(w, typename, t) t.x = w.x t.y = w.y t.typename = typename + do + local ty = t + setmetatable(ty, type_mt) + end return t end @@ -2096,6 +2108,7 @@ local function edit_type(w, t, typename) t.x = w.x t.y = w.y t.typename = typename + setmetatable(t, type_mt) return t end @@ -2150,6 +2163,10 @@ local function shallow_copy_new_type(t) copy[k] = v end copy.typeid = new_typeid() + do + local ty = copy + setmetatable(ty, type_mt) + end return copy end @@ -2286,6 +2303,7 @@ do t.x = token.x t.y = token.y t.typename = typename + setmetatable(t, type_mt) return t end @@ -4329,8 +4347,6 @@ local function fields_of(t, meta) end end -local show_type - local tl_debug_indent = 0 @@ -6662,7 +6678,7 @@ local function show_type_base(t, short, seen) elseif t.typename == "typedecl" then return "type " .. show(t.def) else - return "<" .. t.typename .. " " .. tostring(t) .. ">" + return "<" .. t.typename .. ">" end end @@ -7288,6 +7304,7 @@ do local copy = {} seen[orig_t] = copy + setmetatable(copy, type_mt) copy.typename = t.typename copy.f = t.f copy.x = t.x @@ -7509,7 +7526,7 @@ do end if ret == t or t.typename == "typevar" then - ret = shallow_copy_table(ret) + ret = shallow_copy_new_type(ret) end return type_at(w, ret) end @@ -7521,7 +7538,7 @@ do end if ret == t or t.typename == "typevar" then - ret = shallow_copy_table(ret) + ret = shallow_copy_new_type(ret) end assert(w.f) ret.inferred_at = w diff --git a/tl.tl b/tl.tl index 2e3a7182a..5ff83a2ba 100644 --- a/tl.tl +++ b/tl.tl @@ -2081,12 +2081,24 @@ local record Node debug_type: Type end +local show_type: function(Type, ? boolean, ? {Type:string}): string + +local type_mt: metatable = { + __tostring = function(t: Type): string + return show_type(t) + end +} + local function a_type(w: Where, typename: TypeName, t: T): T t.typeid = new_typeid() t.f = w.f t.x = w.x t.y = w.y t.typename = typename + do + local ty: Type = t + setmetatable(ty, type_mt) + end return t end @@ -2096,6 +2108,7 @@ local function edit_type(w: Where, t: Type, typename: TypeName): Type t.x = w.x t.y = w.y t.typename = typename + setmetatable(t, type_mt) return t end @@ -2150,6 +2163,10 @@ local function shallow_copy_new_type(t: T): T copy[k] = v end copy.typeid = new_typeid() + do + local ty: Type = copy as T + setmetatable(ty, type_mt) + end return copy as T end @@ -2286,6 +2303,7 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type t.x = token.x t.y = token.y t.typename = typename + setmetatable(t, type_mt) return t end @@ -4329,8 +4347,6 @@ local function fields_of(t: RecordLikeType, meta?: MetaMode): (function(): strin end end -local show_type: function(Type, ? boolean, ? {Type:string}): string - local tl_debug_indent = 0 local record DebugEntry mark: string @@ -6662,7 +6678,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str elseif t is TypeDeclType then return "type " .. show(t.def) else - return "<" .. t.typename .. " " .. tostring(t) .. ">" + return "<" .. t.typename .. ">" -- TODO add string.format("%p", t) with compat-5.4 end end @@ -7288,6 +7304,7 @@ do local copy: Type = {} seen[orig_t] = copy + setmetatable(copy, type_mt) copy.typename = t.typename copy.f = t.f copy.x = t.x @@ -7509,7 +7526,7 @@ do end if ret == t or t is TypeVarType then - ret = shallow_copy_table(ret) + ret = shallow_copy_new_type(ret) end return type_at(w, ret) end @@ -7521,7 +7538,7 @@ do end if ret == t or t is TypeVarType then - ret = shallow_copy_table(ret) + ret = shallow_copy_new_type(ret) end assert(w.f) ret.inferred_at = w From f8d56653cc16c7fd7c78a3bc5a323a6c21abf3e8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 10:49:30 -0300 Subject: [PATCH 152/224] no need to type_at(a_typedecl()) --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 8eb68c993..0df035fc4 100644 --- a/tl.lua +++ b/tl.lua @@ -12309,7 +12309,7 @@ self:expand_type(node, values, elements) }) } function TypeChecker:begin_temporary_record_types(typ) - self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) + self:add_var(nil, "@self", a_type(typ, "typedecl", { def = typ })) for fname, ftype in fields_of(typ) do if ftype.typename == "typealias" then diff --git a/tl.tl b/tl.tl index 5ff83a2ba..97b58b768 100644 --- a/tl.tl +++ b/tl.tl @@ -12309,7 +12309,7 @@ do } function TypeChecker:begin_temporary_record_types(typ: RecordType) - self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) + self:add_var(nil, "@self", a_typedecl(typ, typ)) for fname, ftype in fields_of(typ) do if ftype is TypeAliasType then From bbe6c07cb34e94f8dd9b3d6debb9817b9cc9a8fa Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 18 Aug 2024 15:10:43 -0300 Subject: [PATCH 153/224] begin cleaning up generics in type declarations --- spec/declaration/local_spec.lua | 7 +++++++ tl.lua | 16 ++++++++++++++-- tl.tl | 16 ++++++++++++++-- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index a141056f3..cd7eb9ee7 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -551,4 +551,11 @@ describe("local", function() { tag = "redeclaration", msg = "variable shadows previous declaration of 'integer'" }, { tag = "unused", msg = "unused type integer" }, })) + + it("does not accept type arguments declared twice", util.check_syntax_error([[ + local type Foo = record + end + ]], { + { y = 1, msg = "cannot declare type arguments twice in type declaration" }, + })) end) diff --git a/tl.lua b/tl.lua index 0df035fc4..90a4ddcda 100644 --- a/tl.lua +++ b/tl.lua @@ -4056,6 +4056,7 @@ do return fail(ps, i, "expected a type name") end local typeargs + local itypeargs = i if ps.tokens[i].tk == "<" then i, typeargs = parse_anglebracket_list(ps, i, parse_typearg) end @@ -4097,11 +4098,22 @@ do local nt = asgn.value.newtype if nt.typename == "typedecl" then + local def = nt.def + if typeargs then - nt.typeargs = typeargs + if def.typeargs then + if def.typeargs then + fail(ps, itypeargs, "cannot declare type arguments twice in type declaration") + else + def.typeargs = typeargs + end + else + + + nt.typeargs = typeargs + end end - local def = nt.def if def.fields or def.typename == "enum" then if not def.declname then def.declname = asgn.var.tk diff --git a/tl.tl b/tl.tl index 97b58b768..4ca681f2f 100644 --- a/tl.tl +++ b/tl.tl @@ -4056,6 +4056,7 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin return fail(ps, i, "expected a type name") end local typeargs: {TypeArgType} + local itypeargs = i if ps.tokens[i].tk == "<" then i, typeargs = parse_anglebracket_list(ps, i, parse_typearg) end @@ -4097,11 +4098,22 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin local nt = asgn.value.newtype if nt is TypeDeclType then + local def = nt.def + if typeargs then - nt.typeargs = typeargs + if def is HasTypeArgs then + if def.typeargs then + fail(ps, itypeargs, "cannot declare type arguments twice in type declaration") + else + def.typeargs = typeargs + end + else + -- FIXME how to resolve type arguments in unions properly + -- fail(ps, itypeargs, def.typename .. " does not accept type arguments") + nt.typeargs = typeargs + end end - local def = nt.def if def is RecordLikeType or def is EnumType then if not def.declname then def.declname = asgn.var.tk From a12249fd136eb5c7fd5097d85fd9b6047668f160 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 18 Aug 2024 21:35:51 -0300 Subject: [PATCH 154/224] internal debugging improvements: identify named types --- tl.lua | 23 +++++++++++++++-------- tl.tl | 23 +++++++++++++++-------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tl.lua b/tl.lua index 90a4ddcda..a05b7a735 100644 --- a/tl.lua +++ b/tl.lua @@ -6517,8 +6517,10 @@ local function is_unknown(t) t.typename == "unresolved_emptytable_value" end -local function display_typevar(typevar) - return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) +local function display_typevar(typevar, what) + return TL_DEBUG and + (what .. " " .. typevar) or + typevar:gsub("@.*", "") end local function show_fields(t, show) @@ -6573,6 +6575,7 @@ local function show_type_base(t, short, seen) return "self" end + local ret if t.typevals then local out = { table.concat(t.names, "."), "<" } local vals = {} @@ -6581,10 +6584,14 @@ local function show_type_base(t, short, seen) end table.insert(out, table.concat(vals, ", ")) table.insert(out, ">") - return table.concat(out) + ret = table.concat(out) else - return table.concat(t.names, ".") + ret = table.concat(t.names, ".") + end + if TL_DEBUG then + ret = "nominal " .. ret end + return ret elseif t.typename == "tuple" then local out = {} for _, v in ipairs(t.tuple) do @@ -6670,11 +6677,11 @@ local function show_type_base(t, short, seen) (t.literal and string.format(" %q", t.literal) or "") end elseif t.typename == "typevar" then - return display_typevar(t.typevar) + return display_typevar(t.typevar, "typevar") elseif t.typename == "typearg" then - return display_typevar(t.typearg) + return display_typevar(t.typearg, "typearg") elseif t.typename == "unresolvable_typearg" then - return display_typevar(t.typearg) .. " (unresolved generic)" + return display_typevar(t.typearg, "typearg") .. " (unresolved generic)" elseif is_unknown(t) then return "" elseif t.typename == "invalid" then @@ -8198,7 +8205,7 @@ do if constraint then if not self:is_a(other, constraint) then - return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar, "typevar"), other, constraint) } end if self:same_type(other, constraint) then diff --git a/tl.tl b/tl.tl index 4ca681f2f..dc036799e 100644 --- a/tl.tl +++ b/tl.tl @@ -6517,8 +6517,10 @@ local function is_unknown(t: Type): boolean or t.typename == "unresolved_emptytable_value" end -local function display_typevar(typevar: string): string - return TL_DEBUG and typevar or (typevar:gsub("@.*", "")) +local function display_typevar(typevar: string, what: TypeName): string + return TL_DEBUG + and (what .. " " .. typevar) + or typevar:gsub("@.*", "") end local function show_fields(t: RecordLikeType, show: function(Type):(string)): string @@ -6573,6 +6575,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "self" end + local ret: string if t.typevals then local out = { table.concat(t.names, "."), "<" } local vals: {string} = {} @@ -6581,10 +6584,14 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end table.insert(out, table.concat(vals, ", ")) table.insert(out, ">") - return table.concat(out) + ret = table.concat(out) else - return table.concat(t.names, ".") + ret = table.concat(t.names, ".") end + if TL_DEBUG then + ret = "nominal " .. ret + end + return ret elseif t is TupleType then local out: {string} = {} for _, v in ipairs(t.tuple) do @@ -6670,11 +6677,11 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str (t.literal and string.format(" %q", t.literal) or "") end elseif t is TypeVarType then - return display_typevar(t.typevar) + return display_typevar(t.typevar, "typevar") elseif t is TypeArgType then - return display_typevar(t.typearg) + return display_typevar(t.typearg, "typearg") elseif t is UnresolvableTypeArgType then - return display_typevar(t.typearg) .. " (unresolved generic)" + return display_typevar(t.typearg, "typearg") .. " (unresolved generic)" elseif is_unknown(t) then return "" elseif t.typename == "invalid" then @@ -8198,7 +8205,7 @@ do -- but check interface constraint first if present if constraint then if not self:is_a(other, constraint) then - return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar, "typevar"), other, constraint) } end if self:same_type(other, constraint) then From f84fede74968ea9110d3249670232e384321ade9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 18 Aug 2024 21:36:41 -0300 Subject: [PATCH 155/224] fix propagation of type arguments Fixes #777. --- spec/declaration/local_spec.lua | 22 ++++++++++++ tl.lua | 53 ++++++++++++++++++----------- tl.tl | 59 ++++++++++++++++++++------------- 3 files changed, 91 insertions(+), 43 deletions(-) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index cd7eb9ee7..ca6324fc5 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -558,4 +558,26 @@ describe("local", function() ]], { { y = 1, msg = "cannot declare type arguments twice in type declaration" }, })) + + it("propagates type arguments correctly", util.check_type_error([[ + local record module + record Foo + first: A + second: B + end + end + + -- note inverted arguments + local type MyFoo = module.Foo + + local record Boo + field: MyFoo + end + + local b: Boo = { field = { first = "first", second = 2 } } -- bad, not inverted! + local c: Boo = { field = { first = 1, second = "second" } } -- good, inverted! + ]], { + { y = 15, x = 42, msg = 'in record field: first: got string "first", expected integer' }, + { y = 15, x = 60, msg = 'in record field: second: got integer, expected string' }, + })) end) diff --git a/tl.lua b/tl.lua index a05b7a735..e2f89e89f 100644 --- a/tl.lua +++ b/tl.lua @@ -7280,6 +7280,24 @@ do ["unknown"] = true, } + local function clear_resolved_typeargs(copy, resolved) + if not copy.typeargs then + return + end + + for i = #copy.typeargs, 1, -1 do + local r = resolved[copy.typeargs[i].typearg] + if r then + table.remove(copy.typeargs, i) + end + end + if not copy.typeargs[1] then + copy.typeargs = nil + end + + return + end + typevar_resolver = function(self, typ, fn_var, fn_arg) local errs local seen = {} @@ -7457,31 +7475,14 @@ do return copy, same and all_same end - local copy, same = resolve(typ, true) + local copy = resolve(typ, true) if errs then return false, a_type(typ, "invalid", {}), errs end - if (not same) and - (copy.typename == "function" or copy.fields) and - copy.typeargs then - - for i = #copy.typeargs, 1, -1 do - local r = resolved[copy.typeargs[i].typearg] - if r then + clear_resolved_typeargs(copy, resolved) - if r.typename == "nominal" and #r.names == 1 then - copy.typeargs[i].typearg = r.names[1] - else - table.remove(copy.typeargs, i) - end - end - end - if not copy.typeargs[1] then - copy.typeargs = nil - end - end - return true, copy + return true, copy, nil, resolved end local function resolve_typevar(tc, t) @@ -10804,6 +10805,18 @@ self:expand_type(node, values, elements) }) before = function(self, node) local name = node.var.tk local resolved, aliasing = self:get_typedecl(node.value) + local nt = node.value.newtype + if nt and nt.typename == "typealias" and resolved.typename == "typedecl" then + if nt.typeargs then + local def = resolved.def + + + + if def.typename == "record" or def.typename == "function" or def.typename == "interface" then + def.typeargs = nt.typeargs + end + end + end local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing diff --git a/tl.tl b/tl.tl index dc036799e..d0b2050b1 100644 --- a/tl.tl +++ b/tl.tl @@ -7067,7 +7067,7 @@ do end local type ResolveType = function(S, Type): Type - local typevar_resolver: function(s: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + local typevar_resolver: function(s: S, typ: Type, fn_var: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local function fresh_typevar(_: nil, t: TypeVarType): Type, Type, boolean return a_type(t, "typevar", { @@ -7280,7 +7280,25 @@ do ["unknown"] = true, } - typevar_resolver = function(self: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + local function clear_resolved_typeargs(copy: Type, resolved: {string:Type}) + if not copy is HasTypeArgs then + return + end + + for i = #copy.typeargs, 1, -1 do + local r = resolved[copy.typeargs[i].typearg] + if r then + table.remove(copy.typeargs, i) + end + end + if not copy.typeargs[1] then + copy.typeargs = nil + end + + return + end + + typevar_resolver = function(self: S, typ: Type, fn_var: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error}, {string:Type} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:Type} = {} @@ -7457,31 +7475,14 @@ do return copy, same and all_same end - local copy, same = resolve(typ, true) + local copy = resolve(typ, true) if errs then return false, an_invalid(typ), errs end - if (not same) and - (copy is FunctionType or copy is RecordLikeType) and - copy.typeargs - then - for i = #copy.typeargs, 1, -1 do - local r = resolved[copy.typeargs[i].typearg] - if r then - -- FIXME HACK!!! - if r is NominalType and #r.names == 1 then - copy.typeargs[i].typearg = r.names[1] - else - table.remove(copy.typeargs, i) - end - end - end - if not copy.typeargs[1] then - copy.typeargs = nil - end - end - return true, copy + clear_resolved_typeargs(copy, resolved) + + return true, copy, nil, resolved end local function resolve_typevar(tc: TypeChecker, t: TypeVarType): Type @@ -10804,6 +10805,18 @@ do before = function(self: TypeChecker, node: Node) local name = node.var.tk local resolved, aliasing = self:get_typedecl(node.value) + local nt = node.value.newtype + if nt and nt is TypeAliasType and resolved is TypeDeclType then + if nt.typeargs then + local def = resolved.def + -- FIXME ideally we'd like to use `if def is HasTypeArgs` + -- here, but if def.typeargs happens to be nil, the `is` + -- check won't work + if def is RecordType or def is FunctionType or def is InterfaceType then + def.typeargs = nt.typeargs + end + end + end local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing From 4a829b02552c5afab2a8aef2c38dd29438ee94ec Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 10:49:47 -0300 Subject: [PATCH 156/224] self: introduce self type --- spec/cli/types_spec.lua | 14 +++---- tl.lua | 91 ++++++++++++++++++++++++++++++++-------- tl.tl | 93 +++++++++++++++++++++++++++++++++-------- 3 files changed, 156 insertions(+), 42 deletions(-) diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 24f36ff05..61d41f549 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -335,17 +335,17 @@ describe("tl types works like check", function() assert(types.by_pos) local by_pos = types.by_pos[next(types.by_pos)] assert.same({ - ["19"] = 8, - ["22"] = 8, - ["23"] = 6, + ["19"] = 9, + ["22"] = 9, + ["23"] = 7, ["30"] = 2, - ["41"] = 8, + ["41"] = 9, }, by_pos["1"]) assert.same({ - ["17"] = 6, + ["17"] = 7, ["20"] = 2, - ["25"] = 9, - ["31"] = 8, + ["25"] = 10, + ["31"] = 9, }, by_pos["2"]) end) end) diff --git a/tl.lua b/tl.lua index e2f89e89f..50c57be25 100644 --- a/tl.lua +++ b/tl.lua @@ -687,6 +687,7 @@ tl.typecodes = { MAP = 0x00040008, TUPLE = 0x00080008, INTERFACE = 0x00100008, + SELF = 0x00200008, POLY = 0x20000020, UNION = 0x40000000, @@ -1537,6 +1538,7 @@ end + local table_types = { @@ -1544,6 +1546,7 @@ local table_types = { ["map"] = true, ["record"] = true, ["interface"] = true, + ["self"] = true, ["emptytable"] = true, ["tupletable"] = true, @@ -1892,6 +1895,11 @@ end + + + + + @@ -2213,6 +2221,7 @@ local simple_types = { ["thread"] = true, ["boolean"] = true, ["integer"] = true, + ["self"] = true, } do @@ -5614,6 +5623,7 @@ local typename_to_typecode = { ["map"] = tl.typecodes.MAP, ["tupletable"] = tl.typecodes.TUPLE, ["interface"] = tl.typecodes.INTERFACE, + ["self"] = tl.typecodes.SELF, ["record"] = tl.typecodes.RECORD, ["enum"] = tl.typecodes.ENUM, ["boolean"] = tl.typecodes.BOOLEAN, @@ -6592,6 +6602,8 @@ local function show_type_base(t, short, seen) ret = "nominal " .. ret end return ret + elseif t.typename == "self" then + return "self" elseif t.typename == "tuple" then local out = {} for _, v in ipairs(t.tuple) do @@ -6643,15 +6655,10 @@ local function show_type_base(t, short, seen) end table.insert(out, "(") local args = {} - if t.is_method then - table.insert(args, "self") - end for i, v in ipairs(t.args.tuple) do - if not t.is_method or i > 1 then - table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " or - (i > t.min_arity) and "? " or - "") .. show(v)) - end + table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " or + (i > t.min_arity) and "? " or + "") .. show(v)) end table.insert(out, table.concat(args, ", ")) table.insert(out, ")") @@ -8051,7 +8058,7 @@ do end local function is_self(t) - return t.typename == "nominal" and t.names[1] == "@self" + return t.typename == "self" or (t.typename == "nominal" and t.names[1] == "@self") end local function compare_true(_, _, _) @@ -8229,6 +8236,19 @@ do end end + function TypeChecker:type_of_self(w) + local t = self:find_var_type("@self") + if not t then + return a_type(w, "invalid", {}) + end + + if t.typename == "typedecl" then + t = t.def + end + + return t + end + function TypeChecker:exists_supertype_in(t, xs) for _, x in ipairs(xs.types) do @@ -8339,7 +8359,18 @@ do return any_errors(errs) end, }, + ["self"] = { + ["self"] = function(_self, _a, _b) + return true + end, + ["*"] = function(self, a, b) + return self:same_type(self:type_of_self(a), b) + end, + }, ["*"] = { + ["self"] = function(self, a, b) + return self:same_type(a, self:type_of_self(b)) + end, ["typevar"] = function(self, a, b) return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, @@ -8641,6 +8672,14 @@ a.types[i], b.types[i]), } return any_errors(errs) end, }, + ["self"] = { + ["self"] = function(_self, _a, _b) + return true + end, + ["*"] = function(self, a, b) + return self:is_a(self:type_of_self(a), b) + end, + }, ["typearg"] = { ["typearg"] = function(_self, a, b) return a.typearg == b.typearg @@ -8653,6 +8692,9 @@ a.types[i], b.types[i]), } }, ["*"] = { ["any"] = compare_true, + ["self"] = function(self, a, b) + return self:is_a(a, self:type_of_self(b)) + end, ["tuple"] = function(self, a, b) return self:is_a(a_type(a, "tuple", { tuple = { a } }), b) end, @@ -9620,6 +9662,8 @@ a.types[i], b.types[i]), } errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end + elseif ra.typename == "self" then + return self:type_check_index(anode, bnode, self:type_of_self(a), b) elseif ra.elements and rb.typename == "integer" then return ra.elements elseif ra.typename == "emptytable" then @@ -9643,6 +9687,15 @@ a.types[i], b.types[i]), } elseif rb.typename == "string" and rb.literal then local t, e = self:match_record_key(a, anode, rb.literal) if t then + + if t.typename == "function" then + for i, p in ipairs(t.args.tuple) do + if p.typename == "self" then + t.args.tuple[i] = a + end + end + end + return t end @@ -11613,6 +11666,7 @@ self:expand_type(node, values, elements) }) end args.tuple[1] = selftype self:add_var(nil, "self", selftype) + self:add_var(nil, "@self", a_type(node, "typedecl", { def = selftype })) end local fn_type = self:ensure_fresh_typeargs(a_function(node, { @@ -12416,14 +12470,16 @@ self:expand_type(node, values, elements) }) local record_name = typ.declname if record_name then local selfarg = fargs[1] - if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - elseif typ.typeargs then - for j = 1, #typ.typeargs do - local tv = selfarg.typevals[j] - if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then - ftype.is_method = false - break + if selfarg.typename == "nominal" then + if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + elseif typ.typeargs then + for j = 1, #typ.typeargs do + local tv = selfarg.typevals[j] + if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then + ftype.is_method = false + break + end end end end @@ -12551,6 +12607,7 @@ self:expand_type(node, values, elements) }) visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs + visit_type.cbs["self"] = default_type_visitor visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor diff --git a/tl.tl b/tl.tl index d0b2050b1..bb92ebf48 100644 --- a/tl.tl +++ b/tl.tl @@ -687,6 +687,7 @@ tl.typecodes = { MAP = 0x00040008, TUPLE = 0x00080008, INTERFACE = 0x00100008, + SELF = 0x00200008, POLY = 0x20000020, UNION = 0x40000000, -- Indirect types @@ -1515,6 +1516,7 @@ local enum TypeName "tupletable" "record" "interface" + "self" "enum" "boolean" "string" @@ -1544,6 +1546,7 @@ local table_types : {TypeName:boolean} = { ["map"] = true, ["record"] = true, ["interface"] = true, + ["self"] = true, ["emptytable"] = true, ["tupletable"] = true, @@ -1672,6 +1675,11 @@ local record NominalType resolved: Type -- type is found and typeargs are resolved end +local record SelfType + is Type + where self.typename == "self" +end + local interface ArrayLikeType is Type where self.elements @@ -2213,6 +2221,7 @@ local simple_types: {TypeName:boolean} = { ["thread"] = true, ["boolean"] = true, ["integer"] = true, + ["self"] = true, } do ----------------------------------------------------------------------------- @@ -5614,6 +5623,7 @@ local typename_to_typecode : {TypeName:integer} = { ["map"] = tl.typecodes.MAP, ["tupletable"] = tl.typecodes.TUPLE, ["interface"] = tl.typecodes.INTERFACE, + ["self"] = tl.typecodes.SELF, ["record"] = tl.typecodes.RECORD, ["enum"] = tl.typecodes.ENUM, ["boolean"] = tl.typecodes.BOOLEAN, @@ -6592,6 +6602,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str ret = "nominal " .. ret end return ret + elseif t is SelfType then + return "self" elseif t is TupleType then local out: {string} = {} for _, v in ipairs(t.tuple) do @@ -6643,15 +6655,10 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end table.insert(out, "(") local args = {} - if t.is_method then - table.insert(args, "self") - end for i, v in ipairs(t.args.tuple) do - if not t.is_method or i > 1 then - table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " - or (i > t.min_arity) and "? " - or "") .. show(v)) - end + table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " + or (i > t.min_arity) and "? " + or "") .. show(v)) end table.insert(out, table.concat(args, ", ")) table.insert(out, ")") @@ -8051,7 +8058,7 @@ do end local function is_self(t: Type): boolean - return t is NominalType and t.names[1] == "@self" + return t is SelfType or (t is NominalType and t.names[1] == "@self") end local function compare_true(_: TypeChecker, _: Type, _: Type): boolean, {Error} @@ -8229,6 +8236,19 @@ do end end + function TypeChecker:type_of_self(w: Where): Type + local t = self:find_var_type("@self") + if not t then + return an_invalid(w) + end + + if t is TypeDeclType then + t = t.def + end + + return t + end + -- ∃ x ∈ xs. t <: x function TypeChecker:exists_supertype_in(t: Type, xs: AggregateType): Type for _, x in ipairs(xs.types) do @@ -8339,7 +8359,18 @@ do return any_errors(errs) end, }, + ["self"] = { + ["self"] = function(_self: TypeChecker, _a: SelfType, _b: SelfType): boolean, {Error} + return true + end, + ["*"] = function(self: TypeChecker, a: SelfType, b: Type): boolean, {Error} + return self:same_type(self:type_of_self(a), b) + end, + }, ["*"] = { + ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} + return self:same_type(a, self:type_of_self(b)) + end, ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, @@ -8641,6 +8672,14 @@ do return any_errors(errs) end, }, + ["self"] = { + ["self"] = function(_self: TypeChecker, _a: SelfType, _b: SelfType): boolean, {Error} + return true + end, + ["*"] = function(self: TypeChecker, a: SelfType, b: Type): boolean, {Error} + return self:is_a(self:type_of_self(a), b) + end, + }, ["typearg"] = { ["typearg"] = function(_self: TypeChecker, a: TypeArgType, b: TypeArgType): boolean, {Error} return a.typearg == b.typearg @@ -8653,6 +8692,9 @@ do }, ["*"] = { ["any"] = compare_true, + ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} + return self:is_a(a, self:type_of_self(b)) + end, ["tuple"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} return self:is_a(a_tuple(a, {a}), b) end, @@ -9620,6 +9662,8 @@ do errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end + elseif ra is SelfType then + return self:type_check_index(anode, bnode, self:type_of_self(a), b) elseif ra is ArrayLikeType and rb is IntegerType then return ra.elements elseif ra is EmptyTableType then @@ -9643,6 +9687,15 @@ do elseif rb is StringType and rb.literal then local t, e = self:match_record_key(a, anode, rb.literal) if t then + + if t is FunctionType then + for i, p in ipairs(t.args.tuple) do + if p is SelfType then + t.args.tuple[i] = a + end + end + end + return t end @@ -11613,6 +11666,7 @@ do end args.tuple[1] = selftype self:add_var(nil, "self", selftype) + self:add_var(nil, "@self", a_typedecl(node, selftype)) end local fn_type = self:ensure_fresh_typeargs(a_function(node, { @@ -12415,15 +12469,17 @@ do if fargs[1] then local record_name = typ.declname if record_name then - local selfarg = fargs[1] as NominalType - if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - elseif typ.typeargs then - for j=1,#typ.typeargs do - local tv = selfarg.typevals[j] - if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then - ftype.is_method = false - break + local selfarg = fargs[1] + if selfarg is NominalType then + if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then + ftype.is_method = false + elseif typ.typeargs then + for j=1,#typ.typeargs do + local tv = selfarg.typevals[j] + if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then + ftype.is_method = false + break + end end end end @@ -12551,6 +12607,7 @@ do visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs + visit_type.cbs["self"] = default_type_visitor visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor visit_type.cbs["array"] = default_type_visitor From b6d5d7b98dcece5738d2bdf79182939328c08177 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 10:46:10 -0300 Subject: [PATCH 157/224] self: use self type --- tl.lua | 20 ++++++++------------ tl.tl | 26 +++++++++++--------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/tl.lua b/tl.lua index 50c57be25..ae8d63f2b 100644 --- a/tl.lua +++ b/tl.lua @@ -3683,22 +3683,14 @@ do return i, node end - local function parse_where_clause(ps, i, typeargs) + local function parse_where_clause(ps, i) local node = new_node(ps, i, "macroexp") - local selftype = new_nominal(ps, i, "@self") - if typeargs then - selftype.typevals = {} - for a, t in ipairs(typeargs) do - selftype.typevals[a] = a_nominal(node, { t.typearg }) - end - end - node.is_method = true node.args = new_node(ps, i, "argument_list") node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = selftype + node.args[1].argtype = new_type(ps, i, "self") node.min_arity = 1 node.rets = new_tuple(ps, i) node.rets.tuple[1] = new_type(ps, i, "boolean") @@ -3789,7 +3781,7 @@ do local wstart = i i = i + 1 local where_macroexp - i, where_macroexp = parse_where_clause(ps, i, def.typeargs) + i, where_macroexp = parse_where_clause(ps, i) local typ = new_type(ps, wstart, "function") if def.typeargs then @@ -8724,6 +8716,7 @@ a.types[i], b.types[i]), } TypeChecker.type_priorities = { + ["self"] = 1, ["tuple"] = 2, ["typevar"] = 3, ["nil"] = 4, @@ -11664,7 +11657,7 @@ self:expand_type(node, values, elements) }) self.errs:add(node, "could not resolve type of self") return end - args.tuple[1] = selftype + args.tuple[1] = a_type(node, "self", {}) self:add_var(nil, "self", selftype) self:add_var(nil, "@self", a_type(node, "typedecl", { def = selftype })) end @@ -12484,6 +12477,9 @@ self:expand_type(node, values, elements) }) end end end + if ftype.is_method then + fargs[1] = a_type(fargs[1], "self", {}) + end end end elseif ftype.typename == "typealias" then diff --git a/tl.tl b/tl.tl index bb92ebf48..004d1f8b0 100644 --- a/tl.tl +++ b/tl.tl @@ -640,7 +640,7 @@ local record TypeReporter next_num: integer tr: TypeReport - get_typenum: function(TypeReporter, Type): integer + get_typenum: function(self, Type): integer end tl.version = function(): string @@ -3683,22 +3683,14 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): return i, node end -local function parse_where_clause(ps: ParseState, i: integer, typeargs: {TypeArgType}): integer, Node +local function parse_where_clause(ps: ParseState, i: integer): integer, Node local node = new_node(ps, i, "macroexp") - local selftype = new_nominal(ps, i, "@self") - if typeargs then - selftype.typevals = {} - for a, t in ipairs(typeargs) do - selftype.typevals[a] = a_nominal(node, { t.typearg }) - end - end - node.is_method = true node.args = new_node(ps, i, "argument_list") node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = selftype + node.args[1].argtype = new_type(ps, i, "self") node.min_arity = 1 node.rets = new_tuple(ps, i) node.rets.tuple[1] = new_type(ps, i, "boolean") @@ -3789,7 +3781,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no local wstart = i i = i + 1 local where_macroexp: Node - i, where_macroexp = parse_where_clause(ps, i, def.typeargs) + i, where_macroexp = parse_where_clause(ps, i) local typ = new_type(ps, wstart, "function") as FunctionType if def.typeargs then @@ -7015,8 +7007,8 @@ do feat_arity: boolean feat_lax: boolean - same_type: function(TypeChecker, Type, Type): boolean, {Error} - is_a: function(TypeChecker, Type, Type): boolean, {Error} + same_type: function(self, Type, Type): boolean, {Error} + is_a: function(self, Type, Type): boolean, {Error} type_check_funcall: function(TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType @@ -8724,6 +8716,7 @@ do -- evaluation strategy TypeChecker.type_priorities = { -- types that have catch-all rules evaluate first + ["self"] = 1, ["tuple"] = 2, ["typevar"] = 3, ["nil"] = 4, @@ -11664,7 +11657,7 @@ do self.errs:add(node, "could not resolve type of self") return end - args.tuple[1] = selftype + args.tuple[1] = a_type(node, "self", {}) self:add_var(nil, "self", selftype) self:add_var(nil, "@self", a_typedecl(node, selftype)) end @@ -12484,6 +12477,9 @@ do end end end + if ftype.is_method then + fargs[1] = a_type(fargs[1], "self", {}) + end end end elseif ftype is TypeAliasType then From 30e603a8f59372999d3bc6bd2a703dadecb44b4f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 8 Aug 2024 11:46:31 -0300 Subject: [PATCH 158/224] self: remove magic behavior for nominal "@self" --- tl.lua | 17 +++-------------- tl.tl | 17 +++-------------- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/tl.lua b/tl.lua index ae8d63f2b..cc8e2e3f0 100644 --- a/tl.lua +++ b/tl.lua @@ -3790,7 +3790,7 @@ do typ.is_method = true typ.min_arity = 1 typ.args = new_tuple(ps, wstart, { - a_nominal(where_macroexp, { "@self" }), + a_type(where_macroexp, "self", {}), }) typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) typ.macroexp = where_macroexp @@ -6573,10 +6573,6 @@ local function show_type_base(t, short, seen) end if t.typename == "nominal" then - if #t.names == 1 and t.names[1] == "@self" then - return "self" - end - local ret if t.typevals then local out = { table.concat(t.names, "."), "<" } @@ -8049,19 +8045,11 @@ do return arr_type end - local function is_self(t) - return t.typename == "self" or (t.typename == "nominal" and t.names[1] == "@self") - end - local function compare_true(_, _, _) return true end function TypeChecker:subtype_nominal(a, b) - if is_self(a) and is_self(b) then - return true - end - local ra = a.typename == "nominal" and self:resolve_nominal(a) or a local rb = b.typename == "nominal" and self:resolve_nominal(b) or b local ok, errs = self:is_a(ra, rb) @@ -9092,7 +9080,8 @@ a.types[i], b.types[i]), } if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then + local first = fargs[1] + if (not (first.typename == "self")) and not self:arg_check(w, errs, first, args.tuple[1], "contravariant", "self") then return nil, errs end end diff --git a/tl.tl b/tl.tl index 004d1f8b0..4eb756561 100644 --- a/tl.tl +++ b/tl.tl @@ -3790,7 +3790,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no typ.is_method = true typ.min_arity = 1 typ.args = new_tuple(ps, wstart, { - a_nominal(where_macroexp, { "@self" }) + a_type(where_macroexp, "self", {}) }) typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) typ.macroexp = where_macroexp @@ -6573,10 +6573,6 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end if t is NominalType then - if #t.names == 1 and t.names[1] == "@self" then - return "self" - end - local ret: string if t.typevals then local out = { table.concat(t.names, "."), "<" } @@ -8049,19 +8045,11 @@ do return arr_type end - local function is_self(t: Type): boolean - return t is SelfType or (t is NominalType and t.names[1] == "@self") - end - local function compare_true(_: TypeChecker, _: Type, _: Type): boolean, {Error} return true end function TypeChecker:subtype_nominal(a: Type, b: Type): boolean, {Error} - if is_self(a) and is_self(b) then - return true - end - local ra = a is NominalType and self:resolve_nominal(a) or a local rb = b is NominalType and self:resolve_nominal(b) or b local ok, errs = self:is_a(ra, rb) @@ -9092,7 +9080,8 @@ do if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then + local first = fargs[1] + if (not first is SelfType) and not self:arg_check(w, errs, first, args.tuple[1], "contravariant", "self") then return nil, errs end end From 1b27f3ed859b16012f8afbd57f7324701479c89d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 18 Aug 2024 22:00:40 -0300 Subject: [PATCH 159/224] tests: add test case using self See #756. --- spec/subtyping/self_spec.lua | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 spec/subtyping/self_spec.lua diff --git a/spec/subtyping/self_spec.lua b/spec/subtyping/self_spec.lua new file mode 100644 index 000000000..bf9b2aa19 --- /dev/null +++ b/spec/subtyping/self_spec.lua @@ -0,0 +1,43 @@ +local util = require("spec.util") + +describe("subtyping of self", function() + it("self type resolves from abstract interface to concrete records (#756)", util.check([[ + local interface SoundMaker + make_sound: function(self) + end + + local record Animal is SoundMaker + species: string + end + + function Animal:create(species: string): Animal + return setmetatable({ species = species }, { __index = Animal }) + end + + function Animal:make_sound() + print("Animal sound") + end + + local record Person is SoundMaker + name: string + end + + function Person:create(name: string): Person + return setmetatable({ name = name }, { __index = Person }) + end + + function Person:make_sound() + print("Person sound") + end + + local things: {SoundMaker} = { + Animal:create("Dog"), + Person:create("John") + } + + for _, thing in ipairs(things) do + thing:make_sound() + end + ]])) +end) + From 333b9e8bb5fb583982a1889634833c3f917f869e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 18 Aug 2024 22:24:27 -0300 Subject: [PATCH 160/224] fix: self heuristic: ensure that non-nominal types are detected The heuristic was only comparing against nominal types; if a non-nominal type such as `(self: integer)` was used, it was silently accepted as a method and promoted to the `self` type. --- spec/subtyping/self_spec.lua | 124 ++++++++++++++++++++++++++++++++++- tl.lua | 41 +++++++----- tl.tl | 41 +++++++----- 3 files changed, 171 insertions(+), 35 deletions(-) diff --git a/spec/subtyping/self_spec.lua b/spec/subtyping/self_spec.lua index bf9b2aa19..4da770d47 100644 --- a/spec/subtyping/self_spec.lua +++ b/spec/subtyping/self_spec.lua @@ -1,7 +1,7 @@ local util = require("spec.util") describe("subtyping of self", function() - it("self type resolves from abstract interface to concrete records (#756)", util.check([[ + it("self type resolves from abstract interface to concrete records, implicit self type by name (#756)", util.check([[ local interface SoundMaker make_sound: function(self) end @@ -39,5 +39,127 @@ describe("subtyping of self", function() thing:make_sound() end ]])) + + it("self type resolves from abstract interface to concrete records, explicit use of self type (#756)", util.check([[ + local interface SoundMaker + make_sound: function(self: self) + end + + local record Animal is SoundMaker + species: string + end + + function Animal:create(species: string): Animal + return setmetatable({ species = species }, { __index = Animal }) + end + + function Animal:make_sound() + print("Animal sound") + end + + local record Person is SoundMaker + name: string + end + + function Person:create(name: string): Person + return setmetatable({ name = name }, { __index = Person }) + end + + function Person:make_sound() + print("Person sound") + end + + local things: {SoundMaker} = { + Animal:create("Dog"), + Person:create("John") + } + + for _, thing in ipairs(things) do + thing:make_sound() + end + ]])) + + it("self type resolves from abstract interface to concrete records, self type self-reference heuristic (#756)", util.check([[ + local interface SoundMaker + make_sound: function(self: SoundMaker) + end + + local record Animal is SoundMaker + species: string + end + + function Animal:create(species: string): Animal + return setmetatable({ species = species }, { __index = Animal }) + end + + function Animal:make_sound() + print("Animal sound") + end + + local record Person is SoundMaker + name: string + end + + function Person:create(name: string): Person + return setmetatable({ name = name }, { __index = Person }) + end + + function Person:make_sound() + print("Person sound") + end + + local things: {SoundMaker} = { + Animal:create("Dog"), + Person:create("John") + } + + for _, thing in ipairs(things) do + thing:make_sound() + end + ]])) + + it("a self variable that is not a self-referential type has no special behavior", util.check_type_error([[ + local interface SoundMaker + make_sound: function(self: integer) + end + + local record Animal is SoundMaker + species: string + end + + function Animal:create(species: string): Animal + return setmetatable({ species = species }, { __index = Animal }) + end + + function Animal:make_sound() + print("Animal sound") + end + + local record Person is SoundMaker + name: string + end + + function Person:create(name: string): Person + return setmetatable({ name = name }, { __index = Person }) + end + + function Person:make_sound() + print("Person sound") + end + + local things: {SoundMaker} = { + Animal:create("Dog"), + Person:create("John") + } + + for _, thing in ipairs(things) do + thing:make_sound() + end + ]], { + { y = 13, msg = "type signature of 'make_sound' does not match its declaration in Animal: argument 0: got Animal, expected integer" }, + { y = 25, msg = "type signature of 'make_sound' does not match its declaration in Person: argument 0: got Person, expected integer" }, + { y = 35, msg = "self: got SoundMaker, expected integer" }, + })) + end) diff --git a/tl.lua b/tl.lua index cc8e2e3f0..d5172e832 100644 --- a/tl.lua +++ b/tl.lua @@ -3612,6 +3612,9 @@ do local iok = parse_body(ps, i, ndef, nt) if iok then i = iok + if ndef.fields then + ndef.declname = v.tk + end nt.newtype = new_typedecl(ps, itype, ndef) end @@ -12401,6 +12404,26 @@ self:expand_type(node, values, elements) }) end end + local function ensure_is_method_self(typ, fargs) + assert(typ.declname) + local selfarg = fargs[1] + if not (selfarg.typename == "nominal") then + return false + end + if selfarg.names[1] ~= typ.declname or (typ.typeargs and not selfarg.typevals) then + return false + end + if typ.typeargs then + for j = 1, #typ.typeargs do + local tv = selfarg.typevals[j] + if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then + return false + end + end + end + return true + end + local visit_type visit_type = { cbs = { @@ -12449,23 +12472,7 @@ self:expand_type(node, values, elements) }) if ftype.is_method then local fargs = ftype.args.tuple if fargs[1] then - local record_name = typ.declname - if record_name then - local selfarg = fargs[1] - if selfarg.typename == "nominal" then - if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - elseif typ.typeargs then - for j = 1, #typ.typeargs do - local tv = selfarg.typevals[j] - if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then - ftype.is_method = false - break - end - end - end - end - end + ftype.is_method = ensure_is_method_self(typ, fargs) if ftype.is_method then fargs[1] = a_type(fargs[1], "self", {}) end diff --git a/tl.tl b/tl.tl index 4eb756561..58f439ad9 100644 --- a/tl.tl +++ b/tl.tl @@ -3612,6 +3612,9 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType local iok = parse_body(ps, i, ndef, nt) if iok then i = iok + if ndef is RecordLikeType then + ndef.declname = v.tk + end nt.newtype = new_typedecl(ps, itype, ndef) end @@ -12401,6 +12404,26 @@ do end end + local function ensure_is_method_self(typ: RecordLikeType, fargs: {Type}): boolean + assert(typ.declname) + local selfarg = fargs[1] + if not selfarg is NominalType then + return false + end + if selfarg.names[1] ~= typ.declname or (typ.typeargs and not selfarg.typevals) then + return false + end + if typ.typeargs then + for j=1,#typ.typeargs do + local tv = selfarg.typevals[j] + if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then + return false + end + end + end + return true + end + local visit_type: Visitor visit_type = { cbs = { @@ -12449,23 +12472,7 @@ do if ftype.is_method then local fargs = ftype.args.tuple if fargs[1] then - local record_name = typ.declname - if record_name then - local selfarg = fargs[1] - if selfarg is NominalType then - if selfarg.names[1] ~= record_name or (typ.typeargs and not selfarg.typevals) then - ftype.is_method = false - elseif typ.typeargs then - for j=1,#typ.typeargs do - local tv = selfarg.typevals[j] - if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then - ftype.is_method = false - break - end - end - end - end - end + ftype.is_method = ensure_is_method_self(typ, fargs) if ftype.is_method then fargs[1] = a_type(fargs[1], "self", {}) end From a4a780f372c14504babb13df066a265bf3b512f3 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 19 Aug 2024 13:51:36 -0300 Subject: [PATCH 161/224] fix: do not crash when failing to read input --- tl.lua | 4 ++++ tl.tl | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tl.lua b/tl.lua index d5172e832..4d7ba7e82 100644 --- a/tl.lua +++ b/tl.lua @@ -12822,6 +12822,10 @@ end local function read_full_file(fd) local bom = "\239\187\191" local content, err = fd:read("*a") + if not content then + return nil, err + end + if content:sub(1, bom:len()) == bom then content = content:sub(bom:len() + 1) end diff --git a/tl.tl b/tl.tl index 58f439ad9..95d31aa04 100644 --- a/tl.tl +++ b/tl.tl @@ -12822,6 +12822,10 @@ end local function read_full_file(fd: FILE): string, string local bom = "\239\187\191" -- "\xEF\xBB\xBF" local content, err = fd:read("*a") + if not content then + return nil, err + end + if content:sub(1, bom:len()) == bom then content = content:sub(bom:len() + 1) end From 8c2d4c6796308a27f4c13401cd621979b77f07c7 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 19 Aug 2024 13:51:52 -0300 Subject: [PATCH 162/224] self: fix: accept self type in function signature as a method indicator --- tl.lua | 3 +++ tl.tl | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tl.lua b/tl.lua index 4d7ba7e82..9d6b90656 100644 --- a/tl.lua +++ b/tl.lua @@ -12407,6 +12407,9 @@ self:expand_type(node, values, elements) }) local function ensure_is_method_self(typ, fargs) assert(typ.declname) local selfarg = fargs[1] + if selfarg.typename == "self" then + return true + end if not (selfarg.typename == "nominal") then return false end diff --git a/tl.tl b/tl.tl index 95d31aa04..fd629cd1a 100644 --- a/tl.tl +++ b/tl.tl @@ -12407,6 +12407,9 @@ do local function ensure_is_method_self(typ: RecordLikeType, fargs: {Type}): boolean assert(typ.declname) local selfarg = fargs[1] + if selfarg is SelfType then + return true + end if not selfarg is NominalType then return false end From ee5c8fc5ba94eddccff1112c3c431969724cdb39 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 19 Aug 2024 16:27:32 -0300 Subject: [PATCH 163/224] tests: add regression test for #752. Thanks @svermeulen for the report! --- spec/declaration/local_spec.lua | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index ca6324fc5..bcdc9850b 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -552,6 +552,22 @@ describe("local", function() { tag = "unused", msg = "unused type integer" }, })) + it("catches bad assignments of record tables (regression test for #752)", util.check_type_error([[ + local record Foo + qux: integer + end + + local record Bar + gorp: number + end + + local _bar1: Bar = 0.0 + local _bar2: Bar = Foo + ]], { + { y = 9, msg = "_bar1: got number, expected Bar" }, + { y = 10, msg = "_bar2: Foo is not a Bar" }, + })) + it("does not accept type arguments declared twice", util.check_syntax_error([[ local type Foo = record end From 2ff0df4ebcab3b3c5ef14a95cc52676e19e1cb08 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 21 Aug 2024 10:05:37 -0300 Subject: [PATCH 164/224] tests: split `local` and `local type` tests --- spec/declaration/local_spec.lua | 295 ++++----------------------- spec/declaration/local_type_spec.lua | 215 +++++++++++++++++++ 2 files changed, 257 insertions(+), 253 deletions(-) create mode 100644 spec/declaration/local_type_spec.lua diff --git a/spec/declaration/local_spec.lua b/spec/declaration/local_spec.lua index bcdc9850b..0fc8faff2 100644 --- a/spec/declaration/local_spec.lua +++ b/spec/declaration/local_spec.lua @@ -18,6 +18,48 @@ describe("local", function() local z: number z = x + y ]])) + + it("'type', 'record' and 'enum' are not reserved keywords", util.check([[ + local type = type + local record: string = "hello" + local enum: number = 123 + print(record) + print(enum + 123) + ]])) + + it("reports unset and untyped values as errors in tl mode", util.check_type_error([[ + local record T + x: number + y: number + end + + function T:returnsTwo(): number, number + return self.x, self.y + end + + function T:method() + local a, b = self.returnsTwo and self:returnsTwo() + end + ]], { + { msg = "assignment in declaration did not produce an initial value for variable 'b'" }, + })) + + it("reports unset values as unknown in Lua mode", util.lax_check([[ + local record T + x: number + y: number + end + + function T:returnsTwo(): number, number + return self.x, self.y + end + + function T:method() + local a, b = self.returnsTwo and self:returnsTwo() + end + ]], { + { msg = "b" }, + })) end) describe("multiple declaration", function() @@ -63,230 +105,6 @@ describe("local", function() z = x + string.byte(y) ]])) end) - - it("reports unset and untyped values as errors in tl mode", util.check_type_error([[ - local type T = record - x: number - y: number - end - - function T:returnsTwo(): number, number - return self.x, self.y - end - - function T:method() - local a, b = self.returnsTwo and self:returnsTwo() - end - ]], { - { msg = "assignment in declaration did not produce an initial value for variable 'b'" }, - })) - - it("reports unset values as unknown in Lua mode", util.lax_check([[ - local type T = record - x: number - y: number - end - - function T:returnsTwo(): number, number - return self.x, self.y - end - - function T:method() - local a, b = self.returnsTwo and self:returnsTwo() - end - ]], { - { msg = "b" }, - })) - - it("local type can declare a type alias for table", util.check([[ - local type PackTable = table.PackTable - local args: table.PackTable = table.pack(1, 2, 3) - ]])) - - it("local type can declare a nominal type alias (regression test for #238)", function () - util.mock_io(finally, { - ["module.tl"] = [[ - local record module - record Type - data: number - end - end - return module - ]], - ["main.tl"] = [[ - local module = require "module" - local type Boo = module.Type - local var: Boo = { dato = 0 } - print(var.dato) - ]], - }) - local result, err = tl.process("main.tl") - - assert.same({}, result.syntax_errors) - assert.same({ - { y = 3, x = 35, filename = "main.tl", msg = "in local declaration: var: unknown field dato" }, - { y = 4, x = 26, filename = "main.tl", msg = "invalid key 'dato' in record 'var' of type Boo" }, - }, result.type_errors) - end) - - it("local type can resolve a nominal with generics (regression test for #777)", function () - util.mock_io(finally, { - ["module.tl"] = [[ - local record module - record Foo - something: K - end - end - return module - ]], - ["main.tl"] = [[ - local module = require "module" - - local record Boo - field: MyFoo - end - - local type MyFoo = module.Foo - - local b: Boo = { field = { something = "hi" } } - local c: Boo = { field = { something = 123 } } - ]], - }) - local result, err = tl.process("main.tl") - - assert.same({}, result.syntax_errors) - assert.same({ - { y = 10, x = 55, filename = "main.tl", msg = "in record field: something: got integer, expected string" }, - }, result.type_errors) - end) - - it("catches unknown types", util.check_type_error([[ - local type MyType = UnknownType - ]], { - { msg = "unknown type UnknownType" } - })) - - it("nominal types can take type arguments", util.check([[ - local record Foo - item: R - end - - local type Foo2 = Foo - local type Bla = Foo - - local x: Bla = { item = 123 } - local y: Foo2 = { item = 123 } - ]])) - - it("types declared as nominal types are aliases", util.check([[ - local record Foo - item: R - end - - local type Foo2 = Foo - local type FooNumber = Foo - - local x: FooNumber = { item = 123 } - local y: Foo2 = { item = 123 } - - local type Foo3 = Foo - local type Foo4 = Foo2 - - local zep: Foo2 = { item = "hello" } - local zip: Foo3 = zep - local zup: Foo4 = zip - ]])) - - it("nested types can be resolved as aliases", util.check([[ - local record Foo - enum LocalEnum - "loc" - end - - record Nested - x: {LocalEnum} - y: R - end - - item: R - end - - local type Nested = Foo.Nested - ]])) - - it("'type', 'record' and 'enum' are not reserved keywords", util.check([[ - local type = type - local record: string = "hello" - local enum: number = 123 - print(record) - print(enum + 123) - ]])) - - it("local type can require a module", function () - util.mock_io(finally, { - ["class.tl"] = [[ - local record Class - data: number - end - return Class - ]], - ["main.tl"] = [[ - local type Class = require("class") - local obj: Class = { data = 2 } - ]], - }) - local result, err = tl.process("main.tl") - - assert.same({}, result.syntax_errors) - assert.same({}, result.type_errors) - end) - - it("local type can require a module and type is usable", function () - util.mock_io(finally, { - ["class.tl"] = [[ - local record Class - data: number - end - return Class - ]], - ["main.tl"] = [[ - local type Class = require("class") - local obj: Class = { invalid = 2 } - ]], - }) - local result, err = tl.process("main.tl") - - assert.same({}, result.syntax_errors) - assert.same({ - { y = 2, x = 37, filename = "main.tl", msg = "in local declaration: obj: unknown field invalid" }, - }, result.type_errors) - end) - - it("local type can require a module and its globals are visible", function () - util.mock_io(finally, { - ["class.tl"] = [[ - global record Glob - hello: number - end - - local record Class - data: number - end - return Class - ]], - ["main.tl"] = [[ - local type Class = require("class") - local obj: Glob = { hello = 2 } - local obj2: Glob = { invalid = 2 } - ]], - }) - local result, err = tl.process("main.tl") - - assert.same({}, result.syntax_errors) - assert.same({ - { y = 3, x = 37, filename = "main.tl", msg = "in local declaration: obj2: unknown field invalid" }, - }, result.type_errors) - end) end) describe("annotation", function() @@ -567,33 +385,4 @@ describe("local", function() { y = 9, msg = "_bar1: got number, expected Bar" }, { y = 10, msg = "_bar2: Foo is not a Bar" }, })) - - it("does not accept type arguments declared twice", util.check_syntax_error([[ - local type Foo = record - end - ]], { - { y = 1, msg = "cannot declare type arguments twice in type declaration" }, - })) - - it("propagates type arguments correctly", util.check_type_error([[ - local record module - record Foo - first: A - second: B - end - end - - -- note inverted arguments - local type MyFoo = module.Foo - - local record Boo - field: MyFoo - end - - local b: Boo = { field = { first = "first", second = 2 } } -- bad, not inverted! - local c: Boo = { field = { first = 1, second = "second" } } -- good, inverted! - ]], { - { y = 15, x = 42, msg = 'in record field: first: got string "first", expected integer' }, - { y = 15, x = 60, msg = 'in record field: second: got integer, expected string' }, - })) end) diff --git a/spec/declaration/local_type_spec.lua b/spec/declaration/local_type_spec.lua new file mode 100644 index 000000000..e9c0aa506 --- /dev/null +++ b/spec/declaration/local_type_spec.lua @@ -0,0 +1,215 @@ +local util = require("spec.util") +local tl = require("tl") + +describe("local type", function() + it("can declare a type alias for table", util.check([[ + local type PackTable = table.PackTable + local args: table.PackTable = table.pack(1, 2, 3) + ]])) + + it("can declare a nominal type alias (regression test for #238)", function () + util.mock_io(finally, { + ["module.tl"] = [[ + local record module + record Type + data: number + end + end + return module + ]], + ["main.tl"] = [[ + local module = require "module" + local type Boo = module.Type + local var: Boo = { dato = 0 } + print(var.dato) + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { y = 3, x = 32, filename = "main.tl", msg = "in local declaration: var: unknown field dato" }, + { y = 4, x = 23, filename = "main.tl", msg = "invalid key 'dato' in record 'var' of type Boo" }, + }, result.type_errors) + end) + + it("can resolve a nominal with generics (regression test for #777)", function () + util.mock_io(finally, { + ["module.tl"] = [[ + local record module + record Foo + something: K + end + end + return module + ]], + ["main.tl"] = [[ + local module = require "module" + + local record Boo + field: MyFoo + end + + local type MyFoo = module.Foo + + local b: Boo = { field = { something = "hi" } } + local c: Boo = { field = { something = 123 } } + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { y = 10, x = 52, filename = "main.tl", msg = "in record field: something: got integer, expected string" }, + }, result.type_errors) + end) + + it("catches unknown types", util.check_type_error([[ + local type MyType = UnknownType + ]], { + { msg = "unknown type UnknownType" } + })) + + it("nominal types can take type arguments", util.check([[ + local record Foo + item: R + end + + local type Foo2 = Foo + local type Bla = Foo + + local x: Bla = { item = 123 } + local y: Foo2 = { item = 123 } + ]])) + + it("declared as nominal types are aliases", util.check([[ + local record Foo + item: R + end + + local type Foo2 = Foo + local type FooNumber = Foo + + local x: FooNumber = { item = 123 } + local y: Foo2 = { item = 123 } + + local type Foo3 = Foo + local type Foo4 = Foo2 + + local zep: Foo2 = { item = "hello" } + local zip: Foo3 = zep + local zup: Foo4 = zip + ]])) + + it("nested types can be resolved as aliases", util.check([[ + local record Foo + enum LocalEnum + "loc" + end + + record Nested + x: {LocalEnum} + y: R + end + + item: R + end + + local type Nested = Foo.Nested + ]])) + + it("can require a module", function () + util.mock_io(finally, { + ["class.tl"] = [[ + local record Class + data: number + end + return Class + ]], + ["main.tl"] = [[ + local type Class = require("class") + local obj: Class = { data = 2 } + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("can require a module and type is usable", function () + util.mock_io(finally, { + ["class.tl"] = [[ + local record Class + data: number + end + return Class + ]], + ["main.tl"] = [[ + local type Class = require("class") + local obj: Class = { invalid = 2 } + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { y = 2, x = 34, filename = "main.tl", msg = "in local declaration: obj: unknown field invalid" }, + }, result.type_errors) + end) + + it("can require a module and its globals are visible", function () + util.mock_io(finally, { + ["class.tl"] = [[ + global record Glob + hello: number + end + + local record Class + data: number + end + return Class + ]], + ["main.tl"] = [[ + local type Class = require("class") + local obj: Glob = { hello = 2 } + local obj2: Glob = { invalid = 2 } + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({ + { y = 3, x = 34, filename = "main.tl", msg = "in local declaration: obj2: unknown field invalid" }, + }, result.type_errors) + end) + + it("does not accept type arguments declared twice", util.check_syntax_error([[ + local type Foo = record + end + ]], { + { y = 1, msg = "cannot declare type arguments twice in type declaration" }, + })) + + it("propagates type arguments correctly", util.check_type_error([[ + local record module + record Foo + first: A + second: B + end + end + + -- note inverted arguments + local type MyFoo = module.Foo + + local record Boo + field: MyFoo + end + + local b: Boo = { field = { first = "first", second = 2 } } -- bad, not inverted! + local c: Boo = { field = { first = 1, second = "second" } } -- good, inverted! + ]], { + { y = 15, x = 42, msg = 'in record field: first: got string "first", expected integer' }, + { y = 15, x = 60, msg = 'in record field: second: got integer, expected string' }, + })) +end) From 6430aaa56532eb0d6716b377b3ebdb512b443911 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 21 Aug 2024 10:09:59 -0300 Subject: [PATCH 165/224] tests: add regression test for #754 --- spec/declaration/local_type_spec.lua | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/spec/declaration/local_type_spec.lua b/spec/declaration/local_type_spec.lua index e9c0aa506..59a2c88c2 100644 --- a/spec/declaration/local_type_spec.lua +++ b/spec/declaration/local_type_spec.lua @@ -212,4 +212,23 @@ describe("local type", function() { y = 15, x = 42, msg = 'in record field: first: got string "first", expected integer' }, { y = 15, x = 60, msg = 'in record field: second: got integer, expected string' }, })) + + it("resolves type arguments in nested types correctly (#754)", util.check_type_error([[ + local record MyNamespace + record MyGenericRecord + Data: T + end + end + + local enum MyEnum + "foo" + "bar" + end + + local type MyAlias = MyNamespace.MyGenericRecord + + local t: MyAlias = { Data = "invalid" } + ]], { + { y = 14, msg = 'in record field: Data: string "invalid" is not a member of MyEnum' } + })) end) From 234ee2a619197a7f3b0c4c7d858534691c7d4fd7 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 26 Aug 2024 09:29:53 -0300 Subject: [PATCH 166/224] fix: tl types: never trigger ICE on bad files This matches the behavior of master. Can't make a simple regression test for this one because that would be dependent on unspecified behaviors of the parser and type-checker. Fixes #795. --- tl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tl b/tl index e78f40322..6a90ac055 100755 --- a/tl +++ b/tl @@ -897,10 +897,15 @@ do env.keep_going = true env.report_types = true + local pcalls_ok = true for i, input_file in ipairs(args["file"]) do - local pok, perr, err = pcall(process_module, input_file, env) + -- we run the type-checker even on files that produce + -- syntax errors; this means we run it on incomplete and + -- potentially inconsistent trees which may crash the + -- type-checker; hence, we wrap it with a pcall here. + local pok, _, err = pcall(process_module, input_file, env) if not pok then - die("Internal Compiler Error: " .. perr) + pcalls_ok = false end if err then printerr(err) @@ -910,6 +915,9 @@ do end local ok, _, _, w = report_all_errors(tlconfig, env) + if not pcalls_ok then + ok = false + end if not env.reporter then os.exit(1) @@ -928,7 +936,7 @@ do x = tonumber(x) or 1 json_out_table(io.stdout, tl.symbols_in_scope(tr, y, x, filename)) else - tr.symbols = tr.symbols_by_file[filename] + tr.symbols = tr.symbols_by_file[filename] or { [0] = false } json_out_table(io.stdout, tr) end From d12f5345ae0aa6efe3b95c8eff5c2a77b09bd300 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 26 Aug 2024 10:44:31 -0300 Subject: [PATCH 167/224] fix: do not infer type variables as boolean in boolean contexts Introduces a special internal type, to be used only as the node.expected type in boolean contexts such as `if _ then`. It behaves exactly like boolean except that type variables do not infer to it. See #768. --- tl.lua | 36 ++++++++++++++++++++++++++++++++---- tl.tl | 42 +++++++++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/tl.lua b/tl.lua index 9d6b90656..8495d9727 100644 --- a/tl.lua +++ b/tl.lua @@ -1539,6 +1539,7 @@ end + local table_types = { @@ -1569,6 +1570,7 @@ local table_types = { ["unresolved_typearg"] = false, ["unresolvable_typearg"] = false, ["circular_require"] = false, + ["boolean_context"] = false, ["tuple"] = false, ["poly"] = false, ["any"] = false, @@ -1891,6 +1893,20 @@ end + + + + + + + + + + + + + + @@ -5630,6 +5646,7 @@ local typename_to_typecode = { ["union"] = tl.typecodes.UNION, ["nominal"] = tl.typecodes.NOMINAL, ["circular_require"] = tl.typecodes.NOMINAL, + ["boolean_context"] = tl.typecodes.BOOLEAN, ["emptytable"] = tl.typecodes.EMPTY_TABLE, ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, ["poly"] = tl.typecodes.POLY, @@ -6688,6 +6705,8 @@ local function show_type_base(t, short, seen) return "" elseif t.typename == "nil" then return "nil" + elseif t.typename == "boolean_context" then + return "boolean" elseif t.typename == "none" then return "" elseif t.typename == "typealias" then @@ -8350,7 +8369,11 @@ do return self:same_type(self:type_of_self(a), b) end, }, + ["boolean_context"] = { + ["boolean"] = compare_true, + }, ["*"] = { + ["boolean_context"] = compare_true, ["self"] = function(self, a, b) return self:same_type(a, self:type_of_self(b)) end, @@ -8673,8 +8696,12 @@ a.types[i], b.types[i]), } end end, }, + ["boolean_context"] = { + ["boolean"] = compare_true, + }, ["*"] = { ["any"] = compare_true, + ["boolean_context"] = compare_true, ["self"] = function(self, a, b) return self:is_a(a, self:type_of_self(b)) end, @@ -8712,6 +8739,7 @@ a.types[i], b.types[i]), } ["typevar"] = 3, ["nil"] = 4, ["any"] = 5, + ["boolean_context"] = 5, ["union"] = 6, ["poly"] = 7, @@ -11044,7 +11072,7 @@ self:expand_type(node, values, elements) }) self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end if node.exp then - node.exp.expected = a_type(node, "boolean", {}) + node.exp.expected = a_type(node, "boolean_context", {}) end end, before_statements = function(self, node) @@ -11066,7 +11094,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) - node.exp.expected = a_type(node, "boolean", {}) + node.exp.expected = a_type(node, "boolean_context", {}) end, before_statements = function(self, node) self:begin_scope(node) @@ -11130,7 +11158,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) - node.exp.expected = a_type(node, "boolean", {}) + node.exp.expected = a_type(node, "boolean_context", {}) end, after = end_scope_and_none_type, @@ -12001,7 +12029,7 @@ self:expand_type(node, values, elements) }) t = drop_constant_value(t) end - if expected and expected.typename == "boolean" then + if expected and expected.typename == "boolean_context" then t = a_type(node, "boolean", {}) end end diff --git a/tl.tl b/tl.tl index fd629cd1a..e2d273f7c 100644 --- a/tl.tl +++ b/tl.tl @@ -1532,11 +1532,12 @@ local enum TypeName "unresolved_typearg" "unresolvable_typearg" "circular_require" + "boolean_context" "tuple" - "poly" -- intersection types, currently restricted to polymorphic functions defined inside records + "poly" "any" - "unknown" -- to be used in lax mode only - "invalid" -- producing a new value of this type (not propagating) must always produce a type error + "unknown" + "invalid" "none" "*" end @@ -1569,6 +1570,7 @@ local table_types : {TypeName:boolean} = { ["unresolved_typearg"] = false, ["unresolvable_typearg"] = false, ["circular_require"] = false, + ["boolean_context"] = false, ["tuple"] = false, ["poly"] = false, ["any"] = false, @@ -1617,6 +1619,14 @@ local record BooleanType where self.typename == "boolean" end +-- This is a special internal type, to be used only as the node.expected +-- type in boolean contexts such as `if _ then`. It behaves exactly like +-- boolean except that type variables do not infer to it. +local record BooleanContextType + is Type + where self.typename == "boolean_context" +end + local interface HasTypeArgs is Type where self.typeargs @@ -1717,11 +1727,15 @@ local record InterfaceType where self.typename == "interface" end +-- producing a new value of this type (not propagating) +-- must always produce a type error local record InvalidType is Type where self.typename == "invalid" end +-- To be used in lax mode only: +-- this represents non-annotated types in .lua files. local record UnknownType is Type where self.typename == "unknown" @@ -1819,6 +1833,8 @@ local record TupleTableType where self.typename == "tupletable" end +-- Intersection types, currently restricted to polymorphic functions +-- defined inside records, representing polymorphic Lua APIs. local record PolyType is AggregateType where self.typename == "poly" @@ -5630,6 +5646,7 @@ local typename_to_typecode : {TypeName:integer} = { ["union"] = tl.typecodes.UNION, ["nominal"] = tl.typecodes.NOMINAL, ["circular_require"] = tl.typecodes.NOMINAL, + ["boolean_context"] = tl.typecodes.BOOLEAN, ["emptytable"] = tl.typecodes.EMPTY_TABLE, ["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, ["poly"] = tl.typecodes.POLY, @@ -6688,6 +6705,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "" elseif t.typename == "nil" then return "nil" + elseif t.typename == "boolean_context" then + return "boolean" elseif t.typename == "none" then return "" elseif t is TypeAliasType then @@ -8350,7 +8369,11 @@ do return self:same_type(self:type_of_self(a), b) end, }, + ["boolean_context"] = { + ["boolean"] = compare_true, + }, ["*"] = { + ["boolean_context"] = compare_true, ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} return self:same_type(a, self:type_of_self(b)) end, @@ -8673,8 +8696,12 @@ do end end, }, + ["boolean_context"] = { + ["boolean"] = compare_true, + }, ["*"] = { ["any"] = compare_true, + ["boolean_context"] = compare_true, ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} return self:is_a(a, self:type_of_self(b)) end, @@ -8712,6 +8739,7 @@ do ["typevar"] = 3, ["nil"] = 4, ["any"] = 5, + ["boolean_context"] = 5, ["union"] = 6, ["poly"] = 7, -- then typeargs @@ -11044,7 +11072,7 @@ do self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end if node.exp then - node.exp.expected = a_type(node, "boolean", {}) + node.exp.expected = a_type(node, "boolean_context", {}) end end, before_statements = function(self: TypeChecker, node: Node) @@ -11066,7 +11094,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) - node.exp.expected = a_type(node, "boolean", {}) + node.exp.expected = a_type(node, "boolean_context", {}) end, before_statements = function(self: TypeChecker, node: Node) self:begin_scope(node) @@ -11130,7 +11158,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) - node.exp.expected = a_type(node, "boolean", {}) + node.exp.expected = a_type(node, "boolean_context", {}) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, @@ -12001,7 +12029,7 @@ do t = drop_constant_value(t) end - if expected and expected is BooleanType then + if expected and expected is BooleanContextType then t = a_type(node, "boolean", {}) end end From b8b5bb86762f9b2e7192915ffc4ebc1f2ebebba9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 28 Aug 2024 00:12:29 -0300 Subject: [PATCH 168/224] generate inline compat code for table.pack --- spec/cli/gen_spec.lua | 8 ++++++-- tl.lua | 15 +++++++++------ tl.tl | 15 +++++++++------ 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/spec/cli/gen_spec.lua b/spec/cli/gen_spec.lua index fd3d67f51..7f537e109 100644 --- a/spec/cli/gen_spec.lua +++ b/spec/cli/gen_spec.lua @@ -297,6 +297,7 @@ describe("tl gen", function() local t = {1, 2, 3, 4} print(table.unpack(t)) + local t2 = table.pack(1, 2, "any") local n = 42 local maxi = math.maxinteger local mini = math.mininteger @@ -315,6 +316,7 @@ describe("tl gen", function() local t = { 1, 2, 3, 4 } print(table.unpack(t)) + local t2 = table.pack(1, 2, "any") local n = 42 local maxi = math.maxinteger local mini = math.mininteger @@ -330,9 +332,10 @@ describe("tl gen", function() ]] local output_code_with_optional_compat = [[ - local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local _tl_math_mininteger = math.mininteger or -math.pow(2, 53) - 1; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack + local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local _tl_math_mininteger = math.mininteger or -math.pow(2, 53) - 1; local table = _tl_compat and _tl_compat.table or table; local _tl_table_pack = table.pack or function(...) return { n = select("#", ...), ... } end; local _tl_table_unpack = unpack or table.unpack local t = { 1, 2, 3, 4 } print(_tl_table_unpack(t)) + local t2 = _tl_table_pack(1, 2, "any") local n = 42 local maxi = _tl_math_maxinteger local mini = _tl_math_mininteger @@ -348,9 +351,10 @@ describe("tl gen", function() ]] local output_code_with_required_compat = [[ - local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = true, require('compat53.module'); if p then _tl_compat = m end end; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local _tl_math_mininteger = math.mininteger or -math.pow(2, 53) - 1; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack + local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = true, require('compat53.module'); if p then _tl_compat = m end end; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local _tl_math_mininteger = math.mininteger or -math.pow(2, 53) - 1; local table = _tl_compat and _tl_compat.table or table; local _tl_table_pack = table.pack or function(...) return { n = select("#", ...), ... } end; local _tl_table_unpack = unpack or table.unpack local t = { 1, 2, 3, 4 } print(_tl_table_unpack(t)) + local t2 = _tl_table_pack(1, 2, "any") local n = 42 local maxi = _tl_math_maxinteger local mini = _tl_math_mininteger diff --git a/tl.lua b/tl.lua index 8495d9727..d8e10b3e5 100644 --- a/tl.lua +++ b/tl.lua @@ -347,17 +347,17 @@ do move: function({A}, integer, integer, integer, ? {A}): {A} - pack: function(T...): PackTable - pack: function(any...): {any:any} + pack: function(T...): PackTable --[[needs_compat]] + pack: function(any...): {any:any} --[[needs_compat]] remove: function({A}, ? integer): A sort: function({A}, ? SortFunction) - unpack: function({A1, A2, A3, A4, A5}): A1, A2, A3, A4, A5 --[[needs_compat]] - unpack: function({A1, A2, A3, A4}): A1, A2, A3, A4 --[[needs_compat]] - unpack: function({A1, A2, A3}): A1, A2, A3 --[[needs_compat]] - unpack: function({A1, A2}): A1, A2 --[[needs_compat]] unpack: function({A}, ? number, ? number): A... --[[needs_compat]] + unpack: function({A1, A2}): A1, A2 --[[needs_compat]] + unpack: function({A1, A2, A3}): A1, A2, A3 --[[needs_compat]] + unpack: function({A1, A2, A3, A4}): A1, A2, A3, A4 --[[needs_compat]] + unpack: function({A1, A2, A3, A4, A5}): A1, A2, A3, A4, A5 --[[needs_compat]] end global record utf8 @@ -6833,6 +6833,8 @@ local function add_compat_entries(program, used_set, gen_compat) for _, name in ipairs(used_list) do if name == "table.unpack" then load_code(name, "local _tl_table_unpack = unpack or table.unpack") + elseif name == "table.pack" then + load_code(name, [[local _tl_table_pack = table.pack or function(...) return { n = select("#", ...), ... } end]]) elseif name == "bit32" then load_code(name, "local bit32 = bit32; if not bit32 then local p, m = " .. req("bit32") .. "; if p then bit32 = m end") elseif name == "mt" then @@ -6964,6 +6966,7 @@ tl.new_env = function(opts) local table_t = (stdlib_globals["table"].t).def math_t.fields["maxinteger"].needs_compat = true math_t.fields["mininteger"].needs_compat = true + table_t.fields["pack"].needs_compat = true table_t.fields["unpack"].needs_compat = true diff --git a/tl.tl b/tl.tl index e2d273f7c..cec12f333 100644 --- a/tl.tl +++ b/tl.tl @@ -347,17 +347,17 @@ do move: function({A}, integer, integer, integer, ? {A}): {A} - pack: function(T...): PackTable - pack: function(any...): {any:any} + pack: function(T...): PackTable --[[needs_compat]] + pack: function(any...): {any:any} --[[needs_compat]] remove: function({A}, ? integer): A sort: function({A}, ? SortFunction) - unpack: function({A1, A2, A3, A4, A5}): A1, A2, A3, A4, A5 --[[needs_compat]] - unpack: function({A1, A2, A3, A4}): A1, A2, A3, A4 --[[needs_compat]] - unpack: function({A1, A2, A3}): A1, A2, A3 --[[needs_compat]] - unpack: function({A1, A2}): A1, A2 --[[needs_compat]] unpack: function({A}, ? number, ? number): A... --[[needs_compat]] + unpack: function({A1, A2}): A1, A2 --[[needs_compat]] + unpack: function({A1, A2, A3}): A1, A2, A3 --[[needs_compat]] + unpack: function({A1, A2, A3, A4}): A1, A2, A3, A4 --[[needs_compat]] + unpack: function({A1, A2, A3, A4, A5}): A1, A2, A3, A4, A5 --[[needs_compat]] end global record utf8 @@ -6833,6 +6833,8 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge for _, name in ipairs(used_list) do if name == "table.unpack" then load_code(name, "local _tl_table_unpack = unpack or table.unpack") + elseif name == "table.pack" then + load_code(name, [[local _tl_table_pack = table.pack or function(...) return { n = select("#", ...), ... } end]]) elseif name == "bit32" then load_code(name, "local bit32 = bit32; if not bit32 then local p, m = " .. req("bit32") .. "; if p then bit32 = m end") elseif name == "mt" then @@ -6964,6 +6966,7 @@ tl.new_env = function(opts?: EnvOptions): Env, string local table_t = (stdlib_globals["table"].t as TypeDeclType).def as RecordType math_t.fields["maxinteger"].needs_compat = true math_t.fields["mininteger"].needs_compat = true + table_t.fields["pack"].needs_compat = true table_t.fields["unpack"].needs_compat = true -- only global scope and vararg functions accept `...`: From 7a6a19ffc1b122e7952c3e2623a62c7970e8b743 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 16 Jun 2024 13:35:44 -0300 Subject: [PATCH 169/224] pragma: introduce --#pragma syntax --- spec/pragma/invalid_spec.lua | 29 +++++++++++++++ tl.lua | 64 +++++++++++++++++++++++++++++++++ tl.tl | 69 ++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 spec/pragma/invalid_spec.lua diff --git a/spec/pragma/invalid_spec.lua b/spec/pragma/invalid_spec.lua new file mode 100644 index 000000000..105770481 --- /dev/null +++ b/spec/pragma/invalid_spec.lua @@ -0,0 +1,29 @@ +local util = require("spec.util") + +describe("invalid pragma", function() + it("rejects invalid pragma", util.check_syntax_error([[ + --#invalid_pragma on + ]], { + { y = 1, msg = "invalid token '--#invalid_pragma'" } + })) + + it("pragmas currently do not accept punctuation", util.check_syntax_error([[ + --#pragma something(other) + ]], { + { y = 1, msg = "invalid token '('" }, + { y = 1, msg = "invalid token ')'" }, + })) + + it("pragma arguments need to be in a single line", util.check_syntax_error([[ + --#pragma arity + on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { msg = "expected pragma value" } + })) +end) diff --git a/tl.lua b/tl.lua index d8e10b3e5..8eaa70524 100644 --- a/tl.lua +++ b/tl.lua @@ -804,6 +804,8 @@ end + + @@ -838,6 +840,9 @@ do + + + @@ -874,6 +879,9 @@ do ["number hexfloat"] = "number", ["number power"] = "number", ["number powersign"] = "$ERR invalid_number$", + ["pragma"] = "pragma", + ["pragma any"] = nil, + ["pragma word"] = "pragma_identifier", } local keywords = { @@ -1267,11 +1275,39 @@ do elseif state == "got --" then if c == "[" then state = "got --[" + elseif c == "#" then + state = "pragma" else fwd = false state = "comment short" drop_token() end + elseif state == "pragma" then + if not lex_word[c] then + end_token_prev("pragma") + if tokens[nt].tk ~= "--#pragma" then + add_syntax_error() + end + fwd = false + state = "pragma any" + end + elseif state == "pragma any" then + if c == "\n" then + state = "any" + elseif lex_word[c] then + state = "pragma word" + begin_token() + elseif not lex_space[c] then + begin_token() + end_token_here("$ERR invalid$") + add_syntax_error() + end + elseif state == "pragma word" then + if not lex_word[c] then + end_token_prev("pragma_identifier") + fwd = false + state = (c == "\n") and "any" or "pragma any" + end elseif state == "got 0" then if c == "x" or c == "X" then state = "number hex" @@ -4220,7 +4256,27 @@ do return parse_function(ps, i, "record") end + local function parse_pragma(ps, i) + i = i + 1 + local pragma = new_node(ps, i, "pragma") + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma name") + end + pragma.pkey = ps.tokens[i].tk + i = i + 1 + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma value") + end + pragma.pvalue = ps.tokens[i].tk + i = i + 1 + + return i, pragma + end + local parse_statement_fns = { + ["--#pragma"] = parse_pragma, ["::"] = parse_label, ["do"] = parse_do, ["if"] = parse_if, @@ -4589,6 +4645,7 @@ local no_recurse_node = { ["break"] = true, ["label"] = true, ["number"] = true, + ["pragma"] = true, ["string"] = true, ["boolean"] = true, ["integer"] = true, @@ -5547,6 +5604,8 @@ function tl.pretty_print_ast(ast, gen_target, mode) return out end, }, + ["pragma"] = {}, + ["variable"] = emit_exactly_visitor_cbs, ["identifier"] = emit_exactly_visitor_cbs, @@ -12274,6 +12333,11 @@ self:expand_type(node, values, elements) }) return node.newtype end, }, + ["pragma"] = { + after = function(_self, _node, _children) + return NONE + end, + }, ["error_node"] = { after = function(_self, node, _children) return a_type(node, "invalid", {}) diff --git a/tl.tl b/tl.tl index cec12f333..08baa5db3 100644 --- a/tl.tl +++ b/tl.tl @@ -793,6 +793,8 @@ local enum TokenKind "identifier" "number" "integer" + "pragma" + "pragma_identifier" "$ERR unfinished_comment$" "$ERR invalid_string$" "$ERR invalid_number$" @@ -840,6 +842,9 @@ do "number hexfloat" "number power" "number powersign" + "pragma" + "pragma word" + "pragma any" end local last_token_kind : {LexState:TokenKind} = { @@ -874,6 +879,9 @@ do ["number hexfloat"] = "number", ["number power"] = "number", ["number powersign"] = "$ERR invalid_number$", + ["pragma"] = "pragma", + ["pragma any"] = nil, -- never in a token + ["pragma word"] = "pragma_identifier", -- never in a token } local keywords: {string:boolean} = { @@ -1267,11 +1275,39 @@ do elseif state == "got --" then if c == "[" then state = "got --[" + elseif c == "#" then + state = "pragma" else fwd = false state = "comment short" drop_token() end + elseif state == "pragma" then + if not lex_word[c] then + end_token_prev("pragma") + if tokens[nt].tk ~= "--#pragma" then + add_syntax_error() + end + fwd = false + state = "pragma any" + end + elseif state == "pragma any" then + if c == "\n" then + state = "any" + elseif lex_word[c] then + state = "pragma word" + begin_token() + elseif not lex_space[c] then + begin_token() + end_token_here("$ERR invalid$") + add_syntax_error() + end + elseif state == "pragma word" then + if not lex_word[c] then + end_token_prev("pragma_identifier") + fwd = false + state = (c == "\n") and "any" or "pragma any" + end elseif state == "got 0" then if c == "x" or c == "X" then state = "number hex" @@ -1902,6 +1938,7 @@ local enum NodeKind "macroexp" "local_macroexp" "interface" + "pragma" "error_node" end @@ -2100,6 +2137,10 @@ local record Node itemtype: Type decltuple: TupleType + -- pragma + pkey: string + pvalue: string + opt: boolean debug_type: Type @@ -4220,7 +4261,27 @@ local function parse_record_function(ps: ParseState, i: integer): integer, Node return parse_function(ps, i, "record") end +local function parse_pragma(ps: ParseState, i: integer): integer, Node + i = i + 1 -- skip "--#pragma" + local pragma = new_node(ps, i, "pragma") + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma name") + end + pragma.pkey = ps.tokens[i].tk + i = i + 1 + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma value") + end + pragma.pvalue = ps.tokens[i].tk + i = i + 1 + + return i, pragma +end + local parse_statement_fns: {string : function(ParseState, integer):(integer, Node)} = { + ["--#pragma"] = parse_pragma, ["::"] = parse_label, ["do"] = parse_do, ["if"] = parse_if, @@ -4589,6 +4650,7 @@ local no_recurse_node: {NodeKind : boolean} = { ["break"] = true, ["label"] = true, ["number"] = true, + ["pragma"] = true, ["string"] = true, ["boolean"] = true, ["integer"] = true, @@ -5547,6 +5609,8 @@ function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode?: boolean | return out end, }, + ["pragma"] = { + }, ["variable"] = emit_exactly_visitor_cbs, ["identifier"] = emit_exactly_visitor_cbs, @@ -12274,6 +12338,11 @@ do return node.newtype end, }, + ["pragma"] = { + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type + return NONE + end, + }, ["error_node"] = { after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return an_invalid(node) From 1b12d763109eeb074c568618cb78f8fd0d81819f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 16 Jun 2024 13:37:35 -0300 Subject: [PATCH 170/224] pragma: arity on/off --- spec/cli/feat_spec.lua | 4 +- spec/pragma/arity_spec.lua | 237 +++++++++++++++++++++++++++++++++++++ tl.lua | 96 +++++++++++---- tl.tl | 91 ++++++++++---- 4 files changed, 384 insertions(+), 44 deletions(-) create mode 100644 spec/pragma/arity_spec.lua diff --git a/spec/cli/feat_spec.lua b/spec/cli/feat_spec.lua index 286033387..7741cc8f3 100644 --- a/spec/cli/feat_spec.lua +++ b/spec/cli/feat_spec.lua @@ -43,8 +43,8 @@ local test_cases = { status = 1, match = { "2 errors:", - ":9:22: wrong number of arguments (given 3, expects 2)", - ":19:22: wrong number of arguments (given 3, expects at least 1 and at most 2)", + ":9:22: wrong number of arguments (given 3, expects at most 2)", + ":19:22: wrong number of arguments (given 3, expects at most 2)", } } } diff --git a/spec/pragma/arity_spec.lua b/spec/pragma/arity_spec.lua new file mode 100644 index 000000000..02cf8ccc7 --- /dev/null +++ b/spec/pragma/arity_spec.lua @@ -0,0 +1,237 @@ +local util = require("spec.util") + +describe("pragma arity", function() + describe("on", function() + it("rejects function calls with missing arguments", util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { msg = "wrong number of arguments (given 1, expects 2)" } + })) + + it("accepts optional arguments", util.check([[ + --#pragma arity on + + local function f(x: integer, y?: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + end) + + describe("off", function() + it("accepts function calls with missing arguments", util.check([[ + --#pragma arity off + + local function f(x: integer, y: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + + it("ignores optional argument annotations", util.check([[ + --#pragma arity off + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]])) + end) + + describe("no propagation from required module upwards:", function() + it("on then off, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity off + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("on then on, with errors in both", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "r.tl", y = 5, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("off then on, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity off + + local r = require("r") + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 7, filename = "r.tl", msg = "wrong number of arguments (given 1, expects 2)" } + })() + end) + end) + + describe("does propagate downwards into required module:", function() + it("can trigger errors in required modules", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) + ]], { + { filename = "r.tl", y = 4, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 1, expects 3)" }, + })() + end) + + it("can be used to load modules with different settings", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + --#pragma arity off + local r = require("r") + --#pragma arity on + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) -- no error here! + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + end) + + describe("invalid", function() + it("rejects invalid value", util.check_type_error([[ + --#pragma arity invalid_value + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 1, msg = "invalid value for pragma 'arity': invalid_value" } + })) + end) +end) diff --git a/tl.lua b/tl.lua index 8eaa70524..fafb1a100 100644 --- a/tl.lua +++ b/tl.lua @@ -1956,6 +1956,7 @@ end + local TruthyFact = {} @@ -2136,6 +2137,10 @@ local Node = {ExpectedContext = {}, } + + + + @@ -6830,21 +6835,33 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(w, module_name, feat_lax, env) +local function require_module(w, module_name, opts, env) local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$")) then + if found and (opts.feat_lax == "on" or found:match("tl$")) then env.module_filenames[module_name] = found env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + local save_defaults = env.defaults + local defaults = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -7050,7 +7067,11 @@ tl.new_env = function(opts) if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7323,9 +7344,15 @@ do local function show_arity(f) local nfargs = #f.args.tuple - return f.min_arity < nfargs and - "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) or - tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function drop_constant_value(t) @@ -8977,7 +9004,11 @@ a.types[i], b.types[i]), } if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }), + }) end func = self:to_structural(func) @@ -9620,9 +9651,9 @@ a.types[i], b.types[i]), } end end - function TypeChecker:add_function_definition_for_recursion(node, fnargs) + function TypeChecker:add_function_definition_for_recursion(node, fnargs, feat_arity) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10340,7 +10371,7 @@ a.types[i], b.types[i]), } local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), rets = a_type(arg2, "tuple", { tuple = {} }), }) @@ -10428,7 +10459,11 @@ a.types[i], b.types[i]), } end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11587,7 +11622,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11598,7 +11633,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11627,7 +11662,7 @@ self:expand_type(node, values, elements) }) self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11660,7 +11695,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11674,7 +11709,7 @@ self:expand_type(node, values, elements) }) end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11745,7 +11780,7 @@ self:expand_type(node, values, elements) }) end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11819,7 +11854,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11845,7 +11880,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12334,7 +12369,18 @@ self:expand_type(node, values, elements) }) end, }, ["pragma"] = { - after = function(_self, _node, _children) + after = function(self, node, _children) + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end return NONE end, }, @@ -12525,6 +12571,15 @@ self:expand_type(node, values, elements) }) local visit_type visit_type = { cbs = { + ["function"] = { + before = visit_type_with_typeargs.before, + after = function(self, typ, children) + if self.feat_arity == false then + typ.min_arity = 0 + end + return visit_type_with_typeargs.after(self, typ, children) + end, + }, ["record"] = { before = function(self, typ) self:begin_scope() @@ -12693,7 +12748,6 @@ self:expand_type(node, values, elements) }) visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["function"] = visit_type_with_typeargs visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs diff --git a/tl.tl b/tl.tl index 08baa5db3..5f150c924 100644 --- a/tl.tl +++ b/tl.tl @@ -6835,21 +6835,33 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string +local function require_module(w: Where, module_name: string, opts: TypeCheckOptions, env: Env): Type, string local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$") as boolean) then + if found and (opts.feat_lax == "on" or found:match("tl$") as boolean) then env.module_filenames[module_name] = found env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + local save_defaults = env.defaults + local defaults : TypeCheckOptions = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -7055,7 +7067,11 @@ tl.new_env = function(opts?: EnvOptions): Env, string if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7328,9 +7344,15 @@ do local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple - return f.min_arity < nfargs - and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) - or tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function drop_constant_value(t: Type): Type @@ -8982,7 +9004,11 @@ do -- resolve unknown in lax mode, produce a general unknown function if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }) + }) end -- unwrap if tuple, resolve if nominal func = self:to_structural(func) @@ -9625,9 +9651,9 @@ do end end - function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType, feat_arity: boolean) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10345,7 +10371,7 @@ do local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_tuple(arg2, { a_type(arg2, "any", {}) }), rets = a_tuple(arg2, {}) }) @@ -10433,7 +10459,11 @@ do end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts: TypeCheckOptions = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11592,7 +11622,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11603,7 +11633,7 @@ do self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11632,7 +11662,7 @@ do self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11665,7 +11695,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11679,7 +11709,7 @@ do end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11750,7 +11780,7 @@ do end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11824,7 +11854,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11850,7 +11880,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12339,7 +12369,18 @@ do end, }, ["pragma"] = { - after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end return NONE end, }, @@ -12530,6 +12571,15 @@ do local visit_type: Visitor visit_type = { cbs = { + ["function"] = { + before = visit_type_with_typeargs.before, + after = function(self: TypeChecker, typ: FunctionType, children: {Type}): Type + if self.feat_arity == false then + typ.min_arity = 0 + end + return visit_type_with_typeargs.after(self, typ, children) + end + }, ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() @@ -12698,7 +12748,6 @@ do visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["function"] = visit_type_with_typeargs visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs From e5ee053aef12e16f0daac1503d655cbc3cd9429b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 2 Sep 2024 17:46:06 +0000 Subject: [PATCH 171/224] docs: pragmas (#798) Proofreading by Thijs Schreijer Co-authored-by: Thijs Schreijer --- docs/pragmas.md | 134 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 docs/pragmas.md diff --git a/docs/pragmas.md b/docs/pragmas.md new file mode 100644 index 000000000..b33efb4c2 --- /dev/null +++ b/docs/pragmas.md @@ -0,0 +1,134 @@ +# Pragmas + +Teal is evolving as a language. Sometimes we need to add incompatible changes +to the language, but we don't want to break everybody's code at once. The way +to deal with this is by adding _pragmatic annotations_ (typically known in +compiler lingo as "pragmas") that tell the compiler about how to interpret +various minutiae of the language, in practice picking which "dialect" of the +language to use. This lets the programmer pedal back on certain language +changes and adopt them gradually as the existing codebase is converted to the +new version. + +Let's look at a concrete example where pragmas can help us: function arity +checks. + +## Function arity checks + +If you're coming from an older version of Teal, it is possible that you will +start getting lots of errors related to numbers of arguments, such as: + +``` +wrong number of arguments (given 2, expects 4) +``` + +This is because, up to Teal 0.15.x, the language was lenient on the _arity_ of +function calls (the number of expressions passed as arguments in the call). It +would just assume that any missing arguments were intended to be `nil` on +purpose. More often than not, this is not the case, and a missing argument +does not mean that the argument was optional, but rather that the programmer +forgot about it (this is common when adding new arguments during a code +refactor). + +Teal now features _optional function arguments_. if an argument can be +optionally elided, you now can, or rather, have to, annotate it explicitly +adding a `?` to its name: + +```lua +local function greet(greeting: string, name?: string) + if name then + print(string.format("%s, %s!", greeting, name)) + else + print(greeting .. "!") + end +end + +greet("Hello", "Teal") --> Hello, Teal! +greet("Hello") --> Hello! +greet() --> compile error: wrong number of arguments (given 0, expects at least 1 and at most 2) +``` + +However, there are many Teal libraries out there (and Lua libraries for which +[.d.tl type declaration files](declaration_files.md) were written), which were +prepared for earlier versions of Teal. + +The good news is that you don't have to convert all of them at once, neither +you have to make an all-or-nothing choice whether to have or not those +function arity checks. + +You can enable or disable arity checks using the `arity` pragma. Let's first +assume we have an old library written for older versions of Teal: + +```lua +-- old_library.tl +local record old_library +end + +-- no `?` annotations here, but `name` is an optional argument +function old_library.greet(greeting: string, name: string) + if name then + print(string.format("%s, %s!", greeting, name)) + else + print(greeting .. "!") + end +end + +return old_library +``` + +Now we want to use this library with the current version of Teal, but we don't +want to lose arity checks in our own code. We can temporarily disable arity +checks, require the library, then re-enable them: + +```lua +--#pragma arity off +local old_library = require("old_library) +--#pragma arity on + +local function add(a: number, b: number): number + return a + b +end + +print(add(1)) -- compile error: wrong number of arguments (given 1, expects 2) + +old_library.greet("Hello", "Teal") --> Hello, Teal! + +-- no compile error here, because in code loaded with `arity off`, +-- every argument is optional: +old_library.greet("Hello") --> Hello! + +-- no compile error here as well, +-- even though this call will crash at runtime: +old_library.greet() --> runtime error: attempt to concatenate a nil value (local 'greeting') +``` + +The `arity` pragma was introduced as a way to gradually convert codebases, as +opposed to the wholesale approach of passing `--feat-arity=off` to the +compiler command-line or setting `feat_arity = "off"` in `tlconfig.lua`, the +[compiler options](compiler_options.md) file. + +### Optional arities versus optional values + +Note that arity checks are about the number of _expressions_ used as arguments +in function calls: it does not check whether the _values_ are `nil` or not. +In the above example, even with arity check enabled, you could still write +`greet(nil, nil)` and that would be accepted by the compiler as valid, +even though it would crash at runtime. + +Explicit checking for `nil` is a separate feature, which may be added in a +future version of Teal. When that happens, we will definitely need a `pragma` +to allow for gradual adoption of it! + +## What pragmas are not + +One final word about pragmas: there is no well-established definition for a +"compiler pragma" in the literature, even though this is a common term. + +It's important to clarify here that Teal pragmas are not intended as +general-purpose annotations (the kind of things you usually see with `@-` +syntax in various other languages such as C#, Java or `#[]` in Rust). Pragmas +here are intended as compiler directives, more akin to compiler flags (e.g. +the `#pragma` use in C compilers). + +In short, our practical goal for pragmas is to allow for handling +compatibility issues when dealing with the language evolution. That is, in a +Teal codebase with no legacy concerns, there should be no pragmas. From 67193336f71b713a5d1e82902a1e62b5c56e8296 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 2 Sep 2024 17:46:45 +0000 Subject: [PATCH 172/224] docs: explain type aliasing (#799) --- docs/aliasing.md | 145 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 docs/aliasing.md diff --git a/docs/aliasing.md b/docs/aliasing.md new file mode 100644 index 000000000..33c840745 --- /dev/null +++ b/docs/aliasing.md @@ -0,0 +1,145 @@ +# Type aliasing rules in Teal + +## The general rule + +In Teal we can declare new types with user-defined names. These are called +_nominal types_. These nominal types may be unique, or aliases. + +The `local type` syntax produces a new _nominal type_. Whenever you assign to +it another user-defined nominal type, it becomes a _type alias_. Whenever you +assign to it a type constructor, it becomes a new unique type. Type +constructors are syntax constructs such as: block constructors for records, +interfaces and enums (e.g. `record` ... `end`); function signature +declarations with `function()`; applications of generics with `<>`-notation; +declarations of array, tuple or map types with `{}`-notation; or a primitive +type name such as `number`. + +Syntax such as `local record R` is a shorthand to `local type R = record`, so +the same rules apply: it declares a new unique type. + +Nominal types are compared against each other _by name_, but type aliases are +considered to be equivalent. + +```lua +local record Point3D + x: number + y: number + z: number +end + +local record Vector3D + x: number + y: number + z: number +end + +local p: Point3D = { x = 1.0, y = 0.3, z = 2.5 } + +local v: Vector3D = p -- Teal compile error: Point3D is not a Vector3D + +local type P3D = Point3D + +local p2: P3D + +p2 = p -- ok! P3D is a type alias type Point3D +p = p2 -- ok! aliasing works both ways: they are effectively the same type +``` + +Nominal types are compared against non-nominal types _by structure_, so that +you can manipulate concrete values, which have inferred types. For example, +you can assign a plain function to a nominal function type, as long as the +signatures are compatible, and you can assign a number literal to a nominal +number type. + +```lua +local type MyFunction = function(number): string + +-- f has a nominal type +local f: MyFunction + +-- g is inferred a structural type: function(number): string +local g = function(n: number): string + return tostring(n) +end + +f = g -- ok! structural matched against nominal +g = f -- ok! nominal matched against structural +``` + +You can declare structural types for functions explicitly: + +```lua +local type MyFunction = function(number): string + +-- f has a nominal type +local f: MyFunction + +-- h was explicitly given a structural function type +local h: function(n: number): string + +f = h -- ok! +h = f -- ok! +``` + +By design, there is no syntax in Teal for declaring structural record types. + +## Some examples + +Type aliasing only happens when declaring a new user-defined nominal type +using an existing user-defined nominal type. + +```lua +local type Record1 = record + x: integer + y: integer +end + +local type Record2 = Record1 + +local r1: Record1 +assert(r1 is Record2) -- ok! +``` + +This does not apply to primitive types. Declaring a type name with the same +primitive type as a previous declaration is not an aliasing operation. This +allows you to create types based on primitive types which are distinct from +each other. + +```lua +local type Temperature = number + +local type TemperatureAlias = Temperature + +local type Width = number + +local temp: Temperature + +assert(temp is TemperatureAlias) -- ok! +assert(temp is Width) -- Teal compile error: temp (of type Temperature) can never be a Width +``` + +Like records, each declaration of a function type in the program source code +represents a distinct type. The `function(...):...` syntax for type +declaration is a type constructor. + +```lua +local type Function1 = function(number): string + +local type Function2 = function(number): string + +local f1: Function1 + +assert(f1 is Function2) -- Teal compile error: f1 (of type Function2) can never be a Function1 +``` + +However, user-defined nominal names referencing those function types can be +aliased. + +```lua +local type Function1 = function(number): string + +local type Function3 = Function1 + +local f1: Function1 +assert(f1 is Function3) -- ok! +``` From a84ac183040bdc2aaa407da17bcf0be372d8759e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 31 Aug 2024 22:04:38 -0300 Subject: [PATCH 173/224] fix: resolve type aliases as early as possible --- spec/subtyping/typealias_spec.lua | 35 +++++++++++++++++++++++++++++++ tl.lua | 4 +++- tl.tl | 4 +++- 3 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 spec/subtyping/typealias_spec.lua diff --git a/spec/subtyping/typealias_spec.lua b/spec/subtyping/typealias_spec.lua new file mode 100644 index 000000000..71fefa9a0 --- /dev/null +++ b/spec/subtyping/typealias_spec.lua @@ -0,0 +1,35 @@ +local util = require("spec.util") + +describe("typealias", function() + it("nested type aliases match", util.check([[ + local record R + enum E + end + + type E2 = E + end + + function R.f(_use_type: R.E) + end + + function R.g(use_alias: R.E2) + R.f(use_alias) + end + ]])) + + it("resolves early, works with unions", util.check([[ + local record R + record P + x: integer + end + + type Z = P + end + + function R.f(a: boolean | R.Z) + if a is R.Z then + print("hello") + end + end + ]])) +end) diff --git a/tl.lua b/tl.lua index fafb1a100..4622d5800 100644 --- a/tl.lua +++ b/tl.lua @@ -6774,7 +6774,7 @@ local function show_type_base(t, short, seen) elseif t.typename == "none" then return "" elseif t.typename == "typealias" then - return "type " .. show(t.alias_to) + return "type alias to " .. show(t.alias_to) elseif t.typename == "typedecl" then return "type " .. show(t.def) else @@ -12713,6 +12713,8 @@ self:expand_type(node, values, elements) }) local tv = typ tv.typevar = t.typearg tv.constraint = t.constraint + elseif t.typename == "typealias" then + typ.found = t.alias_to.found elseif t.typename == "typedecl" then if t.def.typename ~= "circular_require" then typ.found = t diff --git a/tl.tl b/tl.tl index 5f150c924..158bfccd7 100644 --- a/tl.tl +++ b/tl.tl @@ -6774,7 +6774,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str elseif t.typename == "none" then return "" elseif t is TypeAliasType then - return "type " .. show(t.alias_to) + return "type alias to " .. show(t.alias_to) elseif t is TypeDeclType then return "type " .. show(t.def) else @@ -12713,6 +12713,8 @@ do local tv = typ as TypeVarType tv.typevar = t.typearg tv.constraint = t.constraint + elseif t is TypeAliasType then + typ.found = t.alias_to.found elseif t is TypeDeclType then if t.def.typename ~= "circular_require" then typ.found = t From dc15bc4fe0651903c9bd2246efdc5a4c73b445e7 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 2 Sep 2024 11:27:06 -0300 Subject: [PATCH 174/224] lexer: # is an operator --- tl.lua | 4 ++-- tl.tl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tl.lua b/tl.lua index 4622d5800..eb16f19f4 100644 --- a/tl.lua +++ b/tl.lua @@ -965,11 +965,11 @@ do end local lex_any_char_kinds = {} - local single_char_kinds = { "[", "]", "(", ")", "{", "}", ",", "#", ";", "?" } + local single_char_kinds = { "[", "]", "(", ")", "{", "}", ",", ";", "?" } for _, c in ipairs(single_char_kinds) do lex_any_char_kinds[c] = c end - for _, c in ipairs({ "+", "*", "|", "&", "%", "^" }) do + for _, c in ipairs({ "#", "+", "*", "|", "&", "%", "^" }) do lex_any_char_kinds[c] = "op" end diff --git a/tl.tl b/tl.tl index 158bfccd7..a1cde4cee 100644 --- a/tl.tl +++ b/tl.tl @@ -787,7 +787,7 @@ local enum TokenKind "keyword" "op" "string" - "[" "]" "(" ")" "{" "}" "," ":" "#" "." ";" "?" + "[" "]" "(" ")" "{" "}" "," ":" "." ";" "?" "::" "..." "identifier" @@ -965,11 +965,11 @@ do end local lex_any_char_kinds: {string:TokenKind} = {} - local single_char_kinds: {TokenKind} = {"[", "]", "(", ")", "{", "}", ",", "#", ";", "?"} + local single_char_kinds: {TokenKind} = {"[", "]", "(", ")", "{", "}", ",", ";", "?"} for _, c in ipairs(single_char_kinds) do lex_any_char_kinds[c] = c end - for _, c in ipairs({"+", "*", "|", "&", "%", "^"}) do + for _, c in ipairs({"#", "+", "*", "|", "&", "%", "^"}) do lex_any_char_kinds[c] = "op" end From 1dd8f9e1ce8dbc36ba26d25305af3a03dd252854 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 31 Aug 2024 23:38:05 -0300 Subject: [PATCH 175/224] optional arity: lax arg check in functions with feat-arity off This might be too lax, but it does revert back to accepting assignments that were valid in Teal 0.15. --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index eb16f19f4..15fe568a1 100644 --- a/tl.lua +++ b/tl.lua @@ -8738,7 +8738,7 @@ a.types[i], b.types[i]), } local errs = {} local aa, ba = a.args.tuple, b.args.tuple - if (not b.args.is_va) and a.min_arity > b.min_arity then + if (not b.args.is_va) and (self.feat_arity and a.min_arity > b.min_arity) then table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do diff --git a/tl.tl b/tl.tl index a1cde4cee..982f280d8 100644 --- a/tl.tl +++ b/tl.tl @@ -8738,7 +8738,7 @@ do local errs = {} local aa, ba = a.args.tuple, b.args.tuple - if (not b.args.is_va) and a.min_arity > b.min_arity then + if (not b.args.is_va) and (self.feat_arity and a.min_arity > b.min_arity) then table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do From e49ac8b7a2571a9960eee212e3dd3c167972cfc7 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 31 Aug 2024 22:34:12 -0300 Subject: [PATCH 176/224] fix: do not generate Lua table for interfaces --- tl.lua | 2 +- tl.tl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tl.lua b/tl.lua index 15fe568a1..60d0fda18 100644 --- a/tl.lua +++ b/tl.lua @@ -5075,7 +5075,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) for fname, ftype in fields_of(typ) do if ftype.typename == "typedecl" then local def = ftype.def - if def.fields then + if def.typename == "record" then table.insert(out, fname) table.insert(out, " = ") table.insert(out, print_record_def(def)) diff --git a/tl.tl b/tl.tl index 982f280d8..5462194a1 100644 --- a/tl.tl +++ b/tl.tl @@ -5075,7 +5075,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode?: boolean | for fname, ftype in fields_of(typ) do if ftype is TypeDeclType then local def = ftype.def - if def is RecordLikeType then + if def is RecordType then table.insert(out, fname) table.insert(out, " = ") table.insert(out, print_record_def(def)) From d04fa0f6b69c5379952f7ec885366c74f9c36594 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 2 Sep 2024 11:53:13 -0300 Subject: [PATCH 177/224] pragma: ignores other --# lines --- spec/pragma/invalid_spec.lua | 14 +++++++++----- tl.lua | 12 ++++++++---- tl.tl | 12 ++++++++---- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/spec/pragma/invalid_spec.lua b/spec/pragma/invalid_spec.lua index 105770481..e6dadeff2 100644 --- a/spec/pragma/invalid_spec.lua +++ b/spec/pragma/invalid_spec.lua @@ -1,17 +1,21 @@ local util = require("spec.util") describe("invalid pragma", function() - it("rejects invalid pragma", util.check_syntax_error([[ - --#invalid_pragma on + it("ignores other --# lines", util.check([[ + --#invalid on + ]])) + + it("rejects invalid pragma", util.check_type_error([[ + --#pragma invalid_foo on ]], { - { y = 1, msg = "invalid token '--#invalid_pragma'" } + { y = 1, msg = "invalid pragma: invalid_foo" } })) it("pragmas currently do not accept punctuation", util.check_syntax_error([[ --#pragma something(other) ]], { - { y = 1, msg = "invalid token '('" }, - { y = 1, msg = "invalid token ')'" }, + { y = 1, x = 26, msg = "invalid token '('" }, + { y = 1, x = 32, msg = "invalid token ')'" }, })) it("pragma arguments need to be in a single line", util.check_syntax_error([[ diff --git a/tl.lua b/tl.lua index 60d0fda18..b9b628483 100644 --- a/tl.lua +++ b/tl.lua @@ -879,7 +879,7 @@ do ["number hexfloat"] = "number", ["number power"] = "number", ["number powersign"] = "$ERR invalid_number$", - ["pragma"] = "pragma", + ["pragma"] = nil, ["pragma any"] = nil, ["pragma word"] = "pragma_identifier", } @@ -1285,11 +1285,15 @@ do elseif state == "pragma" then if not lex_word[c] then end_token_prev("pragma") - if tokens[nt].tk ~= "--#pragma" then - add_syntax_error() + if tokens[nt].tk == "--#pragma" then + state = "pragma any" + else + state = "comment short" + table.remove(tokens) + nt = nt - 1 + drop_token() end fwd = false - state = "pragma any" end elseif state == "pragma any" then if c == "\n" then diff --git a/tl.tl b/tl.tl index 5462194a1..7ee6d1f82 100644 --- a/tl.tl +++ b/tl.tl @@ -879,7 +879,7 @@ do ["number hexfloat"] = "number", ["number power"] = "number", ["number powersign"] = "$ERR invalid_number$", - ["pragma"] = "pragma", + ["pragma"] = nil, -- drop comment ["pragma any"] = nil, -- never in a token ["pragma word"] = "pragma_identifier", -- never in a token } @@ -1285,11 +1285,15 @@ do elseif state == "pragma" then if not lex_word[c] then end_token_prev("pragma") - if tokens[nt].tk ~= "--#pragma" then - add_syntax_error() + if tokens[nt].tk == "--#pragma" then + state = "pragma any" + else + state = "comment short" + table.remove(tokens) + nt = nt - 1 + drop_token() end fwd = false - state = "pragma any" end elseif state == "pragma any" then if c == "\n" then From 3cc78b6443fb4f4a8d3521b16aff44546f5abace Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 2 Sep 2024 11:57:13 -0300 Subject: [PATCH 178/224] lexer: simplify TokenKind --- tl.lua | 70 ++++++++++++++++++++++++++------------------------------ tl.tl | 72 +++++++++++++++++++++++++++------------------------------- 2 files changed, 65 insertions(+), 77 deletions(-) diff --git a/tl.lua b/tl.lua index b9b628483..f2f12540b 100644 --- a/tl.lua +++ b/tl.lua @@ -802,9 +802,6 @@ end - - - @@ -864,21 +861,21 @@ do ["got /"] = "op", ["got :"] = "op", ["got --["] = nil, - ["string single"] = "$ERR invalid_string$", - ["string single got \\"] = "$ERR invalid_string$", - ["string double"] = "$ERR invalid_string$", - ["string double got \\"] = "$ERR invalid_string$", - ["string long"] = "$ERR invalid_string$", - ["string long got ]"] = "$ERR invalid_string$", + ["string single"] = "$ERR$", + ["string single got \\"] = "$ERR$", + ["string double"] = "$ERR$", + ["string double got \\"] = "$ERR$", + ["string long"] = "$ERR$", + ["string long got ]"] = "$ERR$", ["comment short"] = nil, - ["comment long"] = "$ERR unfinished_comment$", - ["comment long got ]"] = "$ERR unfinished_comment$", + ["comment long"] = "$ERR$", + ["comment long got ]"] = "$ERR$", ["number dec"] = "integer", ["number decfloat"] = "number", ["number hex"] = "integer", ["number hexfloat"] = "number", ["number power"] = "number", - ["number powersign"] = "$ERR invalid_number$", + ["number powersign"] = "$ERR$", ["pragma"] = nil, ["pragma any"] = nil, ["pragma word"] = "pragma_identifier", @@ -1104,23 +1101,13 @@ do in_token = false end - local function add_syntax_error() + local function add_syntax_error(msg) local t = tokens[nt] - local msg - if t.kind == "$ERR invalid_string$" then - msg = "malformed string" - elseif t.kind == "$ERR invalid_number$" then - msg = "malformed number" - elseif t.kind == "$ERR unfinished_comment$" then - msg = "unfinished long comment" - else - msg = "invalid token '" .. t.tk .. "'" - end table.insert(errs, { filename = filename, y = t.y, x = t.x, - msg = msg, + msg = msg or "invalid token '" .. t.tk .. "'", }) end @@ -1170,7 +1157,7 @@ do end_token(k, c) elseif not lex_space[c] then begin_token() - end_token_here("$ERR invalid$") + end_token_here("$ERR$") add_syntax_error() end end @@ -1303,7 +1290,7 @@ do begin_token() elseif not lex_space[c] then begin_token() - end_token_here("$ERR invalid$") + end_token_here("$ERR$") add_syntax_error() end elseif state == "pragma word" then @@ -1357,8 +1344,8 @@ do local skip, valid = lex_string_escape(input, i, c) i = i + skip if not valid then - end_token_here("$ERR invalid_string$") - add_syntax_error() + end_token_here("$ERR$") + add_syntax_error("malformed string") end x = x + skip state = "string double" @@ -1373,8 +1360,8 @@ do local skip, valid = lex_string_escape(input, i, c) i = i + skip if not valid then - end_token_here("$ERR invalid_string$") - add_syntax_error() + end_token_here("$ERR$") + add_syntax_error("malformed string") end x = x + skip state = "string single" @@ -1462,8 +1449,8 @@ do elseif lex_decimals[c] then state = "number power" else - end_token_here("$ERR invalid_number$") - add_syntax_error() + end_token_here("$ERR$") + add_syntax_error("malformed number") state = "any" end elseif state == "number power" then @@ -1478,8 +1465,17 @@ do if in_token then if last_token_kind[state] then end_token_prev(last_token_kind[state]) - if last_token_kind[state]:sub(1, 4) == "$ERR" then - add_syntax_error() + if last_token_kind[state] == "$ERR$" then + local state_type = state:sub(1, 6) + if state_type == "string" then + add_syntax_error("malformed string") + elseif state_type == "number" then + add_syntax_error("malformed number") + elseif state_type == "commen" then + add_syntax_error("unfinished long comment") + else + add_syntax_error() + end elseif keywords[tokens[nt].tk] then tokens[nt].kind = "keyword" end @@ -2892,10 +2888,8 @@ do return parse_table_literal(ps, i) elseif kind == "..." then return verify_kind(ps, i, "...") - elseif kind == "$ERR invalid_string$" then - return fail(ps, i, "malformed string") - elseif kind == "$ERR invalid_number$" then - return fail(ps, i, "malformed number") + elseif kind == "$ERR$" then + return fail(ps, i, "invalid token") end return fail(ps, i, "syntax error") end diff --git a/tl.tl b/tl.tl index 7ee6d1f82..71b24dcc8 100644 --- a/tl.tl +++ b/tl.tl @@ -795,10 +795,7 @@ local enum TokenKind "integer" "pragma" "pragma_identifier" - "$ERR unfinished_comment$" - "$ERR invalid_string$" - "$ERR invalid_number$" - "$ERR invalid$" + "$ERR$" "$EOF$" end @@ -864,21 +861,21 @@ do ["got /"] = "op", ["got :"] = "op", ["got --["] = nil, -- drop comment - ["string single"] = "$ERR invalid_string$", - ["string single got \\"] = "$ERR invalid_string$", - ["string double"] = "$ERR invalid_string$", - ["string double got \\"] = "$ERR invalid_string$", - ["string long"] = "$ERR invalid_string$", - ["string long got ]"] = "$ERR invalid_string$", + ["string single"] = "$ERR$", + ["string single got \\"] = "$ERR$", + ["string double"] = "$ERR$", + ["string double got \\"] = "$ERR$", + ["string long"] = "$ERR$", + ["string long got ]"] = "$ERR$", ["comment short"] = nil, -- drop comment - ["comment long"] = "$ERR unfinished_comment$", - ["comment long got ]"] = "$ERR unfinished_comment$", + ["comment long"] = "$ERR$", + ["comment long got ]"] = "$ERR$", ["number dec"] = "integer", ["number decfloat"] = "number", ["number hex"] = "integer", ["number hexfloat"] = "number", ["number power"] = "number", - ["number powersign"] = "$ERR invalid_number$", + ["number powersign"] = "$ERR$", ["pragma"] = nil, -- drop comment ["pragma any"] = nil, -- never in a token ["pragma word"] = "pragma_identifier", -- never in a token @@ -1104,23 +1101,13 @@ do in_token = false end - local function add_syntax_error() + local function add_syntax_error(msg?: string) local t = tokens[nt] - local msg: string - if t.kind == "$ERR invalid_string$" then - msg = "malformed string" - elseif t.kind == "$ERR invalid_number$" then - msg = "malformed number" - elseif t.kind == "$ERR unfinished_comment$" then - msg = "unfinished long comment" - else - msg = "invalid token '" .. t.tk .. "'" - end table.insert(errs, { filename = filename, y = t.y, x = t.x, - msg = msg, + msg = msg or "invalid token '" .. t.tk .. "'", }) end @@ -1170,7 +1157,7 @@ do end_token(k, c) elseif not lex_space[c] then begin_token() - end_token_here("$ERR invalid$") + end_token_here("$ERR$") add_syntax_error() end end @@ -1303,7 +1290,7 @@ do begin_token() elseif not lex_space[c] then begin_token() - end_token_here("$ERR invalid$") + end_token_here("$ERR$") add_syntax_error() end elseif state == "pragma word" then @@ -1357,8 +1344,8 @@ do local skip, valid = lex_string_escape(input, i, c) i = i + skip if not valid then - end_token_here("$ERR invalid_string$") - add_syntax_error() + end_token_here("$ERR$") + add_syntax_error("malformed string") end x = x + skip state = "string double" @@ -1373,8 +1360,8 @@ do local skip, valid = lex_string_escape(input, i, c) i = i + skip if not valid then - end_token_here("$ERR invalid_string$") - add_syntax_error() + end_token_here("$ERR$") + add_syntax_error("malformed string") end x = x + skip state = "string single" @@ -1462,8 +1449,8 @@ do elseif lex_decimals[c] then state = "number power" else - end_token_here("$ERR invalid_number$") - add_syntax_error() + end_token_here("$ERR$") + add_syntax_error("malformed number") state = "any" end elseif state == "number power" then @@ -1478,8 +1465,17 @@ do if in_token then if last_token_kind[state] then end_token_prev(last_token_kind[state]) - if last_token_kind[state]:sub(1, 4) == "$ERR" then - add_syntax_error() + if last_token_kind[state] == "$ERR$" then + local state_type = state:sub(1, 6) + if state_type == "string" then + add_syntax_error("malformed string") + elseif state_type == "number" then + add_syntax_error("malformed number") + elseif state_type == "commen" then + add_syntax_error("unfinished long comment") + else + add_syntax_error() + end elseif keywords[tokens[nt].tk] then tokens[nt].kind = "keyword" end @@ -2892,10 +2888,8 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node return parse_table_literal(ps, i) elseif kind == "..." then return verify_kind(ps, i, "...") - elseif kind == "$ERR invalid_string$" then - return fail(ps, i, "malformed string") - elseif kind == "$ERR invalid_number$" then - return fail(ps, i, "malformed number") + elseif kind == "$ERR$" then + return fail(ps, i, "invalid token") end return fail(ps, i, "syntax error") end From fcdd86156b8e377b83e3e6f72ade3d784faa0648 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 2 Sep 2024 21:46:51 -0300 Subject: [PATCH 179/224] metatables: check metamethod types in metatable definition Add special-case behavior to specialize a type `metatable` using the definition of `metamethod` entries from `R` (and not just a type-variable application of `R` into the definition of `global record metatable` from the standard library definition. See tests in spec/declaration/metatable_spec.lua for examples of the added checks. Fixes #633. (At least the extent of it that can be resolved at this time, without explicit `nil` support -- a good explanation as to why the second case isn't resolved is given by @bjornbm in https://github.com/teal-language/tl/issues/633#issuecomment-1450731594 : "the record definition defines what keys/values are valid, but not that they are defined (or more generally perhaps the values may be nil, since nil is a valid value of every type). What is checked is that values for the defined keys have the right type, and that no other keys are added to the record. --- spec/declaration/metatable_spec.lua | 145 ++++++++++++++++++++++++++++ spec/metamethods/index_spec.lua | 4 +- spec/metamethods/le_spec.lua | 4 +- spec/metamethods/lt_spec.lua | 4 +- tl.lua | 96 +++++++++++++----- tl.tl | 96 +++++++++++++----- 6 files changed, 295 insertions(+), 54 deletions(-) create mode 100644 spec/declaration/metatable_spec.lua diff --git a/spec/declaration/metatable_spec.lua b/spec/declaration/metatable_spec.lua new file mode 100644 index 000000000..a57718f03 --- /dev/null +++ b/spec/declaration/metatable_spec.lua @@ -0,0 +1,145 @@ +local util = require("spec.util") + +describe("metatable declaration", function() + it("checks metamethod declarations in record against a general contract", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __sub: function(self: Rec, b: integer, wat: integer): Rec + end + + local rec_mt: metatable + rec_mt = { + __add = function(self: Rec, b: Rec): Rec + return { n = self.n + b.n } + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print((r - 3).n) + ]], { + { y = 3, x = 28, msg = "__sub does not follow metatable contract: got function(Rec, integer, integer): Rec, expected function(A, B): C" }, + { y = 14, x = 16, msg = "wrong number of arguments" }, + })) + + it("checks metatable against metamethod declarations", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __add: function(self: Rec, b: integer): Rec + end + + local rec_mt: metatable + rec_mt = { + __add = function(self: Rec, b: Rec): Rec + return { n = self.n + b.n } + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print((r + 9).n) + print((9 + r).n) + ]], { + { y = 8, x = 41, msg = "in record field: __add: argument 2: got Rec, expected integer" }, + { y = 15, x = 14, msg = "argument 1: got integer, expected Rec" }, + })) + + it("checks non-method metamethods with self in any position", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __mul: function(a: integer, b: Rec): integer + end + + local rec_mt: metatable + rec_mt = { + __mul = function(a: integer, b: Rec): integer + return a * b.n + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print((9 * r) + 3.0) + print((r * 9) + 3.0) + ]], { + { y = 15, x = 14, msg = "argument 1: got Rec, expected integer" }, + })) + + it("checks metamethods with multiple entries of the type", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __div: function(a: Rec, b: Rec): integer + end + + local rec_mt: metatable + rec_mt = { + __div = function(a: Rec, b: Rec): integer + return a.n // b.n + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print((r / 9) + 3.0) + print((r / r) + 3.0) + ]], { + { y = 14, x = 18, msg = "argument 2: got integer, expected Rec" }, + })) + + it("checks metamethods with method-like self", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __index: function(Rec, s: string): Rec + end + + local rec_mt: metatable + rec_mt = { + __index = function(self: Rec, k: string): Rec + return { n = #k } + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print(r["hello"]) + print(r[true]) + ]], { + { y = 15, x = 15, msg = "argument 1: got boolean, expected string" }, + })) + + it("checks metamethods with method-like self (explicit self)", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __index: function(self: Rec, s: string): Rec + end + + local rec_mt: metatable + rec_mt = { + __index = function(r: Rec, k: string): Rec + return { n = #k } + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print(r["hello"]) + print(r[true]) + ]], { + { y = 15, x = 15, msg = "argument 1: got boolean, expected string" }, + })) + + it("checks metamethods with method-like self (other name)", util.check_type_error([[ + local type Rec = record + n: integer + metamethod __index: function(r: Rec, s: string): Rec + end + + local rec_mt: metatable + rec_mt = { + __index = function(r: Rec, k: string): Rec + return { n = #k } + end, + } + + local r: Rec = setmetatable({ n = 10 }, rec_mt) + print(r["hello"]) + print(r[true]) + ]], { + { y = 15, x = 15, msg = "argument 1: got boolean, expected string" }, + })) + +end) diff --git a/spec/metamethods/index_spec.lua b/spec/metamethods/index_spec.lua index 811839272..0a3354179 100644 --- a/spec/metamethods/index_spec.lua +++ b/spec/metamethods/index_spec.lua @@ -8,8 +8,8 @@ describe("metamethod __index", function() end local rec_mt: metatable = { - __index = function(self: Rec, s: string, n: number): string - return tostring(self.x + n) .. s + __index = function(self: Rec, s: string): string + return tostring(self.x) .. s end } diff --git a/spec/metamethods/le_spec.lua b/spec/metamethods/le_spec.lua index c0eaaa36c..bb34f4a80 100644 --- a/spec/metamethods/le_spec.lua +++ b/spec/metamethods/le_spec.lua @@ -51,7 +51,7 @@ describe("binary metamethod __le using <=", function() it("can be used via the second argument", util.check([[ local type Rec = record x: number - metamethod __le: function(number, Rec): Rec + metamethod __le: function(number, Rec): boolean end local rec_mt: metatable @@ -153,7 +153,7 @@ describe("binary metamethod __le using >=", function() it("can be used via the second argument", util.check([[ local type Rec = record x: number - metamethod __le: function(number, Rec): Rec + metamethod __le: function(number, Rec): boolean end local rec_mt: metatable diff --git a/spec/metamethods/lt_spec.lua b/spec/metamethods/lt_spec.lua index 56c87bbfb..43a2a5a20 100644 --- a/spec/metamethods/lt_spec.lua +++ b/spec/metamethods/lt_spec.lua @@ -51,7 +51,7 @@ describe("binary metamethod __lt using <", function() it("can be used via the second argument", util.check([[ local type Rec = record x: number - metamethod __lt: function(number, Rec): Rec + metamethod __lt: function(number, Rec): boolean end local rec_mt: metatable @@ -153,7 +153,7 @@ describe("binary metamethod __lt using >", function() it("can be used via the second argument", util.check([[ local type Rec = record x: number - metamethod __lt: function(number, Rec): Rec + metamethod __lt: function(number, Rec): boolean end local rec_mt: metatable diff --git a/tl.lua b/tl.lua index f2f12540b..595fb816e 100644 --- a/tl.lua +++ b/tl.lua @@ -229,7 +229,7 @@ do __mode: Mode __name: string __tostring: function(T): string - __pairs: function(T): (function(): (K, V)) + __pairs: function(T): function(): (K, V) __index: any --[[FIXME: function | table | anything with an __index metamethod]] __newindex: any --[[FIXME: function | table | anything with an __index metamethod]] @@ -237,27 +237,27 @@ do __gc: function(T) __close: function(T) - __add: function(any, any): any - __sub: function(any, any): any - __mul: function(any, any): any - __div: function(any, any): any - __idiv: function(any, any): any - __mod: function(any, any): any - __pow: function(any, any): any - __band: function(any, any): any - __bor: function(any, any): any - __bxor: function(any, any): any - __shl: function(any, any): any - __shr: function(any, any): any - __concat: function(any, any): any - - __len: function(T): any - __unm: function(T): any - __bnot: function(T): any - - __eq: function(any, any): boolean - __lt: function(any, any): boolean - __le: function(any, any): boolean + __add: function(A, B): C + __sub: function(A, B): C + __mul: function(A, B): C + __div: function(A, B): C + __idiv: function(A, B): C + __mod: function(A, B): C + __pow: function(A, B): C + __band: function(A, B): C + __bor: function(A, B): C + __bxor: function(A, B): C + __shl: function(A, B): C + __shr: function(A, B): C + __concat: function(A, B): C + + __len: function(T): A + __unm: function(T): A + __bnot: function(T): A + + __eq: function(A, B): boolean + __lt: function(A, B): boolean + __le: function(A, B): boolean end global record os @@ -6330,7 +6330,7 @@ end function Errors:fail_unresolved_nominals(scope, global_scope) if global_scope and scope.pending_nominals then for name, types in pairs(scope.pending_nominals) do - if not global_scope.pending_global_types[name] then + if not global_scope.pending_global_types[name] and name ~= "metatable" then for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) @@ -7120,6 +7120,8 @@ do + + @@ -7220,6 +7222,9 @@ do function TypeChecker:find_type(names, accept_typearg) local typ = self:find_var_type(names[1], "use_type") if not typ then + if #names == 1 and names[1] == "metatable" then + return self:find_type({ "_metatable" }) + end return nil end if typ.typename == "nominal" and typ.found then @@ -7891,6 +7896,27 @@ do self:add_var(nil, def.typeargs[i].typearg, tt) end local ret = self:resolve_typevars_at(t, def) + + if def == self.cache_std_metatable_type then + local tv = t.typevals[1] + if tv.typename == "nominal" then + local found = tv.found + if found and found.typename == "typedecl" then + local rec = found.def + if rec.fields and rec.meta_fields and ret.fields then + for fname, ftype in pairs(rec.meta_fields) do + if ret.fields[fname] then + if not self:is_a(ftype, ret.fields[fname]) then + self.errs:add(ftype, fname .. " does not follow metatable contract: got %s, expected %s", ftype, ret.fields[fname]) + end + end + ret.fields[fname] = ftype + end + end + end + end + end + self:end_scope() return ret elseif t.typevals then @@ -9440,7 +9466,12 @@ a.types[i], b.types[i]), } e2[2] = node.e2 args.tuple[2] = orig_b end - return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + + local mtdelta = metamethod.typename == "function" and metamethod.is_method and -1 or 0 + local ret_call = self:type_check_function_call(node, metamethod, args, mtdelta, node, e2) + local ret_unary = resolve_tuple(ret_call) + local ret = self:to_structural(ret_unary) + return ret, meta_on_operator else return nil, nil end @@ -12566,6 +12597,20 @@ self:expand_type(node, values, elements) }) return true end + local metamethod_is_method = { + ["__bnot"] = true, + ["__call"] = true, + ["__close"] = true, + ["__gc"] = true, + ["__index"] = true, + ["__is"] = true, + ["__len"] = true, + ["__newindex"] = true, + ["__pairs"] = true, + ["__tostring"] = true, + ["__unm"] = true, + } + local visit_type visit_type = { cbs = { @@ -12643,6 +12688,7 @@ self:expand_type(node, values, elements) }) fmacros = fmacros or {} table.insert(fmacros, ftype) end + ftype.is_method = metamethod_is_method[name] end typ.meta_fields[name] = ftype i = i + 1 @@ -12874,6 +12920,8 @@ self:expand_type(node, values, elements) }) type_priorities = TypeChecker.type_priorities, } + self.cache_std_metatable_type = env.globals["metatable"] and (env.globals["metatable"].t).def + setmetatable(self, { __index = TypeChecker }) self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) diff --git a/tl.tl b/tl.tl index 71b24dcc8..8cc31b544 100644 --- a/tl.tl +++ b/tl.tl @@ -229,7 +229,7 @@ do __mode: Mode __name: string __tostring: function(T): string - __pairs: function(T): (function(): (K, V)) + __pairs: function(T): function(): (K, V) __index: any --[[FIXME: function | table | anything with an __index metamethod]] __newindex: any --[[FIXME: function | table | anything with an __index metamethod]] @@ -237,27 +237,27 @@ do __gc: function(T) __close: function(T) - __add: function(any, any): any - __sub: function(any, any): any - __mul: function(any, any): any - __div: function(any, any): any - __idiv: function(any, any): any - __mod: function(any, any): any - __pow: function(any, any): any - __band: function(any, any): any - __bor: function(any, any): any - __bxor: function(any, any): any - __shl: function(any, any): any - __shr: function(any, any): any - __concat: function(any, any): any - - __len: function(T): any - __unm: function(T): any - __bnot: function(T): any - - __eq: function(any, any): boolean - __lt: function(any, any): boolean - __le: function(any, any): boolean + __add: function(A, B): C + __sub: function(A, B): C + __mul: function(A, B): C + __div: function(A, B): C + __idiv: function(A, B): C + __mod: function(A, B): C + __pow: function(A, B): C + __band: function(A, B): C + __bor: function(A, B): C + __bxor: function(A, B): C + __shl: function(A, B): C + __shr: function(A, B): C + __concat: function(A, B): C + + __len: function(T): A + __unm: function(T): A + __bnot: function(T): A + + __eq: function(A, B): boolean + __lt: function(A, B): boolean + __le: function(A, B): boolean end global record os @@ -6330,7 +6330,7 @@ end function Errors:fail_unresolved_nominals(scope: Scope, global_scope: Scope) if global_scope and scope.pending_nominals then for name, types in pairs(scope.pending_nominals) do - if not global_scope.pending_global_types[name] then + if not global_scope.pending_global_types[name] and name ~= "metatable" then for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) @@ -7101,6 +7101,8 @@ do dependencies: {string:string} collector: TypeCollector + cache_std_metatable_type: Type + gen_compat: GenCompat gen_target: GenTarget feat_arity: boolean @@ -7220,6 +7222,9 @@ do function TypeChecker:find_type(names: {string}, accept_typearg?: boolean): Type local typ = self:find_var_type(names[1], "use_type") if not typ then + if #names == 1 and names[1] == "metatable" then + return self:find_type({"_metatable"}) + end return nil end if typ is NominalType and typ.found then @@ -7891,6 +7896,27 @@ do self:add_var(nil, def.typeargs[i].typearg, tt) end local ret = self:resolve_typevars_at(t, def) + + if def == self.cache_std_metatable_type then + local tv = t.typevals[1] + if tv is NominalType then + local found = tv.found + if found and found is TypeDeclType then + local rec = found.def + if rec is RecordLikeType and rec.meta_fields and ret is RecordLikeType then + for fname, ftype in pairs(rec.meta_fields) do + if ret.fields[fname] then + if not self:is_a(ftype, ret.fields[fname]) then + self.errs:add(ftype, fname .. " does not follow metatable contract: got %s, expected %s", ftype, ret.fields[fname]) + end + end + ret.fields[fname] = ftype + end + end + end + end + end + self:end_scope() return ret elseif t.typevals then @@ -9440,7 +9466,12 @@ do e2[2] = node.e2 args.tuple[2] = orig_b end - return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + + local mtdelta = metamethod is FunctionType and metamethod.is_method and -1 or 0 + local ret_call = self:type_check_function_call(node, metamethod, args, mtdelta, node, e2) + local ret_unary = resolve_tuple(ret_call) + local ret = self:to_structural(ret_unary) + return ret, meta_on_operator else return nil, nil end @@ -12566,6 +12597,20 @@ do return true end + local metamethod_is_method: {string: boolean} = { + ["__bnot"] = true, + ["__call"] = true, + ["__close"] = true, + ["__gc"] = true, + ["__index"] = true, + ["__is"] = true, + ["__len"] = true, + ["__newindex"] = true, + ["__pairs"] = true, + ["__tostring"] = true, + ["__unm"] = true, + } + local visit_type: Visitor visit_type = { cbs = { @@ -12643,6 +12688,7 @@ do fmacros = fmacros or {} table.insert(fmacros, ftype) end + ftype.is_method = metamethod_is_method[name] end typ.meta_fields[name] = ftype i = i + 1 @@ -12874,6 +12920,8 @@ do type_priorities = TypeChecker.type_priorities, } + self.cache_std_metatable_type = env.globals["metatable"] and (env.globals["metatable"].t as TypeDeclType).def + setmetatable(self, { __index = TypeChecker }) self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) From fe86b725865424abf4424c0954514a52575d34a0 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 3 Sep 2024 13:24:32 -0300 Subject: [PATCH 180/224] compiler debugging: TL_DEBUG_FACTS environment variable --- tl.lua | 26 ++++++++++++++++++++++++++ tl.tl | 26 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/tl.lua b/tl.lua index 595fb816e..773a2c603 100644 --- a/tl.lua +++ b/tl.lua @@ -742,8 +742,13 @@ local DEFAULT_GEN_TARGET = "5.3" local TL_DEBUG = os.getenv("TL_DEBUG") +local TL_DEBUG_FACTS = os.getenv("TL_DEBUG_FACTS") local TL_DEBUG_MAXLINE = _tl_math_maxinteger +if TL_DEBUG_FACTS and not TL_DEBUG then + TL_DEBUG = "1" +end + if TL_DEBUG then local max = assert(tonumber(TL_DEBUG), "TL_DEBUG was defined, but not a number") if max < 0 then @@ -10361,6 +10366,27 @@ a.types[i], b.types[i]), } self:add_var(nil, v, t, "const", "narrow") end end + + if TL_DEBUG_FACTS then + local eval_indent = -1 + local real_eval_fact = eval_fact + eval_fact = function(self, known) + eval_indent = eval_indent + 1 + io.stderr:write((" "):rep(eval_indent)) + io.stderr:write("eval fact: ", tostring(known), "\n") + local facts = real_eval_fact(self, known) + if facts then + for _, k in ipairs(sorted_keys(facts)) do + local f = facts[k] + io.stderr:write((" "):rep(eval_indent), "=> ", tostring(f), "\n") + end + else + io.stderr:write((" "):rep(eval_indent), "=> .\n") + end + eval_indent = eval_indent - 1 + return facts + end + end end function TypeChecker:dismiss_unresolved(name) diff --git a/tl.tl b/tl.tl index 8cc31b544..027cfad56 100644 --- a/tl.tl +++ b/tl.tl @@ -742,8 +742,13 @@ end -------------------------------------------------------------------------------- local TL_DEBUG = os.getenv("TL_DEBUG") +local TL_DEBUG_FACTS = os.getenv("TL_DEBUG_FACTS") local TL_DEBUG_MAXLINE = math.maxinteger +if TL_DEBUG_FACTS and not TL_DEBUG then + TL_DEBUG="1" +end + if TL_DEBUG then local max = assert(tonumber(TL_DEBUG), "TL_DEBUG was defined, but not a number") if max < 0 then @@ -10361,6 +10366,27 @@ do self:add_var(nil, v, t, "const", "narrow") end end + + if TL_DEBUG_FACTS then + local eval_indent = -1 + local real_eval_fact = eval_fact + eval_fact = function(self: TypeChecker, known: Fact): {string: IsFact|EqFact} + eval_indent = eval_indent + 1 + io.stderr:write((" "):rep(eval_indent)) + io.stderr:write("eval fact: ", tostring(known), "\n") + local facts = real_eval_fact(self, known) + if facts then + for _, k in ipairs(sorted_keys(facts)) do + local f = facts[k] + io.stderr:write((" "):rep(eval_indent), "=> ", tostring(f), "\n") + end + else + io.stderr:write((" "):rep(eval_indent), "=> .\n") + end + eval_indent = eval_indent - 1 + return facts + end + end end function TypeChecker:dismiss_unresolved(name: string) From ed93fb31dcfee7db87993c07fb5dba64fbc289a8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 3 Sep 2024 13:25:36 -0300 Subject: [PATCH 181/224] generics: begin generalizing code with HasTypeArgs --- tl.lua | 46 ++++++++++++---------------------------------- tl.tl | 48 +++++++++++++----------------------------------- 2 files changed, 25 insertions(+), 69 deletions(-) diff --git a/tl.lua b/tl.lua index 773a2c603..8ebbcff34 100644 --- a/tl.lua +++ b/tl.lua @@ -4518,6 +4518,14 @@ local function recurse_type(s, ast, visit) local xs = {} + if ast.typeargs then + if ast.typeargs then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(s, child, visit)) + end + end + end + if ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do xs[i] = recurse_type(s, child, visit) @@ -4530,11 +4538,6 @@ local function recurse_type(s, ast, visit) table.insert(xs, recurse_type(s, ast.keys, visit)) table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast.fields then - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(s, child, visit)) - end - end if ast.interface_list then for _, child in ipairs(ast.interface_list) do table.insert(xs, recurse_type(s, child, visit)) @@ -4554,11 +4557,6 @@ local function recurse_type(s, ast, visit) end end elseif ast.typename == "function" then - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(s, child, visit)) - end - end if ast.args then for _, child in ipairs(ast.args.tuple) do table.insert(xs, recurse_type(s, child, visit)) @@ -4591,18 +4589,8 @@ local function recurse_type(s, ast, visit) table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast.typename == "typealias" then - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(s, child, visit)) - end - end table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(s, child, visit)) - end - end table.insert(xs, recurse_type(s, ast.def, visit)) end @@ -7461,6 +7449,10 @@ do copy.x = t.x copy.y = t.y + if t.typeargs then + (copy).typeargs, same = copy_typeargs(t.typeargs, same) + end + if t.typename == "array" then assert(copy.typename == "array") @@ -7487,11 +7479,6 @@ do end elseif t.typename == "typedecl" then assert(copy.typename == "typedecl") - - if t.typeargs then - copy.typeargs, same = copy_typeargs(t.typeargs, same) - end - copy.def, same = resolve(t.def, same) elseif t.typename == "typealias" then assert(copy.typename == "typealias") @@ -7507,11 +7494,6 @@ do copy.found = t.found elseif t.typename == "function" then assert(copy.typename == "function") - - if t.typeargs then - copy.typeargs, same = copy_typeargs(t.typeargs, same) - end - copy.macroexp = t.macroexp copy.min_arity = t.min_arity copy.is_method = t.is_method @@ -7521,10 +7503,6 @@ do assert(copy.typename == "record" or copy.typename == "interface") copy.declname = t.declname - if t.typeargs then - copy.typeargs, same = copy_typeargs(t.typeargs, same) - end - if t.elements then copy.elements, same = resolve(t.elements, same) diff --git a/tl.tl b/tl.tl index 027cfad56..be4ecb9c9 100644 --- a/tl.tl +++ b/tl.tl @@ -4518,6 +4518,14 @@ local function recurse_type(s: S, ast: Type, visit: Visitor(s: S, ast: Type, visit: Visitor(s: S, ast: Type, visit: Visitor(s: S, ast: Type, visit: Visitor Date: Tue, 3 Sep 2024 13:27:28 -0300 Subject: [PATCH 182/224] facts: avoid generating dummy And facts --- tl.lua | 3 +++ tl.tl | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tl.lua b/tl.lua index 8ebbcff34..d92997b09 100644 --- a/tl.lua +++ b/tl.lua @@ -10091,6 +10091,9 @@ a.types[i], b.types[i]), } FACT_TRUTHY = TruthyFact({}) facts_and = function(w, f1, f2) + if not f1 and not f2 then + return + end return AndFact({ f1 = f1, f2 = f2, w = w }) end diff --git a/tl.tl b/tl.tl index be4ecb9c9..cde49da67 100644 --- a/tl.tl +++ b/tl.tl @@ -10091,6 +10091,9 @@ do FACT_TRUTHY = TruthyFact {} facts_and = function(w: Where, f1: Fact, f2: Fact): Fact + if not f1 and not f2 then + return + end return AndFact { f1 = f1, f2 = f2, w = w } end From 68610321864658693519502ac1319141b049e412 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 3 Sep 2024 23:48:43 -0300 Subject: [PATCH 183/224] parser: we don't need enum_item --- tl.lua | 4 +--- tl.tl | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tl.lua b/tl.lua index d92997b09..bd733afdf 100644 --- a/tl.lua +++ b/tl.lua @@ -1961,7 +1961,6 @@ end - local TruthyFact = {} @@ -3687,9 +3686,8 @@ do def.enumset = {} while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do local item - i, item = verify_kind(ps, i, "string", "enum_item") + i, item = verify_kind(ps, i, "string", "string") if item then - table.insert(node, item) def.enumset[unquote(item.tk)] = true end end diff --git a/tl.tl b/tl.tl index cde49da67..fb612750b 100644 --- a/tl.tl +++ b/tl.tl @@ -1909,7 +1909,6 @@ local enum NodeKind "literal_table_item" "function" "expression_list" - "enum_item" "if" "if_block" "while" @@ -3687,9 +3686,8 @@ parse_enum_body = function(ps: ParseState, i: integer, def: EnumType, node: Node def.enumset = {} while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do local item: Node - i, item = verify_kind(ps, i, "string", "enum_item") + i, item = verify_kind(ps, i, "string", "string") if item then - table.insert(node, item) def.enumset[unquote(item.tk)] = true end end From 311fcd178788dc331136ef2b55029201b8c90dfa Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 00:09:32 -0300 Subject: [PATCH 184/224] parser: refactor special behaviors in type declarations --- tl.lua | 50 +++++++++++++++++++++++++++++++------------------- tl.tl | 50 +++++++++++++++++++++++++++++++------------------- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/tl.lua b/tl.lua index bd733afdf..af4655556 100644 --- a/tl.lua +++ b/tl.lua @@ -4109,6 +4109,33 @@ do return i, asgn end + local function parse_type_require(ps, i, asgn) + local istart = i + i, asgn.value = parse_expression(ps, i) + if not asgn.value then + return i + end + if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then + fail(ps, istart, "require() in type declarations cannot be part of larger expressions") + return i + end + if not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") + return i + end + return i, asgn + end + + local function parse_special_type_declaration(ps, i, asgn) + if ps.tokens[i].tk == "require" then + return true, parse_type_require(ps, i, asgn) + elseif ps.tokens[i].tk == "pcall" then + fail(ps, i, "pcall() cannot be used in type declarations") + return true, i + end + return false, i, asgn + end + parse_type_declaration = function(ps, i, node_name) local asgn = new_node(ps, i, node_name) local var @@ -4131,25 +4158,10 @@ do i = verify_tk(ps, i, "=") if ps.tokens[i].kind == "identifier" then - if ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_expression(ps, i) - if asgn.value then - if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then - fail(ps, istart, "require() in type declarations cannot be part of larger expressions") - return i - end - if not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") - return i - end - return i, asgn - else - return i - end - elseif ps.tokens[i].tk == "pcall" then - fail(ps, i, "pcall() cannot be used in type declarations") - return i + local done + done, i, asgn = parse_special_type_declaration(ps, i, asgn) + if done then + return i, asgn end end diff --git a/tl.tl b/tl.tl index fb612750b..8d2a7762a 100644 --- a/tl.tl +++ b/tl.tl @@ -4109,6 +4109,33 @@ local function parse_variable_declarations(ps: ParseState, i: integer, node_name return i, asgn end +local function parse_type_require(ps: ParseState, i: integer, asgn: Node): integer, Node + local istart = i + i, asgn.value = parse_expression(ps, i) + if not asgn.value then + return i + end + if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then + fail(ps, istart, "require() in type declarations cannot be part of larger expressions") + return i + end + if not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") + return i + end + return i, asgn +end + +local function parse_special_type_declaration(ps: ParseState, i: integer, asgn: Node): boolean, integer, Node + if ps.tokens[i].tk == "require" then + return true, parse_type_require(ps, i, asgn) + elseif ps.tokens[i].tk == "pcall" then + fail(ps, i, "pcall() cannot be used in type declarations") + return true, i + end + return false, i, asgn +end + parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKind): integer, Node local asgn: Node = new_node(ps, i, node_name) local var: Node @@ -4131,25 +4158,10 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin i = verify_tk(ps, i, "=") if ps.tokens[i].kind == "identifier" then - if ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_expression(ps, i) - if asgn.value then - if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then - fail(ps, istart, "require() in type declarations cannot be part of larger expressions") - return i - end - if not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") - return i - end - return i, asgn - else - return i - end - elseif ps.tokens[i].tk == "pcall" then - fail(ps, i, "pcall() cannot be used in type declarations") - return i + local done: boolean + done, i, asgn = parse_special_type_declaration(ps, i, asgn) + if done then + return i, asgn end end From 6d6d5d4df925693df6d4c43d605fc33268825ecd Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 00:28:10 -0300 Subject: [PATCH 185/224] parser: minor refactor --- tl.lua | 17 +++++++++++------ tl.tl | 17 +++++++++++------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tl.lua b/tl.lua index af4655556..27d500300 100644 --- a/tl.lua +++ b/tl.lua @@ -2303,6 +2303,7 @@ do local parse_type_list + local parse_typeargs_if_any local parse_expression local parse_expression_and_tk local parse_statements @@ -2662,12 +2663,18 @@ do return i, t end + parse_typeargs_if_any = function(ps, i) + if ps.tokens[i].tk == "<" then + return parse_anglebracket_list(ps, i, parse_typearg) + end + return i + end + local function parse_function_type(ps, i) local typ = new_type(ps, i, "function") i = i + 1 - if ps.tokens[i].tk == "<" then - i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end + + i, typ.typeargs = parse_typeargs_if_any(ps, i) if ps.tokens[i].tk == "(" then i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) @@ -2838,9 +2845,7 @@ do local function parse_function_args_rets_body(ps, i, node) local istart = i - 1 - if ps.tokens[i].tk == "<" then - i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end + i, node.typeargs = parse_typeargs_if_any(ps, i) i, node.args, node.min_arity = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.body = parse_statements(ps, i) diff --git a/tl.tl b/tl.tl index 8d2a7762a..81b835c15 100644 --- a/tl.tl +++ b/tl.tl @@ -2303,6 +2303,7 @@ local enum ParseTypeListMode end local parse_type_list: function(ParseState, integer, ParseTypeListMode): integer, TupleType +local parse_typeargs_if_any: function(ps: ParseState, i: integer): integer, {TypeArgType} local parse_expression: function(ParseState, integer): integer, Node, integer local parse_expression_and_tk: function(ps: ParseState, i: integer, tk: string): integer, Node local parse_statements: function(ParseState, integer, ? boolean): integer, Node @@ -2662,12 +2663,18 @@ local function parse_return_types(ps: ParseState, i: integer): integer, TupleTyp return i, t end +parse_typeargs_if_any = function(ps: ParseState, i: integer): integer, {TypeArgType} + if ps.tokens[i].tk == "<" then + return parse_anglebracket_list(ps, i, parse_typearg) + end + return i +end + local function parse_function_type(ps: ParseState, i: integer): integer, FunctionType local typ = new_type(ps, i, "function") as FunctionType i = i + 1 - if ps.tokens[i].tk == "<" then - i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end + + i, typ.typeargs = parse_typeargs_if_any(ps, i) if ps.tokens[i].tk == "(" then i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) @@ -2838,9 +2845,7 @@ end local function parse_function_args_rets_body(ps: ParseState, i: integer, node: Node): integer, Node local istart = i - 1 - if ps.tokens[i].tk == "<" then - i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end + i, node.typeargs = parse_typeargs_if_any(ps, i) i, node.args, node.min_arity = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.body = parse_statements(ps, i) From f5a1d519e887c18594fb81f515e34c753747d5ab Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 00:38:24 -0300 Subject: [PATCH 186/224] parser: refactor setting declname --- spec/inference/table_literal_spec.lua | 4 ++-- tl.lua | 21 +++++++++++---------- tl.tl | 21 +++++++++++---------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/spec/inference/table_literal_spec.lua b/spec/inference/table_literal_spec.lua index 71d5d91c7..fcd8b7abf 100644 --- a/spec/inference/table_literal_spec.lua +++ b/spec/inference/table_literal_spec.lua @@ -15,7 +15,7 @@ describe("bidirectional inference for table literals", function() } print(x) ]], { - { msg = "in record field: type: string \"who\" is not a member of enum" }, + { msg = "in record field: type: string \"who\" is not a member of TypeEnum" }, })) it("directed inference produces correct results for incomplete records (regression test for #348)", util.check([[ @@ -48,7 +48,7 @@ describe("bidirectional inference for table literals", function() f:bar({ "a", "b" }) ]], { - { msg = 'expected an array: at index 2: string "b" is not a member of enum' } + { msg = 'expected an array: at index 2: string "b" is not a member of Eno' } })) it("resolves nominals across nested generics (regression test for #499)", util.check_type_error([[ diff --git a/tl.lua b/tl.lua index 27d500300..35dabb782 100644 --- a/tl.lua +++ b/tl.lua @@ -3660,6 +3660,14 @@ do return true end + local function set_declname(def, declname) + if def.typename == "record" or def.typename == "interface" or def.typename == "enum" then + if not def.declname then + def.declname = declname + end + end + end + local function parse_nested_type(ps, i, def, typename, parse_body) i = i + 1 local iv = i @@ -3676,9 +3684,7 @@ do local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - if ndef.fields then - ndef.declname = v.tk - end + set_declname(ndef, v.tk) nt.newtype = new_typedecl(ps, itype, ndef) end @@ -4193,11 +4199,7 @@ do end end - if def.fields or def.typename == "enum" then - if not def.declname then - def.declname = asgn.var.tk - end - end + set_declname(def, asgn.var.tk) elseif nt.typename == "typealias" then if typeargs then nt.typeargs = typeargs @@ -4221,8 +4223,7 @@ do return fail(ps, i, "expected a type name") end - assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") - def.declname = asgn.var.tk + set_declname(def, asgn.var.tk) i = parse_body(ps, i, def, nt) diff --git a/tl.tl b/tl.tl index 81b835c15..eadad990d 100644 --- a/tl.tl +++ b/tl.tl @@ -3660,6 +3660,14 @@ local function store_field_in_record(ps: ParseState, i: integer, field_name: str return true end +local function set_declname(def: Type, declname: string) + if def is RecordType or def is InterfaceType or def is EnumType then + if not def.declname then + def.declname = declname + end + end +end + local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType, typename: TypeName, parse_body: ParseBody): integer, boolean i = i + 1 -- skip 'record' or 'enum' local iv = i @@ -3676,9 +3684,7 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - if ndef is RecordLikeType then - ndef.declname = v.tk - end + set_declname(ndef, v.tk) nt.newtype = new_typedecl(ps, itype, ndef) end @@ -4193,11 +4199,7 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin end end - if def is RecordLikeType or def is EnumType then - if not def.declname then - def.declname = asgn.var.tk - end - end + set_declname(def, asgn.var.tk) elseif nt is TypeAliasType then if typeargs then nt.typeargs = typeargs @@ -4221,8 +4223,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod return fail(ps, i, "expected a type name") end - assert(def is RecordType or def is InterfaceType or def is EnumType) - def.declname = asgn.var.tk + set_declname(def, asgn.var.tk) i = parse_body(ps, i, def, nt) From 73faece7d0330cf1335147c19672f9bce6814df8 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 00:46:24 -0300 Subject: [PATCH 187/224] parser: simplify parse_type_constructor signature --- tl.lua | 10 +++++----- tl.tl | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tl.lua b/tl.lua index 35dabb782..a9b43c039 100644 --- a/tl.lua +++ b/tl.lua @@ -4209,12 +4209,12 @@ do return i, asgn end - local function parse_type_constructor(ps, i, node_name, type_name, parse_body) + local function parse_type_constructor(ps, i, node_name, tn) local asgn = new_node(ps, i, node_name) local nt = new_node(ps, i, "newtype") asgn.value = nt local itype = i - local def = new_type(ps, i, type_name) + local def = new_type(ps, i, tn) i = i + 2 @@ -4225,7 +4225,7 @@ do set_declname(def, asgn.var.tk) - i = parse_body(ps, i, def, nt) + i = parse_type_body_fns[tn](ps, i, def, nt) nt.newtype = new_typedecl(ps, itype, def) @@ -4256,7 +4256,7 @@ do elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) + return parse_type_constructor(ps, i, "local_type", tn) end return parse_variable_declarations(ps, i + 1, "local_declaration") end @@ -4269,7 +4269,7 @@ do elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then return parse_type_declaration(ps, i + 2, "global_type") elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) + return parse_type_constructor(ps, i, "global_type", tn) elseif ps.tokens[i + 1].kind == "identifier" then return parse_variable_declarations(ps, i + 1, "global_declaration") end diff --git a/tl.tl b/tl.tl index eadad990d..92b7db3d7 100644 --- a/tl.tl +++ b/tl.tl @@ -4209,12 +4209,12 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin return i, asgn end -local function parse_type_constructor(ps: ParseState, i: integer, node_name: NodeKind, type_name: TypeName, parse_body: ParseBody): integer, Node +local function parse_type_constructor(ps: ParseState, i: integer, node_name: NodeKind, tn: TypeName): integer, Node local asgn: Node = new_node(ps, i, node_name) local nt: Node = new_node(ps, i, "newtype") asgn.value = nt local itype = i - local def = new_type(ps, i, type_name) + local def = new_type(ps, i, tn) i = i + 2 -- skip `local` or `global`, and the constructor name @@ -4225,7 +4225,7 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod set_declname(def, asgn.var.tk) - i = parse_body(ps, i, def, nt) + i = parse_type_body_fns[tn](ps, i, def, nt) nt.newtype = new_typedecl(ps, itype, def) @@ -4256,7 +4256,7 @@ local function parse_local(ps: ParseState, i: integer): integer, Node elseif ntk == "macroexp" and ps.tokens[i+2].kind == "identifier" then return parse_local_macroexp(ps, i) elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then - return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) + return parse_type_constructor(ps, i, "local_type", tn) end return parse_variable_declarations(ps, i + 1, "local_declaration") end @@ -4269,7 +4269,7 @@ local function parse_global(ps: ParseState, i: integer): integer, Node elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then return parse_type_declaration(ps, i + 2, "global_type") elseif parse_type_body_fns[tn] and ps.tokens[i+2].kind == "identifier" then - return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) + return parse_type_constructor(ps, i, "global_type", tn) elseif ps.tokens[i+1].kind == "identifier" then return parse_variable_declarations(ps, i + 1, "global_declaration") end From 9a89cf05e8d1fd2525cfdd1a3f95c5af39ac58b2 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 01:11:04 -0300 Subject: [PATCH 188/224] parser: refactor parse_type_body --- spec/error_reporting/typecheck_error_spec.lua | 2 +- tl.lua | 112 +++++++++------- tl.tl | 121 ++++++++++-------- 3 files changed, 136 insertions(+), 99 deletions(-) diff --git a/spec/error_reporting/typecheck_error_spec.lua b/spec/error_reporting/typecheck_error_spec.lua index b6846beae..5afa15723 100644 --- a/spec/error_reporting/typecheck_error_spec.lua +++ b/spec/error_reporting/typecheck_error_spec.lua @@ -106,7 +106,7 @@ describe("typecheck errors", function() aaa.myfunc(b) ]], { - { msg = "argument 1: Thing (defined in ./bbb.tl:4) is not a Thing (defined in ./aaa.tl:1)" } + { msg = "argument 1: Thing (defined in ./bbb.tl:3) is not a Thing (defined in ./aaa.tl:1)" } }) end) diff --git a/tl.lua b/tl.lua index a9b43c039..27048d34a 100644 --- a/tl.lua +++ b/tl.lua @@ -2430,11 +2430,39 @@ do return skip_i end + local function parse_type_body(ps, i, istart, node, tn) + local typeargs + local def + i, typeargs = parse_typeargs_if_any(ps, i) + + def = new_type(ps, istart, tn) + + if typeargs then + if def.typename == "record" or def.typename == "interface" then + def.typeargs = typeargs + end + end + + local ok + i, ok = parse_type_body_fns[tn](ps, i, def) + if not ok then + return fail(ps, i, "expected a type") + end + + i = verify_end(ps, i, istart, node) + + + + + return i, def + end + local function skip_type_body(ps, i) local tn = ps.tokens[i].tk i = i + 1 assert(parse_type_body_fns[tn], tn .. " has no parse body function") - return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) + local ii, tt = parse_type_body(ps, i, i - 1, {}, tn) + return ii, not not tt end local function parse_table_value(ps, i) @@ -3668,7 +3696,8 @@ do end end - local function parse_nested_type(ps, i, def, typename, parse_body) + local function parse_nested_type(ps, i, def, typename) + local istart = i i = i + 1 local iv = i @@ -3678,22 +3707,23 @@ do return fail(ps, i, "expected a variable name") end - local nt = new_node(ps, i - 2, "newtype") - local ndef = new_type(ps, i, typename) - local itype = i - local iok = parse_body(ps, i, ndef, nt) - if iok then - i = iok - set_declname(ndef, v.tk) - nt.newtype = new_typedecl(ps, itype, ndef) + local nt = new_node(ps, istart, "newtype") + + local ndef + i, ndef = parse_type_body(ps, i, istart, nt, typename) + if not ndef then + return i end + set_declname(ndef, v.tk) + + nt.newtype = new_typedecl(ps, istart, ndef) + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) return i end - parse_enum_body = function(ps, i, def, node) - local istart = i - 1 + parse_enum_body = function(ps, i, def) def.enumset = {} while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do local item @@ -3702,8 +3732,7 @@ do def.enumset[unquote(item.tk)] = true end end - i = verify_end(ps, i, istart, node) - return i, node + return i, true end local metamethod_names = { @@ -3812,15 +3841,10 @@ do end - parse_record_body = function(ps, i, def, node) - local istart = i - 1 + parse_record_body = function(ps, i, def) def.fields = {} def.field_order = {} - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - if ps.tokens[i].tk == "{" then local atype i, atype = parse_array_interface_type(ps, i, def) @@ -3910,7 +3934,7 @@ do store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then - i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) + i = parse_nested_type(ps, i, def, tn) else local is_metamethod = false if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then @@ -3979,8 +4003,7 @@ do end end end - i = verify_end(ps, i, istart, node) - return i, node + return i, true end parse_type_body_fns = { @@ -3993,31 +4016,25 @@ do local node = new_node(ps, i, "newtype") local def local tn = ps.tokens[i].tk - local itype = i - if parse_type_body_fns[tn] then - def = new_type(ps, i, tn) - i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node) - if not def then - return fail(ps, i, "expected a type") - end + local istart = i - node.newtype = new_typedecl(ps, itype, def) - return i, node + if parse_type_body_fns[tn] then + i, def = parse_type_body(ps, i + 1, istart, node, tn) else i, def = parse_type(ps, i) - if not def then - return fail(ps, i, "expected a type") - end + end - if def.typename == "nominal" then - node.newtype = new_typealias(ps, itype, def) - else - node.newtype = new_typedecl(ps, itype, def) - end + if not def then + return fail(ps, i, "expected a type") + end + if def.typename == "nominal" then + node.newtype = new_typealias(ps, istart, def) return i, node end + + node.newtype = new_typedecl(ps, istart, def) + return i, node end local function parse_assignment_expression_list(ps, i, asgn) @@ -4213,8 +4230,8 @@ do local asgn = new_node(ps, i, node_name) local nt = new_node(ps, i, "newtype") asgn.value = nt - local itype = i - local def = new_type(ps, i, tn) + local istart = i + local def i = i + 2 @@ -4223,11 +4240,14 @@ do return fail(ps, i, "expected a type name") end - set_declname(def, asgn.var.tk) + i, def = parse_type_body(ps, i, istart, nt, tn) + if not def then + return i + end - i = parse_type_body_fns[tn](ps, i, def, nt) + set_declname(def, asgn.var.tk) - nt.newtype = new_typedecl(ps, itype, def) + nt.newtype = new_typedecl(ps, istart, def) return i, asgn end diff --git a/tl.tl b/tl.tl index 92b7db3d7..85d8263ca 100644 --- a/tl.tl +++ b/tl.tl @@ -2314,9 +2314,9 @@ local parse_type_declaration: function(ps: ParseState, i: integer, node_name: No local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer -local type ParseBody = function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node -local parse_enum_body: function(ps: ParseState, i: integer, def: EnumType, node: Node): integer, Node -local parse_record_body: function(ps: ParseState, i: integer, def: Type, node: Node): integer, Node +local type ParseBody = function(ps: ParseState, i: integer, def: Type): integer, boolean +local parse_enum_body: function(ps: ParseState, i: integer, def: EnumType): integer, boolean +local parse_record_body: function(ps: ParseState, i: integer, def: Type): integer, boolean local parse_type_body_fns: {TypeName:ParseBody} local function fail(ps: ParseState, i: integer, msg: string): integer @@ -2412,9 +2412,9 @@ local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kin return fail(ps, i, "syntax error, expected " .. kind) end -local type SkipFunction = function(ParseState, integer): integer, Node +local type SkipFunction = function(ParseState, integer): integer, Node | boolean -local function skip(ps: ParseState, i: integer, skip_fn: SkipFunction): integer, Node +local function skip(ps: ParseState, i: integer, skip_fn: SkipFunction): integer, Node | boolean local err_ps: ParseState = { filename = ps.filename, tokens = ps.tokens, @@ -2430,11 +2430,36 @@ local function failskip(ps: ParseState, i: integer, msg: string, skip_fn: SkipFu return skip_i end -local function skip_type_body(ps: ParseState, i: integer): integer, Node +local function parse_type_body(ps: ParseState, i: integer, istart: integer, node: Node, tn: TypeName): integer, Type + local typeargs: {TypeArgType} + local def: Type + i, typeargs = parse_typeargs_if_any(ps, i) + + def = new_type(ps, istart, tn) + + if typeargs then + if def is RecordType or def is InterfaceType then + def.typeargs = typeargs + end + end + + local ok: boolean + i, ok = parse_type_body_fns[tn](ps, i, def) + if not ok then + return fail(ps, i, "expected a type") + end + + i = verify_end(ps, i, istart, node) + + return i, def +end + +local function skip_type_body(ps: ParseState, i: integer): integer, boolean local tn = ps.tokens[i].tk as TypeName i = i + 1 assert(parse_type_body_fns[tn], tn .. " has no parse body function") - return parse_type_body_fns[tn](ps, i, {}, { kind = "function" --[[ skip end_alignment_hint ]] }) + local ii, tt = parse_type_body(ps, i, i - 1, {}, tn) + return ii, not not tt end local function parse_table_value(ps: ParseState, i: integer): integer, Node, integer @@ -3668,7 +3693,8 @@ local function set_declname(def: Type, declname: string) end end -local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType, typename: TypeName, parse_body: ParseBody): integer, boolean +local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType, typename: TypeName): integer, boolean + local istart = i i = i + 1 -- skip 'record' or 'enum' local iv = i @@ -3678,22 +3704,23 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType return fail(ps, i, "expected a variable name") end - local nt: Node = new_node(ps, i - 2, "newtype") - local ndef = new_type(ps, i, typename) - local itype = i - local iok = parse_body(ps, i, ndef, nt) - if iok then - i = iok - set_declname(ndef, v.tk) - nt.newtype = new_typedecl(ps, itype, ndef) + local nt: Node = new_node(ps, istart, "newtype") + + local ndef: Type + i, ndef = parse_type_body(ps, i, istart, nt, typename) + if not ndef then + return i end + set_declname(ndef, v.tk) + + nt.newtype = new_typedecl(ps, istart, ndef) + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) return i end -parse_enum_body = function(ps: ParseState, i: integer, def: EnumType, node: Node): integer, Node - local istart = i - 1 +parse_enum_body = function(ps: ParseState, i: integer, def: EnumType): integer, boolean def.enumset = {} while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do local item: Node @@ -3702,8 +3729,7 @@ parse_enum_body = function(ps: ParseState, i: integer, def: EnumType, node: Node def.enumset[unquote(item.tk)] = true end end - i = verify_end(ps, i, istart, node) - return i, node + return i, true end local metamethod_names: {string:boolean} = { @@ -3812,15 +3838,10 @@ local function clone_typeargs(ps: ParseState, i: integer, typeargs: {TypeArgType end -parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, node: Node): integer, Node - local istart = i - 1 +parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType): integer, boolean def.fields = {} def.field_order = {} - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - if ps.tokens[i].tk == "{" then local atype: Type i, atype = parse_array_interface_type(ps, i, def) @@ -3910,7 +3931,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) elseif parse_type_body_fns[tn] and ps.tokens[i+1].tk ~= ":" then - i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) + i = parse_nested_type(ps, i, def, tn) else local is_metamethod = false if ps.tokens[i].tk == "metamethod" and ps.tokens[i+1].tk ~= ":" then @@ -3979,8 +4000,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no end end end - i = verify_end(ps, i, istart, node) - return i, node + return i, true end parse_type_body_fns = { @@ -3993,31 +4013,25 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node local node: Node = new_node(ps, i, "newtype") local def: Type local tn = ps.tokens[i].tk as TypeName - local itype = i - if parse_type_body_fns[tn] then - def = new_type(ps, i, tn) - i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node) - if not def then - return fail(ps, i, "expected a type") - end + local istart = i - node.newtype = new_typedecl(ps, itype, def) - return i, node + if parse_type_body_fns[tn] then + i, def = parse_type_body(ps, i + 1, istart, node, tn) else i, def = parse_type(ps, i) - if not def then - return fail(ps, i, "expected a type") - end + end - if def is NominalType then - node.newtype = new_typealias(ps, itype, def) - else - node.newtype = new_typedecl(ps, itype, def) - end + if not def then + return fail(ps, i, "expected a type") + end + if def is NominalType then + node.newtype = new_typealias(ps, istart, def) return i, node end + + node.newtype = new_typedecl(ps, istart, def) + return i, node end local function parse_assignment_expression_list(ps: ParseState, i: integer, asgn: Node): integer, Node @@ -4213,8 +4227,8 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod local asgn: Node = new_node(ps, i, node_name) local nt: Node = new_node(ps, i, "newtype") asgn.value = nt - local itype = i - local def = new_type(ps, i, tn) + local istart = i + local def: Type i = i + 2 -- skip `local` or `global`, and the constructor name @@ -4223,11 +4237,14 @@ local function parse_type_constructor(ps: ParseState, i: integer, node_name: Nod return fail(ps, i, "expected a type name") end - set_declname(def, asgn.var.tk) + i, def = parse_type_body(ps, i, istart, nt, tn) + if not def then + return i + end - i = parse_type_body_fns[tn](ps, i, def, nt) + set_declname(def, asgn.var.tk) - nt.newtype = new_typedecl(ps, itype, def) + nt.newtype = new_typedecl(ps, istart, def) return i, asgn end From 98044e38447ed722b64ca403fff5ad80bfc5e21f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 11:42:03 -0300 Subject: [PATCH 189/224] minor cleanup --- tl.lua | 19 +++++++++---------- tl.tl | 30 ++++++++++++++++-------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/tl.lua b/tl.lua index 27048d34a..722772630 100644 --- a/tl.lua +++ b/tl.lua @@ -2451,9 +2451,6 @@ do i = verify_end(ps, i, istart, node) - - - return i, def end @@ -10894,13 +10891,15 @@ self:expand_type(node, values, elements) }) value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0) - - - local ty = t.typename == "tuple" and t.tuple[1] or t - - ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty - return (ty.typename == "typedecl") and ty or (a_type(value, "typedecl", { def = ty })) + local ty = resolve_tuple(special_functions["require"](self, value, + self:find_var_type("require"), + a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }))) + if ty.typename == "typealias" then + return self:resolve_typealias(ty) + elseif ty.typename == "typedecl" then + return ty + end + return a_type(value, "typedecl", { def = ty }) elseif value.kind == "op" and value.op.op == "." then diff --git a/tl.tl b/tl.tl index 85d8263ca..cd7bf317b 100644 --- a/tl.tl +++ b/tl.tl @@ -10417,7 +10417,7 @@ do end end - local function special_pcall_xpcall(self: TypeChecker, node: Node, _a: Type, b: TupleType, argdelta: integer): Type + local function special_pcall_xpcall(self: TypeChecker, node: Node, _a: Type, b: TupleType, argdelta?: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then @@ -10463,8 +10463,8 @@ do return rets end - local special_functions: {string : function(TypeChecker, Node,Type,TupleType,integer):InvalidOrTupleType } = { - ["pairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + local special_functions: {string : function(TypeChecker, Node, Type, TupleType, ? integer):InvalidOrTupleType } = { + ["pairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType if not b.tuple[1] then return self.errs:invalid_at(node, "pairs requires an argument") end @@ -10489,7 +10489,7 @@ do return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["ipairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType if not b.tuple[1] then return self.errs:invalid_at(node, "ipairs requires an argument") end @@ -10510,7 +10510,7 @@ do return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["rawget"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta?: integer): InvalidOrTupleType -- TODO should those offsets be fixed by _argdelta? if #b.tuple == 2 then return a_tuple(node, { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) @@ -10519,7 +10519,7 @@ do end end, - ["require"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["require"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta?: integer): InvalidOrTupleType if #b.tuple ~= 1 then return self.errs:invalid_at(node, "require expects one literal argument") end @@ -10552,7 +10552,7 @@ do ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["assert"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType node.known = FACT_TRUTHY local r = self:type_check_function_call(node, a, b, argdelta) self:apply_facts(node, node.e2[1].known) @@ -10891,13 +10891,15 @@ do and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](self, value, self:find_var_type("require"), a_tuple(value.e2, { a_type(value.e2[1], "string", {}) }), 0) - - -- FIXME why is this ': Type' annotation needed? - local ty: Type = t is TupleType and t.tuple[1] or t - - ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty - return (ty is TypeDeclType) and ty or (a_type(value, "typedecl", { def = ty } as TypeDeclType)) + local ty = resolve_tuple(special_functions["require"](self, value, + self:find_var_type("require"), + a_tuple(value.e2, { a_type(value.e2[1], "string", {}) }))) + if ty is TypeAliasType then + return self:resolve_typealias(ty) + elseif ty is TypeDeclType then + return ty + end + return a_type(value, "typedecl", { def = ty } as TypeDeclType) elseif value.kind == "op" and value.op.op == "." then From 8f89ec76299ad68afbc2d8adb9dff0b8a991dc39 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 4 Sep 2024 11:57:00 -0300 Subject: [PATCH 190/224] refactor: remove TypeAliasType (unify with TypeDeclType) The TypeAliasType is now a case of TypeDeclType, where `is_alias` is true, and `def` is always a NominalType (asserts are sprinkled in the code to check this). --- tl.lua | 163 +++++++++++++++++++++++++---------------------------- tl.tl | 175 +++++++++++++++++++++++++++------------------------------ 2 files changed, 158 insertions(+), 180 deletions(-) diff --git a/tl.lua b/tl.lua index 722772630..0913c1b17 100644 --- a/tl.lua +++ b/tl.lua @@ -1580,7 +1580,6 @@ end - local table_types = { @@ -1593,7 +1592,6 @@ local table_types = { ["tupletable"] = true, ["typedecl"] = false, - ["typealias"] = false, ["typevar"] = false, ["typearg"] = false, ["function"] = false, @@ -1946,14 +1944,6 @@ end - - - - - - - - @@ -2391,12 +2381,6 @@ do return t, t.tuple end - local function new_typealias(ps, i, alias_to) - local t = new_type(ps, i, "typealias") - t.alias_to = alias_to - return t - end - local function new_nominal(ps, i, name) local t = new_type(ps, i, "nominal") if name then @@ -3925,7 +3909,7 @@ do end local ntt = nt.newtype - if ntt.typename == "typealias" then + if ntt.is_alias then ntt.is_nested_alias = true end @@ -4020,17 +4004,16 @@ do else i, def = parse_type(ps, i) end - if not def then return fail(ps, i, "expected a type") end + node.newtype = new_typedecl(ps, istart, def) + if def.typename == "nominal" then - node.newtype = new_typealias(ps, istart, def) - return i, node + node.newtype.is_alias = true end - node.newtype = new_typedecl(ps, istart, def) return i, node end @@ -4214,10 +4197,6 @@ do end set_declname(def, asgn.var.tk) - elseif nt.typename == "typealias" then - if typeargs then - nt.typeargs = typeargs - end end return i, asgn @@ -4621,8 +4600,6 @@ local function recurse_type(s, ast, visit) if ast.vtype then table.insert(xs, recurse_type(s, ast.vtype, visit)) end - elseif ast.typename == "typealias" then - table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then table.insert(xs, recurse_type(s, ast.def, visit)) end @@ -5545,12 +5522,12 @@ function tl.pretty_print_ast(ast, gen_target, mode) after = function(_, node, _children) local out = { y = node.y, h = 0 } local nt = node.newtype - if nt.typename == "typealias" then - table.insert(out, table.concat(nt.alias_to.names, ".")) - elseif nt.typename == "typedecl" then + if nt.typename == "typedecl" then local def = nt.def if def.fields then table.insert(out, print_record_def(def)) + elseif def.typename == "nominal" then + table.insert(out, table.concat(def.names, ".")) else table.insert(out, "{}") end @@ -5664,7 +5641,6 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["string"] = default_type_visitor visit_type.cbs["typedecl"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor visit_type.cbs["typevar"] = default_type_visitor visit_type.cbs["typearg"] = default_type_visitor visit_type.cbs["function"] = default_type_visitor @@ -5746,7 +5722,6 @@ local typename_to_typecode = { ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, - ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5846,8 +5821,6 @@ function TypeReporter:get_typenum(t) if rt.typename == "typedecl" then rt = rt.def - elseif rt.typename == "typealias" then - rt = rt.alias_to end local ti = { @@ -5869,7 +5842,7 @@ function TypeReporter:get_typenum(t) rt = t end end - assert(not (rt.typename == "typedecl" or rt.typename == "typealias")) + assert(not (rt.typename == "typedecl")) if rt.fields then @@ -6248,7 +6221,6 @@ function Errors:unused_warning(name, var) var.is_func_arg and "argument" or t.typename == "function" and "function" or t.typename == "typedecl" and "type" or - t.typename == "typealias" and "type" or "variable", name, show_type(var.t)) @@ -6300,13 +6272,13 @@ local function check_for_unused_vars(scope, is_global) if var.used_as_type then var.declared_at.elide_type = true else - if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then + if t.typename == "typedecl" and not is_global then var.declared_at.elide_type = true end list = list or {} table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then + elseif var.used and t.typename == "typedecl" and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -6797,10 +6769,8 @@ local function show_type_base(t, short, seen) return "boolean" elseif t.typename == "none" then return "" - elseif t.typename == "typealias" then - return "type alias to " .. show(t.alias_to) elseif t.typename == "typedecl" then - return "type " .. show(t.def) + return (t.is_alias and "type alias to " or "type ") .. show(t.def) else return "<" .. t.typename .. ">" end @@ -7276,7 +7246,7 @@ do typ = typ.found end end - if typ.typename == "typedecl" or typ.typename == "typealias" then + if typ.typename == "typedecl" then return typ elseif accept_typearg and typ.typename == "typearg" then return typ @@ -7286,8 +7256,6 @@ do local function type_for_union(t) if t.typename == "typedecl" then return type_for_union(t.def), t.def - elseif t.typename == "typealias" then - return type_for_union(t.alias_to), t.alias_to elseif t.typename == "tuple" then return type_for_union(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then @@ -7396,8 +7364,6 @@ do local function resolve_typedecl(t) if t.typename == "typedecl" then return t.def - elseif t.typename == "typealias" then - return t.alias_to else return t end @@ -7513,9 +7479,7 @@ do elseif t.typename == "typedecl" then assert(copy.typename == "typedecl") copy.def, same = resolve(t.def, same) - elseif t.typename == "typealias" then - assert(copy.typename == "typealias") - copy.alias_to, same = resolve(t.alias_to, same) + copy.is_alias = t.is_alias copy.is_nested_alias = t.is_nested_alias elseif t.typename == "nominal" then assert(copy.typename == "nominal") @@ -7956,8 +7920,10 @@ do return self.errs:invalid_at(t, "unknown type %s", t) end - if found.typename == "typealias" then - found = found.alias_to.found + if found.typename == "typedecl" and found.is_alias then + local def = found.def + assert(def.typename == "nominal") + found = def.found end if not found then @@ -8006,23 +7972,34 @@ do return resolve_decl_into_nominal(self, t, found) end - function TypeChecker:resolve_typealias(typealias) - local t = typealias.alias_to + function TypeChecker:resolve_typealias(ta) + + local nom = ta.def + assert(nom.typename == "nominal") + + local immediate, found = find_nominal_type_decl(self, nom) - local immediate, found = find_nominal_type_decl(self, t) if type(immediate) == "table" then return immediate end - if not t.typevals then + + if not nom.typevals then + nom.resolved = found return found end - local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type(typealias, "typedecl", { def = resolved }) - t.resolved = typedecl - return typedecl + + + local struc = resolve_decl_into_nominal(self, nom, found) + + + local td = a_type(ta, "typedecl", { def = struc }) + nom.resolved = td + + + return td end end @@ -9505,12 +9482,15 @@ a.types[i], b.types[i]), } end if tbl.typename == "typedecl" then - tbl = tbl.def - elseif tbl.typename == "typealias" then if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" + end + local def = tbl.def + if def.typename == "nominal" then + assert(tbl.is_alias) + tbl = self:resolve_nominal(def) else - tbl = self:resolve_nominal(tbl.alias_to) + tbl = def end end @@ -9964,9 +9944,13 @@ a.types[i], b.types[i]), } t = t.fields[fname] if t.typename == "typedecl" then - t = t.def - elseif t.typename == "typealias" then - t = t.alias_to.resolved + local def = t.def + if def.typename == "nominal" then + assert(t.is_alias) + t = def.resolved + else + t = def + end end return t, v, dname @@ -10020,7 +10004,7 @@ a.types[i], b.types[i]), } if def.fields and def.fields[exp.e2.tk] then table.insert(t.names, exp.e2.tk) local ft = def.fields[exp.e2.tk] - if type(ft) == "table" then + if ft.typename == "typedecl" then t.found = ft else return nil @@ -10894,9 +10878,10 @@ self:expand_type(node, values, elements) }) local ty = resolve_tuple(special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }))) - if ty.typename == "typealias" then - return self:resolve_typealias(ty) - elseif ty.typename == "typedecl" then + if ty.typename == "typedecl" then + if ty.is_alias then + return self:resolve_typealias(ty) + end return ty end return a_type(value, "typedecl", { def = ty }) @@ -10920,12 +10905,13 @@ self:expand_type(node, values, elements) }) return ty else local newtype = value.newtype - if newtype.typename == "typealias" then - local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + if newtype.is_alias then + local def = newtype.def + assert(def.typename == "nominal") + local aliasing = self:find_var(def.names[1], "use_type") return self:resolve_typealias(newtype), aliasing - elseif newtype.typename == "typedecl" then - return newtype, nil end + return newtype, nil end end @@ -10943,7 +10929,7 @@ self:expand_type(node, values, elements) }) local missing for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype.typename == "typedecl" or ftype.typename == "typealias" or (ftype.typename == "function" and ftype.is_record_function)) then + if not (ftype.typename == "typedecl" or (ftype.typename == "function" and ftype.is_record_function)) then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10990,7 +10976,7 @@ self:expand_type(node, values, elements) }) end local var = self:to_structural(vartype) - if var.typename == "typedecl" or var.typename == "typealias" then + if var.typename == "typedecl" then self.errs:add(varnode, "cannot reassign a type") return nil end @@ -11040,7 +11026,7 @@ self:expand_type(node, values, elements) }) local name = node.var.tk local resolved, aliasing = self:get_typedecl(node.value) local nt = node.value.newtype - if nt and nt.typename == "typealias" and resolved.typename == "typedecl" then + if nt and nt.is_alias and resolved.typename == "typedecl" then if nt.typeargs then local def = resolved.def @@ -11590,7 +11576,7 @@ self:expand_type(node, values, elements) }) if not df then self.errs:add_in_context(node[i], node, "unknown field " .. ck) else - if df.typename == "typedecl" or df.typename == "typealias" then + if df.typename == "typedecl" then self.errs:add_in_context(node[i], node, "cannot reassign a type") else self:assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -12595,10 +12581,12 @@ self:expand_type(node, values, elements) }) self:add_var(nil, "@self", a_type(typ, "typedecl", { def = typ })) for fname, ftype in fields_of(typ) do - if ftype.typename == "typealias" then - self:resolve_nominal(ftype.alias_to) - self:add_var(nil, fname, ftype) - elseif ftype.typename == "typedecl" then + if ftype.typename == "typedecl" then + local def = ftype.def + if def.typename == "nominal" then + assert(ftype.is_alias) + self:resolve_nominal(def) + end self:add_var(nil, fname, ftype) end end @@ -12610,7 +12598,7 @@ self:expand_type(node, values, elements) }) local scope = self.st[#self.st] scope.vars["@self"] = nil for fname, ftype in fields_of(typ) do - if ftype.typename == "typealias" or ftype.typename == "typedecl" then + if ftype.typename == "typedecl" then scope.vars[fname] = nil end end @@ -12716,7 +12704,7 @@ self:expand_type(node, values, elements) }) end end end - elseif ftype.typename == "typealias" then + elseif ftype.typename == "typedecl" and ftype.is_alias then self:resolve_typealias(ftype) end @@ -12799,10 +12787,12 @@ self:expand_type(node, values, elements) }) local tv = typ tv.typevar = t.typearg tv.constraint = t.constraint - elseif t.typename == "typealias" then - typ.found = t.alias_to.found elseif t.typename == "typedecl" then - if t.def.typename ~= "circular_require" then + local def = t.def + if t.is_alias then + assert(def.typename == "nominal") + typ.found = def.found + elseif def.typename ~= "circular_require" then typ.found = t end end @@ -12837,7 +12827,6 @@ self:expand_type(node, values, elements) }) visit_type.cbs["interface"] = visit_type.cbs["record"] visit_type.cbs["typedecl"] = visit_type_with_typeargs - visit_type.cbs["typealias"] = visit_type_with_typeargs visit_type.cbs["self"] = default_type_visitor visit_type.cbs["string"] = default_type_visitor diff --git a/tl.tl b/tl.tl index cd7bf317b..24a12f0f6 100644 --- a/tl.tl +++ b/tl.tl @@ -1548,7 +1548,6 @@ end local enum TypeName "typedecl" - "typealias" "typevar" "typearg" "function" @@ -1593,7 +1592,6 @@ local table_types : {TypeName:boolean} = { ["tupletable"] = true, ["typedecl"] = false, - ["typealias"] = false, ["typevar"] = false, ["typearg"] = false, ["function"] = false, @@ -1681,18 +1679,10 @@ local record TypeDeclType def: Type closed: boolean -end - -local record TypeAliasType - is Type, HasTypeArgs - where self.typename == "typealias" - - alias_to: NominalType + is_alias: boolean is_nested_alias: boolean end -local type TypeType = TypeDeclType | TypeAliasType - local record LiteralTableItemType is Type where self.typename == "literal_table_item" @@ -1722,7 +1712,7 @@ local record NominalType names: {string} typevals: {Type} - found: TypeType -- type is found but typeargs are not resolved + found: TypeDeclType -- type is found but typeargs are not resolved resolved: Type -- type is found and typeargs are resolved end @@ -2104,7 +2094,7 @@ local record Node exps: Node -- newtype - newtype: TypeType + newtype: TypeDeclType elide_type: boolean -- expressions @@ -2391,12 +2381,6 @@ local function new_tuple(ps: ParseState, i: integer, types?: {Type}, is_va?: boo return t, t.tuple end -local function new_typealias(ps: ParseState, i: integer, alias_to: NominalType): TypeAliasType - local t = new_type(ps, i, "typealias") as TypeAliasType - t.alias_to = alias_to - return t -end - local function new_nominal(ps: ParseState, i: integer, name?: string): NominalType local t = new_type(ps, i, "nominal") as NominalType if name then @@ -3925,7 +3909,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType): i end local ntt = nt.newtype - if ntt is TypeAliasType then + if ntt.is_alias then ntt.is_nested_alias = true end @@ -4020,17 +4004,16 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node else i, def = parse_type(ps, i) end - if not def then return fail(ps, i, "expected a type") end + node.newtype = new_typedecl(ps, istart, def) + if def is NominalType then - node.newtype = new_typealias(ps, istart, def) - return i, node + node.newtype.is_alias = true end - node.newtype = new_typedecl(ps, istart, def) return i, node end @@ -4214,10 +4197,6 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin end set_declname(def, asgn.var.tk) - elseif nt is TypeAliasType then - if typeargs then - nt.typeargs = typeargs - end end return i, asgn @@ -4621,8 +4600,6 @@ local function recurse_type(s: S, ast: Type, visit: Visitor: {TypeName:integer} = { ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, - ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5846,8 +5821,6 @@ function TypeReporter:get_typenum(t: Type): integer if rt is TypeDeclType then rt = rt.def - elseif rt is TypeAliasType then - rt = rt.alias_to end local ti: TypeInfo = { @@ -5869,7 +5842,7 @@ function TypeReporter:get_typenum(t: Type): integer rt = t end end - assert(not (rt is TypeDeclType or rt is TypeAliasType)) + assert(not (rt is TypeDeclType)) if rt is RecordLikeType then -- store record field info @@ -6248,7 +6221,6 @@ function Errors:unused_warning(name: string, var: Variable) var.is_func_arg and "argument" or t is FunctionType and "function" or t is TypeDeclType and "type" - or t is TypeAliasType and "type" or "variable", name, show_type(var.t) @@ -6300,13 +6272,13 @@ local function check_for_unused_vars(scope: Scope, is_global?: boolean): {Unused if var.used_as_type then var.declared_at.elide_type = true else - if (t is TypeDeclType or t is TypeAliasType) and not is_global then + if t is TypeDeclType and not is_global then var.declared_at.elide_type = true end list = list or {} table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) end - elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then + elseif var.used and t is TypeDeclType and var.aliasing then var.aliasing.used = true var.aliasing.declared_at.elide_type = false end @@ -6797,10 +6769,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str return "boolean" elseif t.typename == "none" then return "" - elseif t is TypeAliasType then - return "type alias to " .. show(t.alias_to) elseif t is TypeDeclType then - return "type " .. show(t.def) + return (t.is_alias and "type alias to " or "type ") .. show(t.def) else return "<" .. t.typename .. ">" -- TODO add string.format("%p", t) with compat-5.4 end @@ -7276,7 +7246,7 @@ do typ = typ.found end end - if typ is TypeDeclType or typ is TypeAliasType then + if typ is TypeDeclType then return typ elseif accept_typearg and typ is TypeArgType then return typ @@ -7286,8 +7256,6 @@ do local function type_for_union(t: Type): string, Type if t is TypeDeclType then return type_for_union(t.def), t.def - elseif t is TypeAliasType then - return type_for_union(t.alias_to), t.alias_to elseif t is TupleType then return type_for_union(t.tuple[1]), t.tuple[1] elseif t is NominalType then @@ -7396,8 +7364,6 @@ do local function resolve_typedecl(t: Type): Type if t is TypeDeclType then return t.def - elseif t is TypeAliasType then - return t.alias_to else return t end @@ -7513,9 +7479,7 @@ do elseif t is TypeDeclType then assert(copy is TypeDeclType) copy.def, same = resolve(t.def, same) - elseif t is TypeAliasType then - assert(copy is TypeAliasType) - copy.alias_to, same = resolve(t.alias_to, same) + copy.is_alias = t.is_alias copy.is_nested_alias = t.is_nested_alias elseif t is NominalType then assert(copy is NominalType) @@ -7956,8 +7920,10 @@ do return self.errs:invalid_at(t, "unknown type %s", t) end - if found is TypeAliasType then - found = found.alias_to.found + if found is TypeDeclType and found.is_alias then + local def = found.def + assert(def is NominalType) + found = def.found end if not found then @@ -8006,23 +7972,34 @@ do return resolve_decl_into_nominal(self, t, found) end - function TypeChecker:resolve_typealias(typealias: TypeAliasType): InvalidOrTypeDeclType - local t = typealias.alias_to + function TypeChecker:resolve_typealias(ta: TypeDeclType): InvalidOrTypeDeclType + -- given a typealias that points to a nominal, + local nom = ta.def + assert(nom is NominalType) - local immediate, found = find_nominal_type_decl(self, t) + local immediate, found = find_nominal_type_decl(self, nom) + -- if it was previously resolved (or a circular require, or an error), return that; if immediate is InvalidOrTypeDeclType then return immediate end - if not t.typevals then + -- if nominal has no type arguments, resolve alias to that; + if not nom.typevals then + nom.resolved = found return found end - local resolved = resolve_decl_into_nominal(self, t, found) + -- otherwise, this can't be an alias. + + -- resolve the nominal into a structural type + local struc = resolve_decl_into_nominal(self, nom, found) - local typedecl = a_type(typealias, "typedecl", { def = resolved } as TypeDeclType) - t.resolved = typedecl - return typedecl + -- wrap it into a new non-alias typedecl + local td = a_type(ta, "typedecl", { def = struc } as TypeDeclType) + nom.resolved = td + + -- and return it + return td end end @@ -9505,12 +9482,15 @@ do end if tbl is TypeDeclType then - tbl = tbl.def - elseif tbl is TypeAliasType then if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" + end + local def = tbl.def + if def is NominalType then + assert(tbl.is_alias) + tbl = self:resolve_nominal(def) else - tbl = self:resolve_nominal(tbl.alias_to) + tbl = def end end @@ -9964,9 +9944,13 @@ do t = t.fields[fname] if t is TypeDeclType then - t = t.def - elseif t is TypeAliasType then - t = t.alias_to.resolved + local def = t.def + if def is NominalType then + assert(t.is_alias) + t = def.resolved + else + t = def + end end return t, v, dname @@ -10020,7 +10004,7 @@ do if def is RecordLikeType and def.fields[exp.e2.tk] then table.insert(t.names, exp.e2.tk) local ft = def.fields[exp.e2.tk] - if ft is TypeType then + if ft is TypeDeclType then t.found = ft else return nil @@ -10894,9 +10878,10 @@ do local ty = resolve_tuple(special_functions["require"](self, value, self:find_var_type("require"), a_tuple(value.e2, { a_type(value.e2[1], "string", {}) }))) - if ty is TypeAliasType then - return self:resolve_typealias(ty) - elseif ty is TypeDeclType then + if ty is TypeDeclType then + if ty.is_alias then + return self:resolve_typealias(ty) + end return ty end return a_type(value, "typedecl", { def = ty } as TypeDeclType) @@ -10920,12 +10905,13 @@ do return ty else local newtype = value.newtype - if newtype is TypeAliasType then - local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + if newtype.is_alias then + local def = newtype.def + assert(def is NominalType) + local aliasing = self:find_var(def.names[1], "use_type") return self:resolve_typealias(newtype), aliasing - elseif newtype is TypeDeclType then - return newtype, nil end + return newtype, nil end end @@ -10943,7 +10929,7 @@ do local missing: {string} for _, key in ipairs(t.field_order) do local ftype = t.fields[key] - if not (ftype is TypeDeclType or ftype is TypeAliasType or (ftype is FunctionType and ftype.is_record_function)) then + if not (ftype is TypeDeclType or (ftype is FunctionType and ftype.is_record_function)) then is_total, missing = total_check_key(key, seen_keys, is_total, missing) end end @@ -10990,7 +10976,7 @@ do end local var = self:to_structural(vartype) - if var is TypeDeclType or var is TypeAliasType then + if var is TypeDeclType then self.errs:add(varnode, "cannot reassign a type") return nil end @@ -11040,12 +11026,12 @@ do local name = node.var.tk local resolved, aliasing = self:get_typedecl(node.value) local nt = node.value.newtype - if nt and nt is TypeAliasType and resolved is TypeDeclType then + if nt and nt.is_alias and resolved is TypeDeclType then if nt.typeargs then local def = resolved.def - -- FIXME ideally we'd like to use `if def is HasTypeArgs` - -- here, but if def.typeargs happens to be nil, the `is` - -- check won't work + -- FIXME this looks sketchy; not sure if just overwriting the + -- type variables in a resolved alias won't have bad side-effects. + -- also global types need to propagate type variables as well. if def is RecordType or def is FunctionType or def is InterfaceType then def.typeargs = nt.typeargs end @@ -11590,7 +11576,7 @@ do if not df then self.errs:add_in_context(node[i], node, "unknown field " .. ck) else - if df is TypeDeclType or df is TypeAliasType then + if df is TypeDeclType then self.errs:add_in_context(node[i], node, "cannot reassign a type") else self:assert_is_a(node[i], cvtype, df, "in record field", ck) @@ -12595,10 +12581,12 @@ do self:add_var(nil, "@self", a_typedecl(typ, typ)) for fname, ftype in fields_of(typ) do - if ftype is TypeAliasType then - self:resolve_nominal(ftype.alias_to) - self:add_var(nil, fname, ftype) - elseif ftype is TypeDeclType then + if ftype is TypeDeclType then + local def = ftype.def + if def is NominalType then + assert(ftype.is_alias) + self:resolve_nominal(def) + end self:add_var(nil, fname, ftype) end end @@ -12610,7 +12598,7 @@ do local scope = self.st[#self.st] scope.vars["@self"] = nil for fname, ftype in fields_of(typ) do - if ftype is TypeAliasType or ftype is TypeDeclType then + if ftype is TypeDeclType then scope.vars[fname] = nil end end @@ -12716,7 +12704,7 @@ do end end end - elseif ftype is TypeAliasType then + elseif ftype is TypeDeclType and ftype.is_alias then self:resolve_typealias(ftype) end @@ -12799,10 +12787,12 @@ do local tv = typ as TypeVarType tv.typevar = t.typearg tv.constraint = t.constraint - elseif t is TypeAliasType then - typ.found = t.alias_to.found elseif t is TypeDeclType then - if t.def.typename ~= "circular_require" then + local def = t.def + if t.is_alias then + assert(def is NominalType) + typ.found = def.found + elseif def.typename ~= "circular_require" then typ.found = t end end @@ -12837,7 +12827,6 @@ do visit_type.cbs["interface"] = visit_type.cbs["record"] visit_type.cbs["typedecl"] = visit_type_with_typeargs - visit_type.cbs["typealias"] = visit_type_with_typeargs visit_type.cbs["self"] = default_type_visitor visit_type.cbs["string"] = default_type_visitor From 64de5598ea0373b8185388f5a494fd75823b768a Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 31 Aug 2024 21:34:13 -0300 Subject: [PATCH 191/224] API: expose Teal 0.15.x-like API for backwards compatibility --- spec/api/get_types_spec.lua | 4 +- spec/parser/parser_error_spec.lua | 4 +- spec/util.lua | 10 +- tl | 4 +- tl.lua | 136 ++++++++++++++++------- tl.tl | 172 ++++++++++++++++++++---------- 6 files changed, 223 insertions(+), 107 deletions(-) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index b6f55ec83..a36b5373e 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -4,7 +4,7 @@ describe("tl.get_types", function() it("skips over label nodes (#393)", function() local env = tl.init_env() env.report_types = true - local result = assert(tl.process_string([[ + local result = assert(tl.check_string([[ local function a() ::continue:: end @@ -18,7 +18,7 @@ describe("tl.get_types", function() it("reports resolved type on poly function calls", function() local env = tl.init_env() env.report_types = true - local result = assert(tl.process_string([[ + local result = assert(tl.check_string([[ local record R f: function(string) f: function(integer) diff --git a/spec/parser/parser_error_spec.lua b/spec/parser/parser_error_spec.lua index cfd2e077c..8125c72d8 100644 --- a/spec/parser/parser_error_spec.lua +++ b/spec/parser/parser_error_spec.lua @@ -2,7 +2,7 @@ local tl = require("tl") describe("parser errors", function() it("parse errors include filename", function () - local result = tl.process_string("local x 1", nil, "foo.tl") + local result = tl.check_string("local x 1", nil, "foo.tl") assert.same("foo.tl", result.syntax_errors[1].filename, "parse errors should contain .filename property") end) @@ -30,7 +30,7 @@ describe("parser errors", function() local code = [[ local bar = require "bar" ]] - local result = tl.process_string(code, nil, "foo.tl") + local result = tl.check_string(code, nil, "foo.tl") assert.is_not_nil(string.match(result.env.loaded["./bar.tl"].syntax_errors[1].filename, "bar.tl$"), "errors should contain .filename property") end) end) diff --git a/spec/util.lua b/spec/util.lua index bb6b9af4b..0442b16f6 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -435,7 +435,7 @@ local function check(lax, code, unknowns, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) for _, mname in pairs(result.env.loaded_order) do local mresult = result.env.loaded[mname] @@ -462,7 +462,7 @@ local function check_type_error(lax, code, type_errors, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) local result_type_errors = combine_result(result, "type_errors") batch_compare(batch, "type errors", type_errors, result_type_errors) @@ -531,7 +531,7 @@ function util.check_syntax_error(code, syntax_errors) local batch = batch_assertions() batch_compare(batch, "syntax errors", syntax_errors, errors) batch:assert() - tl.type_check(ast, "foo.tl", { feat_lax = "off" }) + tl.check(ast, "foo.tl", { feat_lax = "off" }) end end @@ -570,7 +570,7 @@ function util.check_types(code, types) local batch = batch_assertions() local env = tl.init_env() env.report_types = true - local result = tl.type_check(ast, "foo.tl", { feat_lax = "off" }, env) + local result = tl.check(ast, "foo.tl", { feat_lax = "off" }, env) batch:add(assert.same, {}, result.type_errors, "Code was not expected to have type errors") local tr = env.reporter:get_report() @@ -603,7 +603,7 @@ local function gen(lax, code, expected, gen_target) local ast, syntax_errors = tl.parse(code, "foo.tl") assert.same({}, syntax_errors, "Code was not expected to have syntax errors") local gen_compat = gen_target == "5.4" and "off" or nil - local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) assert.same({}, result.type_errors) local output_code = tl.pretty_print_ast(ast, gen_target) diff --git a/tl b/tl index 6a90ac055..9d7322d2b 100755 --- a/tl +++ b/tl @@ -260,7 +260,7 @@ local function type_check_and_load(tlconfig, filename) os.exit(1) end - local chunk; chunk, err = (loadstring or load)(tl.pretty_print_ast(result.ast, tlconfig.gen_target), "@" .. filename) + local chunk; chunk, err = (loadstring or load)(tl.generate(result.ast, tlconfig.gen_target), "@" .. filename) if err then die("Internal Compiler Error: Teal generator produced invalid Lua. " .. "Please report a bug at https://github.com/teal-language/tl\n\n" .. tostring(err)) @@ -281,7 +281,7 @@ local function write_out(tlconfig, result, output_file, pp_opts) end local _ - _, err = ofd:write(tl.pretty_print_ast(result.ast, tlconfig.gen_target, pp_opts) .. "\n") + _, err = ofd:write(tl.generate(result.ast, tlconfig.gen_target, pp_opts) .. "\n") if err then die("error writing " .. output_file .. ": " .. err) end diff --git a/tl.lua b/tl.lua index 0913c1b17..fc3e57753 100644 --- a/tl.lua +++ b/tl.lua @@ -492,7 +492,7 @@ local Errors = {} -local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } +local tl = {GenerateOptions = {}, CheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, TypeCheckOptions = {}, } @@ -634,7 +634,6 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = { -local TypeReporter = {} @@ -643,9 +642,44 @@ local TypeReporter = {} -tl.version = function() - return VERSION -end + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +local TypeReporter = {} + + + + + + + local wk = { ["unknown"] = true, @@ -1028,7 +1062,7 @@ do end end - function tl.lex(input, filename) + tl.lex = function(input, filename) local tokens = {} local state = "any" @@ -1518,7 +1552,7 @@ local function binary_search(list, item, cmp) end end -function tl.get_token_at(tks, y, x) +tl.get_token_at = function(tks, y, x) local _, found = binary_search( tks, nil, function(tk) @@ -4951,13 +4985,13 @@ local spaced_op = { } -local default_pretty_print_ast_opts = { +local default_generate_opts = { preserve_indent = true, preserve_newlines = true, preserve_hashbang = false, } -local fast_pretty_print_ast_opts = { +local fast_generate_opts = { preserve_indent = false, preserve_newlines = true, preserve_hashbang = false, @@ -4974,18 +5008,11 @@ local primitive = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast, gen_target, mode) +function tl.generate(ast, gen_target, opts) local err local indent = 0 - local opts - if type(mode) == "table" then - opts = mode - elseif mode == true then - opts = fast_pretty_print_ast_opts - else - opts = default_pretty_print_ast_opts - end + opts = opts or default_generate_opts @@ -6807,7 +6834,7 @@ local function search_for(module_name, suffix, path, tried) return nil, nil, tried end -function tl.search_module(module_name, search_dtl) +tl.search_module = function(module_name, search_dtl) local found local fd local tried = {} @@ -6851,7 +6878,7 @@ local function require_module(w, module_name, opts, env) } env.defaults = defaults - local found_result, err = tl.process(found, env, fd) + local found_result, err = tl.check_file(found, env, fd) assert(found_result, err) env.defaults = save_defaults @@ -6885,7 +6912,7 @@ local function add_compat_entries(program, used_set, gen_compat) local code = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) + tl.check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -7025,7 +7052,7 @@ tl.new_env = function(opts) local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - local result = tl.type_check(program, "@stdlib", {}, env) + local result = tl.check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -12918,10 +12945,7 @@ self:expand_type(node, values, elements) }) end end - tl.type_check = function(ast, filename, opts, env) - assert(type(filename) == "string", "tl.type_check signature has changed, pass filename separately") - assert((not opts) or (not (opts).env), "tl.type_check signature has changed, pass env separately") - + tl.check = function(ast, filename, opts, env) filename = filename or "?" opts = opts or {} @@ -13080,7 +13104,7 @@ local function feat_lax_heuristic(filename, input) return "off" end -tl.process = function(filename, env, fd) +tl.check_file = function(filename, env, fd) if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -13100,7 +13124,7 @@ tl.process = function(filename, env, fd) return nil, "could not read " .. filename .. ": " .. err end - return tl.process_string(input, env, filename) + return tl.check_string(input, env, filename) end function tl.target_from_lua_version(str) @@ -13127,9 +13151,7 @@ local function default_env_opts(runtime, filename, input) } end -function tl.process_string(input, env, filename) - assert(type(env) ~= "boolean", "tl.process_string signature has changed") - +function tl.check_string(input, env, filename) env = env or tl.new_env(default_env_opts(false, filename, input)) if env.loaded and env.loaded[filename] then @@ -13153,23 +13175,23 @@ function tl.process_string(input, env, filename) return result end - local result = tl.type_check(program, filename, env.defaults, env) + local result = tl.check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors return result end -tl.gen = function(input, env, pp) +tl.gen = function(input, env, opts) env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") - local result = tl.process_string(input, env) + local result = tl.check_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code - code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) + code, result.gen_error = tl.generate(result.ast, env.defaults.gen_target, opts) return code, result end @@ -13197,13 +13219,13 @@ local function tl_package_loader(module_name) local w = { f = found_filename, x = 1, y = 1 } env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) - local result = tl.type_check(program, found_filename, opts.defaults, env) + local result = tl.check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type - local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) + local code = assert(tl.generate(program, opts.defaults.gen_target, fast_generate_opts)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname, loader_data) @@ -13258,7 +13280,7 @@ tl.load = function(input, chunkname, mode, ...) end local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") - local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) + local result = tl.check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -13272,7 +13294,7 @@ tl.load = function(input, chunkname, mode, ...) mode = mode:gsub("c", "") end - local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) + local code, err = tl.generate(program, opts.defaults.gen_target, fast_generate_opts) if not code then return nil, err end @@ -13280,6 +13302,10 @@ tl.load = function(input, chunkname, mode, ...) return load(code, chunkname, mode, ...) end +tl.version = function() + return VERSION +end + @@ -13305,4 +13331,36 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) return tl.new_env(opts) end +tl.type_check = function(ast, tc_opts) + local opts = { + feat_lax = tc_opts.lax and "on" or "off", + feat_arity = tc_opts.env and tc_opts.env.defaults.feat_arity or "on", + gen_compat = tc_opts.gen_compat, + gen_target = tc_opts.gen_target, + run_internal_compiler_checks = tc_opts.run_internal_compiler_checks, + } + return tl.check(ast, tc_opts.filename, opts, tc_opts.env) +end + +tl.pretty_print_ast = function(ast, gen_target, mode) + local opts + if type(mode) == "table" then + opts = mode + elseif mode == true then + opts = fast_generate_opts + else + opts = default_generate_opts + end + + return tl.generate(ast, gen_target, opts) +end + +tl.process = function(filename, env, fd) + return tl.check_file(filename, env, fd) +end + +tl.process_string = function(input, is_lua, env, filename, _module_name) + return tl.check_string(input, env or tl.init_env(is_lua), filename) +end + return tl diff --git a/tl.tl b/tl.tl index 24a12f0f6..8ab65623e 100644 --- a/tl.tl +++ b/tl.tl @@ -520,13 +520,13 @@ local record tl "off" end - record PrettyPrintOptions + record GenerateOptions preserve_indent: boolean preserve_newlines: boolean preserve_hashbang: boolean end - record TypeCheckOptions + record CheckOptions feat_lax: Feat feat_arity: Feat gen_compat: GenCompat @@ -543,7 +543,7 @@ local record tl reporter: TypeReporter keep_going: boolean report_types: boolean - defaults: TypeCheckOptions + defaults: CheckOptions end record Result @@ -615,20 +615,58 @@ local record tl end record EnvOptions - defaults: TypeCheckOptions + defaults: CheckOptions predefined_modules: {string} end - load: function(string, string, LoadMode, {any:any}): LoadFunction, string - process: function(string, Env, ? FILE): (Result, string) - process_string: function(string, Env, ? string): Result - gen: function(string, Env, PrettyPrintOptions): string, Result - type_check: function(Node, string, TypeCheckOptions, ? Env): Result, string + ----------------------------------------------------------------------------- + -- Public API + ----------------------------------------------------------------------------- + + check: function(Node, ? string, ? CheckOptions, ? Env): Result, string + gen: function(string, ? Env, ? GenerateOptions): string, Result + generate: function(ast: Node, gen_target: GenTarget, opts?: GenerateOptions): string, string + get_token_at: function(tks: {Token}, y: integer, x: integer): string + lex: function(input: string, filename: string): {Token}, {Error} + load: function(string, ? string, ? LoadMode, ...: {any:any}): LoadFunction, string + loader: function() new_env: function(? EnvOptions): Env, string + parse: function(input: string, filename: string): Node, {Error}, {string} + parse_program: function(tokens: {Token}, errs: {Error}, filename?: string): Node, {string} + check_file: function(filename: string, env?: Env, fd?: FILE): (Result, string) + check_string: function(input: string, env?: Env, filename?: string): Result + search_module: function(module_name: string, search_dtl: boolean): string, FILE, {string} + symbols_in_scope: function(tr: TypeReport, y: integer, x: integer, filename: string): {string:integer} + target_from_lua_version: function(str: string): GenTarget version: function(): string - -- Backwards compatibility - init_env: function(? boolean, ? boolean | GenCompat, ? GenTarget, ? {string}): Env, string + ----------------------------------------------------------------------------- + -- Deprecated, mantained for backwards compatibility: + ----------------------------------------------------------------------------- + + type CompatMode = GenCompat + type PrettyPrintOptions = GenerateOptions + type TargetMode = GenTarget + + record TypeCheckOptions + lax: boolean + filename: string + module_name: string + gen_compat: tl.CompatMode + gen_target: tl.TargetMode + env: Env + run_internal_compiler_checks: boolean + end + + init_env: function(? boolean, ? boolean | tl.CompatMode, ? tl.TargetMode, ? {string}): Env, string + pretty_print_ast: function(ast: Node, gen_target?: tl.TargetMode, mode?: boolean | tl.PrettyPrintOptions): string, string + process: function(filename: string, env?: Env, fd?: FILE): Result, string + process_string: function(input: string, is_lua: boolean, env: Env, filename: string, _module_name: string): Result + type_check: function(Node, TypeCheckOptions): Result, string + + ----------------------------------------------------------------------------- + -- Private data: + ----------------------------------------------------------------------------- package_loader_env: Env load_envs: { {any:any} : Env } @@ -643,10 +681,6 @@ local record TypeReporter get_typenum: function(self, Type): integer end -tl.version = function(): string - return VERSION -end - local wk : {tl.WarningKind:boolean} = { ["unknown"] = true, ["unused"] = true, @@ -707,9 +741,9 @@ local type GenCompat = tl.GenCompat local type GenTarget = tl.GenTarget local type LoadFunction = tl.LoadFunction local type LoadMode = tl.LoadMode -local type PrettyPrintOptions = tl.PrettyPrintOptions +local type GenerateOptions = tl.GenerateOptions local type Result = tl.Result -local type TypeCheckOptions = tl.TypeCheckOptions +local type CheckOptions = tl.CheckOptions local type TypeInfo = tl.TypeInfo local type TypeReport = tl.TypeReport local type WarningKind = tl.WarningKind @@ -1028,7 +1062,7 @@ do end end - function tl.lex(input: string, filename: string): {Token}, {Error} + tl.lex = function(input: string, filename: string): {Token}, {Error} local tokens: {Token} = {} local state: LexState = "any" @@ -1518,7 +1552,7 @@ local function binary_search(list: {T}, item: U, cmp: function(T, U): bool end end -function tl.get_token_at(tks: {Token}, y: integer, x: integer): string +tl.get_token_at = function(tks: {Token}, y: integer, x: integer): string local _, found = binary_search( tks, nil, function(tk: Token): boolean @@ -4370,7 +4404,7 @@ parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): int return i, node end -function tl.parse_program(tokens: {Token}, errs: {Error}, filename: string): Node, {string} +function tl.parse_program(tokens: {Token}, errs: {Error}, filename?: string): Node, {string} errs = errs or {} local ps: ParseState = { tokens = tokens, @@ -4906,7 +4940,7 @@ local function recurse_node(s: S, root: Node, end -------------------------------------------------------------------------------- --- Pretty-print AST +-- Lua code generation -------------------------------------------------------------------------------- local tight_op: {integer:{string:boolean}} = { @@ -4951,13 +4985,13 @@ local spaced_op: {integer:{string:boolean}} = { } -local default_pretty_print_ast_opts: PrettyPrintOptions = { +local default_generate_opts: GenerateOptions = { preserve_indent = true, preserve_newlines = true, preserve_hashbang = false, } -local fast_pretty_print_ast_opts: PrettyPrintOptions = { +local fast_generate_opts: GenerateOptions = { preserve_indent = false, preserve_newlines = true, preserve_hashbang = false, @@ -4974,18 +5008,11 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode?: boolean | PrettyPrintOptions): string, string +function tl.generate(ast: Node, gen_target: GenTarget, opts?: GenerateOptions): string, string local err: string local indent = 0 - local opts: PrettyPrintOptions - if mode is PrettyPrintOptions then - opts = mode - elseif mode == true then - opts = fast_pretty_print_ast_opts - else - opts = default_pretty_print_ast_opts - end + opts = opts or default_generate_opts local record Output {string} @@ -6807,7 +6834,7 @@ local function search_for(module_name: string, suffix: string, path: string, tri return nil, nil, tried end -function tl.search_module(module_name: string, search_dtl: boolean): string, FILE, {string} +tl.search_module = function(module_name: string, search_dtl: boolean): string, FILE, {string} local found: string local fd: FILE local tried: {string} = {} @@ -6829,7 +6856,7 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(w: Where, module_name: string, opts: TypeCheckOptions, env: Env): Type, string +local function require_module(w: Where, module_name: string, opts: CheckOptions, env: Env): Type, string local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] @@ -6842,7 +6869,7 @@ local function require_module(w: Where, module_name: string, opts: TypeCheckOpti env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) local save_defaults = env.defaults - local defaults : TypeCheckOptions = { + local defaults : CheckOptions = { feat_lax = opts.feat_lax or save_defaults.feat_lax, feat_arity = opts.feat_arity or save_defaults.feat_arity, gen_compat = opts.gen_compat or save_defaults.gen_compat, @@ -6851,7 +6878,7 @@ local function require_module(w: Where, module_name: string, opts: TypeCheckOpti } env.defaults = defaults - local found_result, err: Result, string = tl.process(found, env, fd) + local found_result, err: Result, string = tl.check_file(found, env, fd) assert(found_result, err) env.defaults = save_defaults @@ -6885,7 +6912,7 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge local code: Node = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) + tl.check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -7025,7 +7052,7 @@ tl.new_env = function(opts?: EnvOptions): Env, string local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - local result = tl.type_check(program, "@stdlib", {}, env) + local result = tl.check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -10512,7 +10539,7 @@ do end local module_name = assert(node.e2[1].conststr) - local tc_opts: TypeCheckOptions = { + local tc_opts: CheckOptions = { feat_lax = self.feat_lax and "on" or "off", feat_arity = self.feat_arity and "on" or "off", } @@ -12918,10 +12945,7 @@ do end end - tl.type_check = function(ast: Node, filename: string, opts: TypeCheckOptions, env?: Env): Result, string - assert(filename is string, "tl.type_check signature has changed, pass filename separately") - assert((not opts) or (not (opts as {any:any}).env), "tl.type_check signature has changed, pass env separately") - + tl.check = function(ast: Node, filename?: string, opts?: CheckOptions, env?: Env): Result, string filename = filename or "?" opts = opts or {} @@ -13080,7 +13104,7 @@ local function feat_lax_heuristic(filename?: string, input?: string): Feat return "off" end -tl.process = function(filename: string, env: Env, fd?: FILE): Result, string +tl.check_file = function(filename: string, env?: Env, fd?: FILE): Result, string if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -13100,7 +13124,7 @@ tl.process = function(filename: string, env: Env, fd?: FILE): Result, string return nil, "could not read " .. filename .. ": " .. err end - return tl.process_string(input, env, filename) + return tl.check_string(input, env, filename) end function tl.target_from_lua_version(str: string): GenTarget @@ -13127,9 +13151,7 @@ local function default_env_opts(runtime: boolean, filename?: string, input?: str } end -function tl.process_string(input: string, env?: Env, filename?: string): Result - assert(type(env) ~= "boolean", "tl.process_string signature has changed") - +function tl.check_string(input: string, env?: Env, filename?: string): Result env = env or tl.new_env(default_env_opts(false, filename, input)) if env.loaded and env.loaded[filename] then @@ -13153,23 +13175,23 @@ function tl.process_string(input: string, env?: Env, filename?: string): Result return result end - local result = tl.type_check(program, filename, env.defaults, env) + local result = tl.check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors return result end -tl.gen = function(input: string, env: Env, pp: PrettyPrintOptions): string, Result +tl.gen = function(input: string, env?: Env, opts?: GenerateOptions): string, Result env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") - local result = tl.process_string(input, env) + local result = tl.check_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code: string - code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) + code, result.gen_error = tl.generate(result.ast, env.defaults.gen_target, opts) return code, result end @@ -13197,13 +13219,13 @@ local function tl_package_loader(module_name: string): any, any local w = { f = found_filename, x = 1, y = 1 } env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) - local result = tl.type_check(program, found_filename, opts.defaults, env) + local result = tl.check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type -- TODO: should this be a hard error? this seems analogous to -- finding a lua file with a syntax error in it - local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) + local code = assert(tl.generate(program, opts.defaults.gen_target, fast_generate_opts)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname: string, loader_data: string): any @@ -13245,7 +13267,7 @@ local function env_for(opts: EnvOptions, env_tbl: {any:any}): Env return tl.load_envs[env_tbl] end -tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:any}): LoadFunction, string +tl.load = function(input: string, chunkname?: string, mode?: LoadMode, ...: {any:any}): LoadFunction, string local program, errs = tl.parse(input, chunkname) if #errs > 0 then return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg @@ -13258,7 +13280,7 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a end local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") - local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) + local result = tl.check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -13272,7 +13294,7 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a mode = mode:gsub("c", "") as LoadMode end - local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) + local code, err = tl.generate(program, opts.defaults.gen_target, fast_generate_opts) if not code then return nil, err end @@ -13280,6 +13302,10 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return load(code, chunkname, mode, ...) end +tl.version = function(): string + return VERSION +end + -------------------------------------------------------------------------------- -- Backwards compatibility -------------------------------------------------------------------------------- @@ -13305,4 +13331,36 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | GenCompat, gen_targ return tl.new_env(opts) end +tl.type_check = function(ast: Node, tc_opts?: tl.TypeCheckOptions): Result, string + local opts: CheckOptions = { + feat_lax = tc_opts.lax and "on" or "off", + feat_arity = tc_opts.env and tc_opts.env.defaults.feat_arity or "on", + gen_compat = tc_opts.gen_compat, + gen_target = tc_opts.gen_target, + run_internal_compiler_checks = tc_opts.run_internal_compiler_checks, + } + return tl.check(ast, tc_opts.filename, opts, tc_opts.env) +end + +tl.pretty_print_ast = function(ast: Node, gen_target?: tl.TargetMode, mode?: boolean | tl.PrettyPrintOptions): string, string + local opts: GenerateOptions + if mode is tl.PrettyPrintOptions then + opts = mode + elseif mode == true then + opts = fast_generate_opts + else + opts = default_generate_opts + end + + return tl.generate(ast, gen_target, opts) +end + +tl.process = function(filename: string, env?: Env, fd?: FILE): Result, string + return tl.check_file(filename, env, fd) +end + +tl.process_string = function(input: string, is_lua: boolean, env: Env, filename: string, _module_name: string): Result + return tl.check_string(input, env or tl.init_env(is_lua), filename) +end + return tl From 65a904e7144d07812104b39183dd0a2d6310a4d9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sat, 31 Aug 2024 22:37:21 -0300 Subject: [PATCH 192/224] API: expose Token record and Node abstract interface --- tl.lua | 56 +++++++++++++++++++++++++--------------------- tl.tl | 70 +++++++++++++++++++++++++++++++--------------------------- 2 files changed, 69 insertions(+), 57 deletions(-) diff --git a/tl.lua b/tl.lua index fc3e57753..0292b3bcb 100644 --- a/tl.lua +++ b/tl.lua @@ -492,7 +492,35 @@ local Errors = {} -local tl = {GenerateOptions = {}, CheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, TypeCheckOptions = {}, } +local tl = {GenerateOptions = {}, CheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, Token = {}, TypeCheckOptions = {}, } + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -748,6 +776,8 @@ tl.typecodes = { + + local DEFAULT_GEN_COMPAT = "optional" local DEFAULT_GEN_TARGET = "5.3" @@ -821,30 +851,6 @@ end - - - - - - - - - - - - - - - - - - - - - - - - do diff --git a/tl.tl b/tl.tl index 8ab65623e..235a018d0 100644 --- a/tl.tl +++ b/tl.tl @@ -619,6 +619,34 @@ local record tl predefined_modules: {string} end + -- abstract type + interface Node + end + + enum TokenKind + "hashbang" + "keyword" + "op" + "string" + "[" "]" "(" ")" "{" "}" "," ":" "." ";" "?" + "::" + "..." + "identifier" + "number" + "integer" + "pragma" + "pragma_identifier" + "$ERR$" + "$EOF$" + end + + record Token + x: integer + y: integer + tk: string + kind: TokenKind + end + ----------------------------------------------------------------------------- -- Public API ----------------------------------------------------------------------------- @@ -744,6 +772,8 @@ local type LoadMode = tl.LoadMode local type GenerateOptions = tl.GenerateOptions local type Result = tl.Result local type CheckOptions = tl.CheckOptions +local type Token = tl.Token +local type TokenKind = tl.TokenKind local type TypeInfo = tl.TypeInfo local type TypeReport = tl.TypeReport local type WarningKind = tl.WarningKind @@ -821,30 +851,6 @@ end -- Lexer -------------------------------------------------------------------------------- -local enum TokenKind - "hashbang" - "keyword" - "op" - "string" - "[" "]" "(" ")" "{" "}" "," ":" "." ";" "?" - "::" - "..." - "identifier" - "number" - "integer" - "pragma" - "pragma_identifier" - "$ERR$" - "$EOF$" -end - -local record Token - x: integer - y: integer - tk: string - kind: TokenKind -end - do local enum LexState "start" @@ -2063,7 +2069,7 @@ local attributes : {Attribute: boolean} = { local is_attribute : {string:boolean} = attributes as {string:boolean} local record Node - is {Node}, Where + is {Node}, tl.Node, Where where self.kind ~= nil record ExpectedContext @@ -4404,10 +4410,10 @@ parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): int return i, node end -function tl.parse_program(tokens: {Token}, errs: {Error}, filename?: string): Node, {string} +function tl.parse_program(tokens: {Token}, errs: {Error}, filename?: string): tl.Node, {string} errs = errs or {} local ps: ParseState = { - tokens = tokens, + tokens = tokens as {Token}, errs = errs, filename = filename or "", required_modules = {}, @@ -4427,7 +4433,7 @@ function tl.parse_program(tokens: {Token}, errs: {Error}, filename?: string): No return node, ps.required_modules end -function tl.parse(input: string, filename: string): Node, {Error}, {string} +function tl.parse(input: string, filename: string): tl.Node, {Error}, {string} local tokens, errs = tl.lex(input, filename) local node, required_modules = tl.parse_program(tokens, errs, filename) return node, errs, required_modules @@ -4695,7 +4701,7 @@ local no_recurse_node: {NodeKind : boolean} = { ["type_identifier"] = true, } -local function recurse_node(s: S, root: Node, +local function recurse_node(s: S, root: tl.Node, visit_node: Visitor, visit_type: Visitor): T if not root then @@ -4703,7 +4709,7 @@ local function recurse_node(s: S, root: Node, return end - local recurse: function(Node): T + local recurse: function(tl.Node): T local function walk_children(ast: Node, xs: {T}) for i, child in ipairs(ast) do @@ -5008,7 +5014,7 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.generate(ast: Node, gen_target: GenTarget, opts?: GenerateOptions): string, string +function tl.generate(ast: tl.Node, gen_target: GenTarget, opts?: GenerateOptions): string, string local err: string local indent = 0 @@ -6911,7 +6917,7 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge local function load_code(name: string, text: string) local code: Node = compat_code_cache[name] if not code then - code = tl.parse(text, "@internal") + code = tl.parse(text, "@internal") as Node tl.check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end From 775ceda5ed0b727df6572d9174544f28aa5843d1 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 5 Sep 2024 01:00:35 -0300 Subject: [PATCH 193/224] funcall: minor simplification --- tl.lua | 12 ++++++------ tl.tl | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tl.lua b/tl.lua index 0292b3bcb..989ecf2b8 100644 --- a/tl.lua +++ b/tl.lua @@ -9277,9 +9277,7 @@ a.types[i], b.types[i]), } infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(self, f) - - return self:resolve_typevars_at(w, f.rets) + return f.rets end end @@ -9335,9 +9333,7 @@ a.types[i], b.types[i]), } local f = resolve_function_type(func, 1) - mark_invalid_typeargs(self, f) - - return self:resolve_typevars_at(w, f.rets) + return f.rets end local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) @@ -9449,6 +9445,10 @@ a.types[i], b.types[i]), } local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + if f then + mark_invalid_typeargs(self, f) + end + ret = self:resolve_typevars_at(node, ret) self:end_scope() diff --git a/tl.tl b/tl.tl index 235a018d0..dbc702c30 100644 --- a/tl.tl +++ b/tl.tl @@ -9277,9 +9277,7 @@ do infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(self, f) - - return self:resolve_typevars_at(w, f.rets) + return f.rets end end @@ -9335,9 +9333,7 @@ do local f = resolve_function_type(func, 1) - mark_invalid_typeargs(self, f) - - return self:resolve_typevars_at(w, f.rets) + return f.rets end local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType @@ -9449,6 +9445,10 @@ do local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + if f then + mark_invalid_typeargs(self, f) + end + ret = self:resolve_typevars_at(node, ret) self:end_scope() From ae05702d96782f7e0bfd23a2003e2355b007889b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 6 Sep 2024 02:01:26 -0300 Subject: [PATCH 194/224] typevar_resolver: drop unused argument --- tl.lua | 2 +- tl.tl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tl.lua b/tl.lua index 989ecf2b8..06fe5ad0f 100644 --- a/tl.lua +++ b/tl.lua @@ -7604,7 +7604,7 @@ do clear_resolved_typeargs(copy, resolved) - return true, copy, nil, resolved + return true, copy, nil end local function resolve_typevar(tc, t) diff --git a/tl.tl b/tl.tl index dbc702c30..5c340875a 100644 --- a/tl.tl +++ b/tl.tl @@ -7432,7 +7432,7 @@ do return end - typevar_resolver = function(self: S, typ: Type, fn_var: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error}, {string:Type} + typevar_resolver = function(self: S, typ: Type, fn_var: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:Type} = {} @@ -7604,7 +7604,7 @@ do clear_resolved_typeargs(copy, resolved) - return true, copy, nil, resolved + return true, copy, nil end local function resolve_typevar(tc: TypeChecker, t: TypeVarType): Type From 34193b78fe4a7b7ddf1e5de73ee9326c91e27c8d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 5 Sep 2024 13:51:51 -0300 Subject: [PATCH 195/224] funcall: narrow type of check_call --- tl.lua | 30 ++++++++++++++++-------------- tl.tl | 30 ++++++++++++++++-------------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/tl.lua b/tl.lua index 06fe5ad0f..9e5520b2a 100644 --- a/tl.lua +++ b/tl.lua @@ -9336,22 +9336,10 @@ a.types[i], b.types[i]), } return f.rets end - local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) + local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta, is_method) assert(type(func) == "table") assert(type(args) == "table") - local is_method = (argdelta == -1) - - if not (func.typename == "function" or func.typename == "poly") then - func, is_method = self:resolve_for_call(func, args, is_method) - if is_method then - argdelta = -1 - end - if not (func.typename == "function" or func.typename == "poly") then - return self.errs:invalid_at(w, "not a function: %s", func) - end - end - if is_method and args.tuple[1] then self:add_var(nil, "@self", a_type(w, "typedecl", { def = args.tuple[1] })) end @@ -9443,7 +9431,21 @@ a.types[i], b.types[i]), } end end - local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + local is_method = (argdelta == -1) + + if not (func.typename == "function" or func.typename == "poly") then + func, is_method = self:resolve_for_call(func, args, is_method) + if is_method then + argdelta = -1 + end + end + + local ret, f + if func.typename == "function" or func.typename == "poly" then + ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0, is_method) + else + ret = self.errs:invalid_at(node, "not a function: %s", func) + end if f then mark_invalid_typeargs(self, f) diff --git a/tl.tl b/tl.tl index 5c340875a..9221fec3f 100644 --- a/tl.tl +++ b/tl.tl @@ -9336,22 +9336,10 @@ do return f.rets end - local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: FunctionType | PolyType, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer, is_method: boolean): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") - local is_method = (argdelta == -1) - - if not (func is FunctionType or func is PolyType) then - func, is_method = self:resolve_for_call(func, args, is_method) - if is_method then - argdelta = -1 - end - if not (func is FunctionType or func is PolyType) then - return self.errs:invalid_at(w, "not a function: %s", func) - end - end - if is_method and args.tuple[1] then self:add_var(nil, "@self", a_typedecl(w, args.tuple[1])) end @@ -9443,7 +9431,21 @@ do end end - local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + local is_method = (argdelta == -1) + + if not (func is FunctionType or func is PolyType) then + func, is_method = self:resolve_for_call(func, args, is_method) + if is_method then + argdelta = -1 + end + end + + local ret, f: InvalidOrTupleType, FunctionType + if func is FunctionType or func is PolyType then + ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0, is_method) + else + ret = self.errs:invalid_at(node, "not a function: %s", func) + end if f then mark_invalid_typeargs(self, f) From 96baa3d6f9ccb5ccb6731eb50d21ae6f65e70b56 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 5 Sep 2024 15:44:18 -0300 Subject: [PATCH 196/224] funcall: feat_lax=on implies feat_arity=off --- tl.lua | 8 ++++++-- tl.tl | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tl.lua b/tl.lua index 9e5520b2a..84016a83e 100644 --- a/tl.lua +++ b/tl.lua @@ -9371,11 +9371,11 @@ a.types[i], b.types[i]), } local min_arity = self.feat_arity and f.min_arity or 0 - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) or + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted))) or (passes == 3 and ((pass == 1 and given == wanted) or - (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) or + (pass == 2 and given < wanted and given >= min_arity) or (pass == 3 and f.args.is_va and given > wanted))) then @@ -12992,6 +12992,10 @@ self:expand_type(node, values, elements) }) self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET + if self.feat_lax then + self.feat_arity = false + end + if self.gen_target == "5.4" and self.gen_compat ~= "off" then return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" end diff --git a/tl.tl b/tl.tl index 9221fec3f..e1857f4c0 100644 --- a/tl.tl +++ b/tl.tl @@ -9371,11 +9371,11 @@ do local min_arity = self.feat_arity and f.min_arity or 0 -- simple functions: - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted))) -- poly, pass 1: try exact arity matches first or (passes == 3 and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) + or (pass == 2 and given < wanted and given >= min_arity) -- poly, pass 3: then finally try vararg functions or (pass == 3 and f.args.is_va and given > wanted))) then @@ -12992,6 +12992,10 @@ do self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET + if self.feat_lax then + self.feat_arity = false + end + if self.gen_target == "5.4" and self.gen_compat ~= "off" then return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" end From bcaac589f48f389645cb8bb72a7104c3a83a0420 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 5 Sep 2024 18:02:21 -0300 Subject: [PATCH 197/224] funcall: split logic for FunctionType and PolyType --- tl.lua | 338 +++++++++++++++++++++++++++++------------------------- tl.tl | 352 +++++++++++++++++++++++++++++++-------------------------- 2 files changed, 371 insertions(+), 319 deletions(-) diff --git a/tl.lua b/tl.lua index 84016a83e..450d65ea2 100644 --- a/tl.lua +++ b/tl.lua @@ -1991,6 +1991,7 @@ end + local TruthyFact = {} @@ -2725,7 +2726,7 @@ do i, typ.typeargs = parse_typeargs_if_any(ps, i) if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) + i, typ.args, typ.maybe_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) @@ -3985,6 +3986,9 @@ do if not t then return fail(ps, i, "expected a type") end + if t.typename == "function" and t.maybe_method then + t.is_method = true + end local field_name = v.conststr or v.tk local fields = def.fields @@ -6155,6 +6159,15 @@ local function Err(msg, t1, t2, t3) } end +local function Err_at(w, msg) + return { + msg = msg, + x = assert(w.x), + y = assert(w.y), + filename = assert(w.f), + } +end + local function insert_error(self, y, x, err) err.y = assert(y) err.x = assert(x) @@ -9216,194 +9229,202 @@ a.types[i], b.types[i]), } end end - local check_args_rets - do - - local function check_func_type_list(self, w, wheres, xs, ys, from, delta, v, mode) - assert(xs.typename == "tuple", xs.typename) - assert(ys.typename == "tuple", ys.typename) - - local errs = {} - local xt, yt = xs.tuple, ys.tuple - local n_xs = #xt - local n_ys = #yt + local function push_typeargs(self, func) + if func.typeargs then + for _, fnarg in ipairs(func.typeargs) do + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { + constraint = fnarg.constraint, + })) + end + end + end - for i = from, math.max(n_xs, n_ys) do - local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) - local y = yt[i] or (ys.is_va and yt[n_ys]) - if y then - local iw = wheres and wheres[pos] or w - if not self:arg_check(iw, errs, x, y, v, mode, pos) then - return nil, errs - end + local function pop_typeargs(self, func) + if func.typeargs then + for _, fnarg in ipairs(func.typeargs) do + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end - - return true end + end + + + - check_args_rets = function(self, w, where_args, f, args, expected_rets, argdelta) - local rets_ok = true - local rets_errs - local args_ok - local args_errs - local fargs = f.args.tuple - local from = 1 - if argdelta == -1 then - from = 2 + + + local check_call + do + local check_args_rets + do + + local function check_func_type_list(self, w, wheres, xs, ys, from, delta, v, mode) local errs = {} - local first = fargs[1] - if (not (first.typename == "self")) and not self:arg_check(w, errs, first, args.tuple[1], "contravariant", "self") then - return nil, errs + local xt, yt = xs.tuple, ys.tuple + local n_xs = #xt + local n_ys = #yt + + for i = from, math.max(n_xs, n_ys) do + local pos = i + delta + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) + local y = yt[i] or (ys.is_va and yt[n_ys]) + if y then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then + return nil, errs + end + end end - end - if expected_rets then - expected_rets = self:infer_at(w, expected_rets) - infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - - rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + return true end - args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") - if (not args_ok) or (not rets_ok) then - return nil, args_errs or {} - end + check_args_rets = function(self, w, wargs, f, args, expected_rets, argdelta) + local rets_ok, rets_errs = true, nil + local args_ok, args_errs = true, nil + local from = 1 + if argdelta == -1 then + from = 2 + local errs = {} + local first = f.args.tuple[1] + if (not (first.typename == "self")) and not self:arg_check(w, errs, first, args.tuple[1], "contravariant", "self") then + return nil, errs + end + end + if expected_rets then + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + end - infer_emptytables(self, w, where_args, args, f.args, argdelta) + args_ok, args_errs = check_func_type_list(self, w, wargs, f.args, args, from, argdelta, "contravariant", "argument") + if (not args_ok) or (not rets_ok) then + return nil, args_errs or {} + end - return f.rets - end - end + infer_emptytables(self, w, wargs, args, f.args, argdelta) - local function push_typeargs(self, func) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { - constraint = fnarg.constraint, - })) + return true end end - end - local function pop_typeargs(self, func) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - if self.st[#self.st].vars[fnarg.typearg] then - self.st[#self.st].vars[fnarg.typearg] = nil - end + local function is_method_mismatch(self, w, arg1, farg1, cm) + if cm == "method" or not farg1 then + return false + end + if not (arg1 and self:is_a(arg1, farg1)) then + self.errs:add(w, "invoked method as a regular function: use ':' instead of '.'") + return true + end + if cm == "plain" then + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end + return false end - end - local function resolve_function_type(func, i) - if func.typename == "poly" then - return func.types[i] - else - return func + check_call = function(self, w, wargs, f, args, expected_rets, cm, argdelta) + local arg1 = args.tuple[1] + if cm == "method" and arg1 then + self:add_var(nil, "@self", a_type(w, "typedecl", { def = arg1 })) + end + + local fargs = f.args.tuple + if f.is_method and is_method_mismatch(self, w, arg1, fargs[1], cm) then + return false + end + + local given = #args.tuple + local wanted = #fargs + local min_arity = self.feat_arity and f.min_arity or 0 + + if given < min_arity or (given > wanted and not f.args.is_va) then + return nil, { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. show_arity(f) .. ")") } + end + + push_typeargs(self, f) + + return check_args_rets(self, w, wargs, f, args, expected_rets, argdelta) end end - local function fail_call(self, w, func, nargs, errs) - if errs then - self.errs:collect(errs) - else - + local check_poly_call + do + local function fail_poly_call_arity(w, p, given) local expects = {} - if func.typename == "poly" then - for _, f in ipairs(func.types) do - table.insert(expects, show_arity(f)) - end - table.sort(expects) - for i = #expects, 1, -1 do - if expects[i] == expects[i + 1] then - table.remove(expects, i) - end + for _, f in ipairs(p.types) do + table.insert(expects, show_arity(f)) + end + table.sort(expects) + for i = #expects, 1, -1 do + if expects[i] == expects[i + 1] then + table.remove(expects, i) end - else - table.insert(expects, show_arity(func)) end - self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + return { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. table.concat(expects, " or ") .. ")") } end - local f = resolve_function_type(func, 1) + check_poly_call = function(self, w, wargs, p, args, expected_rets, cm, argdelta) + local given = #args.tuple - return f.rets - end + local tried = {} + local first_errs - local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta, is_method) - assert(type(func) == "table") - assert(type(args) == "table") - - if is_method and args.tuple[1] then - self:add_var(nil, "@self", a_type(w, "typedecl", { def = args.tuple[1] })) - end - - local passes, n = 1, 1 - if func.typename == "poly" then - passes, n = 3, #func.types - end - - local given = #args.tuple - local tried - local first_errs - for pass = 1, passes do - for i = 1, n do - if (not tried) or not tried[i] then - local f = resolve_function_type(func, i) - local fargs = f.args.tuple - if f.is_method and not is_method then - if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then - - if not is_typedecl_funcall then - self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") - end - else - return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") - end - end - local wanted = #fargs + for pass = 1, 3 do + for i, f in ipairs(p.types) do + local wanted = #f.args.tuple local min_arity = self.feat_arity and f.min_arity or 0 + if (not tried[i]) and - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted))) or + ((pass == 1 and given == wanted) or - (passes == 3 and ((pass == 1 and given == wanted) or + (pass == 2 and (given < wanted and given >= min_arity)) or - (pass == 2 and given < wanted and given >= min_arity) or + (pass == 3 and (f.args.is_va and given > wanted))) then - (pass == 3 and f.args.is_va and given > wanted))) then + local ok, errs = check_call(self, w, wargs, f, args, expected_rets, cm, argdelta) + if ok then + return f + elseif expected_rets then - push_typeargs(self, f) + infer_emptytables(self, w, wargs, f.rets, f.rets, argdelta) + end - local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) - if matched then + pop_typeargs(self, f) - return matched, f - end first_errs = first_errs or errs + tried[i] = true + end + end + end - if expected_rets then + if not first_errs then + return nil, fail_poly_call_arity(w, p, given) + end - infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) - end + return nil, first_errs + end + end - if passes == 3 then - tried = tried or {} - tried[i] = true - pop_typeargs(self, f) - end - end + local function should_warn_dot(node, e1, is_method) + if is_method then + return "method" + end + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver + if receiver.typename == "nominal" then + local resolved = receiver.resolved + if resolved and resolved.typename == "typedecl" then + return "type_dot" end end end - - return fail_call(self, w, func, given, first_errs) + return "plain" end function TypeChecker:type_check_function_call(node, func, args, argdelta, e1, e2) @@ -9418,19 +9439,6 @@ a.types[i], b.types[i]), } expected_rets = a_type(node, "tuple", { tuple = { node.expected } }) end - self:begin_scope() - - local is_typedecl_funcall - if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then - local receiver = e1.receiver - if receiver.typename == "nominal" then - local resolved = receiver.resolved - if resolved and resolved.typename == "typedecl" then - is_typedecl_funcall = true - end - end - end - local is_method = (argdelta == -1) if not (func.typename == "function" or func.typename == "poly") then @@ -9440,13 +9448,31 @@ a.types[i], b.types[i]), } end end - local ret, f - if func.typename == "function" or func.typename == "poly" then - ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0, is_method) + local cm = should_warn_dot(node, e1, is_method) + + local ok, errs + local f, ret + + self:begin_scope() + + if func.typename == "poly" then + f, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) + if f then + ret = f.rets + else + ret = func.types[1].rets + end + elseif func.typename == "function" then + ok, errs = check_call(self, node, e2, func, args, expected_rets, cm, argdelta) + f, ret = func, func.rets else ret = self.errs:invalid_at(node, "not a function: %s", func) end + if errs then + self.errs:collect(errs) + end + if f then mark_invalid_typeargs(self, f) end diff --git a/tl.tl b/tl.tl index e1857f4c0..2ddb7739f 100644 --- a/tl.tl +++ b/tl.tl @@ -1880,6 +1880,7 @@ local record FunctionType where self.typename == "function" is_method: boolean + maybe_method: boolean is_record_function: boolean min_arity: integer args: TupleType @@ -2725,7 +2726,7 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio i, typ.typeargs = parse_typeargs_if_any(ps, i) if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) + i, typ.args, typ.maybe_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) @@ -3985,6 +3986,9 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType): i if not t then return fail(ps, i, "expected a type") end + if t is FunctionType and t.maybe_method then + t.is_method = true + end local field_name = v.conststr or v.tk local fields = def.fields @@ -6155,6 +6159,15 @@ local function Err(msg: string, t1?: Type, t2?: Type, t3?: Type): Error } end +local function Err_at(w: Where, msg: string): Error + return { + msg = msg, + x = assert(w.x), + y = assert(w.y), + filename = assert(w.f), + } +end + local function insert_error(self: Errors, y: integer, x: integer, err: Error) err.y = assert(y) err.x = assert(x) @@ -9216,71 +9229,6 @@ do end end - local check_args_rets: function(TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} - do - -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} - assert(xs.typename == "tuple", xs.typename) - assert(ys.typename == "tuple", ys.typename) - - local errs = {} - local xt, yt = xs.tuple, ys.tuple - local n_xs = #xt - local n_ys = #yt - - for i = from, math.max(n_xs, n_ys) do - local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) - local y = yt[i] or (ys.is_va and yt[n_ys]) - if y then - local iw = wheres and wheres[pos] or w - if not self:arg_check(iw, errs, x, y, v, mode, pos) then - return nil, errs - end - end - end - - return true - end - - check_args_rets = function(self: TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} - local rets_ok = true - local rets_errs: {Error} - local args_ok: boolean - local args_errs: {Error} - local fargs = f.args.tuple - - local from = 1 - if argdelta == -1 then - from = 2 - local errs = {} - local first = fargs[1] - if (not first is SelfType) and not self:arg_check(w, errs, first, args.tuple[1], "contravariant", "self") then - return nil, errs - end - end - - if expected_rets then - expected_rets = self:infer_at(w, expected_rets) - infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - - rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") - end - - args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") - if (not args_ok) or (not rets_ok) then - return nil, args_errs or {} - end - - -- if we got to this point without returning, - -- we got a valid function match - - infer_emptytables(self, w, where_args, args, f.args, argdelta) - - return f.rets - end - end - local function push_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do @@ -9301,112 +9249,185 @@ do end end - local function resolve_function_type(func: FunctionType | PolyType, i: integer): FunctionType - if func is PolyType then - return func.types[i] - else - return func - end + local enum CallMode + "method" -- a method colon-call, e.g. `my_object:my_method()` + "plain" -- a plain call or a dot-call, e.g `my_func()` or `my_object.my_func()` + "type_dot" -- a dot-call where the receiver is a type, e.g. `MyRecord.my_func()` end - local function fail_call(self: TypeChecker, w: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType - if errs then - self.errs:collect(errs) - else - -- found no arity match to try - local expects: {string} = {} - if func is PolyType then - for _, f in ipairs(func.types) do - table.insert(expects, show_arity(f)) - end - table.sort(expects) - for i = #expects, 1, -1 do - if expects[i] == expects[i+1] then - table.remove(expects, i) + local check_call: function(self: TypeChecker, w: Where, wargs: {Where}, f: FunctionType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): boolean, {Error} + do + local check_args_rets: function(TypeChecker, w: Where, wargs: {Where}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): boolean, {Error} + do + -- check if a tuple `xs` matches tuple `ys` + local function check_func_type_list(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} + local errs = {} + local xt, yt = xs.tuple, ys.tuple + local n_xs = #xt + local n_ys = #yt + + for i = from, math.max(n_xs, n_ys) do + local pos = i + delta + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) + local y = yt[i] or (ys.is_va and yt[n_ys]) + if y then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then + return nil, errs + end end end - else - table.insert(expects, show_arity(func)) + + return true + end + + check_args_rets = function(self: TypeChecker, w: Where, wargs: {Where}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): boolean, {Error} + local rets_ok, rets_errs: boolean, {Error} = true, nil + local args_ok, args_errs: boolean, {Error} = true, nil + + local from = 1 + if argdelta == -1 then + from = 2 + local errs = {} + local first = f.args.tuple[1] + if (not first is SelfType) and not self:arg_check(w, errs, first, args.tuple[1], "contravariant", "self") then + return nil, errs + end + end + + if expected_rets then + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) + + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + end + + args_ok, args_errs = check_func_type_list(self, w, wargs, f.args, args, from, argdelta, "contravariant", "argument") + if (not args_ok) or (not rets_ok) then + return nil, args_errs or {} + end + + infer_emptytables(self, w, wargs, args, f.args, argdelta) + + return true end - self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end - local f = resolve_function_type(func, 1) + local function is_method_mismatch(self: TypeChecker, w: Where, arg1: Type, farg1: Type, cm: CallMode): boolean + if cm == "method" or not farg1 then + return false + end + if not (arg1 and self:is_a(arg1, farg1)) then + self.errs:add(w, "invoked method as a regular function: use ':' instead of '.'") + return true + end + if cm == "plain" then + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") + end + return false + end - return f.rets - end + check_call = function(self: TypeChecker, w: Where, wargs: {Where}, f: FunctionType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): boolean, {Error} + local arg1 = args.tuple[1] + if cm == "method" and arg1 then + self:add_var(nil, "@self", a_typedecl(w, arg1)) + end - local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: FunctionType | PolyType, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer, is_method: boolean): InvalidOrTupleType, FunctionType - assert(type(func) == "table") - assert(type(args) == "table") + local fargs = f.args.tuple + if f.is_method and is_method_mismatch(self, w, arg1, fargs[1], cm) then + return false + end + + local given = #args.tuple + local wanted = #fargs + local min_arity = self.feat_arity and f.min_arity or 0 + + if given < min_arity or (given > wanted and not f.args.is_va) then + return nil, { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. show_arity(f) .. ")") } + end - if is_method and args.tuple[1] then - self:add_var(nil, "@self", a_typedecl(w, args.tuple[1])) + push_typeargs(self, f) + + return check_args_rets(self, w, wargs, f, args, expected_rets, argdelta) end + end - local passes, n = 1, 1 - if func is PolyType then - passes, n = 3, #func.types - end - - local given = #args.tuple - local tried: {integer:boolean} - local first_errs: {Error} - for pass = 1, passes do - for i = 1, n do - if (not tried) or not tried[i] then - local f = resolve_function_type(func, i) - local fargs = f.args.tuple - if f.is_method and not is_method then - if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then - -- a non-"@funcall" means a synthesized call, e.g. from a metamethod - if not is_typedecl_funcall then - self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") - end - else - return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") - end - end - local wanted = #fargs - local min_arity = self.feat_arity and f.min_arity or 0 + local check_poly_call: function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, {Error} + do + local function fail_poly_call_arity(w: Where, p: PolyType, given: integer): {Error} + local expects: {string} = {} + for _, f in ipairs(p.types) do + table.insert(expects, show_arity(f)) + end + table.sort(expects) + for i = #expects, 1, -1 do + if expects[i] == expects[i+1] then + table.remove(expects, i) + end + end + return { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. table.concat(expects, " or ") .. ")") } + end - -- simple functions: - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted))) - -- poly, pass 1: try exact arity matches first - or (passes == 3 and ((pass == 1 and given == wanted) - -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and given >= min_arity) - -- poly, pass 3: then finally try vararg functions - or (pass == 3 and f.args.is_va and given > wanted))) - then - push_typeargs(self, f) + check_poly_call = function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, {Error} + local given = #args.tuple - local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) - if matched then - -- success! - return matched, f - end - first_errs = first_errs or errs + local tried: {integer:boolean} = {} + local first_errs: {Error} - if expected_rets then + for pass = 1, 3 do + for i, f in ipairs(p.types) do + local wanted = #f.args.tuple + local min_arity = self.feat_arity and f.min_arity or 0 + + if (not tried[i]) and + -- try exact arity matches first + ( (pass == 1 and given == wanted) + -- then try adjusting with nils to missing arguments or using '...' + or (pass == 2 and (given < wanted and given >= min_arity)) + -- then finally try vararg functions + or (pass == 3 and (f.args.is_va and given > wanted)) ) + then + local ok, errs = check_call(self, w, wargs, f, args, expected_rets, cm, argdelta) + if ok then + return f + elseif expected_rets then -- revert inferred returns - infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, wargs, f.rets, f.rets, argdelta) end - if passes == 3 then - tried = tried or {} - tried[i] = true - pop_typeargs(self, f) - end + pop_typeargs(self, f) + + first_errs = first_errs or errs + tried[i] = true end end end + + if not first_errs then + return nil, fail_poly_call_arity(w, p, given) + end + + return nil, first_errs end + end - return fail_call(self, w, func, given, first_errs) + local function should_warn_dot(node: Node, e1: Node, is_method: boolean): CallMode + if is_method then + return "method" + end + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver + if receiver is NominalType then + local resolved = receiver.resolved + if resolved and resolved is TypeDeclType then + return "type_dot" + end + end + end + return "plain" end - function TypeChecker:type_check_function_call(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType + function TypeChecker:type_check_function_call(node: Node, func: Type, args: TupleType, argdelta: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -9418,19 +9439,6 @@ do expected_rets = a_tuple(node, { node.expected }) end - self:begin_scope() - - local is_typedecl_funcall: boolean - if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then - local receiver = e1.receiver - if receiver is NominalType then - local resolved = receiver.resolved - if resolved and resolved is TypeDeclType then - is_typedecl_funcall = true - end - end - end - local is_method = (argdelta == -1) if not (func is FunctionType or func is PolyType) then @@ -9440,13 +9448,31 @@ do end end - local ret, f: InvalidOrTupleType, FunctionType - if func is FunctionType or func is PolyType then - ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0, is_method) + local cm = should_warn_dot(node, e1, is_method) + + local ok, errs: boolean, {Error} + local f, ret: FunctionType, InvalidOrTupleType + + self:begin_scope() + + if func is PolyType then + f, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) + if f then + ret = f.rets + else + ret = func.types[1].rets + end + elseif func is FunctionType then + ok, errs = check_call(self, node, e2, func, args, expected_rets, cm, argdelta) + f, ret = func, func.rets else ret = self.errs:invalid_at(node, "not a function: %s", func) end + if errs then + self.errs:collect(errs) + end + if f then mark_invalid_typeargs(self, f) end From 033681e8688f6c19e826f5c529045bb507afacac Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 6 Sep 2024 02:24:34 -0300 Subject: [PATCH 198/224] infer_at: minor cleanup --- tl.lua | 10 +++------- tl.tl | 10 +++------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tl.lua b/tl.lua index 450d65ea2..f62f62fc2 100644 --- a/tl.lua +++ b/tl.lua @@ -11677,12 +11677,8 @@ self:expand_type(node, values, elements) }) end end - local t - if force_array then - t = self:infer_at(node, a_type(node, "array", { elements = force_array })) - else - t = self:resolve_typevars_at(node, node.expected) - end + local t = force_array and a_type(node, "array", { elements = force_array }) or node.expected + t = self:infer_at(node, t) if decltype.typename == "record" then local rt = self:to_structural(t) @@ -12234,7 +12230,7 @@ self:expand_type(node, values, elements) }) local a_is = self:is_a(ua, expected) local b_is = self:is_a(ub, expected) if a_is and b_is then - t = self:resolve_typevars_at(node, expected) + t = self:infer_at(node, expected) end end if not t then diff --git a/tl.tl b/tl.tl index 2ddb7739f..1d73b3517 100644 --- a/tl.tl +++ b/tl.tl @@ -11677,12 +11677,8 @@ do end end - local t: Type - if force_array then - t = self:infer_at(node, an_array(node, force_array)) - else - t = self:resolve_typevars_at(node, node.expected) - end + local t = force_array and an_array(node, force_array) or node.expected + t = self:infer_at(node, t) if decltype is RecordType then local rt = self:to_structural(t) @@ -12234,7 +12230,7 @@ do local a_is = self:is_a(ua, expected) local b_is = self:is_a(ub, expected) if a_is and b_is then - t = self:resolve_typevars_at(node, expected) + t = self:infer_at(node, expected) end end if not t then From 373e06d754f428e97f340a4380c4afefc5c66af9 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 8 Sep 2024 16:21:37 -0300 Subject: [PATCH 199/224] find_type: narrow return type --- tl.lua | 36 +++++++++++++++++------------------- tl.tl | 36 +++++++++++++++++------------------- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/tl.lua b/tl.lua index f62f62fc2..601d25c66 100644 --- a/tl.lua +++ b/tl.lua @@ -7261,7 +7261,7 @@ do return true end - function TypeChecker:find_type(names, accept_typearg) + function TypeChecker:find_type(names) local typ = self:find_var_type(names[1], "use_type") if not typ then if #names == 1 and names[1] == "metatable" then @@ -7294,8 +7294,8 @@ do end if typ.typename == "typedecl" then return typ - elseif accept_typearg and typ.typename == "typearg" then - return typ + elseif typ.typename == "typearg" then + return nil, typ end end @@ -12835,24 +12835,22 @@ self:expand_type(node, values, elements) }) return typ end - local t = self:find_type(typ.names, true) + local t, typearg = self:find_type(typ.names) if t then - if t.typename == "typearg" then - - typ.names = nil - edit_type(typ, typ, "typevar") - local tv = typ - tv.typevar = t.typearg - tv.constraint = t.constraint - elseif t.typename == "typedecl" then - local def = t.def - if t.is_alias then - assert(def.typename == "nominal") - typ.found = def.found - elseif def.typename ~= "circular_require" then - typ.found = t - end + local def = t.def + if t.is_alias then + assert(def.typename == "nominal") + typ.found = def.found + elseif def.typename ~= "circular_require" then + typ.found = t end + elseif typearg then + + typ.names = nil + edit_type(typ, typ, "typevar") + local tv = typ + tv.typevar = typearg.typearg + tv.constraint = typearg.constraint else local name = typ.names[1] local scope = self.st[#self.st] diff --git a/tl.tl b/tl.tl index 1d73b3517..4308d907f 100644 --- a/tl.tl +++ b/tl.tl @@ -7261,7 +7261,7 @@ do return true end - function TypeChecker:find_type(names: {string}, accept_typearg?: boolean): Type + function TypeChecker:find_type(names: {string}): TypeDeclType, TypeArgType local typ = self:find_var_type(names[1], "use_type") if not typ then if #names == 1 and names[1] == "metatable" then @@ -7294,8 +7294,8 @@ do end if typ is TypeDeclType then return typ - elseif accept_typearg and typ is TypeArgType then - return typ + elseif typ is TypeArgType then + return nil, typ end end @@ -12835,24 +12835,22 @@ do return typ end - local t = self:find_type(typ.names, true) + local t, typearg = self:find_type(typ.names) if t then - if t is TypeArgType then - -- convert nominal into a typevar - typ.names = nil - edit_type(typ, typ, "typevar") - local tv = typ as TypeVarType - tv.typevar = t.typearg - tv.constraint = t.constraint - elseif t is TypeDeclType then - local def = t.def - if t.is_alias then - assert(def is NominalType) - typ.found = def.found - elseif def.typename ~= "circular_require" then - typ.found = t - end + local def = t.def + if t.is_alias then + assert(def is NominalType) + typ.found = def.found + elseif def.typename ~= "circular_require" then + typ.found = t end + elseif typearg then + -- convert nominal into a typevar + typ.names = nil + edit_type(typ, typ, "typevar") + local tv = typ as TypeVarType + tv.typevar = typearg.typearg + tv.constraint = typearg.constraint else local name = typ.names[1] local scope = self.st[#self.st] From b78331353026baf19992bb6c1f69e718d8226364 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 9 Sep 2024 00:30:04 -0300 Subject: [PATCH 200/224] refactor: ensure_not_method --- tl.lua | 27 +++++++++++++++------------ tl.tl | 33 ++++++++++++++++++--------------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/tl.lua b/tl.lua index 601d25c66..a2466161d 100644 --- a/tl.lua +++ b/tl.lua @@ -7261,6 +7261,15 @@ do return true end + local function ensure_not_method(t) + + if t.typename == "function" and t.is_method then + t = shallow_copy_new_type(t) + t.is_method = false + end + return t + end + function TypeChecker:find_type(names) local typ = self:find_var_type(names[1], "use_type") if not typ then @@ -10470,14 +10479,12 @@ a.types[i], b.types[i]), } return a_type(node, "tuple", { tuple = { bool } }) end + local ftype = table.remove(b.tuple, 1) - local ftype = table.remove(b.tuple, 1) - if ftype.typename == "function" then - ftype = shallow_copy_new_type(ftype) - ftype.is_method = false - end + + ftype = ensure_not_method(ftype) local fe2 = node_at(node.e2, {}) if node.e1.tk == "xpcall" then @@ -10869,12 +10876,11 @@ self:expand_type(node, values, elements) }) if infertype.typename == "unresolvable_typearg" then ok = false infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") - elseif infertype.typename == "function" and infertype.is_method then + else - infertype = shallow_copy_new_type(infertype) - infertype.is_method = false + infertype = ensure_not_method(infertype) end end end @@ -11709,13 +11715,10 @@ self:expand_type(node, values, elements) }) vtype = node.itemtype self:assert_is_a(node.value, children[2], node.itemtype, node) end - if vtype.typename == "function" and vtype.is_method then - vtype = shallow_copy_new_type(vtype) - vtype.is_method = false - end + vtype = ensure_not_method(vtype) return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, diff --git a/tl.tl b/tl.tl index 4308d907f..49b46b651 100644 --- a/tl.tl +++ b/tl.tl @@ -7261,6 +7261,15 @@ do return true end + local function ensure_not_method(t: Type): Type + + if t is FunctionType and t.is_method then + t = shallow_copy_new_type(t) + t.is_method = false + end + return t + end + function TypeChecker:find_type(names: {string}): TypeDeclType, TypeArgType local typ = self:find_var_type(names[1], "use_type") if not typ then @@ -10470,14 +10479,12 @@ do return a_tuple(node, { bool }) end + local ftype = table.remove(b.tuple, 1) + -- The function called by pcall/xpcall is invoked as a regular function, -- so we wish to avoid incorrect error messages / unnecessary warning messages -- associated with calling methods as functions - local ftype = table.remove(b.tuple, 1) - if ftype is FunctionType then - ftype = shallow_copy_new_type(ftype) - ftype.is_method = false - end + ftype = ensure_not_method(ftype) local fe2: Node = node_at(node.e2, {}) if node.e1.tk == "xpcall" then @@ -10869,12 +10876,11 @@ do if infertype is UnresolvableTypeArgType then ok = false infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") - elseif infertype is FunctionType and infertype.is_method then + else -- If we assign a method to a variable, e.g: -- `local myfunc = myobj.dothing`, -- the variable should not be treated as a method - infertype = shallow_copy_new_type(infertype) - infertype.is_method = false + infertype = ensure_not_method(infertype) end end end @@ -11709,13 +11715,10 @@ do vtype = node.itemtype self:assert_is_a(node.value, children[2], node.itemtype, node) end - if vtype is FunctionType and vtype.is_method then - -- If we assign a method to a table item, e.g. - -- `local a = { myfunc = myobj.dothing }` - -- the table item should not be treated as a method - vtype = shallow_copy_new_type(vtype) - vtype.is_method = false - end + -- If we assign a method to a table item, e.g. + -- `local a = { myfunc = myobj.dothing }` + -- the table item should not be treated as a method + vtype = ensure_not_method(vtype) return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, From cb3369290237f94ce216e11bd4ae1a5ead154f10 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 9 Sep 2024 00:34:37 -0300 Subject: [PATCH 201/224] refactor: are_same_nominals --- tl.lua | 56 ++++++++++++++++++++++++++++++++++---------------------- tl.tl | 56 ++++++++++++++++++++++++++++++++++---------------------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/tl.lua b/tl.lua index a2466161d..3cae101c2 100644 --- a/tl.lua +++ b/tl.lua @@ -8085,35 +8085,45 @@ do return false, { Err(t1name .. " is not a " .. t2name) } end + local function nominal_found_type(self, nom) + local typedecl = nom.found + if not typedecl then + typedecl = self:find_type(nom.names) + if not typedecl then + return nil + end + end + local t = typedecl.def + + return t + end + function TypeChecker:are_same_nominals(t1, t2) - local same_names - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or self:find_type(t1.names) - local ft2 = t2.found or self:find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid - else - if are_same_unresolved_global_type(self, t1, t2) then - return true - end + local t1f = nominal_found_type(self, t1) + local t2f = nominal_found_type(self, t2) + if (not t1f or not t2f) then + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - self.errs:add(t1, "unknown type %s", t1) - end - if not ft2 then - self.errs:add(t2, "unknown type %s", t2) - end - return false, {} + if not t1f then + self.errs:add(t1, "unknown type %s", t1) + end + if not t2f then + self.errs:add(t2, "unknown type %s", t2) end + return false, {} end - if not same_names then + if t1f.typeid ~= t2f.typeid then return fail_nominals(self, t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then + end + + if t1.typevals == nil and t2.typevals == nil then return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + end + + if t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then local errs = {} for i = 1, #t1.typevals do local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) @@ -8121,6 +8131,8 @@ do end return any_errors(errs) end + + return true end end diff --git a/tl.tl b/tl.tl index 49b46b651..7dd411ac5 100644 --- a/tl.tl +++ b/tl.tl @@ -8085,35 +8085,45 @@ do return false, { Err(t1name .. " is not a " .. t2name) } end + local function nominal_found_type(self: TypeChecker, nom: NominalType): Type + local typedecl = nom.found + if not typedecl then + typedecl = self:find_type(nom.names) + if not typedecl then + return nil + end + end + local t = typedecl.def + + return t + end + function TypeChecker:are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local same_names: boolean - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or self:find_type(t1.names) - local ft2 = t2.found or self:find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid - else - if are_same_unresolved_global_type(self, t1, t2) then - return true - end + local t1f = nominal_found_type(self, t1) + local t2f = nominal_found_type(self, t2) + if (not t1f or not t2f) then + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - self.errs:add(t1, "unknown type %s", t1) - end - if not ft2 then - self.errs:add(t2, "unknown type %s", t2) - end - return false, {} -- errors were already produced + if not t1f then + self.errs:add(t1, "unknown type %s", t1) + end + if not t2f then + self.errs:add(t2, "unknown type %s", t2) end + return false, {} -- errors were already produced end - if not same_names then + if t1f.typeid ~= t2f.typeid then return fail_nominals(self, t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then + end + + if t1.typevals == nil and t2.typevals == nil then return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + end + + if t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then local errs = {} for i = 1, #t1.typevals do local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) @@ -8121,6 +8131,8 @@ do end return any_errors(errs) end + + -- FIXME what if presence and arities of typevals don't match?... return true end end From 7a117b98f6831aa7245666e0481cb1d79c509a83 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 9 Sep 2024 17:58:25 -0300 Subject: [PATCH 202/224] check_poly_call: simplify use --- tl.lua | 16 +++++++--------- tl.tl | 20 +++++++++----------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/tl.lua b/tl.lua index 3cae101c2..61b37df57 100644 --- a/tl.lua +++ b/tl.lua @@ -9393,10 +9393,13 @@ a.types[i], b.types[i]), } local given = #args.tuple local tried = {} + local first_rets local first_errs for pass = 1, 3 do for i, f in ipairs(p.types) do + first_rets = first_rets or f.rets + local wanted = #f.args.tuple local min_arity = self.feat_arity and f.min_arity or 0 @@ -9410,7 +9413,7 @@ a.types[i], b.types[i]), } local ok, errs = check_call(self, w, wargs, f, args, expected_rets, cm, argdelta) if ok then - return f + return f, f.rets elseif expected_rets then infer_emptytables(self, w, wargs, f.rets, f.rets, argdelta) @@ -9425,10 +9428,10 @@ a.types[i], b.types[i]), } end if not first_errs then - return nil, fail_poly_call_arity(w, p, given) + return nil, first_rets, fail_poly_call_arity(w, p, given) end - return nil, first_errs + return nil, first_rets, first_errs end end @@ -9477,12 +9480,7 @@ a.types[i], b.types[i]), } self:begin_scope() if func.typename == "poly" then - f, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) - if f then - ret = f.rets - else - ret = func.types[1].rets - end + f, ret, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) elseif func.typename == "function" then ok, errs = check_call(self, node, e2, func, args, expected_rets, cm, argdelta) f, ret = func, func.rets diff --git a/tl.tl b/tl.tl index 7dd411ac5..5621b33ea 100644 --- a/tl.tl +++ b/tl.tl @@ -9373,7 +9373,7 @@ do end end - local check_poly_call: function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, {Error} + local check_poly_call: function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, TupleType, {Error} do local function fail_poly_call_arity(w: Where, p: PolyType, given: integer): {Error} local expects: {string} = {} @@ -9389,14 +9389,17 @@ do return { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. table.concat(expects, " or ") .. ")") } end - check_poly_call = function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, {Error} + check_poly_call = function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, TupleType, {Error} local given = #args.tuple local tried: {integer:boolean} = {} + local first_rets: TupleType local first_errs: {Error} for pass = 1, 3 do for i, f in ipairs(p.types) do + first_rets = first_rets or f.rets + local wanted = #f.args.tuple local min_arity = self.feat_arity and f.min_arity or 0 @@ -9410,7 +9413,7 @@ do then local ok, errs = check_call(self, w, wargs, f, args, expected_rets, cm, argdelta) if ok then - return f + return f, f.rets elseif expected_rets then -- revert inferred returns infer_emptytables(self, w, wargs, f.rets, f.rets, argdelta) @@ -9425,10 +9428,10 @@ do end if not first_errs then - return nil, fail_poly_call_arity(w, p, given) + return nil, first_rets, fail_poly_call_arity(w, p, given) end - return nil, first_errs + return nil, first_rets, first_errs end end @@ -9477,12 +9480,7 @@ do self:begin_scope() if func is PolyType then - f, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) - if f then - ret = f.rets - else - ret = func.types[1].rets - end + f, ret, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) elseif func is FunctionType then ok, errs = check_call(self, node, e2, func, args, expected_rets, cm, argdelta) f, ret = func, func.rets From f157d8450556d50f7fdc63040a2848efa0856b0d Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 9 Sep 2024 22:47:04 -0300 Subject: [PATCH 203/224] refactor: move {} inference out of assert_is_a, into is_a --- spec/inference/emptytable_spec.lua | 2 +- tl.lua | 101 +++++++++++++++++------------ tl.tl | 101 +++++++++++++++++------------ 3 files changed, 123 insertions(+), 81 deletions(-) diff --git a/spec/inference/emptytable_spec.lua b/spec/inference/emptytable_spec.lua index a462cd396..2781a9d08 100644 --- a/spec/inference/emptytable_spec.lua +++ b/spec/inference/emptytable_spec.lua @@ -19,7 +19,7 @@ describe("empty table without type annotation", function() t.foo = "bar" ]], { - { msg = [[cannot index key 'foo' in array 't' of type {integer} (inferred at foo.tl:3:11)]] }, + { msg = [[cannot index key 'foo' in array 't' of type {integer} (inferred at foo.tl:3:10)]] }, })) it("first use can be a function call", util.check([[ diff --git a/tl.lua b/tl.lua index 61b37df57..a51134444 100644 --- a/tl.lua +++ b/tl.lua @@ -7525,6 +7525,9 @@ do elseif t.typename == "unresolvable_typearg" then assert(copy.typename == "unresolvable_typearg") copy.typearg = t.typearg + elseif t.typename == "unresolved_emptytable_value" then + assert(copy.typename == "unresolved_emptytable_value") + copy.emptytable_type = t.emptytable_type elseif t.typename == "typevar" then assert(copy.typename == "typevar") copy.typevar = t.typevar @@ -8443,6 +8446,18 @@ do return true end + local function compare_true_inferring_emptytable(self, a, b) + self:infer_emptytable(b, self:infer_at(b, a)) + return true + end + + local function compare_true_inferring_emptytable_if_not_userdata(self, a, b) + if a.is_userdata then + return false, { Err("{} cannot be used with userdata type %s", a) } + end + return compare_true_inferring_emptytable(self, a, b) + end + local emptytable_relations = { ["array"] = compare_true, @@ -8482,16 +8497,19 @@ do end return true end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["array"] = { ["array"] = function(self, a, b) return self:same_type(a.elements, b.elements) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["map"] = { ["map"] = function(self, a, b) return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["union"] = { ["union"] = function(self, a, b) @@ -8504,11 +8522,13 @@ do }, ["record"] = { ["record"] = TypeChecker.eqtype_record, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["interface"] = { ["interface"] = function(_self, a, b) return a.typeid == b.typeid end, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["function"] = { ["function"] = function(self, a, b) @@ -8557,6 +8577,9 @@ do } TypeChecker.subtype_relations = { + ["nil"] = { + ["*"] = compare_true, + }, ["tuple"] = { ["tuple"] = function(self, a, b) local at, bt = a.tuple, b.tuple @@ -8586,9 +8609,6 @@ do return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, - ["nil"] = { - ["*"] = compare_true, - }, ["union"] = { ["nominal"] = function(self, a, b) @@ -8697,6 +8717,7 @@ do ["tupletable"] = function(self, a, b) return self.subtype_relations["record"]["tupletable"](self, a, b) end, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { @@ -8739,6 +8760,7 @@ a.types[i], b.types[i]), } return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["record"] = { ["record"] = TypeChecker.subtype_record, @@ -8774,6 +8796,7 @@ a.types[i], b.types[i]), } return self.subtype_relations["array"]["tupletable"](self, a, b) end end, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["array"] = { ["array"] = TypeChecker.subtype_array, @@ -8800,6 +8823,7 @@ a.types[i], b.types[i]), } end return true end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["map"] = { ["map"] = function(self, a, b) @@ -8808,6 +8832,7 @@ a.types[i], b.types[i]), } ["array"] = function(self, a, b) return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["typedecl"] = { ["record"] = function(self, a, b) @@ -8875,6 +8900,19 @@ a.types[i], b.types[i]), } ["*"] = { ["any"] = compare_true, ["boolean_context"] = compare_true, + ["emptytable"] = function(_self, a, _b) + return false, { Err("assigning %s to a variable declared with {}", a) } + end, + ["unresolved_emptytable_value"] = function(self, a, b) + local bt = b.emptytable_type + assert(bt.typename == "emptytable", b.typename) + local bkeys = bt.keys + local infer_to = is_numeric_type(bkeys) and + a_type(b, "array", { elements = a }) or + a_type(b, "map", { keys = bkeys, values = a }) + self:infer_emptytable(bt, self:infer_at(b, infer_to)) + return true + end, ["self"] = function(self, a, b) return self:is_a(a, self:type_of_self(b)) end, @@ -8907,29 +8945,29 @@ a.types[i], b.types[i]), } TypeChecker.type_priorities = { - ["self"] = 1, - ["tuple"] = 2, - ["typevar"] = 3, - ["nil"] = 4, - ["any"] = 5, - ["boolean_context"] = 5, - ["union"] = 6, - ["poly"] = 7, - - ["typearg"] = 8, + ["nil"] = 0, + ["unresolved_emptytable_value"] = 1, + ["emptytable"] = 2, + ["self"] = 3, + ["tuple"] = 4, + ["typevar"] = 5, + ["any"] = 6, + ["boolean_context"] = 7, + ["union"] = 8, + ["poly"] = 9, - ["nominal"] = 9, + ["typearg"] = 10, - ["enum"] = 10, - ["string"] = 10, - ["integer"] = 10, - ["boolean"] = 10, + ["nominal"] = 11, - ["interface"] = 11, + ["enum"] = 12, + ["string"] = 12, + ["integer"] = 12, + ["boolean"] = 12, - ["emptytable"] = 12, - ["tupletable"] = 13, + ["interface"] = 13, + ["tupletable"] = 14, ["record"] = 14, ["array"] = 14, ["map"] = 14, @@ -9000,25 +9038,8 @@ a.types[i], b.types[i]), } return true end - - if t1.typename == "nil" then - return true - elseif t2.typename == "unresolved_emptytable_value" then - local t2keys = t2.emptytable_type.keys - if is_numeric_type(t2keys) then - self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "array", { elements = t1 }))) - else - self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "map", { keys = t2keys, values = t1 }))) - end - return true - elseif t2.typename == "emptytable" then - if is_lua_table_type(t1) then - self:infer_emptytable(t2, self:infer_at(w, t1)) - elseif not (t1.typename == "emptytable") then - self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) - return false - end - return true + if t2.typename == "emptytable" then + t2 = type_at(w, t2) end local ok, match_errs = self:is_a(t1, t2) diff --git a/tl.tl b/tl.tl index 5621b33ea..cf7e331fe 100644 --- a/tl.tl +++ b/tl.tl @@ -7525,6 +7525,9 @@ do elseif t is UnresolvableTypeArgType then assert(copy is UnresolvableTypeArgType) copy.typearg = t.typearg + elseif t is UnresolvedEmptyTableValueType then + assert(copy is UnresolvedEmptyTableValueType) + copy.emptytable_type = t.emptytable_type elseif t is TypeVarType then assert(copy is TypeVarType) copy.typevar = t.typevar @@ -8443,6 +8446,18 @@ do return true end + local function compare_true_inferring_emptytable(self: TypeChecker, a: Type, b: EmptyTableType): boolean, {Error} + self:infer_emptytable(b, self:infer_at(b, a)) + return true + end + + local function compare_true_inferring_emptytable_if_not_userdata(self: TypeChecker, a: RecordLikeType, b: EmptyTableType): boolean, {Error} + if a.is_userdata then + return false, { Err("{} cannot be used with userdata type %s", a) } + end + return compare_true_inferring_emptytable(self, a, b) + end + -- emptytable rules are the same in eqtype_relations and subtype_relations local emptytable_relations: {TypeName:CompareTypes} = { ["array"] = compare_true, @@ -8482,16 +8497,19 @@ do end return true end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["array"] = { ["array"] = function(self: TypeChecker, a: ArrayType, b: ArrayType): boolean, {Error} return self:same_type(a.elements, b.elements) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["map"] = { ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["union"] = { ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} @@ -8504,11 +8522,13 @@ do }, ["record"] = { ["record"] = TypeChecker.eqtype_record, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["interface"] = { ["interface"] = function(_self:TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} return a.typeid == b.typeid end, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["function"] = { ["function"] = function(self:TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} @@ -8557,6 +8577,9 @@ do } TypeChecker.subtype_relations = { + ["nil"] = { + ["*"] = compare_true, + }, ["tuple"] = { ["tuple"] = function(self: TypeChecker, a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] local at, bt = a.tuple, b.tuple -- ────────────────────────────────── @@ -8586,9 +8609,6 @@ do return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, - ["nil"] = { - ["*"] = compare_true, - }, ["union"] = { ["nominal"] = function(self: TypeChecker, a: UnionType, b: NominalType): boolean, {Error} -- match unions structurally @@ -8697,6 +8717,7 @@ do ["tupletable"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} return self.subtype_relations["record"]["tupletable"](self, a, b) end, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { @@ -8739,6 +8760,7 @@ do return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["record"] = { ["record"] = TypeChecker.subtype_record, @@ -8774,6 +8796,7 @@ do return self.subtype_relations["array"]["tupletable"](self, a, b) end end, + ["emptytable"] = compare_true_inferring_emptytable_if_not_userdata, }, ["array"] = { ["array"] = TypeChecker.subtype_array, @@ -8800,6 +8823,7 @@ do end return true end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["map"] = { ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} @@ -8808,6 +8832,7 @@ do ["array"] = function(self: TypeChecker, a: MapType, b: ArrayType): boolean, {Error} return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, + ["emptytable"] = compare_true_inferring_emptytable, }, ["typedecl"] = { ["record"] = function(self: TypeChecker, a: TypeDeclType, b: RecordType): boolean, {Error} @@ -8875,6 +8900,19 @@ do ["*"] = { ["any"] = compare_true, ["boolean_context"] = compare_true, + ["emptytable"] = function(_self: TypeChecker, a: Type, _b: EmptyTableType): boolean, {Error} + return false, { Err("assigning %s to a variable declared with {}", a) } + end, + ["unresolved_emptytable_value"] = function(self: TypeChecker, a: Type, b: UnresolvedEmptyTableValueType): boolean, {Error} + local bt = b.emptytable_type + assert(bt is EmptyTableType, b.typename) + local bkeys = bt.keys + local infer_to = bkeys is NumericType -- ideally integer only + and an_array(b, a) + or a_map(b, bkeys, a) + self:infer_emptytable(bt, self:infer_at(b, infer_to)) + return true + end, ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} return self:is_a(a, self:type_of_self(b)) end, @@ -8907,29 +8945,29 @@ do -- evaluation strategy TypeChecker.type_priorities = { -- types that have catch-all rules evaluate first - ["self"] = 1, - ["tuple"] = 2, - ["typevar"] = 3, - ["nil"] = 4, - ["any"] = 5, - ["boolean_context"] = 5, - ["union"] = 6, - ["poly"] = 7, + ["nil"] = 0, + ["unresolved_emptytable_value"] = 1, + ["emptytable"] = 2, + ["self"] = 3, + ["tuple"] = 4, + ["typevar"] = 5, + ["any"] = 6, + ["boolean_context"] = 7, + ["union"] = 8, + ["poly"] = 9, -- then typeargs - ["typearg"] = 8, + ["typearg"] = 10, -- then nominals - ["nominal"] = 9, + ["nominal"] = 11, -- then base types - ["enum"] = 10, - ["string"] = 10, - ["integer"] = 10, - ["boolean"] = 10, + ["enum"] = 12, + ["string"] = 12, + ["integer"] = 12, + ["boolean"] = 12, -- then interfaces - ["interface"] = 11, + ["interface"] = 13, -- then special cases of tables - ["emptytable"] = 12, - ["tupletable"] = 13, - -- then other recursive types + ["tupletable"] = 14, ["record"] = 14, ["array"] = 14, ["map"] = 14, @@ -9000,25 +9038,8 @@ do return true end - -- some flow-based inference - if t1.typename == "nil" then - return true - elseif t2 is UnresolvedEmptyTableValueType then - local t2keys = t2.emptytable_type.keys - if t2keys is NumericType then -- ideally integer only - self:infer_emptytable(t2.emptytable_type, self:infer_at(w, an_array(w, t1))) - else - self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_map(w, t2keys, t1))) - end - return true - elseif t2 is EmptyTableType then - if is_lua_table_type(t1) then - self:infer_emptytable(t2, self:infer_at(w, t1)) - elseif not t1 is EmptyTableType then - self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) - return false - end - return true + if t2 is EmptyTableType then + t2 = type_at(w, t2) -- minor hack: tweak point of inference end local ok, match_errs = self:is_a(t1, t2) From 9e7c64577819bb4e437594717b96e2826b2adcae Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 9 Sep 2024 22:52:55 -0300 Subject: [PATCH 204/224] tests: add #cli tag in spec/cli tests This makes it easy to skip those tests when a faster busted run is desired in exchange of completeness. --- spec/cli/check_spec.lua | 2 +- spec/cli/feat_spec.lua | 2 +- spec/cli/gen_spec.lua | 2 +- spec/cli/global_env_def_spec.lua | 2 +- spec/cli/include_dir_spec.lua | 2 +- spec/cli/output_spec.lua | 2 +- spec/cli/quiet_spec.lua | 2 +- spec/cli/run_spec.lua | 2 +- spec/cli/types_spec.lua | 2 +- spec/cli/warning_spec.lua | 4 ++-- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/spec/cli/check_spec.lua b/spec/cli/check_spec.lua index 347c369aa..17f5bf453 100644 --- a/spec/cli/check_spec.lua +++ b/spec/cli/check_spec.lua @@ -1,7 +1,7 @@ local assert = require("luassert") local util = require("spec.util") -describe("tl check", function() +describe("#cli tl check", function() describe("on .tl files", function() it("reports if file does not exist", function() local pd = io.popen(util.tl_cmd("check", "file_that_does_not_exist.tl") .. " 2>&1", "r") diff --git a/spec/cli/feat_spec.lua b/spec/cli/feat_spec.lua index 7741cc8f3..d96471bdb 100644 --- a/spec/cli/feat_spec.lua +++ b/spec/cli/feat_spec.lua @@ -52,7 +52,7 @@ local test_cases = { } } -describe("feat flags", function() +describe("#cli feat flags", function() for flag, tests in pairs(test_cases) do describe(flag, function() for _, case in ipairs(tests) do diff --git a/spec/cli/gen_spec.lua b/spec/cli/gen_spec.lua index 7f537e109..ae8460cb1 100644 --- a/spec/cli/gen_spec.lua +++ b/spec/cli/gen_spec.lua @@ -102,7 +102,7 @@ local function tl_to_lua(name) return (name:gsub("%.tl$", ".lua"):gsub("^" .. util.os_tmp .. util.os_sep, "")) end -describe("tl gen", function() +describe("#cli tl gen", function() setup(util.chdir_setup) teardown(util.chdir_teardown) describe("on .tl files", function() diff --git a/spec/cli/global_env_def_spec.lua b/spec/cli/global_env_def_spec.lua index 8dc71a01b..20c9bdaea 100644 --- a/spec/cli/global_env_def_spec.lua +++ b/spec/cli/global_env_def_spec.lua @@ -1,7 +1,7 @@ local assert = require("luassert") local util = require("spec.util") -describe("--global-env-def argument", function() +describe("#cli --global-env-def argument", function() it("exports globals from a module", function() util.do_in(util.write_tmp_dir(finally, { mod = { diff --git a/spec/cli/include_dir_spec.lua b/spec/cli/include_dir_spec.lua index 20bdbe00a..cadfb5d89 100644 --- a/spec/cli/include_dir_spec.lua +++ b/spec/cli/include_dir_spec.lua @@ -1,7 +1,7 @@ local assert = require("luassert") local util = require("spec.util") -describe("-I --include-dir argument", function() +describe("#cli -I --include-dir argument", function() it("adds a directory to package.path", function() util.do_in(util.write_tmp_dir(finally, { mod = { diff --git a/spec/cli/output_spec.lua b/spec/cli/output_spec.lua index 92956044f..b16dbccdb 100644 --- a/spec/cli/output_spec.lua +++ b/spec/cli/output_spec.lua @@ -1,6 +1,6 @@ local util = require("spec.util") -describe("-o --output", function() +describe("#cli -o --output", function() it("should gen in the current directory when not provided", function() util.run_mock_project(finally, { dir_structure = { diff --git a/spec/cli/quiet_spec.lua b/spec/cli/quiet_spec.lua index 11031fcc6..922b7cce9 100644 --- a/spec/cli/quiet_spec.lua +++ b/spec/cli/quiet_spec.lua @@ -1,6 +1,6 @@ local util = require("spec.util") -describe("-q --quiet flag", function() +describe("#cli -q --quiet flag", function() setup(util.chdir_setup) teardown(util.chdir_teardown) it("silences warnings from tlconfig.lua", function() diff --git a/spec/cli/run_spec.lua b/spec/cli/run_spec.lua index 26f9d7335..c785161c0 100644 --- a/spec/cli/run_spec.lua +++ b/spec/cli/run_spec.lua @@ -1,6 +1,6 @@ local util = require("spec.util") -describe("tl run", function() +describe("#cli tl run", function() setup(util.chdir_setup) teardown(util.chdir_teardown) describe("on .tl files", function() diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 61d41f549..bcb35f1ff 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -2,7 +2,7 @@ local assert = require("luassert") local json = require("dkjson") local util = require("spec.util") -describe("tl types works like check", function() +describe("#cli tl types works like check", function() describe("on .tl files", function() it("reports missing files", function() local pd = io.popen(util.tl_cmd("types", "nonexistent_file") .. "2>&1 1>" .. util.os_null, "r") diff --git a/spec/cli/warning_spec.lua b/spec/cli/warning_spec.lua index 60eb4a74e..0d28a8b92 100644 --- a/spec/cli/warning_spec.lua +++ b/spec/cli/warning_spec.lua @@ -2,7 +2,7 @@ local assert = require("luassert") local util = require("spec.util") local tl = require("tl") -describe("tl warnings", function() +describe("#cli tl warnings", function() it("reports existing warning types when given no arguments", function() local pd = io.popen(util.tl_cmd("warnings"), "r") local output = pd:read("*a") @@ -16,7 +16,7 @@ describe("tl warnings", function() end) end) -describe("warning flags", function() +describe("#cli warning flags", function() describe("in tlconfig.lua", function() describe("disable_warnings", function() it("disables the given warnings", function() From ece9e50bee2ebfc33da0bf982bb9490f65922116 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 10 Sep 2024 10:57:30 -0300 Subject: [PATCH 205/224] pending test: detect mismatch in function generics See #801. --- spec/declaration/record_function_spec.lua | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/spec/declaration/record_function_spec.lua b/spec/declaration/record_function_spec.lua index 75bd616f5..79b81cb64 100644 --- a/spec/declaration/record_function_spec.lua +++ b/spec/declaration/record_function_spec.lua @@ -94,6 +94,26 @@ describe("record function", function() { y = 5, msg = "different number of input arguments: got 2, expected 3" }, })) + pending("detect mismatch in function generics", util.check_type_error([[ + local type List2 = record + new: function(initialItems: {T}, u: U): List2 + end + + function List2.new(initialItems: {T}, u: U): List2 -- mismatched return type + end + + local type Fruit2 = enum + "apple" + "peach" + "banana" + end + + local type L2 = List2 + local lunchbox = L2.new({"apple", "peach"}, true) + ]], { + { msg = "type signature does not match declaration" } + })) + it("report error in return args correctly (regression test for #618)", util.check_warnings([[ local record R _current: R From 1c3ba3d1509770151628928aa3f451274bd669a0 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 10 Sep 2024 11:55:25 -0300 Subject: [PATCH 206/224] refactor: store_field_in_record --- tl.lua | 28 ++++++++++++++++------------ tl.tl | 28 ++++++++++++++++------------ 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/tl.lua b/tl.lua index a51134444..16db20282 100644 --- a/tl.lua +++ b/tl.lua @@ -3690,24 +3690,28 @@ do return i, node end - local function store_field_in_record(ps, i, field_name, t, fields, field_order) + local function store_field_in_record(ps, i, field_name, newt, fields, field_order) if not fields[field_name] then - fields[field_name] = t + fields[field_name] = newt table.insert(field_order, field_name) - else - local prev_t = fields[field_name] - if t.typename == "function" and prev_t.typename == "function" then + return true + end + + local oldt = fields[field_name] + + if newt.typename == "function" then + if oldt.typename == "function" then local p = new_type(ps, i, "poly") - p.types = { prev_t, t } + p.types = { oldt, newt } fields[field_name] = p - elseif t.typename == "function" and prev_t.typename == "poly" then - table.insert(prev_t.types, t) - else - fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") - return false + return true + elseif oldt.typename == "poly" then + table.insert(oldt.types, newt) + return true end end - return true + fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") + return false end local function set_declname(def, declname) diff --git a/tl.tl b/tl.tl index cf7e331fe..abe65ebc4 100644 --- a/tl.tl +++ b/tl.tl @@ -3690,24 +3690,28 @@ local function parse_return(ps: ParseState, i: integer): integer, Node return i, node end -local function store_field_in_record(ps: ParseState, i: integer, field_name: string, t: Type, fields: {string: Type}, field_order: {string}): boolean +local function store_field_in_record(ps: ParseState, i: integer, field_name: string, newt: Type, fields: {string: Type}, field_order: {string}): boolean if not fields[field_name] then - fields[field_name] = t + fields[field_name] = newt table.insert(field_order, field_name) - else - local prev_t = fields[field_name] - if t is FunctionType and prev_t is FunctionType then + return true + end + + local oldt = fields[field_name] + + if newt is FunctionType then + if oldt is FunctionType then local p = new_type(ps, i, "poly") as PolyType - p.types = { prev_t, t } + p.types = { oldt, newt } fields[field_name] = p - elseif t is FunctionType and prev_t is PolyType then - table.insert(prev_t.types, t) - else - fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") - return false + return true + elseif oldt is PolyType then + table.insert(oldt.types, newt) + return true end end - return true + fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") + return false end local function set_declname(def: Type, declname: string) From 8f63a8636e27b8a3d9c9fd6d120537d61a9a4767 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Sep 2024 10:50:37 -0300 Subject: [PATCH 207/224] assert: ignores additional arguments --- spec/stdlib/assert_spec.lua | 5 +++++ tl.lua | 2 +- tl.tl | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/spec/stdlib/assert_spec.lua b/spec/stdlib/assert_spec.lua index 07b0b2343..05b78a453 100644 --- a/spec/stdlib/assert_spec.lua +++ b/spec/stdlib/assert_spec.lua @@ -17,4 +17,9 @@ describe("assert", function() { y = 9, msg = "cannot use operator '..' for types string | boolean and string" }, { y = 11, msg = "cannot use operator '..' for types string | boolean and string" }, })) + + it("ignores additional arguments", util.check([[ + local f = assert(io.open("nonexistent.txt")) + ]])) + end) diff --git a/tl.lua b/tl.lua index 16db20282..50bd17fc4 100644 --- a/tl.lua +++ b/tl.lua @@ -396,7 +396,7 @@ do type XpcallMsghFunction = function(...: any): () arg: {string} - assert: function(A, ? B): A + assert: function(A, ? B, ...: any): A collectgarbage: function(? CollectGarbageCommand): number collectgarbage: function(CollectGarbageSetValue, integer): number diff --git a/tl.tl b/tl.tl index abe65ebc4..cc65b2222 100644 --- a/tl.tl +++ b/tl.tl @@ -396,7 +396,7 @@ do type XpcallMsghFunction = function(...: any): () arg: {string} - assert: function(A, ? B): A + assert: function(A, ? B, ...: any): A collectgarbage: function(? CollectGarbageCommand): number collectgarbage: function(CollectGarbageSetValue, integer): number From f0d714e5511aac74b270cc1f146c73deba1f72bc Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Sep 2024 16:50:04 -0300 Subject: [PATCH 208/224] fix: do not infer typevars on failed poly match In `assert(f:read("a"))`, the return type of the failed poly match of `file:read` was matching the `A` typevar of `assert`. We now reset the temporary scope of a call typecheck fully when backtracking to test different poly entries (previously we were backtracking only the poly's own typeargs). --- spec/stdlib/io_spec.lua | 6 ++++++ tl.lua | 41 +++++++++++++++++++---------------------- tl.tl | 41 +++++++++++++++++++---------------------- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/spec/stdlib/io_spec.lua b/spec/stdlib/io_spec.lua index 7b785a5b6..1b3c2e329 100644 --- a/spec/stdlib/io_spec.lua +++ b/spec/stdlib/io_spec.lua @@ -25,6 +25,12 @@ describe("io", function() print(n + m) ]])) + it("with a string format, in assert", util.check([[ + local f = assert(io.open("file.txt", "rb")) + local r = assert(f:read("a")) + f:close() + ]])) + it("with multiple formats", util.check([[ local a, b, c = io.read("l", 12, 13) print(a:upper()) diff --git a/tl.lua b/tl.lua index 50bd17fc4..ea0e14850 100644 --- a/tl.lua +++ b/tl.lua @@ -9275,26 +9275,6 @@ a.types[i], b.types[i]), } end end - local function push_typeargs(self, func) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { - constraint = fnarg.constraint, - })) - end - end - end - - local function pop_typeargs(self, func) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - if self.st[#self.st].vars[fnarg.typearg] then - self.st[#self.st].vars[fnarg.typearg] = nil - end - end - end - end - @@ -9373,6 +9353,16 @@ a.types[i], b.types[i]), } return false end + local function add_call_typeargs(self, func) + if func.typeargs then + for _, fnarg in ipairs(func.typeargs) do + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { + constraint = fnarg.constraint, + })) + end + end + end + check_call = function(self, w, wargs, f, args, expected_rets, cm, argdelta) local arg1 = args.tuple[1] if cm == "method" and arg1 then @@ -9392,7 +9382,7 @@ a.types[i], b.types[i]), } return nil, { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. show_arity(f) .. ")") } end - push_typeargs(self, f) + add_call_typeargs(self, f) return check_args_rets(self, w, wargs, f, args, expected_rets, argdelta) end @@ -9414,6 +9404,13 @@ a.types[i], b.types[i]), } return { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. table.concat(expects, " or ") .. ")") } end + local function reset_scope(self) + local vars = self.st[#self.st].vars + for k, _ in pairs(vars) do + vars[k] = nil + end + end + check_poly_call = function(self, w, wargs, p, args, expected_rets, cm, argdelta) local given = #args.tuple @@ -9444,7 +9441,7 @@ a.types[i], b.types[i]), } infer_emptytables(self, w, wargs, f.rets, f.rets, argdelta) end - pop_typeargs(self, f) + reset_scope(self) first_errs = first_errs or errs tried[i] = true diff --git a/tl.tl b/tl.tl index cc65b2222..c8767799b 100644 --- a/tl.tl +++ b/tl.tl @@ -9275,26 +9275,6 @@ do end end - local function push_typeargs(self: TypeChecker, func: FunctionType) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { - constraint = fnarg.constraint, - } as UnresolvedTypeArgType)) - end - end - end - - local function pop_typeargs(self: TypeChecker, func: FunctionType) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - if self.st[#self.st].vars[fnarg.typearg] then - self.st[#self.st].vars[fnarg.typearg] = nil - end - end - end - end - local enum CallMode "method" -- a method colon-call, e.g. `my_object:my_method()` "plain" -- a plain call or a dot-call, e.g `my_func()` or `my_object.my_func()` @@ -9373,6 +9353,16 @@ do return false end + local function add_call_typeargs(self: TypeChecker, func: FunctionType) + if func.typeargs then + for _, fnarg in ipairs(func.typeargs) do + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { + constraint = fnarg.constraint, + } as UnresolvedTypeArgType)) + end + end + end + check_call = function(self: TypeChecker, w: Where, wargs: {Where}, f: FunctionType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): boolean, {Error} local arg1 = args.tuple[1] if cm == "method" and arg1 then @@ -9392,7 +9382,7 @@ do return nil, { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. show_arity(f) .. ")") } end - push_typeargs(self, f) + add_call_typeargs(self, f) return check_args_rets(self, w, wargs, f, args, expected_rets, argdelta) end @@ -9414,6 +9404,13 @@ do return { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. table.concat(expects, " or ") .. ")") } end + local function reset_scope(self: TypeChecker) + local vars = self.st[#self.st].vars + for k, _ in pairs(vars) do + vars[k] = nil + end + end + check_poly_call = function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, TupleType, {Error} local given = #args.tuple @@ -9444,7 +9441,7 @@ do infer_emptytables(self, w, wargs, f.rets, f.rets, argdelta) end - pop_typeargs(self, f) + reset_scope(self) first_errs = first_errs or errs tried[i] = true From cfc986cf762f6ef0d4ad517ae22e30a578355244 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Sep 2024 16:52:15 -0300 Subject: [PATCH 209/224] string.gsub: accept functions with no returns Note that we're doing a little hack here, but notating the return type as vararg, as a stand-in for something like a `?` in the return type. I guess we need `min_arity` calculations in the return types as well. Also fix the declaration of `gsub` because Lua does not accept functions that return booleans, but it does accept numbers. --- spec/stdlib/string_spec.lua | 71 +++++++++++++++++++++++++++++++++++++ tl.lua | 2 +- tl.tl | 2 +- 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/spec/stdlib/string_spec.lua b/spec/stdlib/string_spec.lua index e2a213be9..6e4ba5d4c 100644 --- a/spec/stdlib/string_spec.lua +++ b/spec/stdlib/string_spec.lua @@ -22,4 +22,75 @@ describe("string", function() ]])) end) + describe("gsub", function() + it("accepts a string, returns a string", util.check([[ + local s = "hello" + local hi: string = s:gsub("ello", "i") + ]])) + + it("accepts a string, returns a string and integer", util.check([[ + local s = "hello world" + local wordword, count: string, integer = s:gsub("%w+", "word") + ]])) + + it("accepts a string and integer, returns a string and integer", util.check([[ + local s = "hello world" + local helloword, count: string, integer = s:gsub("%w+", "word", 6) + ]])) + + it("accepts a map, returns a string and integer", util.check([[ + local s = "hello world" + local map = { + ["hello"] = "hola", + ["world"] = "mundo", + } + local holamundo, count: string, integer = s:gsub("%w+", map) + ]])) + + it("accepts a map and integer, returns a string and integer", util.check([[ + local s = "hello world" + local map = { + ["hello"] = "hola", + ["world"] = "mundo", + } + local hellomundo, count: string, integer = s:gsub("%w+", map, 6) + ]])) + + it("accepts a function to strings, returns a string", util.check([[ + local s = "hello world" + local function f(x: string): string + return x:upper() + end + local ret: string = s:gsub("%w+", f) + ]])) + + it("accepts a function to integers, returns a string", util.check([[ + local s = "hello world" + local function f(x: string): integer + return #x + end + local ret: string = s:gsub("%w+", f) + ]])) + + it("accepts a function to numbers, returns a string", util.check([[ + local s = "hello world" + local function f(x: string): number + return #x * 1.5 + end + local ret: string = s:gsub("%w+", f) + ]])) + + it("accepts a function that returns nothing", util.check([[ + local function parse_integers(s: string, i0: integer) : {integer} + local t, p = {}, i0 or 1 + local function f(x: string) + t[p] = math.tointeger(x) + p = p + 1 + end + s:gsub("[-%d]+", f) + return t + end + ]])) + end) + end) diff --git a/tl.lua b/tl.lua index ea0e14850..f3da48189 100644 --- a/tl.lua +++ b/tl.lua @@ -317,7 +317,7 @@ do gsub: function(string, string, string, ? integer): string, integer gsub: function(string, string, {string:string}, ? integer): string, integer - gsub: function(string, string, function(string...): (string | integer | boolean), ? integer): string, integer + gsub: function(string, string, function(string...): ((string | integer | number)...), ? integer): string, integer len: function(string): integer lower: function(string): string diff --git a/tl.tl b/tl.tl index c8767799b..7ba2888b2 100644 --- a/tl.tl +++ b/tl.tl @@ -317,7 +317,7 @@ do gsub: function(string, string, string, ? integer): string, integer gsub: function(string, string, {string:string}, ? integer): string, integer - gsub: function(string, string, function(string...): (string | integer | boolean), ? integer): string, integer + gsub: function(string, string, function(string...): ((string | integer | number)...), ? integer): string, integer len: function(string): integer lower: function(string): string From ccd2417a7d34dfa95dd040c8e977fbdb14358b72 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Sep 2024 18:14:29 -0300 Subject: [PATCH 210/224] improve inference of nested emptytables See https://github.com/teal-language/tl/pull/732#issuecomment-2359380390 > There is one remaining tricky error triggering in the day12 for which I > already found a workaround (but which would merit a nicer solution since this > uncovered the fact that my empty-table inference does not backpropagate > correctly in the case of map[c1][c2] = true). I'll try to code a solution > quickly, otherwise I'll document the edge case in the testsuite and go with > the workaround. This is the workaround, and the documentation of the edge case is in the test case included in this commit. --- spec/inference/emptytable_spec.lua | 46 ++++++++++++++++++++++++++++++ tl.lua | 26 ++++++++++++----- tl.tl | 26 ++++++++++++----- 3 files changed, 84 insertions(+), 14 deletions(-) diff --git a/spec/inference/emptytable_spec.lua b/spec/inference/emptytable_spec.lua index 2781a9d08..906d6c7f4 100644 --- a/spec/inference/emptytable_spec.lua +++ b/spec/inference/emptytable_spec.lua @@ -150,4 +150,50 @@ describe("empty table without type annotation", function() print(x) end ]])) + + it("does not fail when resolving nested emptytables", util.check([[ + local function parse_input() : {string: {string: boolean}} + local m = {} + for line in io.lines("input.txt") do + local c1, c2 = line:match("^(%w+)-(%w+)$") + if not m[c1] then m[c1] = {} end + if not m[c2] then m[c2] = {} end + + -- Summary of the emptytable propagation (desired) behavior: + + local x = m[c1] + -- infer x to { typename = "unresolved_emptytable_value", emptytable_type = m, keys = STRING } + + local y = x[c2] + -- here we want to: + -- declare a new_emptytable + -- infer m to { typename = "map", keys = STRING, values = new_emptytable } + -- infer y to { typename = "unresolved_emptytable_value", emptytable_type = new_emptytable, keys = "string" } + + y = true + -- here we want to: + -- infer y to boolean + -- infer emptytable_type to { typename = "map", keys = "string", values = "boolean" } + -- by propagation, infer m to { typename = "map", keys = STRING, values = { typename = "map", keys = "string", values = "boolean" } } + -- FIXME: this is not propagating backwards correctly (probably because table objects are copied) + + -- same thing as the above, but written in the + -- idiomatic style as it first appeared in @catwell's code: + m[c1][c2] = true + m[c2][c1] = true + end + return m + end + ]])) + + it("does not fail when resolving nested emptytables, three levels deep", util.check([[ + local function f(a: string, b: string, c: string) : {string: {string: {string: boolean}}} + local m = {} + if not m[a] then m[a] = {} end + if not m[a][b] then m[a][b] = {} end + m[a][b][c] = true + return m + end + ]])) + end) diff --git a/tl.lua b/tl.lua index f3da48189..044dcfbfd 100644 --- a/tl.lua +++ b/tl.lua @@ -8462,6 +8462,18 @@ do return compare_true_inferring_emptytable(self, a, b) end + local function infer_emptytable_from_unresolved_value(self, w, u, values) + local et = u.emptytable_type + assert(et.typename == "emptytable", u.typename) + local keys = et.keys + if not (values.typename == "emptytable" or values.typename == "unresolved_emptytable_value") then + local infer_to = is_numeric_type(keys) and + a_type(w, "array", { elements = values }) or + a_type(w, "map", { keys = keys, values = values }) + self:infer_emptytable(et, self:infer_at(w, infer_to)) + end + end + local emptytable_relations = { ["array"] = compare_true, @@ -8908,13 +8920,7 @@ a.types[i], b.types[i]), } return false, { Err("assigning %s to a variable declared with {}", a) } end, ["unresolved_emptytable_value"] = function(self, a, b) - local bt = b.emptytable_type - assert(bt.typename == "emptytable", b.typename) - local bkeys = bt.keys - local infer_to = is_numeric_type(bkeys) and - a_type(b, "array", { elements = a }) or - a_type(b, "map", { keys = bkeys, values = a }) - self:infer_emptytable(bt, self:infer_at(b, infer_to)) + infer_emptytable_from_unresolved_value(self, b, b, a) return true end, ["self"] = function(self, a, b) @@ -9911,6 +9917,12 @@ a.types[i], b.types[i]), } end errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys + elseif ra.typename == "unresolved_emptytable_value" then + local et = a_type(ra, "emptytable", { keys = b }) + infer_emptytable_from_unresolved_value(self, a, ra, et) + return a_type(anode, "unresolved_emptytable_value", { + emptytable_type = et, + }) elseif ra.typename == "map" then if self:is_a(b, ra.keys) then return ra.values diff --git a/tl.tl b/tl.tl index 7ba2888b2..dad8de5ab 100644 --- a/tl.tl +++ b/tl.tl @@ -8462,6 +8462,18 @@ do return compare_true_inferring_emptytable(self, a, b) end + local function infer_emptytable_from_unresolved_value(self: TypeChecker, w: Where, u: UnresolvedEmptyTableValueType, values: Type) + local et = u.emptytable_type + assert(et is EmptyTableType, u.typename) + local keys = et.keys + if not (values is EmptyTableType or values is UnresolvedEmptyTableValueType) then + local infer_to = keys is NumericType -- ideally integer only + and an_array(w, values) + or a_map(w, keys, values) + self:infer_emptytable(et, self:infer_at(w, infer_to)) + end + end + -- emptytable rules are the same in eqtype_relations and subtype_relations local emptytable_relations: {TypeName:CompareTypes} = { ["array"] = compare_true, @@ -8908,13 +8920,7 @@ do return false, { Err("assigning %s to a variable declared with {}", a) } end, ["unresolved_emptytable_value"] = function(self: TypeChecker, a: Type, b: UnresolvedEmptyTableValueType): boolean, {Error} - local bt = b.emptytable_type - assert(bt is EmptyTableType, b.typename) - local bkeys = bt.keys - local infer_to = bkeys is NumericType -- ideally integer only - and an_array(b, a) - or a_map(b, bkeys, a) - self:infer_emptytable(bt, self:infer_at(b, infer_to)) + infer_emptytable_from_unresolved_value(self, b, b, a) return true end, ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} @@ -9911,6 +9917,12 @@ do end errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys + elseif ra is UnresolvedEmptyTableValueType then + local et = a_type(ra, "emptytable", { keys = b } as EmptyTableType) + infer_emptytable_from_unresolved_value(self, a, ra, et) + return a_type(anode, "unresolved_emptytable_value", { + emptytable_type = et + } as UnresolvedEmptyTableValueType) elseif ra is MapType then if self:is_a(b, ra.keys) then return ra.values From 58ebd4b7ab2c57dcd19f5c5e29c69221d8a016aa Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Sep 2024 22:22:18 -0300 Subject: [PATCH 211/224] return: module returns the nominal's type, including typeargs Fixes #804. --- spec/statement/return_spec.lua | 88 ++++++++++++++++++++++++++++++++++ tl.lua | 8 +++- tl.tl | 8 +++- 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/spec/statement/return_spec.lua b/spec/statement/return_spec.lua index b52fc912a..1b9f17949 100644 --- a/spec/statement/return_spec.lua +++ b/spec/statement/return_spec.lua @@ -153,6 +153,94 @@ describe("return", function() assert.same({}, result.syntax_errors) assert.same({}, result.type_errors) end) + + it("when exporting a generic (regression test for #804)", function () + util.mock_io(finally, { + ["foo.tl"] = [[ + local record Foo + bar: T + end + return Foo + ]], + ["main.tl"] = [[ + local Foo = require("foo") + + local foo: Foo + + foo = { + bar = 5 + } + + print(string.format("bar: %d", foo.bar + 1)) + ]], + }) + + local tl = require("tl") + local result, err = tl.process("main.tl", assert(tl.init_env())) + + assert.same(nil, err) + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("when exporting a typealias (variation on regression test for #804)", function () + util.mock_io(finally, { + ["foo.tl"] = [[ + local record Foo + bar: T + end + local type FooInteger = Foo + return FooInteger + ]], + ["main.tl"] = [[ + local Foo = require("foo") + + local foo: Foo + + foo = { + bar = 5 + } + + print(string.format("bar: %d", foo.bar + 1)) + ]], + }) + + local tl = require("tl") + local result, err = tl.process("main.tl", assert(tl.init_env())) + + assert.same(nil, err) + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("when exporting a non-generic (variation on regression test for #804)", function () + util.mock_io(finally, { + ["foo.tl"] = [[ + local record Foo + bar: integer + end + return Foo + ]], + ["main.tl"] = [[ + local Foo = require("foo") + + local foo: Foo + + foo = { + bar = 5 + } + + print(string.format("bar: %d", foo.bar + 1)) + ]], + }) + + local tl = require("tl") + local result, err = tl.process("main.tl", assert(tl.init_env())) + + assert.same(nil, err) + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) end) it("when exporting type alias through multiple levels", function () diff --git a/tl.lua b/tl.lua index 044dcfbfd..4e2daccd8 100644 --- a/tl.lua +++ b/tl.lua @@ -11527,7 +11527,13 @@ self:expand_type(node, values, elements) }) if not expected then expected = self:infer_at(node, got) - self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + local module_type = resolve_tuple(expected) + if module_type.typename == "nominal" then + self:resolve_nominal(module_type) + self.module_type = module_type.found + else + self.module_type = drop_constant_value(module_type) + end self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple diff --git a/tl.tl b/tl.tl index dad8de5ab..bd3a043d0 100644 --- a/tl.tl +++ b/tl.tl @@ -11527,7 +11527,13 @@ do if not expected then -- if at the toplevel expected = self:infer_at(node, got) - self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + local module_type = resolve_tuple(expected) + if module_type is NominalType then + self:resolve_nominal(module_type) + self.module_type = module_type.found + else + self.module_type = drop_constant_value(module_type) + end self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple From af5aed5350464032ee33ecf170de2ae2b2d47c58 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 19 Sep 2024 14:57:20 -0300 Subject: [PATCH 212/224] fix: can assign an emptytable to an emptytable This should also avoid inferring types as `{} | {}`. --- spec/declaration/union_spec.lua | 8 ++++++++ spec/inference/emptytable_spec.lua | 13 +++++++++++++ tl.lua | 1 + tl.tl | 1 + 4 files changed, 23 insertions(+) diff --git a/spec/declaration/union_spec.lua b/spec/declaration/union_spec.lua index fa50f8a3c..43cc9c69d 100644 --- a/spec/declaration/union_spec.lua +++ b/spec/declaration/union_spec.lua @@ -139,4 +139,12 @@ describe("union declaration", function() { msg = "cannot discriminate a union between multiple function types" }, })) + it("collapses multiple emptytables on declaration", util.check([[ + local function count_sea_monsters(image: {string}) + local c, m, n = 0, {{}, {}}, 0 + for row = 1, #image - 2 do + m[row + 2] = {} + end + end + ]])) end) diff --git a/spec/inference/emptytable_spec.lua b/spec/inference/emptytable_spec.lua index 906d6c7f4..b9d6db751 100644 --- a/spec/inference/emptytable_spec.lua +++ b/spec/inference/emptytable_spec.lua @@ -196,4 +196,17 @@ describe("empty table without type annotation", function() end ]])) + it("can assign an emptytable to an emptytable", util.check([[ + local x = {} + + for i = 1, 20 do + if math.random(2) == 1 then + x = {} + else + x = {"hello"} + end + print(#x) + end + ]])) + end) diff --git a/tl.lua b/tl.lua index 4e2daccd8..9efdffc60 100644 --- a/tl.lua +++ b/tl.lua @@ -8476,6 +8476,7 @@ do local emptytable_relations = { + ["emptytable"] = compare_true, ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, diff --git a/tl.tl b/tl.tl index bd3a043d0..d6c9a7b1c 100644 --- a/tl.tl +++ b/tl.tl @@ -8476,6 +8476,7 @@ do -- emptytable rules are the same in eqtype_relations and subtype_relations local emptytable_relations: {TypeName:CompareTypes} = { + ["emptytable"] = compare_true, ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, From d8965faf5da9b0f8da6bc77cd8f3a1e8134e1794 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 19 Sep 2024 15:31:40 -0300 Subject: [PATCH 213/224] types: fix reporting of if block end position --- tl.lua | 1 - tl.tl | 1 - 2 files changed, 2 deletions(-) diff --git a/tl.lua b/tl.lua index 9efdffc60..f082dae71 100644 --- a/tl.lua +++ b/tl.lua @@ -3531,7 +3531,6 @@ do if not block.body then return i end - end_at(block.body, ps.tokens[i - 1]) block.yend, block.xend = block.body.yend, block.body.xend table.insert(node.if_blocks, block) return i, node diff --git a/tl.tl b/tl.tl index d6c9a7b1c..1b3b28fd7 100644 --- a/tl.tl +++ b/tl.tl @@ -3531,7 +3531,6 @@ local function parse_if_block(ps: ParseState, i: integer, n: integer, node: Node if not block.body then return i end - end_at(block.body, ps.tokens[i - 1]) block.yend, block.xend = block.body.yend, block.body.xend table.insert(node.if_blocks, block) return i, node From b618edb8f0f94d36bbfb47ed91534df6836982dc Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 19 Sep 2024 15:36:38 -0300 Subject: [PATCH 214/224] tests: add testcase for `tl types -p` if-block --- spec/cli/types_spec.lua | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index bcb35f1ff..55940d00a 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -175,6 +175,39 @@ describe("#cli tl types works like check", function() assert(type(types["y"]) == "nil") end) + it("reports end of if-block correctly", function() + local filename = util.write_tmp_file(finally, [[ + -- test.tl + + local function hello(): number + return 1 + end + + if 1 == 1 then + local abc = hello() + local def = abc + + + + + + def = abc + end + ]]) + local pd = io.popen(util.tl_cmd("types", "-p", "12:1", filename), "r") + local output = pd:read("*a") + util.assert_popen_close(0, pd:close()) + local types = json.decode(output) + local n = 0 + for _ in pairs(types) do + n = n + 1 + end + assert(n == 3) + assert(type(types["abc"]) == "number") + assert(type(types["def"]) == "number") + assert(type(types["hello"]) == "number") + end) + it("reports number of errors in stderr and code 1 on syntax errors", function() local name = util.write_tmp_file(finally, [[ print(add("string", 20)))))) From 6117c8c2e023cfe9828d972fa91d985e0bda9e8b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 20 Sep 2024 23:59:49 -0300 Subject: [PATCH 215/224] types: reports record functions in record field list Also, strip typedecls from type report, and represent records in the "str" field by their declared name. Thanks @pdesaulniers for the report! --- spec/api/get_types_spec.lua | 28 ++++++++++++++++++++++++++++ tl.lua | 15 +++++++++++++-- tl.tl | 15 +++++++++++++-- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index a36b5373e..d2aea26d1 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -33,4 +33,32 @@ describe("tl.get_types", function() local type_at_y_x = tr.by_pos[""][y][x] assert(tr.types[type_at_y_x].str == "function(string)") end) + + it("reports record functions in record field list", function() + local env = tl.init_env() + env.report_types = true + local result = assert(tl.check_string([[ + local record Point + x: number + y: number + end + + function Point:init(x: number, y: number) + self.x = x + self.y = y + end + ]], env)) + + local tr, trenv = tl.get_types(result) + local y = 1 + local x = 10 + local type_at_y_x = tr.by_pos[""][y][x] + assert(tr.types[type_at_y_x].str == "Point") + local fields = {} + for k, _ in pairs(tr.types[type_at_y_x].fields) do + table.insert(fields, k) + end + table.sort(fields) + assert.same(fields, {"init", "x", "y"}) + end) end) diff --git a/tl.lua b/tl.lua index f082dae71..8e9f08059 100644 --- a/tl.lua +++ b/tl.lua @@ -5860,7 +5860,7 @@ function TypeReporter:get_typenum(t) end if rt.typename == "typedecl" then - rt = rt.def + return self:get_typenum(rt.def) end local ti = { @@ -5916,6 +5916,13 @@ function TypeReporter:get_typenum(t) return n end +function TypeReporter:add_field(rtype, fname, ftype) + local n = self:get_typenum(rtype) + local ti = self.tr.types[n] + assert(ti.fields) + ti.fields[fname] = self:get_typenum(ftype) +end + @@ -6758,7 +6765,7 @@ local function show_type_base(t, short, seen) elseif t.typename == "enum" then return t.declname or "enum" elseif t.fields then - return short and t.typename or t.typename .. show_fields(t, show) + return short and (t.declname or t.typename) or t.typename .. show_fields(t, show) elseif t.typename == "function" then local out = { "function" } if t.typeargs then @@ -11981,6 +11988,10 @@ self:expand_type(node, values, elements) }) if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) + + if self.collector then + self.env.reporter:add_field(rtype, node.name.tk, fn_type) + end else self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return diff --git a/tl.tl b/tl.tl index 1b3b28fd7..580e213fb 100644 --- a/tl.tl +++ b/tl.tl @@ -5860,7 +5860,7 @@ function TypeReporter:get_typenum(t: Type): integer end if rt is TypeDeclType then - rt = rt.def + return self:get_typenum(rt.def) end local ti: TypeInfo = { @@ -5916,6 +5916,13 @@ function TypeReporter:get_typenum(t: Type): integer return n end +function TypeReporter:add_field(rtype: RecordLikeType, fname: string, ftype: Type) + local n = self:get_typenum(rtype) + local ti = self.tr.types[n] + assert(ti.fields) + ti.fields[fname] = self:get_typenum(ftype) +end + local record TypeCollector record Symbol x: integer @@ -6758,7 +6765,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str elseif t is EnumType then return t.declname or "enum" elseif t is RecordLikeType then - return short and t.typename or t.typename .. show_fields(t, show) + return short and (t.declname or t.typename) or t.typename .. show_fields(t, show) elseif t is FunctionType then local out: {string} = {"function"} if t.typeargs then @@ -11981,6 +11988,10 @@ do if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) + + if self.collector then + self.env.reporter:add_field(rtype, node.name.tk, fn_type) + end else self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return From 78c08f9cc6154301f2b51dc21ed16a13f22db31b Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 3 Oct 2024 14:39:35 -0400 Subject: [PATCH 216/224] fix: do not corrupt record type object when assigning a method --- spec/declaration/record_function_spec.lua | 19 +++++++++++++++++++ tl.lua | 10 +++++++--- tl.tl | 12 ++++++++---- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/spec/declaration/record_function_spec.lua b/spec/declaration/record_function_spec.lua index 79b81cb64..75a8ac8cf 100644 --- a/spec/declaration/record_function_spec.lua +++ b/spec/declaration/record_function_spec.lua @@ -147,5 +147,24 @@ describe("record function", function() function Fil:new_method2(self: Fil) end ]])) + + it("method assignment does not corrupt internal record data structure", util.check([[ + local interface MAI + x: integer + my_func: function(self, integer) + end + + local obj: MAI = { x = 20 } + + local record MR is MAI + b: string + end + + obj.my_func = function(self: MAI, n: integer) + end + + function MR:my_func(n: integer) + end + ]])) end) end) diff --git a/tl.lua b/tl.lua index 8e9f08059..80f45a403 100644 --- a/tl.lua +++ b/tl.lua @@ -9940,12 +9940,16 @@ a.types[i], b.types[i]), } local t, e = self:match_record_key(a, anode, rb.literal) if t then - if t.typename == "function" then - for i, p in ipairs(t.args.tuple) do + if t.typename == "function" and t.is_method then + local t2 = shallow_copy_new_type(t) + t2.args = shallow_copy_new_type(t.args) + t2.args.tuple = shallow_copy_table(t2.args.tuple) + for i, p in ipairs(t2.args.tuple) do if p.typename == "self" then - t.args.tuple[i] = a + t2.args.tuple[i] = a end end + return t2 end return t diff --git a/tl.tl b/tl.tl index 580e213fb..a766023ef 100644 --- a/tl.tl +++ b/tl.tl @@ -9940,12 +9940,16 @@ do local t, e = self:match_record_key(a, anode, rb.literal) if t then - if t is FunctionType then - for i, p in ipairs(t.args.tuple) do - if p is SelfType then - t.args.tuple[i] = a + if t is FunctionType and t.is_method then + local t2 = shallow_copy_new_type(t) + t2.args = shallow_copy_new_type(t.args) + t2.args.tuple = shallow_copy_table(t2.args.tuple) + for i, p in ipairs(t2.args.tuple) do + if p.typename == "self" then + t2.args.tuple[i] = a end end + return t2 end return t From 74a44d3ac389ca8cbb5e7aa0c4567d867f8b648e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 3 Oct 2024 14:42:03 -0400 Subject: [PATCH 217/224] accept 'is userdata' syntax in records --- spec/declaration/record_spec.lua | 18 ++++++++++++++++++ tl.lua | 16 ++++++++++++++++ tl.tl | 16 ++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/spec/declaration/record_spec.lua b/spec/declaration/record_spec.lua index b9ea201ad..4a48886b5 100644 --- a/spec/declaration/record_spec.lua +++ b/spec/declaration/record_spec.lua @@ -857,6 +857,14 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces end ]])) + it("can be declared as userdata with 'is'", util.check([[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + is userdata + x: number + y: number + end + ]])) + it("cannot be declared as userdata twice", util.check_syntax_error([[ local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ userdata @@ -868,6 +876,16 @@ for i, name in ipairs({"records", "arrayrecords", "interfaces", "arrayinterfaces { msg = "duplicated 'userdata' declaration" }, })) + it("cannot be declared as userdata with 'is' twice", util.check_syntax_error([[ + local type foo = ]]..statement..[[ ]]..array(i, "{foo}")..[[ + is userdata, userdata + x: number + y: number + end + ]], { + { msg = "duplicated 'userdata' declaration" }, + })) + it("untyped attributes are not accepted (#381)", util.check_syntax_error([[ local ]]..statement..[[ kons ]]..array(i, "{kons}")..[[ any_identifier other_sequence diff --git a/tl.lua b/tl.lua index 80f45a403..b36bb12b2 100644 --- a/tl.lua +++ b/tl.lua @@ -3865,6 +3865,18 @@ do return copy end + local function extract_userdata_from_interface_list(ps, i, def) + for j = #def.interface_list, 1, -1 do + local iface = def.interface_list[j] + if iface.typename == "nominal" and #iface.names == 1 and iface.names[1] == "userdata" then + table.remove(def.interface_list, j) + if def.is_userdata then + fail(ps, i, "duplicated 'userdata' declaration") + end + def.is_userdata = true + end + end + end parse_record_body = function(ps, i, def) def.fields = {} @@ -3896,6 +3908,10 @@ do else i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + + if def.interface_list then + extract_userdata_from_interface_list(ps, i, def) + end end if ps.tokens[i].tk == "where" then diff --git a/tl.tl b/tl.tl index a766023ef..6c161c37f 100644 --- a/tl.tl +++ b/tl.tl @@ -3865,6 +3865,18 @@ local function clone_typeargs(ps: ParseState, i: integer, typeargs: {TypeArgType return copy end +local function extract_userdata_from_interface_list(ps: ParseState, i: integer, def: RecordLikeType) + for j = #def.interface_list, 1, -1 do + local iface = def.interface_list[j] + if iface is NominalType and #iface.names == 1 and iface.names[1] == "userdata" then + table.remove(def.interface_list, j) + if def.is_userdata then + fail(ps, i, "duplicated 'userdata' declaration") + end + def.is_userdata = true + end + end +end parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType): integer, boolean def.fields = {} @@ -3896,6 +3908,10 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType): i else i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + + if def.interface_list then + extract_userdata_from_interface_list(ps, i, def) + end end if ps.tokens[i].tk == "where" then From d05df564970215b6848fffab355347db5548fcde Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 7 Oct 2024 09:13:37 -0700 Subject: [PATCH 218/224] fix: macroexp can resolve self for methods --- spec/code_gen/macroexp_spec.lua | 34 +++++++++++++++++++++++++++++++++ tl.lua | 16 +++++++++++++++- tl.tl | 16 +++++++++++++++- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/spec/code_gen/macroexp_spec.lua b/spec/code_gen/macroexp_spec.lua index 83351d30d..c9668d2f8 100644 --- a/spec/code_gen/macroexp_spec.lua +++ b/spec/code_gen/macroexp_spec.lua @@ -80,5 +80,39 @@ describe("macroexp code generation", function() call_me(ok(8675309)) call_me(fail(911)) ]])) + + it("can resolve self for methods", util.gen([[ + local record R + x: number + + metamethod __call: function(self) = macroexp(self: R) + return print("R is " .. tostring(self.x) .. "!") + end + + get_x: function(self): number = macroexp(self: R): number + return self.x + end + end + + local r: R = { x = 10 } + print(r:get_x()) + r() + ]], [[ + + + + + + + + + + + + + local r: R = { x = 10 } + print(r.x) + print("R is " .. tostring(r.x) .. "!") + ]])) end) diff --git a/tl.lua b/tl.lua index b36bb12b2..6d3ce63f0 100644 --- a/tl.lua +++ b/tl.lua @@ -9555,7 +9555,21 @@ a.types[i], b.types[i]), } end if f and f.macroexp then - expand_macroexp(node, e2, f.macroexp) + local argexps + if is_method then + argexps = {} + if e1.kind == "op" then + table.insert(argexps, e1.e1) + else + table.insert(argexps, e1) + end + for _, e in ipairs(e2) do + table.insert(argexps, e) + end + else + argexps = e2 + end + expand_macroexp(node, argexps, f.macroexp) end return ret, f diff --git a/tl.tl b/tl.tl index 6c161c37f..ada740759 100644 --- a/tl.tl +++ b/tl.tl @@ -9555,7 +9555,21 @@ do end if f and f.macroexp then - expand_macroexp(node, e2, f.macroexp) + local argexps: {Node} + if is_method then + argexps = {} + if e1.kind == "op" then -- obj:method + table.insert(argexps, e1.e1) + else -- __call metamethod + table.insert(argexps, e1) + end + for _, e in ipairs(e2) do + table.insert(argexps, e) + end + else + argexps = e2 + end + expand_macroexp(node, argexps, f.macroexp) end return ret, f From 737a9876890774351848550acf488c2cb48b95bc Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 7 Oct 2024 11:02:00 -0700 Subject: [PATCH 219/224] fix: macroexp resolve metamethods correctly --- spec/code_gen/macroexp_spec.lua | 30 ++++++++++++++++++++++++++++++ tl.lua | 9 +++++++-- tl.tl | 9 +++++++-- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/spec/code_gen/macroexp_spec.lua b/spec/code_gen/macroexp_spec.lua index c9668d2f8..fd5530bf9 100644 --- a/spec/code_gen/macroexp_spec.lua +++ b/spec/code_gen/macroexp_spec.lua @@ -114,5 +114,35 @@ describe("macroexp code generation", function() print(r.x) print("R is " .. tostring(r.x) .. "!") ]])) + + it("can resolve metamethods", util.gen([[ + local record R + x: number + + metamethod __lt: function(a: R, b: R): boolean = macroexp(a: R, b: R): boolean + return a.x < b.x + end + end + + local r: R = { x = 10 } + local s: R = { x = 20 } + if r > s then + print("yes") + end + ]], [[ + + + + + + + + + local r = { x = 10 } + local s = { x = 20 } + if s.x < r.x then + print("yes") + end + ]])) end) diff --git a/tl.lua b/tl.lua index 6d3ce63f0..20bf346ee 100644 --- a/tl.lua +++ b/tl.lua @@ -9576,7 +9576,7 @@ a.types[i], b.types[i]), } end end - function TypeChecker:check_metamethod(node, method_name, a, b, orig_a, orig_b) + function TypeChecker:check_metamethod(node, method_name, a, b, orig_a, orig_b, flipped) if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return a_type(node, "unknown", {}), nil end @@ -9604,6 +9604,9 @@ a.types[i], b.types[i]), } e2[2] = node.e2 args.tuple[2] = orig_b end + if flipped then + e2[2], e2[1] = e2[1], e2[2] + end local mtdelta = metamethod.typename == "function" and metamethod.is_method and -1 or 0 local ret_call = self:type_check_function_call(node, metamethod, args, mtdelta, node, e2) @@ -12450,15 +12453,17 @@ self:expand_type(node, values, elements) }) local meta_on_operator if not t then local mt_name = binop_to_metamethod[node.op.op] + local flipped = false if not mt_name then mt_name = flip_binop_to_metamethod[node.op.op] if mt_name then + flipped = true ra, rb = rb, ra ua, ub = ub, ua end end if mt_name then - t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub, flipped) end if not t then t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) diff --git a/tl.tl b/tl.tl index ada740759..ed85a9958 100644 --- a/tl.tl +++ b/tl.tl @@ -9576,7 +9576,7 @@ do end end - function TypeChecker:check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer + function TypeChecker:check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type, flipped?: boolean): Type, integer if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return an_unknown(node), nil end @@ -9604,6 +9604,9 @@ do e2[2] = node.e2 args.tuple[2] = orig_b end + if flipped then + e2[2], e2[1] = e2[1], e2[2] + end local mtdelta = metamethod is FunctionType and metamethod.is_method and -1 or 0 local ret_call = self:type_check_function_call(node, metamethod, args, mtdelta, node, e2) @@ -12450,15 +12453,17 @@ do local meta_on_operator: integer if not t then local mt_name = binop_to_metamethod[node.op.op] + local flipped = false if not mt_name then mt_name = flip_binop_to_metamethod[node.op.op] if mt_name then + flipped = true ra, rb = rb, ra ua, ub = ub, ua end end if mt_name then - t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub, flipped) end if not t then t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) From ab9bf51643331ca584f3ba477fde901654751d97 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 7 Oct 2024 12:34:55 -0700 Subject: [PATCH 220/224] types: reports reference of a nominal type --- spec/api/get_types_spec.lua | 33 +++++++++++++++++++++++++++++++++ tl.lua | 11 +++++++++++ tl.tl | 11 +++++++++++ 3 files changed, 55 insertions(+) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index d2aea26d1..7b91ef91f 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -61,4 +61,37 @@ describe("tl.get_types", function() table.sort(fields) assert.same(fields, {"init", "x", "y"}) end) + + it("reports reference of a nominal type", function() + local env = tl.init_env() + env.report_types = true + local result = assert(tl.check_string([[ + local record Operator + operator: string + end + + local record Node + node1: Node + operator: Operator + end + + local function node_is_require_call(n: Node): string + if n.operator and n.operator.operator == "." then + return node_is_require_call(n.node1) + end + end + ]], env)) + + local tr, trenv = tl.get_types(result) + local y = 7 + local x = 24 + local type_at_y_x = tr.by_pos[""][y][x] + local ti = tr.types[type_at_y_x] + assert(ti) + assert.same(ti.str, "Operator") + assert(ti.ref) + local ti_ref = tr.types[ti.ref] + assert(ti ~= ti.ref) + assert.same(ti_ref.str, "Operator") + end) end) diff --git a/tl.lua b/tl.lua index 20bf346ee..8b9cef708 100644 --- a/tl.lua +++ b/tl.lua @@ -5939,6 +5939,12 @@ function TypeReporter:add_field(rtype, fname, ftype) ti.fields[fname] = self:get_typenum(ftype) end +function TypeReporter:set_ref(nom, resolved) + local n = self:get_typenum(nom) + local ti = self.tr.types[n] + ti.ref = self:get_typenum(resolved) +end + @@ -8028,6 +8034,10 @@ do t.found = found + if self.collector then + self.env.reporter:set_ref(t, found) + end + return nil, found end @@ -8044,6 +8054,7 @@ do end t.resolved = resolved + return resolved end diff --git a/tl.tl b/tl.tl index ed85a9958..d87ffe46d 100644 --- a/tl.tl +++ b/tl.tl @@ -5939,6 +5939,12 @@ function TypeReporter:add_field(rtype: RecordLikeType, fname: string, ftype: Typ ti.fields[fname] = self:get_typenum(ftype) end +function TypeReporter:set_ref(nom: NominalType, resolved: Type) + local n = self:get_typenum(nom) + local ti = self.tr.types[n] + ti.ref = self:get_typenum(resolved) +end + local record TypeCollector record Symbol x: integer @@ -8028,6 +8034,10 @@ do t.found = found + if self.collector then + self.env.reporter:set_ref(t, found) + end + return nil, found end @@ -8044,6 +8054,7 @@ do end t.resolved = resolved + return resolved end From 2c5219fb51a7f856f72b561e41e4a3c6aab7606e Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 7 Oct 2024 11:02:33 -0700 Subject: [PATCH 221/224] docs: macro expressions --- docs/macroexp.md | 105 +++++++++++++++++++++++++++++++++++++++++++++++ docs/tutorial.md | 3 ++ 2 files changed, 108 insertions(+) create mode 100644 docs/macroexp.md diff --git a/docs/macroexp.md b/docs/macroexp.md new file mode 100644 index 000000000..42b7a0c9b --- /dev/null +++ b/docs/macroexp.md @@ -0,0 +1,105 @@ +# Macro expressions + +Teal supports a restricted form of macro expansion via the `macroexp` +construct, which declares a macro expression. This was added to the +language as the support mechanism for implementing the `where` clauses +in records and interfaces, which power the type resolution performed +by the `is` operator. + +Macro expressions are always expanded inline in the generated Lua code. +The declaration itself produces no Lua code. + +A macro expression is declared similarly to a function, only using +`macroexp` instead of `function`: + +```lua +local macroexp add(a: number, b: number) + return a + b +end +``` + +There are two important restrictions: + +* the body of the macro expression can only contain a single `return` + statement with a single expression; +* each argument can only be used once in the macroexp body. + +The latter restriction allows for macroexp calls to be expanded inline in any +expression context, without the risk for producing double evaluation of +side-effecting expressions. This avoids the pitfalls commonly produced by C +macros in a simple way. + +Because macroexps do not generate code on declaration, you can also +declare a macroexp inline in a record definition: + +```lua +local record R + x: number + + get_x: function(self): number = macroexp(self: R): number + return self.x + end +end + +local r: R = { x = 10 } +print(r:get_x()) +``` + +This generates the following code: + +```lua +local r: R = { x = 10 } +print(r.x) +``` + +You can also use them for metamethods: this will cause the metamethod to +be expanded at compile-time, without requiring a metatable: + +```lua +local record R + x: number + + metamethod __lt: function(a: R, b: R) = macroexp(a: R, b: R) + return a.x < b.x + end +end + +local r: R = { x = 10 } +local s: R = { x = 20 } +if r > s then + print("yes") +end +``` + +This generates the following code: + +```lua +local r = { x = 10 } +local s = { x = 20 } +if s.x < r.x then + print("yes") +end +``` + +This is used to implement the pseudo-metamethod `__is`, which is used to +resolve the `is` operator. The `where` construct is syntax sugar to an +`__is` declaration, meaning the following two constructs are equivalent: + +```lua +local record MyRecord is MyInterface + where self.my_field == "my_record" +end + +-- ...is the same as: + +local record MyRecord is MyInterface + metamethod __is: function(self: MyRecord): boolean = macroexp(self: MyRecord): boolean + return self.my_field == "my_record" + end +end +``` + +At this time, macroexp declarations within records do not allow inference, +so the `function` type needs to be explicitly declared when implementinga +a field or metamethod as a `macroexp`. This requirement may be dropped in +the future. diff --git a/docs/tutorial.md b/docs/tutorial.md index 18572c1ac..70acfef62 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -705,6 +705,9 @@ end The only changes made to the code above were the addition of type signatures in both function declarations. +Teal also supports [macro expressions](macroexp.md), which are a restricted +form of function whose contents are expanded inline when generating Lua code. + ### Variadic functions Just like in Lua, some functions in Teal may receive a variable amount of From 4fa35a45652d69abb8e98f2f1f3b8ba8c0754553 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Thu, 3 Oct 2024 14:49:19 -0400 Subject: [PATCH 222/224] docs: updates for next release --- docs/tutorial.md | 205 ++++++++++++++++++++++++++++++++++++----- docs/type_variables.md | 119 ++++++++++++++++++++++++ 2 files changed, 300 insertions(+), 24 deletions(-) create mode 100644 docs/type_variables.md diff --git a/docs/tutorial.md b/docs/tutorial.md index 70acfef62..003f6d3ab 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -152,9 +152,9 @@ list with a few examples of each; we'll discuss them in more detail below: Finally, there are types that must be declared and referred to using names: -* enum -* record - * userdata +* enums +* records +* interfaces Here is an example declaration of each. Again, we'll go into more detail below, but this should give you an overview: @@ -172,17 +172,35 @@ local record Point y: number end +-- an interface: an abstract record type +local interface Character + sprite: Image + position: Point + kind: string +end + +-- records can implement interfaces, using a type-identifying `where` clause +local record Spaceship + is Character + where self.kind == "spaceship" + + weapon: Weapons +end + +-- a record can also declare an array interface, making it double as a record and an array +local record TreeNode + is {TreeNode} + + item: T +end + -- a userdata record: a record which is implemented as a userdata local record File - userdata + is userdata + status: function(): State close: function(File): boolean, string end - --- a record can double as a record and an array, by declaring an array interface -local record TreeNode is {TreeNode} - item: T -end ``` ## Local variables @@ -528,6 +546,75 @@ local x: http.Response = http.get("http://example.com") print(x.status_code) ``` +## Interfaces + +Interfaces are, in essence, abstract records. + +A concrete record is a type declared with `record`, which can be used +both as a Lua table and as a type. In object-oriented terms, the record +itself works as class whose fields work as class attributes, +while other tables declared with the record type are objects whose +fields are object atributes. For example: + +```lua +local record MyConcreteRecord + a: string + x: integer + y: integer +end + +MyConcreteRecord.a = "this works" + +local obj: MyConcreteRecord = { x = 10, y = 20 } -- this works too +``` + +An interface is abstract. It can declare fields, including those of +`function` type, but they cannot hold concrete values on their own. +Instances of an interface can hold values. + +```lua +local interface MyAbstractInterface + a: string + x: integer + y: integer + my_func: function(self, integer) +end + +MyAbstractInterface.a = "this doesn't work" -- error! + +local obj: MyAbstractInterface = { x = 10, y = 20 } -- this works + +-- error! this doesn't work +function MyAbstractInterface:my_func(n: integer) +end + +-- however, this works +obj.my_func = function(self: MyAbstractInterface, n: integer) +end +``` + +What is most useful about interfaces is that records can inherit +interfaces, using `is`: + +```lua +local record MyRecord is MyAbstractInterface + b: string +end + +local r: MyRecord = {} +r.b = "this works" +r.a = "this works too because 'a' comes from MyAbstractInterface" +``` + +Note that this refers strictly to subtyping of interfaces, not +inheritance of implementations. You cannot use `is` to do +`local MyRecord is AnotherRecord`, as Teal does not implement +a class/object model of its own, as it aims to be compatible +with the multiple class/object models that exist in the Lua +ecosystem. + + + ## Generics Teal supports a simple form of generics that is useful enough for dealing @@ -564,6 +651,33 @@ local t: Tree = { } ``` +A type variable can be constrained by an interface, using `is`: + +```lua +local function largest_shape(shapes: {S}): S + local max = 0 + local largest: S + for _, s in ipairs(shapes) do + if s.area >= max then + max = s.area + largest = s + end + end + return largest +end +``` + +The benefit of doing this instead of `largest_shape(shapes: {Shape}): Shape` +is that, if you call this function passing, say, an array `{Circle}` +(assuming that `record Circle is Shape`, Teal will infer `S` to `Circle`, +and that will be the type of the return value, while still allowing you +to use the specifics of the `Shape` interface within the implementation of +`largest_shape`. + +Keep in mind though, the type variables are inferred upon their first match, +so, especially when using constraints, that might demand [additional +care](type_variables.md). + ## Metamethods Lua supports metamethods to provide some advanced features such as operator @@ -661,19 +775,31 @@ well: ```lua local type Comparator = function(T, T): boolean -local function mysort(arr: {A}, cmp: Comparator) +local function mysort(arr: {A}, cmp?: Comparator) -- ... end ``` +Note that functions can have optional arguments, as in the `cmp?` example above. +This only affects the _arity_ of the functions (that is, the number of arguments +passed to a function), not their types. Note that the question mark is assigned +to the argument name, not its type. If an argument is not optional, it may still +be given explicitly as `nil`. + Another thing to know about function declarations is that you can parenthesize the declaration of return types, to avoid ambiguities when using nested declarations and multiple returns: ```lua -f: function(function():(number, number), number) +f: function(function(? string):(number, number), number) ``` +Note also that in this example the string argument of the return function type +is optional. When declaring optional arguments in function type declarations +which do not use argument names, The question mark is placed ahead of the +type. Again, this is an attribute of the argument position, not of the +argument type itself. + You can declare functions that generate iterators which can be used in `for` statements: the function needs to produce another function that iterates. This is an example [taken the book "Programming in Lua"](https://www.lua.org/pil/7.1.html): @@ -797,29 +923,60 @@ for union types in Teal. The first one is that the `is` operator always matches a variable, not arbitrary expressions. This limitation is there to avoid aliasing. -Since code generation for the `is` operator used for discrimination of union -types translates into a runtime `type()` check, we can only discriminate -across primitive types and at most one table type. +The second one is that Teal only accepts unions over a set of types that +it can discriminate at runtime, so that it can generate code for the +`is` operator properly. That means we can either only use one table +type in a union, or, if we want to use multiple table types in a union, +they need to be records or interfaces that were declared with a `where` +annotation to discriminate them. This means that these unions not accepted: ```lua -local invalid1: MyRecord | MyOtherRecord -local invalid2: {string} | {number} -local invalid3: {string} | {string:string} -local invalid4: {string} | MyRecord +local invalid1: {string} | {number} +local invalid2: {string} | {string:string} +local invalid3: {string} | MyRecord +``` + +However, the following union can be accepted, if we declare the record +types with `where` annotations: + +``` +local interface Named + name: string +end + +local record MyRecord is Named + where self.name == "MyRecord" +end + +local record AnotherRecord is Named + where self.name == "AnotherRecord" +end + +local valid: MyRecord | MyOtherRecord ``` -Also, since `is` checks for enums currently also translate into `type()` checks, -this means they are indistinguishable from strings at runtime. So, for now this -is also not accepted: +A `where` clause is any Teal expression that uses the identifier `self` +at most once (if you need to use it multiple times, you can always write +a function that implements the discriminator expression and call that +in the `where` clause passing `self` as an argument). + +Note that Teal has no way of proving at compile time that the set of `where` +clauses in the union is actually disjoint and can discriminate the values +correctly at runtime. Like the other aspects of setting up a Lua-based +object model, that is up to you. + +Another limitation on `is` checks comes up with enums, since these also +translate into `type()` checks. This means they are indistinguishable from +strings at runtime. So, for now these are also not accepted: ```lua -local invalid5: string | MyEnum +local invalid4: string | MyEnum +local invalid5: MyEnum | AnotherEnum ``` -This restriction between strings and enums may be removed in the future. The -restriction on records may also be lifted in the future. +This restriction on enums may be removed in the future. ## The type `any` diff --git a/docs/type_variables.md b/docs/type_variables.md new file mode 100644 index 000000000..0855ff2f3 --- /dev/null +++ b/docs/type_variables.md @@ -0,0 +1,119 @@ +# Type Variable Matching + +When Teal type-checks a generic function call, it infers any type variables +based on context. Type variables can appear in function arguments and return +types, so these are matched with the information available at the call site: + +* the place where the function call is made is used to infer + type variables in return types; +* the values passed as arguments are used to infer type variables + appearing in function arguments. + +For example, given a generic function with the following type: + +```lua +local my_f: function(T): U +``` + +...the following call will infer `T` to `boolean` and `U` +to `string`. + +``` +local s: string = my_f(true) +``` + +Note that each type variable is inferred upon its first match, and return +types are inferred first, then argument types. This means that if the type +signature was instead this: + +```lua +local my_f: function(T): T +``` + +then the call above would fail with an error like `argument 1: got boolean, +expected string`. + +Matching multiple type variables to types requires particular care when +type variables with `is`-constraints are used multiple types. Consider +the following example, which probably does not do what you want: + +```lua +local interface Shape + area: number +end + +local function largest_shape(a: S, b: S): S + if a.area > b.area then + return a + else + return b + end +end +``` + +When attempting to use this with different kinds of shapes at the same time, +we will get an error: + +```lua +local record Circle is Shape +end + +local record Square is Shape +end + +local c: Circle = { area = 10 } +local s: Square = { area = 20 } + +local l = largest_shape(c, s) -- error! argument 2: Square is not a Circle +``` + +The type variable `S` was matched to `c` first. We can instead do this: + +```lua +local function largest_shape(a: S, b: T): S | T + if a.area > b.area then + return a + else + return b + end +end +``` + +But then we have to make records that can be discriminated in a union, +by giving their definitions `where` clauses. This is a possible solution: + +```lua +-- we add a `name` to the interface +local interface Shape + name: string + area: number +end + +local function largest_shape(a: S, b: T): S | T + if a.area > b.area then + return a + else + return b + end +end + +-- we add `where` clauses to Circle and Square +local record Circle + is Shape + where self.name == "circle" +end + +local record Square + is Shape + where self.name == "square" +end + +-- we add the `name` fields so that the tables conform to their types; +-- in larger programs this would be typically done in constructor functions +local c: Circle = { area = 10, name = "circle" } +local s: Square = { area = 20, name = "square" } + +local l = largest(c, s) +``` + +...which results in `l` having type `Circle | Square`. From e289caf30f7a58b0567d074c43c258d0c6ee4a99 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 7 Oct 2024 11:31:12 -0700 Subject: [PATCH 223/224] docs: self type --- docs/tutorial.md | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index 003f6d3ab..067471045 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -578,6 +578,7 @@ local interface MyAbstractInterface x: integer y: integer my_func: function(self, integer) + another_func: function(self, integer, self) end MyAbstractInterface.a = "this doesn't work" -- error! @@ -606,14 +607,26 @@ r.b = "this works" r.a = "this works too because 'a' comes from MyAbstractInterface" ``` -Note that this refers strictly to subtyping of interfaces, not -inheritance of implementations. You cannot use `is` to do -`local MyRecord is AnotherRecord`, as Teal does not implement -a class/object model of its own, as it aims to be compatible -with the multiple class/object models that exist in the Lua -ecosystem. +Keep in mind that this refers strictly to subtyping of interfaces, not +inheritance of implementations. You cannot use `is` to do `local MyRecord is +AnotherRecord`, as Teal does not implement a class/object model of its own, as +it aims to be compatible with the multiple class/object models that exist in +the Lua ecosystem. +Note also that the definition of `my_func` used `self` as a type name. `self` +is a valid type that can be used when declaring arguments in functions +declared in interfaces and records. When a record is declared to be a subtype +of an interface using `is`, any function arguments using `self` in the parent +interface type will then resolve to the child record's type. The type signature +of `another_func` makes it even more evident: +```lua +-- the following function complies to the type declared for `another_func` +-- in MyAbstractInterface, because MyRecord is the `self` type in this context +function MyRecord:another_func(n: integer, another: MyRecord) + print(n + self.x, another.b) +end +``` ## Generics From 261dacde799ee440bd38aa83b96b6100af9eeb81 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Mon, 7 Oct 2024 12:02:18 -0700 Subject: [PATCH 224/224] docs: prepare changelog for Teal 0.24.0 --- CHANGELOG.md | 120 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b3cffc0f7..74416a6de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,123 @@ +# 0.24.0 - Teal Spring '24 + +2024-10-07 + +This a big release! It is a culmination of work that has spanned multiple +years, and it would not be possible without the amazing feedback given by +the community on the `next` branch, where this was developed. This release +does include that "lot of new code" that I wanted to merge in the "near +future", as stated in the previous release's changelog. + +The main feature is the addition of interfaces, which introduces a model for +subtyping table types in the language. That should allow for representing the +various kinds of object models that are used across the Lua ecosystem, without +promoting any particular class/object system over the others. + +This release features commits by François Perrad, Victor Ilchev and Hisham +Muhammad. + +## What's New + +### Language + +* **Interfaces**: you can now declare abstract interfaces, and record types + can implement them. This allows you to declare subtyping relations, and + better support inheritance models (e.g. `local record Circle is Shape`). + * Records and interfaces may declare a `where` clause with an expression + over `self`, which allows the `is` operator to discriminate the type. + * With discriminated record types, it is now possible to declare unions + over multiple record types. + * "Arrayrecords" are no longer a distinct type: they are just records + that implement an array interface (e.g. `local record R is {T}`) + * Type variables in functions can now have constraints on interfaces + (e.g. `function my_func(...)`) + * `self` is now a valid type that can be used when declaring arguments + in functions declared in interfaces and records + * When a record is declared to be a subtype of an interface using `is`, + any function arguments using `self` in the parent interface type will + then resolve to the child record's type. +* **Optional/required arguments in function calls** - functions may declare + arguments as optional, affecting the required arity of function calls. + The rightmost arguments of a function can have their variable names + annotated with a `?` sign, indicating that the argument does not need + to be passed in a function call. This refers only to the presence of the + argument in a call, and not to its type or nullability: a required + argument may still be called with an explicit `null`, but an optional + argument may be elided. + * In previous versions of Teal, all function arguments were effectively + optional. This means that Teal is now stricter which checking for + function calls. You can use `--feat-arity=off` in the command line + or `feat_arity = "off"` in `tlconfig.lua` to obtain the previous + behavior. + * To convert code that uses the old arity checking rules to the new + behavior, you can also use compiler pragmas in the code, + `--#pragma arity off` and `--#pragma arity off` to disable or enable + stricter arity checks. +* **Macro expressions**: you can declare a restricted form of function called + a `macroexp`, which is always expanded inline. These can also be used + to declare compile-time metamethods, which expand without requiring + a metatable at runtime. The `where` clauses used in interfaces and + records are syntax sugar for macro expressions that implement a + pseudo-metamethod `__is`. +* Dynamic `require` calls that do not take a module name as a literal + string are now allowed, and return `any`; you can load a static type + definition using `local type MyType = require("...")` and then cast + your dynamic require like so: `local my_mod = require(var) as MyType` +* The type system is more nominal: all named types are now treated nominally, + except for unions, which are always type aliases. Previously record types + were nominal, but a named typed that resolved to a primitive such as + `integer` was structural. There are still subleties on the rules for when + a `local type` produces a new distinct type or an alias; they are + [explained in the docs](docs/aliasing.md). +* The `` attribute for variables can now only be applied when + initializing variables with literal tables. This is a minor breaking change, + but the usefulness of this attribute in other cases was very limited, + and it also produced misleading results. +* Improved type signatures in the standard library + * `select` produces variadic returns + +### API + +* Simplified API in the `tl` module. + * The 0.15 API is still supported for backwards compatibility. + +### Tooling + +* `tl build` was removed. [Cyan](https://github.com/teal-language/cyan) + should be used instead as the build tool for Teal projects. + +### Fixes + +* Fix commits included in this release refer both to bugs present + in the `master` branch as well as fixes specific to the 0.24 series, + reported by the community during the beta testing of this release. + The Git history and the GitHub issues list contain a more detailed + accounting of the bugfixes that went into this release. + Some of the fixes include: + * `tl check` now reports if the input file does not exist. + * Reporting location of the end of an `if` block correctly. + * No longer crash if a `require()` epression fails to resolve (#778). + * `tl types` now reports the types of variables in `for` loops. + * Improved error message when calling functions with insufficient + arguments. + * Disallowing using a base type as a type variable name. + * Type arguments resolving correctly in recursive functions. + * Type arguments resolving correctly in nested records (#754). + * `or` type inference between types with a subtyping relation + resolves to whichever is the larger type. + * Reporting an error if an iterator used in a generic `for` does not + declare a return value type (#736). + * When checking `` values, no longer reporting record + function and metamethods as missing fields (#749). + * Localizing a record no longer makes the local a type (#759). + * Bad assignments of record tables are reported (#752). + * Nominal type alias declarations work as expected (#238). + * Nominals with generics can be resolved correcty (#777). + * Nested types are not closed too early (#775). + * Not failing when resolving nested empty tables. + * Exporting generics and type aliases from modules correctly + using `return` at the toplevel (#804). + # 0.15.3 2023-11-05