|
62 | 62 | joint_logp,
|
63 | 63 | )
|
64 | 64 | from pymc.logprob.utils import rvs_to_value_vars, walk_model
|
65 |
| -from pymc.tests.helpers import assert_no_rvs, select_by_precision |
| 65 | +from pymc.tests.helpers import assert_no_rvs |
66 | 66 | from pymc.tests.logprob.utils import joint_logprob
|
67 | 67 |
|
68 | 68 |
|
@@ -409,64 +409,6 @@ def test_joint_logp_incsubtensor(indices, size):
|
409 | 409 | np.testing.assert_almost_equal(logp_vals, exp_obs_logps)
|
410 | 410 |
|
411 | 411 |
|
412 |
| -def test_joint_logp_subtensor(): |
413 |
| - """Make sure we can compute a log-likelihood for ``Y[I]`` where ``Y`` and ``I`` are random variables.""" |
414 |
| - |
415 |
| - size = 5 |
416 |
| - |
417 |
| - mu_base = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size) |
418 |
| - mu = np.stack([mu_base, -mu_base]) |
419 |
| - sigma = 0.001 |
420 |
| - rng = pytensor.shared(np.random.RandomState(232), borrow=True) |
421 |
| - |
422 |
| - A_rv = pm.Normal.dist(mu, sigma, rng=rng) |
423 |
| - A_rv.name = "A" |
424 |
| - |
425 |
| - p = 0.5 |
426 |
| - |
427 |
| - I_rv = pm.Bernoulli.dist(p, size=size, rng=rng) |
428 |
| - I_rv.name = "I" |
429 |
| - |
430 |
| - A_idx = A_rv[I_rv, at.ogrid[A_rv.shape[-1] :]] |
431 |
| - |
432 |
| - assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) |
433 |
| - |
434 |
| - A_idx_value_var = A_idx.type() |
435 |
| - A_idx_value_var.name = "A_idx_value" |
436 |
| - |
437 |
| - I_value_var = I_rv.type() |
438 |
| - I_value_var.name = "I_value" |
439 |
| - |
440 |
| - A_idx_logps = joint_logp( |
441 |
| - (A_idx, I_rv), |
442 |
| - rvs_to_values={A_idx: A_idx_value_var, I_rv: I_value_var}, |
443 |
| - rvs_to_transforms={}, |
444 |
| - rvs_to_total_sizes={}, |
445 |
| - ) |
446 |
| - A_idx_logp = at.add(*A_idx_logps) |
447 |
| - |
448 |
| - logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp) |
449 |
| - |
450 |
| - # The compiled graph should not contain any `RandomVariables` |
451 |
| - assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0]) |
452 |
| - |
453 |
| - decimals = select_by_precision(float64=6, float32=4) |
454 |
| - |
455 |
| - for i in range(10): |
456 |
| - bern_sp = sp.bernoulli(p) |
457 |
| - I_value = bern_sp.rvs(size=size).astype(I_rv.dtype) |
458 |
| - |
459 |
| - norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma) |
460 |
| - A_idx_value = norm_sp.rvs().astype(A_idx.dtype) |
461 |
| - |
462 |
| - exp_obs_logps = norm_sp.logpdf(A_idx_value) |
463 |
| - exp_obs_logps += bern_sp.logpmf(I_value) |
464 |
| - |
465 |
| - logp_vals = logp_vals_fn(A_idx_value, I_value) |
466 |
| - |
467 |
| - np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals) |
468 |
| - |
469 |
| - |
470 | 412 | def test_logp_helper():
|
471 | 413 | value = at.vector("value")
|
472 | 414 | x = pm.Normal.dist(0, 1)
|
|
0 commit comments