Skip to content

Commit e85677b

Browse files
More bugfixes for statespace (#346)
* Allow forward sampling of statespace models in JAX mode Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs * Adjust test suite to reflect API changes Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests * Add JAX test suite * Bug-fixes and changes to statespace distributions Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal` * Re-run example notebooks * Add helper function to sample prior/posterior statespace matrices * fix tests * Wrap jax MvNormal rewrite in try/except block * Don't use `action` keyword in `catch_warnings` * Skip JAX test if `numpyro` is not installed * Handle batch dims on `SequenceMvNormal` * Remove unused batch_dim logic in SequenceMvNormal * Restore `get_support_shape_1d` import
1 parent 51704bd commit e85677b

18 files changed

+2090
-1684
lines changed

Diff for: notebooks/Making a Custom Statespace Model.ipynb

+63-137
Large diffs are not rendered by default.

Diff for: notebooks/SARMA Example.ipynb

+603-572
Large diffs are not rendered by default.

Diff for: notebooks/Structural Timeseries Modeling.ipynb

+344-405
Large diffs are not rendered by default.

Diff for: notebooks/VARMAX Example.ipynb

+345-288
Large diffs are not rendered by default.

Diff for: pymc_experimental/statespace/core/representation.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from typing import Optional, Type, Union
23

34
import numpy as np
@@ -10,7 +11,7 @@
1011
)
1112

1213
floatX = pytensor.config.floatX
13-
KeyLike = Union[tuple[Union[str, int]], str]
14+
KeyLike = Union[tuple[str | int, ...], str]
1415

1516

1617
class PytensorRepresentation:
@@ -152,6 +153,22 @@ class PytensorRepresentation:
152153
http://www.chadfulton.com/files/fulton_statsmodels_2017_v1.pdf
153154
"""
154155

156+
__slots__ = (
157+
"k_endog",
158+
"k_states",
159+
"k_posdef",
160+
"shapes",
161+
"design",
162+
"obs_intercept",
163+
"obs_cov",
164+
"transition",
165+
"state_intercept",
166+
"selection",
167+
"state_cov",
168+
"initial_state",
169+
"initial_state_cov",
170+
)
171+
155172
def __init__(
156173
self,
157174
k_endog: int,
@@ -206,16 +223,17 @@ def _validate_key(self, key: KeyLike) -> None:
206223
if key not in self.shapes:
207224
raise IndexError(f"{key} is an invalid state space matrix name")
208225

209-
def _update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.TensorType]) -> None:
226+
def _update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.Variable]) -> None:
210227
if isinstance(value, (pt.TensorConstant, pt.TensorVariable)):
211228
shape = value.type.shape
212229
else:
213230
shape = value.shape
214231

215232
old_shape = self.shapes[key]
216-
if not all([a == b for a, b in zip(shape[1:], old_shape[1:])]):
233+
ndim_core = 1 if key in VECTOR_VALUED else 2
234+
if not all([a == b for a, b in zip(shape[-ndim_core:], old_shape[-ndim_core:])]):
217235
raise ValueError(
218-
f"The last two dimensions of {key} must be {old_shape[1:]}, found {shape[1:]}"
236+
f"The last two dimensions of {key} must be {old_shape[-ndim_core:]}, found {shape[-ndim_core:]}"
219237
)
220238

221239
# Add time dimension dummy if none present
@@ -229,7 +247,7 @@ def _update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.TensorType]) -
229247

230248
def _add_time_dim_to_slice(
231249
self, name: str, slice_: Union[list[int], tuple[int]], n_dim: int
232-
) -> tuple[int]:
250+
) -> tuple[int | slice, ...]:
233251
# Case 1: There is never a time dim. No changes needed.
234252
if name in NEVER_TIME_VARYING:
235253
return slice_
@@ -389,7 +407,7 @@ def __getitem__(self, key: KeyLike) -> pt.TensorVariable:
389407
else:
390408
raise IndexError("First index must the name of a valid state space matrix.")
391409

392-
def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray]) -> None:
410+
def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray, pt.Variable]) -> None:
393411
_type = type(key)
394412

395413
# Case 1: key is a string: we are setting an entire matrix.
@@ -416,3 +434,6 @@ def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray]) -> Non
416434
matrix.name = name
417435

418436
setattr(self, name, matrix)
437+
438+
def copy(self):
439+
return copy.copy(self)

0 commit comments

Comments
 (0)