Skip to content

Commit

Permalink
#268: Refactor dispatcher to properly deserialize polymorphic types w…
Browse files Browse the repository at this point in the history
…ith/without serialization error checking
  • Loading branch information
thearusable committed Sep 24, 2024
1 parent dfaf97f commit ffc3644
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 110 deletions.
44 changes: 0 additions & 44 deletions src/checkpoint/dispatch/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,50 +188,6 @@ struct Standard {
static SerialByteType* allocate();
};

/**
* \struct Prefixed
*
* \brief Traversal for polymorphic types prefixed with the vrt::TypeIdx
*/
struct Prefixed {
/**
* \brief Traverse a \c target of type \c T recursively with a general \c
* TraverserT that gets applied to each element.
* Allows to traverse only part of the data.
*
* \param[in,out] target the target to traverse
* \param[in] len the len of the target. If > 1, \c target is an array
* \param[in] check_type the flag to control type validation
* \param[in] check_mem the flag to control memory validation
* \param[in] args the args to pass to the traverser for construction
*
* \return the traverser after traversal is complete
*/
template <typename T, typename TraverserT, typename... Args>
static TraverserT traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args);

/**
* \brief Unpack \c T from packed byte-buffer \c mem
*
* \param[in] mem bytes holding a serialized \c T
* \param[in] check_type the flag to control type validation
* \param[in] check_mem the flag to control memory validation
* \param[in] args arguments to the unpacker's constructor
*
* \return a pointer to an unpacked \c T
*/
template <typename T, typename UnpackerT, typename... Args>
static T* unpack(T* mem, bool check_type, bool check_mem, Args&&... args);

/**
* \brief Check if prefix is valid
*
* \param[in] prefix the prefix to be validated
*/
template <typename T>
static void validatePrefix(vrt::TypeIdx prefix);
};

template <typename T>
buffer::ImplReturnType packBuffer(
T& target, SerialSizeType size, BufferObtainFnType fn
Expand Down
90 changes: 27 additions & 63 deletions src/checkpoint/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct serialization_error : public std::runtime_error {
};

template <typename T, typename TraverserT>
TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) {
TraverserT& withTypeIdx(TraverserT& t) {
using CleanT = typename CleanType<typeregistry::DecodedIndex>::CleanT;
using DispatchType =
typename TraverserT::template DispatcherType<TraverserT, CleanT>;
Expand All @@ -87,9 +87,9 @@ TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) {
auto val = cleanType(&serTypeIdx);
ap(t, val, len);

if (check_type &&
(typeregistry::validateIndex(serTypeIdx) == false ||
thisTypeIdx != serTypeIdx)
if (
typeregistry::validateIndex(serTypeIdx) == false ||
thisTypeIdx != serTypeIdx
) {
auto const err = std::string("Unpacking wrong type, got=") +
typeregistry::getTypeNameForIdx(thisTypeIdx) +
Expand All @@ -105,7 +105,7 @@ TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) {
}

template <typename T, typename TraverserT>
TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true) {
TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) {
using DispatchType =
typename TraverserT::template DispatcherType<TraverserT, SerialSizeType>;
SerializerDispatch<TraverserT, SerialSizeType, DispatchType> ap;
Expand All @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true
auto val = cleanType(&serMemUsed);
ap(t, val, memUsedLen);

if (check_mem && memUsed != serMemUsed) {
if (t.shouldValidateMemory() && memUsed != serMemUsed) {
using CleanT = typename CleanType<T>::CleanT;
std::string msg = "For type '" + typeregistry::getTypeName<CleanT>() +
"' serialization used " + std::to_string(serMemUsed) +
Expand All @@ -133,37 +133,6 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true
return t;
}

template <typename T, typename TraverserT, typename... Args>
TraverserT Prefixed::traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args) {
using CleanT = typename CleanType<T>::CleanT;
using DispatchType =
typename TraverserT::template DispatcherType<TraverserT, CleanT>;

TraverserT t(std::forward<Args>(args)...);

withTypeIdx<CleanT>(t, check_type);

auto val = cleanType(&target);
SerializerDispatch<TraverserT, CleanT, DispatchType> ap;

#if defined(SERIALIZATION_ERROR_CHECKING)
try {
ap(t, val, len);
} catch (serialization_error const& err) {
auto const depth = err.depth_ + 1;
auto const what = std::string(err.what()) + "\n#" + std::to_string(depth) +
" " + typeregistry::getTypeName<T>();
throw serialization_error(what, depth);
}
#else
ap(t, val, len);
#endif

withMemUsed<CleanT>(t, 1, check_mem);

return t;
}

template <typename T, typename TraverserT>
TraverserT& Traverse::with(T& target, TraverserT& t, SerialSizeType len) {
using CleanT = typename CleanType<T>::CleanT;
Expand Down Expand Up @@ -252,12 +221,6 @@ T* Standard::unpack(T* t_buf, Args&&... args) {
return t_buf;
}

template <typename T, typename UnpackerT, typename... Args>
T* Prefixed::unpack(T* t_buf, bool check_type, bool check_mem, Args&&... args) {
Prefixed::traverse<T, UnpackerT>(*t_buf, 1, check_type, check_mem, std::forward<Args>(args)...);
return t_buf;
}

template <typename T>
T* Standard::construct(SerialByteType* mem) {
return Traverse::reconstruct<T>(mem);
Expand Down Expand Up @@ -313,9 +276,7 @@ serializeType(T& target, BufferObtainFnType fn) {
}

template <typename T>
typename std::enable_if<
!vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
T*>::type
typename std::enable_if<!vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
auto mem = allocBuf ? allocBuf : Standard::allocate<T>();
auto t_buf = std::unique_ptr<T>(Standard::construct<T>(mem));
Expand Down Expand Up @@ -351,14 +312,17 @@ typename std::enable_if<
vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
using BaseType = vrt::checkpoint_base_type_t<T>;
using PrefixedType = PrefixedType<BaseType>;

auto prefixed = PrefixedType(&target);
auto len = Standard::size<decltype(prefixed), Sizer>(prefixed);
auto len = Standard::size<PrefixedType, Sizer>(prefixed);
debug_checkpoint("serializeType: len=%ld\n", len);
return packBuffer<decltype(prefixed)>(prefixed, len, fn);
return packBuffer<PrefixedType>(prefixed, len, fn);
}

template <typename T>
void Prefixed::validatePrefix(vrt::TypeIdx prefix) {
void validatePrefix(vrt::TypeIdx prefix) {
if (!vrt::objregistry::isValidIdx<T>(prefix)) {
std::string const err = std::string("Unpacking invalid prefix type (") +
std::to_string(prefix) + std::string(") from object registry for type=") +
Expand All @@ -368,28 +332,28 @@ void Prefixed::validatePrefix(vrt::TypeIdx prefix) {
}

template <typename T>
typename std::enable_if<
vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
T*>::type
typename std::enable_if<vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
using BaseType = vrt::checkpoint_base_type_t<T>;
using PrefixedType = PrefixedType<BaseType>;

auto prefix_mem = Standard::allocate<vrt::TypeIdx>();
auto prefix_buf = std::unique_ptr<vrt::TypeIdx>(Standard::construct<vrt::TypeIdx>(prefix_mem));
// Unpack TypeIdx, ignore checks for type and memory used - unpacking will only use a part of the data
vrt::TypeIdx* prefix =
Prefixed::unpack<vrt::TypeIdx, UnpackerBuffer<buffer::UserBuffer>>(prefix_buf.get(), false, false, data);
prefix_buf.release();
auto prefix_mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType<BaseType>(0);
auto prefix_buf = vrt::objregistry::constructConcreteType<BaseType>(0, prefix_mem);
auto prefix_struct = PrefixedType(prefix_buf);
// Disable memory check during first unpacking.
// Unpacking BaseType will always result in memory amount missmatch between serialization/deserialization
auto* prefix =
Standard::unpack<PrefixedType, UnpackerBuffer<buffer::UserBuffer>>(&prefix_struct, data, false);
delete prefix_buf;

Prefixed::validatePrefix<BaseType>(*prefix);
validatePrefix<BaseType>(prefix->prefix_);

// allocate memory based on the readed TypeIdx
auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType<BaseType>(*prefix);
auto t_buf = vrt::objregistry::constructConcreteType<BaseType>(*prefix, mem);
auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType<BaseType>(prefix->prefix_);
auto t_buf = vrt::objregistry::constructConcreteType<BaseType>(prefix->prefix_, mem);
auto prefixed = PrefixedType(t_buf);
// Unpack PrefixedType, ignore checks for unpacked type and execute checks for memory used
auto* traverser =
Prefixed::unpack<decltype(prefixed), UnpackerBuffer<buffer::UserBuffer>>(&prefixed, false, true, data);
Standard::unpack<PrefixedType, UnpackerBuffer<buffer::UserBuffer>>(&prefixed, data);
return static_cast<T*>(traverser->target_);
}

Expand Down
7 changes: 7 additions & 0 deletions src/checkpoint/serializers/base_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ struct BaseSerializer {
*/
void setVirtualDisabled(bool val) { virtual_disabled_ = val; }

/**
* \brief Check if used memory should be validated
*
* \return whether memory should be validated
*/
bool shouldValidateMemory() const { return true; }

protected:
ModeType cur_mode_ = ModeType::Invalid; /**< The current mode */
bool virtual_disabled_ = false; /**< Virtual serialization disabled */
Expand Down
6 changes: 5 additions & 1 deletion src/checkpoint/serializers/unpacker.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,23 @@ template <typename BufferT>
struct UnpackerBuffer : MemorySerializer {
using BufferPtrType = std::unique_ptr<BufferT>;

explicit UnpackerBuffer(SerialByteType* buf);
explicit UnpackerBuffer(SerialByteType* buf, bool validate = true);

template <typename... Args>
explicit UnpackerBuffer(Args&&... args);

void contiguousBytes(void* ptr, SerialSizeType size, SerialSizeType num_elms);
SerialSizeType usedBufferSize() const;

bool shouldValidateMemory() const;

private:
// Size of the actually used memory (for error checking)
SerialSizeType usedSize_ = 0;

BufferPtrType buffer_ = nullptr;

bool validate_memory_ = true;
};

using Unpacker = UnpackerBuffer<buffer::UserBuffer>;
Expand Down
10 changes: 8 additions & 2 deletions src/checkpoint/serializers/unpacker.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@
namespace checkpoint {

template <typename BufferT>
UnpackerBuffer<BufferT>::UnpackerBuffer(SerialByteType* buf)
UnpackerBuffer<BufferT>::UnpackerBuffer(SerialByteType* buf, bool validate)
: MemorySerializer(ModeType::Unpacking),
buffer_(std::make_unique<BufferT>(buf, 0))
buffer_(std::make_unique<BufferT>(buf, 0)),
validate_memory_(validate)
{
MemorySerializer::initializeBuffer(buffer_->getBuffer());

Expand Down Expand Up @@ -103,6 +104,11 @@ SerialSizeType UnpackerBuffer<BufferT>::usedBufferSize() const {
return usedSize_;
}

template <typename BufferT>
bool UnpackerBuffer<BufferT>::shouldValidateMemory() const {
return validate_memory_;
}

} /* end namespace checkpoint */

#endif /*INCLUDED_SRC_CHECKPOINT_SERIALIZERS_UNPACKER_IMPL_H*/

0 comments on commit ffc3644

Please sign in to comment.