diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 845d6afc7a..f6e62ae2f8 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -565,18 +565,19 @@ def specify_shape(x, {create_arg_string(shape_input_names)}): def int_to_float_fn(inputs, out_dtype): """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" - if all( - input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs - ) and isinstance(np.dtype(out_dtype), np.floating): + if ( + all(inp.type.dtype == out_dtype for inp in inputs) + and np.dtype(out_dtype).kind == "f" + ): - @numba_njit + @numba_njit(inline="always") def inputs_cast(x): return x - elif any(i.type.numpy_dtype.kind in "ib" for i in inputs): + elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): args_dtype = np.dtype(f"f{out_dtype.itemsize}") - @numba_njit + @numba_njit(inline="always") def inputs_cast(x): return x.astype(args_dtype) @@ -584,7 +585,7 @@ def inputs_cast(x): args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) args_dtype = np.dtype(f"f{args_dtype_sz}") - @numba_njit + @numba_njit(inline="always") def inputs_cast(x): return x.astype(args_dtype) @@ -593,17 +594,49 @@ def inputs_cast(x): @numba_funcify.register(Dot) def numba_funcify_Dot(op, node, **kwargs): - # Numba's `np.dot` does not support integer dtypes, so we need to cast to - # float. + # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. + x, y = node.inputs + [out] = node.outputs - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) + x_dtype = x.type.dtype + y_dtype = y.type.dtype + dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" + out_dtype = out.type.dtype - @numba_njit - def dot(x, y): - return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype) + if x_dtype == dot_dtype and y_dtype == dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x, y)) + + elif x_dtype == dot_dtype and y_dtype != dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x, y.astype(dot_dtype))) + + elif x_dtype != dot_dtype and y_dtype == dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x.astype(dot_dtype), y)) + + else: + + @numba_njit() + def dot(x, y): + return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) + + if out_dtype == dot_dtype: + return dot + + else: + + @numba_njit + def dot_with_cast(x, y): + return dot(x, y).astype(out_dtype) - return dot + return dot_with_cast @numba_funcify.register(Solve) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 9132d7b202..3b880616df 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -30,7 +30,7 @@ from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar -from pytensor.tensor import blas +from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.sort import ArgSortOp, SortOp @@ -603,43 +603,41 @@ def test_perform_type_convert(): @pytest.mark.parametrize( - "x, y, exc", + "x, y", [ ( (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), - None, ), ( (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")), (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")), - None, ), ( (pt.lmatrix(), rng.poisson(size=(3, 2))), (pt.fvector(), rng.random(size=(2,)).astype("float32")), - None, ), ( (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), - None, + ), + ( + (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), + (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), ), ], ) -def test_Dot(x, y, exc): +def test_Dot(x, y): x, x_test_value = x y, y_test_value = y g = ptm.Dot()(x, y) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x, y], - [g], - [x_test_value, y_test_value], - ) + compare_numba_and_py( + [x, y], + [g], + [x_test_value, y_test_value], + ) @pytest.mark.parametrize( @@ -937,3 +935,18 @@ def test_Nonzero(input_data): compare_numba_and_py( graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data] ) + + +@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) +def test_mat_vec_dot_performance(dtype, benchmark): + A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) + x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype) + out = ptm.dot(A, x) + + fn = function([A, x], out, mode="NUMBA", trust_input=True) + + rng = np.random.default_rng(948) + A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype) + x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) + np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) + benchmark(fn, A_test, x_test)