Skip to content

Commit

Permalink
tl gen: --keep-hashbang flag
Browse files Browse the repository at this point in the history
Closes #646.
  • Loading branch information
hishamhm committed Nov 8, 2023
1 parent 3002e0c commit 2fc72bf
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/compiler_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ return {
| `--gen-target` | `gen_target` | `string` | `build` `gen` `run` | Minimum targeted Lua version for generated code. Options are `5.1`, `5.3` and `5.4`. See [below](#generated-code) for details.
| | `include` | `{string}` | `build` | The set of files to compile/check. See below for details on patterns.
| | `exclude` | `{string}` | `build` | The set of files to exclude. See below for details on patterns.
| `--keep-hashbang` | | | `gen` | Preserve hashbang line (`#!`) at the top of file if present.
| `-s --source-dir` | `source_dir` | `string` | `build` | Set the directory to be searched for files. `build` will compile every .tl file in every subdirectory by default.
| `-b --build-dir` | `build_dir` | `string` | `build` | Set the directory for generated files, mimicking the file structure of the source files.
| | `files` | `{string}` | `build` | The names of files to be compiled. Does not accept patterns like `include`.
Expand Down
30 changes: 30 additions & 0 deletions spec/cli/gen_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ end
local c = 100
]]

local script_with_hashbang = [[
#!/usr/bin/env lua
print("hello world")
]]

local script_without_hashbang = [[
print("hello world")
]]

local function tl_to_lua(name)
return (name:gsub("%.tl$", ".lua"):gsub("^" .. util.os_tmp .. util.os_sep, ""))
end
Expand Down Expand Up @@ -185,6 +195,26 @@ describe("tl gen", function()
end)
end)

it("preserves hashbang with --keep-hashbang", function()
local name = util.write_tmp_file(finally, script_with_hashbang)
local pd = io.popen(util.tl_cmd("gen", "--keep-hashbang", name), "r")
local output = pd:read("*a")
util.assert_popen_close(0, pd:close())
local lua_name = tl_to_lua(name)
assert.match("Wrote: " .. lua_name, output, 1, true)
util.assert_line_by_line(script_with_hashbang, util.read_file(lua_name))
end)

it("drops hashbang when not using --keep-hashbang", function()
local name = util.write_tmp_file(finally, script_with_hashbang)
local pd = io.popen(util.tl_cmd("gen", name), "r")
local output = pd:read("*a")
util.assert_popen_close(0, pd:close())
local lua_name = tl_to_lua(name)
assert.match("Wrote: " .. lua_name, output, 1, true)
util.assert_line_by_line(script_without_hashbang, util.read_file(lua_name))
end)

describe("with --gen-target=5.1", function()
it("targets generated code to Lua 5.1+", function()
local name = util.write_tmp_file(finally, [[
Expand Down
4 changes: 2 additions & 2 deletions spec/lexer/hashbang_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ describe("lexer", function()
it("skips hashbang at the beginning of a file", function()
local syntax_errors = {}
local tokens = tl.lex("#!/usr/bin/env lua\nlocal x = 1")
assert.same({"#!/usr/bin/env lua\n", "local", "x", "=", "1", "$EOF$"}, map(function(x) return x.tk end, tokens))

tl.parse_program(tokens, syntax_errors)
assert.same({}, syntax_errors)
assert.same(5, #tokens)
assert.same({"local", "x", "=", "1", "$EOF$"}, map(function(x) return x.tk end, tokens))
end)
end)
13 changes: 10 additions & 3 deletions tl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ local function type_check_and_load(tlconfig, filename)
return chunk
end

local function write_out(tlconfig, result, output_file)
local function write_out(tlconfig, result, output_file, pp_opts)
if tlconfig["pretend"] then
print("Would Write: " .. output_file)
return
Expand All @@ -243,7 +243,7 @@ local function write_out(tlconfig, result, output_file)
end

local _
_, err = ofd:write(tl.pretty_print_ast(result.ast, tlconfig.gen_target) .. "\n")
_, err = ofd:write(tl.pretty_print_ast(result.ast, tlconfig.gen_target, pp_opts) .. "\n")
if err then
die("error writing " .. output_file .. ": " .. err)
end
Expand Down Expand Up @@ -863,6 +863,7 @@ local function get_args_parser()
local gen_command = parser:command("gen", "Generate a Lua file for one or more Teal files.")
gen_command:argument("file", "The Teal source file."):args("+")
gen_command:flag("-c --check", "Type check and fail on type errors.")
gen_command:flag("--keep-hashbang", "Preserve hashbang line (#!) at the top of file if present.")
gen_command:option("-o --output", "Write to <filename> instead.")
:argname("<filename>")

Expand Down Expand Up @@ -1227,9 +1228,15 @@ commands["gen"] = function(tlconfig, args)
local results = {}
local err
local env
local pp_opts
for i, input_file in ipairs(args["file"]) do
if not env then
env = setup_env(tlconfig, input_file)
pp_opts = {
preserve_indent = true,
preserve_newlines = true,
preserve_hashbang = args["keep_hashbang"]
}
end

local res = {
Expand All @@ -1248,7 +1255,7 @@ commands["gen"] = function(tlconfig, args)

for _, res in ipairs(results) do
if #res.tl_result.syntax_errors == 0 then
write_out(tlconfig, res.tl_result, args["output"] or res.output_file)
write_out(tlconfig, res.tl_result, args["output"] or res.output_file, pp_opts)
end
end

Expand Down
37 changes: 29 additions & 8 deletions tl.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack
local VERSION = "0.15.3+dev"

local tl = {TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, }
local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, }









Expand Down Expand Up @@ -217,6 +223,7 @@ tl.typecodes = {




local TL_DEBUG = os.getenv("TL_DEBUG")
local TL_DEBUG_MAXLINE = _tl_math_maxinteger

Expand Down Expand Up @@ -279,6 +286,7 @@ end






do
Expand Down Expand Up @@ -592,10 +600,12 @@ do

local len = #input
if input:sub(1, 2) == "#!" then
begin_token()
i = input:find("\n")
if not i then
i = len + 1
end
end_token_here("hashbang")
y = 2
x = 0
end
Expand Down Expand Up @@ -1327,6 +1337,7 @@ local is_attribute = attributes






local function is_array_type(t)
Expand Down Expand Up @@ -3164,7 +3175,16 @@ function tl.parse_program(tokens, errs, filename)
filename = filename or "",
required_modules = {},
}
local _, node = parse_statements(ps, 1, true)
local i = 1
local hashbang
if ps.tokens[i].kind == "hashbang" then
hashbang = ps.tokens[i].tk
i = i + 1
end
local _, node = parse_statements(ps, i, true)
if hashbang then
node.hashbang = hashbang
end

clear_redundant_errors(errs)
return node, ps.required_modules
Expand Down Expand Up @@ -3689,18 +3709,16 @@ local spaced_op = {
}






local default_pretty_print_ast_opts = {
preserve_indent = true,
preserve_newlines = true,
preserve_hashbang = false,
}

local fast_pretty_print_ast_opts = {
preserve_indent = false,
preserve_newlines = true,
preserve_hashbang = false,
}

local primitive = {
Expand Down Expand Up @@ -3837,6 +3855,9 @@ function tl.pretty_print_ast(ast, gen_target, mode)
["statements"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
if opts.preserve_hashbang and node.hashbang then
table.insert(out, node.hashbang)
end
local space
for i, child in ipairs(children) do
add_child(out, child, space, indent)
Expand Down Expand Up @@ -10854,7 +10875,7 @@ function tl.process_string(input, is_lua, env, filename, module_name)
return result
end

tl.gen = function(input, env)
tl.gen = function(input, env, pp)
env = env or assert(tl.init_env(), "Default environment initialization failed")
local result = tl.process_string(input, false, env)

Expand All @@ -10863,7 +10884,7 @@ tl.gen = function(input, env)
end

local code
code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target)
code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp)
return code, result
end

Expand Down
47 changes: 34 additions & 13 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ local record tl
"5.4"
end

record PrettyPrintOptions
preserve_indent: boolean
preserve_newlines: boolean
preserve_hashbang: boolean
end

record TypeCheckOptions
lax: boolean
filename: string
Expand Down Expand Up @@ -125,7 +131,7 @@ local record tl
load: function(string, string, LoadMode, {any:any}): LoadFunction, string
process: function(string, Env, string, FILE): (Result, string)
process_string: function(string, boolean, Env, string, string): Result
gen: function(string, Env): string, Result
gen: function(string, Env, PrettyPrintOptions): string, Result
type_check: function(Node, TypeCheckOptions): Result, string
init_env: function(boolean, boolean | CompatMode, TargetMode, {string}): Env, string
version: function(): string
Expand Down Expand Up @@ -204,6 +210,7 @@ local type Result = tl.Result
local type Env = tl.Env
local type Error = tl.Error
local type CompatMode = tl.CompatMode
local type PrettyPrintOptions = tl.PrettyPrintOptions
local type TypeCheckOptions = tl.TypeCheckOptions
local type LoadMode = tl.LoadMode
local type LoadFunction = tl.LoadFunction
Expand Down Expand Up @@ -258,6 +265,7 @@ end
--------------------------------------------------------------------------------

local enum TokenKind
"hashbang"
"keyword"
"op"
"string"
Expand Down Expand Up @@ -592,10 +600,12 @@ do

local len = #input
if input:sub(1,2) == "#!" then
begin_token()
i = input:find("\n")
if not i then
i = len + 1
end
end_token_here("hashbang")
y = 2
x = 0
end
Expand Down Expand Up @@ -1250,6 +1260,7 @@ local record Node
kind: NodeKind
symbol_list_slot: integer
semicolon: boolean
hashbang: string

is_longstring: boolean

Expand Down Expand Up @@ -3164,7 +3175,16 @@ function tl.parse_program(tokens: {Token}, errs: {Error}, filename: string): Nod
filename = filename or "",
required_modules = {},
}
local _, node = parse_statements(ps, 1, true)
local i = 1
local hashbang: string
if ps.tokens[i].kind == "hashbang" then
hashbang = ps.tokens[i].tk
i = i + 1
end
local _, node = parse_statements(ps, i, true)
if hashbang then
node.hashbang = hashbang
end

clear_redundant_errors(errs)
return node, ps.required_modules
Expand Down Expand Up @@ -3688,19 +3708,17 @@ local spaced_op: {integer:{string:boolean}} = {
},
}

local record PrettyPrintOpts
preserve_indent: boolean
preserve_newlines: boolean
end

local default_pretty_print_ast_opts: PrettyPrintOpts = {
local default_pretty_print_ast_opts: PrettyPrintOptions = {
preserve_indent = true,
preserve_newlines = true,
preserve_hashbang = false,
}

local fast_pretty_print_ast_opts: PrettyPrintOpts = {
local fast_pretty_print_ast_opts: PrettyPrintOptions = {
preserve_indent = false,
preserve_newlines = true,
preserve_hashbang = false,
}

local primitive: {TypeName:string} = {
Expand All @@ -3714,12 +3732,12 @@ local primitive: {TypeName:string} = {
["thread"] = "thread",
}

function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOpts): string, string
function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOptions): string, string
local err: string
local indent = 0

local opts: PrettyPrintOpts
if mode is PrettyPrintOpts then
local opts: PrettyPrintOptions
if mode is PrettyPrintOptions then
opts = mode
elseif mode == true then
opts = fast_pretty_print_ast_opts
Expand Down Expand Up @@ -3837,6 +3855,9 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean |
["statements"] = {
after = function(node: Node, children: {Output}): Output
local out: Output = { y = node.y, h = 0 }
if opts.preserve_hashbang and node.hashbang then
table.insert(out, node.hashbang)
end
local space: string
for i, child in ipairs(children) do
add_child(out, child, space, indent)
Expand Down Expand Up @@ -10854,7 +10875,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename: s
return result
end

tl.gen = function(input: string, env: Env): string, Result
tl.gen = function(input: string, env: Env, pp: PrettyPrintOptions): string, Result
env = env or assert(tl.init_env(), "Default environment initialization failed")
local result = tl.process_string(input, false, env)

Expand All @@ -10863,7 +10884,7 @@ tl.gen = function(input: string, env: Env): string, Result
end

local code: string
code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target)
code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp)
return code, result
end

Expand Down

0 comments on commit 2fc72bf

Please sign in to comment.