Skip to content

VI using deprecated Aesara MRG sampler #4523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
twiecki opened this issue Mar 10, 2021 · 14 comments · Fixed by #6304
Closed

VI using deprecated Aesara MRG sampler #4523

twiecki opened this issue Mar 10, 2021 · 14 comments · Fixed by #6304
Assignees
Labels
v4 VI Variational Inference

Comments

@twiecki
Copy link
Member

twiecki commented Mar 10, 2021

The current usage of the mrng sampler in ADVI does not support JAX (pymc-devs/pytensor#322). It should be fairly easy to instead make it use the new RandomVariable Op pymc-devs/pytensor#296 to get JAX support for ADVI.

CC @ferrine

@twiecki twiecki added the jax label Mar 10, 2021
@ferrine
Copy link
Member

ferrine commented Mar 13, 2021

@brandonwillard should I use the master branch or V4 for development?

@brandonwillard
Copy link
Contributor

We've been putting in PRs to the v4 branch.

@ferrine
Copy link
Member

ferrine commented Mar 13, 2021

So yes, I'll open a PR to v4 branch:tada:

@brandonwillard
Copy link
Contributor

brandonwillard commented Mar 13, 2021

Oh, wait, the change requested here could go toward v3 or v4.

I would prioritize v4, but we need to port the VI code to v4 first and foremost, and that work is independent of this exact issue. Both could be done simultaneously, though.

@ferrine
Copy link
Member

ferrine commented Mar 14, 2021

@brandonwillard, running pymc3/tests/test_variational_inference.py I have errors unrelataed to VI, is it expected?

ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises0-grouping0]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises1-grouping1]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises2-grouping2]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises3-grouping3]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises4-grouping4]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises5-grouping5]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: None]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[FullRankGroup: None, MeanFieldGroup: ['one']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], FullRankGroup: ['two'], NormalizingFlowGroup: ['three']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], FullRankGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], EmpiricalGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises0-mean_field-MeanFieldGroup-kw0]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises1-mf-MeanFieldGroup-kw1]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises2-full_rank-FullRankGroup-kw2]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises3-fr-FullRankGroup-kw3]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises4-FR-FullRankGroup-kw4]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises5-loc-NormalizingFlowGroup-kw5]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises6-scale-NormalizingFlowGroup-kw6]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises7-hh-NormalizingFlowGroup-kw7]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises8-planar-NormalizingFlowGroup-kw8]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises9-radial-NormalizingFlowGroup-kw9]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises10-scale-loc-NormalizingFlowGroup-kw10]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises11-empirical-EmpiricalGroup-kw11]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises12-empirical-EmpiricalGroup-kw12]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises0-params0-MeanFieldGroup-kw0-None]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises1-params1-FullRankGroup-kw1-None]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises2-params2-NormalizingFlowGroup-kw2-loc]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises3-params3-NormalizingFlowGroup-kw3-scale]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises4-params4-NormalizingFlowGroup-kw4-hh]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises5-params5-NormalizingFlowGroup-kw5-planar]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises6-params6-NormalizingFlowGroup-kw6-radial]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises7-params7-NormalizingFlowGroup-kw7-scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises8-params8-EmpiricalGroup-kw8-None]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[MeanFieldGroup-MeanField-kw0]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[FullRankGroup-FullRank-kw1]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[EmpiricalGroup-Empirical-kw2]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw3]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw4]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw5]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[NFVI=scale-loc-mini]
ERROR pymc3/tests/test_variational_inference.py::test_profile[NFVI=scale-loc-mini]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[NFVI=scale-loc-full]
ERROR pymc3/tests/test_variational_inference.py::test_profile[NFVI=scale-loc-full]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ADVI-full] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ADVI-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ADVI-mini] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ADVI-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[ADVI] - Deprecatio...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[ADVI] - Ty...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[ADVI]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[FullRankADVI-full]
ERROR pymc3/tests/test_variational_inference.py::test_profile[FullRankADVI-full]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[FullRankADVI-mini]
ERROR pymc3/tests/test_variational_inference.py::test_profile[FullRankADVI-mini]
ERROR pymc3/tests/test_variational_inference.py::test_aevb[FullRankADVI] - De...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[FullRankADVI]
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[FullRankADVI]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[SVGD-full] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[SVGD-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[SVGD-mini] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[SVGD-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[SVGD] - Deprecatio...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[SVGD] - Ty...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[SVGD]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ASVGD-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ASVGD-full] - A...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ASVGD-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ASVGD-mini] - A...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[ASVGD] - Deprecati...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[ASVGD] - T...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[ASVGD]
ERROR pymc3/tests/test_variational_inference.py::test_aevb[NFVI=scale-loc] - ...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[NFVI=scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[NFVI=scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: None]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[FullRankGroup: None, MeanFieldGroup: ['one']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], FullRankGroup: ['two'], NormalizingFlowGroup: ['three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], FullRankGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], EmpiricalGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_single_group - D...
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]

@ferrine
Copy link
Member

ferrine commented Mar 14, 2021

I have aesara v2.0.2 there

@ferrine
Copy link
Member

ferrine commented Mar 14, 2021

using aesara master results in the same

@ferrine ferrine added the v4 label Mar 14, 2021
@ferrine
Copy link
Member

ferrine commented Mar 14, 2021

I have errors like this

self = <pymc3.step_methods.metropolis.Metropolis object at 0x7f44b6c0efd0>, vars = [x], S = None, proposal_dist = None, scaling = 1.0, tune = True, tune_interval = 100
model = <pymc3.model.Model object at 0x7f44b64c5430>, mode = None, kwargs = {}

    def __init__(
        self,
        vars=None,
        S=None,
        proposal_dist=None,
        scaling=1.0,
        tune=True,
        tune_interval=100,
        model=None,
        mode=None,
        **kwargs
    ):
        """Create an instance of a Metropolis stepper

        Parameters
        ----------
        vars: list
            List of variables for sampler
        S: standard deviation or covariance matrix
            Some measure of variance to parameterize proposal distribution
        proposal_dist: function
            Function that returns zero-mean deviates when parameterized with
            S (and n). Defaults to normal.
        scaling: scalar or array
            Initial scale factor for proposal. Defaults to 1.
        tune: bool
            Flag for tuning. Defaults to True.
        tune_interval: int
            The frequency of tuning. Defaults to 100 iterations.
        model: PyMC Model
            Optional model for sampling step. Defaults to None (taken from context).
        mode: string or `Mode` instance.
            compilation mode passed to Aesara functions
        """

        model = pm.modelcontext(model)

        if vars is None:
            vars = model.vars
        vars = pm.inputvars(vars)

        if S is None:
            # XXX: This needs to be refactored
            S = None  # np.ones(sum(v.dsize for v in vars))

        if proposal_dist is not None:
            self.proposal_dist = proposal_dist(S)
>       elif S.ndim == 1:
E       AttributeError: 'NoneType' object has no attribute 'ndim'

@brandonwillard
Copy link
Contributor

Yes, you should expect all sorts of errors.

@fonnesbeck
Copy link
Member

Does any of this change with the Numba backend?

@brandonwillard
Copy link
Contributor

Does any of this change with the Numba backend?

It shouldn't; in general, if we use Aesara, its backends should handle everything without requiring backend-specific logic at this level.

@ricardoV94 ricardoV94 removed the jax label Nov 18, 2021
@michaelosthege michaelosthege modified the milestones: v4.0.0b2, v4.0.0b3 Jan 7, 2022
@ricardoV94 ricardoV94 modified the milestones: v4.0.0b3, v4.0.0 Feb 7, 2022
@bukson
Copy link

bukson commented Aug 26, 2022

It is not working with pymc 4:

import pymc

with pm.Model():
    x = pm.Normal("x")
    pm.fit()

WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 5, in <cell line: 3>
pm.fit()
File "/usr/local/lib/python3.10/site-packages/pymc/variational/inference.py", line 744, in fit
return inference.fit(n, **kwargs)
File "/usr/local/lib/python3.10/site-packages/pymc/variational/inference.py", line 138, in fit
step_func = self.objective.step_function(score=score, **kwargs)
File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 355, in step_function
updates = self.updates(
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 244, in updates
self.add_obj_updates(
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 289, in add_obj_updates
obj_target = self(
File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 407, in call
a = self.approx.set_size_and_deterministic(a, nmc, 0, kwargs.get("more_replacements"))
File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 1359, in set_size_and_deterministic
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 1333, in make_size_and_deterministic_replacements
flat2rand.update(g.make_size_and_deterministic_replacements(s, d, more_replacements))
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 1067, in make_size_and_deterministic_replacements
initial = self._new_initial(s, d, more_replacements)
File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 978, in _new_initial
return getattr(self._rng, dist_name)(size=shape)
File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 1184, in normal
uniform = self.uniform(
File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 914, in uniform
rstates = self.get_substream_rstates(nstreams, dtype)
File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 818, in get_substream_rstates
multMatVect(rval[0], A1p72, M1, A2p72, M2)
File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 66, in multMatVect
multMatVect.dot_modulo = function(
File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/init.py", line 317, in function
fn = pfunc(
File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/pfunc.py", line 374, in pfunc
return orig_function(
File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1763, in orig_function
fn = m.create(defaults)
File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1656, in create
_fn, _i, _o = self.linker.make_thunk(
File "/usr/local/lib/python3.10/site-packages/aesara/link/basic.py", line 254, in make_thunk
return self.make_all(
File "/usr/local/lib/python3.10/site-packages/aesara/link/basic.py", line 697, in make_all
thunks, nodes, jit_fn = self.create_jitable_thunk(
File "/usr/local/lib/python3.10/site-packages/aesara/link/basic.py", line 646, in create_jitable_thunk
converted_fgraph = self.fgraph_convert(
File "/usr/local/lib/python3.10/site-packages/aesara/link/jax/linker.py", line 13, in fgraph_convert
return jax_funcify(fgraph, **kwargs)
File "/usr/local/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].class)(*args, **kw)
File "/usr/local/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 670, in jax_funcify_FunctionGraph
return fgraph_to_python(
File "/usr/local/lib/python3.10/site-packages/aesara/link/utils.py", line 741, in fgraph_to_python
compiled_func = op_conversion_fn(
File "/usr/local/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].class)(*args, **kw)
File "/usr/local/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 143, in jax_funcify
raise NotImplementedError(f"No JAX conversion for the given Op: {op}")
NotImplementedError: No JAX conversion for the given Op: DotModulo

@ricardoV94
Copy link
Member

@bukson please open an issue on our discourse at: https://discourse.pymc.io/

@twiecki
Copy link
Member Author

twiecki commented Sep 15, 2022

No this is a legit issue, we should port VI to use random variable: aesara-devs/aesara#322 (comment). CC @ferrine

@twiecki twiecki reopened this Sep 15, 2022
@ricardoV94 ricardoV94 changed the title Port VI to use RandomVariable VI fails in JAX mode Nov 5, 2022
@ricardoV94 ricardoV94 added the VI Variational Inference label Nov 5, 2022
@ricardoV94 ricardoV94 modified the milestones: v4.0.0, v4.4.0 Nov 5, 2022
@ricardoV94 ricardoV94 pinned this issue Nov 5, 2022
@ricardoV94 ricardoV94 changed the title VI fails in JAX mode VI using deprecated Aesara MRG sampler Nov 7, 2022
@ricardoV94 ricardoV94 removed this from the v4.4.0 milestone Nov 9, 2022
@ricardoV94 ricardoV94 unpinned this issue Dec 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
v4 VI Variational Inference
Projects
None yet
7 participants