Skip to content

Commit 1d8f7f5

Browse files
authored
Filter warning for batched_dot until we change it (#455)
1 parent 1fff560 commit 1d8f7f5

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

Diff for: pymc_extras/inference/pathfinder/pathfinder.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,9 @@ def bfgs_sample_dense(
502502

503503
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
504504

505-
mu = x - pt.batched_dot(H_inv, g)
505+
with _warnings.catch_warnings():
506+
_warnings.simplefilter("ignore", category=FutureWarning)
507+
mu = x - pt.batched_dot(H_inv, g)
506508

507509
phi = pt.matrix_transpose(
508510
# (L, N, 1)
@@ -572,14 +574,16 @@ def bfgs_sample_sparse(
572574
logdet += pt.sum(pt.log(alpha), axis=-1)
573575

574576
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
575-
mu = x - (
576-
# (L, N), (L, N) -> (L, N)
577-
pt.batched_dot(alpha_diag, g)
578-
# beta @ gamma @ beta.T
579-
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
580-
# (L, N, N), (L, N) -> (L, N)
581-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
582-
)
577+
with _warnings.catch_warnings():
578+
_warnings.simplefilter("ignore", category=FutureWarning)
579+
mu = x - (
580+
# (L, N), (L, N) -> (L, N)
581+
pt.batched_dot(alpha_diag, g)
582+
# beta @ gamma @ beta.T
583+
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
584+
# (L, N, N), (L, N) -> (L, N)
585+
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
586+
)
583587

584588
phi = pt.matrix_transpose(
585589
# (L, N, 1)

Diff for: tests/test_pathfinder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import pymc as pm
1919
import pytest
2020

21-
pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
21+
pytestmark = pytest.mark.filterwarnings(
22+
"ignore:compile_pymc was renamed to compile:FutureWarning",
23+
)
2224

2325
import pymc_extras as pmx
2426

0 commit comments

Comments
 (0)