Skip to content

Commit 8f74ea9

Browse files
lucianopazrpgoldman
authored andcommitted
Remove pymc3_size_interpretation kwarg
1 parent fdf7a65 commit 8f74ea9

File tree

2 files changed

+9
-96
lines changed

2 files changed

+9
-96
lines changed

pymc3/distributions/distribution.py

+9-76
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def __init__(
211211
random=None,
212212
wrap_random_with_dist_shape=True,
213213
check_shape_in_random=True,
214-
pymc3_size_interpretation=True,
215214
*args,
216215
**kwargs
217216
):
@@ -251,29 +250,6 @@ def __init__(
251250
If ``True``, the shape of the random samples generate in the
252251
``random`` method is checked with the expected return shape. This
253252
test is only performed if ``wrap_random_with_dist_shape is False``.
254-
pymc3_size_interpretation: bool (Optional)
255-
This flag affects how the ``size`` parameter supplied to the
256-
passed ``random`` function is interpreted. If no ``random``
257-
callable is supplied, this flag is ignored. Furthermore, this flag
258-
is only used if ``wrap_random_with_dist_shape`` is `True``.
259-
If ``True``, the ``size`` parameter of ``random`` is interpreted
260-
as the number of IID draws to take from a given distribution,
261-
which is how every pymc3 distribution interprets the ``size``
262-
parameter.
263-
If it is ``False``, the ``size`` parameter is used just like in
264-
any scipy random variate generator, i.e. the shape of the returned
265-
array of samples.
266-
The difference is subtle and is only relevant in multidimensional
267-
distributions. Consider that a multivariate normal of rank 3 is
268-
used as the random number generator. With the pymc3 interpretation
269-
of ``size``, a call like ``random(size=10)`` will return an array
270-
with shape ``(10, 3)``. With the scipy interpretation, to get a
271-
similarly shaped result, the call must be changed to
272-
``random(size=(10, 3))``.
273-
A quick rule of thumb is that if the supplied ``random`` callable
274-
is a pymc3 distribution's random method, you should set this flag
275-
to ``True``. If you use a scipy rvs, you should set this flag to
276-
``False``. Refer to the examples for more information.
277253
args, kwargs: (Optional)
278254
These are passed to the parent class' ``__init__``.
279255
@@ -371,52 +347,12 @@ def __init__(
371347
# We get samples with an incorrect shape
372348
assert prior.shape != (10, 100, 3)
373349
374-
The default catching can be disabled with the
375-
``check_shape_in_random`` parameter.
376-
377-
378-
.. code-block:: python
379-
380-
with pm.Model():
381-
mu = pm.Normal('mu', 0 , 1)
382-
normal_dist = pm.Normal.dist(mu, 1, shape=3)
383-
dens = pm.DensityDist(
384-
'density_dist',
385-
normal_dist.logp,
386-
observed=np.random.randn(100, 3),
387-
shape=3,
388-
random=normal_dist.random,
389-
wrap_random_with_dist_shape=False, # Is True by default
390-
check_shape_in_random=False, # Is True by default
391-
)
392-
prior = pm.sample_prior_predictive(10)['density_dist']
393-
# We get samples with an incorrect shape
394-
assert prior.shape != (10, 100, 3)
395-
396-
One final word of caution. If you use a pymc3 distribution's random
397-
method, you should stick with ``pymc3_size_interpretation=True``, or
398-
you will get incorrectly shaped samples.
399-
400-
401-
.. code-block:: python
402-
403-
with pm.Model():
404-
mu = pm.Normal('mu', 0 , 1)
405-
normal_dist = pm.Normal.dist(mu, 1, shape=3)
406-
dens = pm.DensityDist(
407-
'density_dist',
408-
normal_dist.logp,
409-
observed=np.random.randn(100, 3),
410-
shape=3,
411-
random=normal_dist.random,
412-
pymc3_size_interpretation=False, # Is True by default
413-
)
414-
prior = pm.sample_prior_predictive(10)['density_dist']
415-
assert prior.shape != (10, 100, 3)
416-
assert prior.shape == (10, 100, 3, 3)
417-
418-
If you use callables that work with ``scipy.stats`` rvs, you should
419-
set ``pymc3_size_interpretation=False``.
350+
If you use callables that work with ``scipy.stats`` rvs, you must
351+
be aware that their ``size`` parameter is not the number of IID
352+
samples to draw from a distribution, but the desired ``shape`` of
353+
the returned array of samples. It is the user's responsibility to
354+
wrap the callable to make it comply with PyMC3's interpretation
355+
of ``size``.
420356
421357
422358
.. code-block:: python
@@ -443,7 +379,6 @@ def __init__(
443379
self.rand = random
444380
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
445381
self.check_shape_in_random = check_shape_in_random
446-
self.pymc3_size_interpretation = pymc3_size_interpretation
447382

448383
def random(self, point=None, size=None, **kwargs):
449384
if self.rand is not None:
@@ -464,11 +399,9 @@ def random(self, point=None, size=None, **kwargs):
464399
[dist_shape, test_shape],
465400
size=size
466401
)
467-
if self.pymc3_size_interpretation:
468-
len(broadcast_shape) - len(test_shape)
469-
broadcast_shape = broadcast_shape[
470-
:len(broadcast_shape) - len(test_shape)
471-
]
402+
broadcast_shape = broadcast_shape[
403+
:len(broadcast_shape) - len(test_shape)
404+
]
472405
samples = generate_samples(
473406
self.rand,
474407
broadcast_shape=broadcast_shape,

pymc3/tests/test_distributions_random.py

-20
Original file line numberDiff line numberDiff line change
@@ -981,26 +981,6 @@ def test_density_dist_with_random_sampleable_hidden_error(self, shape):
981981
assert len(ppc['density_dist']) == samples
982982
assert ((samples,) + obs.distribution.shape) != ppc['density_dist'].shape
983983

984-
@pytest.mark.parametrize("shape", [(), (3,), (3, 2)], ids=str)
985-
def test_density_dist_with_stats_random_sampleable(self, shape):
986-
with pm.Model() as model:
987-
mu = pm.Normal('mu', 0, 1)
988-
normal_dist = pm.Normal.dist(mu, 1, shape=shape)
989-
obs = pm.DensityDist(
990-
'density_dist',
991-
normal_dist.logp,
992-
observed=np.random.randn(100, *shape),
993-
shape=shape,
994-
random=st.norm.rvs,
995-
pymc3_size_interpretation=False
996-
)
997-
trace = pm.sample(100)
998-
999-
samples = 500
1000-
ppc = pm.sample_posterior_predictive(trace, samples=samples, model=model)
1001-
assert len(ppc['density_dist']) == samples
1002-
assert ((samples,) + obs.distribution.shape) == ppc['density_dist'].shape
1003-
1004984
def test_density_dist_with_random_sampleable_handcrafted_success(self):
1005985
with pm.Model() as model:
1006986
mu = pm.Normal('mu', 0, 1)

0 commit comments

Comments
 (0)