Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 949 complex step derivative #950

Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions stan/math/prim/scal/functor/complex_step_derivative.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef STAN_MATH_PRIM_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP
#define STAN_MATH_PRIM_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP

#include <stan/math/prim/scal/fun/value_of.hpp>
#include <stan/math/prim/scal/meta/return_type.hpp>

#include <vector>
#include <iostream>

namespace stan {
namespace math {

/**
* Return a double that has value of given functor F with signature
* (complex, std::vector<double>, std::vector<int>, stream*) : complex
*
* @tparam F type of functor F
* @param[in] f functor for the complex number evaluation,
* must support @c std::complex<double> as arg.
* @param[in] theta parameter where f and df/d(theta) is requested.
* @param[in] x_r continuous data vector for the ODE.
* @param[in] x_i integer data vector for the ODE.
* @param[out] msgs the print stream for warning messages.
* @return a var with value f(theta.val()) and derivative at theta.
*/
template <typename F>
double complex_step_derivative(const F& f, const double& theta,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prim version should be fully templated (return a T, and theta is const T&). No reason this wouldn't work with the higher order stuff.

const std::vector<double>& x_r,
const std::vector<int>& x_i,
std::ostream* msgs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass h through as an argument and have its default value be 1e-32 or whatever? That seems like a handy thing to control.

return f(theta, x_r, x_i, msgs);
}

} // namespace math
} // namespace stan

#endif
53 changes: 53 additions & 0 deletions stan/math/rev/scal/functor/complex_step_derivative.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#ifndef STAN_MATH_REV_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP
#define STAN_MATH_REV_SCAL_FUNCTOR_COMPLEX_STEP_DERIVATIVE_HPP

#include <stan/math/prim/scal/fun/value_of.hpp>
#include <stan/math/prim/scal/meta/return_type.hpp>
#include <stan/math/prim/scal/functor/complex_step_derivative.hpp>
#include <stan/math/rev/core/precomp_v_vari.hpp>
#include <stan/math/rev/core/vari.hpp>
#include <stan/math/rev/core/var.hpp>

#include <vector>
#include <complex>
#include <iostream>

namespace stan {
namespace math {

/**
* Return a var that has value of given functor F and derivative
* of df/d(theta), using complex step derivative
* approximation. "f" does not have to support "var"
* type, as its signature should be
* (complex, std::vector<double>, std::vector<int>, stream*) : complex
*
* @tparam F type of functor F
* @param[in] f functor for the complex number evaluation,
* must support @c std::complex<double> as arg.
* @param[in] theta parameter where f and df/d(theta) is requested.
* @param[in] x_r continuous data vector for the ODE.
* @param[in] x_i integer data vector for the ODE.
* @param[out] msgs the print stream for warning messages.
* @return a var with value f(theta.val()) and derivative at theta.
*/
template <typename F>
stan::math::var complex_step_derivative(const F& f,
const stan::math::var& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i,
std::ostream* msgs) {
using stan::math::var;
using std::complex;
static double h = 1.e-32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The precision is O(h^2) so why does h need to be 1e-32? Usually people do 1e-8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, < 1e-10 it becomes insensitive. I was showing merely we can do this way less than finite difference.

const double theta_d = theta.val();
const double res = complex_step_derivative(f, theta_d, x_r, x_i, msgs);
const double g
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems as if you should be able to do this in one function call that yields a complex<double> that has both the real and imaginary parts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean. We need both f(x) and f(x + ih), don't we?

= std::imag(f(complex<double>(theta_d, h), x_r, x_i, msgs)) / h;
return var(new stan::math::precomp_v_vari(res, theta.vi_, g));
}

} // namespace math
} // namespace stan

#endif
44 changes: 44 additions & 0 deletions test/unit/math/rev/scal/functor/complex_step_derivative_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <gtest/gtest.h>
#include <stan/math.hpp>
#include <stan/math/rev/scal/functor/complex_step_derivative.hpp>
#include <test/unit/util.hpp>
#include <iostream>
#include <limits>
#include <sstream>
#include <vector>

struct ComplexStepDerivativeScalTest : public ::testing::Test {
struct Fexp {
template <typename T>
inline T operator()(const T &theta, const std::vector<double> &x_r,
const std::vector<int> &x_i, std::ostream *msgs) const {
return exp(theta) / sqrt(theta) - 0.5 * exp(theta) * pow(theta, -1.5);
}
};

void SetUp() {}
Fexp f;
const std::vector<double> x_r;
const std::vector<int> x_i;
std::ostream *msgs = nullptr;
};

TEST_F(ComplexStepDerivativeScalTest, func_exp_sqrt) {
using stan::math::complex_step_derivative;
using stan::math::var;

/* f near x = 0 has very large derivative */
var x = 0.01;
var y = complex_step_derivative(f, x, x_r, x_i, msgs);

ASSERT_FLOAT_EQ(y.val(), f(x.val(), x_r, x_i, msgs));

std::vector<stan::math::var> xv{x};
std::vector<double> g1, g;
var y1 = f(x, x_r, x_i, msgs);
stan::math::set_zero_all_adjoints();
y1.grad(xv, g1);
stan::math::set_zero_all_adjoints();
y.grad(xv, g);
ASSERT_FLOAT_EQ(g[0], g1[0]);
}