-
-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue 949 complex step derivative #950
Changes from 5 commits
54bf6e2
fb42412
4538112
f57a8de
a65c79d
d0d2b27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#ifndef STAN_MATH_PRIM_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP | ||
#define STAN_MATH_PRIM_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP | ||
|
||
#include <stan/math/prim/scal/fun/value_of.hpp> | ||
#include <stan/math/prim/scal/meta/return_type.hpp> | ||
|
||
#include <vector> | ||
#include <iostream> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return a double that has value of given functor F with signature | ||
* (complex, std::vector<double>, std::vector<int>, stream*) : complex | ||
* | ||
* @tparam F type of functor F | ||
* @param[in] f functor for the complex number evaluation, | ||
* must support @c std::complex<double> as arg. | ||
* @param[in] theta parameter where f and df/d(theta) is requested. | ||
* @param[in] x_r continuous data vector for the ODE. | ||
* @param[in] x_i integer data vector for the ODE. | ||
* @param[out] msgs the print stream for warning messages. | ||
* @return a var with value f(theta.val()) and derivative at theta. | ||
*/ | ||
template <typename F> | ||
double complex_step_derivative(const F& f, const double& theta, | ||
const std::vector<double>& x_r, | ||
const std::vector<int>& x_i, | ||
std::ostream* msgs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we pass h through as an argument and have its default value be 1e-32 or whatever? That seems like a handy thing to control. |
||
return f(theta, x_r, x_i, msgs); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#ifndef STAN_MATH_REV_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP | ||
#define STAN_MATH_REV_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP | ||
|
||
#include <stan/math/prim/scal/fun/value_of.hpp> | ||
#include <stan/math/prim/scal/meta/return_type.hpp> | ||
#include <stan/math/prim/scal/functor/complex_step_derivative.hpp> | ||
#include <stan/math/rev/core/precomp_v_vari.hpp> | ||
#include <stan/math/rev/core/vari.hpp> | ||
#include <stan/math/rev/core/var.hpp> | ||
|
||
#include <vector> | ||
#include <complex> | ||
#include <iostream> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return a var that has value of given functor F and derivative | ||
* of df/d(theta), using complex step derivative | ||
* approximation. "f" does not have to support "var" | ||
* type, as its signature should be | ||
* (complex, std::vector<double>, std::vector<int>, stream*) : complex | ||
* | ||
* @tparam F type of functor F | ||
* @param[in] f functor for the complex number evaluation, | ||
* must support @c std::complex<double> as arg. | ||
* @param[in] theta parameter where f and df/d(theta) is requested. | ||
* @param[in] x_r continuous data vector for the ODE. | ||
* @param[in] x_i integer data vector for the ODE. | ||
* @param[out] msgs the print stream for warning messages. | ||
* @return a var with value f(theta.val()) and derivative at theta. | ||
*/ | ||
template <typename F> | ||
stan::math::var complex_step_derivative(const F& f, | ||
const stan::math::var& theta, | ||
const std::vector<double>& x_r, | ||
const std::vector<int>& x_i, | ||
std::ostream* msgs) { | ||
using stan::math::var; | ||
using std::complex; | ||
static double h = 1.e-32; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The precision is O(h^2) so why does h need to be 1e-32? Usually people do 1e-8. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, < 1e-10 it becomes insensitive. I was showing merely we can do this way less than finite difference. |
||
const double theta_d = theta.val(); | ||
const double res = complex_step_derivative(f, theta_d, x_r, x_i, msgs); | ||
const double g | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems as if you should be able to do this in one function call that yields a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what you mean. We need both |
||
= std::imag(f(complex<double>(theta_d, h), x_r, x_i, msgs)) / h; | ||
return var(new stan::math::precomp_v_vari(res, theta.vi_, g)); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#include <gtest/gtest.h> | ||
#include <stan/math.hpp> | ||
#include <stan/math/rev/scal/functor/complex_step_derivative.hpp> | ||
#include <test/unit/util.hpp> | ||
#include <iostream> | ||
#include <limits> | ||
#include <sstream> | ||
#include <vector> | ||
|
||
struct ComplexStepDerivativeScalTest : public ::testing::Test { | ||
struct Fexp { | ||
template <typename T> | ||
inline T operator()(const T &theta, const std::vector<double> &x_r, | ||
const std::vector<int> &x_i, std::ostream *msgs) const { | ||
return exp(theta) / sqrt(theta) - 0.5 * exp(theta) * pow(theta, -1.5); | ||
} | ||
}; | ||
|
||
void SetUp() {} | ||
Fexp f; | ||
const std::vector<double> x_r; | ||
const std::vector<int> x_i; | ||
std::ostream *msgs = nullptr; | ||
}; | ||
|
||
TEST_F(ComplexStepDerivativeScalTest, func_exp_sqrt) { | ||
using stan::math::complex_step_derivative; | ||
using stan::math::var; | ||
|
||
/* f near x = 0 has very large derivative */ | ||
var x = 0.01; | ||
var y = complex_step_derivative(f, x, x_r, x_i, msgs); | ||
|
||
ASSERT_FLOAT_EQ(y.val(), f(x.val(), x_r, x_i, msgs)); | ||
|
||
std::vector<stan::math::var> xv{x}; | ||
std::vector<double> g1, g; | ||
var y1 = f(x, x_r, x_i, msgs); | ||
stan::math::set_zero_all_adjoints(); | ||
y1.grad(xv, g1); | ||
stan::math::set_zero_all_adjoints(); | ||
y.grad(xv, g); | ||
ASSERT_FLOAT_EQ(g[0], g1[0]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prim version should be fully templated (return a T, and theta is const T&). No reason this wouldn't work with the higher order stuff.