|
| 1 | +from typing import List, Optional, Sequence, Union |
| 2 | + |
| 3 | +import pytensor.tensor as pt |
| 4 | +from pymc import DiracDelta |
| 5 | +from pymc.distributions.censored import CensoredRV |
| 6 | +from pymc.distributions.timeseries import AR, AutoRegressiveRV |
| 7 | +from pymc.model import Model |
| 8 | +from pytensor.graph.basic import Variable |
| 9 | + |
| 10 | +from pymc_experimental.utils.model_fgraph import ( |
| 11 | + ModelFreeRV, |
| 12 | + ModelValuedVar, |
| 13 | + fgraph_from_model, |
| 14 | + model_free_rv, |
| 15 | + model_from_fgraph, |
| 16 | + toposort_replace, |
| 17 | +) |
| 18 | + |
| 19 | +__all__ = ( |
| 20 | + "uncensor", |
| 21 | + "forecast_timeseries", |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +ModelVariable = Union[Variable, str] |
| 26 | +SequenceModelVariables = Union[ModelVariable, Sequence[ModelVariable]] |
| 27 | + |
| 28 | + |
| 29 | +def parse_vars(model: Model, vars: SequenceModelVariables) -> List[Variable]: |
| 30 | + if not isinstance(vars, (list, tuple)): |
| 31 | + vars = (vars,) |
| 32 | + return [model[var] if isinstance(var, str) else var for var in vars] |
| 33 | + |
| 34 | + |
| 35 | +def uncensor(model: Model, vars: Optional[SequenceModelVariables] = None) -> Model: |
| 36 | + """Replace censored variables in the model by uncensored ones. |
| 37 | +
|
| 38 | + .. code-block:: python |
| 39 | +
|
| 40 | + import pymc as pm |
| 41 | + from pymc_experimental.model_transform.predict import uncensor |
| 42 | +
|
| 43 | + with pm.Model() as model: |
| 44 | + x = pm.Normal("x") |
| 45 | + dist_raw = pm.Normal.dist(x, sigma=10) |
| 46 | + y = pm.Censored("y", dist=dist_raw, lower=0, upper=10, observed=[0, 5, 10]) |
| 47 | + trace = pm.sample() |
| 48 | +
|
| 49 | + with uncensor(model): |
| 50 | + pp = pm.sample_posterior_predictive(trace, var_names=["y"]) |
| 51 | +
|
| 52 | +
|
| 53 | + Parameters |
| 54 | + ---------- |
| 55 | + model: Model |
| 56 | + vars: optional |
| 57 | + Model variables that should be replaced by uncensored counterparts. |
| 58 | + Defaults to all censored variables. |
| 59 | +
|
| 60 | + Returns |
| 61 | + ------- |
| 62 | + uncensored_model: Model |
| 63 | + Model with the censored variables replaced by uncensored versions |
| 64 | +
|
| 65 | + """ |
| 66 | + vars = parse_vars(model, vars) if vars is not None else [] |
| 67 | + |
| 68 | + fgraph, memo = fgraph_from_model(model) |
| 69 | + |
| 70 | + target_vars = {memo[var] for var in vars} |
| 71 | + replacements = {} |
| 72 | + for node in fgraph.apply_nodes: |
| 73 | + if not isinstance(node.op, ModelValuedVar): |
| 74 | + continue |
| 75 | + |
| 76 | + dummy_rv = node.outputs[0] |
| 77 | + if target_vars and dummy_rv not in target_vars: |
| 78 | + continue |
| 79 | + |
| 80 | + rv, value, *dims = node.inputs |
| 81 | + if not isinstance(rv.owner.op, (CensoredRV,)): |
| 82 | + if target_vars: |
| 83 | + raise NotImplementedError(f"RV distribution {rv.owner.op} is not censored") |
| 84 | + else: |
| 85 | + continue |
| 86 | + |
| 87 | + # The first argument is the `dist` RV |
| 88 | + new_rv = rv.owner.inputs[0] |
| 89 | + |
| 90 | + new_rv.name = rv.name |
| 91 | + new_dummy_rv = model_free_rv(new_rv, new_rv.type(), None, *dims) |
| 92 | + replacements[dummy_rv] = new_dummy_rv |
| 93 | + |
| 94 | + toposort_replace(fgraph, tuple(replacements.items())) |
| 95 | + return model_from_fgraph(fgraph) |
| 96 | + |
| 97 | + |
| 98 | +def forecast_timeseries( |
| 99 | + model: Model, |
| 100 | + vars: Optional[SequenceModelVariables] = None, |
| 101 | + *, |
| 102 | + steps: Optional[int] = None, |
| 103 | +) -> Model: |
| 104 | + """Replace timeseries variables in the model by forecast that start at the last value. |
| 105 | +
|
| 106 | + .. code-block:: python |
| 107 | +
|
| 108 | + import pymc as pm |
| 109 | + from pymc_experimental.model_transform.predict import forecast_timeseries |
| 110 | +
|
| 111 | + with pm.Model() as model: |
| 112 | + rho = pm.Normal("rho") |
| 113 | + sigma = pm.HalfNormal("sigma") |
| 114 | + init_dist = pm.Normal.dist() |
| 115 | + y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=[0] * 100) |
| 116 | + trace = pm.sample() |
| 117 | +
|
| 118 | + with forecast_timeseries(model, steps=20): |
| 119 | + pp = pm.sample_posterior_predictive(trace, var_names=["y"], predictions=True) |
| 120 | +
|
| 121 | +
|
| 122 | +
|
| 123 | + Parameters |
| 124 | + ---------- |
| 125 | + model: Model |
| 126 | + vars: optional |
| 127 | + Model variables that should be replaced by forecast counterparts. |
| 128 | + Defaults to all timeseries variables. |
| 129 | + steps: int, optional |
| 130 | + Number of steps for the forecast. Defaults to the same as originally |
| 131 | +
|
| 132 | + Returns |
| 133 | + ------- |
| 134 | + forecast_model: Model |
| 135 | + Model with the timeseries variables replaced by the forecast versions |
| 136 | +
|
| 137 | + """ |
| 138 | + vars = parse_vars(model, vars) if vars is not None else [] |
| 139 | + |
| 140 | + if steps is not None: |
| 141 | + steps = pt.as_tensor_variable(steps, dtype=int) |
| 142 | + |
| 143 | + fgraph, memo = fgraph_from_model(model) |
| 144 | + |
| 145 | + target_vars = {memo[var] for var in vars} |
| 146 | + replacements = {} |
| 147 | + for node in fgraph.apply_nodes: |
| 148 | + |
| 149 | + if not isinstance(node.op, ModelValuedVar): |
| 150 | + continue |
| 151 | + |
| 152 | + dummy_rv = node.outputs[0] |
| 153 | + if target_vars and dummy_rv not in target_vars: |
| 154 | + continue |
| 155 | + |
| 156 | + rv, value, *dims = node.inputs |
| 157 | + if not isinstance(rv.owner.op, (AutoRegressiveRV,)): |
| 158 | + if target_vars: |
| 159 | + raise NotImplementedError(f"RV distribution {rv.owner.op} can't be forecasted") |
| 160 | + else: |
| 161 | + continue |
| 162 | + |
| 163 | + # For free RVs we use the RV as the starting value |
| 164 | + # For observedRVs we use the actual value as the starting value |
| 165 | + if isinstance(node.op, ModelFreeRV): |
| 166 | + value = rv |
| 167 | + |
| 168 | + if isinstance(rv.owner.op, AutoRegressiveRV): |
| 169 | + init_dist = DiracDelta.dist(value[-1]) |
| 170 | + rhos, sigma, _, old_steps, _ = rv.owner.inputs |
| 171 | + new_rv = AR.rv_op( |
| 172 | + rhos, |
| 173 | + sigma, |
| 174 | + init_dist, |
| 175 | + steps=steps or old_steps, |
| 176 | + ar_order=rv.owner.op.ar_order, |
| 177 | + constant_term=rv.owner.op.constant_term, |
| 178 | + ) |
| 179 | + |
| 180 | + new_rv.name = rv.name |
| 181 | + new_dummy_rv = model_free_rv(new_rv, new_rv.type(), None, *dims) |
| 182 | + replacements[dummy_rv] = new_dummy_rv |
| 183 | + |
| 184 | + toposort_replace(fgraph, tuple(replacements.items())) |
| 185 | + return model_from_fgraph(fgraph) |
0 commit comments