Skip to content

Fix slow dot in numba #1426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,26 +565,27 @@
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""

if all(
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
) and isinstance(np.dtype(out_dtype), np.floating):
if (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old check doesn't work or stopped working at some point:

assert (
    isinstance(np.dtype("float64"), np.floating)
    or isinstance(np.float64, np.floating) 
)  # fails

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use this function anymore in dot, but it's used elsewhere

all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):

@numba_njit
@numba_njit(inline="always")
def inputs_cast(x):
return x

elif any(i.type.numpy_dtype.kind in "ib" for i in inputs):
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")

@numba_njit
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")

@numba_njit
@numba_njit(inline="always")

Check warning on line 588 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L588

Added line #L588 was not covered by tests
def inputs_cast(x):
return x.astype(args_dtype)

Expand All @@ -593,17 +594,49 @@

@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
# float.
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
x, y = node.inputs
[out] = node.outputs

out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
x_dtype = x.type.dtype
y_dtype = y.type.dtype
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
out_dtype = out.type.dtype

@numba_njit
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
if x_dtype == dot_dtype and y_dtype == dot_dtype:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need these branches otherwise, I get some failed to unify during numba compilation.
It doesn't like the pattern:

if x.dtype != dot_dtype:
  x = x.astype(dot_dtype)

When it's actually needed. Can be simplified once astype(copy=False) is implemented in numba


@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))

elif x_dtype == dot_dtype and y_dtype != dot_dtype:

@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype)))

elif x_dtype != dot_dtype and y_dtype == dot_dtype:

@numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y))

Check warning on line 622 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L620-L622

Added lines #L620 - L622 were not covered by tests

else:

@numba_njit()
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))

if out_dtype == dot_dtype:
return dot

else:

@numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)

return dot
return dot_with_cast


@numba_funcify.register(Solve)
Expand Down
41 changes: 27 additions & 14 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.sort import ArgSortOp, SortOp
Expand Down Expand Up @@ -603,43 +603,41 @@ def test_perform_type_convert():


@pytest.mark.parametrize(
"x, y, exc",
"x, y",
[
(
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
),
(
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
None,
),
(
(pt.lmatrix(), rng.poisson(size=(3, 2))),
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
None,
),
(
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
None,
),
(
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
),
],
)
def test_Dot(x, y, exc):
def test_Dot(x, y):
x, x_test_value = x
y, y_test_value = y

g = ptm.Dot()(x, y)

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x, y],
[g],
[x_test_value, y_test_value],
)
compare_numba_and_py(
[x, y],
[g],
[x_test_value, y_test_value],
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -937,3 +935,18 @@ def test_Nonzero(input_data):
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)


@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
def test_mat_vec_dot_performance(dtype, benchmark):
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
out = ptm.dot(A, x)

fn = function([A, x], out, mode="NUMBA", trust_input=True)

rng = np.random.default_rng(948)
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
benchmark(fn, A_test, x_test)