Skip to content

Commit c2d96cf

Browse files
committed
Use unit normal as default init_dist in GaussianRandomWalk and AR
1 parent 2536920 commit c2d96cf

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

pymc/distributions/timeseries.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ class GaussianRandomWalk(distribution.Continuous):
225225
sigma > 0, innovation standard deviation, defaults to 1.0
226226
init : unnamed distribution
227227
Univariate distribution of the initial value, created with the `.dist()` API.
228-
Defaults to Normal with same `mu` and `sigma` as the GaussianRandomWalk
228+
Defaults to a unit Normal.
229229
230230
.. warning:: init will be cloned, rendering them independent of the ones passed as input.
231231
@@ -265,7 +265,7 @@ def dist(
265265

266266
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
267267
if init is None:
268-
init = Normal.dist(mu, sigma)
268+
init = Normal.dist(0, 1)
269269
else:
270270
if not (
271271
isinstance(init, at.TensorVariable)
@@ -361,7 +361,7 @@ class AR(SymbolicDistribution):
361361
Whether the first element of rho should be used as a constant term in the AR
362362
process. Defaults to False
363363
init_dist: unnamed distribution, optional
364-
Scalar or vector distribution for initial values. Defaults to Normal(0, sigma).
364+
Scalar or vector distribution for initial values. Defaults to a unit Normal.
365365
Distribution should be created via the `.dist()` API, and have dimension
366366
(*size, ar_order). If not, it will be automatically resized.
367367
@@ -452,8 +452,7 @@ def dist(
452452
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
453453
)
454454
else:
455-
# Sigma must broadcast with ar_order
456-
init_dist = Normal.dist(sigma=at.shape_padright(sigma), size=(*sigma.shape, ar_order))
455+
init_dist = Normal.dist(0, 1, size=(*sigma.shape, ar_order))
457456

458457
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
459458
init_dist = ignore_logprob(init_dist)

pymc/tests/test_distributions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2610,7 +2610,7 @@ def test_gaussianrandomwalk(self):
26102610
def ref_logp(value, mu, sigma, steps):
26112611
# Relying on fact that init will be normal by default
26122612
return (
2613-
scipy.stats.norm.logpdf(value[0], mu, sigma)
2613+
scipy.stats.norm.logpdf(value[0])
26142614
+ scipy.stats.norm.logpdf(np.diff(value), mu, sigma).sum()
26152615
)
26162616

pymc/tests/test_distributions_timeseries.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def test_batched_sigma(self):
333333
"y",
334334
beta_tp,
335335
sigma=sigma,
336+
init_dist=Normal.dist(0, sigma[..., None]),
336337
size=batch_size,
337338
steps=steps,
338339
initval=y_tp,
@@ -346,6 +347,7 @@ def test_batched_sigma(self):
346347
f"y_{i}{j}",
347348
beta_tp,
348349
sigma=sigma[i][j],
350+
init_dist=Normal.dist(0, sigma[i][j]),
349351
shape=steps,
350352
initval=y_tp[i, j],
351353
ar_order=ar_order,
@@ -371,7 +373,7 @@ def test_batched_init_dist(self):
371373
beta_tp = aesara.shared(np.random.randn(ar_order), shape=(3,))
372374
y_tp = np.random.randn(batch_size, steps)
373375
with Model() as t0:
374-
init_dist = Normal.dist(0.0, 0.01, size=(batch_size, ar_order))
376+
init_dist = Normal.dist(0.0, 1.0, size=(batch_size, ar_order))
375377
AR("y", beta_tp, sigma=0.01, init_dist=init_dist, steps=steps, initval=y_tp)
376378
with Model() as t1:
377379
for i in range(batch_size):

0 commit comments

Comments
 (0)