Skip to content

Commit 77a3f8f

Browse files
5hv5hvnktwiecki
andauthored
added predict_posterior (#90)
* added predict_posterior * changed tests * Update pymc_experimental/model_builder.py Co-authored-by: Thomas Wiecki <[email protected]>
1 parent e64d1f2 commit 77a3f8f

File tree

2 files changed

+74
-8
lines changed

2 files changed

+74
-8
lines changed

Diff for: pymc_experimental/model_builder.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,49 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
245245
def predict(
246246
self,
247247
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
248-
point_estimate: bool = True,
249248
):
250249
"""
251-
Uses model to predict on unseen data.
250+
Uses model to predict on unseen data and return point prediction of all the samples
251+
252+
Parameters
253+
---------
254+
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
255+
It is the data we need to make prediction on using the model.
256+
257+
Returns
258+
-------
259+
returns dictionary of sample's mean of posterior predict.
260+
261+
Examples
262+
--------
263+
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
264+
>>> model = LinearModel(model_config, sampler_config)
265+
>>> idata = model.fit(data)
266+
>>> x_pred = []
267+
>>> prediction_data = pd.DataFrame({'input':x_pred})
268+
# point predict
269+
>>> pred_mean = model.predict(prediction_data)
270+
"""
271+
272+
if data_prediction is not None: # set new input data
273+
self._data_setter(data_prediction)
274+
275+
with self.model: # sample with new input data
276+
post_pred = pm.sample_posterior_predictive(self.idata)
277+
278+
# reshape output
279+
post_pred = self._extract_samples(post_pred)
280+
for key in post_pred:
281+
post_pred[key] = post_pred[key].mean(axis=0)
282+
283+
return post_pred
284+
285+
def predict_posterior(
286+
self,
287+
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
288+
):
289+
"""
290+
Uses model to predict samples on unseen data.
252291
253292
Parameters
254293
---------
@@ -268,10 +307,8 @@ def predict(
268307
>>> idata = model.fit(data)
269308
>>> x_pred = []
270309
>>> prediction_data = pd.DataFrame({'input':x_pred})
271-
# only point estimate
272-
>>> pred_mean = model.predict(prediction_data)
273310
# samples
274-
>>> pred_samples = model.predict(prediction_data, point_estimate=False)
311+
>>> pred_mean = model.predict_posterior(prediction_data)
275312
"""
276313

277314
if data_prediction is not None: # set new input data
@@ -282,9 +319,6 @@ def predict(
282319

283320
# reshape output
284321
post_pred = self._extract_samples(post_pred)
285-
if point_estimate: # average, if point-like estimate desired
286-
for key in post_pred:
287-
post_pred[key] = post_pred[key].mean(axis=0)
288322

289323
return post_pred
290324

Diff for: pymc_experimental/tests/test_model_builder.py

+32
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,35 @@ def test_predict():
135135
y_test = pm.sample_posterior_predictive(idata)
136136

137137
assert str(model_2.idata.groups) == str(idata.groups)
138+
139+
140+
def test_predict_posterior():
141+
x_pred = np.random.uniform(low=0, high=1, size=100)
142+
prediction_data = pd.DataFrame({"input": x_pred})
143+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
144+
model_2 = test_ModelBuilder(model_config, sampler_config, data)
145+
model_2.idata = model_2.fit()
146+
model_2.predict_posterior(prediction_data)
147+
with pm.Model() as model:
148+
x = np.linspace(start=0, stop=1, num=100)
149+
y = 5 * x + 3
150+
x = pm.MutableData("x", x)
151+
y_data = pm.MutableData("y_data", y)
152+
a_loc = 7
153+
a_scale = 3
154+
b_loc = 5
155+
b_scale = 3
156+
obs_error = 2
157+
158+
a = pm.Normal("a", a_loc, sigma=a_scale)
159+
b = pm.Normal("b", b_loc, sigma=b_scale)
160+
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
161+
162+
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
163+
164+
idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
165+
idata.extend(pm.sample_prior_predictive())
166+
idata.extend(pm.sample_posterior_predictive(idata))
167+
y_test = pm.sample_posterior_predictive(idata)
168+
169+
assert str(model_2.idata.groups) == str(idata.groups)

0 commit comments

Comments
 (0)