Skip to content

Commit

Permalink
feat: introduce option cast_literal_in_context.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashigeru committed Jul 3, 2024
1 parent 3d70411 commit 7a8134d
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 46 deletions.
21 changes: 21 additions & 0 deletions include/mizugaki/analyzer/sql_analyzer_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ class sql_analyzer_options {
*/
static constexpr bool default_validate_scalar_expressions = false;

/**
* @brief the default value of whether or not to automatically cast literals if the context type is specified.
* @see cast_literals_in_context()
*/
static constexpr bool default_cast_literals_in_context = true;

/**
* @brief the default value of a function name to advance a sequence value.
* @see advance_sequence_function_name()
Expand Down Expand Up @@ -382,6 +388,20 @@ class sql_analyzer_options {
return validate_scalar_expressions_;
}

/**
* @brief returns whether or not to automatically cast literals if the context type is specified.
* @return true if inserting cast is enabled
* @return false otherwise
*/
[[nodiscard]] bool& cast_literals_in_context() noexcept {
return cast_literals_in_context_;
}

/// @copydoc cast_literals_in_context()
[[nodiscard]] bool const& cast_literals_in_context() const noexcept {
return cast_literals_in_context_;
}

/**
* @brief returns the function name to advance a sequence value.
* @details This function requires a symbol of the target sequence.
Expand Down Expand Up @@ -419,6 +439,7 @@ class sql_analyzer_options {
bool host_parameter_declaration_starts_with_colon_ { default_host_parameter_declaration_starts_with_colon };
bool allow_context_independent_null_ { default_allow_context_independent_null };
bool validate_scalar_expressions_ { default_validate_scalar_expressions };
bool cast_literals_in_context_ { default_cast_literals_in_context };

std::string_view advance_sequence_function_name_ { default_advance_sequence_function_name };
};
Expand Down
91 changes: 47 additions & 44 deletions src/mizugaki/analyzer/details/analyze_literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,26 @@ class engine {
{}

[[nodiscard]] std::unique_ptr<tscalar::expression> process(ast::literal::literal const& literal) {
return ast::literal::dispatch(*this, literal);
auto result = ast::literal::dispatch(*this, literal);
if (!result) {
return {};
}
if (auto&& t = value_context_.type();
context_.options()->cast_literals_in_context() &&
t && *t != result->type()) {
// NOTE: here, we only apply cast operation to constant values.
// later optimization (if it available) will reduce this operation
return context_.create<tscalar::cast>(
result->region(),
t,
tscalar::cast_loss_policy::error,
std::move(result));
}
return result;

}

std::unique_ptr<tscalar::expression> operator()(ast::literal::literal const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::literal const& value) {
context_.report(
sql_analyzer_code::unsupported_feature,
string_builder {}
Expand All @@ -49,22 +65,22 @@ class engine {
return {};
}

std::unique_ptr<tscalar::expression> operator()(ast::literal::boolean const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::boolean const& value) {
if (value.value() == ast::literal::boolean_kind::unknown) {
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::unknown {}),
context_.types().get(ttype::boolean {})));
context_.types().get(ttype::boolean {}));
}
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::boolean {
value.value() == ast::literal::boolean_kind::true_,
}),
context_.types().get(ttype::boolean {})));
context_.types().get(ttype::boolean {}));
}

std::unique_ptr<tscalar::expression> operator()(ast::literal::numeric const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::numeric const& value) {
using ast::literal::kind;
switch (value.node_kind()) {
case kind::exact_numeric: return process_exact_numeric(value);
Expand All @@ -83,7 +99,7 @@ class engine {
return {};
}

std::unique_ptr<tscalar::expression> operator()(ast::literal::string const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::string const& value) {
using ast::literal::kind;
switch (value.node_kind()) {
case kind::character_string: return process_character_string(value);
Expand All @@ -101,7 +117,7 @@ class engine {
return {};
}

std::unique_ptr<tscalar::expression> operator()(ast::literal::datetime const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::datetime const& value) {
using ast::literal::kind;
switch (value.node_kind()) {
default:
Expand All @@ -118,12 +134,12 @@ class engine {
return {};
}

// std::unique_ptr<tscalar::expression> operator()(ast::literal::interval const& value) {
// std::unique_ptr<tscalar::immediate> operator()(ast::literal::interval const& value) {
// (void) value;
// return {};
// }

std::unique_ptr<tscalar::expression> operator()(ast::literal::null const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::null const& value) {
if (auto t = value_context_.type()) {
return context_.create<tscalar::immediate>(
value.region(),
Expand All @@ -142,12 +158,12 @@ class engine {
return {};
}

// std::unique_ptr<tscalar::expression> operator()(ast::literal::empty const& value) {
// std::unique_ptr<tscalar::immediate> operator()(ast::literal::empty const& value) {
// (void) value;
// return {};
// }

std::unique_ptr<tscalar::expression> operator()(ast::literal::default_ const& value) {
std::unique_ptr<tscalar::immediate> operator()(ast::literal::default_ const& value) {
auto t = value_context_.type();
auto v = value_context_.default_value();
if (t && v) {
Expand All @@ -174,7 +190,7 @@ class engine {
return {};
}

[[nodiscard]] std::unique_ptr<tscalar::expression> process_exact_numeric(ast::literal::numeric const& value) { // NOLINT(*-function-cognitive-complexity)
[[nodiscard]] std::unique_ptr<tscalar::immediate> process_exact_numeric(ast::literal::numeric const& value) { // NOLINT(*-function-cognitive-complexity)
// FIXME: move to yugawara
auto max_precision = static_cast<mpd_ssize_t>(context_.options()->max_decimal_precision());
::decimal::Context context {
Expand Down Expand Up @@ -221,44 +237,44 @@ class engine {
if (integer) {
if (context_.options()->prefer_small_integer_literals()) {
if (auto r = soft_cast<std::int8_t>(*integer)) {
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::int4 { *r }),
context_.types().get(ttype::int1 {})));
context_.types().get(ttype::int1 {}));
}
if (auto r = soft_cast<std::int16_t>(*integer)) {
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::int4 { *r }),
context_.types().get(ttype::int2 {})));
context_.types().get(ttype::int2 {}));
}
if (auto r = soft_cast<std::int32_t>(*integer)) {
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::int4 { *r }),
context_.types().get(ttype::int4 {})));
context_.types().get(ttype::int4 {}));
}
}
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::int8 { *integer }),
context_.types().get(ttype::int8 {})));
context_.types().get(ttype::int8 {}));
}
}
std::optional<std::size_t> precision {};
if (context_.options()->prefer_small_decimal_literals()) {
precision = static_cast<std::size_t>(::decimal::Decimal::radix());
}
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::decimal { ::takatori::decimal::triple { v } }),
context_.types().get(ttype::decimal {
precision,
static_cast<std::size_t>(v.exponent()),
})));
}));
}

[[nodiscard]] std::unique_ptr<tscalar::expression> process_approximate_numeric(ast::literal::numeric const& value) {
[[nodiscard]] std::unique_ptr<tscalar::immediate> process_approximate_numeric(ast::literal::numeric const& value) {
// FIXME: more accurate values for fp
std::size_t index {};
double result {};
Expand Down Expand Up @@ -293,10 +309,10 @@ class engine {
value.region());
return {};
}
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::float8 { result }),
context_.types().get(ttype::float8 {})));
context_.types().get(ttype::float8 {}));
}

[[nodiscard]] std::optional<std::size_t> count_characters(ast::literal::string::value_type const& string) {
Expand Down Expand Up @@ -340,7 +356,7 @@ class engine {
}
}

[[nodiscard]] std::unique_ptr<tscalar::expression> process_character_string(ast::literal::string const& value) {
[[nodiscard]] std::unique_ptr<tscalar::immediate> process_character_string(ast::literal::string const& value) {
std::size_t nchars = 0;
if (auto n = count_characters(value.value())) {
nchars += *n;
Expand All @@ -366,26 +382,13 @@ class engine {
if (context_.options()->prefer_small_character_literals()) {
size = nchars;
}
return adapt(context_.create<tscalar::immediate>(
return context_.create<tscalar::immediate>(
value.region(),
context_.values().get(tvalue::character { std::move(string) }),
context_.types().get(ttype::character {
ttype::varying,
size,
})));
}

[[nodiscard]] std::unique_ptr<tscalar::expression> adapt(std::unique_ptr<tscalar::immediate> expression) {
if (auto t = value_context_.type(); t && *t != expression->type()) {
// NOTE: here, we only apply cast operation to constant values.
// later optimization (if it available) will reduce this operation
return context_.create<tscalar::cast>(
expression->region(),
std::move(t),
tscalar::cast_loss_policy::error,
std::move(expression));
}
return expression;
}));
}
};

Expand Down
45 changes: 43 additions & 2 deletions test/mizugaki/analyzer/details/analyze_literal_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <takatori/type/character.h>

#include <takatori/scalar/immediate.h>
#include <takatori/scalar/cast.h>

#include <mizugaki/ast/literal/numeric.h>
#include <mizugaki/ast/literal/string.h> // NOLINT
Expand Down Expand Up @@ -234,7 +235,7 @@ TEST_F(analyze_literal_test, null_wo_context) {
EXPECT_TRUE(contains(context(), diagnostic_code::missing_context_of_null));
}

TEST_F(analyze_literal_test, dfault) {
TEST_F(analyze_literal_test, default) {
auto r = analyze_literal(
context(),
ast::literal::default_ {},
Expand All @@ -249,7 +250,7 @@ TEST_F(analyze_literal_test, dfault) {
}));
}

TEST_F(analyze_literal_test, dfault_wo_context) {
TEST_F(analyze_literal_test, default_wo_context) {
auto r = analyze_literal(
context(),
ast::literal::default_ {},
Expand All @@ -260,4 +261,44 @@ TEST_F(analyze_literal_test, dfault_wo_context) {
EXPECT_TRUE(contains(context(), diagnostic_code::missing_context_of_default_value));
}

TEST_F(analyze_literal_test, conversion_by_context_enabled) {
options_.cast_literals_in_context() = true;
auto r = analyze_literal(
context(),
ast::literal::string {
ast::literal::kind::character_string,
"'1'",
},
{
ttype::int8 {},
});
ASSERT_TRUE(r);
EXPECT_EQ(*r, (tscalar::cast {
ttype::int8 {},
tscalar::cast_loss_policy::error,
tscalar::immediate {
tvalue::character { "1" },
ttype::character { ttype::varying },
}
}));
}

TEST_F(analyze_literal_test, conversion_by_context_disabled) {
options_.cast_literals_in_context() = false;
auto r = analyze_literal(
context(),
ast::literal::string {
ast::literal::kind::character_string,
"'1'",
},
{
ttype::int8 {},
});
ASSERT_TRUE(r);
EXPECT_EQ(*r, (tscalar::immediate {
tvalue::character { "1" },
ttype::character { ttype::varying },
}));
}

} // namespace mizugaki::analyzer::details

0 comments on commit 7a8134d

Please sign in to comment.