Skip to content

Commit 8b3ecf9

Browse files
lucianopazrpgoldman
authored andcommitted
Wrap DensityDist's random with generate_samples (#3554)
* Fix for 3553. Wrapped random with generate_samples. Added tests and extra control in DensityDist. Extended the related docstring and made sure it was linked into the online docs. Make the docs for `generate_samples` and `draw_values` appear in the API manual pages. Also add docs for other bits of distributions/distribution.py Added point parameter to rand call. Updated release notes.
1 parent 517f10d commit 8b3ecf9

File tree

6 files changed

+363
-43
lines changed

6 files changed

+363
-43
lines changed

Diff for: RELEASE-NOTES.md

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
- SMC is no longer a step method of `pm.sample` now it should be called using `pm.sample_smc` [3579](https://github.com/pymc-devs/pymc3/pull/3579)
1818
- Now uses `multiprocessong` rather than `psutil` to count CPUs, which results in reliable core counts on Chromebooks.
1919
- `sample_posterior_predictive` now preallocates the memory required for its output to improve memory usage. Addresses problems raised in this [discourse thread](https://discourse.pymc.io/t/memory-error-with-posterior-predictive-sample/2891/4).
20+
- Fixed a bug in `Categorical.logp`. In the case of multidimensional `p`'s, the indexing was done wrong leading to incorrectly shaped tensors that consumed `O(n**2)` memory instead of `O(n)`. This fixes issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535)
21+
- Fixed a defect in `OrderedLogistic.__init__` that unnecessarily increased the dimensionality of the underlying `p`. Related to issue issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535) but was not the true cause of it.
22+
- Wrapped `DensityDist.rand` with `generate_samples` to make it aware of the distribution's shape. Added control flow attributes to still be able to behave as in earlier versions, and to control how to interpret the `size` parameter in the `random` callable signature. Fixes [3553](https://github.com/pymc-devs/pymc3/issues/3553)
2023

2124

2225
## PyMC3 3.7 (May 29 2019)

Diff for: docs/source/api/distributions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ Distributions
99
distributions/multivariate
1010
distributions/mixture
1111
distributions/timeseries
12+
distributions/utilities

Diff for: docs/source/api/distributions/utilities.rst

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
*******************************************
2+
Distribution utility classes and functions
3+
*******************************************
4+
5+
.. currentmodule:: pymc3.distributions
6+
.. autosummary::
7+
8+
Distribution
9+
Discrete
10+
Continuous
11+
NoDistribution
12+
DensityDist
13+
TensorType
14+
15+
draw_values
16+
generate_samples
17+
18+
19+
.. autoclass:: Distribution
20+
.. autoclass:: Discrete
21+
.. autoclass:: Continuous
22+
.. autoclass:: NoDistribution
23+
.. autoclass:: DensityDist
24+
:members:
25+
.. autofunction:: TensorType
26+
27+
.. autofunction:: draw_values
28+
.. autofunction:: generate_samples
29+

Diff for: pymc3/distributions/distribution.py

+241-20
Original file line numberDiff line numberDiff line change
@@ -198,29 +198,250 @@ class DensityDist(Distribution):
198198
199199
A distribution with the passed log density function is created.
200200
Requires a custom random function passed as kwarg `random` to
201-
enable sampling.
201+
enable prior or posterior predictive sampling.
202202
203-
Example:
203+
"""
204+
205+
def __init__(
206+
self,
207+
logp,
208+
shape=(),
209+
dtype=None,
210+
testval=0,
211+
random=None,
212+
wrap_random_with_dist_shape=True,
213+
check_shape_in_random=True,
214+
*args,
215+
**kwargs
216+
):
217+
"""
218+
Parameters
219+
----------
220+
221+
logp: callable
222+
A callable that has the following signature ``logp(value)`` and
223+
returns a theano tensor that represents the distribution's log
224+
probability density.
225+
shape: tuple (Optional): defaults to `()`
226+
The shape of the distribution. The default value indicates a scalar.
227+
If the distribution is *not* scalar-valued, the programmer should pass
228+
a value here.
229+
dtype: None, str (Optional)
230+
The dtype of the distribution.
231+
testval: number or array (Optional)
232+
The ``testval`` of the RV's tensor that follow the ``DensityDist``
233+
distribution.
234+
random: None or callable (Optional)
235+
If ``None``, no random method is attached to the ``DensityDist``
236+
instance.
237+
If a callable, it is used as the distribution's ``random`` method.
238+
The behavior of this callable can be altered with the
239+
``wrap_random_with_dist_shape`` parameter.
240+
The supplied callable must have the following signature:
241+
``random(size=None, **kwargs)``, where ``size`` is the number of
242+
IID draws to take from the distribution. Any extra keyword
243+
argument can be added as required.
244+
wrap_random_with_dist_shape: bool (Optional)
245+
If ``True``, the provided ``random`` callable is passed through
246+
``generate_samples`` to make the random number generator aware of
247+
the ``DensityDist`` instance's ``shape``.
248+
If ``False``, it is used exactly as it was provided.
249+
check_shape_in_random: bool (Optional)
250+
If ``True``, the shape of the random samples generate in the
251+
``random`` method is checked with the expected return shape. This
252+
test is only performed if ``wrap_random_with_dist_shape is False``.
253+
args, kwargs: (Optional)
254+
These are passed to the parent class' ``__init__``.
255+
256+
Note
257+
----
258+
If the ``random`` method is wrapped with dist shape, what this
259+
means is that the ``random`` callable will be wrapped with the
260+
:func:`~genereate_samples` function. The distribution's shape will
261+
be passed to :func:`~generate_samples` as the ``dist_shape``
262+
parameter. Any extra ``kwargs`` provided to ``random`` will be
263+
passed as ``not_broadcast_kwargs`` of :func:`~generate_samples`.
264+
265+
Examples
204266
--------
205-
.. code-block:: python
206-
with pm.Model():
207-
mu = pm.Normal('mu',0,1)
208-
normal_dist = pm.Normal.dist(mu, 1)
209-
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
210-
trace = pm.sample(100)
267+
.. code-block:: python
268+
269+
with pm.Model():
270+
mu = pm.Normal('mu',0,1)
271+
normal_dist = pm.Normal.dist(mu, 1)
272+
pm.DensityDist(
273+
'density_dist',
274+
normal_dist.logp,
275+
observed=np.random.randn(100),
276+
random=normal_dist.random
277+
)
278+
trace = pm.sample(100)
279+
280+
If the ``DensityDist`` is multidimensional, some care must be taken
281+
with the supplied ``random`` method. By default, the supplied random
282+
is wrapped by :func:`~generate_samples` to make it aware of the
283+
multidimensional distribution's shape.
284+
This can be prevented setting ``wrap_random_with_dist_shape=False``.
285+
Furthermore, the ``size`` parameter is interpreted as the number of
286+
IID draws to take from this multidimensional distribution.
287+
288+
289+
.. code-block:: python
290+
291+
with pm.Model():
292+
mu = pm.Normal('mu', 0 , 1)
293+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
294+
dens = pm.DensityDist(
295+
'density_dist',
296+
normal_dist.logp,
297+
observed=np.random.randn(100, 3),
298+
shape=3,
299+
random=normal_dist.random,
300+
)
301+
prior = pm.sample_prior_predictive(10)['density_dist']
302+
assert prior.shape == (10, 100, 3)
303+
304+
If ``wrap_random_with_dist_shape=False``, we start to get samples of
305+
an incorrect shape. By default, we can try to catch these situations.
211306
212-
"""
213307
214-
def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **kwargs):
308+
.. code-block:: python
309+
310+
with pm.Model():
311+
mu = pm.Normal('mu', 0 , 1)
312+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
313+
dens = pm.DensityDist(
314+
'density_dist',
315+
normal_dist.logp,
316+
observed=np.random.randn(100, 3),
317+
shape=3,
318+
random=normal_dist.random,
319+
wrap_random_with_dist_shape=False, # Is True by default
320+
)
321+
err = None
322+
try:
323+
prior = pm.sample_prior_predictive(10)['density_dist']
324+
except RuntimeError as e:
325+
err = e
326+
assert isinstance(err, RuntimeError)
327+
328+
The default catching can be disabled with the
329+
``check_shape_in_random`` parameter.
330+
331+
332+
.. code-block:: python
333+
334+
with pm.Model():
335+
mu = pm.Normal('mu', 0 , 1)
336+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
337+
dens = pm.DensityDist(
338+
'density_dist',
339+
normal_dist.logp,
340+
observed=np.random.randn(100, 3),
341+
shape=3,
342+
random=normal_dist.random,
343+
wrap_random_with_dist_shape=False, # Is True by default
344+
check_shape_in_random=False, # Is True by default
345+
)
346+
prior = pm.sample_prior_predictive(10)['density_dist']
347+
# We get samples with an incorrect shape
348+
assert prior.shape != (10, 100, 3)
349+
350+
If you use callables that work with ``scipy.stats`` rvs, you must
351+
be aware that their ``size`` parameter is not the number of IID
352+
samples to draw from a distribution, but the desired ``shape`` of
353+
the returned array of samples. It is the user's responsibility to
354+
wrap the callable to make it comply with PyMC3's interpretation
355+
of ``size``.
356+
357+
358+
.. code-block:: python
359+
360+
with pm.Model():
361+
mu = pm.Normal('mu', 0 , 1)
362+
normal_dist = pm.Normal.dist(mu, 1, shape=3)
363+
dens = pm.DensityDist(
364+
'density_dist',
365+
normal_dist.logp,
366+
observed=np.random.randn(100, 3),
367+
shape=3,
368+
random=stats.norm.rvs,
369+
pymc3_size_interpretation=False, # Is True by default
370+
)
371+
prior = pm.sample_prior_predictive(10)['density_dist']
372+
assert prior.shape == (10, 100, 3)
373+
374+
"""
215375
if dtype is None:
216376
dtype = theano.config.floatX
217377
super().__init__(shape, dtype, testval, *args, **kwargs)
218378
self.logp = logp
219379
self.rand = random
380+
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
381+
self.check_shape_in_random = check_shape_in_random
220382

221-
def random(self, *args, **kwargs):
383+
def random(self, point=None, size=None, **kwargs):
222384
if self.rand is not None:
223-
return self.rand(*args, **kwargs)
385+
not_broadcast_kwargs = dict(point=point)
386+
not_broadcast_kwargs.update(**kwargs)
387+
if self.wrap_random_with_dist_shape:
388+
size = to_tuple(size)
389+
with _DrawValuesContextBlocker():
390+
test_draw = generate_samples(
391+
self.rand,
392+
size=None,
393+
not_broadcast_kwargs=not_broadcast_kwargs,
394+
)
395+
test_shape = test_draw.shape
396+
if self.shape[:len(size)] == size:
397+
dist_shape = size + self.shape
398+
else:
399+
dist_shape = self.shape
400+
broadcast_shape = broadcast_dist_samples_shape(
401+
[dist_shape, test_shape],
402+
size=size
403+
)
404+
broadcast_shape = broadcast_shape[
405+
:len(broadcast_shape) - len(test_shape)
406+
]
407+
samples = generate_samples(
408+
self.rand,
409+
broadcast_shape=broadcast_shape,
410+
size=size,
411+
not_broadcast_kwargs=not_broadcast_kwargs,
412+
)
413+
else:
414+
samples = self.rand(point=point, size=size, **kwargs)
415+
if self.check_shape_in_random:
416+
expected_shape = (
417+
self.shape
418+
if size is None else
419+
to_tuple(size) + self.shape
420+
)
421+
if not expected_shape == samples.shape:
422+
raise RuntimeError(
423+
"DensityDist encountered a shape inconsistency "
424+
"while drawing samples using the supplied random "
425+
"function. Was expecting to get samples of shape "
426+
"{expected} but got {got} instead.\n"
427+
"Whenever possible wrap_random_with_dist_shape = True "
428+
"is recommended.\n"
429+
"Be aware that the random callable provided as the "
430+
"DensityDist random method cannot "
431+
"adapt to shape changes in the distribution's "
432+
"shape, which sometimes are necessary for sampling "
433+
"when the model uses pymc3.Data or theano shared "
434+
"tensors, or when the DensityDist has observed "
435+
"values.\n"
436+
"This check can be disabled by passing "
437+
"check_shape_in_random=False when the DensityDist "
438+
"is initialized.".
439+
format(
440+
expected=expected_shape,
441+
got=samples.shape,
442+
)
443+
)
444+
return samples
224445
else:
225446
raise ValueError("Distribution was not passed any random method "
226447
"Define a custom random method and pass it as kwarg random")
@@ -290,17 +511,17 @@ def draw_values(params, point=None, size=None):
290511
Draw (fix) parameter values. Handles a number of cases:
291512
292513
1) The parameter is a scalar
293-
2) The parameter is an *RV
514+
2) The parameter is an RV
294515
295516
a) parameter can be fixed to the value in the point
296-
b) parameter can be fixed by sampling from the *RV
517+
b) parameter can be fixed by sampling from the RV
297518
c) parameter can be fixed using tag.test_value (last resort)
298519
299520
3) The parameter is a tensor variable/constant. Can be evaluated using
300521
theano.function, but a variable may contain nodes which
301522
302523
a) are named parameters in the point
303-
b) are *RVs with a random method
524+
b) are RVs with a random method
304525
"""
305526
# Get fast drawable values (i.e. things in point or numbers, arrays,
306527
# constants or shares, or things that were already drawn in related
@@ -646,13 +867,13 @@ def generate_samples(generator, *args, **kwargs):
646867
generator : function
647868
Function to generate the random samples. The function is
648869
expected take parameters for generating samples and
649-
a keyword argument `size` which determines the shape
870+
a keyword argument ``size`` which determines the shape
650871
of the samples.
651-
The *args and **kwargs (stripped of the keywords below) will be
872+
The args and kwargs (stripped of the keywords below) will be
652873
passed to the generator function.
653874
654875
keyword arguments
655-
~~~~~~~~~~~~~~~~
876+
~~~~~~~~~~~~~~~~~
656877
657878
dist_shape : int or tuple of int
658879
The shape of the random variable (i.e., the shape attribute).
@@ -666,9 +887,9 @@ def generate_samples(generator, *args, **kwargs):
666887
the shape of the probabilities in the Categorical distribution.
667888
not_broadcast_kwargs: dict or None
668889
Key word argument dictionary to provide to the random generator, which
669-
must not be broadcasted with the rest of the *args and **kwargs.
890+
must not be broadcasted with the rest of the args and kwargs.
670891
671-
Any remaining *args and **kwargs are passed on to the generator function.
892+
Any remaining args and kwargs are passed on to the generator function.
672893
"""
673894
dist_shape = kwargs.pop('dist_shape', ())
674895
one_d = _is_one_d(dist_shape)

0 commit comments

Comments
 (0)