12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from collections .abc import Sequence
15
- from typing import Literal
15
+ from typing import Any , Literal
16
16
17
17
from arviz import InferenceData
18
18
from xarray import Dataset
@@ -36,6 +36,7 @@ def compute_log_likelihood(
36
36
model : Model | None = None ,
37
37
sample_dims : Sequence [str ] = ("chain" , "draw" ),
38
38
progressbar = True ,
39
+ compile_kwargs : dict [str , Any ] | None = None ,
39
40
):
40
41
"""Compute elemwise log_likelihood of model given InferenceData with posterior group
41
42
@@ -51,6 +52,8 @@ def compute_log_likelihood(
51
52
model : Model, optional
52
53
sample_dims : sequence of str, default ("chain", "draw")
53
54
progressbar : bool, default True
55
+ compile_kwargs : dict[str, Any] | None
56
+ Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
54
57
55
58
Returns
56
59
-------
@@ -65,6 +68,7 @@ def compute_log_likelihood(
65
68
kind = "likelihood" ,
66
69
sample_dims = sample_dims ,
67
70
progressbar = progressbar ,
71
+ compile_kwargs = compile_kwargs ,
68
72
)
69
73
70
74
@@ -75,6 +79,7 @@ def compute_log_prior(
75
79
model : Model | None = None ,
76
80
sample_dims : Sequence [str ] = ("chain" , "draw" ),
77
81
progressbar = True ,
82
+ compile_kwargs = None ,
78
83
):
79
84
"""Compute elemwise log_prior of model given InferenceData with posterior group
80
85
@@ -90,6 +95,8 @@ def compute_log_prior(
90
95
model : Model, optional
91
96
sample_dims : sequence of str, default ("chain", "draw")
92
97
progressbar : bool, default True
98
+ compile_kwargs : dict[str, Any] | None
99
+ Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
93
100
94
101
Returns
95
102
-------
@@ -104,6 +111,7 @@ def compute_log_prior(
104
111
kind = "prior" ,
105
112
sample_dims = sample_dims ,
106
113
progressbar = progressbar ,
114
+ compile_kwargs = compile_kwargs ,
107
115
)
108
116
109
117
@@ -116,14 +124,42 @@ def compute_log_density(
116
124
kind : Literal ["likelihood" , "prior" ] = "likelihood" ,
117
125
sample_dims : Sequence [str ] = ("chain" , "draw" ),
118
126
progressbar = True ,
127
+ compile_kwargs = None ,
119
128
) -> InferenceData | Dataset :
120
129
"""
121
130
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
131
+
132
+ Parameters
133
+ ----------
134
+ idata : InferenceData
135
+ InferenceData with posterior group
136
+ var_names : sequence of str, optional
137
+ List of Observed variable names for which to compute log_prior.
138
+ Defaults to all all free variables.
139
+ extend_inferencedata : bool, default True
140
+ Whether to extend the original InferenceData or return a new one
141
+ model : Model, optional
142
+ kind: Literal["likelihood", "prior"]
143
+ Whether to compute the log density of the observed random variables (likelihood)
144
+ or to compute the log density of the latent random variables (prior). This
145
+ parameter determines the group that gets added to the returned `~arviz.InferenceData` object.
146
+ sample_dims : sequence of str, default ("chain", "draw")
147
+ progressbar : bool, default True
148
+ compile_kwargs : dict[str, Any] | None
149
+ Extra compilation arguments to supply to :py:func:`pymc.model.core.Model.compile_fn`
150
+
151
+ Returns
152
+ -------
153
+ idata : InferenceData
154
+ InferenceData with the ``log_likelihood`` group when ``kind == "likelihood"``
155
+ or the ``log_prior`` group when ``kind == "prior"``.
122
156
"""
123
157
124
158
posterior = idata ["posterior" ]
125
159
126
160
model = modelcontext (model )
161
+ if compile_kwargs is None :
162
+ compile_kwargs = {}
127
163
128
164
if kind not in ("likelihood" , "prior" ):
129
165
raise ValueError ("kind must be either 'likelihood' or 'prior'" )
@@ -150,6 +186,7 @@ def compute_log_density(
150
186
inputs = umodel .value_vars ,
151
187
outs = umodel .logp (vars = vars , sum = False ),
152
188
on_unused_input = "ignore" ,
189
+ ** compile_kwargs ,
153
190
)
154
191
155
192
coords , dims = coords_and_dims_for_inferencedata (umodel )
0 commit comments