Skip to content

Commit 53fcdf9

Browse files
committed
Add example predictive model transforms
1 parent dd3c44d commit 53fcdf9

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed

docs/api_reference.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ Model Transformations
4444

4545
conditioning.do
4646
conditioning.observe
47+
predict.forecast_timeseries
48+
predict.uncensor
4749

4850

4951
Utils
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from typing import List, Optional, Sequence, Union
2+
3+
import pytensor.tensor as pt
4+
from pymc import DiracDelta
5+
from pymc.distributions.censored import CensoredRV
6+
from pymc.distributions.timeseries import AR, AutoRegressiveRV
7+
from pymc.model import Model
8+
from pytensor.graph.basic import Variable
9+
10+
from pymc_experimental.utils.model_fgraph import (
11+
ModelFreeRV,
12+
ModelValuedVar,
13+
fgraph_from_model,
14+
model_free_rv,
15+
model_from_fgraph,
16+
toposort_replace,
17+
)
18+
19+
__all__ = (
20+
"uncensor",
21+
"forecast_timeseries",
22+
)
23+
24+
25+
ModelVariable = Union[Variable, str]
26+
SequenceModelVariables = Union[ModelVariable, Sequence[ModelVariable]]
27+
28+
29+
def parse_vars(model: Model, vars: SequenceModelVariables) -> List[Variable]:
30+
if not isinstance(vars, (list, tuple)):
31+
vars = (vars,)
32+
return [model[var] if isinstance(var, str) else var for var in vars]
33+
34+
35+
def uncensor(model: Model, vars: Optional[SequenceModelVariables] = None) -> Model:
36+
"""Replace censored variables in the model by uncensored ones.
37+
38+
.. code-block:: python
39+
40+
import pymc as pm
41+
from pymc_experimental.model_transform.predict import uncensor
42+
43+
with pm.Model() as model:
44+
x = pm.Normal("x")
45+
dist_raw = pm.Normal.dist(x, sigma=10)
46+
y = pm.Censored("y", dist=dist_raw, lower=0, upper=10, observed=[0, 5, 10])
47+
trace = pm.sample()
48+
49+
with uncensor(model):
50+
pp = pm.sample_posterior_predictive(trace, var_names=["y"])
51+
52+
53+
Parameters
54+
----------
55+
model: Model
56+
vars: optional
57+
Model variables that should be replaced by uncensored counterparts.
58+
Defaults to all censored variables.
59+
60+
Returns
61+
-------
62+
uncensored_model: Model
63+
Model with the censored variables replaced by uncensored versions
64+
65+
"""
66+
vars = parse_vars(model, vars) if vars is not None else []
67+
68+
fgraph, memo = fgraph_from_model(model)
69+
70+
target_vars = {memo[var] for var in vars}
71+
replacements = {}
72+
for node in fgraph.apply_nodes:
73+
if not isinstance(node.op, ModelValuedVar):
74+
continue
75+
76+
dummy_rv = node.outputs[0]
77+
if target_vars and dummy_rv not in target_vars:
78+
continue
79+
80+
rv, value, *dims = node.inputs
81+
if not isinstance(rv.owner.op, (CensoredRV,)):
82+
if target_vars:
83+
raise NotImplementedError(f"RV distribution {rv.owner.op} is not censored")
84+
else:
85+
continue
86+
87+
# The first argument is the `dist` RV
88+
new_rv = rv.owner.inputs[0]
89+
90+
new_rv.name = rv.name
91+
new_dummy_rv = model_free_rv(new_rv, new_rv.type(), None, *dims)
92+
replacements[dummy_rv] = new_dummy_rv
93+
94+
toposort_replace(fgraph, tuple(replacements.items()))
95+
return model_from_fgraph(fgraph)
96+
97+
98+
def forecast_timeseries(
99+
model: Model,
100+
vars: Optional[SequenceModelVariables] = None,
101+
*,
102+
steps: Optional[int] = None,
103+
) -> Model:
104+
"""Replace timeseries variables in the model by forecast that start at the last value.
105+
106+
.. code-block:: python
107+
108+
import pymc as pm
109+
from pymc_experimental.model_transform.predict import forecast_timeseries
110+
111+
with pm.Model() as model:
112+
rho = pm.Normal("rho")
113+
sigma = pm.HalfNormal("sigma")
114+
init_dist = pm.Normal.dist()
115+
y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=[0] * 100)
116+
trace = pm.sample()
117+
118+
with forecast_timeseries(model, steps=20):
119+
pp = pm.sample_posterior_predictive(trace, var_names=["y"], predictions=True)
120+
121+
122+
123+
Parameters
124+
----------
125+
model: Model
126+
vars: optional
127+
Model variables that should be replaced by forecast counterparts.
128+
Defaults to all timeseries variables.
129+
steps: int, optional
130+
Number of steps for the forecast. Defaults to the same as originally
131+
132+
Returns
133+
-------
134+
forecast_model: Model
135+
Model with the timeseries variables replaced by the forecast versions
136+
137+
"""
138+
vars = parse_vars(model, vars) if vars is not None else []
139+
140+
if steps is not None:
141+
steps = pt.as_tensor_variable(steps, dtype=int)
142+
143+
fgraph, memo = fgraph_from_model(model)
144+
145+
target_vars = {memo[var] for var in vars}
146+
replacements = {}
147+
for node in fgraph.apply_nodes:
148+
149+
if not isinstance(node.op, ModelValuedVar):
150+
continue
151+
152+
dummy_rv = node.outputs[0]
153+
if target_vars and dummy_rv not in target_vars:
154+
continue
155+
156+
rv, value, *dims = node.inputs
157+
if not isinstance(rv.owner.op, (AutoRegressiveRV,)):
158+
if target_vars:
159+
raise NotImplementedError(f"RV distribution {rv.owner.op} can't be forecasted")
160+
else:
161+
continue
162+
163+
# For free RVs we use the RV as the starting value
164+
# For observedRVs we use the actual value as the starting value
165+
if isinstance(node.op, ModelFreeRV):
166+
value = rv
167+
168+
if isinstance(rv.owner.op, AutoRegressiveRV):
169+
init_dist = DiracDelta.dist(value[-1])
170+
rhos, sigma, _, old_steps, _ = rv.owner.inputs
171+
new_rv = AR.rv_op(
172+
rhos,
173+
sigma,
174+
init_dist,
175+
steps=steps or old_steps,
176+
ar_order=rv.owner.op.ar_order,
177+
constant_term=rv.owner.op.constant_term,
178+
)
179+
180+
new_rv.name = rv.name
181+
new_dummy_rv = model_free_rv(new_rv, new_rv.type(), None, *dims)
182+
replacements[dummy_rv] = new_dummy_rv
183+
184+
toposort_replace(fgraph, tuple(replacements.items()))
185+
return model_from_fgraph(fgraph)

0 commit comments

Comments
 (0)