Skip to content

Updated load method in ModelBuilder #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ repos:
args: [--branch, main]
- id: trailing-whitespace
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
name: isort
40 changes: 21 additions & 19 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The PyMC Developers
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -185,20 +185,19 @@ def load(cls, fname):
"""

filepath = Path(str(fname))
data = az.from_netcdf(filepath)
idata = data
# Since there is an issue with attrs getting saved in netcdf format which will be fixed in future the following part of code is commented
# Link of issue -> https://github.com/arviz-devs/arviz/issues/2109
# if model.idata.attrs is not None:
# if model.idata.attrs['id'] == self.idata.attrs['id']:
# self = model
# self.idata = data
# return self
# else:
# raise ValueError(
# f"The route '{file}' does not contain an inference data of the same model '{self.__name__}'"
# )
return idata
idata = az.from_netcdf(filepath)
self = cls(
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
idata.fit_data.to_dataframe(),
)
self.idata = idata
if self.id() != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
)

return self

def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
"""
@@ -238,8 +237,11 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
self.idata.attrs["id"] = self.id()
self.idata.attrs["model_type"] = self._model_type
self.idata.attrs["version"] = self.version
self.idata.attrs["sample_conifg"] = self.sample_config
self.idata.attrs["model_config"] = self.model_config
self.idata.attrs["sample_config_keys"] = tuple(self.sample_config.keys())
self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values())
self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys())
self.idata.attrs["model_config_values"] = tuple(self.model_config.values())
self.idata.add_groups(fit_data=self.data.to_xarray())
return self.idata

def predict(
@@ -306,7 +308,7 @@ def predict_posterior(
>>> model = LinearModel(model_config, sampler_config)
>>> idata = model.fit(data)
>>> x_pred = []
>>> prediction_data = pd.DataFrame({'input':x_pred})
>>> prediction_data = pd.DataFrame({'input': x_pred})
# samples
>>> pred_mean = model.predict_posterior(prediction_data)
"""
@@ -355,5 +357,5 @@ def id(self):
hasher.update(str(self.model_config.values()).encode())
hasher.update(self.version.encode())
hasher.update(self._model_type.encode())
hasher.update(str(self.sample_config.values()).encode())
# hasher.update(str(self.sample_config.values()).encode())
return hasher.hexdigest()[:16]
110 changes: 28 additions & 82 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The PyMC Developers
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +13,13 @@
# limitations under the License.


import sys
import tempfile

import numpy as np
import pandas as pd
import pymc as pm
import pytest

from pymc_experimental.model_builder import ModelBuilder

@@ -77,93 +81,35 @@ def create_sample_input(cls):


def test_fit():
with pm.Model() as model:
x = np.linspace(start=0, stop=1, num=100)
y = 5 * x + 3
x = pm.MutableData("x", x)
y_data = pm.MutableData("y_data", y)

a_loc = 7
a_scale = 3
b_loc = 5
b_scale = 3
obs_error = 2

a = pm.Normal("a", a_loc, sigma=a_scale)
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)

idata = pm.sample(tune=100, draws=200, chains=1, cores=1, target_accept=0.5)
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_posterior_predictive(idata))

data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_2 = test_ModelBuilder(model_config, sampler_config, data)
model_2.idata = model_2.fit()
assert str(model_2.idata.groups) == str(idata.groups)

model = test_ModelBuilder(model_config, sampler_config, data)
model.fit()
assert model.idata is not None
assert "posterior" in model.idata.groups()

def test_predict():
x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_2 = test_ModelBuilder(model_config, sampler_config, data)
model_2.idata = model_2.fit()
model_2.predict(prediction_data)
with pm.Model() as model:
x = np.linspace(start=0, stop=1, num=100)
y = 5 * x + 3
x = pm.MutableData("x", x)
y_data = pm.MutableData("y_data", y)
a_loc = 7
a_scale = 3
b_loc = 5
b_scale = 3
obs_error = 2

a = pm.Normal("a", a_loc, sigma=a_scale)
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
pred = model.predict(prediction_data)
assert "y_model" in pred.keys()
post_pred = model.predict_posterior(prediction_data)
assert "y_model" in post_pred.keys()

idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_posterior_predictive(idata))
y_test = pm.sample_posterior_predictive(idata)

assert str(model_2.idata.groups) == str(idata.groups)

@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model = test_ModelBuilder(model_config, sampler_config, data)
model.fit()
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
model.save(temp.name)
model2 = test_ModelBuilder.load(temp.name)
assert model.idata.groups() == model2.idata.groups()

def test_predict_posterior():
x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_2 = test_ModelBuilder(model_config, sampler_config, data)
model_2.idata = model_2.fit()
model_2.predict_posterior(prediction_data)
with pm.Model() as model:
x = np.linspace(start=0, stop=1, num=100)
y = 5 * x + 3
x = pm.MutableData("x", x)
y_data = pm.MutableData("y_data", y)
a_loc = 7
a_scale = 3
b_loc = 5
b_scale = 3
obs_error = 2

a = pm.Normal("a", a_loc, sigma=a_scale)
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)

idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_posterior_predictive(idata))
y_test = pm.sample_posterior_predictive(idata)

assert str(model_2.idata.groups) == str(idata.groups)
pred1 = model.predict(prediction_data)
pred2 = model2.predict(prediction_data)
assert pred1["y_model"].shape == pred2["y_model"].shape
temp.close()