-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
speed up posterior predictive sampling #6208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
ac4616f
160b0d8
61a05f1
870db74
507de2f
9130d81
65f69d6
75a032e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,8 @@ | |
|
||
import functools | ||
|
||
from typing import Dict, Hashable, List, Tuple, Union, cast | ||
from typing import Any, Dict, List, Tuple, cast | ||
|
||
import arviz | ||
import cloudpickle | ||
import numpy as np | ||
import xarray | ||
|
@@ -231,38 +230,26 @@ def enhanced(*args, **kwargs): | |
return enhanced | ||
|
||
|
||
def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]: | ||
def dataset_to_point_list( | ||
ds: xarray.Dataset, sample_dims: List | ||
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]: | ||
# All keys of the dataset must be a str | ||
for vn in ds.keys(): | ||
var_names = list(ds.keys()) | ||
for vn in var_names: | ||
if not isinstance(vn, str): | ||
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") | ||
# make dicts | ||
points: List[Dict[Hashable, np.ndarray]] = [] | ||
da: "xarray.DataArray" | ||
for c in ds.chain: | ||
for d in ds.draw: | ||
points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()}) | ||
num_sample_dims = len(sample_dims) | ||
stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims} | ||
ds = ds.transpose(*sample_dims, ...) | ||
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
||
stacked_dict = { | ||
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items() | ||
} | ||
points = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we could yield instead of returning the whole list at once? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with using a lazy generator approach unless the whole list is needed at once for some reason There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If that works later in the code then yes! I only kept the list because the function is called Here would that be using a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Me too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed the code downstream of this may be incompatible with generators (it asks for len and sometimes to check the first point...) |
||
{vn: stacked_dict[vn][i, ...] for vn in var_names} | ||
for i in range(np.product([len(coords) for coords in stacked_dims.values()])) | ||
] | ||
# use the list of points | ||
return cast(List[Dict[str, np.ndarray]], points) | ||
|
||
|
||
def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]: | ||
"""Extract and return number of chains and samples in xarray or arviz traces.""" | ||
dataset: xarray.Dataset | ||
if isinstance(data, xarray.Dataset): | ||
dataset = data | ||
elif isinstance(data, arviz.InferenceData): | ||
dataset = data["posterior"] | ||
else: | ||
raise ValueError( | ||
"Argument must be xarray Dataset or arviz InferenceData. Got %s", | ||
data.__class__, | ||
) | ||
|
||
coords = dataset.coords | ||
nchains = coords["chain"].sizes["chain"] | ||
nsamples = coords["draw"].sizes["draw"] | ||
return nchains, nsamples | ||
return cast(List[Dict[str, np.ndarray]], points), stacked_dims | ||
|
||
|
||
def hashable(a=None) -> int: | ||
|
Uh oh!
There was an error while loading. Please reload this page.