From 7a8134d6003c650058f03f91e827432bb79686a4 Mon Sep 17 00:00:00 2001 From: Suguru ARAKAWA Date: Wed, 3 Jul 2024 15:57:13 +0900 Subject: [PATCH] feat: introduce option `cast_literal_in_context`. --- .../mizugaki/analyzer/sql_analyzer_options.h | 21 +++++ .../analyzer/details/analyze_literal.cpp | 91 ++++++++++--------- .../analyzer/details/analyze_literal_test.cpp | 45 ++++++++- 3 files changed, 111 insertions(+), 46 deletions(-) diff --git a/include/mizugaki/analyzer/sql_analyzer_options.h b/include/mizugaki/analyzer/sql_analyzer_options.h index b51edc2..172d48c 100644 --- a/include/mizugaki/analyzer/sql_analyzer_options.h +++ b/include/mizugaki/analyzer/sql_analyzer_options.h @@ -118,6 +118,12 @@ class sql_analyzer_options { */ static constexpr bool default_validate_scalar_expressions = false; + /** + * @brief the default value of whether or not to automatically cast literals if the context type is specified. + * @see cast_literals_in_context() + */ + static constexpr bool default_cast_literals_in_context = true; + /** * @brief the default value of a function name to advance a sequence value. * @see advance_sequence_function_name() @@ -382,6 +388,20 @@ class sql_analyzer_options { return validate_scalar_expressions_; } + /** + * @brief returns whether or not to automatically cast literals if the context type is specified. + * @return true if inserting cast is enabled + * @return false otherwise + */ + [[nodiscard]] bool& cast_literals_in_context() noexcept { + return cast_literals_in_context_; + } + + /// @copydoc cast_literals_in_context() + [[nodiscard]] bool const& cast_literals_in_context() const noexcept { + return cast_literals_in_context_; + } + /** * @brief returns the function name to advance a sequence value. * @details This function requires a symbol of the target sequence. @@ -419,6 +439,7 @@ class sql_analyzer_options { bool host_parameter_declaration_starts_with_colon_ { default_host_parameter_declaration_starts_with_colon }; bool allow_context_independent_null_ { default_allow_context_independent_null }; bool validate_scalar_expressions_ { default_validate_scalar_expressions }; + bool cast_literals_in_context_ { default_cast_literals_in_context }; std::string_view advance_sequence_function_name_ { default_advance_sequence_function_name }; }; diff --git a/src/mizugaki/analyzer/details/analyze_literal.cpp b/src/mizugaki/analyzer/details/analyze_literal.cpp index df586a8..4543558 100644 --- a/src/mizugaki/analyzer/details/analyze_literal.cpp +++ b/src/mizugaki/analyzer/details/analyze_literal.cpp @@ -35,10 +35,26 @@ class engine { {} [[nodiscard]] std::unique_ptr process(ast::literal::literal const& literal) { - return ast::literal::dispatch(*this, literal); + auto result = ast::literal::dispatch(*this, literal); + if (!result) { + return {}; + } + if (auto&& t = value_context_.type(); + context_.options()->cast_literals_in_context() && + t && *t != result->type()) { + // NOTE: here, we only apply cast operation to constant values. + // later optimization (if it available) will reduce this operation + return context_.create( + result->region(), + t, + tscalar::cast_loss_policy::error, + std::move(result)); + } + return result; + } - std::unique_ptr operator()(ast::literal::literal const& value) { + std::unique_ptr operator()(ast::literal::literal const& value) { context_.report( sql_analyzer_code::unsupported_feature, string_builder {} @@ -49,22 +65,22 @@ class engine { return {}; } - std::unique_ptr operator()(ast::literal::boolean const& value) { + std::unique_ptr operator()(ast::literal::boolean const& value) { if (value.value() == ast::literal::boolean_kind::unknown) { - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::unknown {}), - context_.types().get(ttype::boolean {}))); + context_.types().get(ttype::boolean {})); } - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::boolean { value.value() == ast::literal::boolean_kind::true_, }), - context_.types().get(ttype::boolean {}))); + context_.types().get(ttype::boolean {})); } - std::unique_ptr operator()(ast::literal::numeric const& value) { + std::unique_ptr operator()(ast::literal::numeric const& value) { using ast::literal::kind; switch (value.node_kind()) { case kind::exact_numeric: return process_exact_numeric(value); @@ -83,7 +99,7 @@ class engine { return {}; } - std::unique_ptr operator()(ast::literal::string const& value) { + std::unique_ptr operator()(ast::literal::string const& value) { using ast::literal::kind; switch (value.node_kind()) { case kind::character_string: return process_character_string(value); @@ -101,7 +117,7 @@ class engine { return {}; } - std::unique_ptr operator()(ast::literal::datetime const& value) { + std::unique_ptr operator()(ast::literal::datetime const& value) { using ast::literal::kind; switch (value.node_kind()) { default: @@ -118,12 +134,12 @@ class engine { return {}; } -// std::unique_ptr operator()(ast::literal::interval const& value) { +// std::unique_ptr operator()(ast::literal::interval const& value) { // (void) value; // return {}; // } - std::unique_ptr operator()(ast::literal::null const& value) { + std::unique_ptr operator()(ast::literal::null const& value) { if (auto t = value_context_.type()) { return context_.create( value.region(), @@ -142,12 +158,12 @@ class engine { return {}; } -// std::unique_ptr operator()(ast::literal::empty const& value) { +// std::unique_ptr operator()(ast::literal::empty const& value) { // (void) value; // return {}; // } - std::unique_ptr operator()(ast::literal::default_ const& value) { + std::unique_ptr operator()(ast::literal::default_ const& value) { auto t = value_context_.type(); auto v = value_context_.default_value(); if (t && v) { @@ -174,7 +190,7 @@ class engine { return {}; } - [[nodiscard]] std::unique_ptr process_exact_numeric(ast::literal::numeric const& value) { // NOLINT(*-function-cognitive-complexity) + [[nodiscard]] std::unique_ptr process_exact_numeric(ast::literal::numeric const& value) { // NOLINT(*-function-cognitive-complexity) // FIXME: move to yugawara auto max_precision = static_cast(context_.options()->max_decimal_precision()); ::decimal::Context context { @@ -221,44 +237,44 @@ class engine { if (integer) { if (context_.options()->prefer_small_integer_literals()) { if (auto r = soft_cast(*integer)) { - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::int4 { *r }), - context_.types().get(ttype::int1 {}))); + context_.types().get(ttype::int1 {})); } if (auto r = soft_cast(*integer)) { - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::int4 { *r }), - context_.types().get(ttype::int2 {}))); + context_.types().get(ttype::int2 {})); } if (auto r = soft_cast(*integer)) { - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::int4 { *r }), - context_.types().get(ttype::int4 {}))); + context_.types().get(ttype::int4 {})); } } - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::int8 { *integer }), - context_.types().get(ttype::int8 {}))); + context_.types().get(ttype::int8 {})); } } std::optional precision {}; if (context_.options()->prefer_small_decimal_literals()) { precision = static_cast(::decimal::Decimal::radix()); } - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::decimal { ::takatori::decimal::triple { v } }), context_.types().get(ttype::decimal { precision, static_cast(v.exponent()), - }))); + })); } - [[nodiscard]] std::unique_ptr process_approximate_numeric(ast::literal::numeric const& value) { + [[nodiscard]] std::unique_ptr process_approximate_numeric(ast::literal::numeric const& value) { // FIXME: more accurate values for fp std::size_t index {}; double result {}; @@ -293,10 +309,10 @@ class engine { value.region()); return {}; } - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::float8 { result }), - context_.types().get(ttype::float8 {}))); + context_.types().get(ttype::float8 {})); } [[nodiscard]] std::optional count_characters(ast::literal::string::value_type const& string) { @@ -340,7 +356,7 @@ class engine { } } - [[nodiscard]] std::unique_ptr process_character_string(ast::literal::string const& value) { + [[nodiscard]] std::unique_ptr process_character_string(ast::literal::string const& value) { std::size_t nchars = 0; if (auto n = count_characters(value.value())) { nchars += *n; @@ -366,26 +382,13 @@ class engine { if (context_.options()->prefer_small_character_literals()) { size = nchars; } - return adapt(context_.create( + return context_.create( value.region(), context_.values().get(tvalue::character { std::move(string) }), context_.types().get(ttype::character { ttype::varying, size, - }))); - } - - [[nodiscard]] std::unique_ptr adapt(std::unique_ptr expression) { - if (auto t = value_context_.type(); t && *t != expression->type()) { - // NOTE: here, we only apply cast operation to constant values. - // later optimization (if it available) will reduce this operation - return context_.create( - expression->region(), - std::move(t), - tscalar::cast_loss_policy::error, - std::move(expression)); - } - return expression; + })); } }; diff --git a/test/mizugaki/analyzer/details/analyze_literal_test.cpp b/test/mizugaki/analyzer/details/analyze_literal_test.cpp index 44ede4b..fd36020 100644 --- a/test/mizugaki/analyzer/details/analyze_literal_test.cpp +++ b/test/mizugaki/analyzer/details/analyze_literal_test.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include // NOLINT @@ -234,7 +235,7 @@ TEST_F(analyze_literal_test, null_wo_context) { EXPECT_TRUE(contains(context(), diagnostic_code::missing_context_of_null)); } -TEST_F(analyze_literal_test, dfault) { +TEST_F(analyze_literal_test, default) { auto r = analyze_literal( context(), ast::literal::default_ {}, @@ -249,7 +250,7 @@ TEST_F(analyze_literal_test, dfault) { })); } -TEST_F(analyze_literal_test, dfault_wo_context) { +TEST_F(analyze_literal_test, default_wo_context) { auto r = analyze_literal( context(), ast::literal::default_ {}, @@ -260,4 +261,44 @@ TEST_F(analyze_literal_test, dfault_wo_context) { EXPECT_TRUE(contains(context(), diagnostic_code::missing_context_of_default_value)); } +TEST_F(analyze_literal_test, conversion_by_context_enabled) { + options_.cast_literals_in_context() = true; + auto r = analyze_literal( + context(), + ast::literal::string { + ast::literal::kind::character_string, + "'1'", + }, + { + ttype::int8 {}, + }); + ASSERT_TRUE(r); + EXPECT_EQ(*r, (tscalar::cast { + ttype::int8 {}, + tscalar::cast_loss_policy::error, + tscalar::immediate { + tvalue::character { "1" }, + ttype::character { ttype::varying }, + } + })); +} + +TEST_F(analyze_literal_test, conversion_by_context_disabled) { + options_.cast_literals_in_context() = false; + auto r = analyze_literal( + context(), + ast::literal::string { + ast::literal::kind::character_string, + "'1'", + }, + { + ttype::int8 {}, + }); + ASSERT_TRUE(r); + EXPECT_EQ(*r, (tscalar::immediate { + tvalue::character { "1" }, + ttype::character { ttype::varying }, + })); +} + } // namespace mizugaki::analyzer::details