@@ -198,16 +198,7 @@ class DensityDist(Distribution):
198
198
199
199
A distribution with the passed log density function is created.
200
200
Requires a custom random function passed as kwarg `random` to
201
- enable sampling.
202
-
203
- Example:
204
- --------
205
- .. code-block:: python
206
- with pm.Model():
207
- mu = pm.Normal('mu',0,1)
208
- normal_dist = pm.Normal.dist(mu, 1)
209
- pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
210
- trace = pm.sample(100)
201
+ enable prior or posterior predictive sampling.
211
202
212
203
"""
213
204
@@ -220,6 +211,7 @@ def __init__(
220
211
random = None ,
221
212
wrap_random_with_dist_shape = True ,
222
213
check_shape_in_random = True ,
214
+ pymc3_size_interpretation = True ,
223
215
* args ,
224
216
** kwargs
225
217
):
@@ -246,6 +238,10 @@ def __init__(
246
238
If a callable, it is used as the distribution's ``random`` method.
247
239
The behavior of this callable can be altered with the
248
240
``wrap_random_with_dist_shape`` parameter.
241
+ The supplied callable must have the following signature:
242
+ ``random(size=None, **kwargs)``, where ``size`` is the number of
243
+ IID draws to take from the distribution. Any extra keyword
244
+ argument can be added as required.
249
245
wrap_random_with_dist_shape: bool (Optional)
250
246
If ``True``, the provided ``random`` callable is passed through
251
247
``generate_samples`` to make the random number generator aware of
@@ -255,15 +251,190 @@ def __init__(
255
251
If ``True``, the shape of the random samples generate in the
256
252
``random`` method is checked with the expected return shape. This
257
253
test is only performed if ``wrap_random_with_dist_shape is False``.
254
+ pymc3_size_interpretation: bool (Optional)
255
+ This flag affects how the ``size`` parameter supplied to the
256
+ passed ``random`` function is interpreted. If no ``random``
257
+ callable is supplied, this flag is ignored. Furthermore, this flag
258
+ is only used if ``wrap_random_with_dist_shape`` is `True``.
259
+ If ``True``, the ``size`` parameter of ``random`` is interpreted
260
+ as the number of IID draws to take from a given distribution,
261
+ which is how every pymc3 distribution interprets the ``size``
262
+ parameter.
263
+ If it is ``False``, the ``size`` parameter is used just like in
264
+ any scipy random variate generator, i.e. the shape of the returned
265
+ array of samples.
266
+ The difference is subtle and is only relevant in multidimensional
267
+ distributions. Consider that a multivariate normal of rank 3 is
268
+ used as the random number generator. With the pymc3 interpretation
269
+ of ``size``, a call like ``random(size=10)`` will return an array
270
+ with shape ``(10, 3)``. With the scipy interpretation, to get a
271
+ similarly shaped result, the call must be changed to
272
+ ``random(size=(10, 3))``.
273
+ A quick rule of thumb is that if the supplied ``random`` callable
274
+ is a pymc3 distribution's random method, you should set this flag
275
+ to ``True``. If you use a scipy rvs, you should set this flag to
276
+ ``False``. Refer to the examples for more information.
258
277
args, kwargs: (Optional)
259
278
These are passed to the parent class' ``__init__``.
279
+
260
280
Note
261
281
----
262
- If the ``random`` method is wrapped with dist shape, what this means
263
- is that the ``random`` callable will be wrapped with the
282
+ If the ``random`` method is wrapped with dist shape, what this
283
+ means is that the ``random`` callable will be wrapped with the
264
284
:func:`~genereate_samples` function. The distribution's shape will
265
285
be passed to :func:`~generate_samples` as the ``dist_shape``
266
- parameter.
286
+ parameter. Any extra ``kwargs`` provided to ``random`` will be
287
+ passed as ``not_broadcast_kwargs`` of :func:`~generate_samples`.
288
+
289
+ Examples
290
+ --------
291
+ .. code-block:: python
292
+
293
+ with pm.Model():
294
+ mu = pm.Normal('mu',0,1)
295
+ normal_dist = pm.Normal.dist(mu, 1)
296
+ pm.DensityDist(
297
+ 'density_dist',
298
+ normal_dist.logp,
299
+ observed=np.random.randn(100),
300
+ random=normal_dist.random
301
+ )
302
+ trace = pm.sample(100)
303
+
304
+ If the ``DensityDist`` is multidimensional, some care must be taken
305
+ with the supplied ``random`` method. By default, the supplied random
306
+ is wrapped by :func:`~generate_samples` to make it aware of the
307
+ multidimensional distribution's shape.
308
+ This can be prevented setting ``wrap_random_with_dist_shape=False``.
309
+ Furthermore, the ``size`` parameter is interpreted as the number of
310
+ IID draws to take from this multidimensional distribution.
311
+
312
+
313
+ .. code-block:: python
314
+
315
+ with pm.Model():
316
+ mu = pm.Normal('mu', 0 , 1)
317
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
318
+ dens = pm.DensityDist(
319
+ 'density_dist',
320
+ normal_dist.logp,
321
+ observed=np.random.randn(100, 3),
322
+ shape=3,
323
+ random=normal_dist.random,
324
+ )
325
+ prior = pm.sample_prior_predictive(10)['density_dist']
326
+ assert prior.shape == (10, 100, 3)
327
+
328
+ If ``wrap_random_with_dist_shape=False``, we start to get samples of
329
+ an incorrect shape. By default, we can try to catch these situations.
330
+
331
+
332
+ .. code-block:: python
333
+
334
+ with pm.Model():
335
+ mu = pm.Normal('mu', 0 , 1)
336
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
337
+ dens = pm.DensityDist(
338
+ 'density_dist',
339
+ normal_dist.logp,
340
+ observed=np.random.randn(100, 3),
341
+ shape=3,
342
+ random=normal_dist.random,
343
+ wrap_random_with_dist_shape=False, # Is True by default
344
+ )
345
+ err = None
346
+ try:
347
+ prior = pm.sample_prior_predictive(10)['density_dist']
348
+ except RuntimeError as e:
349
+ err = e
350
+ assert isinstance(err, RuntimeError)
351
+
352
+ The default catching can be disabled with the
353
+ ``check_shape_in_random`` parameter.
354
+
355
+
356
+ .. code-block:: python
357
+
358
+ with pm.Model():
359
+ mu = pm.Normal('mu', 0 , 1)
360
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
361
+ dens = pm.DensityDist(
362
+ 'density_dist',
363
+ normal_dist.logp,
364
+ observed=np.random.randn(100, 3),
365
+ shape=3,
366
+ random=normal_dist.random,
367
+ wrap_random_with_dist_shape=False, # Is True by default
368
+ check_shape_in_random=False, # Is True by default
369
+ )
370
+ prior = pm.sample_prior_predictive(10)['density_dist']
371
+ # We get samples with an incorrect shape
372
+ assert prior.shape != (10, 100, 3)
373
+
374
+ The default catching can be disabled with the
375
+ ``check_shape_in_random`` parameter.
376
+
377
+
378
+ .. code-block:: python
379
+
380
+ with pm.Model():
381
+ mu = pm.Normal('mu', 0 , 1)
382
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
383
+ dens = pm.DensityDist(
384
+ 'density_dist',
385
+ normal_dist.logp,
386
+ observed=np.random.randn(100, 3),
387
+ shape=3,
388
+ random=normal_dist.random,
389
+ wrap_random_with_dist_shape=False, # Is True by default
390
+ check_shape_in_random=False, # Is True by default
391
+ )
392
+ prior = pm.sample_prior_predictive(10)['density_dist']
393
+ # We get samples with an incorrect shape
394
+ assert prior.shape != (10, 100, 3)
395
+
396
+ One final word of caution. If you use a pymc3 distribution's random
397
+ method, you should stick with ``pymc3_size_interpretation=True``, or
398
+ you will get incorrectly shaped samples.
399
+
400
+
401
+ .. code-block:: python
402
+
403
+ with pm.Model():
404
+ mu = pm.Normal('mu', 0 , 1)
405
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
406
+ dens = pm.DensityDist(
407
+ 'density_dist',
408
+ normal_dist.logp,
409
+ observed=np.random.randn(100, 3),
410
+ shape=3,
411
+ random=normal_dist.random,
412
+ pymc3_size_interpretation=False, # Is True by default
413
+ )
414
+ prior = pm.sample_prior_predictive(10)['density_dist']
415
+ assert prior.shape != (10, 100, 3)
416
+ assert prior.shape == (10, 100, 3, 3)
417
+
418
+ If you use callables that work with ``scipy.stats`` rvs, you should
419
+ set ``pymc3_size_interpretation=False``.
420
+
421
+
422
+ .. code-block:: python
423
+
424
+ with pm.Model():
425
+ mu = pm.Normal('mu', 0 , 1)
426
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
427
+ dens = pm.DensityDist(
428
+ 'density_dist',
429
+ normal_dist.logp,
430
+ observed=np.random.randn(100, 3),
431
+ shape=3,
432
+ random=stats.norm.rvs,
433
+ pymc3_size_interpretation=False, # Is True by default
434
+ )
435
+ prior = pm.sample_prior_predictive(10)['density_dist']
436
+ assert prior.shape == (10, 100, 3)
437
+
267
438
"""
268
439
if dtype is None :
269
440
dtype = theano .config .floatX
@@ -272,20 +443,41 @@ def __init__(
272
443
self .rand = random
273
444
self .wrap_random_with_dist_shape = wrap_random_with_dist_shape
274
445
self .check_shape_in_random = check_shape_in_random
446
+ self .pymc3_size_interpretation = pymc3_size_interpretation
275
447
276
- def random (self , * args , ** kwargs ):
448
+ def random (self , point = None , size = None , ** kwargs ):
277
449
if self .rand is not None :
278
450
if self .wrap_random_with_dist_shape :
451
+ size = to_tuple (size )
452
+ with _DrawValuesContextBlocker ():
453
+ test_draw = generate_samples (
454
+ self .rand ,
455
+ size = None ,
456
+ not_broadcast_kwargs = kwargs ,
457
+ )
458
+ test_shape = test_draw .shape
459
+ if self .shape [:len (size )] == size :
460
+ dist_shape = size + self .shape
461
+ else :
462
+ dist_shape = self .shape
463
+ broadcast_shape = broadcast_dist_samples_shape (
464
+ [dist_shape , test_shape ],
465
+ size = size
466
+ )
467
+ if self .pymc3_size_interpretation :
468
+ len (broadcast_shape ) - len (test_shape )
469
+ broadcast_shape = broadcast_shape [
470
+ :len (broadcast_shape ) - len (test_shape )
471
+ ]
279
472
samples = generate_samples (
280
- self .rand , dist_shape = self .shape , * args , ** kwargs
473
+ self .rand ,
474
+ broadcast_shape = broadcast_shape ,
475
+ size = size ,
476
+ not_broadcast_kwargs = kwargs ,
281
477
)
282
478
else :
283
- samples = self .rand (* args , ** kwargs )
479
+ samples = self .rand (size = size , ** kwargs )
284
480
if self .check_shape_in_random :
285
- try :
286
- size = args [1 ]
287
- except IndexError :
288
- size = kwargs .get ("size" , None )
289
481
expected_shape = (
290
482
self .shape
291
483
if size is None else
@@ -384,17 +576,17 @@ def draw_values(params, point=None, size=None):
384
576
Draw (fix) parameter values. Handles a number of cases:
385
577
386
578
1) The parameter is a scalar
387
- 2) The parameter is an * RV
579
+ 2) The parameter is an RV
388
580
389
581
a) parameter can be fixed to the value in the point
390
- b) parameter can be fixed by sampling from the * RV
582
+ b) parameter can be fixed by sampling from the RV
391
583
c) parameter can be fixed using tag.test_value (last resort)
392
584
393
585
3) The parameter is a tensor variable/constant. Can be evaluated using
394
586
theano.function, but a variable may contain nodes which
395
587
396
588
a) are named parameters in the point
397
- b) are * RVs with a random method
589
+ b) are RVs with a random method
398
590
"""
399
591
# Get fast drawable values (i.e. things in point or numbers, arrays,
400
592
# constants or shares, or things that were already drawn in related
@@ -740,13 +932,13 @@ def generate_samples(generator, *args, **kwargs):
740
932
generator : function
741
933
Function to generate the random samples. The function is
742
934
expected take parameters for generating samples and
743
- a keyword argument `size` which determines the shape
935
+ a keyword argument `` size` ` which determines the shape
744
936
of the samples.
745
- The * args and ** kwargs (stripped of the keywords below) will be
937
+ The args and kwargs (stripped of the keywords below) will be
746
938
passed to the generator function.
747
939
748
940
keyword arguments
749
- ~~~~~~~~~~~~~~~~
941
+ ~~~~~~~~~~~~~~~~~
750
942
751
943
dist_shape : int or tuple of int
752
944
The shape of the random variable (i.e., the shape attribute).
@@ -760,9 +952,9 @@ def generate_samples(generator, *args, **kwargs):
760
952
the shape of the probabilities in the Categorical distribution.
761
953
not_broadcast_kwargs: dict or None
762
954
Key word argument dictionary to provide to the random generator, which
763
- must not be broadcasted with the rest of the * args and ** kwargs.
955
+ must not be broadcasted with the rest of the args and kwargs.
764
956
765
- Any remaining * args and ** kwargs are passed on to the generator function.
957
+ Any remaining args and kwargs are passed on to the generator function.
766
958
"""
767
959
dist_shape = kwargs .pop ('dist_shape' , ())
768
960
one_d = _is_one_d (dist_shape )
0 commit comments