Skip to content

Commit 5b20bdf

Browse files
committed
Add function that caches sampling results
1 parent 150fb0f commit 5b20bdf

File tree

3 files changed

+219
-0
lines changed

3 files changed

+219
-0
lines changed

Diff for: docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Utils
4949

5050
spline.bspline_interpolation
5151
prior.prior_from_idata
52+
cache.cache_sampling
5253

5354

5455
Statespace Models

Diff for: pymc_experimental/tests/utils/test_cache.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import pymc as pm
4+
5+
from pymc_experimental.utils.cache import cache_sampling
6+
7+
8+
def test_cache_sampling(tmpdir):
9+
10+
with pm.Model() as m:
11+
x = pm.Normal("x", 0, 1)
12+
y = pm.Normal("y", mu=x, observed=[0, 1, 2])
13+
14+
cache_prior = cache_sampling(pm.sample_prior_predictive, path=tmpdir)
15+
cache_post = cache_sampling(pm.sample, path=tmpdir)
16+
cache_pred = cache_sampling(pm.sample_posterior_predictive, path=tmpdir)
17+
assert len(os.listdir(tmpdir)) == 0
18+
19+
prior1, prior2 = (cache_prior(samples=5) for _ in range(2))
20+
assert len(os.listdir(tmpdir)) == 1
21+
assert prior1.prior["x"].mean() == prior2.prior["x"].mean()
22+
23+
post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2))
24+
assert len(os.listdir(tmpdir)) == 2
25+
assert post1.posterior["x"].mean() == post2.posterior["x"].mean()
26+
27+
# Change model
28+
with pm.Model() as m:
29+
x = pm.Normal("x", 0, 1)
30+
y = pm.Normal("y", mu=x, observed=[0, 1, 2, 3])
31+
32+
post3 = cache_post(tune=5, draws=5, progressbar=False)
33+
assert len(os.listdir(tmpdir)) == 3
34+
assert post3.posterior["x"].mean() != post1.posterior["x"].mean()
35+
36+
pred1, pred2 = (cache_pred(trace=post3, progressbar=False) for _ in range(2))
37+
assert len(os.listdir(tmpdir)) == 4
38+
assert pred1.posterior_predictive["y"].mean() == pred2.posterior_predictive["y"].mean()
39+
assert "x" not in pred1.posterior_predictive
40+
41+
# Change kwargs
42+
pred3 = cache_pred(trace=post3, progressbar=False, var_names=["x"])
43+
assert len(os.listdir(tmpdir)) == 5
44+
assert "x" in pred3.posterior_predictive

Diff for: pymc_experimental/utils/cache.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import hashlib
2+
import os
3+
import sys
4+
from typing import Callable, Literal
5+
6+
import arviz as az
7+
import numpy as np
8+
from pymc import (
9+
modelcontext,
10+
sample,
11+
sample_posterior_predictive,
12+
sample_prior_predictive,
13+
)
14+
from pymc.model.fgraph import fgraph_from_model
15+
from pytensor.compile import SharedVariable
16+
from pytensor.graph import Constant, FunctionGraph, Variable
17+
from pytensor.scalar import ScalarType
18+
from pytensor.tensor import TensorType
19+
from pytensor.tensor.random.type import RandomType
20+
from pytensor.tensor.type_other import NoneTypeT
21+
22+
23+
def hash_data(c: Variable) -> str:
24+
if isinstance(c.type, NoneTypeT):
25+
return "None"
26+
if isinstance(c.type, (ScalarType, TensorType)):
27+
if isinstance(c, Constant):
28+
arr = c.data
29+
elif isinstance(c, SharedVariable):
30+
arr = c.get_value(borrow=True)
31+
arr_data = arr.view(np.uint8) if arr.size > 1 else arr.tobytes()
32+
return hashlib.sha1(arr_data).hexdigest()
33+
else:
34+
raise NotImplementedError(f"Hashing not implemented for type {c.type}")
35+
36+
37+
def get_name_and_props(obj):
38+
name = str(obj)
39+
props = str(getattr(obj, "_props", lambda: {})())
40+
return name, props
41+
42+
43+
def hash_from_fg(fg: FunctionGraph) -> str:
44+
objects_to_hash = []
45+
for node in fg.toposort():
46+
objects_to_hash.append(
47+
(
48+
get_name_and_props(node.op),
49+
tuple(get_name_and_props(inp.type) for inp in node.inputs),
50+
tuple(get_name_and_props(out.type) for out in node.outputs),
51+
# Name is not a symbolic input in the fgraph representation, maybe it should?
52+
tuple(inp.name for inp in node.inputs if inp.name),
53+
tuple(out.name for out in node.outputs if out.name),
54+
)
55+
)
56+
objects_to_hash.append(
57+
tuple(
58+
hash_data(c)
59+
for c in node.inputs
60+
if (
61+
isinstance(c, (Constant, SharedVariable))
62+
# Ignore RNG values
63+
and not isinstance(c.type, RandomType)
64+
)
65+
)
66+
)
67+
str_hash = "\n".join(map(str, objects_to_hash))
68+
return hashlib.sha1(str_hash.encode()).hexdigest()
69+
70+
71+
def cache_sampling(
72+
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive],
73+
path: str = "",
74+
force_sample: bool = False,
75+
) -> Callable:
76+
"""Cache the result of PyMC sampling.
77+
78+
Parameter
79+
---------
80+
sampling_fn: Callable
81+
Must be one of `pymc.sample`, `pymc.sample_prior_predictive` or `pymc.sample_posterior_predictive`.
82+
Positional arguments are disallowed.
83+
path: string, Optional
84+
The path where the results should be saved or retrieved from. Defaults to working directory.
85+
force_sample: bool, Optional
86+
Whether to force sampling even if cache is found. Defaults to False.
87+
88+
Returns
89+
-------
90+
cached_sampling_fn: Callable
91+
Function that wraps the sampling_fn. When called, the wrapped function will look for a valid cached result.
92+
A valid cache requires the same:
93+
1. Model and data
94+
2. Sampling function
95+
3. Sampling kwargs, ignoring ``random_seed``, ``trace``, ``progressbar``, ``extend_inferencedata`` and ``compile_kwargs``.
96+
If a valid cache is found, sampling is bypassed altogether, unless ``force_sample=True``.
97+
Otherwise, sampling is performed and the result cached for future reuse.
98+
Caching is done on the basis of SHA-1 hashing, and there could be unlikely false positives.
99+
100+
101+
Examples
102+
--------
103+
104+
.. code-block:: python
105+
106+
import pymc as pm
107+
from pymc_experimental.utils.cache import cache_sampling
108+
109+
with pm.Model() as m:
110+
y_data = pm.MutableData("y_data", [0, 1, 2])
111+
x = pm.Normal("x", 0, 1)
112+
y = pm.Normal("y", mu=x, observed=y_data)
113+
114+
cache_sample = cache_sampling(pm.sample, path="data")
115+
idata1 = cache_sample(chains=2)
116+
117+
# Cache hit! Returning stored result
118+
idata2 = cache_sample(chains=2)
119+
120+
pm.set_data({"y_data": [1, 1, 1]})
121+
idata3 = cache_sample(chains=2)
122+
123+
assert idata1.posterior["x"].mean() == idata2.posterior["x"].mean()
124+
assert idata1.posterior["x"].mean() != idata3.posterior["x"].mean()
125+
126+
"""
127+
allowed_fns = (sample, sample_prior_predictive, sample_posterior_predictive)
128+
if sampling_fn not in allowed_fns:
129+
raise ValueError(f"Cache sampling can only be used with {allowed_fns}")
130+
131+
def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs):
132+
if args:
133+
raise ValueError("Non-keyword arguments not allowed in cache_sampling")
134+
135+
extend_inferencedata = kwargs.pop("extend_inferencedata", False)
136+
137+
# Model hash
138+
model = modelcontext(model)
139+
fg, _ = fgraph_from_model(model)
140+
model_hash = hash_from_fg(fg)
141+
142+
# Sampling hash
143+
sampling_hash_kwargs = kwargs.copy()
144+
sampling_hash_kwargs["sampling_fn"] = str(sampling_fn)
145+
sampling_hash_kwargs.pop("trace", None)
146+
sampling_hash_kwargs.pop("random_seed", None)
147+
sampling_hash_kwargs.pop("progressbar", None)
148+
sampling_hash_kwargs.pop("compile_kwargs", None)
149+
sampling_hash = str(sampling_hash_kwargs)
150+
151+
file_name = hashlib.sha1((model_hash + sampling_hash).encode()).hexdigest() + ".nc"
152+
file_path = os.path.join(path, file_name)
153+
154+
if not force_sample and os.path.exists(file_path):
155+
print("Cache hit! Returning stored result", file=sys.stdout)
156+
idata_out = az.from_netcdf(file_path)
157+
158+
else:
159+
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed)
160+
if os.path.exists(file_path):
161+
os.remove(file_path)
162+
if not os.path.exists(path):
163+
os.mkdir(path)
164+
az.to_netcdf(idata_out, file_path)
165+
166+
# We save inferencedata separately and extend if needed
167+
if extend_inferencedata:
168+
trace = kwargs["trace"]
169+
trace.extend(idata_out)
170+
idata_out = trace
171+
172+
return idata_out
173+
174+
return wrapped_sampling_fn

0 commit comments

Comments
 (0)