|
| 1 | +from dataclasses import dataclass |
| 2 | +from functools import singledispatch |
| 3 | +from typing import Dict, List, Optional, Sequence, Tuple, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pymc as pm |
| 7 | +import pytensor |
| 8 | +import pytensor.tensor as pt |
| 9 | +import scipy.special |
| 10 | +from pymc.logprob.transforms import Transform |
| 11 | +from pymc.model.fgraph import ( |
| 12 | + ModelDeterministic, |
| 13 | + ModelNamed, |
| 14 | + fgraph_from_model, |
| 15 | + model_deterministic, |
| 16 | + model_free_rv, |
| 17 | + model_from_fgraph, |
| 18 | + model_named, |
| 19 | +) |
| 20 | +from pymc.pytensorf import toposort_replace |
| 21 | +from pytensor.graph.basic import Apply, Variable |
| 22 | +from pytensor.tensor.random.op import RandomVariable |
| 23 | + |
| 24 | + |
| 25 | +@dataclass |
| 26 | +class VIP: |
| 27 | + r"""Helper to reparemetrize VIP model. |
| 28 | +
|
| 29 | + Manipulation of :math:`\lambda` in the below equation is done using this helper class. |
| 30 | +
|
| 31 | + .. math:: |
| 32 | +
|
| 33 | + \begin{align*} |
| 34 | + \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\ |
| 35 | + \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu) |
| 36 | + \sim \text{normal}(\mu, \sigma). |
| 37 | + \end{align*} |
| 38 | + """ |
| 39 | + |
| 40 | + _logit_lambda: Dict[str, pytensor.tensor.sharedvar.TensorSharedVariable] |
| 41 | + |
| 42 | + @property |
| 43 | + def variational_parameters(self) -> List[pytensor.tensor.sharedvar.TensorSharedVariable]: |
| 44 | + r"""Return raw :math:`\operatorname{logit}(\lambda_k)` for custom optimization. |
| 45 | +
|
| 46 | + Examples |
| 47 | + -------- |
| 48 | + with model: |
| 49 | + # set all parameterizations to mix of centered and non-centered |
| 50 | + vip.set_all_lambda(0.5) |
| 51 | +
|
| 52 | + pm.fit(more_obj_params=vip.variational_parameters, method="fullrank_advi") |
| 53 | + """ |
| 54 | + return list(self._logit_lambda.values()) |
| 55 | + |
| 56 | + def truncate_lambda(self, **kwargs: float): |
| 57 | + r"""Truncate :math:`\lambda_k` with :math:`\varepsilon`. |
| 58 | +
|
| 59 | + .. math:: |
| 60 | +
|
| 61 | + \hat \lambda_k = \begin{cases} |
| 62 | + 0, \quad &\lambda_k \le \varepsilon\\ |
| 63 | + \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ |
| 64 | + 1, \quad &\lambda_k \ge 1-\varepsilon\\ |
| 65 | + \end{cases} |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + kwargs : Dict[str, float] |
| 70 | + Variable to :math:`\varepsilon` mapping. |
| 71 | + If :math:`\lambda` (or :math:`1-\lambda`) is not passing |
| 72 | + the threshold of :math:`\varepsilon`, it will be clipped |
| 73 | + to 1 or zero if rounding is turned on. |
| 74 | + """ |
| 75 | + lambdas = self.get_lambda() |
| 76 | + update = dict() |
| 77 | + for var, eps in kwargs.items(): |
| 78 | + lam = lambdas[var] |
| 79 | + update[var] = np.piecewise( |
| 80 | + lam, |
| 81 | + [lam < eps, lam > (1 - eps)], |
| 82 | + [0, 1, lambda x: x], |
| 83 | + ) |
| 84 | + self.set_lambda(**update) |
| 85 | + |
| 86 | + def truncate_all_lambda(self, value: float): |
| 87 | + r"""Truncate all :math:`\lambda_k` with :math:`\varepsilon`. |
| 88 | +
|
| 89 | + .. math:: |
| 90 | +
|
| 91 | + \hat \lambda_k = \begin{cases} |
| 92 | + 0, \quad &\lambda_k \le \varepsilon\\ |
| 93 | + \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ |
| 94 | + 1, \quad &\lambda_k \ge 1-\varepsilon\\ |
| 95 | + \end{cases} |
| 96 | +
|
| 97 | +
|
| 98 | +
|
| 99 | + Parameters |
| 100 | + ---------- |
| 101 | + value : float |
| 102 | + :math:`\varepsilon` |
| 103 | + """ |
| 104 | + truncate = dict.fromkeys( |
| 105 | + self._logit_lambda.keys(), |
| 106 | + value, |
| 107 | + ) |
| 108 | + self.truncate_lambda(**truncate) |
| 109 | + |
| 110 | + def get_lambda(self) -> Dict[str, np.ndarray]: |
| 111 | + r"""Get :math:`\lambda_k` that are currently used by the model. |
| 112 | +
|
| 113 | + Returns |
| 114 | + ------- |
| 115 | + Dict[str, np.ndarray] |
| 116 | + Mapping from variable name to :math:`\lambda_k`. |
| 117 | + """ |
| 118 | + return { |
| 119 | + name: scipy.special.expit(shared.get_value()) |
| 120 | + for name, shared in self._logit_lambda.items() |
| 121 | + } |
| 122 | + |
| 123 | + def set_lambda(self, **kwargs: Dict[str, Union[np.ndarray, float]]): |
| 124 | + r"""Set :math:`\lambda_k` per variable.""" |
| 125 | + for key, value in kwargs.items(): |
| 126 | + logit_lam = scipy.special.logit(value) |
| 127 | + shared = self._logit_lambda[key] |
| 128 | + fill = np.broadcast_to( |
| 129 | + logit_lam, |
| 130 | + shared.type.shape, |
| 131 | + ) |
| 132 | + shared.set_value(fill) |
| 133 | + |
| 134 | + def set_all_lambda(self, value: Union[np.ndarray, float]): |
| 135 | + r"""Set :math:`\lambda_k` globally.""" |
| 136 | + config = dict.fromkeys( |
| 137 | + self._logit_lambda.keys(), |
| 138 | + value, |
| 139 | + ) |
| 140 | + self.set_lambda(**config) |
| 141 | + |
| 142 | + def fit(self, *args, **kwargs) -> pm.Approximation: |
| 143 | + r"""Set :math:`\lambda_k` using Variational Inference. |
| 144 | +
|
| 145 | + Examples |
| 146 | + -------- |
| 147 | +
|
| 148 | + .. code-block:: python |
| 149 | +
|
| 150 | + with model: |
| 151 | + # set all parameterizations to mix of centered and non-centered |
| 152 | + vip.set_all_lambda(0.5) |
| 153 | +
|
| 154 | + # fit using ADVI |
| 155 | + mf = vip.fit(random_seed=42) |
| 156 | + """ |
| 157 | + kwargs.setdefault("obj_optimizer", pm.adagrad_window(learning_rate=0.1)) |
| 158 | + kwargs.setdefault("method", "advi") |
| 159 | + return pm.fit( |
| 160 | + *args, |
| 161 | + more_obj_params=self.variational_parameters, |
| 162 | + **kwargs, |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +def vip_reparam_node( |
| 167 | + op: RandomVariable, |
| 168 | + node: Apply, |
| 169 | + name: str, |
| 170 | + dims: List[Variable], |
| 171 | + transform: Optional[Transform], |
| 172 | +) -> Tuple[ModelDeterministic, ModelNamed]: |
| 173 | + if not isinstance(node.op, RandomVariable): |
| 174 | + raise TypeError("Op should be RandomVariable type") |
| 175 | + size = node.inputs[1] |
| 176 | + if not isinstance(size, pt.TensorConstant): |
| 177 | + raise ValueError("Size should be static for autoreparametrization.") |
| 178 | + logit_lam_ = pytensor.shared( |
| 179 | + np.zeros(size.data), |
| 180 | + shape=size.data, |
| 181 | + name=f"{name}::lam_logit__", |
| 182 | + ) |
| 183 | + logit_lam = model_named(logit_lam_, *dims) |
| 184 | + lam = pt.sigmoid(logit_lam) |
| 185 | + return ( |
| 186 | + _vip_reparam_node( |
| 187 | + op, |
| 188 | + node=node, |
| 189 | + name=name, |
| 190 | + dims=dims, |
| 191 | + transform=transform, |
| 192 | + lam=lam, |
| 193 | + ), |
| 194 | + logit_lam, |
| 195 | + ) |
| 196 | + |
| 197 | + |
| 198 | +@singledispatch |
| 199 | +def _vip_reparam_node( |
| 200 | + op: RandomVariable, |
| 201 | + node: Apply, |
| 202 | + name: str, |
| 203 | + dims: List[Variable], |
| 204 | + transform: Optional[Transform], |
| 205 | + lam: pt.TensorVariable, |
| 206 | +) -> ModelDeterministic: |
| 207 | + raise NotImplementedError |
| 208 | + |
| 209 | + |
| 210 | +@_vip_reparam_node.register |
| 211 | +def _( |
| 212 | + op: pm.Normal, |
| 213 | + node: Apply, |
| 214 | + name: str, |
| 215 | + dims: List[Variable], |
| 216 | + transform: Optional[Transform], |
| 217 | + lam: pt.TensorVariable, |
| 218 | +) -> ModelDeterministic: |
| 219 | + rng, size, _, loc, scale = node.inputs |
| 220 | + if transform is not None: |
| 221 | + raise NotImplementedError("Reparametrization of Normal with Transform is not implemented") |
| 222 | + vip_rv_ = pm.Normal.dist( |
| 223 | + lam * loc, |
| 224 | + scale**lam, |
| 225 | + size=size, |
| 226 | + rng=rng, |
| 227 | + ) |
| 228 | + vip_rv_.name = f"{name}::tau_" |
| 229 | + |
| 230 | + vip_rv = model_free_rv( |
| 231 | + vip_rv_, |
| 232 | + vip_rv_.clone(), |
| 233 | + None, |
| 234 | + *dims, |
| 235 | + ) |
| 236 | + |
| 237 | + vip_rep_ = loc + scale ** (1 - lam) * (vip_rv - lam * loc) |
| 238 | + |
| 239 | + vip_rep_.name = name |
| 240 | + |
| 241 | + vip_rep = model_deterministic(vip_rep_, *dims) |
| 242 | + return vip_rep |
| 243 | + |
| 244 | + |
| 245 | +def vip_reparametrize( |
| 246 | + model: pm.Model, |
| 247 | + var_names: Sequence[str], |
| 248 | +) -> Tuple[pm.Model, VIP]: |
| 249 | + r"""Repametrize Model using Variationally Informed Parametrization (VIP). |
| 250 | +
|
| 251 | + .. math:: |
| 252 | +
|
| 253 | + \begin{align*} |
| 254 | + \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\ |
| 255 | + \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu) |
| 256 | + \sim \text{normal}(\mu, \sigma). |
| 257 | + \end{align*} |
| 258 | +
|
| 259 | + Parameters |
| 260 | + ---------- |
| 261 | + model : Model |
| 262 | + Model with centered parameterizations for variables. |
| 263 | + var_names : Sequence[str] |
| 264 | + Target variables to reparemetrize. |
| 265 | +
|
| 266 | + Returns |
| 267 | + ------- |
| 268 | + Tuple[Model, VIP] |
| 269 | + Updated model and VIP helper to reparametrize or infer parametrization of the model. |
| 270 | +
|
| 271 | + Examples |
| 272 | + -------- |
| 273 | + The traditional eight schools. |
| 274 | +
|
| 275 | + .. code-block:: python |
| 276 | +
|
| 277 | + import pymc as pm |
| 278 | + import numpy as np |
| 279 | +
|
| 280 | + J = 8 |
| 281 | + y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) |
| 282 | + sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) |
| 283 | +
|
| 284 | + with pm.Model() as Centered_eight: |
| 285 | + mu = pm.Normal("mu", mu=0, sigma=5) |
| 286 | + tau = pm.HalfCauchy("tau", beta=5) |
| 287 | + theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J) |
| 288 | + obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y) |
| 289 | +
|
| 290 | + The regular model definition with centered parametrization is sufficient to use VIP. |
| 291 | + To change the model parametrization use the following function. |
| 292 | +
|
| 293 | + .. code-block:: python |
| 294 | +
|
| 295 | + from pymc_experimental.model.transforms.autoreparam import vip_reparametrize |
| 296 | + Reparam_eight, vip = vip_reparametrize(Centered_eight, ["theta"]) |
| 297 | +
|
| 298 | + with Reparam_eight: |
| 299 | + # set all parameterizations to cenered (not needed) |
| 300 | + vip.set_all_lambda(1) |
| 301 | +
|
| 302 | + # set all parameterizations to non-cenered (desired) |
| 303 | + vip.set_all_lambda(0) |
| 304 | +
|
| 305 | + # or per variable |
| 306 | + vip.set_lambda(theta=0) |
| 307 | +
|
| 308 | + # just set non-centered parameterization |
| 309 | + trace = pm.sample() |
| 310 | +
|
| 311 | + However, setting it manually is not always great experience, we can learn it. |
| 312 | +
|
| 313 | + .. code-block:: python |
| 314 | +
|
| 315 | + with Reparam_eight: |
| 316 | + # set all parameterizations to mix of centered and non-centered |
| 317 | + vip.set_all_lambda(0.5) |
| 318 | +
|
| 319 | + # fit using ADVI |
| 320 | + mf = vip.fit(random_seed=42) |
| 321 | +
|
| 322 | + # display lambdas |
| 323 | + print(vip.get_lambda()) |
| 324 | +
|
| 325 | + # {'theta': array([0.01473405, 0.02221006, 0.03656685, 0.03798879, 0.04876761, |
| 326 | + # 0.0300203 , 0.02733082, 0.01817754])} |
| 327 | +
|
| 328 | + Now you can use sampling again: |
| 329 | +
|
| 330 | + .. code-block:: python |
| 331 | +
|
| 332 | + with Reparam_eight: |
| 333 | + trace = pm.sample() |
| 334 | +
|
| 335 | + Sometimes it makes sense to enable clipping (that is off by default). |
| 336 | + The idea is to round :math:`\varepsilon` to the closest extremum (:math:`0` or :math:`0`) |
| 337 | +
|
| 338 | + .. math:: |
| 339 | +
|
| 340 | + \hat \lambda_k = \begin{cases} |
| 341 | + 0, \quad &\lambda_k \le \varepsilon\\ |
| 342 | + \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ |
| 343 | + 1, \quad &\lambda_k \ge 1-\varepsilon |
| 344 | + \end{cases} |
| 345 | +
|
| 346 | + .. code-block:: python |
| 347 | +
|
| 348 | + vip.truncate_all_lambda(0.1) |
| 349 | +
|
| 350 | + Sampling has to be performed again |
| 351 | +
|
| 352 | + .. code-block:: python |
| 353 | +
|
| 354 | + with Reparam_eight: |
| 355 | + trace = pm.sample() |
| 356 | + """ |
| 357 | + fmodel, memo = fgraph_from_model(model) |
| 358 | + lambda_names = [] |
| 359 | + replacements = [] |
| 360 | + for name in var_names: |
| 361 | + old = memo[model.named_vars[name]] |
| 362 | + rv, _, *dims = old.owner.inputs |
| 363 | + new, lam = vip_reparam_node( |
| 364 | + rv.owner.op, |
| 365 | + rv.owner, |
| 366 | + name=rv.name, |
| 367 | + dims=dims, |
| 368 | + transform=old.owner.op.transform, |
| 369 | + ) |
| 370 | + replacements.append((old, new)) |
| 371 | + lambda_names.append(lam.name) |
| 372 | + toposort_replace(fmodel, replacements, reverse=True) |
| 373 | + reparam_model = model_from_fgraph(fmodel) |
| 374 | + model_lambdas = {n: reparam_model[l] for l, n in zip(lambda_names, var_names)} |
| 375 | + vip = VIP(model_lambdas) |
| 376 | + return reparam_model, vip |
0 commit comments