From bba1f266c358feb92159c0017e91c779c71fa4cc Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 23 Jul 2024 14:03:02 +0200 Subject: [PATCH 01/14] #268: Add tests for deserialization of the polymorphic types --- tests/unit/test_polymorphic.cc | 127 +++++++++++++++++++++++++++ tests/unit/test_virtual_serialize.cc | 22 +++++ 2 files changed, 149 insertions(+) create mode 100644 tests/unit/test_polymorphic.cc diff --git a/tests/unit/test_polymorphic.cc b/tests/unit/test_polymorphic.cc new file mode 100644 index 00000000..b9e22f97 --- /dev/null +++ b/tests/unit/test_polymorphic.cc @@ -0,0 +1,127 @@ +/* +//@HEADER +// ***************************************************************************** +// +// test_polymorphic.cc +// DARMA/checkpoint => Serialization Library +// +// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ +#include + +#include "test_harness.h" + +#include +#include + +namespace checkpoint { namespace tests { namespace unit { + +using TestPolymorphic = TestHarness; + +struct Base { + explicit Base() = default; + explicit Base(int val_in): base_val_(val_in) {}; + virtual ~Base() = default; + + checkpoint_virtual_serialize_root() + + int base_val_; + virtual int getVal() { + return base_val_; + } + + template + void serialize(Serializer& s) { + s | base_val_; + } +}; + +struct Derived1: public Base { + explicit Derived1() = default; + explicit Derived1(int val_in): Base(0), derived_val_(val_in) {}; + virtual ~Derived1() = default; + + checkpoint_virtual_serialize_derived_from(Base) + + int derived_val_; + int getVal() override { + return derived_val_; + } + + template + void serialize(Serializer& s) { + s | derived_val_; + } +}; + +struct Derived2: public Derived1 { + explicit Derived2() = default; + explicit Derived2(int val_in): Derived1(0), derived_val_2_(val_in) {}; + virtual ~Derived2() = default; + + checkpoint_virtual_serialize_derived_from(Derived1) + + int derived_val_2_; + int getVal() override { + return derived_val_2_; + } + + template + void serialize(Serializer& s) { + s | derived_val_2_; + } +}; + +template +void testPolymorphicTypes(int val) { + std::unique_ptr task(new Derived(val)); + auto ret = checkpoint::serialize(*task); + auto out = checkpoint::deserialize(std::move(ret)); + + EXPECT_TRUE(nullptr != out); + EXPECT_EQ(val, out->getVal()); +} + +TEST_F(TestPolymorphic, test_polumorphic_type) { + testPolymorphicTypes(5); + testPolymorphicTypes(50); + testPolymorphicTypes(500); + testPolymorphicTypes(10); + testPolymorphicTypes(100); + testPolymorphicTypes(1); +} + +}}} // end namespace checkpoint::tests::unit diff --git a/tests/unit/test_virtual_serialize.cc b/tests/unit/test_virtual_serialize.cc index 85df6c60..dfe586f3 100644 --- a/tests/unit/test_virtual_serialize.cc +++ b/tests/unit/test_virtual_serialize.cc @@ -611,6 +611,28 @@ INSTANTIATE_TYPED_TEST_CASE_P( test_virtual_serialize_inst, TestVirtualSerialize, ConstructTypes, ); +/* + * Test for deserialization when using the base class type + */ + +using TestDeserializationFromBase = TestHarness; + +template +void testDeserializationFromBase() { + std::unique_ptr task(new Derived(TEST_CONSTRUCT{})); + auto ret = checkpoint::serialize(*task); + auto out = checkpoint::deserialize(std::move(ret)); + + EXPECT_TRUE(nullptr != out); + out->check(); +} + +TEST_F(TestDeserializationFromBase, test_deserilization_from_base) { + testDeserializationFromBase(); + testDeserializationFromBase(); + testDeserializationFromBase(); +} + //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// From 2927fc3c56b41b2d7d2f04c8dabb109b6b3135f0 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 23 Jul 2024 15:10:39 +0200 Subject: [PATCH 02/14] #268: Add support for deserialization when called from the base type --- src/checkpoint/dispatch/dispatch.h | 50 +++++++++++- src/checkpoint/dispatch/dispatch.impl.h | 103 ++++++++++++++++++++++-- 2 files changed, 143 insertions(+), 10 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index e5c6da1b..46e0a2a3 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -188,6 +188,42 @@ struct Standard { static SerialByteType* allocate(); }; +/** + * \struct Prefixed + * + * \brief Traversal for polymorphic types prefixed with the vrt::TypeIdx + */ +struct Prefixed { + /** + * \brief Traverse a \c target of type \c T recursively with a general \c + * TraverserT that gets applied to each element. + * Allows to traverse only part of the data. + * + * \param[in,out] target the target to traverse + * \param[in] len the len of the target. If > 1, \c target is an array + * \param[in] check_type the flag to control type validation + * \param[in] check_mem the flag to control memory validation + * \param[in] args the args to pass to the traverser for construction + * + * \return the traverser after traversal is complete + */ + template + static TraverserT traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args); + + /** + * \brief Unpack \c T from packed byte-buffer \c mem + * + * \param[in] t_buf bytes holding a serialized \c T + * \param[in] check_type the flag to control type validation + * \param[in] check_mem the flag to control memory validation + * \param[in] args arguments to the unpacker's constructor + * + * \return a pointer to an unpacked \c T + */ + template + static T* unpack(T* t_buf, bool check_type = true, bool check_mem = true, Args&&... args); +}; + template buffer::ImplReturnType packBuffer( T& target, SerialSizeType size, BufferObtainFnType fn @@ -197,10 +233,20 @@ template inline void serializeArray(Serializer& s, T* array, SerialSizeType const len); template -buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn = nullptr); +typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn = nullptr); + +template +typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn = nullptr); + +template +typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); template -T* deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); +typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); template void deserializeType(InPlaceTag, SerialByteType* data, T* t); diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index 44aac6a2..d88e3b00 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -71,7 +71,7 @@ struct serialization_error : public std::runtime_error { }; template -TraverserT& withTypeIdx(TraverserT& t) { +TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { using CleanT = typename CleanType::CleanT; using DispatchType = typename TraverserT::template DispatcherType; @@ -87,9 +87,9 @@ TraverserT& withTypeIdx(TraverserT& t) { auto val = cleanType(&serTypeIdx); ap(t, val, len); - if ( - typeregistry::validateIndex(serTypeIdx) == false || - thisTypeIdx != serTypeIdx + if (check_type && + (typeregistry::validateIndex(serTypeIdx) == false || + thisTypeIdx != serTypeIdx) ) { auto const err = std::string("Unpacking wrong type, got=") + typeregistry::getTypeNameForIdx(thisTypeIdx) + @@ -105,7 +105,7 @@ TraverserT& withTypeIdx(TraverserT& t) { } template -TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { +TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true) { using DispatchType = typename TraverserT::template DispatcherType; SerializerDispatch ap; @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { auto val = cleanType(&serMemUsed); ap(t, val, memUsedLen); - if (memUsed != serMemUsed) { + if (check_mem && memUsed != serMemUsed) { using CleanT = typename CleanType::CleanT; std::string msg = "For type '" + typeregistry::getTypeName() + "' serialization used " + std::to_string(serMemUsed) + @@ -133,6 +133,37 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { return t; } +template +TraverserT Prefixed::traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args) { + using CleanT = typename CleanType::CleanT; + using DispatchType = + typename TraverserT::template DispatcherType; + + TraverserT t(std::forward(args)...); + + withTypeIdx(t, check_type); + + auto val = cleanType(&target); + SerializerDispatch ap; + + #if defined(SERIALIZATION_ERROR_CHECKING) + try { + ap(t, val, len); + } catch (serialization_error const& err) { + auto const depth = err.depth_ + 1; + auto const what = std::string(err.what()) + "\n#" + std::to_string(depth) + + " " + typeregistry::getTypeName(); + throw serialization_error(what, depth); + } + #else + ap(t, val, len); + #endif + + withMemUsed(t, 1, check_mem); + + return t; +} + template TraverserT& Traverse::with(T& target, TraverserT& t, SerialSizeType len) { using CleanT = typename CleanType::CleanT; @@ -221,6 +252,12 @@ T* Standard::unpack(T* t_buf, Args&&... args) { return t_buf; } +template +T* Prefixed::unpack(T* t_buf, bool check_type, bool check_mem, Args&&... args) { + Prefixed::traverse(*t_buf, 1, check_type, check_mem, std::forward(args)...); + return t_buf; +} + template T* Standard::construct(SerialByteType* mem) { return Traverse::reconstruct(mem); @@ -266,14 +303,17 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) { } template -buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn) { +typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn) { auto len = Standard::size(target); debug_checkpoint("serializeType: len=%ld\n", len); return packBuffer(target, len, fn); } template -T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) { +typename std::enable_if::value + || !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : Standard::allocate(); auto t_buf = std::unique_ptr(Standard::construct(mem)); T* traverser = @@ -282,11 +322,58 @@ T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) { return traverser; } +// TODO: this also needs to be updated template void deserializeType(InPlaceTag, SerialByteType* data, T* t) { Standard::unpack>(t, data); } +template +struct PrefixedType { + explicit PrefixedType(T* target) : target_(target) { + prefix_ = target->_checkpointDynamicTypeIndex(); + } + + vrt::TypeIdx prefix_; + T* target_; + + template + void serialize(SerializerT& s) { + s | prefix_; + s | *target_; + } +}; + +template +typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn) { + auto prefixed = PrefixedType(&target); + auto len = Standard::size, Sizer>(prefixed); + debug_checkpoint("serializeType: len=%ld\n", len); + return packBuffer>(prefixed, len, fn); +} + +template +typename std::enable_if::value + && vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf) { + using BaseType = vrt::checkpoint_base_type_t; + + auto prefix_mem = Standard::allocate(); + auto prefix_buf = std::unique_ptr(Standard::construct(prefix_mem)); + vrt::TypeIdx* prefix = + Prefixed::unpack>(prefix_buf.get(), false, false, data); + prefix_buf.release(); + + auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(*prefix); + auto t_buf = vrt::objregistry::constructConcreteType(*prefix, mem); + auto prefixed = PrefixedType(t_buf); + + auto* traverser = + Prefixed::unpack>(&prefixed, false, true, data); + return static_cast(traverser->target_); +} + }} /* end namespace checkpoint::dispatch */ #endif /*INCLUDED_SRC_CHECKPOINT_DISPATCH_DISPATCH_IMPL_H*/ From d1a6d6c388aaec54f58f611808c7de191ed2b27b Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 23 Jul 2024 15:48:43 +0200 Subject: [PATCH 03/14] #268: Add comments in the deserializeType method --- src/checkpoint/dispatch/dispatch.h | 4 ++-- src/checkpoint/dispatch/dispatch.impl.h | 27 ++++++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index 46e0a2a3..16f87129 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -213,7 +213,7 @@ struct Prefixed { /** * \brief Unpack \c T from packed byte-buffer \c mem * - * \param[in] t_buf bytes holding a serialized \c T + * \param[in] mem bytes holding a serialized \c T * \param[in] check_type the flag to control type validation * \param[in] check_mem the flag to control memory validation * \param[in] args arguments to the unpacker's constructor @@ -221,7 +221,7 @@ struct Prefixed { * \return a pointer to an unpacked \c T */ template - static T* unpack(T* t_buf, bool check_type = true, bool check_mem = true, Args&&... args); + static T* unpack(T* mem, bool check_type = true, bool check_mem = true, Args&&... args); }; template diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index d88e3b00..cfadb5d1 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -303,7 +303,9 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) { } template -typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +typename std::enable_if< + !std::is_class::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, + buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn) { auto len = Standard::size(target); debug_checkpoint("serializeType: len=%ld\n", len); @@ -311,8 +313,9 @@ serializeType(T& target, BufferObtainFnType fn) { } template -typename std::enable_if::value - || !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +typename std::enable_if< + !std::is_class::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, + T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : Standard::allocate(); auto t_buf = std::unique_ptr(Standard::construct(mem)); @@ -322,7 +325,6 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) { return traverser; } -// TODO: this also needs to be updated template void deserializeType(InPlaceTag, SerialByteType* data, T* t) { Standard::unpack>(t, data); @@ -345,30 +347,35 @@ struct PrefixedType { }; template -typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +typename std::enable_if< + std::is_class::value && vrt::VirtualSerializeTraits::has_virtual_serialize, + buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn) { auto prefixed = PrefixedType(&target); - auto len = Standard::size, Sizer>(prefixed); + auto len = Standard::size(prefixed); debug_checkpoint("serializeType: len=%ld\n", len); - return packBuffer>(prefixed, len, fn); + return packBuffer(prefixed, len, fn); } template -typename std::enable_if::value - && vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +typename std::enable_if< + std::is_class::value && vrt::VirtualSerializeTraits::has_virtual_serialize, + T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { using BaseType = vrt::checkpoint_base_type_t; auto prefix_mem = Standard::allocate(); auto prefix_buf = std::unique_ptr(Standard::construct(prefix_mem)); + // Unpack TypeIdx, ignore checks for type and memory used - unpacking will only use a part of the data vrt::TypeIdx* prefix = Prefixed::unpack>(prefix_buf.get(), false, false, data); prefix_buf.release(); + // allocate memory based on the readed TypeIdx auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(*prefix); auto t_buf = vrt::objregistry::constructConcreteType(*prefix, mem); auto prefixed = PrefixedType(t_buf); - + // Unpack PrefixedType, ignore checks for unpacked type and execute checks for memory used auto* traverser = Prefixed::unpack>(&prefixed, false, true, data); return static_cast(traverser->target_); From de5058395b6caed02a914cb1a18a10046c8f973b Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 23 Jul 2024 16:04:38 +0200 Subject: [PATCH 04/14] #268: Remove default parameters from Prefixed::unpack method --- src/checkpoint/dispatch/dispatch.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index 16f87129..67193ea9 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -221,7 +221,7 @@ struct Prefixed { * \return a pointer to an unpacked \c T */ template - static T* unpack(T* mem, bool check_type = true, bool check_mem = true, Args&&... args); + static T* unpack(T* mem, bool check_type, bool check_mem, Args&&... args); }; template From 57a77ba9df892c7d6e5e25810e70f1f113e08e5c Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 23 Jul 2024 17:17:21 +0200 Subject: [PATCH 05/14] #268: Add prefix validation to protect against requests with invalid types --- src/checkpoint/dispatch/dispatch.h | 16 +++++++++++---- src/checkpoint/dispatch/dispatch.impl.h | 20 +++++++++++++++---- src/checkpoint/dispatch/vrt/object_registry.h | 5 +++++ tests/unit/test_polymorphic.cc | 2 +- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index 67193ea9..8470a0e1 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -222,6 +222,14 @@ struct Prefixed { */ template static T* unpack(T* mem, bool check_type, bool check_mem, Args&&... args); + + /** + * \brief Check if prefix is valid + * + * \param[in] prefix the prefix to be validated + */ + template + static void validatePrefix(vrt::TypeIdx prefix); }; template @@ -233,19 +241,19 @@ template inline void serializeArray(Serializer& s, T* array, SerialSizeType const len); template -typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +typename std::enable_if::has_virtual_serialize, buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn = nullptr); template -typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +typename std::enable_if::has_virtual_serialize, buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn = nullptr); template -typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); template -typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); template diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index cfadb5d1..35a1db3d 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -304,7 +304,7 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) { template typename std::enable_if< - !std::is_class::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, + !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn) { auto len = Standard::size(target); @@ -314,7 +314,7 @@ serializeType(T& target, BufferObtainFnType fn) { template typename std::enable_if< - !std::is_class::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, + !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : Standard::allocate(); @@ -348,7 +348,7 @@ struct PrefixedType { template typename std::enable_if< - std::is_class::value && vrt::VirtualSerializeTraits::has_virtual_serialize, + vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn) { auto prefixed = PrefixedType(&target); @@ -357,9 +357,19 @@ serializeType(T& target, BufferObtainFnType fn) { return packBuffer(prefixed, len, fn); } +template +void Prefixed::validatePrefix(vrt::TypeIdx prefix) { + if (!vrt::objregistry::isValidIdx(prefix)) { + std::string const err = std::string("Unpacking invalid prefix type (") + + std::to_string(prefix) + std::string(") from object registry for type=") + + std::string(typeregistry::getTypeName()); + throw serialization_error(err); + } +} + template typename std::enable_if< - std::is_class::value && vrt::VirtualSerializeTraits::has_virtual_serialize, + vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { using BaseType = vrt::checkpoint_base_type_t; @@ -371,6 +381,8 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) { Prefixed::unpack>(prefix_buf.get(), false, false, data); prefix_buf.release(); + Prefixed::validatePrefix(*prefix); + // allocate memory based on the readed TypeIdx auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(*prefix); auto t_buf = vrt::objregistry::constructConcreteType(*prefix, mem); diff --git a/src/checkpoint/dispatch/vrt/object_registry.h b/src/checkpoint/dispatch/vrt/object_registry.h index 54da0dea..e92c26f3 100644 --- a/src/checkpoint/dispatch/vrt/object_registry.h +++ b/src/checkpoint/dispatch/vrt/object_registry.h @@ -124,6 +124,11 @@ inline auto getObjIdx(TypeIdx han) { return getRegistry().at(han).idx_; } +template +inline auto isValidIdx(TypeIdx han) { + return getRegistry().size() > static_cast(han); +} + template inline auto getSizeConcreteType(TypeIdx han) { return getRegistry().at(han).size_; diff --git a/tests/unit/test_polymorphic.cc b/tests/unit/test_polymorphic.cc index b9e22f97..57944fd8 100644 --- a/tests/unit/test_polymorphic.cc +++ b/tests/unit/test_polymorphic.cc @@ -115,7 +115,7 @@ void testPolymorphicTypes(int val) { EXPECT_EQ(val, out->getVal()); } -TEST_F(TestPolymorphic, test_polumorphic_type) { +TEST_F(TestPolymorphic, test_polymorphic_type) { testPolymorphicTypes(5); testPolymorphicTypes(50); testPolymorphicTypes(500); From 0aab79218a8d43a27123ef84b2b223d035ae6882 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Wed, 24 Jul 2024 19:22:30 +0200 Subject: [PATCH 06/14] #268: Refactor dispatcher to properly deserialize polymorphic types with/without serialization error checking --- src/checkpoint/dispatch/dispatch.h | 44 ---------- src/checkpoint/dispatch/dispatch.impl.h | 90 ++++++-------------- src/checkpoint/serializers/base_serializer.h | 7 ++ src/checkpoint/serializers/unpacker.h | 6 +- src/checkpoint/serializers/unpacker.impl.h | 10 ++- 5 files changed, 47 insertions(+), 110 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index 8470a0e1..91dba3d6 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -188,50 +188,6 @@ struct Standard { static SerialByteType* allocate(); }; -/** - * \struct Prefixed - * - * \brief Traversal for polymorphic types prefixed with the vrt::TypeIdx - */ -struct Prefixed { - /** - * \brief Traverse a \c target of type \c T recursively with a general \c - * TraverserT that gets applied to each element. - * Allows to traverse only part of the data. - * - * \param[in,out] target the target to traverse - * \param[in] len the len of the target. If > 1, \c target is an array - * \param[in] check_type the flag to control type validation - * \param[in] check_mem the flag to control memory validation - * \param[in] args the args to pass to the traverser for construction - * - * \return the traverser after traversal is complete - */ - template - static TraverserT traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args); - - /** - * \brief Unpack \c T from packed byte-buffer \c mem - * - * \param[in] mem bytes holding a serialized \c T - * \param[in] check_type the flag to control type validation - * \param[in] check_mem the flag to control memory validation - * \param[in] args arguments to the unpacker's constructor - * - * \return a pointer to an unpacked \c T - */ - template - static T* unpack(T* mem, bool check_type, bool check_mem, Args&&... args); - - /** - * \brief Check if prefix is valid - * - * \param[in] prefix the prefix to be validated - */ - template - static void validatePrefix(vrt::TypeIdx prefix); -}; - template buffer::ImplReturnType packBuffer( T& target, SerialSizeType size, BufferObtainFnType fn diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index 35a1db3d..efa3d659 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -71,7 +71,7 @@ struct serialization_error : public std::runtime_error { }; template -TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { +TraverserT& withTypeIdx(TraverserT& t) { using CleanT = typename CleanType::CleanT; using DispatchType = typename TraverserT::template DispatcherType; @@ -87,9 +87,9 @@ TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { auto val = cleanType(&serTypeIdx); ap(t, val, len); - if (check_type && - (typeregistry::validateIndex(serTypeIdx) == false || - thisTypeIdx != serTypeIdx) + if ( + typeregistry::validateIndex(serTypeIdx) == false || + thisTypeIdx != serTypeIdx ) { auto const err = std::string("Unpacking wrong type, got=") + typeregistry::getTypeNameForIdx(thisTypeIdx) + @@ -105,7 +105,7 @@ TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { } template -TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true) { +TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { using DispatchType = typename TraverserT::template DispatcherType; SerializerDispatch ap; @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true auto val = cleanType(&serMemUsed); ap(t, val, memUsedLen); - if (check_mem && memUsed != serMemUsed) { + if (t.shouldValidateMemory() && memUsed != serMemUsed) { using CleanT = typename CleanType::CleanT; std::string msg = "For type '" + typeregistry::getTypeName() + "' serialization used " + std::to_string(serMemUsed) + @@ -133,37 +133,6 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true return t; } -template -TraverserT Prefixed::traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args) { - using CleanT = typename CleanType::CleanT; - using DispatchType = - typename TraverserT::template DispatcherType; - - TraverserT t(std::forward(args)...); - - withTypeIdx(t, check_type); - - auto val = cleanType(&target); - SerializerDispatch ap; - - #if defined(SERIALIZATION_ERROR_CHECKING) - try { - ap(t, val, len); - } catch (serialization_error const& err) { - auto const depth = err.depth_ + 1; - auto const what = std::string(err.what()) + "\n#" + std::to_string(depth) + - " " + typeregistry::getTypeName(); - throw serialization_error(what, depth); - } - #else - ap(t, val, len); - #endif - - withMemUsed(t, 1, check_mem); - - return t; -} - template TraverserT& Traverse::with(T& target, TraverserT& t, SerialSizeType len) { using CleanT = typename CleanType::CleanT; @@ -252,12 +221,6 @@ T* Standard::unpack(T* t_buf, Args&&... args) { return t_buf; } -template -T* Prefixed::unpack(T* t_buf, bool check_type, bool check_mem, Args&&... args) { - Prefixed::traverse(*t_buf, 1, check_type, check_mem, std::forward(args)...); - return t_buf; -} - template T* Standard::construct(SerialByteType* mem) { return Traverse::reconstruct(mem); @@ -313,9 +276,7 @@ serializeType(T& target, BufferObtainFnType fn) { } template -typename std::enable_if< - !vrt::VirtualSerializeTraits::has_virtual_serialize, - T*>::type +typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : Standard::allocate(); auto t_buf = std::unique_ptr(Standard::construct(mem)); @@ -351,14 +312,17 @@ typename std::enable_if< vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn) { + using BaseType = vrt::checkpoint_base_type_t; + using PrefixedType = PrefixedType; + auto prefixed = PrefixedType(&target); - auto len = Standard::size(prefixed); + auto len = Standard::size(prefixed); debug_checkpoint("serializeType: len=%ld\n", len); - return packBuffer(prefixed, len, fn); + return packBuffer(prefixed, len, fn); } template -void Prefixed::validatePrefix(vrt::TypeIdx prefix) { +void validatePrefix(vrt::TypeIdx prefix) { if (!vrt::objregistry::isValidIdx(prefix)) { std::string const err = std::string("Unpacking invalid prefix type (") + std::to_string(prefix) + std::string(") from object registry for type=") + @@ -368,28 +332,28 @@ void Prefixed::validatePrefix(vrt::TypeIdx prefix) { } template -typename std::enable_if< - vrt::VirtualSerializeTraits::has_virtual_serialize, - T*>::type +typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { using BaseType = vrt::checkpoint_base_type_t; + using PrefixedType = PrefixedType; - auto prefix_mem = Standard::allocate(); - auto prefix_buf = std::unique_ptr(Standard::construct(prefix_mem)); - // Unpack TypeIdx, ignore checks for type and memory used - unpacking will only use a part of the data - vrt::TypeIdx* prefix = - Prefixed::unpack>(prefix_buf.get(), false, false, data); - prefix_buf.release(); + auto prefix_mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(0); + auto prefix_buf = vrt::objregistry::constructConcreteType(0, prefix_mem); + auto prefix_struct = PrefixedType(prefix_buf); + // Disable memory check during first unpacking. + // Unpacking BaseType will always result in memory amount missmatch between serialization/deserialization + auto* prefix = + Standard::unpack>(&prefix_struct, data, false); + delete prefix_buf; - Prefixed::validatePrefix(*prefix); + validatePrefix(prefix->prefix_); // allocate memory based on the readed TypeIdx - auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(*prefix); - auto t_buf = vrt::objregistry::constructConcreteType(*prefix, mem); + auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(prefix->prefix_); + auto t_buf = vrt::objregistry::constructConcreteType(prefix->prefix_, mem); auto prefixed = PrefixedType(t_buf); - // Unpack PrefixedType, ignore checks for unpacked type and execute checks for memory used auto* traverser = - Prefixed::unpack>(&prefixed, false, true, data); + Standard::unpack>(&prefixed, data); return static_cast(traverser->target_); } diff --git a/src/checkpoint/serializers/base_serializer.h b/src/checkpoint/serializers/base_serializer.h index e64251cc..5bcd5904 100644 --- a/src/checkpoint/serializers/base_serializer.h +++ b/src/checkpoint/serializers/base_serializer.h @@ -203,6 +203,13 @@ struct BaseSerializer { */ void setVirtualDisabled(bool val) { virtual_disabled_ = val; } + /** + * \brief Check if used memory should be validated + * + * \return whether memory should be validated + */ + bool shouldValidateMemory() const { return true; } + protected: ModeType cur_mode_ = ModeType::Invalid; /**< The current mode */ bool virtual_disabled_ = false; /**< Virtual serialization disabled */ diff --git a/src/checkpoint/serializers/unpacker.h b/src/checkpoint/serializers/unpacker.h index 85e9a850..4c31f82b 100644 --- a/src/checkpoint/serializers/unpacker.h +++ b/src/checkpoint/serializers/unpacker.h @@ -54,7 +54,7 @@ template struct UnpackerBuffer : MemorySerializer { using BufferPtrType = std::unique_ptr; - explicit UnpackerBuffer(SerialByteType* buf); + explicit UnpackerBuffer(SerialByteType* buf, bool validate = true); template explicit UnpackerBuffer(Args&&... args); @@ -62,11 +62,15 @@ struct UnpackerBuffer : MemorySerializer { void contiguousBytes(void* ptr, SerialSizeType size, SerialSizeType num_elms); SerialSizeType usedBufferSize() const; + bool shouldValidateMemory() const; + private: // Size of the actually used memory (for error checking) SerialSizeType usedSize_ = 0; BufferPtrType buffer_ = nullptr; + + bool validate_memory_ = true; }; using Unpacker = UnpackerBuffer; diff --git a/src/checkpoint/serializers/unpacker.impl.h b/src/checkpoint/serializers/unpacker.impl.h index ad2429b2..73bab01e 100644 --- a/src/checkpoint/serializers/unpacker.impl.h +++ b/src/checkpoint/serializers/unpacker.impl.h @@ -54,9 +54,10 @@ namespace checkpoint { template -UnpackerBuffer::UnpackerBuffer(SerialByteType* buf) +UnpackerBuffer::UnpackerBuffer(SerialByteType* buf, bool validate) : MemorySerializer(ModeType::Unpacking), - buffer_(std::make_unique(buf, 0)) + buffer_(std::make_unique(buf, 0)), + validate_memory_(validate) { MemorySerializer::initializeBuffer(buffer_->getBuffer()); @@ -103,6 +104,11 @@ SerialSizeType UnpackerBuffer::usedBufferSize() const { return usedSize_; } +template +bool UnpackerBuffer::shouldValidateMemory() const { + return validate_memory_; +} + } /* end namespace checkpoint */ #endif /*INCLUDED_SRC_CHECKPOINT_SERIALIZERS_UNPACKER_IMPL_H*/ From 468af6ad4458c00588bd34503199e7bb434e13f9 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Wed, 24 Jul 2024 19:31:06 +0200 Subject: [PATCH 07/14] #268: Add check for type id in virtual serialize tests --- tests/unit/test_virtual_serialize.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_virtual_serialize.cc b/tests/unit/test_virtual_serialize.cc index dfe586f3..f7967511 100644 --- a/tests/unit/test_virtual_serialize.cc +++ b/tests/unit/test_virtual_serialize.cc @@ -624,6 +624,7 @@ void testDeserializationFromBase() { auto out = checkpoint::deserialize(std::move(ret)); EXPECT_TRUE(nullptr != out); + EXPECT_EQ(TestEnum::Derived2, out->getID()); out->check(); } From a2e7abe93a8c66645da6ae526081f3d26f806bc0 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Wed, 24 Jul 2024 20:51:00 +0200 Subject: [PATCH 08/14] #268: Add validate parameter to all constructors in Unpacker --- src/checkpoint/serializers/unpacker.impl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/checkpoint/serializers/unpacker.impl.h b/src/checkpoint/serializers/unpacker.impl.h index 73bab01e..23faaeb7 100644 --- a/src/checkpoint/serializers/unpacker.impl.h +++ b/src/checkpoint/serializers/unpacker.impl.h @@ -72,7 +72,8 @@ template template UnpackerBuffer::UnpackerBuffer(Args&&... args) : MemorySerializer(ModeType::Unpacking), - buffer_(std::make_unique(std::forward(args)...)) + buffer_(std::make_unique(std::forward(args)...)), + validate_memory_(true) { MemorySerializer::initializeBuffer(buffer_->getBuffer()); From 35e050cdffe87a17e53cba0f6d106660e6e991ad Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Wed, 24 Jul 2024 20:51:23 +0200 Subject: [PATCH 09/14] #268: Add check for type id in virtual serialize test --- tests/unit/test_virtual_serialize.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_virtual_serialize.cc b/tests/unit/test_virtual_serialize.cc index f7967511..2d881ad8 100644 --- a/tests/unit/test_virtual_serialize.cc +++ b/tests/unit/test_virtual_serialize.cc @@ -618,20 +618,21 @@ INSTANTIATE_TYPED_TEST_CASE_P( using TestDeserializationFromBase = TestHarness; template -void testDeserializationFromBase() { +void testDeserializationFromBase(TestEnum expected_id) { std::unique_ptr task(new Derived(TEST_CONSTRUCT{})); auto ret = checkpoint::serialize(*task); auto out = checkpoint::deserialize(std::move(ret)); EXPECT_TRUE(nullptr != out); - EXPECT_EQ(TestEnum::Derived2, out->getID()); + EXPECT_EQ(expected_id, out->getID()); out->check(); } -TEST_F(TestDeserializationFromBase, test_deserilization_from_base) { - testDeserializationFromBase(); - testDeserializationFromBase(); - testDeserializationFromBase(); +TEST_F(TestDeserializationFromBase, test_deserialization_from_base) { + testDeserializationFromBase( + TestEnum::Derived3); + testDeserializationFromBase( + TestEnum::Derived2); } //////////////////////////////////////////////////////////////////////////////// From 7f78f7313ee7785eecd523875b1b1a0c40307b38 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Wed, 24 Jul 2024 21:34:47 +0200 Subject: [PATCH 10/14] #268: Move delete to the bottom of deserializeType function --- src/checkpoint/dispatch/dispatch.impl.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index efa3d659..fc6a5dc3 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -342,9 +342,7 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto prefix_struct = PrefixedType(prefix_buf); // Disable memory check during first unpacking. // Unpacking BaseType will always result in memory amount missmatch between serialization/deserialization - auto* prefix = - Standard::unpack>(&prefix_struct, data, false); - delete prefix_buf; + auto* prefix = Standard::unpack>(&prefix_struct, data, false); validatePrefix(prefix->prefix_); @@ -352,8 +350,9 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(prefix->prefix_); auto t_buf = vrt::objregistry::constructConcreteType(prefix->prefix_, mem); auto prefixed = PrefixedType(t_buf); - auto* traverser = - Standard::unpack>(&prefixed, data); + auto* traverser = Standard::unpack>(&prefixed, data); + + delete prefix_buf; return static_cast(traverser->target_); } From 148b93cbb36b5a2f27d85782d0ac1ce1f88322fb Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 30 Jul 2024 15:29:38 +0200 Subject: [PATCH 11/14] #268: Remove the need for the second pass when deserializing a polymorphic type --- src/checkpoint/dispatch/dispatch.impl.h | 56 +++++++++++++------------ 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index fc6a5dc3..c5d4afc6 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -291,18 +291,44 @@ void deserializeType(InPlaceTag, SerialByteType* data, T* t) { Standard::unpack>(t, data); } +template +void validatePrefix(vrt::TypeIdx prefix) { + if (!vrt::objregistry::isValidIdx(prefix)) { + std::string const err = std::string("Unpacking invalid prefix type (") + + std::to_string(prefix) + std::string(") from object registry for type=") + + std::string(typeregistry::getTypeName()); + throw serialization_error(err); + } +} + template struct PrefixedType { + using BaseType = vrt::checkpoint_base_type_t; + explicit PrefixedType(T* target) : target_(target) { prefix_ = target->_checkpointDynamicTypeIndex(); } - vrt::TypeIdx prefix_; - T* target_; + explicit PrefixedType(SerialByteType* allocBuf) + : unpack_buf_(allocBuf) { + } + + vrt::TypeIdx prefix_ = 0; + T* target_ = nullptr; + SerialByteType* unpack_buf_ = nullptr; template void serialize(SerializerT& s) { s | prefix_; + + // Determine the correct type and allocate memory + if (s.isUnpacking()) { + validatePrefix(prefix_); + + auto mem = unpack_buf_ ? unpack_buf_ : vrt::objregistry::allocateConcreteType(prefix_); + target_ = vrt::objregistry::constructConcreteType(prefix_, mem); + } + s | *target_; } }; @@ -321,38 +347,14 @@ serializeType(T& target, BufferObtainFnType fn) { return packBuffer(prefixed, len, fn); } -template -void validatePrefix(vrt::TypeIdx prefix) { - if (!vrt::objregistry::isValidIdx(prefix)) { - std::string const err = std::string("Unpacking invalid prefix type (") + - std::to_string(prefix) + std::string(") from object registry for type=") + - std::string(typeregistry::getTypeName()); - throw serialization_error(err); - } -} - template typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { using BaseType = vrt::checkpoint_base_type_t; using PrefixedType = PrefixedType; - auto prefix_mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(0); - auto prefix_buf = vrt::objregistry::constructConcreteType(0, prefix_mem); - auto prefix_struct = PrefixedType(prefix_buf); - // Disable memory check during first unpacking. - // Unpacking BaseType will always result in memory amount missmatch between serialization/deserialization - auto* prefix = Standard::unpack>(&prefix_struct, data, false); - - validatePrefix(prefix->prefix_); - - // allocate memory based on the readed TypeIdx - auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(prefix->prefix_); - auto t_buf = vrt::objregistry::constructConcreteType(prefix->prefix_, mem); - auto prefixed = PrefixedType(t_buf); + auto prefixed = PrefixedType(allocBuf); auto* traverser = Standard::unpack>(&prefixed, data); - - delete prefix_buf; return static_cast(traverser->target_); } From 03df7fa376e9f0d271ffd5bcd05e1e0779e474b5 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 30 Jul 2024 15:56:41 +0200 Subject: [PATCH 12/14] #268: Remove validate_memory_ member from UnpackerBuffer --- src/checkpoint/dispatch/dispatch.impl.h | 46 +++++++++++--------- src/checkpoint/serializers/base_serializer.h | 7 --- src/checkpoint/serializers/unpacker.h | 6 +-- src/checkpoint/serializers/unpacker.impl.h | 13 ++---- 4 files changed, 29 insertions(+), 43 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index c5d4afc6..faf35db1 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { auto val = cleanType(&serMemUsed); ap(t, val, memUsedLen); - if (t.shouldValidateMemory() && memUsed != serMemUsed) { + if (memUsed != serMemUsed) { using CleanT = typename CleanType::CleanT; std::string msg = "For type '" + typeregistry::getTypeName() + "' serialization used " + std::to_string(serMemUsed) + @@ -291,31 +291,17 @@ void deserializeType(InPlaceTag, SerialByteType* data, T* t) { Standard::unpack>(t, data); } -template -void validatePrefix(vrt::TypeIdx prefix) { - if (!vrt::objregistry::isValidIdx(prefix)) { - std::string const err = std::string("Unpacking invalid prefix type (") + - std::to_string(prefix) + std::string(") from object registry for type=") + - std::string(typeregistry::getTypeName()); - throw serialization_error(err); - } -} - template struct PrefixedType { using BaseType = vrt::checkpoint_base_type_t; - explicit PrefixedType(T* target) : target_(target) { + // Create PrefixedType for serialization purposes + explicit PrefixedType(BaseType* target) : target_(target) { prefix_ = target->_checkpointDynamicTypeIndex(); } - explicit PrefixedType(SerialByteType* allocBuf) - : unpack_buf_(allocBuf) { - } - - vrt::TypeIdx prefix_ = 0; - T* target_ = nullptr; - SerialByteType* unpack_buf_ = nullptr; + // Create PrefixedType for deserialization purposes + explicit PrefixedType(SerialByteType* allocBuf) : unpack_buf_(allocBuf) { } template void serialize(SerializerT& s) { @@ -323,7 +309,7 @@ struct PrefixedType { // Determine the correct type and allocate memory if (s.isUnpacking()) { - validatePrefix(prefix_); + validatePrefix(prefix_); auto mem = unpack_buf_ ? unpack_buf_ : vrt::objregistry::allocateConcreteType(prefix_); target_ = vrt::objregistry::constructConcreteType(prefix_, mem); @@ -331,6 +317,24 @@ struct PrefixedType { s | *target_; } + + BaseType* getTarget() const { + return target_; + } + +private: + void validatePrefix(vrt::TypeIdx prefix) { + if (!vrt::objregistry::isValidIdx(prefix)) { + std::string const err = std::string("Unpacking invalid prefix type (") + + std::to_string(prefix) + std::string(") from object registry for type=") + + std::string(typeregistry::getTypeName()); + throw serialization_error(err); + } + } + + vrt::TypeIdx prefix_ = 0; + BaseType* target_ = nullptr; + SerialByteType* unpack_buf_ = nullptr; }; template @@ -355,7 +359,7 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto prefixed = PrefixedType(allocBuf); auto* traverser = Standard::unpack>(&prefixed, data); - return static_cast(traverser->target_); + return static_cast(traverser->getTarget()); } }} /* end namespace checkpoint::dispatch */ diff --git a/src/checkpoint/serializers/base_serializer.h b/src/checkpoint/serializers/base_serializer.h index 5bcd5904..e64251cc 100644 --- a/src/checkpoint/serializers/base_serializer.h +++ b/src/checkpoint/serializers/base_serializer.h @@ -203,13 +203,6 @@ struct BaseSerializer { */ void setVirtualDisabled(bool val) { virtual_disabled_ = val; } - /** - * \brief Check if used memory should be validated - * - * \return whether memory should be validated - */ - bool shouldValidateMemory() const { return true; } - protected: ModeType cur_mode_ = ModeType::Invalid; /**< The current mode */ bool virtual_disabled_ = false; /**< Virtual serialization disabled */ diff --git a/src/checkpoint/serializers/unpacker.h b/src/checkpoint/serializers/unpacker.h index 4c31f82b..85e9a850 100644 --- a/src/checkpoint/serializers/unpacker.h +++ b/src/checkpoint/serializers/unpacker.h @@ -54,7 +54,7 @@ template struct UnpackerBuffer : MemorySerializer { using BufferPtrType = std::unique_ptr; - explicit UnpackerBuffer(SerialByteType* buf, bool validate = true); + explicit UnpackerBuffer(SerialByteType* buf); template explicit UnpackerBuffer(Args&&... args); @@ -62,15 +62,11 @@ struct UnpackerBuffer : MemorySerializer { void contiguousBytes(void* ptr, SerialSizeType size, SerialSizeType num_elms); SerialSizeType usedBufferSize() const; - bool shouldValidateMemory() const; - private: // Size of the actually used memory (for error checking) SerialSizeType usedSize_ = 0; BufferPtrType buffer_ = nullptr; - - bool validate_memory_ = true; }; using Unpacker = UnpackerBuffer; diff --git a/src/checkpoint/serializers/unpacker.impl.h b/src/checkpoint/serializers/unpacker.impl.h index 23faaeb7..ad2429b2 100644 --- a/src/checkpoint/serializers/unpacker.impl.h +++ b/src/checkpoint/serializers/unpacker.impl.h @@ -54,10 +54,9 @@ namespace checkpoint { template -UnpackerBuffer::UnpackerBuffer(SerialByteType* buf, bool validate) +UnpackerBuffer::UnpackerBuffer(SerialByteType* buf) : MemorySerializer(ModeType::Unpacking), - buffer_(std::make_unique(buf, 0)), - validate_memory_(validate) + buffer_(std::make_unique(buf, 0)) { MemorySerializer::initializeBuffer(buffer_->getBuffer()); @@ -72,8 +71,7 @@ template template UnpackerBuffer::UnpackerBuffer(Args&&... args) : MemorySerializer(ModeType::Unpacking), - buffer_(std::make_unique(std::forward(args)...)), - validate_memory_(true) + buffer_(std::make_unique(std::forward(args)...)) { MemorySerializer::initializeBuffer(buffer_->getBuffer()); @@ -105,11 +103,6 @@ SerialSizeType UnpackerBuffer::usedBufferSize() const { return usedSize_; } -template -bool UnpackerBuffer::shouldValidateMemory() const { - return validate_memory_; -} - } /* end namespace checkpoint */ #endif /*INCLUDED_SRC_CHECKPOINT_SERIALIZERS_UNPACKER_IMPL_H*/ From 639fea019345a94373aea2eaa4da7b4a400300d1 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 24 Sep 2024 16:09:14 +0200 Subject: [PATCH 13/14] #268: Update license in the test_polymorphic file --- tests/unit/test_polymorphic.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_polymorphic.cc b/tests/unit/test_polymorphic.cc index 57944fd8..84b24104 100644 --- a/tests/unit/test_polymorphic.cc +++ b/tests/unit/test_polymorphic.cc @@ -3,7 +3,7 @@ // ***************************************************************************** // // test_polymorphic.cc -// DARMA/checkpoint => Serialization Library +// DARMA/magistrate => Serialization Library // // Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC // (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. From 6fec2216e0b86c9885f07991f27fbfe1022694a8 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Wed, 25 Sep 2024 13:03:52 +0200 Subject: [PATCH 14/14] #268: Add check for type ID in tests for polymorphic types --- tests/unit/test_polymorphic.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_polymorphic.cc b/tests/unit/test_polymorphic.cc index 84b24104..215e3ce9 100644 --- a/tests/unit/test_polymorphic.cc +++ b/tests/unit/test_polymorphic.cc @@ -112,6 +112,8 @@ void testPolymorphicTypes(int val) { auto out = checkpoint::deserialize(std::move(ret)); EXPECT_TRUE(nullptr != out); + EXPECT_EQ(typeid(*task), typeid(*out)); + EXPECT_TRUE(nullptr != dynamic_cast(out.get())); EXPECT_EQ(val, out->getVal()); }