From 36f3e506befcdb5f4ee46689d86615c451692f35 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 13 Oct 2024 05:33:43 -0300 Subject: [PATCH] fix: generate __is-aware code for `is` on unions This expands an `is` applied to a union (even in behind an alias) into a chain of `or` tests with individual `is` checks on the types of the union. If the individual entries have __is metamethods, expands their code as well. This implementation is not recursive (i.e. it does not handle unions of unions), but it is a major improvement over the previous behavior. Fixes #742. --- spec/lang/operator/is_spec.lua | 118 +++++++++++++++++++++++++++++++++ tl.lua | 44 ++++++++++-- tl.tl | 44 ++++++++++-- 3 files changed, 198 insertions(+), 8 deletions(-) diff --git a/spec/lang/operator/is_spec.lua b/spec/lang/operator/is_spec.lua index 154ba657..c771bfc1 100644 --- a/spec/lang/operator/is_spec.lua +++ b/spec/lang/operator/is_spec.lua @@ -649,6 +649,124 @@ end]])) end end ]])) + + it("generates type checks expanding unions (#742)", util.gen([[ + global record Foo + bar: string + end + + global function repro(x:Foo | string | nil): integer + local y = x + if y is string | Foo then + return 1 + elseif y is nil then + return 2 + end + return 3 + end + ]], [[ + Foo = {} + + + + function repro(x) + local y = x + if type(y) == "string" or type(y) == "table" then + return 1 + elseif y == nil then + return 2 + end + return 3 + end + ]])) + + it("generates type checks applying __is to discriminated records in unions", util.gen([[ + local interface Type + typename: string + end + + local record FooType is Type where self.typename == "foo" + end + + local record BarType is Type where self.typename == "bar" + end + + global function repro(x:Type | string | nil): integer + local y = x + if y is FooType | BarType then + return 1 + elseif y is nil then + return 2 + end + return 3 + end + ]], [[ + + + + + + + + + + + function repro(x) + local y = x + if y.typename == "foo" or y.typename == "bar" then + return 1 + elseif y == nil then + return 2 + end + return 3 + end + ]])) + + it("generates type checks applying __is to discriminated records in unions expanding alias", util.gen([[ + local interface Type + typename: string + end + + local record FooType is Type where self.typename == "foo" + end + + local record BarType is Type where self.typename == "bar" + end + + local type FooBar = FooType | BarType + + global function repro(x:Type | string | nil): integer + local y = x + if y is FooBar then + return 1 + elseif y is nil then + return 2 + end + return 3 + end + ]], [[ + + + + + + + + + + + + + function repro(x) + local y = x + if y.typename == "foo" or y.typename == "bar" then + return 1 + elseif y == nil then + return 2 + end + return 3 + end + ]])) end) end) diff --git a/tl.lua b/tl.lua index 287cc002..fdaae079 100644 --- a/tl.lua +++ b/tl.lua @@ -8063,7 +8063,7 @@ do local immediate, found = find_nominal_type_decl(self, nom) - if type(immediate) == "table" then + if immediate and (immediate.typename == "invalid" or immediate.typename == "typedecl") then return immediate end @@ -9670,6 +9670,36 @@ a.types[i], b.types[i]), } end end + local function make_is_node(self, var, v, t) + local node = node_at(var, { kind = "op", op = { op = "is", arity = 2, prec = 3 } }) + node.e1 = var + node.e2 = node_at(var, { kind = "cast", casttype = self:infer_at(var, t) }) + self:check_metamethod(node, "__is", self:to_structural(v), self:to_structural(t), v, t) + if node.expanded then + apply_macroexp(node) + end + node.known = IsFact({ var = var.tk, typ = t, w = node }) + return node + end + + local function convert_is_of_union_to_or_of_is(self, node, v, u) + local var = node.e1 + node.op.op = "or" + node.op.arity = 2 + node.op.prec = 1 + node.e1 = make_is_node(self, var, v, u.types[1]) + local at = node + local n = #u.types + for i = 2, n - 1 do + at.e2 = node_at(var, { kind = "op", op = { op = "or", arity = 2, prec = 1 } }) + at.e2.e1 = make_is_node(self, var, v, u.types[i]) + node.known = OrFact({ f1 = at.e1.known, f2 = at.e2.known, w = node }) + at = at.e2 + end + at.e2 = make_is_node(self, var, v, u.types[n]) + node.known = OrFact({ f1 = at.e1.known, f2 = at.e2.known, w = node }) + end + function TypeChecker:match_record_key(tbl, rec, key) assert(type(tbl) == "table") assert(type(rec) == "table") @@ -12320,9 +12350,15 @@ self:expand_type(node, values, elements) }) if rb.typename == "integer" then self.all_needs_compat["math"] = true end - if node.e1.kind == "variable" then - self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact({ var = node.e1.tk, typ = ub, w = node }) + if ra.typename == "typedecl" then + self.errs:add(node, "can only use 'is' on variables, not types") + elseif node.e1.kind == "variable" then + if rb.typename == "union" then + convert_is_of_union_to_or_of_is(self, node, ra, rb) + else + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact({ var = node.e1.tk, typ = ub, w = node }) + end else self.errs:add(node, "can only use 'is' on variables") end diff --git a/tl.tl b/tl.tl index 3b4755f5..ab3ce77d 100644 --- a/tl.tl +++ b/tl.tl @@ -8063,7 +8063,7 @@ do local immediate, found = find_nominal_type_decl(self, nom) -- if it was previously resolved (or a circular require, or an error), return that; - if immediate is InvalidOrTypeDeclType then + if immediate and immediate is InvalidOrTypeDeclType then return immediate end @@ -9670,6 +9670,36 @@ do end end + local function make_is_node(self: TypeChecker, var: Node, v: Type, t: Type): Node + local node = node_at(var, { kind = "op", op = { op = "is", arity = 2, prec = 3 } }) + node.e1 = var + node.e2 = node_at(var, { kind = "cast", casttype = self:infer_at(var, t) }) + self:check_metamethod(node, "__is", self:to_structural(v), self:to_structural(t), v, t) + if node.expanded then + apply_macroexp(node) + end + node.known = IsFact { var = var.tk, typ = t, w = node } + return node + end + + local function convert_is_of_union_to_or_of_is(self: TypeChecker, node: Node, v: Type, u: UnionType) + local var = node.e1 + node.op.op = "or" + node.op.arity = 2 + node.op.prec = 1 + node.e1 = make_is_node(self, var, v, u.types[1]) + local at = node + local n = #u.types + for i = 2, n - 1 do + at.e2 = node_at(var, { kind = "op", op = { op = "or", arity = 2, prec = 1 } }) + at.e2.e1 = make_is_node(self, var, v, u.types[i]) + node.known = OrFact { f1 = at.e1.known, f2 = at.e2.known, w = node } + at = at.e2 + end + at.e2 = make_is_node(self, var, v, u.types[n]) + node.known = OrFact { f1 = at.e1.known, f2 = at.e2.known, w = node } + end + function TypeChecker:match_record_key(tbl: Type, rec: Node, key: string): Type, string assert(type(tbl) == "table") assert(type(rec) == "table") @@ -12320,9 +12350,15 @@ do if rb.typename == "integer" then self.all_needs_compat["math"] = true end - if node.e1.kind == "variable" then - self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact { var = node.e1.tk, typ = ub, w = node } + if ra is TypeDeclType then + self.errs:add(node, "can only use 'is' on variables, not types") + elseif node.e1.kind == "variable" then + if rb is UnionType then + convert_is_of_union_to_or_of_is(self, node, ra, rb) + else + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact { var = node.e1.tk, typ = ub, w = node } + end else self.errs:add(node, "can only use 'is' on variables") end