Skip to content

Commit 740e6fe

Browse files
lucianopazrpgoldman
authored andcommitted
Added fix and documented it
1 parent 04e6ef3 commit 740e6fe

File tree

2 files changed

+223
-28
lines changed

2 files changed

+223
-28
lines changed

docs/source/api/distributions/utilities.rst

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Distribution utility classes and functions
99
Discrete
1010
Continuous
1111
NoDistribution
12+
DensityDist
1213
TensorType
1314

1415
draw_values
@@ -19,6 +20,8 @@ Distribution utility classes and functions
1920
.. autoclass:: Discrete
2021
.. autoclass:: Continuous
2122
.. autoclass:: NoDistribution
23+
.. autoclass:: DensityDist
24+
:members:
2225
.. autofunction:: TensorType
2326

2427
.. autofunction:: draw_values

pymc3/distributions/distribution.py

+220-28
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,7 @@ class DensityDist(Distribution):
198198
199199
A distribution with the passed log density function is created.
200200
Requires a custom random function passed as kwarg `random` to
201-
enable sampling.
202-
203-
Example:
204-
--------
205-
.. code-block:: python
206-
with pm.Model():
207-
mu = pm.Normal('mu',0,1)
208-
normal_dist = pm.Normal.dist(mu, 1)
209-
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
210-
trace = pm.sample(100)
201+
enable prior or posterior predictive sampling.
211202
212203
"""
213204

@@ -220,6 +211,7 @@ def __init__(
220211
random=None,
221212
wrap_random_with_dist_shape=True,
222213
check_shape_in_random=True,
214+
pymc3_size_interpretation=True,
223215
*args,
224216
**kwargs
225217
):
@@ -246,6 +238,10 @@ def __init__(
246238
If a callable, it is used as the distribution's ``random`` method.
247239
The behavior of this callable can be altered with the
248240
``wrap_random_with_dist_shape`` parameter.
241+
The supplied callable must have the following signature:
242+
``random(size=None, **kwargs)``, where ``size`` is the number of
243+
IID draws to take from the distribution. Any extra keyword
244+
argument can be added as required.
249245
wrap_random_with_dist_shape: bool (Optional)
250246
If ``True``, the provided ``random`` callable is passed through
251247
``generate_samples`` to make the random number generator aware of
@@ -255,15 +251,190 @@ def __init__(
255251
If ``True``, the shape of the random samples generate in the
256252
``random`` method is checked with the expected return shape. This
257253
test is only performed if ``wrap_random_with_dist_shape is False``.
254+
pymc3_size_interpretation: bool (Optional)
255+
This flag affects how the ``size`` parameter supplied to the
256+
passed ``random`` function is interpreted. If no ``random``
257+
callable is supplied, this flag is ignored. Furthermore, this flag
258+
is only used if ``wrap_random_with_dist_shape`` is `True``.
259+
If ``True``, the ``size`` parameter of ``random`` is interpreted
260+
as the number of IID draws to take from a given distribution,
261+
which is how every pymc3 distribution interprets the ``size``
262+
parameter.
263+
If it is ``False``, the ``size`` parameter is used just like in
264+
any scipy random variate generator, i.e. the shape of the returned
265+
array of samples.
266+
The difference is subtle and is only relevant in multidimensional
267+
distributions. Consider that a multivariate normal of rank 3 is
268+
used as the random number generator. With the pymc3 interpretation
269+
of ``size``, a call like ``random(size=10)`` will return an array
270+
with shape ``(10, 3)``. With the scipy interpretation, to get a
271+
similarly shaped result, the call must be changed to
272+
``random(size=(10, 3))``.
273+
A quick rule of thumb is that if the supplied ``random`` callable
274+
is a pymc3 distribution's random method, you should set this flag
275+
to ``True``. If you use a scipy rvs, you should set this flag to
276+
``False``. Refer to the examples for more information.
258277
args, kwargs: (Optional)
259278
These are passed to the parent class' ``__init__``.
279+
260280
Note
261281
----
262-
If the ``random`` method is wrapped with dist shape, what this means
263-
is that the ``random`` callable will be wrapped with the
282+
If the ``random`` method is wrapped with dist shape, what this
283+
means is that the ``random`` callable will be wrapped with the
264284
:func:`~genereate_samples` function. The distribution's shape will
265285
be passed to :func:`~generate_samples` as the ``dist_shape``
266-
parameter.
286+
parameter. Any extra ``kwargs`` provided to ``random`` will be
287+
passed as ``not_broadcast_kwargs`` of :func:`~generate_samples`.
288+
289+
Examples
290+
--------
291+
.. code-block:: python
292+
293+
with pm.Model():
294+
mu = pm.Normal('mu',0,1)
295+
normal_dist = pm.Normal.dist(mu, 1)
296+
pm.DensityDist(
297+
'density_dist',
298+
normal_dist.logp,
299+
observed=np.random.randn(100),
300+
random=normal_dist.random
301+
)
302+
trace = pm.sample(100)
303+
304+
If the ``DensityDist`` is multidimensional, some care must be taken
305+
with the supplied ``random`` method. By default, the supplied random
306+
is wrapped by :func:`~generate_samples` to make it aware of the
307+
multidimensional distribution's shape.
308+
This can be prevented setting ``wrap_random_with_dist_shape=False``.
309+
Furthermore, the ``size`` parameter is interpreted as the number of
310+
IID draws to take from this multidimensional distribution.
311+
312+
313+
.. code-block:: python
314+
315+
with pm.Model():
316+
mu = pm.Normal('mu', 0 , 1)
317+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
318+
dens = pm.DensityDist(
319+
'density_dist',
320+
normal_dist.logp,
321+
observed=np.random.randn(100, 3),
322+
shape=3,
323+
random=normal_dist.random,
324+
)
325+
prior = pm.sample_prior_predictive(10)['density_dist']
326+
assert prior.shape == (10, 100, 3)
327+
328+
If ``wrap_random_with_dist_shape=False``, we start to get samples of
329+
an incorrect shape. By default, we can try to catch these situations.
330+
331+
332+
.. code-block:: python
333+
334+
with pm.Model():
335+
mu = pm.Normal('mu', 0 , 1)
336+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
337+
dens = pm.DensityDist(
338+
'density_dist',
339+
normal_dist.logp,
340+
observed=np.random.randn(100, 3),
341+
shape=3,
342+
random=normal_dist.random,
343+
wrap_random_with_dist_shape=False, # Is True by default
344+
)
345+
err = None
346+
try:
347+
prior = pm.sample_prior_predictive(10)['density_dist']
348+
except RuntimeError as e:
349+
err = e
350+
assert isinstance(err, RuntimeError)
351+
352+
The default catching can be disabled with the
353+
``check_shape_in_random`` parameter.
354+
355+
356+
.. code-block:: python
357+
358+
with pm.Model():
359+
mu = pm.Normal('mu', 0 , 1)
360+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
361+
dens = pm.DensityDist(
362+
'density_dist',
363+
normal_dist.logp,
364+
observed=np.random.randn(100, 3),
365+
shape=3,
366+
random=normal_dist.random,
367+
wrap_random_with_dist_shape=False, # Is True by default
368+
check_shape_in_random=False, # Is True by default
369+
)
370+
prior = pm.sample_prior_predictive(10)['density_dist']
371+
# We get samples with an incorrect shape
372+
assert prior.shape != (10, 100, 3)
373+
374+
The default catching can be disabled with the
375+
``check_shape_in_random`` parameter.
376+
377+
378+
.. code-block:: python
379+
380+
with pm.Model():
381+
mu = pm.Normal('mu', 0 , 1)
382+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
383+
dens = pm.DensityDist(
384+
'density_dist',
385+
normal_dist.logp,
386+
observed=np.random.randn(100, 3),
387+
shape=3,
388+
random=normal_dist.random,
389+
wrap_random_with_dist_shape=False, # Is True by default
390+
check_shape_in_random=False, # Is True by default
391+
)
392+
prior = pm.sample_prior_predictive(10)['density_dist']
393+
# We get samples with an incorrect shape
394+
assert prior.shape != (10, 100, 3)
395+
396+
One final word of caution. If you use a pymc3 distribution's random
397+
method, you should stick with ``pymc3_size_interpretation=True``, or
398+
you will get incorrectly shaped samples.
399+
400+
401+
.. code-block:: python
402+
403+
with pm.Model():
404+
mu = pm.Normal('mu', 0 , 1)
405+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
406+
dens = pm.DensityDist(
407+
'density_dist',
408+
normal_dist.logp,
409+
observed=np.random.randn(100, 3),
410+
shape=3,
411+
random=normal_dist.random,
412+
pymc3_size_interpretation=False, # Is True by default
413+
)
414+
prior = pm.sample_prior_predictive(10)['density_dist']
415+
assert prior.shape != (10, 100, 3)
416+
assert prior.shape == (10, 100, 3, 3)
417+
418+
If you use callables that work with ``scipy.stats`` rvs, you should
419+
set ``pymc3_size_interpretation=False``.
420+
421+
422+
.. code-block:: python
423+
424+
with pm.Model():
425+
mu = pm.Normal('mu', 0 , 1)
426+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
427+
dens = pm.DensityDist(
428+
'density_dist',
429+
normal_dist.logp,
430+
observed=np.random.randn(100, 3),
431+
shape=3,
432+
random=stats.norm.rvs,
433+
pymc3_size_interpretation=False, # Is True by default
434+
)
435+
prior = pm.sample_prior_predictive(10)['density_dist']
436+
assert prior.shape == (10, 100, 3)
437+
267438
"""
268439
if dtype is None:
269440
dtype = theano.config.floatX
@@ -272,20 +443,41 @@ def __init__(
272443
self.rand = random
273444
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
274445
self.check_shape_in_random = check_shape_in_random
446+
self.pymc3_size_interpretation = pymc3_size_interpretation
275447

276-
def random(self, *args, **kwargs):
448+
def random(self, point=None, size=None, **kwargs):
277449
if self.rand is not None:
278450
if self.wrap_random_with_dist_shape:
451+
size = to_tuple(size)
452+
with _DrawValuesContextBlocker():
453+
test_draw = generate_samples(
454+
self.rand,
455+
size=None,
456+
not_broadcast_kwargs=kwargs,
457+
)
458+
test_shape = test_draw.shape
459+
if self.shape[:len(size)] == size:
460+
dist_shape = size + self.shape
461+
else:
462+
dist_shape = self.shape
463+
broadcast_shape = broadcast_dist_samples_shape(
464+
[dist_shape, test_shape],
465+
size=size
466+
)
467+
if self.pymc3_size_interpretation:
468+
len(broadcast_shape) - len(test_shape)
469+
broadcast_shape = broadcast_shape[
470+
:len(broadcast_shape) - len(test_shape)
471+
]
279472
samples = generate_samples(
280-
self.rand, dist_shape=self.shape, *args, **kwargs
473+
self.rand,
474+
broadcast_shape=broadcast_shape,
475+
size=size,
476+
not_broadcast_kwargs=kwargs,
281477
)
282478
else:
283-
samples = self.rand(*args, **kwargs)
479+
samples = self.rand(size=size, **kwargs)
284480
if self.check_shape_in_random:
285-
try:
286-
size = args[1]
287-
except IndexError:
288-
size = kwargs.get("size", None)
289481
expected_shape = (
290482
self.shape
291483
if size is None else
@@ -384,17 +576,17 @@ def draw_values(params, point=None, size=None):
384576
Draw (fix) parameter values. Handles a number of cases:
385577
386578
1) The parameter is a scalar
387-
2) The parameter is an *RV
579+
2) The parameter is an RV
388580
389581
a) parameter can be fixed to the value in the point
390-
b) parameter can be fixed by sampling from the *RV
582+
b) parameter can be fixed by sampling from the RV
391583
c) parameter can be fixed using tag.test_value (last resort)
392584
393585
3) The parameter is a tensor variable/constant. Can be evaluated using
394586
theano.function, but a variable may contain nodes which
395587
396588
a) are named parameters in the point
397-
b) are *RVs with a random method
589+
b) are RVs with a random method
398590
"""
399591
# Get fast drawable values (i.e. things in point or numbers, arrays,
400592
# constants or shares, or things that were already drawn in related
@@ -740,13 +932,13 @@ def generate_samples(generator, *args, **kwargs):
740932
generator : function
741933
Function to generate the random samples. The function is
742934
expected take parameters for generating samples and
743-
a keyword argument `size` which determines the shape
935+
a keyword argument ``size`` which determines the shape
744936
of the samples.
745-
The *args and **kwargs (stripped of the keywords below) will be
937+
The args and kwargs (stripped of the keywords below) will be
746938
passed to the generator function.
747939
748940
keyword arguments
749-
~~~~~~~~~~~~~~~~
941+
~~~~~~~~~~~~~~~~~
750942
751943
dist_shape : int or tuple of int
752944
The shape of the random variable (i.e., the shape attribute).
@@ -760,9 +952,9 @@ def generate_samples(generator, *args, **kwargs):
760952
the shape of the probabilities in the Categorical distribution.
761953
not_broadcast_kwargs: dict or None
762954
Key word argument dictionary to provide to the random generator, which
763-
must not be broadcasted with the rest of the *args and **kwargs.
955+
must not be broadcasted with the rest of the args and kwargs.
764956
765-
Any remaining *args and **kwargs are passed on to the generator function.
957+
Any remaining args and kwargs are passed on to the generator function.
766958
"""
767959
dist_shape = kwargs.pop('dist_shape', ())
768960
one_d = _is_one_d(dist_shape)

0 commit comments

Comments
 (0)