Skip to content

Commit 16a1d76

Browse files
lucianopazrpgoldman
authored andcommitted
Added point parameter to rand call
1 parent 8f74ea9 commit 16a1d76

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

Diff for: pymc3/distributions/distribution.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,15 @@ def __init__(
382382

383383
def random(self, point=None, size=None, **kwargs):
384384
if self.rand is not None:
385+
not_broadcast_kwargs = dict(point=point)
386+
not_broadcast_kwargs.update(**kwargs)
385387
if self.wrap_random_with_dist_shape:
386388
size = to_tuple(size)
387389
with _DrawValuesContextBlocker():
388390
test_draw = generate_samples(
389391
self.rand,
390392
size=None,
391-
not_broadcast_kwargs=kwargs,
393+
not_broadcast_kwargs=not_broadcast_kwargs,
392394
)
393395
test_shape = test_draw.shape
394396
if self.shape[:len(size)] == size:
@@ -406,10 +408,10 @@ def random(self, point=None, size=None, **kwargs):
406408
self.rand,
407409
broadcast_shape=broadcast_shape,
408410
size=size,
409-
not_broadcast_kwargs=kwargs,
411+
not_broadcast_kwargs=not_broadcast_kwargs,
410412
)
411413
else:
412-
samples = self.rand(size=size, **kwargs)
414+
samples = self.rand(point=point, size=size, **kwargs)
413415
if self.check_shape_in_random:
414416
expected_shape = (
415417
self.shape

0 commit comments

Comments
 (0)