diff --git a/stan/math/prim/scal/functor/complex_step_derivative.hpp b/stan/math/prim/scal/functor/complex_step_derivative.hpp new file mode 100644 index 00000000000..28f192e6905 --- /dev/null +++ b/stan/math/prim/scal/functor/complex_step_derivative.hpp @@ -0,0 +1,39 @@ +#ifndef STAN_MATH_PRIM_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP +#define STAN_MATH_PRIM_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP + +#include +#include + +#include +#include + +namespace stan { +namespace math { + +/** + * Return a double that has value of given functor F with signature + * (complex, std::vector, std::vector, stream*) : complex + * + * @tparam F type of functor F + * @param[in] f functor for the complex number evaluation, + * must support @c std::complex as arg. + * @param[in] theta parameter where f and df/d(theta) is requested. + * @param[in] x_r continuous data vector for the ODE. + * @param[in] x_i integer data vector for the ODE. + * @param[in] h complex step size + * @param[out] msgs the print stream for warning messages. + * @return a var with value f(theta.val()) and derivative at theta. + */ +template +double complex_step_derivative(const F& f, const double& theta, + const std::vector& x_r, + const std::vector& x_i, + const double h, + std::ostream* msgs) { + return f(theta, x_r, x_i, msgs); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/rev/scal/functor/complex_step_derivative.hpp b/stan/math/rev/scal/functor/complex_step_derivative.hpp new file mode 100644 index 00000000000..986cc0bb80b --- /dev/null +++ b/stan/math/rev/scal/functor/complex_step_derivative.hpp @@ -0,0 +1,75 @@ +#ifndef STAN_MATH_REV_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP +#define STAN_MATH_REV_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return a var that has value of given functor F and derivative + * of df/d(theta), using complex step derivative + * approximation. "f" does not have to support "var" + * type, as its signature should be + * (complex, std::vector, std::vector, stream*) : complex + * + * @tparam F type of functor F + * @param[in] f functor for the complex number evaluation, + * must support @c std::complex as arg. + * @param[in] theta parameter where f and df/d(theta) is requested. + * @param[in] x_r continuous data vector for the ODE. + * @param[in] x_i integer data vector for the ODE. + * @param[in] h complex step size + * @param[out] msgs the print stream for warning messages. + * @return a var with value f(theta.val()) and derivative at theta. + */ +template +stan::math::var complex_step_derivative(const F& f, + const stan::math::var& theta, + const std::vector& x_r, + const std::vector& x_i, + const double h, + std::ostream* msgs) { + using stan::math::var; + using std::complex; + const double theta_d = theta.val(); + const complex res = f(complex(theta_d, h), x_r, x_i, msgs); + const double fx = std::real(res); + const double g = std::imag(res) / h; + return var(new stan::math::precomp_v_vari(fx, theta.vi_, g)); +} + +/** + * CSDA, default h version, with h = 1.E-20 + * + * @tparam F type of functor F + * @param[in] f functor for the complex number evaluation, + * must support @c std::complex as arg. + * @param[in] theta parameter where f and df/d(theta) is requested. + * @param[in] x_r continuous data vector for the ODE. + * @param[in] x_i integer data vector for the ODE. + * @param[out] msgs the print stream for warning messages. + * @return a var with value f(theta.val()) and derivative at theta. + */ +template +stan::math::var complex_step_derivative(const F& f, + const stan::math::var& theta, + const std::vector& x_r, + const std::vector& x_i, + std::ostream* msgs) { + return complex_step_derivative(f, theta, x_r, x_i, 1.E-20, msgs); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/test/unit/math/rev/scal/functor/complex_step_derivative_test.cpp b/test/unit/math/rev/scal/functor/complex_step_derivative_test.cpp new file mode 100644 index 00000000000..e1b0c7309b9 --- /dev/null +++ b/test/unit/math/rev/scal/functor/complex_step_derivative_test.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +struct ComplexStepDerivativeScalTest : public ::testing::Test { + struct Fexp { + template + inline T operator()(const T &theta, const std::vector &x_r, + const std::vector &x_i, std::ostream *msgs) const { + return exp(theta) / sqrt(theta) - 0.5 * exp(theta) * pow(theta, -1.5); + } + }; + + void SetUp() {} + Fexp f; + const std::vector x_r; + const std::vector x_i; + std::ostream *msgs = nullptr; +}; + +TEST_F(ComplexStepDerivativeScalTest, func_exp_sqrt) { + using stan::math::complex_step_derivative; + using stan::math::var; + + /* f near x = 0 has very large derivative */ + var x = 0.01; + var y = complex_step_derivative(f, x, x_r, x_i, 1.e-20, msgs); + + ASSERT_FLOAT_EQ(y.val(), f(x.val(), x_r, x_i, msgs)); + + std::vector xv{x}; + std::vector g1, g; + var y1 = f(x, x_r, x_i, msgs); + stan::math::set_zero_all_adjoints(); + y1.grad(xv, g1); + stan::math::set_zero_all_adjoints(); + y.grad(xv, g); + ASSERT_FLOAT_EQ(g[0], g1[0]); +} + +TEST_F(ComplexStepDerivativeScalTest, func_exp_sqrt_default_h) { + using stan::math::complex_step_derivative; + using stan::math::var; + + /* f near x = 0 has very large derivative */ + var x = 0.01; + var y = complex_step_derivative(f, x, x_r, x_i, msgs); + + ASSERT_FLOAT_EQ(y.val(), f(x.val(), x_r, x_i, msgs)); + + std::vector xv{x}; + std::vector g1, g; + var y1 = f(x, x_r, x_i, msgs); + stan::math::set_zero_all_adjoints(); + y1.grad(xv, g1); + stan::math::set_zero_all_adjoints(); + y.grad(xv, g); + ASSERT_FLOAT_EQ(g[0], g1[0]); +}