33
33
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34
34
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35
35
# SOFTWARE.
36
-
37
-
38
36
from typing import cast
39
37
40
38
import pytensor .tensor as pt
41
39
42
40
from pytensor .graph .basic import Apply
43
41
from pytensor .graph .fg import FunctionGraph
44
42
from pytensor .graph .rewriting .basic import node_rewriter
45
- from pytensor .tensor .elemwise import Elemwise
46
43
from pytensor .tensor .math import Max
47
- from pytensor .tensor .random .op import RandomVariable
48
44
from pytensor .tensor .variable import TensorVariable
49
45
50
46
from pymc .logprob .abstract import (
47
+ MeasurableElemwise ,
48
+ MeasurableOp ,
51
49
MeasurableOpMixin ,
52
50
_logcdf_helper ,
53
51
_logprob ,
54
52
_logprob_helper ,
55
53
)
56
54
from pymc .logprob .rewriting import measurable_ir_rewrites_db
57
- from pymc .logprob .utils import find_negated_var
58
55
from pymc .math import logdiffexp
59
56
from pymc .pytensorf import constant_fold
60
57
@@ -73,25 +70,41 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
73
70
if rv_map_feature is None :
74
71
return None # pragma: no cover
75
72
76
- if isinstance (node .op , MeasurableMax ):
77
- return None # pragma: no cover
73
+ if isinstance (node .op , MeasurableMax | MeasurableMaxDiscrete ):
74
+ return None
78
75
79
- base_var = cast ( TensorVariable , node .inputs [ 0 ])
76
+ [ base_var ] = node .inputs
80
77
81
78
if base_var .owner is None :
82
79
return None
83
80
84
81
if not rv_map_feature .request_measurable (node .inputs ):
85
82
return None
86
83
87
- # Non-univariate distributions and non-RVs must be rejected
88
- if not (isinstance (base_var .owner .op , RandomVariable ) and base_var .owner .op .ndim_supp == 0 ):
84
+ # We allow Max of RandomVariables or Elemwise of univariate RandomVariables
85
+ if isinstance (base_var .owner .op , MeasurableElemwise ):
86
+ latent_base_vars = [
87
+ var
88
+ for var in base_var .owner .inputs
89
+ if (var .owner and isinstance (var .owner .op , MeasurableOp ))
90
+ ]
91
+ if len (latent_base_vars ) != 1 :
92
+ return None
93
+ [latent_base_var ] = latent_base_vars
94
+ else :
95
+ latent_base_var = base_var
96
+
97
+ latent_op = latent_base_var .owner .op
98
+ if not (hasattr (latent_op , "dist_params" ) and getattr (latent_op , "ndim_supp" ) == 0 ):
89
99
return None
90
100
91
101
# univariate i.i.d. test which also rules out other distributions
92
- for params in base_var .owner .op .dist_params (base_var .owner ):
93
- if not all (params .type .broadcastable ):
94
- return None
102
+ if not all (
103
+ all (params .type .broadcastable ) for params in latent_op .dist_params (latent_base_var .owner )
104
+ ):
105
+ return None
106
+
107
+ base_var = cast (TensorVariable , base_var )
95
108
96
109
if node .op .axis is None :
97
110
axis = tuple (range (base_var .ndim ))
@@ -102,16 +115,11 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
102
115
return None
103
116
104
117
# distinguish measurable discrete and continuous (because logprob is different)
105
- measurable_max : Max
106
- if base_var .type .dtype .startswith ("int" ):
107
- measurable_max = MeasurableMaxDiscrete (axis )
108
- else :
109
- measurable_max = MeasurableMax (axis )
110
-
111
- max_rv_node = measurable_max .make_node (base_var )
112
- max_rv = max_rv_node .outputs
113
-
114
- return max_rv
118
+ measurable_max_class = (
119
+ MeasurableMaxDiscrete if latent_base_var .type .dtype .startswith ("int" ) else MeasurableMax
120
+ )
121
+ max_rv = cast (TensorVariable , measurable_max_class (axis )(base_var ))
122
+ return [max_rv ]
115
123
116
124
117
125
measurable_ir_rewrites_db .register (
@@ -127,13 +135,13 @@ def max_logprob(op, values, base_rv, **kwargs):
127
135
r"""Compute the log-likelihood graph for the `Max` operation."""
128
136
(value ,) = values
129
137
130
- logprob = _logprob_helper (base_rv , value )
131
- logcdf = _logcdf_helper (base_rv , value )
138
+ base_rv_shape = constant_fold (tuple (base_rv .shape ), raise_not_constant = False )
139
+ bcast_value = pt .broadcast_to (value , base_rv_shape )
140
+ logprob = _logprob_helper (base_rv , bcast_value )[0 ]
141
+ logcdf = _logcdf_helper (base_rv , bcast_value )[0 ]
132
142
133
- [n ] = constant_fold ([base_rv .size ])
134
- logprob = (n - 1 ) * logcdf + logprob + pt .math .log (n )
135
-
136
- return logprob
143
+ n = pt .prod (base_rv_shape )
144
+ return (n - 1 ) * logcdf + logprob + pt .math .log (n )
137
145
138
146
139
147
@_logprob .register (MeasurableMaxDiscrete )
@@ -146,126 +154,11 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
146
154
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
147
155
"""
148
156
(value ,) = values
149
- logcdf = _logcdf_helper (base_rv , value )
150
- logcdf_prev = _logcdf_helper (base_rv , value - 1 )
151
-
152
- [n ] = constant_fold ([base_rv .size ])
153
-
154
- logprob = logdiffexp (n * logcdf , n * logcdf_prev )
155
-
156
- return logprob
157
-
158
-
159
- class MeasurableMaxNeg (MeasurableOpMixin , Max ):
160
- """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
161
- This shows up in the graph of min, which is (neg(max(neg(x)))."""
162
-
163
-
164
- class MeasurableDiscreteMaxNeg (MeasurableOpMixin , Max ):
165
- """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
166
-
167
-
168
- @node_rewriter (tracks = [Max ])
169
- def find_measurable_max_neg (fgraph : FunctionGraph , node : Apply ) -> list [TensorVariable ] | None :
170
- rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
171
-
172
- if rv_map_feature is None :
173
- return None # pragma: no cover
174
-
175
- if isinstance (node .op , MeasurableMaxNeg ):
176
- return None # pragma: no cover
177
-
178
- base_var = cast (TensorVariable , node .inputs [0 ])
179
-
180
- # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
181
- if not (base_var .owner is not None and isinstance (base_var .owner .op , Elemwise )):
182
- return None
183
-
184
- base_rv = find_negated_var (base_var )
185
-
186
- # negation is rv * (-1). Hence the scalar_op must be Mul
187
- if base_rv is None :
188
- return None
189
-
190
- # Non-univariate distributions and non-RVs must be rejected
191
- if not (isinstance (base_rv .owner .op , RandomVariable ) and base_rv .owner .op .ndim_supp == 0 ):
192
- return None
193
-
194
- # univariate i.i.d. test which also rules out other distributions
195
- for params in base_rv .owner .op .dist_params (base_rv .owner ):
196
- if not all (params .type .broadcastable ):
197
- return None
198
157
199
- if node .op .axis is None :
200
- axis = tuple (range (base_var .ndim ))
201
- else :
202
- # Check whether axis is supported or not
203
- axis = tuple (sorted (node .op .axis ))
204
- if axis != tuple (range (base_var .ndim )):
205
- return None
206
-
207
- if not rv_map_feature .request_measurable ([base_rv ]):
208
- return None
209
-
210
- # distinguish measurable discrete and continuous (because logprob is different)
211
- measurable_min : Max
212
- if base_rv .type .dtype .startswith ("int" ):
213
- measurable_min = MeasurableDiscreteMaxNeg (axis )
214
- else :
215
- measurable_min = MeasurableMaxNeg (axis )
216
-
217
- return measurable_min .make_node (base_rv ).outputs
218
-
219
-
220
- measurable_ir_rewrites_db .register (
221
- "find_measurable_max_neg" ,
222
- find_measurable_max_neg ,
223
- "basic" ,
224
- "min" ,
225
- )
226
-
227
-
228
- @_logprob .register (MeasurableMaxNeg )
229
- def max_neg_logprob (op , values , base_rv , ** kwargs ):
230
- r"""Compute the log-likelihood graph for the `Max` operation.
231
- The formula that we use here is :
232
- \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
233
- where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
234
- """
235
- (value ,) = values
236
-
237
- logprob = _logprob_helper (base_rv , - value )
238
- logcdf = _logcdf_helper (base_rv , - value )
239
-
240
- [n ] = constant_fold ([base_rv .size ])
241
- logprob = (n - 1 ) * pt .math .log (1 - pt .math .exp (logcdf )) + logprob + pt .math .log (n )
242
-
243
- return logprob
244
-
245
-
246
- @_logprob .register (MeasurableDiscreteMaxNeg )
247
- def discrete_max_neg_logprob (op , values , base_rv , ** kwargs ):
248
- r"""Compute the log-likelihood graph for the `Max` operation.
249
-
250
- The formula that we use here is :
251
- .. math::
252
- \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
253
- where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
254
- """
255
-
256
- (value ,) = values
257
-
258
- # The cdf of a negative variable is the survival at the negated value
259
- logcdf = pt .log1mexp (_logcdf_helper (base_rv , - value ))
260
- logcdf_prev = pt .log1mexp (_logcdf_helper (base_rv , - (value + 1 )))
261
-
262
- [n ] = constant_fold ([base_rv .size ])
263
-
264
- # Now we can use the same expression as the discrete max
265
- logprob = pt .where (
266
- pt .and_ (pt .eq (logcdf , - pt .inf ), pt .eq (logcdf_prev , - pt .inf )),
267
- - pt .inf ,
268
- logdiffexp (n * logcdf_prev , n * logcdf ),
269
- )
158
+ base_rv_shape = constant_fold (tuple (base_rv .shape ), raise_not_constant = False )
159
+ bcast_value = pt .broadcast_to (value , base_rv_shape )
160
+ logcdf = _logcdf_helper (base_rv , bcast_value )[0 ]
161
+ logcdf_prev = _logcdf_helper (base_rv , bcast_value - 1 )[0 ]
270
162
271
- return logprob
163
+ n = pt .prod (base_rv_shape )
164
+ return logdiffexp (n * logcdf , n * logcdf_prev )
0 commit comments