diff --git a/src/fury/encoder/row_encode_trait.h b/src/fury/encoder/row_encode_trait.h index c0eeb9da34..e87d2f1c2b 100644 --- a/src/fury/encoder/row_encode_trait.h +++ b/src/fury/encoder/row_encode_trait.h @@ -70,9 +70,13 @@ inline constexpr bool IsString = template inline constexpr bool IsArray = meta::IsIterable && !IsString; +template inline constexpr bool IsOptional = false; + +template inline constexpr bool IsOptional> = true; + template inline constexpr bool IsClassButNotBuiltin = - std::is_class_v && !(IsString || IsArray); + std::is_class_v && !(IsString || IsArray || IsOptional); inline decltype(auto) GetChildType(RowWriter &writer, int index) { return writer.schema()->field(index)->type(); @@ -140,6 +144,24 @@ struct RowEncodeTrait< } }; +template +struct RowEncodeTrait< + T, std::enable_if_t>>> { + static auto Type() { return RowEncodeTrait::Type(); } + + template ::value, + int> = 0> + static void Write(V &&visitor, const T &value, W &writer, int index) { + if (value) { + RowEncodeTrait::Write(std::forward(visitor), + *value, writer, index); + } else { + writer.SetNullAt(index); + } + } +}; + template struct RowEncodeTrait< T, std::enable_if_t>>> { diff --git a/src/fury/encoder/row_encode_trait_test.cc b/src/fury/encoder/row_encode_trait_test.cc index 8f049efb47..e8a0ee4471 100644 --- a/src/fury/encoder/row_encode_trait_test.cc +++ b/src/fury/encoder/row_encode_trait_test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include +#include #include #include "fury/encoder/row_encode_trait.h" @@ -247,6 +248,52 @@ TEST(RowEncodeTrait, ArrayInArray) { ASSERT_EQ(array->GetArray(2)->GetInt32(2), 60); } +struct F { + bool a; + std::optional b; + int c; +}; + +FURY_FIELD_INFO(F, a, b, c); + +TEST(RowEncodeTrait, Optional) { + F x{false, 233, 111}, y{true, std::nullopt, 222}; + + auto schema = encoder::RowEncodeTrait::Type(); + ASSERT_EQ(schema->field(0)->type()->name(), "bool"); + ASSERT_EQ(schema->field(1)->type()->name(), "int32"); + ASSERT_EQ(schema->field(2)->type()->name(), "int32"); + + { + RowWriter writer(encoder::RowEncodeTrait::Schema()); + writer.Reset(); + + encoder::RowEncodeTrait::Write(encoder::EmptyWriteVisitor{}, x, writer); + + auto row = writer.ToRow(); + ASSERT_EQ(row->IsNullAt(0), false); + ASSERT_EQ(row->IsNullAt(1), false); + ASSERT_EQ(row->IsNullAt(2), false); + + ASSERT_EQ(row->GetInt32(1), 233); + ASSERT_EQ(row->GetInt32(2), 111); + } + + { + RowWriter writer(encoder::RowEncodeTrait::Schema()); + writer.Reset(); + + encoder::RowEncodeTrait::Write(encoder::EmptyWriteVisitor{}, y, writer); + + auto row = writer.ToRow(); + ASSERT_EQ(row->IsNullAt(0), false); + ASSERT_EQ(row->IsNullAt(1), true); + ASSERT_EQ(row->IsNullAt(2), false); + + ASSERT_EQ(row->GetInt32(2), 222); + } +} + } // namespace test } // namespace fury