Skip to content

Commit 0680957

Browse files
lucianopazricardoV94
authored andcommitted
Add compile_kwargs to compute_log_density functions
1 parent 064822a commit 0680957

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

pymc/stats/log_density.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Sequence
15-
from typing import Literal
15+
from typing import Any, Literal
1616

1717
from arviz import InferenceData
1818
from xarray import Dataset
@@ -36,6 +36,7 @@ def compute_log_likelihood(
3636
model: Model | None = None,
3737
sample_dims: Sequence[str] = ("chain", "draw"),
3838
progressbar=True,
39+
compile_kwargs: dict[str, Any] | None = None,
3940
):
4041
"""Compute elemwise log_likelihood of model given InferenceData with posterior group
4142
@@ -51,6 +52,8 @@ def compute_log_likelihood(
5152
model : Model, optional
5253
sample_dims : sequence of str, default ("chain", "draw")
5354
progressbar : bool, default True
55+
compile_kwargs : dict[str, Any] | None
56+
Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
5457
5558
Returns
5659
-------
@@ -65,6 +68,7 @@ def compute_log_likelihood(
6568
kind="likelihood",
6669
sample_dims=sample_dims,
6770
progressbar=progressbar,
71+
compile_kwargs=compile_kwargs,
6872
)
6973

7074

@@ -75,6 +79,7 @@ def compute_log_prior(
7579
model: Model | None = None,
7680
sample_dims: Sequence[str] = ("chain", "draw"),
7781
progressbar=True,
82+
compile_kwargs=None,
7883
):
7984
"""Compute elemwise log_prior of model given InferenceData with posterior group
8085
@@ -90,6 +95,8 @@ def compute_log_prior(
9095
model : Model, optional
9196
sample_dims : sequence of str, default ("chain", "draw")
9297
progressbar : bool, default True
98+
compile_kwargs : dict[str, Any] | None
99+
Extra compilation arguments to supply to :py:func:`~pymc.stats.compute_log_density`
93100
94101
Returns
95102
-------
@@ -104,6 +111,7 @@ def compute_log_prior(
104111
kind="prior",
105112
sample_dims=sample_dims,
106113
progressbar=progressbar,
114+
compile_kwargs=compile_kwargs,
107115
)
108116

109117

@@ -116,14 +124,42 @@ def compute_log_density(
116124
kind: Literal["likelihood", "prior"] = "likelihood",
117125
sample_dims: Sequence[str] = ("chain", "draw"),
118126
progressbar=True,
127+
compile_kwargs=None,
119128
) -> InferenceData | Dataset:
120129
"""
121130
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
131+
132+
Parameters
133+
----------
134+
idata : InferenceData
135+
InferenceData with posterior group
136+
var_names : sequence of str, optional
137+
List of Observed variable names for which to compute log_prior.
138+
Defaults to all all free variables.
139+
extend_inferencedata : bool, default True
140+
Whether to extend the original InferenceData or return a new one
141+
model : Model, optional
142+
kind: Literal["likelihood", "prior"]
143+
Whether to compute the log density of the observed random variables (likelihood)
144+
or to compute the log density of the latent random variables (prior). This
145+
parameter determines the group that gets added to the returned `~arviz.InferenceData` object.
146+
sample_dims : sequence of str, default ("chain", "draw")
147+
progressbar : bool, default True
148+
compile_kwargs : dict[str, Any] | None
149+
Extra compilation arguments to supply to :py:func:`pymc.model.core.Model.compile_fn`
150+
151+
Returns
152+
-------
153+
idata : InferenceData
154+
InferenceData with the ``log_likelihood`` group when ``kind == "likelihood"``
155+
or the ``log_prior`` group when ``kind == "prior"``.
122156
"""
123157

124158
posterior = idata["posterior"]
125159

126160
model = modelcontext(model)
161+
if compile_kwargs is None:
162+
compile_kwargs = {}
127163

128164
if kind not in ("likelihood", "prior"):
129165
raise ValueError("kind must be either 'likelihood' or 'prior'")
@@ -150,6 +186,7 @@ def compute_log_density(
150186
inputs=umodel.value_vars,
151187
outs=umodel.logp(vars=vars, sum=False),
152188
on_unused_input="ignore",
189+
**compile_kwargs,
153190
)
154191

155192
coords, dims = coords_and_dims_for_inferencedata(umodel)

tests/stats/test_log_density.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import patch
15+
1416
import numpy as np
1517
import pytest
1618
import scipy.stats as st
@@ -174,3 +176,17 @@ def test_deterministic_log_prior(self):
174176
res.log_prior["x"].values,
175177
st.norm(0, 1).logpdf(idata.posterior["x"].values),
176178
)
179+
180+
def test_compilation_kwargs(self):
181+
with Model() as m:
182+
x = Normal("x")
183+
Deterministic("d", 2 * x)
184+
Normal("y", x, observed=[0, 1, 2])
185+
186+
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
187+
with patch("pymc.model.core.compile_pymc") as patched_compile_pymc:
188+
compute_log_prior(idata, compile_kwargs={"mode": "JAX"})
189+
compute_log_likelihood(idata, compile_kwargs={"mode": "NUMBA"})
190+
assert len(patched_compile_pymc.call_args_list) == 2
191+
assert patched_compile_pymc.call_args_list[0].kwargs["mode"] == "JAX"
192+
assert patched_compile_pymc.call_args_list[1].kwargs["mode"] == "NUMBA"

0 commit comments

Comments
 (0)