Skip to content

Add inverse ILR simplex transform #3170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
spinkney opened this issue Mar 31, 2025 · 8 comments · May be fixed by #3171
Open

Add inverse ILR simplex transform #3170

spinkney opened this issue Mar 31, 2025 · 8 comments · May be fixed by #3171

Comments

@spinkney
Copy link
Collaborator

Add inverse ILR simplex.

@WardBrian we can do this in 2 loops using the online softmax https://arxiv.org/abs/1805.02867. The first loop constructs the sum to zero, get the max value of that sum to zero vec, and returns the sum of exponentials with the max subtracted. The second loop does the safe exponential of the sum to zero vector with the max subtracted and divides by the sum of exponentials output from the first loop. The jacobian is output after.

inline plain_type_t<Vec> simplex_ilr_constrain(const Vec& y, Lp& lp) {
  const auto N = y.size();

  plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
  if (unlikely(N == 0)) {
    return z;
  }

  auto&& y_ref = to_ref(y);
  value_type_t<Vec> sum_w(0);

  // new   
  double d = 0;  // sum of exponentials
  double max_val = 0;
  double max_val_old = 0;

  for (int i = N; i > 0; --i) {
    double n = static_cast<double>(i);
    auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
    sum_w += w;

    z.coeffRef(i - 1) += sum_w;
    z.coeffRef(i) -= w * n;
    
    // new
    max_val = max(max_val_old, z.coeff(i));
    d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
    max_val_old = max_val;
  }

  // new loop
  for (int i = 0; i < N; ++i) {
   z.coeffRef(i) = exp(z.coeff(i) - max_val) / d;
  }

  lp += -N * log(d) + 0.5 * log(N);

  return z;
 }
@WardBrian
Copy link
Member

WardBrian commented Apr 1, 2025

I think this would be the first time we ever change a constraint parametrization, so there are some details to work out. I would propose:

  1. Renaming the existing simplex_[constrain, free] to something like simplex_stickbreaking_...
  2. Adding the ILR version under the name simplex_ilr_...
  3. Defining new simplex_[constrain, free] functions that just call the simplex_ilr_... friends

My reasoning:
Doing (1) lets us keep the existing simplex transform exposed as functions for the user.
Doing (3) actually changes the meaning of simplex[N] in the language.
At first, (2) might seem odd, but it lets the user be explicit if they're deciding to manually call our functions. They can either call simplex_stickbreak_jacobian, simplex_ilr_jacobian, or simplex_jacobian, with the idea being that simplex_jacobian is always going to be the same as whatever simplex[N] does. So if someone wants to keep using the ILR even if we changed the default again in the future, they could use ..._ilr_..., but if they want the default they can just use simplex_jacobian.

How does this sound @spinkney @bob-carpenter?

@bob-carpenter
Copy link
Member

I'm curious where you think people will want the current simplex transform. @spinkney---do you know of anywhere it outperforms the ILR version? If it's just backwards compatibility of sampling, that's not something we guarantee release over release. We will be changing the meaning of simplex[K] declarations in the language anyway.

@spinkney
Copy link
Collaborator Author

spinkney commented Apr 1, 2025

I'm curious where you think people will want the current simplex transform. @spinkney---do you know of anywhere it outperforms the ILR version? If it's just backwards compatibility of sampling, that's not something we guarantee release over release. We will be changing the meaning of simplex[K] declarations in the language anyway.

I don't know where the stickbreaking outperforms though it might be more competitive after this change that @WardBrian adds. Regardless, my push to have multiple constraint transform options for the same transformation is that there are cases where one transform is better than another for different models. I'd like users to have the freedom to swap them out.

I'm thinking that we can tell people that our dev priority is with the default transforms that have been shown to work well in many cases. Transforms which work well in specific cases will be low priority to include. Transforms shown to outperform the default in most cases will be considered to replace the current default.

@WardBrian
Copy link
Member

My reasoning for having the existing transform still around was motivated by

  1. It’s not particularly hard to keep around, this part of code is relatively slow moving
  2. It sets a precedent for cases where there may not be as clear an advantage but we still want to give users a different option than the current default

It’s also possible there are users really relying on specific implementation details who would want to keep updating while not changing too much else of their pipeline, though I don’t imagine that’s a huge group

@WardBrian
Copy link
Member

On the original topic: I assume we would also want to move the existing row/column stochastic types to the ILR technique?

@spinkney
Copy link
Collaborator Author

spinkney commented Apr 2, 2025

On the original topic: I assume we would also want to move the existing row/column stochastic types to the ILR technique?

yep!

@WardBrian
Copy link
Member

WardBrian commented Apr 2, 2025

There were a couple issues with the online softmax in the code in the OP. I believe this fixes them:

template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
          require_not_st_var<Vec>* = nullptr,
          require_convertible_t<value_type_t<Vec>, Lp>* = nullptr>
inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
  using std::log;
  using T = value_type_t<Vec>;
  const auto N = y.size();

  plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
  if (unlikely(N == 0)) {
    // FIX: size zero input should yield a vector of {1}, not {0}
    z.coeffRef(0) = 1;
    return z;
  }

  auto&& y_ref = to_ref(y);
  T sum_w(0);

  T d(0);  // sum of exponentials
  T max_val(0);
  // FIX: initialize at -inf
  T max_val_old(negative_infinity());

  // this is equivalent to softmax(sum_to_zero_constrain(y))
  // but is more efficient and stable if computed this way

  for (int i = N; i > 0; --i) {
    double n = static_cast<double>(i);
    auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
    sum_w += w;

    z.coeffRef(i - 1) += sum_w;
    z.coeffRef(i) -= w * n;


    max_val = max(max_val_old, z.coeff(i));
    d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
    max_val_old = max_val;
  }

  // NEW:
  // above loop doesn't reach i==0
  max_val = max(max_val_old, z.coeff(0));
  d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);

  // FIX: off by one error in original loop end
  for (int i = 0; i <= N; ++i) {
    z.coeffRef(i) = exp(z.coeff(i) - max_val) / d;
  }

  lp += -(N + 1) * (max_val + log(d)) + 0.5 * log(N + 1);

  return z;
}

@spinkney
Copy link
Collaborator Author

spinkney commented Apr 2, 2025

There were a couple issues with the online softmax in the code in the OP. I believe this fixes them:

Awesome! That's what I get for not testing it

@WardBrian WardBrian linked a pull request Apr 2, 2025 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants