Skip to content

Commit 712bb7a

Browse files
5hv5hvnktwiecki
andauthored
Updated load method in ModelBuilder (#104)
* added predict_posterior * changed tests * Update pymc_experimental/model_builder.py * updated load method * updated load method * added required changes * added required changes * fixed laod function * fixed laod functio * Update pymc_experimental/model_builder.py * Update pymc_experimental/model_builder.py * Apply suggestions from code review * Restructure tests. Fix load(). * Bump isort. * Skip saving on windows. * Different approach. * Different approach. * Different approach. * Different approach. --------- Co-authored-by: Thomas Wiecki <[email protected]>
1 parent e8b1ad7 commit 712bb7a

File tree

3 files changed

+50
-102
lines changed

3 files changed

+50
-102
lines changed

Diff for: .pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
args: [--branch, main]
1212
- id: trailing-whitespace
1313
- repo: https://github.com/PyCQA/isort
14-
rev: 5.10.1
14+
rev: 5.12.0
1515
hooks:
1616
- id: isort
1717
name: isort

Diff for: pymc_experimental/model_builder.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The PyMC Developers
1+
# Copyright 2023 The PyMC Developers
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -185,20 +185,19 @@ def load(cls, fname):
185185
"""
186186

187187
filepath = Path(str(fname))
188-
data = az.from_netcdf(filepath)
189-
idata = data
190-
# Since there is an issue with attrs getting saved in netcdf format which will be fixed in future the following part of code is commented
191-
# Link of issue -> https://github.com/arviz-devs/arviz/issues/2109
192-
# if model.idata.attrs is not None:
193-
# if model.idata.attrs['id'] == self.idata.attrs['id']:
194-
# self = model
195-
# self.idata = data
196-
# return self
197-
# else:
198-
# raise ValueError(
199-
# f"The route '{file}' does not contain an inference data of the same model '{self.__name__}'"
200-
# )
201-
return idata
188+
idata = az.from_netcdf(filepath)
189+
self = cls(
190+
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
191+
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
192+
idata.fit_data.to_dataframe(),
193+
)
194+
self.idata = idata
195+
if self.id() != idata.attrs["id"]:
196+
raise ValueError(
197+
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
198+
)
199+
200+
return self
202201

203202
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
204203
"""
@@ -238,8 +237,11 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
238237
self.idata.attrs["id"] = self.id()
239238
self.idata.attrs["model_type"] = self._model_type
240239
self.idata.attrs["version"] = self.version
241-
self.idata.attrs["sample_conifg"] = self.sample_config
242-
self.idata.attrs["model_config"] = self.model_config
240+
self.idata.attrs["sample_config_keys"] = tuple(self.sample_config.keys())
241+
self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values())
242+
self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys())
243+
self.idata.attrs["model_config_values"] = tuple(self.model_config.values())
244+
self.idata.add_groups(fit_data=self.data.to_xarray())
243245
return self.idata
244246

245247
def predict(
@@ -306,7 +308,7 @@ def predict_posterior(
306308
>>> model = LinearModel(model_config, sampler_config)
307309
>>> idata = model.fit(data)
308310
>>> x_pred = []
309-
>>> prediction_data = pd.DataFrame({'input':x_pred})
311+
>>> prediction_data = pd.DataFrame({'input': x_pred})
310312
# samples
311313
>>> pred_mean = model.predict_posterior(prediction_data)
312314
"""
@@ -355,5 +357,5 @@ def id(self):
355357
hasher.update(str(self.model_config.values()).encode())
356358
hasher.update(self.version.encode())
357359
hasher.update(self._model_type.encode())
358-
hasher.update(str(self.sample_config.values()).encode())
360+
# hasher.update(str(self.sample_config.values()).encode())
359361
return hasher.hexdigest()[:16]

Diff for: pymc_experimental/tests/test_model_builder.py

+28-82
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The PyMC Developers
1+
# Copyright 2023 The PyMC Developers
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,9 +13,13 @@
1313
# limitations under the License.
1414

1515

16+
import sys
17+
import tempfile
18+
1619
import numpy as np
1720
import pandas as pd
1821
import pymc as pm
22+
import pytest
1923

2024
from pymc_experimental.model_builder import ModelBuilder
2125

@@ -77,93 +81,35 @@ def create_sample_input(cls):
7781

7882

7983
def test_fit():
80-
with pm.Model() as model:
81-
x = np.linspace(start=0, stop=1, num=100)
82-
y = 5 * x + 3
83-
x = pm.MutableData("x", x)
84-
y_data = pm.MutableData("y_data", y)
85-
86-
a_loc = 7
87-
a_scale = 3
88-
b_loc = 5
89-
b_scale = 3
90-
obs_error = 2
91-
92-
a = pm.Normal("a", a_loc, sigma=a_scale)
93-
b = pm.Normal("b", b_loc, sigma=b_scale)
94-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
95-
96-
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
97-
98-
idata = pm.sample(tune=100, draws=200, chains=1, cores=1, target_accept=0.5)
99-
idata.extend(pm.sample_prior_predictive())
100-
idata.extend(pm.sample_posterior_predictive(idata))
101-
10284
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
103-
model_2 = test_ModelBuilder(model_config, sampler_config, data)
104-
model_2.idata = model_2.fit()
105-
assert str(model_2.idata.groups) == str(idata.groups)
106-
85+
model = test_ModelBuilder(model_config, sampler_config, data)
86+
model.fit()
87+
assert model.idata is not None
88+
assert "posterior" in model.idata.groups()
10789

108-
def test_predict():
10990
x_pred = np.random.uniform(low=0, high=1, size=100)
11091
prediction_data = pd.DataFrame({"input": x_pred})
111-
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
112-
model_2 = test_ModelBuilder(model_config, sampler_config, data)
113-
model_2.idata = model_2.fit()
114-
model_2.predict(prediction_data)
115-
with pm.Model() as model:
116-
x = np.linspace(start=0, stop=1, num=100)
117-
y = 5 * x + 3
118-
x = pm.MutableData("x", x)
119-
y_data = pm.MutableData("y_data", y)
120-
a_loc = 7
121-
a_scale = 3
122-
b_loc = 5
123-
b_scale = 3
124-
obs_error = 2
125-
126-
a = pm.Normal("a", a_loc, sigma=a_scale)
127-
b = pm.Normal("b", b_loc, sigma=b_scale)
128-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
129-
130-
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
92+
pred = model.predict(prediction_data)
93+
assert "y_model" in pred.keys()
94+
post_pred = model.predict_posterior(prediction_data)
95+
assert "y_model" in post_pred.keys()
13196

132-
idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
133-
idata.extend(pm.sample_prior_predictive())
134-
idata.extend(pm.sample_posterior_predictive(idata))
135-
y_test = pm.sample_posterior_predictive(idata)
136-
137-
assert str(model_2.idata.groups) == str(idata.groups)
13897

98+
@pytest.mark.skipif(
99+
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
100+
)
101+
def test_save_load():
102+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
103+
model = test_ModelBuilder(model_config, sampler_config, data)
104+
model.fit()
105+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
106+
model.save(temp.name)
107+
model2 = test_ModelBuilder.load(temp.name)
108+
assert model.idata.groups() == model2.idata.groups()
139109

140-
def test_predict_posterior():
141110
x_pred = np.random.uniform(low=0, high=1, size=100)
142111
prediction_data = pd.DataFrame({"input": x_pred})
143-
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
144-
model_2 = test_ModelBuilder(model_config, sampler_config, data)
145-
model_2.idata = model_2.fit()
146-
model_2.predict_posterior(prediction_data)
147-
with pm.Model() as model:
148-
x = np.linspace(start=0, stop=1, num=100)
149-
y = 5 * x + 3
150-
x = pm.MutableData("x", x)
151-
y_data = pm.MutableData("y_data", y)
152-
a_loc = 7
153-
a_scale = 3
154-
b_loc = 5
155-
b_scale = 3
156-
obs_error = 2
157-
158-
a = pm.Normal("a", a_loc, sigma=a_scale)
159-
b = pm.Normal("b", b_loc, sigma=b_scale)
160-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
161-
162-
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
163-
164-
idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
165-
idata.extend(pm.sample_prior_predictive())
166-
idata.extend(pm.sample_posterior_predictive(idata))
167-
y_test = pm.sample_posterior_predictive(idata)
168-
169-
assert str(model_2.idata.groups) == str(idata.groups)
112+
pred1 = model.predict(prediction_data)
113+
pred2 = model2.predict(prediction_data)
114+
assert pred1["y_model"].shape == pred2["y_model"].shape
115+
temp.close()

0 commit comments

Comments
 (0)