diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index bb695113..60d89777 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import arviz as az -def fit(method, **kwargs): +def fit(method: str, **kwargs) -> az.InferenceData: """ Fit a model with an inference algorithm diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index dfe5fc6a..531efc56 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -21,11 +21,9 @@ from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass, field, replace from enum import Enum, auto -from importlib.util import find_spec from typing import Literal, TypeAlias import arviz as az -import blackjax import filelock import jax import numpy as np @@ -1736,8 +1734,8 @@ def fit_pathfinder( ) pathfinder_samples = mp_result.samples elif inference_backend == "blackjax": - if find_spec("blackjax") is None: - raise RuntimeError("Need BlackJAX to use `pathfinder`") + import blackjax + if version.parse(blackjax.__version__).major < 1: raise ImportError("fit_pathfinder requires blackjax 1.0 or above")