Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BEAM solver #160

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ abs_prefix := ${abspath ${prefix}}
export CXX blas blas_int blas_threaded openmp static gpu_backend

CXXFLAGS += -O3 -std=c++17 -Wall -Wshadow -pedantic -MMD
NVCCFLAGS += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function'
HIPCCFLAGS += -std=c++14 -DTCE_HIP -fno-gpu-rdc
NVCCFLAGS += -O3 -std=c++17 --compiler-options '-Wall -Wno-unused-function'
HIPCCFLAGS += -std=c++17 -DTCE_HIP -fno-gpu-rdc

force: ;

Expand Down Expand Up @@ -480,6 +480,7 @@ ifneq ($(only_unit),1)
src/internal/internal_getrf.cc \
src/internal/internal_getrf_nopiv.cc \
src/internal/internal_getrf_tntpiv.cc \
src/internal/internal_getrf_addmod.cc \
src/internal/internal_hbnorm.cc \
src/internal/internal_hebr.cc \
src/internal/internal_hegst.cc \
Expand All @@ -501,6 +502,8 @@ ifneq ($(only_unit),1)
src/internal/internal_trnorm.cc \
src/internal/internal_trsm.cc \
src/internal/internal_trsmA.cc \
src/internal/internal_trsm_addmod.cc \
src/internal/internal_trsmA_addmod.cc \
src/internal/internal_trtri.cc \
src/internal/internal_trtrm.cc \
src/internal/internal_ttlqt.cc \
Expand Down Expand Up @@ -530,6 +533,7 @@ cuda_src := \
src/cuda/device_synorm.cu \
src/cuda/device_transpose.cu \
src/cuda/device_trnorm.cu \
src/cuda/device_trsm_addmod.cu \
src/cuda/device_tzadd.cu \
src/cuda/device_tzcopy.cu \
src/cuda/device_tzscale.cu \
Expand Down Expand Up @@ -599,13 +603,17 @@ ifneq ($(only_unit),1)
src/gesv_mixed_gmres.cc \
src/gesv_nopiv.cc \
src/gesv_rbt.cc \
src/gesv_addmod.cc \
src/gesv_addmod_ir.cc \
src/getrf.cc \
src/getrf_nopiv.cc \
src/getrf_tntpiv.cc \
src/getrf_addmod.cc \
src/getri.cc \
src/getriOOP.cc \
src/getrs.cc \
src/getrs_nopiv.cc \
src/getrs_addmod.cc \
src/hb2st.cc \
src/hbmm.cc \
src/he2hb.cc \
Expand Down Expand Up @@ -657,6 +665,9 @@ ifneq ($(only_unit),1)
src/trsm.cc \
src/trsmA.cc \
src/trsmB.cc \
src/trsm_addmod.cc \
src/trsmA_addmod.cc \
src/trsmB_addmod.cc \
src/trtri.cc \
src/trtrm.cc \
src/unmlq.cc \
Expand All @@ -667,6 +678,8 @@ ifneq ($(only_unit),1)
src/work/work_trmm.cc \
src/work/work_trsm.cc \
src/work/work_trsmA.cc \
src/work/work_trsm_addmod.cc \
src/work/work_trsmA_addmod.cc \
# End. Add alphabetically.
endif

Expand Down
88 changes: 88 additions & 0 deletions include/slate/addmod.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2022, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

//------------------------------------------------------------------------------
/// @file
///
#ifndef SLATE_ADDMOD_HH
#define SLATE_ADDMOD_HH

#include "slate/Matrix.hh"

#include <vector>

namespace slate {

//------------------------------------------------------------------------------
// auxiliary type for modifications and Woodbury matrices

template <typename scalar_t>
class AddModFactors {
using real_t = blas::real_type<scalar_t>;
public:
int64_t block_size;
int64_t num_modifications;
BlockFactor factorType;

Matrix<scalar_t> A;
Matrix<scalar_t> U_factors;
Matrix<scalar_t> VT_factors;
std::vector<std::vector<real_t>> singular_values;
std::vector<std::vector<scalar_t>> modifications;
std::vector<std::vector<int64_t>> modification_indices;
Matrix<scalar_t> capacitance_matrix;
Pivots capacitance_pivots;

Matrix<scalar_t> S_VT_Rinv;
Matrix<scalar_t> Linv_U;
};

//------------------------------------------------------------------------------
// Routines

template <typename scalar_t>
void gesv_addmod(Matrix<scalar_t>& A, AddModFactors<scalar_t>& W, Matrix<scalar_t>& B,
Options const& opts = Options());

template <typename scalar_t>
void gesv_addmod_ir( Matrix<scalar_t>& A, AddModFactors<scalar_t>& W,
Matrix<scalar_t>& B,
Matrix<scalar_t>& X,
int& iter,
Options const& opts);

template <typename scalar_t>
void getrf_addmod(Matrix<scalar_t>& A, AddModFactors<scalar_t>& W,
Options const& opts = Options());

template <typename scalar_t>
void getrs_addmod(AddModFactors<scalar_t>& W,
Matrix<scalar_t>& B,
Options const& opts);

template <typename scalar_t>
void trsm_addmod(
Side side, Uplo uplo,
scalar_t alpha, AddModFactors<scalar_t>& W,
Matrix<scalar_t>& B,
Options const& opts = Options());

template <typename scalar_t>
void trsmA_addmod(
Side side, Uplo uplo,
scalar_t alpha, AddModFactors<scalar_t>& W,
Matrix<scalar_t>& B,
Options const& opts = Options());

template <typename scalar_t>
void trsmB_addmod(
Side side, Uplo uplo,
scalar_t alpha, AddModFactors<scalar_t>& W,
Matrix<scalar_t>& B,
Options const& opts = Options());

} // namespace slate

#endif // SLATE_ADDMOD_HH
43 changes: 42 additions & 1 deletion include/slate/enums.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef SLATE_ENUMS_HH
#define SLATE_ENUMS_HH

#include <algorithm>
#include <blas.hh>
#include <lapack.hh>

Expand Down Expand Up @@ -74,6 +75,9 @@ enum class Option : char {
MaxIterations, ///< maximum iteration count
UseFallbackSolver, ///< whether to fallback to a robust solver if iterations do not converge
PivotThreshold, ///< threshold for pivoting, >= 0, <= 1
AdditiveTolerance, ///< tolerance for additive modification, >= 0
UseWoodbury, ///< whether to apply the Woodbury formula
BlockFactor, ///< how to factor the diagonal blocks in the addmod solver

// Printing parameters
PrintVerbose = 50, ///< verbose, 0: no printing,
Expand All @@ -87,7 +91,6 @@ enum class Option : char {
PrintWidth, ///< width print format specifier
PrintPrecision, ///< precision print format specifier
///< For correct printing, PrintWidth = PrintPrecision + 6.

// Methods, listed alphabetically.
MethodCholQR = 60, ///< Select the algorithm to compute A^H * A
MethodEig, ///< Select the algorithm to compute eigenpairs of tridiagonal matrix
Expand Down Expand Up @@ -144,6 +147,44 @@ enum MOSI {
typedef short MOSI_State;



//------------------------------------------------------------------------------
enum class BlockFactor : char {
SVD,
QLP,
QRCP,
QR
};
inline BlockFactor str2blockfactor(const char* method)
{
std::string method_ = method;
std::transform(
method_.begin(), method_.end(), method_.begin(), ::tolower );

if (method_ == "svd")
return BlockFactor::SVD;
else if (method_ == "qlp")
return BlockFactor::QLP;
else if (method_ == "qrcp")
return BlockFactor::QRCP;
else if (method_ == "qr")
return BlockFactor::QR;
else
// throw slate::Exception("unknown BlockFactor");
return BlockFactor::SVD;
}

inline const char* blockfactor2str(BlockFactor method)
{
switch (method) {
case BlockFactor::SVD: return "SVD";
case BlockFactor::QLP: return "QLP";
case BlockFactor::QRCP: return "QRCP";
case BlockFactor::QR: return "QR";
default: return "error";
}
}

} // namespace slate

#endif // SLATE_ENUMS_HH
20 changes: 20 additions & 0 deletions include/slate/internal/device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,26 @@ void trnorm(
blas::real_type<scalar_t>* values, int64_t ldv,
int64_t batch_count, blas::Queue& queue);

//------------------------------------------------------------------------------
template <typename scalar_t>
void batch_trsm_addmod(
BlockFactor factorType,
blas::Layout layout,
blas::Side side,
blas::Uplo uplo,
int64_t mb,
int64_t nb,
int64_t ib,
scalar_t alpha,
std::vector<scalar_t*> Aarray, int64_t ldda,
std::vector<scalar_t*> Uarray, int64_t lddu,
std::vector<scalar_t*> VTarray, int64_t lddvt,
std::vector<blas::real_type<scalar_t>*> Sarray,
std::vector<scalar_t*> Barray, int64_t lddb,
std::vector<scalar_t*> dwork,
const size_t batch,
blas::Queue &queue );

//------------------------------------------------------------------------------
// In-place, square.
template <typename scalar_t>
Expand Down
9 changes: 7 additions & 2 deletions include/slate/method.hh
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ namespace MethodTrsm {
const Method TrsmB = 2; ///< Select trsmB algorithm

template <typename TA, typename TB>
inline Method select_algo(TA& A, TB& B, Options const& opts) {
inline Method select_algo(TA& A, TB& B, Side side, Options const& opts) {
Target target = get_option( opts, Option::Target, Target::HostTask );
int n_devices = A.num_devices();

Method method = (B.nt() < 2 ? TrsmA : TrsmB);
Method method;
if (side == Side::Left) {
method = (A.nt()>B.nt() && B.nt() < 2 ? TrsmA : TrsmB);
} else {
method = (A.mt()>B.mt() && B.mt() < 2 ? TrsmA : TrsmB);
}

if (method == TrsmA && target == Target::Devices && n_devices > 1)
method = TrsmB;
Expand Down
5 changes: 5 additions & 0 deletions include/slate/slate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "slate/types.hh"
#include "slate/print.hh"

#include "slate/addmod.hh"

//------------------------------------------------------------------------------
/// @namespace slate
/// SLATE's top-level namespace.
Expand Down Expand Up @@ -644,6 +646,9 @@ void getri(
Matrix<scalar_t>& B,
Options const& opts = Options());

//-----------------------------------------
// LU with additive modifications

//-----------------------------------------
// Cholesky

Expand Down
3 changes: 3 additions & 0 deletions include/slate/types.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public:
OptionValue(Target t) : i_(int(t))
{}

OptionValue(BlockFactor f) : i_(int(f))
{}

OptionValue(MethodEig m) : i_(int(m))
{}

Expand Down
Loading
Loading