Skip to content

Commit

Permalink
Incorporate Feedback by Olivier Parcollet
Browse files Browse the repository at this point in the history
- Various simplifications / refactorings
- Add documentation
- Remove address space generic calloc
- Rename lapack_cuda_ -> cusolver_
- Rename lapack_cxx_ -> cxx_
  • Loading branch information
Wentzell committed Aug 22, 2023
1 parent 1e34be8 commit 9bad8a6
Show file tree
Hide file tree
Showing 29 changed files with 185 additions and 197 deletions.
2 changes: 1 addition & 1 deletion c++/nda/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
file(GLOB_RECURSE sources *.cpp)
if(NOT CudaSupport)
list(FILTER sources EXCLUDE REGEX "(cublas_interface|cuda_interface)")
list(FILTER sources EXCLUDE REGEX "(cublas_interface|cusolver_interface)")
endif()
add_library(${PROJECT_NAME}_c ${sources})
add_library(${PROJECT_NAME}::${PROJECT_NAME}_c ALIAS ${PROJECT_NAME}_c)
Expand Down
9 changes: 4 additions & 5 deletions c++/nda/basic_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ namespace nda {
explicit basic_array(basic_array const &x) noexcept : lay(x.indexmap()), sto(x.sto) {}

/// Makes a deep copy, given a basic_array with a different container policy
template <char Algebra_other, typename ContainerPolicy_other>
explicit basic_array(basic_array<ValueType, Rank, LayoutPolicy, Algebra_other, ContainerPolicy_other> x) noexcept
template <char AlgebraOther, typename ContainerPolicyOther>
explicit basic_array(basic_array<ValueType, Rank, LayoutPolicy, AlgebraOther, ContainerPolicyOther> x) noexcept
: lay(x.indexmap()), sto(std::move(x.storage())) {}

/**
Expand All @@ -112,12 +112,11 @@ namespace nda {
* @param i0, is ... are the extents (lengths) in each dimension
*/
template <std::integral... Int>
requires(sizeof...(Int) == Rank)
explicit basic_array(Int... is) noexcept {
static_assert(sizeof...(Int) == Rank, "Incorrect number of extents");
// Constructing layout and storage in constructor body improves error message for wrong # of args
lay = layout_t{std::array{long(is)...}};
sto = storage_t{lay.size()};
// It would be more natural to construct lay, storage from the start, but the error message in case of false # of parameters (very common)
// is better like this. FIXME to be tested in benchs
}

/**
Expand Down
8 changes: 4 additions & 4 deletions c++/nda/blas/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ namespace nda::blas {
return x * y;
} else {
static_assert(have_same_value_type_v<X, Y>, "Vectors must have same value type");
static_assert(mem::have_compatible_addr_space_v<X, Y>, "Vectors must have compatible memory address space");
static_assert(mem::have_compatible_addr_space<X, Y>, "Vectors must have compatible memory address space");
static_assert(is_blas_lapack_v<get_value_t<X>>, "Vectors hold value_type incompatible with blas");

EXPECTS(x.shape() == y.shape());

if constexpr (mem::have_device_compatible_addr_space_v<X, Y>) {
if constexpr (mem::have_device_compatible_addr_space<X, Y>) {
#if defined(NDA_HAVE_DEVICE)
return device::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
#else
Expand All @@ -56,14 +56,14 @@ namespace nda::blas {
return conj(x) * y;
} else {
static_assert(have_same_value_type_v<X, Y>, "Vectors must have same value type");
static_assert(mem::have_compatible_addr_space_v<X, Y>, "Vectors must have same memory address space");
static_assert(mem::have_compatible_addr_space<X, Y>, "Vectors must have same memory address space");
static_assert(is_blas_lapack_v<get_value_t<X>>, "Vectors hold value_type incompatible with blas");

EXPECTS(x.shape() == y.shape());

if constexpr (!is_complex_v<get_value_t<X>>) {
return dot(x, y);
} else if constexpr (mem::have_device_compatible_addr_space_v<X, Y>) {
} else if constexpr (mem::have_device_compatible_addr_space<X, Y>) {
#if defined(NDA_HAVE_DEVICE)
return device::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
#else
Expand Down
5 changes: 2 additions & 3 deletions c++/nda/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace nda::blas {

using A = decltype(a);
using B = decltype(b);
static_assert(mem::have_compatible_addr_space_v<A, B, C>, "Matrices must have compatible memory address space");
static_assert(mem::have_compatible_addr_space<A, B, C>, "Matrices must have compatible memory address space");

EXPECTS(a.extent(1) == b.extent(0));
EXPECTS(a.extent(0) == c.extent(0));
Expand All @@ -91,14 +91,13 @@ namespace nda::blas {
// c is in C order: compute the transpose of the product in Fortran order
if constexpr (has_C_layout<C>) {
gemm(alpha, transpose(y), transpose(x), beta, transpose(std::forward<C>(c)));
return;
} else { // c is in Fortran order
char op_a = get_op<conj_A, /*transpose =*/has_C_layout<A>>;
char op_b = get_op<conj_B, /*transpose =*/has_C_layout<B>>;
auto [m, k] = a.shape();
auto n = b.extent(1);

if constexpr (mem::have_device_compatible_addr_space_v<A, B, C>) {
if constexpr (mem::have_device_compatible_addr_space<A, B, C>) {
#if defined(NDA_HAVE_DEVICE)
device::gemm(op_a, op_b, m, n, k, alpha, a.data(), get_ld(a), b.data(), get_ld(b), beta, c.data(), get_ld(c));
#else
Expand Down
25 changes: 17 additions & 8 deletions c++/nda/blas/gemm_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ namespace nda::blas {

/**
* Batched version of GEMM taking vectors of matrices as arguments
*
* @tparam VBATCH bool, allow for variable size matrices
* @tparam X Matrix type for input vector vx, fullfills Matrix concept
* @tparam Y Matrix type for input vector xy, fullfills Matrix concept
* @tparam C Matrix type for output vector vc, fullfills MemoryMatrix concept
* @param alpha Scalar prefactor for all gemm operations
* @param vx Vector of input matrices to be multiplied from left
* @param vy Vector of input matrices to be multiplied from right
* @param vc Vector of output matrices
*/
template <bool VBATCH = false, Matrix X, Matrix Y, MemoryMatrix C>
requires((MemoryMatrix<X> or is_conj_array_expr<X>) and //
Expand All @@ -37,6 +46,7 @@ namespace nda::blas {
if (vx.empty()) return;
int batch_count = vx.size();

// FIXMEOP : move in tools : reepeat in every function
auto to_mat = []<typename Z>(Z &z) -> auto & {
if constexpr (is_conj_array_expr<Z>)
return std::get<0>(z.a);
Expand All @@ -52,7 +62,7 @@ namespace nda::blas {

using A = decltype(a0);
using B = decltype(b0);
static_assert(mem::have_compatible_addr_space_v<A, B, C>, "Matrices must have same memory address space");
static_assert(mem::have_compatible_addr_space<A, B, C>, "Matrices must have same memory address space");

// c is in C order: compute the transpose of the product in Fortran order
if constexpr (has_C_layout<C>) {
Expand Down Expand Up @@ -91,8 +101,8 @@ namespace nda::blas {
if constexpr (VBATCH) {

// Create vectors of size 'batch_count + 1' as required by Magma
vector<int, heap<vec_adr_spc>> vm(batch_count + 1), vk(batch_count + 1), vn(batch_count + 1), vlda(batch_count + 1), vldb(batch_count + 1),
vldc(batch_count + 1);
nda::vector<int, heap<vec_adr_spc>> vm(batch_count + 1), vk(batch_count + 1), vn(batch_count + 1), vlda(batch_count + 1),
vldb(batch_count + 1), vldc(batch_count + 1);

for (auto i : range(batch_count)) {
auto &ai = to_mat(vx[i]);
Expand All @@ -112,7 +122,7 @@ namespace nda::blas {
vldc[i] = get_ld(ci);
}

if constexpr (mem::have_device_compatible_addr_space_v<A, B, C>) {
if constexpr (mem::have_device_compatible_addr_space<A, B, C>) {
#if defined(NDA_HAVE_DEVICE)
device::gemm_vbatch(op_a, op_b, vm.data(), vn.data(), vk.data(), alpha, a_ptrs.data(), vlda.data(), b_ptrs.data(), vldb.data(), beta,
c_ptrs.data(), vldc.data(), batch_count);
Expand All @@ -132,7 +142,7 @@ namespace nda::blas {
auto [m, k] = a0.shape();
auto n = b0.extent(1);

if constexpr (mem::have_device_compatible_addr_space_v<A, B, C>) {
if constexpr (mem::have_device_compatible_addr_space<A, B, C>) {
#if defined(NDA_HAVE_DEVICE)
device::gemm_batch(op_a, op_b, m, n, k, alpha, a_ptrs.data(), get_ld(a0), b_ptrs.data(), get_ld(b0), beta, c_ptrs.data(), get_ld(c0),
batch_count);
Expand Down Expand Up @@ -183,7 +193,7 @@ namespace nda::blas {

using A = decltype(a);
using B = decltype(b);
static_assert(mem::have_compatible_addr_space_v<A, B, C>, "Arrays must have same memory address space");
static_assert(mem::have_compatible_addr_space<A, B, C>, "Arrays must have same memory address space");

auto _ = nda::range::all;
auto a0 = a(0, _, _);
Expand All @@ -200,7 +210,6 @@ namespace nda::blas {

// c is in C order: compute the transpose of the product in Fortran order
if constexpr (has_C_layout<C>) {
//Reconsider ..
gemm_batch_strided(alpha, transposed_view<1, 2>(y), transposed_view<1, 2>(x), beta, transposed_view<1, 2>(std::forward<C>(c)));
return;
} else { // c is in Fortran order
Expand All @@ -209,7 +218,7 @@ namespace nda::blas {
auto [m, k] = a0.shape();
auto n = b0.extent(1);

if constexpr (mem::have_device_compatible_addr_space_v<A, B, C>) {
if constexpr (mem::have_device_compatible_addr_space<A, B, C>) {
#if defined(NDA_HAVE_DEVICE)
device::gemm_batch_strided(op_a, op_b, m, n, k, alpha, a.data(), get_ld(a0), a.indexmap().strides()[0], b.data(), get_ld(b0), b.strides()[0],
beta, c.data(), get_ld(c0), c.indexmap().strides()[0], a.extent(0));
Expand Down
4 changes: 2 additions & 2 deletions c++/nda/blas/gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace nda::blas {
static constexpr bool conj_A = is_conj_array_expr<X>;

using A = decltype(a);
static_assert(mem::have_compatible_addr_space_v<A, B, C>);
static_assert(mem::have_compatible_addr_space<A, B, C>);

EXPECTS(a.extent(1) == b.extent(0));
EXPECTS(a.extent(0) == c.extent(0));
Expand All @@ -85,7 +85,7 @@ namespace nda::blas {
auto [m, n] = a.shape();
if constexpr (has_C_layout<A>) std::swap(m, n);

if constexpr (mem::have_device_compatible_addr_space_v<A, B, C>) {
if constexpr (mem::have_device_compatible_addr_space<A, B, C>) {
#if defined(NDA_HAVE_DEVICE)
device::gemv(op_a, m, n, alpha, a.data(), get_ld(a), b.data(), b.indexmap().strides()[0], beta, c.data(), c.indexmap().strides()[0]);
#else
Expand Down
8 changes: 4 additions & 4 deletions c++/nda/blas/ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace nda::blas {
* * m has the correct dimension given a, b.
*/
template <MemoryVector X, MemoryVector Y, MemoryMatrix M>
requires(have_same_value_type_v<X, Y, M> and mem::have_compatible_addr_space_v<X, Y, M> and is_blas_lapack_v<get_value_t<X>>)
requires(have_same_value_type_v<X, Y, M> and mem::have_compatible_addr_space<X, Y, M> and is_blas_lapack_v<get_value_t<X>>)
void ger(get_value_t<X> alpha, X const &x, Y const &y, M &&m) {

EXPECTS(m.extent(0) == x.extent(0));
Expand All @@ -52,7 +52,7 @@ namespace nda::blas {
return;
}

if constexpr (mem::have_device_compatible_addr_space_v<X, Y, M>) {
if constexpr (mem::have_device_compatible_addr_space<X, Y, M>) {
#if defined(NDA_HAVE_DEVICE)
device::ger(m.extent(0), m.extent(1), alpha, x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0], m.data(), get_ld(m));
#else
Expand All @@ -64,7 +64,7 @@ namespace nda::blas {
}

/**
* Calculate the outer product of two (contiguous) arrays a and b
* Calculate the outer product of two contiguous arrays a and b
*
* $$ c_{i,j,k,...,u,v,w,...} = a_{i,j,k,...} * b_{u,v,w,...} $$
*
Expand All @@ -85,7 +85,7 @@ namespace nda::blas {
} else {
if (not a.is_contiguous()) NDA_RUNTIME_ERROR << "First argument to outer_product call has non-contiguous layout";
if (not b.is_contiguous()) NDA_RUNTIME_ERROR << "Second argument to outer_product call has non-contiguous layout";
auto res = zeros<get_value_t<A>, mem::get_addr_space<A>>(stdutil::join(a.shape(), b.shape()));
auto res = zeros<get_value_t<A>, mem::common_addr_space<A, B>>(stdutil::join(a.shape(), b.shape()));

auto a_vec = reshape(a, std::array{a.size()});
auto b_vec = reshape(b, std::array{b.size()});
Expand Down
Loading

0 comments on commit 9bad8a6

Please sign in to comment.