@@ -245,10 +245,49 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
245
245
def predict (
246
246
self ,
247
247
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
248
- point_estimate : bool = True ,
249
248
):
250
249
"""
251
- Uses model to predict on unseen data.
250
+ Uses model to predict on unseen data and return point prediction of all the samples
251
+
252
+ Parameters
253
+ ---------
254
+ data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
255
+ It is the data we need to make prediction on using the model.
256
+
257
+ Returns
258
+ -------
259
+ returns dictionary of sample's mean of posterior predict.
260
+
261
+ Examples
262
+ --------
263
+ >>> data, model_config, sampler_config = LinearModel.create_sample_input()
264
+ >>> model = LinearModel(model_config, sampler_config)
265
+ >>> idata = model.fit(data)
266
+ >>> x_pred = []
267
+ >>> prediction_data = pd.DataFrame({'input':x_pred})
268
+ # point predict
269
+ >>> pred_mean = model.predict(prediction_data)
270
+ """
271
+
272
+ if data_prediction is not None : # set new input data
273
+ self ._data_setter (data_prediction )
274
+
275
+ with self .model : # sample with new input data
276
+ post_pred = pm .sample_posterior_predictive (self .idata )
277
+
278
+ # reshape output
279
+ post_pred = self ._extract_samples (post_pred )
280
+ for key in post_pred :
281
+ post_pred [key ] = post_pred [key ].mean (axis = 0 )
282
+
283
+ return post_pred
284
+
285
+ def predict_posterior (
286
+ self ,
287
+ data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
288
+ ):
289
+ """
290
+ Uses model to predict samples on unseen data.
252
291
253
292
Parameters
254
293
---------
@@ -268,10 +307,8 @@ def predict(
268
307
>>> idata = model.fit(data)
269
308
>>> x_pred = []
270
309
>>> prediction_data = pd.DataFrame({'input':x_pred})
271
- # only point estimate
272
- >>> pred_mean = model.predict(prediction_data)
273
310
# samples
274
- >>> pred_samples = model.predict (prediction_data, point_estimate=False )
311
+ >>> pred_mean = model.predict_posterior (prediction_data)
275
312
"""
276
313
277
314
if data_prediction is not None : # set new input data
@@ -282,9 +319,6 @@ def predict(
282
319
283
320
# reshape output
284
321
post_pred = self ._extract_samples (post_pred )
285
- if point_estimate : # average, if point-like estimate desired
286
- for key in post_pred :
287
- post_pred [key ] = post_pred [key ].mean (axis = 0 )
288
322
289
323
return post_pred
290
324
0 commit comments