Skip to content

Commit

Permalink
sparse: replace macros with constexpr bools (#2260)
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson authored Jun 28, 2024
1 parent 31be658 commit bbfc3ff
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 141 deletions.
61 changes: 35 additions & 26 deletions sparse/impl/KokkosSparse_spadd_numeric_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ struct UnsortedNumericSumFunctor {
const CcolindsT Bpos;
};

// Helper macro to check that two types are the same (ignoring const)
#define SAME_TYPE(A, B) \
std::is_same<typename std::remove_const<A>::type, \
typename std::remove_const<B>::type>::value
// Two types are the same (ignoring const)
template <typename T, typename U>
constexpr bool spadd_numeric_same_type =
std::is_same_v<typename std::remove_const_t<T>,
typename std::remove_const_t<U>>;

template <
typename execution_space, typename KernelHandle, typename alno_row_view_t,
Expand All @@ -193,46 +194,56 @@ void spadd_numeric_impl(
typedef typename KernelHandle::nnz_scalar_t scalar_type;
// Check that A/B/C data types match KernelHandle types, and that C data types
// are nonconst (doesn't matter if A/B types are const)
static_assert(SAME_TYPE(ascalar_t, scalar_type),
static_assert(spadd_numeric_same_type<ascalar_t, scalar_type>,
"A scalar type must match handle scalar type");
static_assert(SAME_TYPE(bscalar_t, scalar_type),
static_assert(spadd_numeric_same_type<bscalar_t, scalar_type>,
"B scalar type must match handle scalar type");
static_assert(SAME_TYPE(typename alno_row_view_t::value_type, size_type),
"add_symbolic: A size_type must match KernelHandle size_type "
"(const doesn't matter)");
static_assert(SAME_TYPE(typename blno_row_view_t::value_type, size_type),
"add_symbolic: B size_type must match KernelHandle size_type "
"(const doesn't matter)");
static_assert(
SAME_TYPE(typename clno_row_view_t::non_const_value_type, size_type),
spadd_numeric_same_type<typename alno_row_view_t::value_type, size_type>,
"add_symbolic: A size_type must match KernelHandle size_type "
"(const doesn't matter)");
static_assert(
spadd_numeric_same_type<typename blno_row_view_t::value_type, size_type>,
"add_symbolic: B size_type must match KernelHandle size_type "
"(const doesn't matter)");
static_assert(
spadd_numeric_same_type<typename clno_row_view_t::non_const_value_type,
size_type>,
"add_symbolic: C size_type must match KernelHandle size_type)");
static_assert(SAME_TYPE(typename alno_nnz_view_t::value_type, ordinal_type),
static_assert(spadd_numeric_same_type<typename alno_nnz_view_t::value_type,
ordinal_type>,
"add_symbolic: A entry type must match KernelHandle entry type "
"(aka nnz_lno_t, and const doesn't matter)");
static_assert(SAME_TYPE(typename blno_nnz_view_t::value_type, ordinal_type),
static_assert(spadd_numeric_same_type<typename blno_nnz_view_t::value_type,
ordinal_type>,
"add_symbolic: B entry type must match KernelHandle entry type "
"(aka nnz_lno_t, and const doesn't matter)");
static_assert(SAME_TYPE(typename clno_nnz_view_t::value_type, ordinal_type),
static_assert(spadd_numeric_same_type<typename clno_nnz_view_t::value_type,
ordinal_type>,
"add_symbolic: C entry type must match KernelHandle entry type "
"(aka nnz_lno_t)");
static_assert(std::is_same<typename clno_nnz_view_t::non_const_value_type,
typename clno_nnz_view_t::value_type>::value,
static_assert(std::is_same_v<typename clno_nnz_view_t::non_const_value_type,
typename clno_nnz_view_t::value_type>,
"add_symbolic: C entry type must not be const");
static_assert(
SAME_TYPE(typename ascalar_nnz_view_t::value_type, scalar_type),
spadd_numeric_same_type<typename ascalar_nnz_view_t::value_type,
scalar_type>,
"add_symbolic: A scalar type must match KernelHandle entry type (aka "
"nnz_lno_t, and const doesn't matter)");
static_assert(
SAME_TYPE(typename bscalar_nnz_view_t::value_type, scalar_type),
spadd_numeric_same_type<typename bscalar_nnz_view_t::value_type,
scalar_type>,
"add_symbolic: B scalar type must match KernelHandle entry type (aka "
"nnz_lno_t, and const doesn't matter)");
static_assert(
SAME_TYPE(typename cscalar_nnz_view_t::value_type, scalar_type),
spadd_numeric_same_type<typename cscalar_nnz_view_t::value_type,
scalar_type>,
"add_symbolic: C scalar type must match KernelHandle entry type (aka "
"nnz_lno_t)");
static_assert(std::is_same<typename cscalar_nnz_view_t::non_const_value_type,
typename cscalar_nnz_view_t::value_type>::value,
"add_symbolic: C scalar type must not be const");
static_assert(
std::is_same_v<typename cscalar_nnz_view_t::non_const_value_type,
typename cscalar_nnz_view_t::value_type>,
"add_symbolic: C scalar type must not be const");
typedef Kokkos::RangePolicy<execution_space, size_type> range_type;
auto addHandle = kernel_handle->get_spadd_handle();
// rowmap length can be 0 or 1 if #rows is 0.
Expand Down Expand Up @@ -269,8 +280,6 @@ void spadd_numeric_impl(
addHandle->set_call_numeric();
}

#undef SAME_TYPE

} // namespace Impl
} // namespace KokkosSparse

Expand Down
34 changes: 19 additions & 15 deletions sparse/impl/KokkosSparse_spadd_symbolic_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
namespace KokkosSparse {
namespace Impl {

// Helper macro to check that two types are the same (ignoring const)
#define SAME_TYPE(A, B) \
std::is_same<typename std::remove_const<A>::type, \
typename std::remove_const<B>::type>::value
// Two types are the same (ignoring const)
template <typename T, typename U>
constexpr bool spadd_symbolic_same_type =
std::is_same_v<typename std::remove_const_t<T>,
typename std::remove_const_t<U>>;

// get C rowmap for sorted input
template <typename size_type, typename ordinal_type, typename ARowPtrsT,
Expand Down Expand Up @@ -479,29 +480,34 @@ void spadd_symbolic_impl(
// Check that A/B/C data types match KernelHandle types, and that C data types
// are nonconst (doesn't matter if A/B types are const)
static_assert(
SAME_TYPE(typename alno_row_view_t_::non_const_value_type, size_type),
spadd_symbolic_same_type<typename alno_row_view_t_::non_const_value_type,
size_type>,
"add_symbolic: A size_type must match KernelHandle size_type (const "
"doesn't matter)");
static_assert(
SAME_TYPE(typename blno_row_view_t_::non_const_value_type, size_type),
spadd_symbolic_same_type<typename blno_row_view_t_::non_const_value_type,
size_type>,
"add_symbolic: B size_type must match KernelHandle size_type (const "
"doesn't matter)");
static_assert(
SAME_TYPE(typename clno_row_view_t_::non_const_value_type, size_type),
spadd_symbolic_same_type<typename clno_row_view_t_::non_const_value_type,
size_type>,
"add_symbolic: C size_type must match KernelHandle size_type)");
static_assert(std::is_same<typename clno_row_view_t_::non_const_value_type,
typename clno_row_view_t_::value_type>::value,
static_assert(std::is_same_v<typename clno_row_view_t_::non_const_value_type,
typename clno_row_view_t_::value_type>,
"add_symbolic: C size_type must not be const");
static_assert(
SAME_TYPE(typename alno_nnz_view_t_::non_const_value_type, ordinal_type),
spadd_symbolic_same_type<typename alno_nnz_view_t_::non_const_value_type,
ordinal_type>,
"add_symbolic: A entry type must match KernelHandle entry type (aka "
"nnz_lno_t, and const doesn't matter)");
static_assert(
SAME_TYPE(typename blno_nnz_view_t_::non_const_value_type, ordinal_type),
spadd_symbolic_same_type<typename blno_nnz_view_t_::non_const_value_type,
ordinal_type>,
"add_symbolic: B entry type must match KernelHandle entry type (aka "
"nnz_lno_t, and const doesn't matter)");
static_assert(std::is_same<typename clno_row_view_t_::non_const_value_type,
typename clno_row_view_t_::value_type>::value,
static_assert(std::is_same_v<typename clno_row_view_t_::non_const_value_type,
typename clno_row_view_t_::value_type>,
"add_symbolic: C entry type must not be const");
// symbolic just needs to compute c_rowmap
// easy for sorted, but for unsorted is easiest to just compute the whole sum
Expand Down Expand Up @@ -594,8 +600,6 @@ void spadd_symbolic_impl(
addHandle->set_call_numeric(false);
}

#undef SAME_TYPE

} // namespace Impl
} // namespace KokkosSparse

Expand Down
Loading

0 comments on commit bbfc3ff

Please sign in to comment.