Skip to content

Fix numba 0.61 compatibility #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import operator
import sys
import warnings
from contextlib import contextmanager
from copy import copy
from functools import singledispatch
from textwrap import dedent
Expand Down Expand Up @@ -362,23 +361,6 @@ def create_arg_string(x):
return args


@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target

context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
Expand Down
52 changes: 4 additions & 48 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
numba_funcify,
numba_njit,
use_optimized_cheap_pass,
)
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
Expand Down Expand Up @@ -245,47 +243,6 @@ def {careduce_fn_name}(x):
return careduce_fn


def jit_compile_reducer(
node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds
):
"""Compile Python source for reduction loops using additional optimizations.

Parameters
==========
node
An node from which the signature can be derived.
fn
The Python function object to compile.
reduce_to_scalar: bool, default False
Whether to reduce output to a scalar (instead of 0d array)
infer_signature: bool: default True
Whether to try and infer the function signature from the Apply node.
kwds
Extra keywords to be added to the :func:`numba.njit` function.

Returns
=======
A :func:`numba.njit`-compiled function.

"""
if infer_signature:
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
args = (signature,)
else:
args = ()

# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with use_optimized_cheap_pass():
res = numba_basic.numba_njit(
*args,
boundscheck=False,
**kwds,
)(fn)

return res


def create_axis_apply_fn(fn, axis, ndim, dtype):
axis = normalize_axis_index(axis, ndim)

Expand Down Expand Up @@ -448,7 +405,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
np.dtype(node.outputs[0].type.dtype),
)

careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
careduce_fn = numba_njit(careduce_py_fn, boundscheck=False)
return careduce_fn


Expand Down Expand Up @@ -579,7 +536,7 @@ def softmax_py_fn(x):
sm = e_x / w
return sm

softmax = jit_compile_reducer(node, softmax_py_fn)
softmax = numba_njit(softmax_py_fn, boundscheck=False)

return softmax

Expand Down Expand Up @@ -608,8 +565,7 @@ def softmax_grad_py_fn(dy, sm):
dx = dy_times_sm - sum_dy_times_sm * sm
return dx

# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False)
softmax_grad = numba_njit(softmax_grad_py_fn, boundscheck=False)

return softmax_grad

Expand Down Expand Up @@ -647,7 +603,7 @@ def log_softmax_py_fn(x):
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm

log_softmax = jit_compile_reducer(node, log_softmax_py_fn)
log_softmax = numba_njit(log_softmax_py_fn, boundscheck=False)
return log_softmax


Expand Down
138 changes: 68 additions & 70 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,6 @@ def test_elemwise_runtime_broadcast():
check_elemwise_runtime_broadcast(get_mode("NUMBA"))


def test_elemwise_speed(benchmark):
x = pt.dmatrix("y")
y = pt.dvector("z")

out = np.exp(2 * x * y + y)

rng = np.random.default_rng(42)

x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)

func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)

benchmark(func, x_val, y_val)


@pytest.mark.parametrize(
"v, new_order",
[
Expand Down Expand Up @@ -631,41 +612,6 @@ def test_Argmax(x, axes, exc):
)


@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = pt.matrix("X")
X_max = pt.max(X, axis=axis, keepdims=True)
X_max = pt.switch(pt.isinf(X_max), 0, X_max)
X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max

rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)

X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")

# JIT compile first
res = X_lse_fn(X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)


def test_fused_elemwise_benchmark(benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")

logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)

func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
# JIT compile first
func()
benchmark(func)


def test_elemwise_out_type():
# Create a graph with an elemwise
# Ravel failes if the elemwise output type is reported incorrectly
Expand All @@ -681,22 +627,6 @@ def test_elemwise_out_type():
assert func(x_val).shape == (18,)


@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)


def test_scalar_loop():
a = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
Expand All @@ -709,3 +639,71 @@ def test_scalar_loop():
([x], [elemwise_loop]),
(np.array([1, 2, 3], dtype="float64"),),
)


class TestsBenchmark:
def test_elemwise_speed(self, benchmark):
x = pt.dmatrix("y")
y = pt.dvector("z")

out = np.exp(2 * x * y + y)

rng = np.random.default_rng(42)

x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)

func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)

benchmark(func, x_val, y_val)

def test_fused_elemwise_benchmark(self, benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")

logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)

func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
# JIT compile first
func()
benchmark(func)

@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(self, size, axis, benchmark):
X = pt.matrix("X")
X_max = pt.max(X, axis=axis, keepdims=True)
X_max = pt.switch(pt.isinf(X_max), 0, X_max)
X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max

rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)

X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")

# JIT compile first
res = X_lse_fn(X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)

@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)