Skip to content

speed up posterior predictive sampling #6208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Expand All @@ -15,7 +16,6 @@
)

import numpy as np
import xarray as xr

from aesara.graph.basic import Constant
from aesara.tensor.sharedvar import SharedVariable
Expand Down Expand Up @@ -162,6 +162,7 @@ def __init__(
predictions=None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
model=None,
save_warmup: Optional[bool] = None,
):
Expand Down Expand Up @@ -223,6 +224,9 @@ def __init__(
for var_name, dims in self.model.RV_dims.items()
}
self.dims = {**model_dims, **self.dims}
if sample_dims is None:
sample_dims = ["chain", "draw"]
self.sample_dims = sample_dims

self.observations = find_observations(self.model)

Expand Down Expand Up @@ -419,36 +423,31 @@ def log_likelihood_to_xarray(self):
),
)

def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
data = {}
warning_vars = []
for k, ary in dct.items():
if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws):
data[k] = ary
else:
data[k] = np.expand_dims(ary, 0)
warning_vars.append(k)
if warning_vars:
warnings.warn(
f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
"with number of chains and draws. The automatic dimension naming might not have worked. "
"This can also mean that some draws or even whole chains are not represented.",
UserWarning,
)
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=self.dims, default_dims=self.sample_dims
)

@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, "posterior_predictive"
data = self.posterior_predictive
dims = {
var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data.keys()
}
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
)

@requires(["predictions"])
def predictions_to_xarray(self):
"""Convert predictions (out of sample predictions) to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions")
data = self.predictions
dims = {
var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data.keys()
}
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
)

def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
Expand Down Expand Up @@ -537,6 +536,7 @@ def to_inference_data(
log_likelihood: Union[bool, Iterable[str]] = True,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
) -> InferenceData:
Expand Down Expand Up @@ -586,6 +586,7 @@ def to_inference_data(
log_likelihood=log_likelihood,
coords=coords,
dims=dims,
sample_dims=sample_dims,
model=model,
save_warmup=save_warmup,
).to_inference_data()
Expand All @@ -599,6 +600,7 @@ def predictions_to_inference_data(
model: Optional["Model"] = None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
idata_orig: Optional[InferenceData] = None,
inplace: bool = False,
) -> InferenceData:
Expand Down Expand Up @@ -644,6 +646,7 @@ def predictions_to_inference_data(
model=model,
coords=coords,
dims=dims,
sample_dims=sample_dims,
log_likelihood=False,
)
if hasattr(idata_orig, "posterior"):
Expand Down
61 changes: 39 additions & 22 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
chains_and_samples,
dataset_to_point_list,
get_default_varnames,
get_untransformed_name,
Expand Down Expand Up @@ -1771,6 +1770,7 @@ def sample_posterior_predictive(
trace,
model: Optional[Model] = None,
var_names: Optional[List[str]] = None,
sample_dims: Optional[List[str]] = None,
random_seed: RandomState = None,
progressbar: bool = True,
return_inferencedata: bool = True,
Expand All @@ -1791,6 +1791,10 @@ def sample_posterior_predictive(
generally be the model used to generate the ``trace``, but it doesn't need to be.
var_names : Iterable[str]
Names of variables for which to compute the posterior predictive samples.
sample_dims : list of str, optional
Dimensions over which to loop and generate posterior predictive samples.
When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
dimensions. Only taken into account when `trace` is InferenceData or Dataset.
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.
progressbar : bool
Expand Down Expand Up @@ -1827,6 +1831,14 @@ def sample_posterior_predictive(
thinned_idata = idata.sel(draw=slice(None, None, 5))
with model:
idata.extend(pymc.sample_posterior_predictive(thinned_idata))

Generate 5 posterior predictive samples per posterior sample.

.. code:: python

expanded_data = idata.posterior.expand_dims(pred_id=5)
with model:
idata.extend(pymc.sample_posterior_predictive(expanded_data))
"""

_trace: Union[MultiTrace, PointList]
Expand All @@ -1835,36 +1847,34 @@ def sample_posterior_predictive(
idata_kwargs = {}
else:
idata_kwargs = idata_kwargs.copy()
if sample_dims is None:
sample_dims = ["chain", "draw"]
constant_data: Dict[str, np.ndarray] = {}
trace_coords: Dict[str, np.ndarray] = {}
if "coords" not in idata_kwargs:
idata_kwargs["coords"] = {}
idata: Optional[InferenceData] = None
stacked_dims = None
if isinstance(trace, InferenceData):
idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"])
idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"])
_constant_data = getattr(trace, "constant_data", None)
if _constant_data is not None:
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()})
_trace = dataset_to_point_list(trace["posterior"])
nchain, len_trace = chains_and_samples(trace)
elif isinstance(trace, xarray.Dataset):
idata_kwargs["coords"].setdefault("draw", trace["draw"])
idata_kwargs["coords"].setdefault("chain", trace["chain"])
idata = trace
trace = trace["posterior"]
if isinstance(trace, xarray.Dataset):
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
_trace = dataset_to_point_list(trace)
nchain, len_trace = chains_and_samples(trace)
_trace, stacked_dims = dataset_to_point_list(trace, sample_dims)
nchain = 1
elif isinstance(trace, MultiTrace):
_trace = trace
nchain = _trace.nchains
len_trace = len(_trace)
elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace):
_trace = trace
nchain = 1
len_trace = len(_trace)
else:
raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.")
len_trace = len(_trace)

if isinstance(_trace, MultiTrace):
samples = sum(len(v) for v in _trace._straces.values())
Expand Down Expand Up @@ -1967,23 +1977,30 @@ def sample_posterior_predictive(
ppc_trace = ppc_trace_t.trace_dict

for k, ary in ppc_trace.items():
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
if stacked_dims is not None:
ppc_trace[k] = ary.reshape(
(*[len(coord) for coord in stacked_dims.values()], *ary.shape[1:])
)
else:
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))

if not return_inferencedata:
return ppc_trace
ikwargs: Dict[str, Any] = dict(model=model, **idata_kwargs)
ikwargs.setdefault("sample_dims", sample_dims)
if stacked_dims is not None:
coords = ikwargs.get("coords", {})
ikwargs["coords"] = {**stacked_dims, **coords}
if predictions:
if extend_inferencedata:
ikwargs.setdefault("idata_orig", trace)
ikwargs.setdefault("idata_orig", idata)
ikwargs.setdefault("inplace", True)
return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
converter = pm.backends.arviz.InferenceDataConverter(posterior_predictive=ppc_trace, **ikwargs)
converter.nchains = nchain
converter.ndraws = len_trace
idata_pp = converter.to_inference_data()
if extend_inferencedata:
trace.extend(idata_pp)
return trace
idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)

if extend_inferencedata and idata is not None:
idata.extend(idata_pp)
return idata
return idata_pp


Expand Down
16 changes: 16 additions & 0 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,22 @@ def test_aesara_function_kwargs(self):

assert np.all(pp["y"] == np.arange(5) * 2)

def test_sample_dims(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
post = pm.to_inference_data(trace).posterior.stack(sample=["chain", "draw"])
pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"])
assert "sample" in pp.posterior_predictive
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
post = post.expand_dims(pred_id=5)
pp = pm.sample_posterior_predictive(
post, var_names=["d"], sample_dims=["sample", "pred_id"]
)
assert "sample" in pp.posterior_predictive
assert "pred_id" in pp.posterior_predictive
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
assert len(pp.posterior_predictive["pred_id"]) == 5


class TestDraw(SeededTest):
def test_univariate(self):
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def fn(a=UNSET):
def test_dataset_to_point_list():
ds = xarray.Dataset()
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
pl = dataset_to_point_list(ds)
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
assert isinstance(pl, list)
assert len(pl) == 6
assert isinstance(pl[0], dict)
Expand All @@ -163,4 +163,4 @@ def test_dataset_to_point_list():
# Check that non-str keys are caught
ds[3] = xarray.DataArray([1, 2, 3])
with pytest.raises(ValueError, match="must be str"):
dataset_to_point_list(ds)
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
47 changes: 17 additions & 30 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

import functools

from typing import Dict, Hashable, List, Tuple, Union, cast
from typing import Any, Dict, List, Tuple, cast

import arviz
import cloudpickle
import numpy as np
import xarray
Expand Down Expand Up @@ -231,38 +230,26 @@ def enhanced(*args, **kwargs):
return enhanced


def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
def dataset_to_point_list(
ds: xarray.Dataset, sample_dims: List
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
# All keys of the dataset must be a str
for vn in ds.keys():
var_names = list(ds.keys())
for vn in var_names:
if not isinstance(vn, str):
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
# make dicts
points: List[Dict[Hashable, np.ndarray]] = []
da: "xarray.DataArray"
for c in ds.chain:
for d in ds.draw:
points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()})
num_sample_dims = len(sample_dims)
stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims}
ds = ds.transpose(*sample_dims, ...)
stacked_dict = {
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items()
}
points = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could yield instead of returning the whole list at once?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with using a lazy generator approach unless the whole list is needed at once for some reason

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that works later in the code then yes! I only kept the list because the function is called _to_list. You should assume I have no idea about the format we need to interface with the aesara random drawing function.

Here would that be using a () comprehension or an explicit loop with a yield? Or either?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the () comprehensión more

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too

Copy link
Member

@ricardoV94 ricardoV94 Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the code downstream of this may be incompatible with generators (it asks for len and sometimes to check the first point...)

{vn: stacked_dict[vn][i, ...] for vn in var_names}
for i in range(np.product([len(coords) for coords in stacked_dims.values()]))
]
# use the list of points
return cast(List[Dict[str, np.ndarray]], points)


def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]:
"""Extract and return number of chains and samples in xarray or arviz traces."""
dataset: xarray.Dataset
if isinstance(data, xarray.Dataset):
dataset = data
elif isinstance(data, arviz.InferenceData):
dataset = data["posterior"]
else:
raise ValueError(
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
data.__class__,
)

coords = dataset.coords
nchains = coords["chain"].sizes["chain"]
nsamples = coords["draw"].sizes["draw"]
return nchains, nsamples
return cast(List[Dict[str, np.ndarray]], points), stacked_dims


def hashable(a=None) -> int:
Expand Down