jupytext | kernelspec | myst | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
|
(marginalizing-models)=
:::{post} Jan 20, 2024 :tags: mixture model :category: intermediate, how-to :author: Rob Zinkov :::
PyMC is very amendable to sampling models with discrete latent variables. But if you insist on using the NUTS sampler exclusively, you will need to get rid of your discrete variables somehow. The best way to do this is by marginalizing them out, as then you benefit from Rao-Blackwell's theorem and get a lower variance estimate of your parameters.
Formally the argument goes like this, samplers can be understood as approximating the expectation
$$ \mathbb{E}{p(x, z)}[f(x, z)] = \mathbb{E}{p(z)}\left[\mathbb{E}_{p(x \mid z)}\left[f(x, z)\right]\right] $$
Letting
$$ \mathbb{V}{p(x, z)}[f(x, z)] = \mathbb{V}{p(z)}[g(z)] + \mathbb{E}{p(z)}\left[\mathbb{V}{p(x \mid z)}\left[f(x, z)\right]\right] $$
Because the expectation is over a variance it must always be positive, and thus we know
$$ \mathbb{V}{p(x, z)}[f(x, z)] \geq \mathbb{V}{p(z)}[g(z)] $$
Intuitively, marginalizing variables in your model lets you use
Unfortunately, the computation to do this is often tedious and unintuitive. Luckily, pymc-experimental
now supports a way to do this work automatically!
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
:::{include} ../extra_installs.md :::
import pymc_extras as pmx
%config InlineBackend.figure_format = 'retina' # high resolution figures
az.style.use("arviz-darkgrid")
rng = np.random.default_rng(32)
As a motivating example, consider a gaussian mixture model
+++
+++
There are two ways to specify the same model. One where the choice of mixture is explicit.
mu = pt.as_tensor([-2.0, 2.0])
with pmx.MarginalModel() as explicit_mixture:
idx = pm.Bernoulli("idx", 0.7)
y = pm.Normal("y", mu=mu[idx], sigma=1.0)
plt.hist(pm.draw(y, draws=2000, random_seed=rng), bins=30, rwidth=0.9);
The other way is where we use the built-in {class}NormalMixture <pymc.NormalMixture>
distribution. Here the mixture assignment is not an explicit variable in our model. There is nothing unique about the first model other than we initialize it with {class}pmx.MarginalModel <pymc_extras.MarginalModel>
instead of {class}pm.Model <pymc.model.core.Model>
. This different class is what will allow us to marginalize out variables later.
with pm.Model() as prebuilt_mixture:
y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[-2, 2])
plt.hist(pm.draw(y, draws=2000, random_seed=rng), bins=30, rwidth=0.9);
with prebuilt_mixture:
idata = pm.sample(draws=2000, chains=4, random_seed=rng)
az.summary(idata)
with explicit_mixture:
idata = pm.sample(draws=2000, chains=4, random_seed=rng)
az.summary(idata)
We can immediately see that the marginalized model has a higher ESS. Let's now marginalize out the choice and see what it changes in our model.
explicit_mixture.marginalize(["idx"])
with explicit_mixture:
idata = pm.sample(draws=2000, chains=4, random_seed=rng)
az.summary(idata)
As we can see, the idx
variable is gone now. We also were able to use the NUTS sampler, and the ESS has improved.
But {class}MarginalModel <pymc_extras.MarginalModel>
has a distinct advantage. It still knows about the discrete variables that were marginalized out, and we can obtain estimates for the posterior of idx
given the other variables. We do this using the {meth}recover_marginals <pymc_extras.MarginalModel.recover_marginals>
method.
explicit_mixture.recover_marginals(idata, random_seed=rng);
az.summary(idata)
This idx
variable lets us recover the mixture assignment variable after running the NUTS sampler! We can split out the samples of y
by reading off the mixture label from the associated idx
for each sample.
# fmt: off
post = idata.posterior
plt.hist(
post.where(post.idx == 0).y.values.reshape(-1),
bins=30,
rwidth=0.9,
alpha=0.75,
label='idx = 0',
)
plt.hist(
post.where(post.idx == 1).y.values.reshape(-1),
bins=30,
rwidth=0.9,
alpha=0.75,
label='idx = 1'
)
# fmt: on
plt.legend();
One important thing to notice is that this discrete variable has a lower ESS, and particularly so for the tail. This means idx
might not be estimated well particularly for the tails. If this is important, I recommend using the lp_idx
instead, which is the log-probability of idx
given sample values on each iteration. The benefits of working with lp_idx
will explored further in the next example.
+++
+++
The same methods work for the {ref}Coal mining <pymc:pymc_overview#case-study-2-coal-mining-disasters>
switchpoint model as well. The coal mining dataset records the number of coal mining disasters in the UK between 1851 and 1962. The time series dataset captures a time when mining safety regulations are being introduced, we try to estimate when this occurred using a discrete switchpoint
variable.
# fmt: off
disaster_data = pd.Series(
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
)
# fmt: on
years = np.arange(1851, 1962)
with pmx.MarginalModel() as disaster_model:
switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
early_rate = pm.Exponential("early_rate", 1.0, initval=3)
late_rate = pm.Exponential("late_rate", 1.0, initval=1)
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
We will sample the model both before and after we marginalize out the switchpoint
variable
with disaster_model:
before_marg = pm.sample(chains=2, random_seed=rng)
disaster_model.marginalize(["switchpoint"])
with disaster_model:
after_marg = pm.sample(chains=2, random_seed=rng)
az.summary(before_marg, var_names=["~disasters"], filter_vars="like")
az.summary(after_marg, var_names=["~disasters"], filter_vars="like")
As before, the ESS improved massively
+++
Finally, let us recover the switchpoint
variable
disaster_model.recover_marginals(after_marg);
az.summary(after_marg, var_names=["~disasters", "~lp"], filter_vars="like")
While recover_marginals
is able to sample the discrete variables that were marginalized out. The probabilities associated with each draw often offer a cleaner estimate of the discrete variable. Particularly for lower probability values. This is best illustrated by comparing the histogram of the sampled values with the plot of the log-probabilities.
post = after_marg.posterior.switchpoint.values.reshape(-1)
bins = np.arange(post.min(), post.max())
plt.hist(post, bins, rwidth=0.9);
lp_switchpoint = after_marg.posterior.lp_switchpoint.mean(dim=["chain", "draw"])
x_max = years[lp_switchpoint.argmax()]
plt.scatter(years, lp_switchpoint)
plt.axvline(x=x_max, c="orange")
plt.xlabel(r"$\mathrm{year}$")
plt.ylabel(r"$\log p(\mathrm{switchpoint}=\mathrm{year})$");
By plotting a histogram of sampled values instead of working with the log-probabilities directly, we are left with noisier and more incomplete exploration of the underlying discrete distribution.
+++
- Authored by Rob Zinkov in January, 2024
+++
:::{bibliography} :filter: docname in docnames :::
+++
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,xarray
:::{include} ../page_footer.md :::