Skip to content

Commit d84c718

Browse files
committed
Add fallback to prior when moment is not implemented
1 parent 216dcd5 commit d84c718

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

pymc/initial_point.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import warnings
1516

1617
from typing import Callable, Dict, List, Optional, Sequence, Set, Union
1718

@@ -269,7 +270,19 @@ def make_initial_point_expression(
269270

270271
if isinstance(strategy, str):
271272
if strategy == "moment":
272-
value = get_moment(variable)
273+
try:
274+
value = get_moment(variable)
275+
except NotImplementedError:
276+
warnings.warn(
277+
f"Moment not defined for variable {variable} of type "
278+
f"{variable.owner.op.__class__.__name__}, defaulting to "
279+
f"a draw from the prior. This can lead to difficulties "
280+
f"during tuning. You can manually define an initval or "
281+
f"implement a get_moment dispatched function for this "
282+
f"distribution.",
283+
UserWarning,
284+
)
285+
value = variable
273286
elif strategy == "prior":
274287
value = variable
275288
else:

pymc/tests/test_initial_point.py

+26
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import numpy as np
1818
import pytest
1919

20+
from aesara.tensor.random.op import RandomVariable
21+
2022
import pymc as pm
2123

2224
from pymc.distributions.distribution import get_moment
@@ -255,6 +257,30 @@ def test_moment_from_dims(self, rv_cls):
255257
assert tuple(get_moment(rv).shape.eval()) == (4, 3)
256258
pass
257259

260+
def test_moment_not_implemented_fallback(self):
261+
class MyNormalRV(RandomVariable):
262+
name = "my_normal"
263+
ndim_supp = 0
264+
ndims_params = [0, 0]
265+
dtype = "floatX"
266+
267+
@classmethod
268+
def rng_fn(cls, rng, mu, sigma, size):
269+
return np.pi
270+
271+
class MyNormalDistribution(pm.Normal):
272+
rv_op = MyNormalRV()
273+
274+
with pm.Model() as m:
275+
x = MyNormalDistribution("x", 0, 1, initval="moment")
276+
277+
with pytest.warns(
278+
UserWarning, match="Moment not defined for variable x of type MyNormalRV"
279+
):
280+
res = m.recompute_initial_point()
281+
282+
assert np.isclose(res["x"], np.pi)
283+
258284

259285
def test_pickling_issue_5090():
260286
with pm.Model() as model:

0 commit comments

Comments
 (0)