-
-
Notifications
You must be signed in to change notification settings - Fork 191
Add inverse ILR simplex transform #3170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think this would be the first time we ever change a constraint parametrization, so there are some details to work out. I would propose:
My reasoning: How does this sound @spinkney @bob-carpenter? |
I'm curious where you think people will want the current simplex transform. @spinkney---do you know of anywhere it outperforms the ILR version? If it's just backwards compatibility of sampling, that's not something we guarantee release over release. We will be changing the meaning of |
I don't know where the stickbreaking outperforms though it might be more competitive after this change that @WardBrian adds. Regardless, my push to have multiple constraint transform options for the same transformation is that there are cases where one transform is better than another for different models. I'd like users to have the freedom to swap them out. I'm thinking that we can tell people that our dev priority is with the default transforms that have been shown to work well in many cases. Transforms which work well in specific cases will be low priority to include. Transforms shown to outperform the default in most cases will be considered to replace the current default. |
My reasoning for having the existing transform still around was motivated by
It’s also possible there are users really relying on specific implementation details who would want to keep updating while not changing too much else of their pipeline, though I don’t imagine that’s a huge group |
On the original topic: I assume we would also want to move the existing row/column stochastic types to the ILR technique? |
yep! |
There were a couple issues with the online softmax in the code in the OP. I believe this fixes them: template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr,
require_convertible_t<value_type_t<Vec>, Lp>* = nullptr>
inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
using std::log;
using T = value_type_t<Vec>;
const auto N = y.size();
plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
if (unlikely(N == 0)) {
// FIX: size zero input should yield a vector of {1}, not {0}
z.coeffRef(0) = 1;
return z;
}
auto&& y_ref = to_ref(y);
T sum_w(0);
T d(0); // sum of exponentials
T max_val(0);
// FIX: initialize at -inf
T max_val_old(negative_infinity());
// this is equivalent to softmax(sum_to_zero_constrain(y))
// but is more efficient and stable if computed this way
for (int i = N; i > 0; --i) {
double n = static_cast<double>(i);
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
sum_w += w;
z.coeffRef(i - 1) += sum_w;
z.coeffRef(i) -= w * n;
max_val = max(max_val_old, z.coeff(i));
d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
max_val_old = max_val;
}
// NEW:
// above loop doesn't reach i==0
max_val = max(max_val_old, z.coeff(0));
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
// FIX: off by one error in original loop end
for (int i = 0; i <= N; ++i) {
z.coeffRef(i) = exp(z.coeff(i) - max_val) / d;
}
lp += -(N + 1) * (max_val + log(d)) + 0.5 * log(N + 1);
return z;
} |
Awesome! That's what I get for not testing it |
Add inverse ILR simplex.
@WardBrian we can do this in 2 loops using the online softmax https://arxiv.org/abs/1805.02867. The first loop constructs the sum to zero, get the max value of that sum to zero vec, and returns the sum of exponentials with the max subtracted. The second loop does the safe exponential of the sum to zero vector with the max subtracted and divides by the sum of exponentials output from the first loop. The jacobian is output after.
The text was updated successfully, but these errors were encountered: