26
26
from pymc .distributions import distribution , multivariate
27
27
from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
28
28
from pymc .distributions .dist_math import check_parameters
29
+ from pymc .distributions .distribution import moment
29
30
from pymc .distributions .logprob import ignore_logprob , logp
30
31
from pymc .distributions .shape_utils import rv_size_is_none , to_tuple
31
32
from pymc .util import check_dist_not_registered
@@ -131,7 +132,9 @@ def rng_fn(
131
132
else :
132
133
dist_shape = (* size , int (steps ))
133
134
134
- innovations = rng .normal (loc = mu , scale = sigma , size = dist_shape )
135
+ # Add one dimension to the right, so that mu and sigma broadcast safely along
136
+ # the steps dimension
137
+ innovations = rng .normal (loc = mu [..., None ], scale = sigma [..., None ], size = dist_shape )
135
138
grw = np .concatenate ([init [..., None ], innovations ], axis = - 1 )
136
139
return np .cumsum (grw , axis = - 1 )
137
140
@@ -211,6 +214,14 @@ def dist(
211
214
212
215
return super ().dist ([mu , sigma , init , steps ], size = size , ** kwargs )
213
216
217
+ def moment (rv , size , mu , sigma , init , steps ):
218
+ grw_moment = at .zeros_like (rv )
219
+ grw_moment = at .set_subtensor (grw_moment [..., 0 ], moment (init ))
220
+ # Add one dimension to the right, so that mu broadcasts safely along the steps
221
+ # dimension
222
+ grw_moment = at .set_subtensor (grw_moment [..., 1 :], mu [..., None ])
223
+ return at .cumsum (grw_moment , axis = - 1 )
224
+
214
225
def logp (
215
226
value : at .Variable ,
216
227
mu : at .Variable ,
@@ -225,7 +236,9 @@ def logp(
225
236
226
237
# Make time series stationary around the mean value
227
238
stationary_series = value [..., 1 :] - value [..., :- 1 ]
228
- series_logp = logp (Normal .dist (mu , sigma ), stationary_series )
239
+ # Add one dimension to the right, so that mu and sigma broadcast safely along
240
+ # the steps dimension
241
+ series_logp = logp (Normal .dist (mu [..., None ], sigma [..., None ]), stationary_series )
229
242
230
243
return check_parameters (
231
244
init_logp + series_logp .sum (axis = - 1 ),
0 commit comments