We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.
It seems to be caused by input being a numpy array rather than a pytensor one.
Fix seems simple. Posting a PR for it next
import pymc as pm, numpy as np with pm.Model() as model: pm.ZeroSumNormal('zsn',shape=(10,)) pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0])) mp = pm.find_MAP() pm.sample(initvals=mp)
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8 5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0])) 6 mp = pm.find_MAP() ----> 8 pm.sample(initvals=mp) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs) 830 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] 831 with joined_blas_limiter(): --> 832 initial_points, step = init_nuts( 833 init=init, 834 chains=chains, 835 n_init=n_init, 836 model=model, 837 random_seed=random_seed_list, 838 progressbar=progress_bool, 839 jitter_max_retries=jitter_max_retries, 840 tune=tune, 841 initvals=initvals, 842 compile_kwargs=compile_kwargs, 843 **kwargs, 844 ) 845 else: 846 # Get initial points 847 ipfns = make_initial_point_fns_per_chain( 848 model=model, 849 overrides=initvals, 850 jitter_rvs=set(), 851 chains=chains, 852 ) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs) 1602 q, _ = DictToArrayBijection.map(ip) 1603 return logp_dlogp_func([q], extra_vars={})[0] -> 1605 initial_points = _init_jitter( 1606 model, 1607 initvals, 1608 seeds=random_seed_list, 1609 jitter="jitter" in init, 1610 jitter_max_retries=jitter_max_retries, 1611 logp_fn=model_logp_fn, 1612 ) 1614 apoints = [DictToArrayBijection.map(point) for point in initial_points] 1615 apoints_data = [apoint.data for apoint in apoints] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn) 1432 def _init_jitter( 1433 model: Model, 1434 initvals: StartDict | Sequence[StartDict | None] | None, (...) 1438 logp_fn: Callable[[PointType], np.ndarray] | None = None, 1439 ) -> list[PointType]: 1440 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. 1441 1442 ``model.check_start_vals`` is used to test whether the jittered starting (...) 1460 List of starting points for the sampler 1461 """ -> 1462 ipfns = make_initial_point_fns_per_chain( 1463 model=model, 1464 overrides=initvals, 1465 jitter_rvs=set(model.free_RVs) if jitter else set(), 1466 chains=len(seeds), 1467 ) 1469 if not jitter: 1470 return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains) 72 """Create an initial point function for each chain, as defined by initvals. 73 74 If a single initval dictionary is passed, the function is replicated for each (...) 95 96 """ 97 if isinstance(overrides, dict) or overrides is None: 98 # One strategy for all chains 99 # Only one function compilation is needed. 100 ipfns = [ --> 101 make_initial_point_fn( 102 model=model, 103 overrides=overrides, 104 jitter_rvs=jitter_rvs, 105 return_transformed=True, 106 ) 107 ] * chains 108 elif len(overrides) == chains: 109 ipfns = [ 110 make_initial_point_fn( 111 model=model, (...) 116 for chain_overrides in overrides 117 ] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed) 126 def make_initial_point_fn( 127 *, 128 model, (...) 132 return_transformed: bool = True, 133 ) -> Callable[[SeedSequenceSeed], PointType]: 134 """Create seeded function that computes initial values for all free model variables. 135 136 Parameters (...) 150 initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]] 151 """ --> 152 sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) 153 initval_strats = { 154 **model.rvs_to_initial_values, 155 **sdict_overrides, 156 } 158 initial_values = make_initial_point_expression( 159 free_rvs=model.free_RVs, 160 rvs_to_transforms=model.rvs_to_transforms, (...) 164 return_transformed=return_transformed, 165 ) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start) 55 if is_transformed_name(key): 56 rv = model[get_untransformed_name(key)] ---> 57 initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs) 58 else: 59 initvals[model[key]] = initval File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs) 307 def backward(self, value, *rv_inputs): 308 for axis in self.zerosum_axes: --> 309 value = self.extend_axis(value, axis=axis) 310 return value File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis) 279 @staticmethod 280 def extend_axis(array, axis): --> 281 n = (array.shape[axis] + 1).astype("floatX") 282 sum_vals = array.sum(axis, keepdims=True) 283 norm = sum_vals / (pt.sqrt(n) + n) AttributeError: 'int' object has no attribute 'astype'
pymc 5.22.0
I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.
The text was updated successfully, but these errors were encountered:
One-line fix at #7773
Sorry, something went wrong.
Successfully merging a pull request may close this issue.
Describe the issue:
Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.
It seems to be caused by input being a numpy array rather than a pytensor one.
Fix seems simple. Posting a PR for it next
Reproduceable code example:
Error message:
PyMC version information:
pymc 5.22.0
Context for the issue:
I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.
The text was updated successfully, but these errors were encountered: