|
1 |
| -# Copyright 2022 The PyMC Developers |
| 1 | +# Copyright 2023 The PyMC Developers |
2 | 2 | #
|
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | # you may not use this file except in compliance with the License.
|
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 |
|
| 16 | +import sys |
| 17 | +import tempfile |
| 18 | + |
16 | 19 | import numpy as np
|
17 | 20 | import pandas as pd
|
18 | 21 | import pymc as pm
|
| 22 | +import pytest |
19 | 23 |
|
20 | 24 | from pymc_experimental.model_builder import ModelBuilder
|
21 | 25 |
|
@@ -77,93 +81,35 @@ def create_sample_input(cls):
|
77 | 81 |
|
78 | 82 |
|
79 | 83 | def test_fit():
|
80 |
| - with pm.Model() as model: |
81 |
| - x = np.linspace(start=0, stop=1, num=100) |
82 |
| - y = 5 * x + 3 |
83 |
| - x = pm.MutableData("x", x) |
84 |
| - y_data = pm.MutableData("y_data", y) |
85 |
| - |
86 |
| - a_loc = 7 |
87 |
| - a_scale = 3 |
88 |
| - b_loc = 5 |
89 |
| - b_scale = 3 |
90 |
| - obs_error = 2 |
91 |
| - |
92 |
| - a = pm.Normal("a", a_loc, sigma=a_scale) |
93 |
| - b = pm.Normal("b", b_loc, sigma=b_scale) |
94 |
| - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) |
95 |
| - |
96 |
| - y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data) |
97 |
| - |
98 |
| - idata = pm.sample(tune=100, draws=200, chains=1, cores=1, target_accept=0.5) |
99 |
| - idata.extend(pm.sample_prior_predictive()) |
100 |
| - idata.extend(pm.sample_posterior_predictive(idata)) |
101 |
| - |
102 | 84 | data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
|
103 |
| - model_2 = test_ModelBuilder(model_config, sampler_config, data) |
104 |
| - model_2.idata = model_2.fit() |
105 |
| - assert str(model_2.idata.groups) == str(idata.groups) |
106 |
| - |
| 85 | + model = test_ModelBuilder(model_config, sampler_config, data) |
| 86 | + model.fit() |
| 87 | + assert model.idata is not None |
| 88 | + assert "posterior" in model.idata.groups() |
107 | 89 |
|
108 |
| -def test_predict(): |
109 | 90 | x_pred = np.random.uniform(low=0, high=1, size=100)
|
110 | 91 | prediction_data = pd.DataFrame({"input": x_pred})
|
111 |
| - data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
112 |
| - model_2 = test_ModelBuilder(model_config, sampler_config, data) |
113 |
| - model_2.idata = model_2.fit() |
114 |
| - model_2.predict(prediction_data) |
115 |
| - with pm.Model() as model: |
116 |
| - x = np.linspace(start=0, stop=1, num=100) |
117 |
| - y = 5 * x + 3 |
118 |
| - x = pm.MutableData("x", x) |
119 |
| - y_data = pm.MutableData("y_data", y) |
120 |
| - a_loc = 7 |
121 |
| - a_scale = 3 |
122 |
| - b_loc = 5 |
123 |
| - b_scale = 3 |
124 |
| - obs_error = 2 |
125 |
| - |
126 |
| - a = pm.Normal("a", a_loc, sigma=a_scale) |
127 |
| - b = pm.Normal("b", b_loc, sigma=b_scale) |
128 |
| - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) |
129 |
| - |
130 |
| - y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data) |
| 92 | + pred = model.predict(prediction_data) |
| 93 | + assert "y_model" in pred.keys() |
| 94 | + post_pred = model.predict_posterior(prediction_data) |
| 95 | + assert "y_model" in post_pred.keys() |
131 | 96 |
|
132 |
| - idata = pm.sample(tune=10, draws=20, chains=3, cores=1) |
133 |
| - idata.extend(pm.sample_prior_predictive()) |
134 |
| - idata.extend(pm.sample_posterior_predictive(idata)) |
135 |
| - y_test = pm.sample_posterior_predictive(idata) |
136 |
| - |
137 |
| - assert str(model_2.idata.groups) == str(idata.groups) |
138 | 97 |
|
| 98 | +@pytest.mark.skipif( |
| 99 | + sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." |
| 100 | +) |
| 101 | +def test_save_load(): |
| 102 | + data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
| 103 | + model = test_ModelBuilder(model_config, sampler_config, data) |
| 104 | + model.fit() |
| 105 | + temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) |
| 106 | + model.save(temp.name) |
| 107 | + model2 = test_ModelBuilder.load(temp.name) |
| 108 | + assert model.idata.groups() == model2.idata.groups() |
139 | 109 |
|
140 |
| -def test_predict_posterior(): |
141 | 110 | x_pred = np.random.uniform(low=0, high=1, size=100)
|
142 | 111 | prediction_data = pd.DataFrame({"input": x_pred})
|
143 |
| - data, model_config, sampler_config = test_ModelBuilder.create_sample_input() |
144 |
| - model_2 = test_ModelBuilder(model_config, sampler_config, data) |
145 |
| - model_2.idata = model_2.fit() |
146 |
| - model_2.predict_posterior(prediction_data) |
147 |
| - with pm.Model() as model: |
148 |
| - x = np.linspace(start=0, stop=1, num=100) |
149 |
| - y = 5 * x + 3 |
150 |
| - x = pm.MutableData("x", x) |
151 |
| - y_data = pm.MutableData("y_data", y) |
152 |
| - a_loc = 7 |
153 |
| - a_scale = 3 |
154 |
| - b_loc = 5 |
155 |
| - b_scale = 3 |
156 |
| - obs_error = 2 |
157 |
| - |
158 |
| - a = pm.Normal("a", a_loc, sigma=a_scale) |
159 |
| - b = pm.Normal("b", b_loc, sigma=b_scale) |
160 |
| - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) |
161 |
| - |
162 |
| - y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data) |
163 |
| - |
164 |
| - idata = pm.sample(tune=10, draws=20, chains=3, cores=1) |
165 |
| - idata.extend(pm.sample_prior_predictive()) |
166 |
| - idata.extend(pm.sample_posterior_predictive(idata)) |
167 |
| - y_test = pm.sample_posterior_predictive(idata) |
168 |
| - |
169 |
| - assert str(model_2.idata.groups) == str(idata.groups) |
| 112 | + pred1 = model.predict(prediction_data) |
| 113 | + pred2 = model2.predict(prediction_data) |
| 114 | + assert pred1["y_model"].shape == pred2["y_model"].shape |
| 115 | + temp.close() |
0 commit comments