Skip to content

Commit

Permalink
Move over-max-arg packing to analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
jeaye committed Nov 30, 2024
1 parent e4f9653 commit b9e1675
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 94 deletions.
38 changes: 38 additions & 0 deletions compiler+runtime/include/cpp/jank/analyze/expr/list.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <jank/analyze/expression_base.hpp>

namespace jank::analyze::expr
{
using namespace jank::runtime;

template <typename E>
struct list : expression_base
{
native_vector<native_box<E>> data_exprs;
option<object_ptr> meta;
obj::persistent_list_ptr data{};

void propagate_position(expression_position const pos)
{
position = pos;
}

object_ptr to_runtime_data() const
{
object_ptr exprs(make_box<obj::persistent_vector>());
for(auto const &e : data_exprs)
{
exprs = conj(exprs, e->to_runtime_data());
}

return merge(static_cast<expression_base const *>(this)->to_runtime_data(),
obj::persistent_array_map::create_unique(make_box("__type"),
make_box("expr::list"),
make_box("data_exprs"),
exprs,
make_box("meta"),
jank::detail::to_runtime_data(meta)));
}
};
}
6 changes: 3 additions & 3 deletions compiler+runtime/include/cpp/jank/analyze/expr/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ namespace jank::analyze::expr

object_ptr to_runtime_data() const
{
object_ptr pair_maps(make_box<obj::persistent_vector>());
object_ptr exprs(make_box<obj::persistent_vector>());
for(auto const &e : data_exprs)
{
pair_maps = conj(pair_maps, e->to_runtime_data());
exprs = conj(exprs, e->to_runtime_data());
}

return merge(static_cast<expression_base const *>(this)->to_runtime_data(),
obj::persistent_array_map::create_unique(make_box("__type"),
make_box("expr::vector"),
make_box("data_exprs"),
pair_maps,
exprs,
make_box("meta"),
jank::detail::to_runtime_data(meta)));
}
Expand Down
2 changes: 2 additions & 0 deletions compiler+runtime/include/cpp/jank/analyze/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <jank/analyze/expr/var_ref.hpp>
#include <jank/analyze/expr/call.hpp>
#include <jank/analyze/expr/primitive_literal.hpp>
#include <jank/analyze/expr/list.hpp>
#include <jank/analyze/expr/vector.hpp>
#include <jank/analyze/expr/map.hpp>
#include <jank/analyze/expr/set.hpp>
Expand All @@ -32,6 +33,7 @@ namespace jank::analyze
expr::var_ref<E>,
expr::call<E>,
expr::primitive_literal<E>,
expr::list<E>,
expr::vector<E>,
expr::map<E>,
expr::set<E>,
Expand Down
13 changes: 13 additions & 0 deletions compiler+runtime/include/cpp/jank/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ extern "C"
jank_object_ptr a8,
jank_object_ptr a9,
jank_object_ptr a10);
jank_object_ptr jank_call11(jank_object_ptr f,
jank_object_ptr a1,
jank_object_ptr a2,
jank_object_ptr a3,
jank_object_ptr a4,
jank_object_ptr a5,
jank_object_ptr a6,
jank_object_ptr a7,
jank_object_ptr a8,
jank_object_ptr a9,
jank_object_ptr a10,
jank_object_ptr rest);

jank_object_ptr jank_nil();
jank_object_ptr jank_true();
Expand All @@ -102,6 +114,7 @@ extern "C"
jank_object_ptr jank_symbol_create(jank_object_ptr ns, jank_object_ptr name);
jank_object_ptr jank_character_create(char const *s);

jank_object_ptr jank_list_create(uint64_t size, ...);
jank_object_ptr jank_vector_create(uint64_t size, ...);
jank_object_ptr jank_map_create(uint64_t pairs, ...);
jank_object_ptr jank_set_create(uint64_t size, ...);
Expand Down
2 changes: 2 additions & 0 deletions compiler+runtime/include/cpp/jank/codegen/llvm_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ namespace jank::codegen
analyze::expr::function_arity<analyze::expression> const &);
llvm::Value *gen(analyze::expr::primitive_literal<analyze::expression> const &,
analyze::expr::function_arity<analyze::expression> const &);
llvm::Value *gen(analyze::expr::list<analyze::expression> const &,
analyze::expr::function_arity<analyze::expression> const &);
llvm::Value *gen(analyze::expr::vector<analyze::expression> const &,
analyze::expr::function_arity<analyze::expression> const &);
llvm::Value *gen(analyze::expr::map<analyze::expression> const &,
Expand Down
1 change: 1 addition & 0 deletions compiler+runtime/include/cpp/jank/evaluate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace jank::evaluate
runtime::object_ptr eval(analyze::expr::var_ref<analyze::expression> const &);
runtime::object_ptr eval(analyze::expr::call<analyze::expression> const &);
runtime::object_ptr eval(analyze::expr::primitive_literal<analyze::expression> const &);
runtime::object_ptr eval(analyze::expr::list<analyze::expression> const &);
runtime::object_ptr eval(analyze::expr::vector<analyze::expression> const &);
runtime::object_ptr eval(analyze::expr::map<analyze::expression> const &);
runtime::object_ptr eval(analyze::expr::set<analyze::expression> const &);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace jank::runtime
static native_box<static_object> create(native_box<static_object> s);

static_object() = default;
static_object(static_object &&) = default;
static_object(static_object &&) noexcept = default;
static_object(static_object const &) = default;
static_object(value_type &&d);
static_object(value_type const &d);
static_object(object_ptr meta, value_type const &d);

/* TODO: This is broken when `args` is a value_type list we're looking to wrap in another list.
* It just uses the copy ctor. */
Expand Down
23 changes: 7 additions & 16 deletions compiler+runtime/include/cpp/jank/runtime/visit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,18 +341,9 @@ namespace jank::runtime
break;
default:
{
/* TODO: Use fmt when possible. */
throw std::runtime_error{ "invalid object type: "
+ std::to_string(static_cast<int>(const_erased->type)) };
//throw std::runtime_error
//{
// fmt::format
// (
// "invalid object type: {} raw value {}",
// magic_enum::enum_name(erased->type),
// static_cast<int>(erased->type)
// )
//};
throw std::runtime_error{ fmt::format("invalid object type: {} raw value {}",
magic_enum::enum_name(erased->type),
static_cast<int>(erased->type)) };
}
break;
}
Expand Down Expand Up @@ -537,7 +528,7 @@ namespace jank::runtime
return visit_seqable(
fn,
[=]() -> decltype(fn(obj::cons_ptr{}, std::forward<Args>(args)...)) {
throw std::runtime_error{ "not seqable: " + to_string(const_erased) };
throw std::runtime_error{ "not seqable: " + to_code_string(const_erased) };
},
const_erased,
std::forward<Args>(args)...);
Expand Down Expand Up @@ -586,7 +577,7 @@ namespace jank::runtime
return visit_map_like(
fn,
[=]() -> decltype(fn(obj::persistent_hash_map_ptr{}, std::forward<Args>(args)...)) {
throw std::runtime_error{ "not map-like: " + to_string(const_erased) };
throw std::runtime_error{ "not map-like: " + to_code_string(const_erased) };
},
const_erased,
std::forward<Args>(args)...);
Expand Down Expand Up @@ -630,7 +621,7 @@ namespace jank::runtime
return visit_set_like(
fn,
[=]() -> decltype(fn(obj::persistent_hash_set_ptr{}, std::forward<Args>(args)...)) {
throw std::runtime_error{ "not set-like: " + to_string(const_erased) };
throw std::runtime_error{ "not set-like: " + to_code_string(const_erased) };
},
const_erased,
std::forward<Args>(args)...);
Expand Down Expand Up @@ -675,7 +666,7 @@ namespace jank::runtime
return visit_number_like(
fn,
[=]() -> decltype(fn(obj::integer_ptr{}, std::forward<Args>(args)...)) {
throw std::runtime_error{ "not a number: " + to_string(const_erased) };
throw std::runtime_error{ "not a number: " + to_code_string(const_erased) };
},
const_erased,
std::forward<Args>(args)...);
Expand Down
43 changes: 40 additions & 3 deletions compiler+runtime/src/cpp/jank/analyze/processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,17 +1387,54 @@ namespace jank::analyze
}

native_vector<expression_ptr> arg_exprs;
arg_exprs.reserve(arg_count);
for(auto const &s : o->data.rest())
arg_exprs.reserve(std::min(arg_count, runtime::max_params + 1));

auto it(o->data.rest());
for(size_t i{}; i < runtime::max_params && i < arg_count; ++i, it = it.rest())
{
auto arg_expr(analyze(s, current_frame, expression_position::value, fn_ctx, needs_arg_box));
auto arg_expr(analyze(it.first().unwrap(),
current_frame,
expression_position::value,
fn_ctx,
needs_arg_box));
if(arg_expr.is_err())
{
return arg_expr;
}
arg_exprs.emplace_back(arg_expr.expect_ok());
}

/* If we have more args than a fn allows, we need to pack all of the extras
* into a single list and tack that on at the end. So, if max_params is 10, and
* we pass 15 args, we'll pass 10 normally and then we'll have a special 11th
* arg which is a list containing the 5 remaining params. We rely on dynamic_call
* to do the hard work of packing that in the shape the function actually wants,
* based on its highest fixed arity flag. */
if(runtime::max_params < arg_count)
{
native_vector<expression_ptr> packed_arg_exprs;
for(size_t i{ runtime::max_params }; i < arg_count; ++i, it = it.rest())
{
auto arg_expr(analyze(it.first().unwrap(),
current_frame,
expression_position::value,
fn_ctx,
needs_arg_box));
if(arg_expr.is_err())
{
return arg_expr;
}
packed_arg_exprs.emplace_back(arg_expr.expect_ok());
}
expr::list<expression> list{
expression_base{ {}, expression_position::value, current_frame, needs_arg_box },
std::move(packed_arg_exprs),
none,
nullptr
};
arg_exprs.emplace_back(make_box<expression>(std::move(list)));
}

auto const recursion_ref(boost::get<expr::recursion_reference<expression>>(&source->data));
if(recursion_ref)
{
Expand Down
59 changes: 59 additions & 0 deletions compiler+runtime/src/cpp/jank/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,45 @@ extern "C"
a10_obj);
}

jank_object_ptr jank_call11(jank_object_ptr const f,
jank_object_ptr const a1,
jank_object_ptr const a2,
jank_object_ptr const a3,
jank_object_ptr const a4,
jank_object_ptr const a5,
jank_object_ptr const a6,
jank_object_ptr const a7,
jank_object_ptr const a8,
jank_object_ptr const a9,
jank_object_ptr const a10,
jank_object_ptr const rest)
{
auto const f_obj(reinterpret_cast<object *>(f));
auto const a1_obj(reinterpret_cast<object *>(a1));
auto const a2_obj(reinterpret_cast<object *>(a2));
auto const a3_obj(reinterpret_cast<object *>(a3));
auto const a4_obj(reinterpret_cast<object *>(a4));
auto const a5_obj(reinterpret_cast<object *>(a5));
auto const a6_obj(reinterpret_cast<object *>(a6));
auto const a7_obj(reinterpret_cast<object *>(a7));
auto const a8_obj(reinterpret_cast<object *>(a8));
auto const a9_obj(reinterpret_cast<object *>(a9));
auto const a10_obj(reinterpret_cast<object *>(a10));
auto const rest_obj(reinterpret_cast<object *>(rest));
return dynamic_call(f_obj,
a1_obj,
a2_obj,
a3_obj,
a4_obj,
a5_obj,
a6_obj,
a7_obj,
a8_obj,
a9_obj,
a10_obj,
try_object<obj::persistent_list>(rest_obj));
}

jank_object_ptr jank_nil()
{
return erase(obj::nil::nil_const());
Expand Down Expand Up @@ -338,6 +377,26 @@ extern "C"
return erase(make_box<obj::character>(read::parse::get_char_from_literal(s).unwrap()));
}

jank_object_ptr jank_list_create(uint64_t const size, ...)
{
/* NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) */
va_list args;
va_start(args, size);

native_vector<object_ptr> v;

for(uint64_t i{}; i < size; ++i)
{
/* NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) */
v.emplace_back(reinterpret_cast<object *>(va_arg(args, jank_object_ptr)));
}

va_end(args);

runtime::detail::native_persistent_list const npl{ v.rbegin(), v.rend() };
return erase(make_box<obj::persistent_list>(std::move(npl)));
}

jank_object_ptr jank_vector_create(uint64_t const size, ...)
{
/* NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) */
Expand Down
33 changes: 31 additions & 2 deletions compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,14 @@ namespace jank::codegen

static native_persistent_string arity_to_call_fn(size_t const arity)
{
/* Anything max_params + 1 or higher gets packed into a list so we
* just end up calling max_params + 1 at most. */
switch(arity)
{
case 0 ... 10:
case 0 ... runtime::max_params:
return fmt::format("jank_call{}", arity);
default:
throw std::runtime_error{ fmt::format("invalid fn arity: {}", arity) };
return fmt::format("jank_call{}", runtime::max_params + 1);
}
}

Expand Down Expand Up @@ -386,6 +388,33 @@ namespace jank::codegen
return ret;
}

llvm::Value *llvm_processor::gen(analyze::expr::list<analyze::expression> const &expr,
analyze::expr::function_arity<analyze::expression> const &arity)
{
auto const fn_type(
llvm::FunctionType::get(ctx->builder->getPtrTy(), { ctx->builder->getInt64Ty() }, true));
auto const fn(ctx->module->getOrInsertFunction("jank_list_create", fn_type));

auto const size(expr.data_exprs.size());
std::vector<llvm::Value *> args;
args.reserve(1 + size);
args.emplace_back(ctx->builder->getInt64(size));

for(auto const &expr : expr.data_exprs)
{
args.emplace_back(gen(expr, arity));
}

auto const call(ctx->builder->CreateCall(fn, args));

if(expr.position == analyze::expression_position::tail)
{
return ctx->builder->CreateRet(call);
}

return call;
}

llvm::Value *llvm_processor::gen(analyze::expr::vector<analyze::expression> const &expr,
analyze::expr::function_arity<analyze::expression> const &arity)
{
Expand Down
Loading

0 comments on commit b9e1675

Please sign in to comment.