Skip to content

Commit 5711f96

Browse files
committed
Add function that caches sampling results
1 parent 150fb0f commit 5711f96

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Utils
4949

5050
spline.bspline_interpolation
5151
prior.prior_from_idata
52+
cache.cache_sampling
5253

5354

5455
Statespace Models
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import pymc as pm
4+
5+
from pymc_experimental.utils.cache import cache_sampling
6+
7+
8+
def test_cache_sampling(tmpdir):
9+
10+
with pm.Model() as m:
11+
x = pm.Normal("x", 0, 1)
12+
y = pm.Normal("y", mu=x, observed=[0, 1, 2])
13+
14+
cache_prior = cache_sampling(pm.sample_prior_predictive, path=tmpdir)
15+
cache_post = cache_sampling(pm.sample, path=tmpdir)
16+
cache_pred = cache_sampling(pm.sample_posterior_predictive, path=tmpdir)
17+
assert len(os.listdir(tmpdir)) == 0
18+
19+
prior1, prior2 = (cache_prior(samples=5) for _ in range(2))
20+
assert len(os.listdir(tmpdir)) == 1
21+
assert prior1.prior["x"].mean() == prior2.prior["x"].mean()
22+
23+
post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2))
24+
assert len(os.listdir(tmpdir)) == 2
25+
assert post1.posterior["x"].mean() == post2.posterior["x"].mean()
26+
27+
# Change model
28+
with pm.Model() as m:
29+
x = pm.Normal("x", 0, 1)
30+
y = pm.Normal("y", mu=x, observed=[0, 1, 2, 3])
31+
32+
post3 = cache_post(tune=5, draws=5, progressbar=False)
33+
assert len(os.listdir(tmpdir)) == 3
34+
assert post3.posterior["x"].mean() != post1.posterior["x"].mean()
35+
36+
pred1, pred2 = (cache_pred(trace=post3, progressbar=False) for _ in range(2))
37+
assert len(os.listdir(tmpdir)) == 4
38+
assert pred1.posterior_predictive["y"].mean() == pred2.posterior_predictive["y"].mean()
39+
assert "x" not in pred1.posterior_predictive
40+
41+
# Change kwargs
42+
pred3 = cache_pred(trace=post3, progressbar=False, var_names=["x"])
43+
assert len(os.listdir(tmpdir)) == 5
44+
assert "x" in pred3.posterior_predictive

pymc_experimental/utils/cache.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import hashlib
2+
import os
3+
import sys
4+
from typing import Literal
5+
6+
import arviz as az
7+
import numpy as np
8+
from pymc import (
9+
modelcontext,
10+
sample,
11+
sample_posterior_predictive,
12+
sample_prior_predictive,
13+
)
14+
from pymc.model.fgraph import fgraph_from_model
15+
from pytensor.compile import SharedVariable
16+
from pytensor.graph import Constant, FunctionGraph
17+
from pytensor.scalar import ScalarType
18+
from pytensor.tensor import TensorType
19+
from pytensor.tensor.random.type import RandomType
20+
from pytensor.tensor.type_other import NoneTypeT
21+
22+
23+
def hash_data(c):
24+
if isinstance(c.type, NoneTypeT):
25+
return ""
26+
if isinstance(c.type, (ScalarType, TensorType)):
27+
if isinstance(c, Constant):
28+
arr = c.data
29+
elif isinstance(c, SharedVariable):
30+
arr = c.get_value(borrow=True)
31+
arr_data = arr.view(np.uint8) if arr.size > 1 else arr.tobytes()
32+
return hashlib.sha1(arr_data).hexdigest()
33+
else:
34+
raise NotImplementedError(f"Hashing not implemented for type {c.type}")
35+
36+
37+
def get_name_and_props(obj):
38+
name = str(obj)
39+
props = str(getattr(obj, "_props", lambda: {})())
40+
return name, props
41+
42+
43+
def hash_from_fg(fg: FunctionGraph) -> int:
44+
objects_to_hash = []
45+
for node in fg.toposort():
46+
objects_to_hash.append(
47+
(
48+
get_name_and_props(node.op),
49+
tuple(get_name_and_props(inp.type) for inp in node.inputs),
50+
tuple(get_name_and_props(out.type) for out in node.outputs),
51+
# Name is not a symbolic input in the fgraph representation, maybe it should?
52+
tuple(inp.name for inp in node.inputs if inp.name),
53+
tuple(out.name for out in node.outputs if out.name),
54+
)
55+
)
56+
objects_to_hash.append(
57+
tuple(
58+
hash_data(c)
59+
for c in node.inputs
60+
if (
61+
isinstance(c, (Constant, SharedVariable))
62+
# Ignore RNG values
63+
and not isinstance(c.type, RandomType)
64+
)
65+
)
66+
)
67+
str_hash = "\n".join(map(str, objects_to_hash))
68+
return hashlib.sha1(str_hash.encode()).hexdigest()
69+
70+
71+
def cache_sampling(
72+
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive],
73+
path: str = "",
74+
force_sample: bool = False,
75+
):
76+
"""Cache the result of PyMC sampling.
77+
78+
Parameter
79+
---------
80+
sampling_fn: Callable
81+
Must be one of `pymc.sample`, `pymc.sample_prior_predictive` or `pymc.sample_posterior_predictive`.
82+
Positional arguments are disallowed.
83+
path: string, Optional
84+
The path where the results should be saved or retrieved from. Defaults to working directory.
85+
force_sample: bool, Optional
86+
Whether to force sampling even if cache is found. Defaults to False.
87+
88+
Returns
89+
-------
90+
cached_sampling_fn: Callable
91+
Function that wraps the sampling_fn. When called, the wrapped function will look for a valid cached result.
92+
A valid cache requires the same:
93+
1. Model and data
94+
2. Sampling function
95+
3. Sampling kwargs, ignoring ``random_seed``, ``trace``, ``progressbar``, ``extend_inferencedata`` and ``compile_kwargs``.
96+
If o valid cache is found, sampling is bypassed altogether, unless ``force_sample=True``.
97+
Otherwise, sampling is performed and the result cached for future reuse.
98+
Caching is done on the basis of SHA-1 hashing, and there could be unlikely false positives.
99+
100+
101+
Examples
102+
--------
103+
104+
.. code-block:: python
105+
106+
import pymc as pm
107+
from pymc_experimental.utils import cache_sampling
108+
109+
with pm.Model() as m:
110+
x = pm.Normal("x", 0, 1)
111+
y = pm.Normal("y", mu=x, observed=[0, 1, 2])
112+
113+
idata = cache_sampling(pm.sample)()
114+
115+
with m:
116+
idata = cache_sampling(pm.sample)() # Cache hit! Returning stored result
117+
118+
"""
119+
allowed_fns = (sample, sample_prior_predictive, sample_posterior_predictive)
120+
if sampling_fn not in allowed_fns:
121+
raise ValueError(f"Cache sampling can only be used with {allowed_fns}")
122+
123+
def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs):
124+
if args:
125+
raise ValueError("Non-keyword arguments not allowed in cache_sampling")
126+
127+
extend_inferencedata = kwargs.pop("extend_inferencedata", False)
128+
129+
# Model hash
130+
model = modelcontext(model)
131+
fg, _ = fgraph_from_model(model)
132+
model_hash = hash_from_fg(fg)
133+
134+
# Sampling hash
135+
sampling_hash_kwargs = kwargs.copy()
136+
sampling_hash_kwargs["sampling_fn"] = str(sampling_fn)
137+
sampling_hash_kwargs.pop("trace", None)
138+
sampling_hash_kwargs.pop("random_seed", None)
139+
sampling_hash_kwargs.pop("progressbar", None)
140+
sampling_hash_kwargs.pop("compile_kwargs", None)
141+
sampling_hash = str(sampling_hash_kwargs)
142+
143+
file_name = hashlib.sha1((model_hash + sampling_hash).encode()).hexdigest() + ".nc"
144+
file_path = os.path.join(path, file_name)
145+
146+
if not force_sample and os.path.exists(file_path):
147+
print("Cache hit! Returning stored result", file=sys.stdout)
148+
idata_out = az.from_netcdf(file_path)
149+
150+
else:
151+
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed)
152+
153+
if os.path.exists(file_path):
154+
os.remove(file_path)
155+
az.to_netcdf(idata_out, file_path)
156+
157+
# We save inferencedata separately and extend if needed
158+
if extend_inferencedata:
159+
trace = kwargs["trace"]
160+
trace.extend(idata_out)
161+
idata_out = trace
162+
163+
return idata_out
164+
165+
return wrapped_sampling_fn

0 commit comments

Comments
 (0)