Skip to content

Commit 17d8d19

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 607e0d5 commit 17d8d19

7 files changed

+80
-48
lines changed

notebooks/Making a Custom Statespace Model.ipynb

+8-9
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
"\n",
2020
"numpyro.set_host_device_count(4)\n",
2121
"\n",
22-
"import numpy as np\n",
23-
"import matplotlib.pyplot as plt\n",
2422
"import arviz as az\n",
25-
"\n",
26-
"from pymc_experimental.statespace.core.statespace import PyMCStateSpace\n",
23+
"import matplotlib.pyplot as plt\n",
24+
"import numpy as np\n",
25+
"import pymc as pm\n",
2726
"import pytensor.tensor as pt\n",
28-
"import pymc as pm"
27+
"\n",
28+
"from pymc_experimental.statespace.core.statespace import PyMCStateSpace"
2929
]
3030
},
3131
{
@@ -1092,7 +1092,7 @@
10921092
],
10931093
"source": [
10941094
"az.plot_posterior(\n",
1095-
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=true_ar.tolist() + [true_sigma_x]\n",
1095+
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=[*true_ar.tolist(), true_sigma_x]\n",
10961096
");"
10971097
]
10981098
},
@@ -1169,13 +1169,12 @@
11691169
"metadata": {},
11701170
"outputs": [],
11711171
"source": [
1172+
"from pymc_experimental.statespace.models.utilities import make_default_coords\n",
11721173
"from pymc_experimental.statespace.utils.constants import (\n",
1173-
" ALL_STATE_DIM,\n",
11741174
" ALL_STATE_AUX_DIM,\n",
1175-
" OBS_STATE_DIM,\n",
1175+
" ALL_STATE_DIM,\n",
11761176
" SHOCK_DIM,\n",
11771177
")\n",
1178-
"from pymc_experimental.statespace.models.utilities import make_default_coords\n",
11791178
"\n",
11801179
"\n",
11811180
"class AutoRegressiveThree(PyMCStateSpace):\n",

notebooks/SARMA Example.ipynb

+8-9
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,19 @@
3131
"\n",
3232
"numpyro.set_host_device_count(8)\n",
3333
"\n",
34-
"import pymc as pm\n",
35-
"from pytensor import tensor as pt\n",
36-
"\n",
3734
"import arviz as az\n",
38-
"import statsmodels.api as sm\n",
3935
"import matplotlib.pyplot as plt\n",
4036
"import numpy as np\n",
4137
"import pandas as pd\n",
42-
"from scipy import stats\n",
38+
"import pymc as pm\n",
39+
"import statsmodels.api as sm\n",
4340
"\n",
44-
"import pymc_experimental.statespace as pmss\n",
41+
"from pymc.model.transform.optimization import freeze_dims_and_data\n",
42+
"from pytensor import tensor as pt\n",
4543
"from pytensor.link.jax.dispatch import jax_funcify\n",
4644
"from pytensor.tensor.nlinalg import KroneckerProduct\n",
47-
"from pymc.model.transform.optimization import freeze_dims_and_data\n",
45+
"\n",
46+
"import pymc_experimental.statespace as pmss\n",
4847
"\n",
4948
"\n",
5049
"@jax_funcify.register(KroneckerProduct)\n",
@@ -2582,8 +2581,8 @@
25822581
"source": [
25832582
"fig, ax = plt.subplots()\n",
25842583
"post = az.extract(post_pred).map(np.exp)\n",
2585-
"hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n",
2586-
"post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
2584+
"hdi = az.hdi(post_pred.map(np.exp))[\"predicted_posterior_observed\"]\n",
2585+
"post[\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
25872586
" x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n",
25882587
")\n",
25892588
"ax.fill_between(\n",

notebooks/Structural Timeseries Modeling.ipynb

+10-10
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,18 @@
2222
"import sys\n",
2323
"\n",
2424
"sys.path.append(\"..\")\n",
25-
"from pymc_experimental.statespace import structural as st\n",
26-
"from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG, MATRIX_NAMES\n",
27-
"import matplotlib.pyplot as plt\n",
28-
"import pymc as pm\n",
2925
"import arviz as az\n",
30-
"import pytensor\n",
31-
"import pytensor.tensor as pt\n",
26+
"import matplotlib.pyplot as plt\n",
3227
"import numpy as np\n",
3328
"import pandas as pd\n",
29+
"import pymc as pm\n",
30+
"import pytensor.tensor as pt\n",
31+
"\n",
3432
"from patsy import dmatrix\n",
3533
"\n",
34+
"from pymc_experimental.statespace import structural as st\n",
35+
"from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG\n",
36+
"\n",
3637
"plt.rcParams.update(\n",
3738
" {\n",
3839
" \"figure.figsize\": (14, 4),\n",
@@ -61,15 +62,14 @@
6162
},
6263
"outputs": [],
6364
"source": [
64-
"from pymc_experimental.statespace.filters.kalman_filter import StandardFilter\n",
65-
"from pymc_experimental.statespace.filters.kalman_smoother import KalmanSmoother\n",
65+
"from pymc.pytensorf import compile_pymc, inputvars\n",
66+
"\n",
6667
"from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace\n",
67-
"from pymc.pytensorf import inputvars, compile_pymc\n",
6868
"\n",
6969
"\n",
7070
"def make_numpy_function(mod):\n",
7171
" mod = mod.build(verbose=False)\n",
72-
" data = pt.matrix(\"data\", shape=(None, 1))\n",
72+
" pt.matrix(\"data\", shape=(None, 1))\n",
7373
" steps = pt.iscalar(\"steps\")\n",
7474
" x0, _, c, d, T, Z, R, H, Q = mod._unpack_statespace_with_placeholders()\n",
7575
" sequence_names = [x.name for x in [c, d] if x.ndim == 2]\n",

notebooks/VARMAX Example.ipynb

+8-8
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@
1414
"\n",
1515
"numpyro.set_host_device_count(8)\n",
1616
"\n",
17+
"import sys\n",
18+
"\n",
19+
"import arviz as az\n",
20+
"import matplotlib.pyplot as plt\n",
1721
"import numpy as np\n",
18-
"import statsmodels.api as sm\n",
1922
"import pandas as pd\n",
20-
"\n",
2123
"import pymc as pm\n",
2224
"import pytensor.tensor as pt\n",
23-
"import arviz as az\n",
24-
"\n",
25-
"import matplotlib.pyplot as plt\n",
26-
"import sys\n",
25+
"import statsmodels.api as sm\n",
2726
"\n",
2827
"sys.path.append(\"..\")\n",
29-
"import pymc_experimental.statespace as pmss\n",
3028
"import re\n",
3129
"\n",
30+
"import pymc_experimental.statespace as pmss\n",
31+
"\n",
3232
"config = {\n",
3333
" \"figure.figsize\": [12.0, 4.0],\n",
3434
" \"figure.dpi\": 72.0 * 2,\n",
@@ -679,7 +679,7 @@
679679
" new_labels = []\n",
680680
" for label in axis.yaxis.get_majorticklabels():\n",
681681
" old_text = \"[\" + label.get_text().split(\"[\")[-1]\n",
682-
" labels = eval(re.sub(\"([\\d\\w]+)\", '\"\\g<1>\"', old_text))\n",
682+
" labels = eval(re.sub(r\"([\\d\\w]+)\", r'\"\\g<1>\"', old_text))\n",
683683
" lag, other_var = labels\n",
684684
" new_text = f\"L{lag}.{other_var}\"\n",
685685
" new_labels.append(new_text)\n",

0 commit comments

Comments
 (0)