Skip to content

BUG: ZeroSumTransform fails with initvalues #7772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
velochy opened this issue May 2, 2025 · 1 comment · May be fixed by #7773
Open

BUG: ZeroSumTransform fails with initvalues #7772

velochy opened this issue May 2, 2025 · 1 comment · May be fixed by #7773
Labels

Comments

@velochy
Copy link
Contributor

velochy commented May 2, 2025

Describe the issue:

Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.

It seems to be caused by input being a numpy array rather than a pytensor one.

Fix seems simple. Posting a PR for it next

Reproduceable code example:

import pymc as pm, numpy as np

with pm.Model() as model:
    pm.ZeroSumNormal('zsn',shape=(10,))
    pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
    mp = pm.find_MAP()

    pm.sample(initvals=mp)

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8
      5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
      6 mp = pm.find_MAP()
----> 8 pm.sample(initvals=mp)

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    830         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    831     with joined_blas_limiter():
--> 832         initial_points, step = init_nuts(
    833             init=init,
    834             chains=chains,
    835             n_init=n_init,
    836             model=model,
    837             random_seed=random_seed_list,
    838             progressbar=progress_bool,
    839             jitter_max_retries=jitter_max_retries,
    840             tune=tune,
    841             initvals=initvals,
    842             compile_kwargs=compile_kwargs,
    843             **kwargs,
    844         )
    845 else:
    846     # Get initial points
    847     ipfns = make_initial_point_fns_per_chain(
    848         model=model,
    849         overrides=initvals,
    850         jitter_rvs=set(),
    851         chains=chains,
    852     )

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
   1602     q, _ = DictToArrayBijection.map(ip)
   1603     return logp_dlogp_func([q], extra_vars={})[0]
-> 1605 initial_points = _init_jitter(
   1606     model,
   1607     initvals,
   1608     seeds=random_seed_list,
   1609     jitter="jitter" in init,
   1610     jitter_max_retries=jitter_max_retries,
   1611     logp_fn=model_logp_fn,
   1612 )
   1614 apoints = [DictToArrayBijection.map(point) for point in initial_points]
   1615 apoints_data = [apoint.data for apoint in apoints]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn)
   1432 def _init_jitter(
   1433     model: Model,
   1434     initvals: StartDict | Sequence[StartDict | None] | None,
   (...)
   1438     logp_fn: Callable[[PointType], np.ndarray] | None = None,
   1439 ) -> list[PointType]:
   1440     """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
   1441 
   1442     ``model.check_start_vals`` is used to test whether the jittered starting
   (...)
   1460         List of starting points for the sampler
   1461     """
-> 1462     ipfns = make_initial_point_fns_per_chain(
   1463         model=model,
   1464         overrides=initvals,
   1465         jitter_rvs=set(model.free_RVs) if jitter else set(),
   1466         chains=len(seeds),
   1467     )
   1469     if not jitter:
   1470         return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains)
     72 """Create an initial point function for each chain, as defined by initvals.
     73 
     74 If a single initval dictionary is passed, the function is replicated for each
   (...)
     95 
     96 """
     97 if isinstance(overrides, dict) or overrides is None:
     98     # One strategy for all chains
     99     # Only one function compilation is needed.
    100     ipfns = [
--> 101         make_initial_point_fn(
    102             model=model,
    103             overrides=overrides,
    104             jitter_rvs=jitter_rvs,
    105             return_transformed=True,
    106         )
    107     ] * chains
    108 elif len(overrides) == chains:
    109     ipfns = [
    110         make_initial_point_fn(
    111             model=model,
   (...)
    116         for chain_overrides in overrides
    117     ]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed)
    126 def make_initial_point_fn(
    127     *,
    128     model,
   (...)
    132     return_transformed: bool = True,
    133 ) -> Callable[[SeedSequenceSeed], PointType]:
    134     """Create seeded function that computes initial values for all free model variables.
    135 
    136     Parameters
   (...)
    150     initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
    151     """
--> 152     sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
    153     initval_strats = {
    154         **model.rvs_to_initial_values,
    155         **sdict_overrides,
    156     }
    158     initial_values = make_initial_point_expression(
    159         free_rvs=model.free_RVs,
    160         rvs_to_transforms=model.rvs_to_transforms,
   (...)
    164         return_transformed=return_transformed,
    165     )

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start)
     55 if is_transformed_name(key):
     56     rv = model[get_untransformed_name(key)]
---> 57     initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
     58 else:
     59     initvals[model[key]] = initval

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs)
    307 def backward(self, value, *rv_inputs):
    308     for axis in self.zerosum_axes:
--> 309         value = self.extend_axis(value, axis=axis)
    310     return value

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis)
    279 @staticmethod
    280 def extend_axis(array, axis):
--> 281     n = (array.shape[axis] + 1).astype("floatX")
    282     sum_vals = array.sum(axis, keepdims=True)
    283     norm = sum_vals / (pt.sqrt(n) + n)

AttributeError: 'int' object has no attribute 'astype'

PyMC version information:

pymc 5.22.0

Context for the issue:

I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.

@velochy velochy added the bug label May 2, 2025
@velochy velochy linked a pull request May 2, 2025 that will close this issue
6 tasks
@velochy
Copy link
Contributor Author

velochy commented May 2, 2025

One-line fix at #7773

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant