Skip to content

NumPyro sampling fails in CI #4853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
twiecki opened this issue Jul 11, 2021 · 2 comments
Closed

NumPyro sampling fails in CI #4853

twiecki opened this issue Jul 11, 2021 · 2 comments
Labels

Comments

@twiecki
Copy link
Member

twiecki commented Jul 11, 2021


pymc3/tests/test_sampling_jax.py::test_transform_samples FAILED          [100%]

=================================== FAILURES ===================================
____________________________ test_transform_samples ____________________________

    def test_transform_samples():
        aesara.config.on_opt_error = "raise"
        np.random.seed(13244)
    
        obs = np.random.normal(10, 2, size=100)
        obs_at = aesara.shared(obs, borrow=True, name="obs")
        with pm.Model() as model:
            a = pm.Uniform("a", -20, 20)
            sigma = pm.HalfNormal("sigma")
            b = pm.Normal("b", a, sigma=sigma, observed=obs_at)
    
>           trace = sample_numpyro_nuts(chains=1, random_seed=1322, keep_untransformed=True)

pymc3/tests/test_sampling_jax.py:20: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
pymc3/sampling_jax.py:212: in sample_numpyro_nuts
    _sample = compile_rv_inplace(
pymc3/aesaraf.py:888: in compile_rv_inplace
    aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/compile/function/__init__.py:337: in function
    fn = pfunc(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/compile/function/pfunc.py:524: in pfunc
    return orig_function(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/compile/function/types.py:1983: in orig_function
    fn = m.create(defaults)
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/compile/function/types.py:1838: in create
    _fn, _i, _o = self.linker.make_thunk(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/link/basic.py:282: in make_thunk
    return self.make_all(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/link/basic.py:739: in make_all
    thunks, nodes = self.create_jitable_thunk(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/link/basic.py:683: in create_jitable_thunk
    converted_fgraph = self.fgraph_convert(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/link/jax/linker.py:13: in fgraph_convert
    return jax_funcify(fgraph, **kwargs)
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/functools.py:877: in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:597: in jax_funcify_FunctionGraph
    return fgraph_to_python(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/aesara/link/utils.py:718: in fgraph_to_python
    compiled_func = op_conversion_fn(
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/functools.py:877: in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
pymc3/sampling_jax.py:87: in jax_funcify_NumPyroNUTS
    from numpyro.infer import MCMC, NUTS
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/numpyro/__init__.py:4: in <module>
    from numpyro import compat, diagnostics, distributions, handlers, infer, optim
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/numpyro/distributions/__init__.py:4: in <module>
    from numpyro.distributions.conjugate import BetaBinomial, GammaPoisson
/usr/share/miniconda/envs/pymc3-dev-py39/lib/python3.9/site-packages/numpyro/distributions/conjugate.py:10: in <module>
    from numpyro.distributions.discrete import (
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    import numpy as np
    
    from jax import device_put, lax
    from jax.dtypes import canonicalize_dtype
    from jax.nn import softmax
    import jax.numpy as jnp
    import jax.random as random
    from jax.scipy.special import expit, gammaln, logsumexp, xlog1py, xlogy
    
    from numpyro.distributions import constraints
    from numpyro.distributions.distribution import Distribution
>   from numpyro.distributions.util import (
        binary_cross_entropy_with_logits,
        binomial,
        categorical,
        clamp_probs,
        get_dtype,
        lazy_property,
        multinomial,
        promote_shapes,
        sum_rightmost,
        validate_sample,
    )
E   ImportError: cannot import name 'get_dtype' from 'numpyro.distributions.util' (/usr/share/miniconda
@michaelosthege
Copy link
Member

It did not fail in the latest runs on main. Does anybody know which commit fixed it? The trace above looks like a deterministic problem...

@twiecki
Copy link
Member Author

twiecki commented Jul 17, 2021

Odd, well if it's fixed that all that matters I guess 🤞 .

@twiecki twiecki closed this as completed Jul 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants