Skip to content

Commit 1fff560

Browse files
authored
Minor fix of blackjax import in fit_pathfinder function (#443)
* Moved the import statement for blackjax to ensure it is only imported when needed. * Moved blackjax import statement prevents import errors for users on Windows. * Updated the fit function to specify the return type as az.InferenceData.
1 parent 7d62c53 commit 1fff560

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

Diff for: pymc_extras/inference/fit.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import arviz as az
1415

1516

16-
def fit(method, **kwargs):
17+
def fit(method: str, **kwargs) -> az.InferenceData:
1718
"""
1819
Fit a model with an inference algorithm
1920

Diff for: pymc_extras/inference/pathfinder/pathfinder.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@
2121
from collections.abc import Callable, Iterator
2222
from dataclasses import asdict, dataclass, field, replace
2323
from enum import Enum, auto
24-
from importlib.util import find_spec
2524
from typing import Literal, TypeAlias
2625

2726
import arviz as az
28-
import blackjax
2927
import filelock
3028
import jax
3129
import numpy as np
@@ -1736,8 +1734,8 @@ def fit_pathfinder(
17361734
)
17371735
pathfinder_samples = mp_result.samples
17381736
elif inference_backend == "blackjax":
1739-
if find_spec("blackjax") is None:
1740-
raise RuntimeError("Need BlackJAX to use `pathfinder`")
1737+
import blackjax
1738+
17411739
if version.parse(blackjax.__version__).major < 1:
17421740
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
17431741

0 commit comments

Comments
 (0)