Skip to content

Commit bdc6e85

Browse files
committed
Add expand kwarg to Censored.change_size
1 parent 9055b1e commit bdc6e85

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

pymc/distributions/censored.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ def ndim_supp(cls, *dist_params):
100100
return 0
101101

102102
@classmethod
103-
def change_size(cls, rv, new_size):
103+
def change_size(cls, rv, new_size, expand=False):
104104
dist_node = rv.tag.dist.owner
105105
lower = rv.tag.lower
106106
upper = rv.tag.upper
107107
rng, old_size, dtype, *dist_params = dist_node.inputs
108+
new_size = new_size if not expand else tuple(new_size) + tuple(old_size)
108109
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
109110
return cls.rv_op(new_dist, lower, upper)
110111

pymc/tests/test_distributions.py

+9
Original file line numberDiff line numberDiff line change
@@ -3357,6 +3357,15 @@ def test_censored_invalid_dist(self):
33573357
):
33583358
x = pm.Censored("x", registered_dist, lower=None, upper=None)
33593359

3360+
def test_change_size(self):
3361+
base_dist = pm.Censored.dist(pm.Normal.dist(), -1, 1, size=(3, 2))
3362+
3363+
new_dist = pm.Censored.change_size(base_dist, (4,))
3364+
assert new_dist.eval().shape == (4,)
3365+
3366+
new_dist = pm.Censored.change_size(base_dist, (4,), expand=True)
3367+
assert new_dist.eval().shape == (4, 3, 2)
3368+
33603369

33613370
class TestLKJCholeskCov:
33623371
def test_dist(self):

0 commit comments

Comments
 (0)