Skip to content

Commit 370ffe0

Browse files
committed
Added initvals to parameters, constants and observations to returnvalue for pathfinder and cleaned relevant docs a bit
1 parent 7d62c53 commit 370ffe0

File tree

6 files changed

+33
-16
lines changed

6 files changed

+33
-16
lines changed

Diff for: docs/api_reference.rst

+3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ Inference
2323
.. autosummary::
2424
:toctree: generated/
2525

26+
find_MAP
2627
fit
28+
fit_laplace
29+
fit_pathfinder
2730

2831

2932
Distributions

Diff for: pymc_extras/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
18-
from pymc_extras.inference.find_map import find_MAP
19-
from pymc_extras.inference.fit import fit
20-
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
2119
from pymc_extras.model.marginal.marginal_model import (
2220
MarginalModel,
2321
marginalize,

Diff for: pymc_extras/inference/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
from pymc_extras.inference.find_map import find_MAP
1616
from pymc_extras.inference.fit import fit
17+
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1719

18-
__all__ = ["fit"]
20+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]

Diff for: pymc_extras/inference/fit.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@
1515

1616
def fit(method, **kwargs):
1717
"""
18-
Fit a model with an inference algorithm
18+
Fit a model with an inference algorithm.
19+
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
1920
2021
Parameters
2122
----------
2223
method : str
2324
Which inference method to run.
2425
Supported: pathfinder or laplace
2526
26-
kwargs are passed on.
27+
kwargs: keyword arguments are passed on to the inference method.
2728
2829
Returns
2930
-------
30-
arviz.InferenceData
31+
:class:`~arviz.InferenceData`
3132
"""
3233
if method == "pathfinder":
3334
from pymc_extras.inference.pathfinder import fit_pathfinder

Diff for: pymc_extras/inference/laplace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def fit_laplace(
509509
510510
Returns
511511
-------
512-
idata: az.InferenceData
512+
:class:`~arviz.InferenceData`
513513
An InferenceData object containing the approximated posterior samples.
514514
515515
Examples

Diff for: pymc_extras/inference/pathfinder/pathfinder.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import collections
1617
import logging
1718
import time
@@ -67,6 +68,7 @@
6768
# TODO: change to typing.Self after Python versions greater than 3.10
6869
from typing_extensions import Self
6970

71+
from pymc_extras.inference.laplace import add_data_to_inferencedata
7072
from pymc_extras.inference.pathfinder.importance_sampling import (
7173
importance_sampling as _importance_sampling,
7274
)
@@ -1630,6 +1632,7 @@ def fit_pathfinder(
16301632
inference_backend: Literal["pymc", "blackjax"] = "pymc",
16311633
pathfinder_kwargs: dict = {},
16321634
compile_kwargs: dict = {},
1635+
initvals: dict | None = None,
16331636
) -> az.InferenceData:
16341637
"""
16351638
Fit the Pathfinder Variational Inference algorithm.
@@ -1665,12 +1668,12 @@ def fit_pathfinder(
16651668
importance_sampling : str, None, optional
16661669
Method to apply sampling based on log importance weights (logP - logQ).
16671670
Options are:
1668-
"psis" : Pareto Smoothed Importance Sampling (default)
1669-
Recommended for more stable results.
1670-
"psir" : Pareto Smoothed Importance Resampling
1671-
Less stable than PSIS.
1672-
"identity" : Applies log importance weights directly without resampling.
1673-
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1671+
1672+
- **"psis"** : Pareto Smoothed Importance Sampling (default). Usually most stable.
1673+
- **"psir"** : Pareto Smoothed Importance Resampling. Less stable than PSIS.
1674+
- **"identity"** : Applies log importance weights directly without resampling.
1675+
- **None** : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1676+
16741677
progressbar : bool, optional
16751678
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16761679
random_seed : RandomSeed, optional
@@ -1685,17 +1688,24 @@ def fit_pathfinder(
16851688
Additional keyword arguments for the Pathfinder algorithm.
16861689
compile_kwargs
16871690
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1691+
initvals: dict | None = None
1692+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1693+
If None, the model's default initial values are used.
16881694
16891695
Returns
16901696
-------
1691-
arviz.InferenceData
1697+
:class:`~arviz.InferenceData`
16921698
The inference data containing the results of the Pathfinder algorithm.
16931699
16941700
References
16951701
----------
16961702
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
16971703
"""
16981704

1705+
if initvals is not None:
1706+
for rv_name, ivals in initvals.items():
1707+
model.set_initval(model.named_vars[rv_name], ivals)
1708+
16991709
model = modelcontext(model)
17001710

17011711
valid_importance_sampling = {"psis", "psir", "identity", None}
@@ -1775,4 +1785,7 @@ def fit_pathfinder(
17751785
model=model,
17761786
importance_sampling=importance_sampling,
17771787
)
1788+
1789+
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1790+
17781791
return idata

0 commit comments

Comments
 (0)