Skip to content

Commit

Permalink
local type require() accepts dot notation for nested record
Browse files Browse the repository at this point in the history
`local type MyType = require("module").MyType` is now valid.

Closes #778.
  • Loading branch information
hishamhm committed Aug 7, 2024
1 parent e95abaa commit fbac8f3
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 32 deletions.
5 changes: 3 additions & 2 deletions docs/grammar.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,13 @@ precedence, see below.
* typeargs ::= ‘<’ Name {‘,’ Name } ‘>’
* newtype ::= ‘record’ recordbody | ‘enum’ enumbody | type
* | ‘require’ ‘(’ LiteralString ‘)’ {‘.’ Name }
* interfacelist ::= nominal {‘,’ nominal} |
* ‘{’ type ‘}’ {‘,’ nominal}
* ‘{’ type ‘}’ {‘,’ nominal}
* recordbody ::= [typeargs] [‘is’ interfacelist]
* [‘where’ exp] {recordentry} ‘end’
* [‘where’ exp] {recordentry} ‘end’
* recordentry ::= ‘userdata’ |
* ‘type’ Name ‘=’ newtype | [‘metamethod’] recordkey ‘:’ type |
Expand Down
11 changes: 11 additions & 0 deletions spec/cli/types_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ describe("tl types works like check", function()
util.assert_popen_close(0, pd:close())
-- TODO check json output
end)

it("does not crash when a require() expression does not resolve (#778)", function()
local name = util.write_tmp_file(finally, [[
local type Foo = require("missingmodule").baz
]])
local pd = io.popen(util.tl_cmd("types", name, "--gen-target=5.1") .. "2>&1 1>" .. util.os_null, "r")
local output = pd:read("*a")
util.assert_popen_close(1, pd:close())
assert.match("1 error:", output, 1, true)
-- TODO check json output
end)
end)

describe("on .lua files", function()
Expand Down
6 changes: 6 additions & 0 deletions spec/stdlib/pcall_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ local tl = require("tl")
local util = require("spec.util")

describe("pcall", function()
it("can't use pcall in local type", util.check_syntax_error([[
local type bla = pcall("require", "something")
]], {
{ msg = "pcall() cannot be used in type declarations" }
}))

it("can't pcall nothing", util.check_type_error([[
local pok = pcall()
]], {
Expand Down
97 changes: 97 additions & 0 deletions spec/stdlib/require_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1152,4 +1152,101 @@ describe("require", function()
assert.same({}, result.env.loaded["./types/person.tl"].type_errors)
end)
end)

it("in 'local type' accepts dots for extracting nested types", function ()
-- ok
util.mock_io(finally, {
["mod.tl"] = [[
local record mod
record Foo<K>
something: K
fn: function(): Foo<K>
end
end
return mod
]],
["main.tl"] = [[
local type Foo = require("mod").Foo
local function f(v: Foo<integer>)
print(v.something)
-- check that aliasing works:
local x = v.fn()
x = v
end
]],
})
local result, err = tl.process("main.tl")

assert.same({}, result.syntax_errors)
assert.same({}, result.type_errors)
end)

it("in 'local type' with dots fails if not a record", function ()
-- ok
util.mock_io(finally, {
["mod.tl"] = [[
return 123
]],
["main.tl"] = [[
local type Foo = require("mod").Foo
]],
})
local result, err = tl.process("main.tl")

assert.same({}, result.syntax_errors)
assert.same({
{ filename = "main.tl", x = 37, y = 1, msg = "type is not a record" },
{ filename = "main.tl", x = 45, y = 1, msg = "cannot index key 'Foo' in type integer" },
}, result.type_errors)
end)

it("in 'local type' with dots fails if not a record", function ()
-- ok
util.mock_io(finally, {
["mod.tl"] = [[
local record mod
record Bar
end
end
return mod
]],
["main.tl"] = [[
local type Foo = require("mod").Foo
]],
})
local result, err = tl.process("main.tl")

assert.same({}, result.syntax_errors)
assert.same({
{ filename = "main.tl", x = 45, y = 1, msg = "nested type 'Foo' not found in record" },
}, result.type_errors)
end)

it("in 'local type' does not accept arbitrary expressions", function ()
-- ok
util.mock_io(finally, {
["mod.tl"] = [[
local record mod
record Foo<K>
something: K
end
end
return mod
]],
["main.tl"] = [[
local type Foo = require("mod") + "hello"
local function f(v: Foo)
print(v.something)
end
]],
})
local result, err = tl.process("main.tl")

assert.same({
{ filename = "main.tl", x = 30, y = 1, msg = "require() in type declarations cannot be part of larger expressions" }
}, result.syntax_errors)
end)
end)
76 changes: 61 additions & 15 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2808,22 +2808,37 @@ do
end

local function node_is_require_call(n)
if n.e1 and n.e2 and
n.e1.kind == "variable" and n.e1.tk == "require" and
if not (n.e1 and n.e2) then
return nil
end
if n.op and n.op.op == "." then

return node_is_require_call(n.e1)
elseif n.e1.kind == "variable" and n.e1.tk == "require" and
n.e2.kind == "expression_list" and #n.e2 == 1 and
n.e2[1].kind == "string" then


return n.e2[1].conststr
elseif n.op and n.op.op == "@funcall" and
end
return nil
end

local function node_is_require_call_or_pcall(n)
local r = node_is_require_call(n)
if r then
return r
end
if n.op and n.op.op == "@funcall" and
n.e1 and n.e1.tk == "pcall" and
n.e2 and #n.e2 == 2 and
n.e2[1].kind == "variable" and n.e2[1].tk == "require" and
n.e2[2].kind == "string" and n.e2[2].conststr then


return n.e2[2].conststr
else
return nil
end
return nil
end

do
Expand Down Expand Up @@ -2999,7 +3014,7 @@ do

e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args }

table.insert(ps.required_modules, node_is_require_call(e1))
table.insert(ps.required_modules, node_is_require_call_or_pcall(e1))
elseif tkop.tk == "[" then
local op = new_operator(tkop, 2, "@index")

Expand Down Expand Up @@ -3042,7 +3057,7 @@ do
table.insert(args, argument)
e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args }

table.insert(ps.required_modules, node_is_require_call(e1))
table.insert(ps.required_modules, node_is_require_call_or_pcall(e1))
elseif tkop.tk == "as" or tkop.tk == "is" then
local op = new_operator(tkop, 2, tkop.tk)

Expand Down Expand Up @@ -4029,13 +4044,27 @@ do

i = verify_tk(ps, i, "=")

if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_call_or_assignment(ps, i)
if asgn.value and not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
if ps.tokens[i].kind == "identifier" then
if ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_expression(ps, i)
if asgn.value then
if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then
fail(ps, istart, "require() in type declarations cannot be part of larger expressions")
return i
end
if not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
return i
end
return i, asgn
else
return i
end
elseif ps.tokens[i].tk == "pcall" then
fail(ps, i, "pcall() cannot be used in type declarations")
return i
end
return i, asgn
end

i, asgn.value = parse_newtype(ps, i)
Expand Down Expand Up @@ -10575,8 +10604,25 @@ self:expand_type(node, values, elements) })
local ty = t.typename == "tuple" and t.tuple[1] or t

ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty
local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty })
return td
return (ty.typename == "typedecl") and ty or (a_type(value, "typedecl", { def = ty }))
elseif value.kind == "op" and
value.op.op == "." then

local ty = self:get_typedecl(value.e1)
if ty.typename == "typedecl" then
local def = ty.def
if def.typename == "record" then
local t = def.fields[value.e2.tk]
if t and t.typename == "typedecl" then
return t
else
return self.errs:invalid_at(value.e2, "nested type '" .. value.e2.tk .. "' not found in record")
end
else
return self.errs:invalid_at(value.e1, "type is not a record")
end
end
return ty
else
local newtype = value.newtype
if newtype.typename == "typealias" then
Expand Down
76 changes: 61 additions & 15 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -2808,22 +2808,37 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node
end

local function node_is_require_call(n: Node): string
if n.e1 and n.e2 -- literal require call
and n.e1.kind == "variable" and n.e1.tk == "require"
if not (n.e1 and n.e2) then
return nil
end
if n.op and n.op.op == "." then
-- `require("str").something`
return node_is_require_call(n.e1)
elseif n.e1.kind == "variable" and n.e1.tk == "require"
and n.e2.kind == "expression_list" and #n.e2 == 1
and n.e2[1].kind == "string"
then
-- `require("str")`
return n.e2[1].conststr
elseif n.op and n.op.op == "@funcall" -- pcall(require, "str")
end
return nil -- table.insert cares about arity
end

local function node_is_require_call_or_pcall(n: Node): string
local r = node_is_require_call(n)
if r then
return r
end
if n.op and n.op.op == "@funcall"
and n.e1 and n.e1.tk == "pcall"
and n.e2 and #n.e2 == 2
and n.e2[1].kind == "variable" and n.e2[1].tk == "require"
and n.e2[2].kind == "string" and n.e2[2].conststr
then
-- `pcall(require, "str")`
return n.e2[2].conststr
else
return nil -- table.insert cares about arity
end
return nil -- table.insert cares about arity
end

do
Expand Down Expand Up @@ -2999,7 +3014,7 @@ do

e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args }

table.insert(ps.required_modules, node_is_require_call(e1))
table.insert(ps.required_modules, node_is_require_call_or_pcall(e1))
elseif tkop.tk == "[" then
local op: Operator = new_operator(tkop, 2, "@index")

Expand Down Expand Up @@ -3042,7 +3057,7 @@ do
table.insert(args, argument)
e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args }

table.insert(ps.required_modules, node_is_require_call(e1))
table.insert(ps.required_modules, node_is_require_call_or_pcall(e1))
elseif tkop.tk == "as" or tkop.tk == "is" then
local op: Operator = new_operator(tkop, 2, tkop.tk)

Expand Down Expand Up @@ -4029,13 +4044,27 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin

i = verify_tk(ps, i, "=")

if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_call_or_assignment(ps, i)
if asgn.value and not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
if ps.tokens[i].kind == "identifier" then
if ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_expression(ps, i)
if asgn.value then
if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then
fail(ps, istart, "require() in type declarations cannot be part of larger expressions")
return i
end
if not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
return i
end
return i, asgn
else
return i
end
elseif ps.tokens[i].tk == "pcall" then
fail(ps, i, "pcall() cannot be used in type declarations")
return i
end
return i, asgn
end

i, asgn.value = parse_newtype(ps, i)
Expand Down Expand Up @@ -10575,8 +10604,25 @@ do
local ty: Type = t is TupleType and t.tuple[1] or t

ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty
local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType)
return td
return (ty is TypeDeclType) and ty or (a_type(value, "typedecl", { def = ty } as TypeDeclType))
elseif value.kind == "op"
and value.op.op == "."
then
local ty = self:get_typedecl(value.e1)
if ty is TypeDeclType then
local def = ty.def
if def is RecordType then
local t = def.fields[value.e2.tk]
if t and t is TypeDeclType then
return t
else
return self.errs:invalid_at(value.e2, "nested type '" .. value.e2.tk .. "' not found in record")
end
else
return self.errs:invalid_at(value.e1, "type is not a record")
end
end
return ty
else
local newtype = value.newtype
if newtype is TypeAliasType then
Expand Down

0 comments on commit fbac8f3

Please sign in to comment.