Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan over discrete latent variables causes tracer leak #1998

Open
fehiepsi opened this issue Mar 6, 2025 · 1 comment · May be fixed by #2002
Open

scan over discrete latent variables causes tracer leak #1998

fehiepsi opened this issue Mar 6, 2025 · 1 comment · May be fixed by #2002
Labels
bug Something isn't working

Comments

@fehiepsi
Copy link
Member

fehiepsi commented Mar 6, 2025

Bug Description

This is a part of the issues reported in #1981. Running the following test will raise an error/xfail.

Steps to Reproduce

JAX_CHECK_TRACER_LEAKS=1 pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke

Expected Behavior

The test should pass.

@fehiepsi fehiepsi added the bug Something isn't working label Mar 6, 2025
@fehiepsi
Copy link
Member Author

fehiepsi commented Mar 7, 2025

The reason seems to be caused by this line

with funsor.adjoint.AdjointTape() as tape:

where the stateful adjoint tape is not compatible with jax scan.

Switching back to lazy interpretations seems to fix the leakage but it makes some tests failing.

-    with funsor.adjoint.AdjointTape() as tape:
+    with funsor.interpretations.lazy:
         with block(), enum(first_available_dim=first_available_dim):
             log_prob, model_tr, log_measures = _enum_log_density(
                 model, args, kwargs, {}, sum_op, prod_op
             )
 
     with approx:
-        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
+        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant