Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow distributions to return vectors #2751

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
38 changes: 38 additions & 0 deletions stan/math/fwd/functor/operands_and_partials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ops_partials_edge<Dx, fvar<Dx>> {
const Op& operand_;

Dx dx() { return this->partials_[0] * this->operand_.d_; }
Dx dx_v() { return this->partials_[0] * this->operand_.d_; }
};
} // namespace internal

Expand Down Expand Up @@ -106,6 +107,14 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
= edge1_.dx() + edge2_.dx() + edge3_.dx() + edge4_.dx() + edge5_.dx();
return T_return_type(value, deriv);
}

template <typename EigVec, require_eigen_vector_t<EigVec>* = nullptr>
auto build(EigVec&& value) {
Eigen::Array<fvar<Dx>, -1, 1> ret(value.template cast<fvar<Dx>>());
ret.d() = (edge1_.dx_v() + edge2_.dx_v() + edge3_.dx_v() + edge4_.dx_v()
+ edge5_.dx_v());
return ret;
}
};

namespace internal {
Expand Down Expand Up @@ -135,6 +144,13 @@ class ops_partials_edge<Dx, std::vector<fvar<Dx>>> {
}
return derivative;
}
Eigen::Array<Dx, -1, 1> dx_v() {
Eigen::Array<Dx, -1, 1> derivative(this->operands_.size());
for (size_t i = 0; i < this->operands_.size(); ++i) {
derivative[i] = this->partials_[i] * this->operands_[i].d_;
}
return derivative;
}
};

template <typename Dx, int R, int C>
Expand All @@ -161,6 +177,10 @@ class ops_partials_edge<Dx, Eigen::Matrix<fvar<Dx>, R, C>> {
}
return derivative;
}

Eigen::Array<Dx, -1, 1> dx_v() {
return this->partials_.array() * this->operands_.d().array();
}
};

// Multivariate; vectors of eigen types
Expand Down Expand Up @@ -191,6 +211,14 @@ class ops_partials_edge<Dx, std::vector<Eigen::Matrix<fvar<Dx>, R, C>>> {
}
return derivative;
}

Eigen::Array<Dx, -1, 1> dx_v() {
Eigen::Array<Dx, -1, 1> derivative(this->operands_.size());
for (size_t i = 0; i < this->operands_.size(); ++i) {
derivative[i] = this->partials_vec_[i].dot(this->operands_[i].d());
}
return derivative;
}
};

template <typename Dx>
Expand Down Expand Up @@ -220,6 +248,16 @@ class ops_partials_edge<Dx, std::vector<std::vector<fvar<Dx>>>> {
}
return derivative;
}
Eigen::Array<Dx, -1, 1> dx_v() {
Eigen::Array<Dx, -1, 1> derivative
= Eigen::Array<Dx, -1, 1>::Zero(this->operands_.size());
for (size_t i = 0; i < this->operands_.size(); ++i) {
for (int j = 0; j < this->operands_[i].size(); ++j) {
derivative[i] = this->partials_vec_[i][j] * this->operands_[i][j].d_;
}
}
return derivative;
}
};

} // namespace internal
Expand Down
6 changes: 5 additions & 1 deletion stan/math/prim/functor/operands_and_partials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>> {
* expression returning zero.
*/
static constexpr double dx() noexcept { return 0.0; }
static constexpr double dx_v() noexcept { return 0.0; }
/**
* Return the size of the operand for the edge. For doubles this is a compile
* time expression returning zero.
Expand Down Expand Up @@ -140,7 +141,10 @@ class operands_and_partials {
* @param value the return value of the function we are compressing
* @return the value with its derivative
*/
inline double build(double value) const noexcept { return value; }
template <typename T>
inline auto build(T&& value) const noexcept {
return std::forward<T>(value);
}

// These will always be 0 size base template instantiations (above).
internal::ops_partials_edge<double, std::decay_t<Op1>> edge1_;
Expand Down
226 changes: 226 additions & 0 deletions stan/math/prim/functor/prob_reducer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_PROB_REDUCER_HPP
#define STAN_MATH_PRIM_FUNCTOR_PROB_REDUCER_HPP

#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {

/**
* Used by distributions to decide whether return type shoudl be Scalar or
* Vector.
*/
enum class ProbReturnType { Scalar, Vector };

/**
* For scalars performs summations and is a no-op reducer for eigen vectors.
* Used in the probability distributions for scalar or vector return types.
*/
template <typename T, typename = void>
struct prob_reducer;

/**
* For scalars performs summations when given eigen types.
* @tparam A stan scalar type
*/
template <typename T>
struct prob_reducer<T, require_stan_scalar_t<T>> {
T ret_; // Underlying return type

/**
* Construct from an Eigen type while ignoring size argument passed.
* @tparam EigArr A type inheriting from `Eigen::EigenBase`
* @tparam Tossed An integral type
* @param x will be summed and passed into `ret_`.
*/
template <typename EigArr, typename Tossed,
require_eigen_t<EigArr>* = nullptr,
require_integral_t<Tossed>* = nullptr>
prob_reducer(EigArr&& x, Tossed&& /* */)
: ret_(sum(std::forward<EigArr>(x))) {}

/**
* Construct from an Eigen type while ignoring size argument passed.
* @tparam Scalar a scalar
* @tparam Tossed an integral type
* @param x will be summed and inserted into `ret_`.
*/
template <typename Scalar, typename Tossed,
require_stan_scalar_t<Scalar>* = nullptr>
prob_reducer(Scalar&& x, Tossed&& /* */) : ret_(x) {}

/**
* Perform summation and then assignment
* @tparam EigArr A type inheriting from `Eigen::EigenBase`
*/
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
inline auto operator=(EigArr&& x) {
ret_ = sum(x);
return *this;
}

/**
* Assignment
* @tparam Scalar A scalar type
*/
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
inline auto operator=(Scalar x) {
ret_ = x;
return *this;
}

/**
* Perform summation and then `+=`
* @tparam EigArr A type inheriting from `Eigen::EigenBase`
* @param x Eigen object to be summed.
*/
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
inline auto operator+=(EigArr&& x) {
ret_ += sum(x);
return *this;
}

template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
inline auto operator+=(Scalar&& x) {
ret_ += x;
return *this;
}

/**
* Perform summation and then `-=`
* @tparam EigArr A type inheriting from `Eigen::EigenBase`
* @param x Eigen object to be summed.
*/
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
inline auto operator-=(EigArr&& x) {
ret_ -= sum(x);
return *this;
}

template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
inline auto operator-=(Scalar&& x) {
ret_ -= x;
return *this;
}

/**
* Return the underlying scalar return type.
*/
inline auto ret() noexcept { return ret_; }

/**
* Return a zero value, used when distribution has special cases that
* immedietly return zero.
* @tparam Types types to deduce the overall return type of the function.
*/
template <typename... Types>
static auto zero(int /* */) {
return return_type_t<Types...>(0);
}
};

template <typename T>
struct prob_reducer<T, require_eigen_t<T>> {
T ret_;

/**
* Construct from an Eigen type while ignoring size argument passed.
* @tparam EigArr A type inheriting from `Eigen::EigenBase`
* @tparam Size An integral type
* @param x will be forwarded to `ret_`.
*/
template <typename EigArr, typename Size, require_eigen_t<EigArr>* = nullptr,
require_integral_t<Size>* = nullptr>
prob_reducer(EigArr&& x, Size /* x */) : ret_(std::forward<EigArr>(x)) {}

/**
* Construct from a scalar type.
* @tparam Scalar a scalar
* @tparam Size An integral type
* @param x passed to `ret_` along with size to fill with a base value.
* @param n The size `ret_` should be
*/
template <typename Scalar, typename Size,
require_stan_scalar_t<Scalar>* = nullptr,
require_integral_t<Size>* = nullptr>
prob_reducer(Scalar constant, Size n) : ret_(T::Constant(n, constant)) {}

/**
* assignment
* @tparam EigArr A type inheriting from `Eigen::EigenBase`
*/
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
inline auto operator=(EigArr&& x) {
ret_ = std::forward<EigArr>(x);
return *this;
}

/**
* assignm scalar by propogating value over `ret_`
* @tparam Scalar a stan scalar
* @param x The value to fill `ret_` with.
*/
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
inline auto operator=(Scalar x) {
ret_ = Eigen::Array<value_type_t<T>, -1, 1>::Constant(x, ret_.size());
return *this;
}

template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
inline auto operator+=(EigArr&& x) {
ret_ += std::forward<EigArr>(x);
return *this;
}

template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
inline auto operator+=(Scalar&& x) {
ret_ += x;
return *this;
}

template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
inline auto operator-=(EigArr&& x) {
ret_ -= std::forward<EigArr>(x);
return *this;
}

template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
inline auto operator-=(Scalar&& x) {
ret_ -= x;
return *this;
}

/**
* Return the underlying scalar return type. Passed the underlying by
* moving it which can cause `ret_` to be uninitialized after.
*/
inline auto&& ret() noexcept { return std::move(ret_); }

/**
* Return a zero value, used when distribution has special cases that
* immedietly return zero.
* @tparam Types types to deduce the overall return type of the function.
* @param size The size of the array to return
*/
template <typename... Types>
static auto zero(int size) {
return Eigen::Array<return_type_t<Types...>, -1, 1>::Constant(0, size)
.eval();
}
};

/**
* Generate a reducer with correct return type.
* @tparam ReturnType Either Scalar or Vector.
* @tparam Types A parameter pack of types to deduce the underlying scalar type
* from
*/
template <ProbReturnType ReturnType, typename... Types>
using prob_return_t = prob_reducer<std::conditional_t<
ReturnType == ProbReturnType::Scalar, return_type_t<Types...>,
Eigen::Array<return_type_t<Types...>, -1, 1>>>;

} // namespace math
} // namespace stan

#endif
12 changes: 4 additions & 8 deletions stan/math/prim/prob/normal_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,16 @@ namespace math {
* @tparam T_loc Type of location parameter.
*/
template <bool propto, typename T_y, typename T_loc, typename T_scale>
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
const T_loc& mu,
const T_scale& sigma) {
return normal_lpdf<propto, T_y, T_loc, T_scale>(y, mu, sigma);
inline auto normal_log(const T_y& y, const T_loc& mu, const T_scale& sigma) {
return normal_lpdf<propto>(y, mu, sigma);
}

/** \ingroup prob_dists
* @deprecated use <code>normal_lpdf</code>
*/
template <typename T_y, typename T_loc, typename T_scale>
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
const T_loc& mu,
const T_scale& sigma) {
return normal_lpdf<T_y, T_loc, T_scale>(y, mu, sigma);
inline auto normal_log(const T_y& y, const T_loc& mu, const T_scale& sigma) {
return normal_lpdf<false>(y, mu, sigma);
}

} // namespace math
Expand Down
Loading