Skip to content

InferenceData coords for NumPyro plates #2022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ColdTeapot273K opened this issue Apr 28, 2022 · 3 comments · Fixed by #2441
Closed

InferenceData coords for NumPyro plates #2022

ColdTeapot273K opened this issue Apr 28, 2022 · 3 comments · Fixed by #2441

Comments

@ColdTeapot273K
Copy link

ColdTeapot273K commented Apr 28, 2022

Tell us about it

When creating InferenceData using

az.from_numpyro(...)

the resulting autogenerated coordinates are note very telling:
image

Now consider the NumPyro model below that produced these samples.
These coords with autogenerated names are in fact plate dimensions. And plates have names.

import numpy as np
import pandas as pd

# INFO: PPL specific imports

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoNormal

from jax import lax, random
from jax.scipy.special import expit

import arviz as az

# %%

data_uri = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/NWOGrants.csv"
df_dev = pd.read_csv(data_uri, sep=";")
df_dev.head()

df_dev["gender"] = df_dev["gender"] == "m"
df_dev["gender"] = df_dev["gender"].astype(int)
df_dev["discipline"] = df_dev["discipline"].astype("category").cat.codes

# %%


def model(data: pd.DataFrame, observed=True):
    applications = data["applications"].values
    awards = data["awards"].values

    discipline = data["discipline"].values
    discipline_card = np.unique(discipline).shape[0]
    gender = data["gender"].values
    gender_card = np.unique(gender).shape[0]

    observations_card = data.shape[0]

    # INFO: good plate version
    with numpyro.plate("plate_gender", gender_card):
        with numpyro.plate("plate_discipline", discipline_card):
            alpha_gender_discipline = numpyro.sample("alpha_gender_discipline", dist.Normal(-1, 1))

    assert alpha_gender_discipline.shape == (9, 2)

    link_p = numpyro.deterministic("link_p", alpha_gender_discipline[discipline, gender])

    with numpyro.plate("plate_observations", observations_card):
        numpyro.sample(
            "awards", dist.Binomial(total_count=applications, logits=link_p), obs=awards if observed else None
        )


kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=1,
    progress_bar=True,
)
mcmc.run(random.PRNGKey(0), df_dev)
samples = mcmc.get_samples()


az.from_numpyro(mcmc)

Thoughts on implementation

It would be handy if from_numpyro could extract those sites from numpyro model.

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

@OriolAbril
Copy link
Member

That would be nice. Do you know how the dimensions are stored in the numpyro model or mcmc object?

If it is possible to get a dictionary with variable names as keys and a list of the dimensions/plates as values it would be straightforward to implement. the pymc converter does that already

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

This sounds like a confusion between dimensions and coordinates. It uses cmdstanpy, but maybe this blogpost I wrote can help clarify the difference.

@kylejcaron
Copy link
Contributor

That would be nice. Do you know how the dimensions are stored in the numpyro model or mcmc object?

If it is possible to get a dictionary with variable names as keys and a list of the dimensions/plates as values it would be straightforward to implement. the pymc converter does that already

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

This sounds like a confusion between dimensions and coordinates. It uses cmdstanpy, but maybe this blogpost I wrote can help clarify the difference.

Leaving some notes here, I might be interested in picking this up later on if I have time but I implemented something similar in my own code recently and figured I'd share it. I'll hopefully have time to revisit this.

Dims are stored in numpyro in the model (as opposed to in the mcmc object) and tend to get captured by plates, which typically represent independent draws across dimensions.

def model(X,y=None):
     ...
     with numpyro.plate("categories", n_categories):
          alpha = numpyro.sample("alpha", dist.Normal(0, 5))
     ....

They can be pulled out of the model as follows:

plates = numpyro.infer.inspect.get_model_relations(
    model,
    model_args=model_args, 
    model_kwargs=model_kwargs
)['plate_sample']

# I haven't really checked this, its probably wrong when there are multiple dims but its a start
dims = {
    value:[key]
    for key,lst in plates.items()
    for value in lst
}

Notice that it also requires *model_args, and **model_kwargs - this means the input for the model has to get included as well for the inspect tool to work. This would require the model to be passed in as an optional input AND the model inputs which makes it a complicated pattern that I wouldn't recommend:

idata = az.from_numpyro(mcmc, model, *model_args, **model_kwargs)

I'm guessing there's probably a way to get the plate dims without the model args and kwargs, but I haven't figured out how yet. That would simplify the example above to a much more reasonable state:

idata = az.from_numpyro(mcmc, model)

However, there can also be dependent draws across dimensions and these aren't represented with plates. For instance, a categorical variable represented as a zerosumnormal where there's a dim for each of the 50 us states isn't represented with a plate in numpyro since the zerosumnormal has dependent dims due to the zerosumconstraint.

def model(X, y=None):
     ...
     b_state = numpyro.sample("b_state", dist.ZeroSumNormal(scale=1, event_shape=(50,))
     ....

This wouldn't get captured by the example above. Not sure if there is a solution for that

@kylejcaron
Copy link
Contributor

kylejcaron commented Mar 30, 2025

Hi @OriolAbril, I opened a PR in numpyro that would help make this possible (and if it gets accepted, I can make a PR here). I have an edge case I wanted to ask about.

Heres an example model

def model(X=None, y=None):
    alpha = numpyro.sample("alpha", dist.Normal(0, 2.5))
    sigma = numpyro.sample("sigma", dist.HalfNormal(2.5))
    if X is not None:
        with numpyro.plate("features", X.shape[-1]):
            beta = numpyro.sample("beta", dist.Normal(0,1))
    
    bX = 0
    with numpyro.plate("obs_idx", len(X) if X is not None else len(y)):
        if X is not None:
            bX += jnp.dot(X, beta)
        
        # shape = len(X) if X is not none else 1
        mu = numpyro.deterministic("mu", alpha + bX)
        return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

Heres an inspect output when X is None

inspect.get_site_dims(model, X=None, y=y)
# {'mu': {'batch_dims': ['obs_idx'], 'event_dims': []},
#   'y': {'batch_dims': ['obs_idx'], 'event_dims': []}}

for this model, the problem is that obs_idx is a named dim for mu, but when X is None, the batch shape of mu is None instead of len(coords['obs_idx']) - this would fail in az.from_numpyro

Some questions:

  1. would automatically broadcasting under the hood in NumpyroConverter make sense if its applicable (i.e. mu would be broadcasted from (chain, draw, 1) to (chain, draw, obs_idx)?
  2. If dims arent provided to NumpyroConverter, should they automatically be grabbed from inspect.get_site_dims, or should it be required for users to provide dims themselves?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants