Skip to content

The tag 'local_logsoftmax' is already present in the database. #4645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
JohnGoertz opened this issue Apr 15, 2021 · 6 comments
Closed

The tag 'local_logsoftmax' is already present in the database. #4645

JohnGoertz opened this issue Apr 15, 2021 · 6 comments
Labels

Comments

@JohnGoertz
Copy link

If you have questions about a specific use case, or you are not sure whether this is a bug or not, please post it to our discourse channel: https://discourse.pymc.io

sampling_jax isn't imported via pymc3.init, and manual import raises exception

Not sure if this was intentional or not, but from pymc3.sampling_jax import * is missing from pymc3.init. Manually importing it raises the following exception, looks like it comes from the latest JAX version. Maybe except AttributeError needs just needs to be amended to include Exception?

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-61-7010767e2f65> in <module>
----> 1 from pymc3 import sampling_jax

~/miniconda3/envs/candas/lib/python3.8/site-packages/pymc3/sampling_jax.py in <module>
     16 import theano.graph.fg
     17 
---> 18 from theano.link.jax.jax_dispatch import jax_funcify
     19 
     20 import pymc3 as pm

~/miniconda3/envs/candas/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py in <module>
     84 # Older versions < 0.2.0 do not have this flag so we don't need to set it.
     85 try:
---> 86     jax.config.disable_omnistaging()
     87 except AttributeError:
     88     pass

~/miniconda3/envs/candas/lib/python3.8/site-packages/jax/config.py in disable_omnistaging(self)
    165 
    166   def disable_omnistaging(self):
--> 167     raise Exception(
    168       "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: "
    169       "see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.")

Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.

Versions and main components

  • PyMC3 Version: 3.11.2
  • JAX Version: 0.2.12
  • Theano Version: 1.1.2
  • Operating system: Ubuntu (WSL on Windows 10)
  • How did you install PyMC3: (conda)
@PedroSebe
Copy link

I faced the same issue in a Google Colab environment. I had PyMC3 updated to 3.11.2 using pip and the same version of JAX.

@twiecki
Copy link
Member

twiecki commented Jul 19, 2021

This is fixed in a recent aesara version. Can you try installing aesara and pymc3 main?

@twiecki twiecki closed this as completed Jul 19, 2021
@Bodisatva
Copy link

Sorry but It doesn't look like it's fixed. I was experiencing se same error messages as #4645 (comment) with the same configuration. After uninstalling pymc3 completely, I reinstalled aesara and pymc3. I still couldn't have the jax sampler working (same error message) but I also experienced more issues. Now the pymc3 sampler send me this error from aesara:

/tmp/ipykernel_15166/2622855677.py in fit(self, X, y, weigth, regular_sampling, **kwds)
     47         with self.hierarchical_model:
     48 
---> 49             if regular_sampling:hierarchical_trace = pm.sample(2000, tune=1000, target_accept=0.9, return_inferencedata=True, **kwds)
     50             else: hierarchical_trace = pm.sampling_jax.sample_numpyro_nuts(2000,tune=1000, target_accept=0.9, **kwd)
     51 

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    637         if idata_kwargs:
    638             ikwargs.update(idata_kwargs)
--> 639         idata = arviz.from_pymc3(trace, **ikwargs)
    640 
    641     if compute_convergence_checks:

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/arviz/data/io_pymc3.py in from_pymc3(trace, prior, posterior_predictive, log_likelihood, coords, dims, model, save_warmup, density_dist_obs)
    561     InferenceData
    562     """
--> 563     return PyMC3Converter(
    564         trace=trace,
    565         prior=prior,

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/arviz/data/io_pymc3.py in __init__(self, trace, prior, posterior_predictive, log_likelihood, predictions, coords, dims, model, save_warmup, density_dist_obs)
     75 
     76         try:
---> 77             import aesara  # pylint: disable=redefined-outer-name
     78         except ImportError:
     79             import theano as aesara

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/__init__.py in <module>
     77 __api_version__ = 1
     78 
---> 79 from aesara import scalar, tensor
     80 from aesara.compile import (
     81     In,

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/tensor/__init__.py in <module>
     52 from aesara.gradient import consider_constant, grad, hessian, jacobian
     53 from aesara.tensor import sharedvar  # adds shared-variable constructors
---> 54 from aesara.tensor import (
     55     basic_opt,
     56     blas,

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/tensor/nnet/__init__.py in <module>
     10     separable_conv2d,
     11 )
---> 12 from aesara.tensor.nnet.basic import (
     13     binary_crossentropy,
     14     categorical_crossentropy,

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/tensor/nnet/basic.py in <module>
    803 @register_specialize("stabilize", "fast_compile")
    804 @local_optimizer([Elemwise])
--> 805 def local_logsoftmax(fgraph, node):
    806     """
    807     Detect Log(Softmax(x)) and replace it with LogSoftmax(x)

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/tensor/basic_opt.py in register(inner_lopt)
    583 
    584         def register(inner_lopt):
--> 585             return register_specialize(inner_lopt, lopt, *tags, **kwargs)
    586 
    587         return register

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/tensor/basic_opt.py in register_specialize(lopt, *tags, **kwargs)
    588     else:
    589         name = kwargs.pop("name", None) or lopt.__name__
--> 590         compile.optdb["specialize"].register(name, lopt, "fast_run", *tags, **kwargs)
    591         return lopt
    592 

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/graph/optdb.py in register(self, name, obj, final_opt, cleanup, *tags, **kwtags)
    349         if final_opt and cleanup:
    350             raise ValueError("`final_opt` and `cleanup` cannot both be true.")
--> 351         super().register(name, obj, *tags, **kwtags)
    352         self.__final__[name] = final_opt
    353         self.__cleanup__[name] = cleanup

~/.local/share/virtualenvs/python3.8-CXITuAic/lib/python3.8/site-packages/aesara/graph/optdb.py in register(self, name, optimizer, use_db_name_as_tag, *tags)
     63 
     64         if name in self.__db__:
---> 65             raise ValueError(f"The tag '{name}' is already present in the database.")
     66 
     67         if use_db_name_as_tag:

ValueError: The tag 'local_logsoftmax' is already present in the database.

@twiecki
Copy link
Member

twiecki commented Jul 28, 2021

This looks like an aesara issue, so I'm moving the issue.

@twiecki twiecki reopened this Jul 28, 2021
@twiecki twiecki changed the title sampling_jax import issues The tag 'local_logsoftmax' is already present in the database. Jul 28, 2021
@twiecki
Copy link
Member

twiecki commented Jul 28, 2021

Ah, can't transfer to a different org. Can you open an issue there with the traceback?

@twiecki twiecki closed this as completed Jul 28, 2021
@Bodisatva
Copy link

Ah, can't transfer to a different org. Can you open an issue there with the traceback?

Will do with the tag: The tag 'local_logsoftmax' is already present in the database.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants