Skip to content

Commit

Permalink
local type require() accepts dot notation for nested record
Browse files Browse the repository at this point in the history
`local type MyType = require("module").MyType` is now valid.

Closes #778.
  • Loading branch information
hishamhm committed Aug 5, 2024
1 parent df61b5f commit 09d0473
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 13 deletions.
11 changes: 11 additions & 0 deletions spec/cli/types_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ describe("tl types works like check", function()
util.assert_popen_close(0, pd:close())
-- TODO check json output
end)

it("does not crash when a require() expression does not resolve (#778)", function()
local name = util.write_tmp_file(finally, [[
local type Foo = require("missingmodule").baz
]])
local pd = io.popen(util.tl_cmd("types", name, "--gen-target=5.1") .. "2>&1 1>" .. util.os_null, "r")
local output = pd:read("*a")
util.assert_popen_close(1, pd:close())
assert.match("2 errors:", output, 1, true)
-- TODO check json output
end)
end)

describe("on .lua files", function()
Expand Down
51 changes: 51 additions & 0 deletions spec/stdlib/require_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1152,4 +1152,55 @@ describe("require", function()
assert.same({}, result.env.loaded["./types/person.tl"].type_errors)
end)
end)

it("in 'local type' accepts dots for extracting nested types", function ()
-- ok
util.mock_io(finally, {
["mod.tl"] = [[
local record mod
record Foo<K>
something: K
end
end
return mod
]],
["main.tl"] = [[
local type Foo = require("mod").Foo
local function f(v: Foo)
print(v.something)
end
]],
})
local result, err = tl.process("main.tl")

assert.same({}, result.syntax_errors)
assert.same({}, result.type_errors)
end)

it("in 'local type' does not accept arbitrary expressions", function ()
-- ok
util.mock_io(finally, {
["mod.tl"] = [[
local record mod
record Foo<K>
something: K
end
end
return mod
]],
["main.tl"] = [[
local type Foo = require("mod") + "hello"
local function f(v: Foo)
print(v.something)
end
]],
})
local result, err = tl.process("main.tl")

assert.same({
{ filename = "main.tl", x = 30, y = 1, msg = "require() in type declarations cannot be part of larger expressions" }
}, result.syntax_errors)
end)
end)
46 changes: 40 additions & 6 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2808,18 +2808,25 @@ do
end

local function node_is_require_call(n)
if n.e1 and n.e2 and
n.e1.kind == "variable" and n.e1.tk == "require" and
if not (n.e1 and n.e2) then
return nil
end
if n.op and n.op.op == "." then

return node_is_require_call(n.e1)
elseif n.e1.kind == "variable" and n.e1.tk == "require" and
n.e2.kind == "expression_list" and #n.e2 == 1 and
n.e2[1].kind == "string" then


return n.e2[1].conststr
elseif n.op and n.op.op == "@funcall" and
n.e1 and n.e1.tk == "pcall" and
n.e2 and #n.e2 == 2 and
n.e2[1].kind == "variable" and n.e2[1].tk == "require" and
n.e2[2].kind == "string" and n.e2[2].conststr then


return n.e2[2].conststr
else
return nil
Expand Down Expand Up @@ -4031,11 +4038,20 @@ do

if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_call_or_assignment(ps, i)
if asgn.value and not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
i, asgn.value = parse_expression(ps, i)
if asgn.value then
if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then
fail(ps, istart, "require() in type declarations cannot be part of larger expressions")
return i
end
if not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
return i
end
return i, asgn
else
return i
end
return i, asgn
end

i, asgn.value = parse_newtype(ps, i)
Expand Down Expand Up @@ -10577,6 +10593,24 @@ self:expand_type(node, values, elements) })
ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty
local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty })
return td
elseif value.kind == "op" and
value.op.op == "." then

local ty = self:get_typedecl(value.e1)
if ty.typename == "typedecl" then
local def = ty.def
if def.typename == "record" then
local t = def.fields[value.e2.tk]
if t then
return a_type(value, "typedecl", { def = t })
else
return self.errs:invalid_at(value.e2, "type not found")
end
else
return self.errs:invalid_at(value.e2, "type is not a record")
end
end
return ty
else
local newtype = value.newtype
if newtype.typename == "typealias" then
Expand Down
48 changes: 41 additions & 7 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -2808,18 +2808,25 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node
end

local function node_is_require_call(n: Node): string
if n.e1 and n.e2 -- literal require call
and n.e1.kind == "variable" and n.e1.tk == "require"
if not (n.e1 and n.e2) then
return nil
end
if n.op and n.op.op == "." then
-- require("str").something
return node_is_require_call(n.e1)
elseif n.e1.kind == "variable" and n.e1.tk == "require"
and n.e2.kind == "expression_list" and #n.e2 == 1
and n.e2[1].kind == "string"
then
-- require("str")
return n.e2[1].conststr
elseif n.op and n.op.op == "@funcall" -- pcall(require, "str")
elseif n.op and n.op.op == "@funcall"
and n.e1 and n.e1.tk == "pcall"
and n.e2 and #n.e2 == 2
and n.e2[1].kind == "variable" and n.e2[1].tk == "require"
and n.e2[2].kind == "string" and n.e2[2].conststr
then
-- pcall(require, "str")
return n.e2[2].conststr
else
return nil -- table.insert cares about arity
Expand Down Expand Up @@ -4031,11 +4038,20 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin

if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_call_or_assignment(ps, i)
if asgn.value and not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
i, asgn.value = parse_expression(ps, i)
if asgn.value then
if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then
fail(ps, istart, "require() in type declarations cannot be part of larger expressions")
return i
end
if not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
return i
end
return i, asgn
else
return i
end
return i, asgn
end

i, asgn.value = parse_newtype(ps, i)
Expand Down Expand Up @@ -10577,6 +10593,24 @@ do
ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty
local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType)
return td
elseif value.kind == "op"
and value.op.op == "."
then
local ty = self:get_typedecl(value.e1)
if ty is TypeDeclType then
local def = ty.def
if def is RecordType then
local t = def.fields[value.e2.tk]
if t then
return a_type(value, "typedecl", { def = t } as TypeDeclType)
else
return self.errs:invalid_at(value.e2, "type not found")
end
else
return self.errs:invalid_at(value.e2, "type is not a record")
end
end
return ty
else
local newtype = value.newtype
if newtype is TypeAliasType then
Expand Down

0 comments on commit 09d0473

Please sign in to comment.