Skip to content

Commit 28bac77

Browse files
ArmavicaricardoV94
authored andcommitted
Remove test redundant with test_mixture.py
1 parent ab680a5 commit 28bac77

File tree

1 file changed

+1
-59
lines changed

1 file changed

+1
-59
lines changed

pymc/tests/logprob/test_joint_logprob.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
joint_logp,
6363
)
6464
from pymc.logprob.utils import rvs_to_value_vars, walk_model
65-
from pymc.tests.helpers import assert_no_rvs, select_by_precision
65+
from pymc.tests.helpers import assert_no_rvs
6666
from pymc.tests.logprob.utils import joint_logprob
6767

6868

@@ -409,64 +409,6 @@ def test_joint_logp_incsubtensor(indices, size):
409409
np.testing.assert_almost_equal(logp_vals, exp_obs_logps)
410410

411411

412-
def test_joint_logp_subtensor():
413-
"""Make sure we can compute a log-likelihood for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
414-
415-
size = 5
416-
417-
mu_base = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size)
418-
mu = np.stack([mu_base, -mu_base])
419-
sigma = 0.001
420-
rng = pytensor.shared(np.random.RandomState(232), borrow=True)
421-
422-
A_rv = pm.Normal.dist(mu, sigma, rng=rng)
423-
A_rv.name = "A"
424-
425-
p = 0.5
426-
427-
I_rv = pm.Bernoulli.dist(p, size=size, rng=rng)
428-
I_rv.name = "I"
429-
430-
A_idx = A_rv[I_rv, at.ogrid[A_rv.shape[-1] :]]
431-
432-
assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
433-
434-
A_idx_value_var = A_idx.type()
435-
A_idx_value_var.name = "A_idx_value"
436-
437-
I_value_var = I_rv.type()
438-
I_value_var.name = "I_value"
439-
440-
A_idx_logps = joint_logp(
441-
(A_idx, I_rv),
442-
rvs_to_values={A_idx: A_idx_value_var, I_rv: I_value_var},
443-
rvs_to_transforms={},
444-
rvs_to_total_sizes={},
445-
)
446-
A_idx_logp = at.add(*A_idx_logps)
447-
448-
logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp)
449-
450-
# The compiled graph should not contain any `RandomVariables`
451-
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])
452-
453-
decimals = select_by_precision(float64=6, float32=4)
454-
455-
for i in range(10):
456-
bern_sp = sp.bernoulli(p)
457-
I_value = bern_sp.rvs(size=size).astype(I_rv.dtype)
458-
459-
norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
460-
A_idx_value = norm_sp.rvs().astype(A_idx.dtype)
461-
462-
exp_obs_logps = norm_sp.logpdf(A_idx_value)
463-
exp_obs_logps += bern_sp.logpmf(I_value)
464-
465-
logp_vals = logp_vals_fn(A_idx_value, I_value)
466-
467-
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)
468-
469-
470412
def test_logp_helper():
471413
value = at.vector("value")
472414
x = pm.Normal.dist(0, 1)

0 commit comments

Comments
 (0)