|
| 1 | +from collections.abc import Mapping |
| 2 | +from functools import singledispatch |
| 3 | +from typing import Dict, Optional, Union |
| 4 | + |
| 5 | +import aesara.tensor as at |
| 6 | +from aesara import config |
| 7 | +from aesara.gradient import disconnected_grad |
| 8 | +from aesara.graph.basic import Constant, clone, graph_inputs, io_toposort |
| 9 | +from aesara.graph.fg import FunctionGraph |
| 10 | +from aesara.graph.op import Op, compute_test_value |
| 11 | +from aesara.tensor.random.op import RandomVariable |
| 12 | +from aesara.tensor.random.opt import local_subtensor_rv_lift |
| 13 | +from aesara.tensor.subtensor import ( |
| 14 | + AdvancedIncSubtensor, |
| 15 | + AdvancedIncSubtensor1, |
| 16 | + AdvancedSubtensor, |
| 17 | + AdvancedSubtensor1, |
| 18 | + IncSubtensor, |
| 19 | + Subtensor, |
| 20 | +) |
| 21 | +from aesara.tensor.var import TensorVariable |
| 22 | + |
| 23 | +from aeppl.logpdf import _logpdf |
| 24 | +from aeppl.utils import ( |
| 25 | + extract_rv_and_value_vars, |
| 26 | + indices_from_subtensor, |
| 27 | + rvs_to_value_vars, |
| 28 | +) |
| 29 | + |
| 30 | + |
| 31 | +def loglik( |
| 32 | + var: TensorVariable, |
| 33 | + rv_values: Optional[ |
| 34 | + Union[TensorVariable, Dict[TensorVariable, TensorVariable]] |
| 35 | + ] = None, |
| 36 | + **kwargs, |
| 37 | +) -> TensorVariable: |
| 38 | + """Create a measure-space (i.e. log-likelihood) graph for a random variable at a given point. |
| 39 | +
|
| 40 | + The input `var` determines which log-likelihood graph is used and |
| 41 | + `rv_value` is that graph's input parameter. For example, if `var` is |
| 42 | + the output of a ``NormalRV`` ``Op``, then the output is a graph of the |
| 43 | + density function for `var` set to the value `rv_value`. |
| 44 | +
|
| 45 | + Parameters |
| 46 | + ========== |
| 47 | + var |
| 48 | + The `RandomVariable` output that determines the log-likelihood graph. |
| 49 | + rv_values |
| 50 | + A variable, or ``dict`` of variables, that represents the value of |
| 51 | + `var` in its log-likelihood. If no `rv_value` is provided, |
| 52 | + ``var.tag.value_var`` will be checked and, when available, used. |
| 53 | +
|
| 54 | + """ |
| 55 | + if not isinstance(rv_values, Mapping): |
| 56 | + rv_values = {var: rv_values} if rv_values is not None else {} |
| 57 | + |
| 58 | + rv_var, rv_value_var = extract_rv_and_value_vars(var) |
| 59 | + |
| 60 | + rv_value = rv_values.get(rv_var, rv_value_var) |
| 61 | + |
| 62 | + if rv_var is not None and rv_value is None: |
| 63 | + raise ValueError(f"No value variable specified or associated with {rv_var}") |
| 64 | + |
| 65 | + if rv_value is not None: |
| 66 | + rv_value = at.as_tensor(rv_value) |
| 67 | + |
| 68 | + if rv_var is not None: |
| 69 | + # Make sure that the value is compatible with the random variable |
| 70 | + rv_value = rv_var.type.filter_variable(rv_value.astype(rv_var.dtype)) |
| 71 | + |
| 72 | + if rv_value_var is None: |
| 73 | + rv_value_var = rv_value |
| 74 | + |
| 75 | + if rv_var is None: |
| 76 | + if var.owner is not None: |
| 77 | + return _loglik( |
| 78 | + var.owner.op, |
| 79 | + var, |
| 80 | + rv_values, |
| 81 | + *var.owner.inputs, |
| 82 | + ) |
| 83 | + |
| 84 | + return at.zeros_like(var) |
| 85 | + |
| 86 | + rv_node = rv_var.owner |
| 87 | + |
| 88 | + rng, size, dtype, *dist_params = rv_node.inputs |
| 89 | + |
| 90 | + # Here, we plug the actual random variable into the log-likelihood graph, |
| 91 | + # because we want a log-likelihood graph that only contains |
| 92 | + # random variables. This is important, because a random variable's |
| 93 | + # parameters can contain random variables themselves. |
| 94 | + # Ultimately, with a graph containing only random variables and |
| 95 | + # "deterministics", we can simply replace all the random variables with |
| 96 | + # their value variables and be done. |
| 97 | + |
| 98 | + # tmp_rv_values = rv_values.copy() |
| 99 | + # tmp_rv_values[rv_var] = rv_var |
| 100 | + |
| 101 | + logpdf_var = _logpdf(rv_node.op, rv_value_var, *dist_params, **kwargs) |
| 102 | + |
| 103 | + # Replace random variables with their value variables |
| 104 | + replacements = rv_values.copy() |
| 105 | + replacements.update({rv_var: rv_value, rv_value_var: rv_value}) |
| 106 | + |
| 107 | + (logpdf_var,), _ = rvs_to_value_vars( |
| 108 | + (logpdf_var,), |
| 109 | + initial_replacements=replacements, |
| 110 | + ) |
| 111 | + |
| 112 | + if sum: |
| 113 | + logpdf_var = at.sum(logpdf_var) |
| 114 | + |
| 115 | + # Recompute test values for the changes introduced by the replacements |
| 116 | + # above. |
| 117 | + if config.compute_test_value != "off": |
| 118 | + for node in io_toposort(graph_inputs((logpdf_var,)), (logpdf_var,)): |
| 119 | + compute_test_value(node) |
| 120 | + |
| 121 | + if rv_var.name is not None: |
| 122 | + logpdf_var.name = "__logp_%s" % rv_var.name |
| 123 | + |
| 124 | + return logpdf_var |
| 125 | + |
| 126 | + |
| 127 | +@singledispatch |
| 128 | +def _loglik( |
| 129 | + op: Op, |
| 130 | + var: TensorVariable, |
| 131 | + rvs_to_values: Dict[TensorVariable, TensorVariable], |
| 132 | + *inputs: TensorVariable, |
| 133 | + **kwargs, |
| 134 | +): |
| 135 | + """Create a graph for the log-likelihood of a ``Variable``. |
| 136 | +
|
| 137 | + This function dispatches on the type of ``op``. If you want to implement |
| 138 | + new graphs for an ``Op``, register a new function on this dispatcher. |
| 139 | +
|
| 140 | + The default returns a graph producing only zeros. |
| 141 | +
|
| 142 | + """ |
| 143 | + value_var = rvs_to_values.get(var, var) |
| 144 | + return at.zeros_like(value_var) |
| 145 | + |
| 146 | + |
| 147 | +@_loglik.register(IncSubtensor) |
| 148 | +@_loglik.register(AdvancedIncSubtensor) |
| 149 | +@_loglik.register(AdvancedIncSubtensor1) |
| 150 | +def incsubtensor_loglik( |
| 151 | + op, var, rvs_to_values, indexed_rv_var, rv_values, *indices, **kwargs |
| 152 | +): |
| 153 | + |
| 154 | + index = indices_from_subtensor(getattr(op, "idx_list", None), indices) |
| 155 | + |
| 156 | + _, (new_rv_var,) = clone( |
| 157 | + tuple( |
| 158 | + v for v in graph_inputs((indexed_rv_var,)) if not isinstance(v, Constant) |
| 159 | + ), |
| 160 | + (indexed_rv_var,), |
| 161 | + copy_inputs=False, |
| 162 | + copy_orphans=False, |
| 163 | + ) |
| 164 | + new_values = at.set_subtensor(disconnected_grad(new_rv_var)[index], rv_values) |
| 165 | + logp_var = loglik(indexed_rv_var, new_values, **kwargs) |
| 166 | + |
| 167 | + return logp_var |
| 168 | + |
| 169 | + |
| 170 | +@_loglik.register(Subtensor) |
| 171 | +@_loglik.register(AdvancedSubtensor) |
| 172 | +@_loglik.register(AdvancedSubtensor1) |
| 173 | +def subtensor_loglik(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs): |
| 174 | + |
| 175 | + index = indices_from_subtensor(getattr(op, "idx_list", None), indices) |
| 176 | + |
| 177 | + rv_value = rvs_to_values.get(var, getattr(var.tag, "value_var", None)) |
| 178 | + |
| 179 | + if indexed_rv_var.owner and isinstance(indexed_rv_var.owner.op, RandomVariable): |
| 180 | + |
| 181 | + # We need to lift the index operation through the random variable so |
| 182 | + # that we have a new random variable consisting of only the relevant |
| 183 | + # subset of variables per the index. |
| 184 | + var_copy = var.owner.clone().default_output() |
| 185 | + fgraph = FunctionGraph( |
| 186 | + [i for i in graph_inputs((indexed_rv_var,)) if not isinstance(i, Constant)], |
| 187 | + [var_copy], |
| 188 | + clone=False, |
| 189 | + ) |
| 190 | + |
| 191 | + (lifted_var,) = local_subtensor_rv_lift.transform( |
| 192 | + fgraph, fgraph.outputs[0].owner |
| 193 | + ) |
| 194 | + |
| 195 | + new_rvs_to_values = rvs_to_values.copy() |
| 196 | + new_rvs_to_values[lifted_var] = rv_value |
| 197 | + |
| 198 | + logp_var = loglik(lifted_var, new_rvs_to_values, **kwargs) |
| 199 | + |
| 200 | + for idx_var in index: |
| 201 | + logp_var += loglik(idx_var, rvs_to_values, **kwargs) |
| 202 | + |
| 203 | + # TODO: We could add the constant case (i.e. `indexed_rv_var.owner is None`) |
| 204 | + else: |
| 205 | + raise NotImplementedError( |
| 206 | + f"`Subtensor` log-likelihood not implemented for {indexed_rv_var.owner}" |
| 207 | + ) |
| 208 | + |
| 209 | + return logp_var |
0 commit comments