Skip to content

Commit 21c2e6c

Browse files
committed
Add moment to GaussianRandomWalk and fix mu/sigma broadcasting bug
1 parent 1a5b3d6 commit 21c2e6c

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

Diff for: pymc/distributions/timeseries.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc.distributions import distribution, multivariate
2727
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
2828
from pymc.distributions.dist_math import check_parameters
29+
from pymc.distributions.distribution import moment
2930
from pymc.distributions.logprob import ignore_logprob, logp
3031
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
3132
from pymc.util import check_dist_not_registered
@@ -131,7 +132,9 @@ def rng_fn(
131132
else:
132133
dist_shape = (*size, int(steps))
133134

134-
innovations = rng.normal(loc=mu, scale=sigma, size=dist_shape)
135+
# Add one dimension to the right, so that mu and sigma broadcast safely along
136+
# the steps dimension
137+
innovations = rng.normal(loc=mu[..., None], scale=sigma[..., None], size=dist_shape)
135138
grw = np.concatenate([init[..., None], innovations], axis=-1)
136139
return np.cumsum(grw, axis=-1)
137140

@@ -211,6 +214,14 @@ def dist(
211214

212215
return super().dist([mu, sigma, init, steps], size=size, **kwargs)
213216

217+
def moment(rv, size, mu, sigma, init, steps):
218+
grw_moment = at.zeros_like(rv)
219+
grw_moment = at.set_subtensor(grw_moment[..., 0], moment(init))
220+
# Add one dimension to the right, so that mu broadcasts safely along the steps
221+
# dimension
222+
grw_moment = at.set_subtensor(grw_moment[..., 1:], mu[..., None])
223+
return at.cumsum(grw_moment, axis=-1)
224+
214225
def logp(
215226
value: at.Variable,
216227
mu: at.Variable,
@@ -225,7 +236,9 @@ def logp(
225236

226237
# Make time series stationary around the mean value
227238
stationary_series = value[..., 1:] - value[..., :-1]
228-
series_logp = logp(Normal.dist(mu, sigma), stationary_series)
239+
# Add one dimension to the right, so that mu and sigma broadcast safely along
240+
# the steps dimension
241+
series_logp = logp(Normal.dist(mu[..., None], sigma[..., None]), stationary_series)
229242

230243
return check_parameters(
231244
init_logp + series_logp.sum(axis=-1),

Diff for: pymc/tests/test_distributions_moments.py

-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def test_all_distributions_have_moments():
103103
dist_module.timeseries.AR,
104104
dist_module.timeseries.AR1,
105105
dist_module.timeseries.GARCH11,
106-
dist_module.timeseries.GaussianRandomWalk,
107106
dist_module.timeseries.MvGaussianRandomWalk,
108107
dist_module.timeseries.MvStudentTRandomWalk,
109108
}

Diff for: pymc/tests/test_distributions_timeseries.py

+23
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pymc.model import Model
3030
from pymc.sampling import sample, sample_posterior_predictive
3131
from pymc.tests.helpers import select_by_precision
32+
from pymc.tests.test_distributions_moments import assert_moment_is_expected
3233
from pymc.tests.test_distributions_random import BaseTestDistributionRandom
3334

3435

@@ -142,6 +143,28 @@ def test_gaussian_random_walk_init_dist_logp(self, init):
142143
pm.logp(init, 0).eval() + scipy.stats.norm.logpdf(0),
143144
)
144145

146+
@pytest.mark.parametrize(
147+
"mu, sigma, init, steps, size, expected",
148+
[
149+
(0, 1, Normal.dist(1), 10, None, np.ones((11,))),
150+
(1, 1, Normal.dist(0), 10, (2,), np.full((2, 11), np.arange(11))),
151+
(1, 1, Normal.dist([0, 1]), 10, None, np.vstack((np.arange(11), np.arange(11) + 1))),
152+
(0, [1, 1], Normal.dist(0), 10, None, np.zeros((2, 11))),
153+
(
154+
[1, -1],
155+
1,
156+
Normal.dist(0),
157+
10,
158+
(4, 2),
159+
np.full((4, 2, 11), np.vstack((np.arange(11), -np.arange(11)))),
160+
),
161+
],
162+
)
163+
def test_moment(self, mu, sigma, init, steps, size, expected):
164+
with Model() as model:
165+
GaussianRandomWalk("x", mu=mu, sigma=sigma, init=init, steps=steps, size=size)
166+
assert_moment_is_expected(model, expected)
167+
145168

146169
@pytest.mark.xfail(reason="Timeseries not refactored")
147170
def test_AR():

0 commit comments

Comments
 (0)