Skip to content

Commit b766706

Browse files
authored
Mark xfail for leaked tracer tests (#1997)
* Mark xfail for leaked tracer tests * test_handlers::test_plate is renamed to test_jit_trace * remove tracer leak xfail for the current passing tests * remove all global jax live arrays * add issues for each tracer leak test * fix failing tests * fix jax core deprecation in provenance
1 parent 3b7d7f0 commit b766706

22 files changed

+210
-92
lines changed

Diff for: .github/workflows/ci.yml

+14
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@ jobs:
7878
- name: Test x64
7979
run: |
8080
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw
81+
- name: Test tracer leak
82+
if: matrix.python-version == '3.10'
83+
env:
84+
JAX_CHECK_TRACER_LEAKS: 1
85+
run: |
86+
pytest -vs test/contrib/einstein/test_steinvi.py::test_run_smoke -k ASVGD
87+
pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke
88+
pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
89+
pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke
90+
pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run
91+
pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths
92+
pytest -vs test/infer/test_svi.py::test_mutable_state
93+
pytest -vs test/test_distributions.py::test_mean_var -k Gompertz
94+
8195
- name: Coveralls
8296
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.10'
8397
uses: coverallsapp/github-action@v2

Diff for: numpyro/distributions/discrete.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def entropy(self):
944944
logq = -jax.nn.softplus(self.logits)
945945
logp = -jax.nn.softplus(-self.logits)
946946
p = jax.scipy.special.expit(self.logits)
947-
p_clip = jnp.clip(p, min=jnp.finfo(p).tiny)
947+
p_clip = jnp.clip(p, jnp.finfo(p).tiny)
948948
return -(1 - p) * logq / p_clip - logp
949949

950950

Diff for: numpyro/ops/provenance.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import jax
55
from jax.api_util import flatten_fun, shaped_abstractify
6-
import jax.core as core
76
from jax.experimental.pjit import pjit_p
87
import jax.util as util
98

@@ -12,11 +11,21 @@
1211
except ImportError:
1312
import jax.linear_util as lu
1413

14+
try:
15+
from jax.extend.core import Literal
16+
except ImportError:
17+
from jax.core import Literal
18+
1519
try:
1620
from jax.extend.core.primitives import call_p, closed_call_p
1721
except ImportError:
1822
from jax.core import call_p, closed_call_p
1923

24+
try:
25+
from jax.api_util import debug_info
26+
except ImportError:
27+
debug_info = None
28+
2029
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
2130
from jax.interpreters.pxla import xla_pmap_p
2231

@@ -44,14 +53,29 @@ def eval_provenance(fn, **kwargs):
4453
"""
4554
# Flatten the function and its arguments
4655
args, in_tree = jax.tree.flatten(((), kwargs))
47-
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn), in_tree)
56+
fn_info = (
57+
dict(debug_info=debug_info("eval_provenance fn", fn, (), kwargs))
58+
if debug_info is not None
59+
else {}
60+
)
61+
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn, **fn_info), in_tree)
4862
# Abstract eval to get output pytree
4963
avals = util.safe_map(shaped_abstractify, args)
5064
# XXX: we split out the process of abstract evaluation and provenance tracking
5165
# for simplicity. In principle, they can be merged so that we only need to walk
5266
# through the equations once.
67+
68+
wrapped_info = (
69+
dict(
70+
debug_info=debug_info(
71+
"eval_provenance wrapped", wrapped_fun.call_wrapped, args, {}
72+
)
73+
)
74+
if debug_info is not None
75+
else {}
76+
)
5377
jaxpr, avals_out, _ = trace_to_jaxpr_dynamic(
54-
lu.wrap_init(wrapped_fun.call_wrapped, {}), avals
78+
lu.wrap_init(wrapped_fun.call_wrapped, {}, **wrapped_info), avals
5579
)
5680

5781
# get provenances of flatten kwargs
@@ -69,12 +93,12 @@ def track_deps_jaxpr(jaxpr, provenance_inputs):
6993
env = {}
7094

7195
def read(v):
72-
if isinstance(v, core.Literal):
96+
if isinstance(v, Literal):
7397
return frozenset()
7498
return env.get(v, frozenset())
7599

76100
def write(v, p):
77-
if isinstance(v, core.Literal):
101+
if isinstance(v, Literal):
78102
return
79103
env[v] = read(v) | p
80104

Diff for: test/conftest.py

+8
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,22 @@
33

44
import os
55

6+
import jax
67
from jax import config
78

89
from numpyro.util import set_rng_seed
910

1011
config.update("jax_platform_name", "cpu") # noqa: E702
1112

1213

14+
SETUP_STATE = {"is_first_test": True}
15+
16+
1317
def pytest_runtest_setup(item):
18+
if SETUP_STATE["is_first_test"]:
19+
SETUP_STATE["is_first_test"] = False
20+
assert len(jax.live_arrays()) == 0
21+
1422
if "JAX_ENABLE_X64" in os.environ:
1523
config.update("jax_enable_x64", True)
1624
set_rng_seed(0)

Diff for: test/contrib/einstein/test_steinvi.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from collections import namedtuple
55
from functools import partial
6+
import os
67
import string
78

89
import numpy as np
@@ -119,6 +120,9 @@ def model(features, labels):
119120
@pytest.mark.parametrize("kernel", KERNELS)
120121
@pytest.mark.parametrize("problem", (uniform_normal, regression))
121122
@pytest.mark.parametrize("method", ("ASVGD", "SVGD", "SteinVI"))
123+
@pytest.mark.xfail(
124+
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", reason="Expected tracer leak"
125+
)
122126
def test_run_smoke(kernel, problem, method):
123127
true_coefs, data, model = problem()
124128
if method == "ASVGD":

Diff for: test/contrib/hsgp/test_laplacian.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,12 @@ def test_eigenfunctions(x: ArrayLike, ell: float | int, m: int | list[int]):
131131
(1, 2, False),
132132
([1, 1], 2, False),
133133
(np.array([1, 1])[..., None], 2, False),
134-
(jnp.array([1, 1])[..., None], 2, False),
134+
(np.array([1, 1])[..., None], 2, False),
135+
(np.array([1, 1]), 2, True),
135136
(np.array([1, 1]), 2, True),
136-
(jnp.array([1, 1]), 2, True),
137137
([1, 1], 1, True),
138138
(np.array([1, 1]), 1, True),
139-
(jnp.array([1, 1]), 1, True),
139+
(np.array([1, 1]), 1, True),
140140
],
141141
ids=[
142142
"ell-float",

Diff for: test/contrib/stochastic_support/test_dcc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
@pytest.mark.parametrize(
2121
"branch_dist",
22-
[dist.Normal(0, 1), dist.Gamma(1, 1)],
22+
[lambda: dist.Normal(0, 1), lambda: dist.Gamma(1, 1)],
2323
)
2424
@pytest.mark.xfail(raises=RuntimeError)
2525
def test_continuous_branching(branch_dist):
2626
rng_key = random.PRNGKey(0)
2727

2828
def model():
29-
model1 = numpyro.sample("model1", branch_dist, infer={"branching": True})
29+
model1 = numpyro.sample("model1", branch_dist(), infer={"branching": True})
3030
mean = 1.0 if model1 == 0 else 2.0
3131
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)
3232

Diff for: test/contrib/test_enum_elbo.py

-4
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,6 @@ def test_elbo_enumerate_plates_6(scale):
12471247

12481248
@config_enumerate
12491249
@handlers.scale(scale=scale)
1250-
@handlers.trace
12511250
def model_iplate_iplate(data, params):
12521251
probs_a = pyro.param(
12531252
"probs_a", params["probs_a"], constraint=constraints.simplex
@@ -1305,7 +1304,6 @@ def model_iplate_plate(data, params):
13051304

13061305
@config_enumerate
13071306
@handlers.scale(scale=scale)
1308-
@handlers.trace
13091307
def model_plate_iplate(data, params):
13101308
probs_a = pyro.param(
13111309
"probs_a", params["probs_a"], constraint=constraints.simplex
@@ -1423,7 +1421,6 @@ def test_elbo_enumerate_plates_7(scale):
14231421

14241422
@config_enumerate
14251423
@handlers.scale(scale=scale)
1426-
@handlers.trace
14271424
def model_iplate_iplate(data, params):
14281425
probs_a = pyro.param(
14291426
"probs_a", params["probs_a"], constraint=constraints.simplex
@@ -1489,7 +1486,6 @@ def model_iplate_plate(data, params):
14891486

14901487
@config_enumerate
14911488
@handlers.scale(scale=scale)
1492-
@handlers.trace
14931489
def model_plate_iplate(data, params):
14941490
probs_a = pyro.param(
14951491
"probs_a", params["probs_a"], constraint=constraints.simplex

Diff for: test/contrib/test_infer_discrete.py

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import logging
5+
import os
56

67
import numpy as np
78
from numpy.testing import assert_allclose
@@ -95,6 +96,10 @@ def hmm(data, hidden_dim=10):
9596
],
9697
)
9798
@pytest.mark.parametrize("temperature", [0, 1])
99+
@pytest.mark.xfail(
100+
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
101+
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/1998",
102+
)
98103
def test_scan_hmm_smoke(length, temperature):
99104
# This should match the example in the infer_discrete docstring.
100105
def hmm(data, hidden_dim=10):

Diff for: test/contrib/test_tfp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def model(data):
179179
(
180180
"ReplicaExchangeMC",
181181
dict(
182-
inverse_temperatures=0.5 ** jnp.arange(4), make_kernel_fn=make_kernel_fn
182+
inverse_temperatures=0.5 ** np.arange(4), make_kernel_fn=make_kernel_fn
183183
),
184184
),
185185
],

Diff for: test/infer/test_ensemble_mcmc.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import numpy as np
45
import pytest
56

67
import jax.numpy as jnp
@@ -15,10 +16,13 @@
1516
# reused for all smoke-tests
1617
N, dim = 3000, 3
1718

18-
data = random.normal(random.PRNGKey(0), (N, dim))
19-
true_coefs = jnp.arange(1.0, dim + 1.0)
20-
logits = jnp.sum(true_coefs * data, axis=-1)
21-
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
19+
data = np.random.default_rng(0).normal(N, dim)
20+
true_coefs = np.arange(1.0, dim + 1.0)
21+
logits = np.sum(true_coefs * data, axis=-1)
22+
23+
24+
def labels_maker():
25+
return dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
2226

2327

2428
def model(labels):
@@ -54,7 +58,7 @@ def test_chain_smoke(kernel_cls, n_chain, method):
5458
)
5559

5660
with pytest.raises(AssertionError, match="chain_method"):
57-
mcmc.run(random.PRNGKey(2), labels)
61+
mcmc.run(random.PRNGKey(2), labels_maker())
5862

5963

6064
@pytest.mark.parametrize("kernel_cls", [AIES, ESS])
@@ -70,7 +74,7 @@ def test_out_shape_smoke(kernel_cls):
7074
num_chains=n_chains,
7175
chain_method="vectorized",
7276
)
73-
mcmc.run(random.PRNGKey(2), labels)
77+
mcmc.run(random.PRNGKey(2), labels_maker())
7478

7579
assert mcmc.get_samples(group_by_chain=True)["coefs"].shape[0] == n_chains
7680

@@ -94,6 +98,7 @@ def test_multirun(kernel_cls):
9498
num_chains=n_chains,
9599
chain_method="vectorized",
96100
)
101+
labels = labels_maker()
97102
mcmc.run(random.PRNGKey(2), labels)
98103
mcmc.run(random.PRNGKey(3), labels)
99104

@@ -111,5 +116,6 @@ def test_warmup(kernel_cls):
111116
num_chains=n_chains,
112117
chain_method="vectorized",
113118
)
119+
labels = labels_maker()
114120
mcmc.warmup(random.PRNGKey(2), labels)
115121
mcmc.run(random.PRNGKey(3), labels)

Diff for: test/infer/test_gradient.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def guide_0(data, params):
3737
pyro.sample("z", dist.Categorical(probs))
3838

3939

40-
params_0 = {"probs": jnp.array([[0.4, 0.6], [0.5, 0.5]])}
40+
params_0 = {"probs": np.array([[0.4, 0.6], [0.5, 0.5]])}
4141

4242

4343
def model_1(data, params):
@@ -57,8 +57,8 @@ def guide_1(data, params):
5757

5858

5959
params_1 = {
60-
"probs_a": jnp.array([0.5, 0.5]),
61-
"probs_b": jnp.array([[[0.5, 0.5], [0.6, 0.4]], [[0.4, 0.6], [0.35, 0.65]]]),
60+
"probs_a": np.array([0.5, 0.5]),
61+
"probs_b": np.array([[[0.5, 0.5], [0.6, 0.4]], [[0.4, 0.6], [0.35, 0.65]]]),
6262
}
6363

6464

@@ -88,19 +88,19 @@ def guide_2(data, params):
8888

8989

9090
params_2 = {
91-
"probs_a": jnp.array([0.5, 0.5]),
92-
"probs_b": jnp.array([[0.4, 0.6], [0.3, 0.7]]),
93-
"probs_c": jnp.array([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]),
94-
"probs_d": jnp.array([[[0.2, 0.8], [0.9, 0.1]], [[0.1, 0.9], [0.4, 0.6]]]),
91+
"probs_a": np.array([0.5, 0.5]),
92+
"probs_b": np.array([[0.4, 0.6], [0.3, 0.7]]),
93+
"probs_c": np.array([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]),
94+
"probs_d": np.array([[[0.2, 0.8], [0.9, 0.1]], [[0.1, 0.9], [0.4, 0.6]]]),
9595
}
9696

9797

9898
@pytest.mark.parametrize(
9999
"model,guide,params,data",
100100
[
101-
(model_0, guide_0, params_0, jnp.array([-0.5, 2.0])),
102-
(model_1, guide_1, params_1, jnp.array([-0.5, 2.0])),
103-
(model_2, guide_2, params_2, jnp.array([0, 1])),
101+
(model_0, guide_0, params_0, np.array([-0.5, 2.0])),
102+
(model_1, guide_1, params_1, np.array([-0.5, 2.0])),
103+
(model_2, guide_2, params_2, np.array([0, 1])),
104104
],
105105
)
106106
def test_gradient(model, guide, params, data):

Diff for: test/infer/test_hmc_util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def register_fn(model):
125125
num_steps=100,
126126
q_i={"x": 0.0},
127127
p_i={"x": 1.0},
128-
q_f={"x": jnp.sin(1.0)},
129-
p_f={"x": jnp.cos(1.0)},
128+
q_f={"x": np.sin(1.0)},
129+
p_f={"x": np.cos(1.0)},
130130
m_inv=np.array([1.0]),
131131
prec=1e-4,
132132
)

0 commit comments

Comments
 (0)