Skip to content

Commit 0ffdf22

Browse files
committed
Add workaround for floatX == 'float32' and discrete variables
1 parent 22da46a commit 0ffdf22

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

Diff for: pymc3/smc/smc.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
from collections import OrderedDict
1618

1719
import aesara.tensor as at
1820
import numpy as np
1921

22+
from aesara import config
2023
from aesara import function as aesara_function
2124
from scipy.special import logsumexp
2225
from scipy.stats import multivariate_normal
@@ -290,9 +293,21 @@ def logp_forw(point, out_vars, vars, shared):
290293
shared: List
291294
containing :class:`aesara.tensor.Tensor` for depended shared data
292295
"""
296+
293297
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
294-
f = aesara_function([inarray0], out_list[0])
295-
f.trust_input = True
298+
# TODO: Figure out how to safely accept float32 (floatX) input when there are
299+
# discrete variables of int64 dtype in `vars`.
300+
# See https://github.com/pymc-devs/pymc3/pull/4769#issuecomment-861494080
301+
if config.floatX == "float32" and any(var.dtype == "int64" for var in vars):
302+
warnings.warn(
303+
"SMC sampling may run slower due to the presence of discrete variables "
304+
"together with aesara.config.floatX == `float32`",
305+
UserWarning,
306+
)
307+
f = aesara_function([inarray0], out_list[0], allow_input_downcast=True)
308+
else:
309+
f = aesara_function([inarray0], out_list[0])
310+
f.trust_input = False
296311
return f
297312

298313

Diff for: pymc3/tests/test_smc.py

+9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara
1516
import aesara.tensor as at
1617
import numpy as np
1718
import pytest
@@ -97,6 +98,14 @@ def test_start(self):
9798
}
9899
trace = pm.sample_smc(500, start=start)
99100

101+
def test_slowdown_warning(self):
102+
with aesara.config.change_flags(floatX="float32"):
103+
with pytest.warns(UserWarning, match="SMC sampling may run slower due to"):
104+
with pm.Model() as model:
105+
a = pm.Poisson("a", 5)
106+
y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4])
107+
trace = pm.sample_smc()
108+
100109

101110
@pytest.mark.xfail(reason="SMC-ABC not refactored yet")
102111
class TestSMCABC(SeededTest):

0 commit comments

Comments
 (0)