Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 83 additions & 33 deletions stan/math/prim/constraint/simplex_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/fmax.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/inv_sqrt.hpp>
#include <cmath>

namespace stan {
Expand All @@ -18,7 +19,11 @@ namespace math {
* to 0 that sum to 1. A vector with (K-1) unconstrained values
* will produce a simplex of size K.
*
* The transform is based on a centered stick-breaking process.
* The simplex transform is defined using the inverse of the
* isometric log ratio (ILR) transform. This code is equivalent to
* `softmax(sum_to_zero_constrain(y))`, but is more efficient and
* stable if computed this way thanks to the use of the online
* softmax algorithm courtesy of https://arxiv.org/abs/1805.02867.
*
* @tparam Vec type of the vector
* @param y Free vector input of dimensionality K - 1.
Expand All @@ -27,29 +32,54 @@ namespace math {
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline plain_type_t<Vec> simplex_constrain(const Vec& y) {
// cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
using std::log;
using T = value_type_t<Vec>;
const auto N = y.size();

plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
if (unlikely(N == 0)) {
z.coeffRef(0) = 1;
return z;
}

auto&& y_ref = to_ref(y);
T sum_w(0);

T d(0); // sum of exponentials
T max_val(0);
T max_val_old(negative_infinity());

int Km1 = y.size();
plain_type_t<Vec> x(Km1 + 1);
T stick_len(1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
T z_k = inv_logit(y.coeff(k) - log(Km1 - k));
x.coeffRef(k) = stick_len * z_k;
stick_len -= x.coeff(k);
for (int i = N; i > 0; --i) {
double n = static_cast<double>(i);
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
sum_w += w;

z.coeffRef(i - 1) += sum_w;
z.coeffRef(i) -= w * n;

max_val = fmax(max_val_old, z.coeff(i));
d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
max_val_old = max_val;
}
x.coeffRef(Km1) = stick_len;
return x;

// above loop doesn't reach i==0
max_val = fmax(max_val_old, z.coeff(0));
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);

z.array() = (z.array() - max_val).exp() / d;

return z;
}

/**
* Return the simplex corresponding to the specified free vector
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
* The simplex transform is defined using the inverse of the
* isometric log ratio (ILR) transform. This code is equivalent to
* `softmax(sum_to_zero_constrain(y))`, but is more efficient and
* stable if computed this way thanks to the use of the online
* softmax algorithm courtesy of https://arxiv.org/abs/1805.02867.
*
* @tparam Vec type of the vector
* @tparam Lp A scalar type for the lp argument. The scalar type of Vec should
Expand All @@ -62,26 +92,46 @@ template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr,
require_convertible_t<value_type_t<Vec>, Lp>* = nullptr>
inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
using Eigen::Dynamic;
using Eigen::Matrix;
using std::log;
using T = value_type_t<Vec>;
const auto N = y.size();

int Km1 = y.size(); // K = Km1 + 1
plain_type_t<Vec> x(Km1 + 1);
T stick_len(1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
double eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
T adj_y_k = y.coeff(k) + eq_share;
T z_k = inv_logit(adj_y_k);
x.coeffRef(k) = stick_len * z_k;
lp += log(stick_len);
lp -= log1p_exp(-adj_y_k);
lp -= log1p_exp(adj_y_k);
stick_len -= x.coeff(k); // equivalently *= (1 - z_k);
plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
if (unlikely(N == 0)) {
z.coeffRef(0) = 1;
return z;
}
x.coeffRef(Km1) = stick_len; // no Jacobian contrib for last dim
return x;

auto&& y_ref = to_ref(y);
T sum_w(0);

T d(0); // sum of exponentials
T max_val(0);
T max_val_old(negative_infinity());

for (int i = N; i > 0; --i) {
double n = static_cast<double>(i);
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
sum_w += w;

z.coeffRef(i - 1) += sum_w;
z.coeffRef(i) -= w * n;

max_val = fmax(max_val_old, z.coeff(i));
d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
max_val_old = max_val;
}

// above loop doesn't reach i==0
max_val = fmax(max_val_old, z.coeff(0));
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);

z.array() = (z.array() - max_val).exp() / d;

// equivalent to z.log().sum() + 0.5 * log(N + 1)
lp += -(N + 1) * (max_val + log(d)) + 0.5 * log(N + 1);

return z;
}

/**
Expand Down
20 changes: 10 additions & 10 deletions stan/math/prim/constraint/simplex_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/sqrt.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <cmath>

Expand All @@ -17,8 +17,8 @@ namespace math {
* the specified simplex. It applies to a simplex of dimensionality
* K and produces an unconstrained vector of dimensionality (K-1).
*
* <p>The simplex transform is defined through a centered
* stick-breaking process.
* The simplex transform is defined using isometric log ratio (ILR)
* transform
*
* @tparam ColVec type of the simplex (must be a column vector)
* @param x Simplex of dimensionality K.
Expand All @@ -28,20 +28,20 @@ namespace math {
*/
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr>
inline plain_type_t<Vec> simplex_free(const Vec& x) {
using std::log;
using T = value_type_t<Vec>;

const auto& x_ref = to_ref(x);
check_simplex("stan::math::simplex_free", "Simplex variable", x_ref);
Eigen::Index Km1 = x_ref.size() - 1;
plain_type_t<Vec> y(Km1);
T stick_len = x_ref.coeff(Km1);
for (Eigen::Index k = Km1; --k >= 0;) {
stick_len += x_ref.coeff(k);
T z_k = x_ref.coeff(k) / stick_len;
y.coeffRef(k) = logit(z_k) + log(Km1 - k);
// note: log(Km1 - k) = logit(1.0 / (Km1 + 1 - k));

T cumsum = 0.0;
for (int i = 0; i < Km1; ++i) {
cumsum += log(x_ref.coeff(i));
double n = static_cast<double>(i + 1);
y.coeffRef(i) = (cumsum - n * log(x_ref.coeff(i + 1))) / sqrt(n * (n + 1));
}

return y;
}

Expand Down
11 changes: 4 additions & 7 deletions stan/math/prim/constraint/stochastic_column_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/constraint/simplex_constrain.hpp>
#include <cmath>

Expand All @@ -16,7 +12,8 @@ namespace math {
/**
* Return a column stochastic matrix.
*
* The transform is based on a centered stick-breaking process.
* The transform is defined using the inverse of the
* isometric log ratio (ILR) transform
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M)
Expand All @@ -39,8 +36,8 @@ inline plain_type_t<Mat> stochastic_column_constrain(const Mat& y) {
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
* The simplex transform is defined using the inverse of the
* isometric log ratio (ILR) transform
*
* @tparam Mat type of the Matrix
* @tparam Lp A scalar type for the lp argument. The scalar type of Mat should
Expand Down
43 changes: 12 additions & 31 deletions stan/math/prim/constraint/stochastic_row_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/constraint/simplex_constrain.hpp>
#include <cmath>

Expand All @@ -16,7 +12,8 @@ namespace math {
/**
* Return a row stochastic matrix.
*
* The transform is based on a centered stick-breaking process.
* The transform is defined using the inverse of the
* isometric log ratio (ILR) transform
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (N, K - 1).
Expand All @@ -27,23 +24,17 @@ template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
int Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
using eigen_arr = Eigen::Array<scalar_type_t<Mat>, -1, 1>;
eigen_arr stick_len = eigen_arr::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k));
x.array().col(k) = stick_len * z_k;
stick_len -= x.array().col(k);
plain_type_t<Mat> ret(N, y_ref.cols() + 1);
for (Eigen::Index i = 0; i < N; ++i) {
ret.row(i) = simplex_constrain(y_ref.row(i));
}
x.array().col(Km1) = stick_len;
return x;
return ret;
}

/**
* Return a row stochastic matrix.
* The simplex transform is defined through a centered
* stick-breaking process.
* The simplex transform is defined using the inverse of the
* isometric log ratio (ILR) transform
*
* @tparam Mat type of the matrix
* @tparam Lp A scalar type for the lp argument. The scalar type of Mat should
Expand All @@ -59,21 +50,11 @@ template <typename Mat, typename Lp,
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y, Lp& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
Eigen::Index Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
Eigen::Array<scalar_type_t<Mat>, -1, 1> stick_len
= Eigen::Array<scalar_type_t<Mat>, -1, 1>::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
const auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
auto adj_y_k = (y_ref.array().col(k) + eq_share).eval();
auto z_k = inv_logit(adj_y_k);
x.array().col(k) = stick_len * z_k;
lp += -sum(log1p_exp(adj_y_k)) - sum(log1p_exp(-adj_y_k))
+ sum(log(stick_len));
stick_len -= x.array().col(k); // equivalently *= (1 - z_k);
plain_type_t<Mat> ret(N, y_ref.cols() + 1);
for (Eigen::Index i = 0; i < N; ++i) {
ret.row(i) = simplex_constrain(y_ref.row(i), lp);
}
x.col(Km1).array() = stick_len;
return x;
return ret;
}

/**
Expand Down
Loading