Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move type arg to the end to match Aten constructors. #5379

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions extension/tensor/tensor_impl_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ struct TensorImplPtrDeleter final {
} // namespace

TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type,
std::vector<exec_aten::SizesType> sizes,
void* data,
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type,
exec_aten::TensorShapeDynamism dynamism,
std::function<void(void*)> deleter) {
const auto dim = sizes.size();
Expand Down Expand Up @@ -129,24 +129,24 @@ TensorImplPtr make_tensor_impl_ptr(
}

TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType scalar_type,
std::vector<exec_aten::SizesType> sizes,
std::vector<uint8_t> data,
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type,
exec_aten::TensorShapeDynamism dynamism) {
ET_CHECK_MSG(
data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
exec_aten::elementSize(scalar_type),
exec_aten::elementSize(type),
"Data size is smaller than required by sizes and scalar type.");
auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<uint8_t>>(std::move(data));
return make_tensor_impl_ptr(
scalar_type,
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}
Expand Down
204 changes: 173 additions & 31 deletions extension/tensor/tensor_impl_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ namespace extension {

#ifndef USE_ATEN_LIB
/**
* A smart pointer type for managing the lifecycle of a TensorImpl.
* A smart pointer for managing the lifecycle of a TensorImpl.
*
* TensorImplPtr uses a shared pointer because multiple Tensor objects might
* share the same underlying data and metadata. This shared ownership model
* ensures that the TensorImpl is only destroyed when all references to it are
* gone, providing a safe and efficient way to manage shared tensor
* implementations. This abstraction is designed to be a safer and more
* convenient alternative to the original TensorImpl, which does not
* manage metadata by design.
* TensorImplPtr uses a shared pointer since multiple Tensor objects may
* share the same underlying data and metadata. This shared ownership ensures
* that the TensorImpl is destroyed only when all references to it are gone,
* providing a safe and efficient way to manage shared tensor implementations.
* It serves as a safer, more convenient alternative to the original TensorImpl,
* which does not manage its metadata by design.
*/
using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
#else
Expand All @@ -48,23 +47,23 @@ using TensorImplPtr =
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* @param type The scalar type of the tensor elements.
* @param sizes A vector specifying the size of each dimension.
* @param data A pointer to the data buffer.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @param deleter A custom deleter function for managing the lifetime of the
* data buffer. If provided, this deleter will be called when the managed
* TensorImpl object is destroyed.
* data buffer. If provided, this deleter is called when the managed TensorImpl
* is destroyed.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType type,
std::vector<exec_aten::SizesType> sizes,
void* data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
std::function<void(void*)> deleter = nullptr);
Expand All @@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as a vector. The scalar type is automatically deduced from the
* vector's data type. The deleter ensures that the data vector is properly
* managed and its lifetime is tied to the TensorImpl.
* @param sizes A vector specifying the size of each dimension.
* @param data A pointer to the data buffer.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @param deleter A custom deleter function for managing the lifetime of the
* data buffer. If provided, this deleter is called when the managed TensorImpl
* is destroyed.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
void* data,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
std::function<void(void*)> deleter = nullptr) {
return make_tensor_impl_ptr(
std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as a vector. The scalar type is automatically deduced from the vector's data
* type. The deleter ensures that the data vector is properly managed, with its
* lifetime tied to the TensorImpl.
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <typename T = float>
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::vector<T> data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
constexpr exec_aten::ScalarType scalar_type =
runtime::CppTypeToScalarType<T>::value;
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
scalar_type,
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}
Expand All @@ -119,43 +145,159 @@ inline TensorImplPtr make_tensor_impl_ptr(
*
* @tparam T The C++ type of the tensor elements, deduced from the vector.
* @param data A vector containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <typename T = float>
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<T> data,
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(data), {0}, {1}, dynamism);
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where tensor data is provided
* as an initializer list. The scalar type is automatically deduced from the
* initializer list's data type. The deleter ensures that the data is properly
* managed, with its lifetime tied to the TensorImpl.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::initializer_list<T> list,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
auto data = std::vector<T>(std::move(list));
const auto raw_data_ptr = data.data();
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
return make_tensor_impl_ptr(
std::move(sizes),
raw_data_ptr,
std::move(dim_order),
std::move(strides),
type,
dynamism,
[data_ptr = std::move(data_ptr)](void*) {});
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This template overload is specialized for cases where the tensor data is
* provided as an initializer list. The scalar type is automatically deduced
* from the initializer list's data type. The deleter ensures that the data is
* properly managed and its lifetime is tied to the TensorImpl.
*
* @tparam T The C++ type of the tensor elements, deduced from the initializer
* list.
* @param sizes A vector specifying the size of each dimension.
* @param list An initializer list containing the tensor's data.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr that manages the newly created TensorImpl.
*/
template <
typename T = float,
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
inline TensorImplPtr make_tensor_impl_ptr(
std::initializer_list<T> list,
exec_aten::ScalarType type = deduced_type,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
return make_tensor_impl_ptr(
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
}

/**
* Creates a TensorImplPtr to manage a Tensor with a single scalar value.
*
* @tparam T The C++ type of the scalar value.
* @param value The scalar value used for the Tensor.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
template <typename T>
inline TensorImplPtr make_tensor_impl_ptr(T value) {
return make_tensor_impl_ptr({}, std::vector<T>{value});
}

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
* and a scalar type to interpret the data. The vector is managed, and the
* memory's lifetime is tied to the TensorImpl.
* and a scalar type to interpret the data. The vector is managed, and its
* lifetime is tied to the TensorImpl.
*
* @param scalar_type The scalar type of the tensor elements.
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the raw memory for the tensor's data.
* @param data A vector containing the raw memory buffer for the tensor's data.
* @param dim_order A vector specifying the order of dimensions.
* @param strides A vector specifying the strides of each dimension.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
TensorImplPtr make_tensor_impl_ptr(
exec_aten::ScalarType scalar_type,
std::vector<exec_aten::SizesType> sizes,
std::vector<uint8_t> data,
std::vector<exec_aten::DimOrderType> dim_order = {},
std::vector<exec_aten::StridesType> strides = {},
std::vector<exec_aten::DimOrderType> dim_order,
std::vector<exec_aten::StridesType> strides,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);

/**
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
* specified properties.
*
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
* and a scalar type to interpret the data. The vector is managed, and the
* memory's lifetime is tied to the TensorImpl.
*
* @param sizes A vector specifying the size of each dimension.
* @param data A vector containing the raw memory for the tensor's data.
* @param type The scalar type of the tensor elements.
* @param dynamism Specifies the mutability of the tensor's shape.
* @return A TensorImplPtr managing the newly created TensorImpl.
*/
inline TensorImplPtr make_tensor_impl_ptr(
std::vector<exec_aten::SizesType> sizes,
std::vector<uint8_t> data,
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
exec_aten::TensorShapeDynamism dynamism =
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
return make_tensor_impl_ptr(
std::move(sizes), std::move(data), {}, {}, type, dynamism);
}

} // namespace extension
} // namespace executorch
Loading
Loading