@@ -198,29 +198,250 @@ class DensityDist(Distribution):
198
198
199
199
A distribution with the passed log density function is created.
200
200
Requires a custom random function passed as kwarg `random` to
201
- enable sampling.
201
+ enable prior or posterior predictive sampling.
202
202
203
- Example:
203
+ """
204
+
205
+ def __init__ (
206
+ self ,
207
+ logp ,
208
+ shape = (),
209
+ dtype = None ,
210
+ testval = 0 ,
211
+ random = None ,
212
+ wrap_random_with_dist_shape = True ,
213
+ check_shape_in_random = True ,
214
+ * args ,
215
+ ** kwargs
216
+ ):
217
+ """
218
+ Parameters
219
+ ----------
220
+
221
+ logp: callable
222
+ A callable that has the following signature ``logp(value)`` and
223
+ returns a theano tensor that represents the distribution's log
224
+ probability density.
225
+ shape: tuple (Optional): defaults to `()`
226
+ The shape of the distribution. The default value indicates a scalar.
227
+ If the distribution is *not* scalar-valued, the programmer should pass
228
+ a value here.
229
+ dtype: None, str (Optional)
230
+ The dtype of the distribution.
231
+ testval: number or array (Optional)
232
+ The ``testval`` of the RV's tensor that follow the ``DensityDist``
233
+ distribution.
234
+ random: None or callable (Optional)
235
+ If ``None``, no random method is attached to the ``DensityDist``
236
+ instance.
237
+ If a callable, it is used as the distribution's ``random`` method.
238
+ The behavior of this callable can be altered with the
239
+ ``wrap_random_with_dist_shape`` parameter.
240
+ The supplied callable must have the following signature:
241
+ ``random(size=None, **kwargs)``, where ``size`` is the number of
242
+ IID draws to take from the distribution. Any extra keyword
243
+ argument can be added as required.
244
+ wrap_random_with_dist_shape: bool (Optional)
245
+ If ``True``, the provided ``random`` callable is passed through
246
+ ``generate_samples`` to make the random number generator aware of
247
+ the ``DensityDist`` instance's ``shape``.
248
+ If ``False``, it is used exactly as it was provided.
249
+ check_shape_in_random: bool (Optional)
250
+ If ``True``, the shape of the random samples generate in the
251
+ ``random`` method is checked with the expected return shape. This
252
+ test is only performed if ``wrap_random_with_dist_shape is False``.
253
+ args, kwargs: (Optional)
254
+ These are passed to the parent class' ``__init__``.
255
+
256
+ Note
257
+ ----
258
+ If the ``random`` method is wrapped with dist shape, what this
259
+ means is that the ``random`` callable will be wrapped with the
260
+ :func:`~genereate_samples` function. The distribution's shape will
261
+ be passed to :func:`~generate_samples` as the ``dist_shape``
262
+ parameter. Any extra ``kwargs`` provided to ``random`` will be
263
+ passed as ``not_broadcast_kwargs`` of :func:`~generate_samples`.
264
+
265
+ Examples
204
266
--------
205
- .. code-block:: python
206
- with pm.Model():
207
- mu = pm.Normal('mu',0,1)
208
- normal_dist = pm.Normal.dist(mu, 1)
209
- pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
210
- trace = pm.sample(100)
267
+ .. code-block:: python
268
+
269
+ with pm.Model():
270
+ mu = pm.Normal('mu',0,1)
271
+ normal_dist = pm.Normal.dist(mu, 1)
272
+ pm.DensityDist(
273
+ 'density_dist',
274
+ normal_dist.logp,
275
+ observed=np.random.randn(100),
276
+ random=normal_dist.random
277
+ )
278
+ trace = pm.sample(100)
279
+
280
+ If the ``DensityDist`` is multidimensional, some care must be taken
281
+ with the supplied ``random`` method. By default, the supplied random
282
+ is wrapped by :func:`~generate_samples` to make it aware of the
283
+ multidimensional distribution's shape.
284
+ This can be prevented setting ``wrap_random_with_dist_shape=False``.
285
+ Furthermore, the ``size`` parameter is interpreted as the number of
286
+ IID draws to take from this multidimensional distribution.
287
+
288
+
289
+ .. code-block:: python
290
+
291
+ with pm.Model():
292
+ mu = pm.Normal('mu', 0 , 1)
293
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
294
+ dens = pm.DensityDist(
295
+ 'density_dist',
296
+ normal_dist.logp,
297
+ observed=np.random.randn(100, 3),
298
+ shape=3,
299
+ random=normal_dist.random,
300
+ )
301
+ prior = pm.sample_prior_predictive(10)['density_dist']
302
+ assert prior.shape == (10, 100, 3)
303
+
304
+ If ``wrap_random_with_dist_shape=False``, we start to get samples of
305
+ an incorrect shape. By default, we can try to catch these situations.
211
306
212
- """
213
307
214
- def __init__ (self , logp , shape = (), dtype = None , testval = 0 , random = None , * args , ** kwargs ):
308
+ .. code-block:: python
309
+
310
+ with pm.Model():
311
+ mu = pm.Normal('mu', 0 , 1)
312
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
313
+ dens = pm.DensityDist(
314
+ 'density_dist',
315
+ normal_dist.logp,
316
+ observed=np.random.randn(100, 3),
317
+ shape=3,
318
+ random=normal_dist.random,
319
+ wrap_random_with_dist_shape=False, # Is True by default
320
+ )
321
+ err = None
322
+ try:
323
+ prior = pm.sample_prior_predictive(10)['density_dist']
324
+ except RuntimeError as e:
325
+ err = e
326
+ assert isinstance(err, RuntimeError)
327
+
328
+ The default catching can be disabled with the
329
+ ``check_shape_in_random`` parameter.
330
+
331
+
332
+ .. code-block:: python
333
+
334
+ with pm.Model():
335
+ mu = pm.Normal('mu', 0 , 1)
336
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
337
+ dens = pm.DensityDist(
338
+ 'density_dist',
339
+ normal_dist.logp,
340
+ observed=np.random.randn(100, 3),
341
+ shape=3,
342
+ random=normal_dist.random,
343
+ wrap_random_with_dist_shape=False, # Is True by default
344
+ check_shape_in_random=False, # Is True by default
345
+ )
346
+ prior = pm.sample_prior_predictive(10)['density_dist']
347
+ # We get samples with an incorrect shape
348
+ assert prior.shape != (10, 100, 3)
349
+
350
+ If you use callables that work with ``scipy.stats`` rvs, you must
351
+ be aware that their ``size`` parameter is not the number of IID
352
+ samples to draw from a distribution, but the desired ``shape`` of
353
+ the returned array of samples. It is the user's responsibility to
354
+ wrap the callable to make it comply with PyMC3's interpretation
355
+ of ``size``.
356
+
357
+
358
+ .. code-block:: python
359
+
360
+ with pm.Model():
361
+ mu = pm.Normal('mu', 0 , 1)
362
+ normal_dist = pm.Normal.dist(mu, 1, shape=3)
363
+ dens = pm.DensityDist(
364
+ 'density_dist',
365
+ normal_dist.logp,
366
+ observed=np.random.randn(100, 3),
367
+ shape=3,
368
+ random=stats.norm.rvs,
369
+ pymc3_size_interpretation=False, # Is True by default
370
+ )
371
+ prior = pm.sample_prior_predictive(10)['density_dist']
372
+ assert prior.shape == (10, 100, 3)
373
+
374
+ """
215
375
if dtype is None :
216
376
dtype = theano .config .floatX
217
377
super ().__init__ (shape , dtype , testval , * args , ** kwargs )
218
378
self .logp = logp
219
379
self .rand = random
380
+ self .wrap_random_with_dist_shape = wrap_random_with_dist_shape
381
+ self .check_shape_in_random = check_shape_in_random
220
382
221
- def random (self , * args , ** kwargs ):
383
+ def random (self , point = None , size = None , ** kwargs ):
222
384
if self .rand is not None :
223
- return self .rand (* args , ** kwargs )
385
+ not_broadcast_kwargs = dict (point = point )
386
+ not_broadcast_kwargs .update (** kwargs )
387
+ if self .wrap_random_with_dist_shape :
388
+ size = to_tuple (size )
389
+ with _DrawValuesContextBlocker ():
390
+ test_draw = generate_samples (
391
+ self .rand ,
392
+ size = None ,
393
+ not_broadcast_kwargs = not_broadcast_kwargs ,
394
+ )
395
+ test_shape = test_draw .shape
396
+ if self .shape [:len (size )] == size :
397
+ dist_shape = size + self .shape
398
+ else :
399
+ dist_shape = self .shape
400
+ broadcast_shape = broadcast_dist_samples_shape (
401
+ [dist_shape , test_shape ],
402
+ size = size
403
+ )
404
+ broadcast_shape = broadcast_shape [
405
+ :len (broadcast_shape ) - len (test_shape )
406
+ ]
407
+ samples = generate_samples (
408
+ self .rand ,
409
+ broadcast_shape = broadcast_shape ,
410
+ size = size ,
411
+ not_broadcast_kwargs = not_broadcast_kwargs ,
412
+ )
413
+ else :
414
+ samples = self .rand (point = point , size = size , ** kwargs )
415
+ if self .check_shape_in_random :
416
+ expected_shape = (
417
+ self .shape
418
+ if size is None else
419
+ to_tuple (size ) + self .shape
420
+ )
421
+ if not expected_shape == samples .shape :
422
+ raise RuntimeError (
423
+ "DensityDist encountered a shape inconsistency "
424
+ "while drawing samples using the supplied random "
425
+ "function. Was expecting to get samples of shape "
426
+ "{expected} but got {got} instead.\n "
427
+ "Whenever possible wrap_random_with_dist_shape = True "
428
+ "is recommended.\n "
429
+ "Be aware that the random callable provided as the "
430
+ "DensityDist random method cannot "
431
+ "adapt to shape changes in the distribution's "
432
+ "shape, which sometimes are necessary for sampling "
433
+ "when the model uses pymc3.Data or theano shared "
434
+ "tensors, or when the DensityDist has observed "
435
+ "values.\n "
436
+ "This check can be disabled by passing "
437
+ "check_shape_in_random=False when the DensityDist "
438
+ "is initialized." .
439
+ format (
440
+ expected = expected_shape ,
441
+ got = samples .shape ,
442
+ )
443
+ )
444
+ return samples
224
445
else :
225
446
raise ValueError ("Distribution was not passed any random method "
226
447
"Define a custom random method and pass it as kwarg random" )
@@ -290,17 +511,17 @@ def draw_values(params, point=None, size=None):
290
511
Draw (fix) parameter values. Handles a number of cases:
291
512
292
513
1) The parameter is a scalar
293
- 2) The parameter is an * RV
514
+ 2) The parameter is an RV
294
515
295
516
a) parameter can be fixed to the value in the point
296
- b) parameter can be fixed by sampling from the * RV
517
+ b) parameter can be fixed by sampling from the RV
297
518
c) parameter can be fixed using tag.test_value (last resort)
298
519
299
520
3) The parameter is a tensor variable/constant. Can be evaluated using
300
521
theano.function, but a variable may contain nodes which
301
522
302
523
a) are named parameters in the point
303
- b) are * RVs with a random method
524
+ b) are RVs with a random method
304
525
"""
305
526
# Get fast drawable values (i.e. things in point or numbers, arrays,
306
527
# constants or shares, or things that were already drawn in related
@@ -646,13 +867,13 @@ def generate_samples(generator, *args, **kwargs):
646
867
generator : function
647
868
Function to generate the random samples. The function is
648
869
expected take parameters for generating samples and
649
- a keyword argument `size` which determines the shape
870
+ a keyword argument `` size` ` which determines the shape
650
871
of the samples.
651
- The * args and ** kwargs (stripped of the keywords below) will be
872
+ The args and kwargs (stripped of the keywords below) will be
652
873
passed to the generator function.
653
874
654
875
keyword arguments
655
- ~~~~~~~~~~~~~~~~
876
+ ~~~~~~~~~~~~~~~~~
656
877
657
878
dist_shape : int or tuple of int
658
879
The shape of the random variable (i.e., the shape attribute).
@@ -666,9 +887,9 @@ def generate_samples(generator, *args, **kwargs):
666
887
the shape of the probabilities in the Categorical distribution.
667
888
not_broadcast_kwargs: dict or None
668
889
Key word argument dictionary to provide to the random generator, which
669
- must not be broadcasted with the rest of the * args and ** kwargs.
890
+ must not be broadcasted with the rest of the args and kwargs.
670
891
671
- Any remaining * args and ** kwargs are passed on to the generator function.
892
+ Any remaining args and kwargs are passed on to the generator function.
672
893
"""
673
894
dist_shape = kwargs .pop ('dist_shape' , ())
674
895
one_d = _is_one_d (dist_shape )
0 commit comments