forked from pymc-devs/pymc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharviz.py
709 lines (634 loc) · 25.9 KB
/
arviz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
"""PyMC3-ArviZ conversion code."""
import logging
import warnings
from typing import ( # pylint: disable=unused-import
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
import numpy as np
import xarray as xr
from aesara.graph.basic import Constant
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor
from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec
from arviz.data.base import dict_to_dataset as _dict_to_dataset
from arviz.data.base import generate_dims_coords, make_attrs, requires
import pymc3
from pymc3.aesaraf import extract_obs_data
from pymc3.distributions import logpt
from pymc3.model import modelcontext
from pymc3.util import get_default_varnames
if TYPE_CHECKING:
from typing import Set # pylint: disable=ungrouped-imports
from pymc3.backends.base import MultiTrace # pylint: disable=invalid-name
from pymc3.model import Model
___all__ = [""]
_log = logging.getLogger("pymc3")
# random variable object ...
Var = Any # pylint: disable=invalid-name
class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Name comes from its similarity to ``defaultdict``:
entries are lazily created.
Parameters
----------
samples : int
The number of samples that will be collected, per variable,
into the trace.
Attributes
----------
trace_dict : Dict[str, np.ndarray]
A dictionary constituting a trace. Should be extracted
after a procedure has filled the `_DefaultTrace` using the
`insert()` method
"""
trace_dict: Dict[str, np.ndarray] = {}
_len: Optional[int] = None
def __init__(self, samples: int):
self._len = samples
self.trace_dict = {}
def insert(self, k: str, v, idx: int):
"""
Insert `v` as the value of the `idx`th sample for the variable `k`.
Parameters
----------
k: str
Name of the variable.
v: anything that can go into a numpy array (including a numpy array)
The value of the `idx`th sample from variable `k`
ids: int
The index of the sample we are inserting into the trace.
"""
value_shape = np.shape(v)
# initialize if necessary
if k not in self.trace_dict:
array_shape = (self._len,) + value_shape
self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)
# do the actual insertion
if value_shape == ():
self.trace_dict[k][idx] = v
else:
self.trace_dict[k][idx, :] = v
def dict_to_dataset(
data,
library=None,
coords=None,
dims=None,
attrs=None,
default_dims=None,
skip_event_dims=None,
index_origin=None,
):
"""Temporal workaround for dict_to_dataset.
Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
also remove unnecessary imports
"""
if default_dims is None:
return _dict_to_dataset(
data, library=library, coords=coords, dims=dims, skip_event_dims=skip_event_dims
)
else:
out_data = {}
for name, vals in data.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=out_data, attrs=make_attrs(library=library))
class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
"""Encapsulate InferenceData specific logic."""
model = None # type: Optional[Model]
nchains = None # type: int
ndraws = None # type: int
posterior_predictive = None # Type: Optional[Mapping[str, np.ndarray]]
predictions = None # Type: Optional[Mapping[str, np.ndarray]]
prior = None # Type: Optional[Mapping[str, np.ndarray]]
def __init__(
self,
*,
trace=None,
prior=None,
posterior_predictive=None,
log_likelihood=True,
predictions=None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
model=None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
index_origin: Optional[int] = None,
):
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.trace = trace
# this permits us to get the model from command-line argument or from with model:
self.model = modelcontext(model)
self.attrs = None
if trace is not None:
self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
if hasattr(trace.report, "n_draws") and trace.report.n_draws is not None:
self.ndraws = trace.report.n_draws
self.attrs = {
"sampling_time": trace.report.t_sampling,
"tuning_steps": trace.report.n_tune,
}
else:
self.ndraws = len(trace)
if self.save_warmup:
warnings.warn(
"Warmup samples will be stored in posterior group and will not be"
" excluded from stats and diagnostics."
" Do not slice the trace manually before conversion",
UserWarning,
)
self.ntune = len(self.trace) - self.ndraws
self.posterior_trace, self.warmup_trace = self.split_trace()
else:
self.nchains = self.ndraws = 0
self.prior = prior
self.posterior_predictive = posterior_predictive
self.log_likelihood = log_likelihood
self.predictions = predictions
self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin
def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
return next(iter(dct.values()))
if trace is None:
# if you have a posterior_predictive built with keep_dims,
# you'll lose here, but there's nothing I can do about that.
self.nchains = 1
get_from = None
if predictions is not None:
get_from = predictions
elif posterior_predictive is not None:
get_from = posterior_predictive
elif prior is not None:
get_from = prior
if get_from is None:
# pylint: disable=line-too-long
raise ValueError(
"When constructing InferenceData must have at least"
" one of trace, prior, posterior_predictive or predictions."
)
aelem = arbitrary_element(get_from)
self.ndraws = aelem.shape[0]
self.coords = {} if coords is None else coords
if hasattr(self.model, "coords"):
self.coords = {**self.model.coords, **self.coords}
self.coords = {key: value for key, value in self.coords.items() if value is not None}
self.dims = {} if dims is None else dims
if hasattr(self.model, "RV_dims"):
model_dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in self.model.RV_dims.items()
}
self.dims = {**model_dims, **self.dims}
self.density_dist_obs = density_dist_obs
self.observations = self.find_observations()
def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if self.model is None:
return None
observations = {}
for obs in self.model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")
return observations
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Returns
-------
trace_posterior: MultiTrace or None
The slice of the trace corresponding to the posterior. If the posterior
trace is empty, None is returned
trace_warmup: MultiTrace or None
The slice of the trace corresponding to the warmup. If the warmup trace is
empty or ``save_warmup=False``, None is returned
"""
trace_posterior = None
trace_warmup = None
if self.save_warmup and self.ntune > 0:
trace_warmup = self.trace[: self.ntune]
if self.ndraws > 0:
trace_posterior = self.trace[self.ntune :]
return trace_posterior, trace_warmup
def log_likelihood_vals_point(self, point, var, log_like_fun):
"""Compute log likelihood for each observed point."""
# TODO: This is a cheap hack; we should filter-out the correct
# variables some other way
point = {i.name: point[i.name] for i in log_like_fun.f.maker.inputs if i.name in point}
log_like_val = np.atleast_1d(log_like_fun(point))
if isinstance(var.owner.op, AdvancedIncSubtensor):
try:
obs_data = extract_obs_data(var.tag.observations)
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {var}")
mask = obs_data.mask
if np.ndim(mask) > np.ndim(log_like_val):
mask = np.any(mask, axis=-1)
log_like_val = np.where(mask, np.nan, log_like_val)
return log_like_val
def _extract_log_likelihood(self, trace):
"""Compute log likelihood of each observation."""
if self.trace is None:
return None
if self.model is None:
return None
if self.log_likelihood is True:
cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs]
else:
cached = [
(var, self.model.fn(logpt(var)))
for var in self.model.observed_RVs
if var.name in self.log_likelihood
]
log_likelihood_dict = _DefaultTrace(len(trace.chains))
for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
return log_likelihood_dict.trace_dict
@requires("trace")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
var_names = get_default_varnames(self.trace.varnames, include_transformed=False)
data = {}
data_warmup = {}
for var_name in var_names:
if self.warmup_trace:
data_warmup[var_name] = np.array(
self.warmup_trace.get_values(var_name, combine=False, squeeze=False)
)
if self.posterior_trace:
data[var_name] = np.array(
self.posterior_trace.get_values(var_name, combine=False, squeeze=False)
)
return (
dict_to_dataset(
data,
library=pymc3,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
index_origin=self.index_origin,
),
)
@requires("trace")
def sample_stats_to_xarray(self):
"""Extract sample_stats from PyMC3 trace."""
data = {}
rename_key = {
"model_logp": "lp",
"mean_tree_accept": "acceptance_rate",
"depth": "tree_depth",
"tree_size": "n_steps",
}
data = {}
data_warmup = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
if name == "tune":
continue
if self.warmup_trace:
data_warmup[name] = np.array(
self.warmup_trace.get_sampler_stats(stat, combine=False)
)
if self.posterior_trace:
data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
return (
dict_to_dataset(
data,
library=pymc3,
dims=None,
coords=self.coords,
attrs=self.attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
dims=None,
coords=self.coords,
attrs=self.attrs,
index_origin=self.index_origin,
),
)
@requires("trace")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood and log_p data from PyMC3 trace."""
if self.predictions or not self.log_likelihood:
return None
data_warmup = {}
data = {}
warn_msg = (
"Could not compute log_likelihood, it will be omitted. "
"Check your model object or set log_likelihood=False"
)
if self.posterior_trace:
try:
data = self._extract_log_likelihood(self.posterior_trace)
except TypeError:
warnings.warn(warn_msg)
if self.warmup_trace:
try:
data_warmup = self._extract_log_likelihood(self.warmup_trace)
except TypeError:
warnings.warn(warn_msg)
return (
dict_to_dataset(
data,
library=pymc3,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
index_origin=self.index_origin,
),
)
def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
data = {}
for k, ary in dct.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = np.expand_dims(ary, 0)
# pylint: disable=line-too-long
_log.warning(
"posterior predictive variable %s's shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented.",
k,
)
return dict_to_dataset(
data, library=pymc3, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)
@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.posterior_predictive)
@requires(["predictions"])
def predictions_to_xarray(self):
"""Convert predictions (out of sample predictions) to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions)
def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
if self.prior is None:
return {"prior": None, "prior_predictive": None}
if self.observations is not None:
prior_predictive_vars = list(self.observations.keys())
prior_vars = [key for key in self.prior.keys() if key not in prior_predictive_vars]
else:
prior_vars = list(self.prior.keys())
prior_predictive_vars = None
priors_dict = {}
for group, var_names in zip(
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
):
priors_dict[group] = (
None
if var_names is None
else dict_to_dataset(
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
library=pymc3,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
)
)
return priors_dict
@requires("observations")
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.predictions:
return None
return dict_to_dataset(
self.observations,
library=pymc3,
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)
@requires(["trace", "predictions"])
@requires("model")
def constant_data_to_xarray(self):
"""Convert constant data to xarray."""
# For constant data, we are concerned only with deterministics and
# data. The constant data vars must be either pm.Data
# (TensorSharedVariable) or pm.Deterministic
constant_data_vars = {} # type: Dict[str, Var]
def is_data(name, var) -> bool:
assert self.model is not None
return (
var not in self.model.deterministics
and var not in self.model.observed_RVs
and var not in self.model.free_RVs
and var not in self.model.potentials
and (self.observations is None or name not in self.observations)
and isinstance(var, (Constant, SharedVariable))
)
# I don't know how to find pm.Data, except that they are named
# variables that aren't observed or free RVs, nor are they
# deterministics, and then we eliminate observations.
for name, var in self.model.named_vars.items():
if is_data(name, var):
constant_data_vars[name] = var
if not constant_data_vars:
return None
constant_data = {}
for name, vals in constant_data_vars.items():
if hasattr(vals, "get_value"):
vals = vals.get_value()
elif hasattr(vals, "data"):
vals = vals.data
constant_data[name] = vals
return dict_to_dataset(
constant_data,
library=pymc3,
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)
def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (e.g., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
id_dict = {
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
}
if self.predictions:
id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
else:
id_dict["constant_data"] = self.constant_data_to_xarray()
return InferenceData(save_warmup=self.save_warmup, **id_dict)
def to_inference_data(
trace: Optional["MultiTrace"] = None,
*,
prior: Optional[Dict[str, Any]] = None,
posterior_predictive: Optional[Dict[str, Any]] = None,
log_likelihood: Union[bool, Iterable[str]] = True,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
) -> InferenceData:
"""Convert pymc3 data into an InferenceData object.
All three of them are optional arguments, but at least one of ``trace``,
``prior`` and ``posterior_predictive`` must be present.
For a usage example read the
:ref:`Creating InferenceData section on from_pymc3 <creating_InferenceData>`
Parameters
----------
trace : MultiTrace, optional
Trace generated from MCMC sampling. Output of
:func:`~pymc3.sampling.sample`.
prior : dict, optional
Dictionary with the variable names as keys, and values numpy arrays
containing prior and prior predictive samples.
posterior_predictive : dict, optional
Dictionary with the variable names as keys, and values numpy arrays
containing posterior predictive samples.
log_likelihood : bool or array_like of str, optional
List of variables to calculate `log_likelihood`. Defaults to True which calculates
`log_likelihood` for all observed variables. If set to False, log_likelihood is skipped.
coords : dict of {str: array-like}, optional
Map of coordinate names to coordinate values
dims : dict of {str: list of str}, optional
Map of variable names to the coordinate names to use to index its dimensions.
model : Model, optional
Model used to generate ``trace``. It is not necessary to pass ``model`` if in
``with`` context.
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
density_dist_obs : bool, default True
Store variables passed with ``observed`` arg to
:class:`~pymc.distributions.DensityDist` in the generated InferenceData.
Returns
-------
arviz.InferenceData
"""
if isinstance(trace, InferenceData):
return trace
return InferenceDataConverter(
trace=trace,
prior=prior,
posterior_predictive=posterior_predictive,
log_likelihood=log_likelihood,
coords=coords,
dims=dims,
model=model,
save_warmup=save_warmup,
density_dist_obs=density_dist_obs,
).to_inference_data()
### Later I could have this return ``None`` if the ``idata_orig`` argument is supplied. But
### perhaps we should have an inplace argument?
def predictions_to_inference_data(
predictions,
posterior_trace: Optional["MultiTrace"] = None,
model: Optional["Model"] = None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
idata_orig: Optional[InferenceData] = None,
inplace: bool = False,
) -> InferenceData:
"""Translate out-of-sample predictions into ``InferenceData``.
Parameters
----------
predictions: Dict[str, np.ndarray]
The predictions are the return value of :func:`~pymc3.sample_posterior_predictive`,
a dictionary of strings (variable names) to numpy ndarrays (draws).
posterior_trace: MultiTrace
This should be a trace that has been thinned appropriately for
``pymc3.sample_posterior_predictive``. Specifically, any variable whose shape is
a deterministic function of the shape of any predictor (explanatory, independent, etc.)
variables must be *removed* from this trace.
model: Model
The pymc3 model. It can be ommited if within a model context.
coords: Dict[str, array-like[Any]]
Coordinates for the variables. Map from coordinate names to coordinate values.
dims: Dict[str, array-like[str]]
Map from variable name to ordered set of coordinate names.
idata_orig: InferenceData, optional
If supplied, then modify this inference data in place, adding ``predictions`` and
(if available) ``predictions_constant_data`` groups. If this is not supplied, make a
fresh InferenceData
inplace: boolean, optional
If idata_orig is supplied and inplace is True, merge the predictions into idata_orig,
rather than returning a fresh InferenceData object.
Returns
-------
InferenceData:
May be modified ``idata_orig``.
"""
if inplace and not idata_orig:
raise ValueError(
"Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig"
)
new_idata = InferenceDataConverter(
trace=posterior_trace,
predictions=predictions,
model=model,
coords=coords,
dims=dims,
log_likelihood=False,
).to_inference_data()
if idata_orig is None:
return new_idata
elif inplace:
concat([idata_orig, new_idata], dim=None, inplace=True)
return idata_orig
else:
# if we are not returning in place, then merge the old groups into the new inference
# data and return that.
concat([new_idata, idata_orig], dim=None, copy=True, inplace=True)
return new_idata