Skip to content

Commit

Permalink
[SYCL] Detect conflicts between kernel properties (#15510)
Browse files Browse the repository at this point in the history
The `max_work_group_size` and `max_linear_work_group_size` kernel
properties conflict with the `work_group_size` property when the
required work-group size exceeds either of the maximum sizes.
  • Loading branch information
frasercrmck authored Sep 30, 2024
1 parent b9eb520 commit f18ed8c
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,12 @@ using HasUsmKind = HasProperty<usm_kind_key, PropertyListT>;
template <typename PropertyListT>
using HasBufferLocation = HasProperty<buffer_location_key, PropertyListT>;

// Get the value of a property from a property list
template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename PropertyListT>
struct GetPropertyValueFromPropList {};

template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename... Props>
struct GetPropertyValueFromPropList<PropKey, ConstType, DefaultPropVal,
detail::properties_t<Props...>> {
using prop_val_t = std::conditional_t<
detail::ContainsProperty<PropKey, std::tuple<Props...>>::value,
typename detail::FindCompileTimePropertyValueType<
PropKey, std::tuple<Props...>>::type,
DefaultPropVal>;
static constexpr ConstType value =
detail::PropertyMetaInfo<std::remove_const_t<prop_val_t>>::value;
};
detail::properties_t<Props...>>
: GetPropertyValueFromPropList<PropKey, ConstType, DefaultPropVal,
std::tuple<Props...>> {};

// Get the value of alignment from a property list
// If alignment is not present in the property list, set to default value 0
Expand Down
79 changes: 76 additions & 3 deletions sycl/include/sycl/ext/oneapi/kernel_properties/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
#pragma once

#include <array> // for array
#include <limits>
#include <stddef.h> // for size_t
#include <stdint.h> // for uint32_T
#include <sycl/aspects.hpp> // for aspect
#include <sycl/ext/oneapi/experimental/forward_progress.hpp> // for forward_progress_guarantee enum
#include <sycl/ext/oneapi/properties/property.hpp> // for PropKind
#include <sycl/ext/oneapi/properties/property_utils.hpp> // for SizeListToStr
#include <sycl/ext/oneapi/properties/property_value.hpp> // for property_value
#include <sycl/ext/oneapi/properties/properties.hpp>
#include <type_traits> // for true_type
#include <utility> // for declval
namespace sycl {
Expand Down Expand Up @@ -351,6 +350,80 @@ struct HasKernelPropertiesGetMethod<T,
decltype(std::declval<T>().get(std::declval<properties_tag>()));
};

// Trait for property compile-time meta names and values.
template <typename PropertyT> struct WGSizePropertyMetaInfo {
static constexpr std::array<size_t, 0> WGSize = {};
static constexpr size_t LinearSize = 0;
};

template <size_t Dim0, size_t... Dims>
struct WGSizePropertyMetaInfo<work_group_size_key::value_t<Dim0, Dims...>> {
static constexpr std::array<size_t, sizeof...(Dims) + 1> WGSize = {Dim0,
Dims...};
static constexpr size_t LinearSize = (Dim0 * ... * Dims);
};

template <size_t Dim0, size_t... Dims>
struct WGSizePropertyMetaInfo<max_work_group_size_key::value_t<Dim0, Dims...>> {
static constexpr std::array<size_t, sizeof...(Dims) + 1> WGSize = {Dim0,
Dims...};
static constexpr size_t LinearSize = (Dim0 * ... * Dims);
};

// Get the value of a work-group size related property from a property list
template <typename PropKey, typename PropertiesT>
struct GetWGPropertyFromPropList {};

template <typename PropKey, typename... PropertiesT>
struct GetWGPropertyFromPropList<PropKey, std::tuple<PropertiesT...>> {
using prop_val_t = std::conditional_t<
ContainsProperty<PropKey, std::tuple<PropertiesT...>>::value,
typename FindCompileTimePropertyValueType<
PropKey, std::tuple<PropertiesT...>>::type,
void>;
static constexpr auto WGSize =
WGSizePropertyMetaInfo<std::remove_const_t<prop_val_t>>::WGSize;
static constexpr size_t LinearSize =
WGSizePropertyMetaInfo<std::remove_const_t<prop_val_t>>::LinearSize;
};

// If work_group_size and max_work_group_size coexist, check that the
// dimensionality matches and that the required work-group size doesn't
// trivially exceed the maximum size.
template <typename Properties>
struct ConflictingProperties<max_work_group_size_key, Properties>
: std::false_type {
using WGSizeVal = GetWGPropertyFromPropList<work_group_size_key, Properties>;
using MaxWGSizeVal =
GetWGPropertyFromPropList<max_work_group_size_key, Properties>;
// If work_group_size_key doesn't exist in the list of properties, WGSize is
// an empty array and so Dims == 0.
static constexpr size_t Dims = WGSizeVal::WGSize.size();
static_assert(
Dims == 0 || Dims == MaxWGSizeVal::WGSize.size(),
"work_group_size and max_work_group_size dimensionality must match");
static_assert(Dims < 1 || WGSizeVal::WGSize[0] <= MaxWGSizeVal::WGSize[0],
"work_group_size must not exceed max_work_group_size");
static_assert(Dims < 2 || WGSizeVal::WGSize[1] <= MaxWGSizeVal::WGSize[1],
"work_group_size must not exceed max_work_group_size");
static_assert(Dims < 3 || WGSizeVal::WGSize[2] <= MaxWGSizeVal::WGSize[2],
"work_group_size must not exceed max_work_group_size");
};

// If work_group_size and max_linear_work_group_size coexist, check that the
// required linear work-group size doesn't trivially exceed the maximum size.
template <typename Properties>
struct ConflictingProperties<max_linear_work_group_size_key, Properties>
: std::false_type {
using WGSizeVal = GetWGPropertyFromPropList<work_group_size_key, Properties>;
using MaxLinearWGSizeVal =
GetPropertyValueFromPropList<max_linear_work_group_size_key, size_t, void,
Properties>;
static_assert(WGSizeVal::WGSize.empty() ||
WGSizeVal::LinearSize <= MaxLinearWGSizeVal::value,
"work_group_size must not exceed max_linear_work_group_size");
};

} // namespace detail
} // namespace ext::oneapi::experimental
} // namespace _V1
Expand Down
18 changes: 18 additions & 0 deletions sycl/include/sycl/ext/oneapi/properties/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,24 @@ struct ExtractProperties<PropertyArgsT,
}
};

// Get the value of a property from a property list
template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename PropertiesT>
struct GetPropertyValueFromPropList {};

template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename... PropertiesT>
struct GetPropertyValueFromPropList<PropKey, ConstType, DefaultPropVal,
std::tuple<PropertiesT...>> {
using prop_val_t = std::conditional_t<
ContainsProperty<PropKey, std::tuple<PropertiesT...>>::value,
typename FindCompileTimePropertyValueType<
PropKey, std::tuple<PropertiesT...>>::type,
DefaultPropVal>;
static constexpr ConstType value =
PropertyMetaInfo<std::remove_const_t<prop_val_t>>::value;
};

} // namespace detail

template <typename PropertiesT> class properties {
Expand Down
59 changes: 59 additions & 0 deletions sycl/test/extensions/properties/properties_kernel_negative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,68 @@ void check_sub_group_size() {
KernelFunctorWithSGSize<2>{});
}

void check_max_work_group_size() {
sycl::queue Q;

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size and max_work_group_size dimensionality must match}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 2>,
sycl::ext::oneapi::experimental::max_work_group_size<1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2>,
sycl::ext::oneapi::experimental::max_work_group_size<1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 2>,
sycl::ext::oneapi::experimental::max_work_group_size<2, 1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 2, 2>,
sycl::ext::oneapi::experimental::max_work_group_size<2, 2, 1>},
[]() {});
}

void check_max_linear_work_group_size() {
sycl::queue Q;

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_linear_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2>,
sycl::ext::oneapi::experimental::max_linear_work_group_size<1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_linear_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 4>,
sycl::ext::oneapi::experimental::max_linear_work_group_size<7>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_linear_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 4, 2>,
sycl::ext::oneapi::experimental::max_linear_work_group_size<15>},
[]() {});
}

int main() {
check_work_group_size();
check_work_group_size_hint();
check_sub_group_size();
check_max_work_group_size();
check_max_linear_work_group_size();
return 0;
}

0 comments on commit f18ed8c

Please sign in to comment.