diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2f3cac6ea6..2b264df7e7 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -39,7 +39,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.slinalg import Solve from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import MakeSlice, NoneConst +from pytensor.tensor.type_other import MakeSlice, NoneConst, NoneTypeT def global_numba_func(func): @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs): message=( "(\x1b\\[1m)*" # ansi escape code for bold text "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' + '"(store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' "as it uses dynamic globals" ), category=NumbaWarning, @@ -477,6 +477,37 @@ def reshape(x, shape): @numba_funcify.register(SpecifyShape) def numba_funcify_SpecifyShape(op, node, **kwargs): + x, *shape = node.inputs + ndim = x.type.ndim + specified_dims = tuple(not isinstance(dim.type, NoneTypeT) for dim in shape) + match (ndim, specified_dims): + case (1, (True,)): + + def func(x, shape_0): + assert x.shape[0] == shape_0 + return x + case (2, (True, False)): + + def func(x, shape_0, shape_1): + assert x.shape[0] == shape_0 + return x + case (2, (False, True)): + + def func(x, shape_0, shape_1): + assert x.shape[1] == shape_1 + return x + case (2, (True, True)): + + def func(x, shape_0, shape_1): + assert x.shape[0] == shape_0 + assert x.shape[1] == shape_1 + return x + case _: + func = None + + if func is not None: + return numba_njit(func) + shape_inputs = node.inputs[1:] shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..c81cc89830 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -411,7 +411,15 @@ def numba_funcify_CAReduce(op, node, **kwargs): @numba_funcify.register(DimShuffle) -def numba_funcify_DimShuffle(op, node, **kwargs): +def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs): + if op.is_left_expand_dims and op.new_order.count("x") == 1: + # Most common case, numba compiles it more quickly + @numba_njit + def left_expand_dims(x): + return np.expand_dims(x, 0) + + return left_expand_dims + # We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call # Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays. new_order = tuple(op._new_order) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..82369c3976 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -172,18 +172,64 @@ def {binary_op_name}({input_signature}): @numba_funcify.register(Add) def numba_funcify_Add(op, node, **kwargs): + match len(node.inputs): + case 2: + + def add(i0, i1): + return i0 + i1 + case 3: + + def add(i0, i1, i2): + return i0 + i1 + i2 + case 4: + + def add(i0, i1, i2, i3): + return i0 + i1 + i2 + i3 + case 5: + + def add(i0, i1, i2, i3, i4): + return i0 + i1 + i2 + i3 + i4 + case _: + add = None + + if add is not None: + return numba_basic.numba_njit(add) + signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature)(nary_add_fn) + return numba_basic.numba_njit(signature, cache=False)(nary_add_fn) @numba_funcify.register(Mul) def numba_funcify_Mul(op, node, **kwargs): + match len(node.inputs): + case 2: + + def mul(i0, i1): + return i0 * i1 + case 3: + + def mul(i0, i1, i2): + return i0 * i1 * i2 + case 4: + + def mul(i0, i1, i2, i3): + return i0 * i1 * i2 * i3 + case 5: + + def mul(i0, i1, i2, i3, i4): + return i0 * i1 * i2 * i3 * i4 + case _: + mul = None + + if mul is not None: + return numba_basic.numba_njit(mul) + signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature)(nary_add_fn) + return numba_basic.numba_njit(signature, cache=False)(nary_add_fn) @numba_funcify.register(Cast) @@ -233,7 +279,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature)( + composite_fn = numba_basic.numba_njit(signature, cache=False)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ee9e183d16..c35d49c485 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -23,6 +23,60 @@ def numba_funcify_default_subtensor(op, node, **kwargs): """Create a Python function that assembles and uses an index on an array.""" + if isinstance(op, Subtensor) and len(op.idx_list) == 1: + # Hard code indices along first dimension to allow caching + [idx] = op.idx_list + + if isinstance(idx, slice): + slice_info = ( + idx.start is not None, + idx.stop is not None, + idx.step is not None, + ) + match slice_info: + case (False, False, False): + + def subtensor(x): + return x + + case (True, False, False): + + def subtensor(x, start): + return x[start:] + case (False, True, False): + + def subtensor(x, stop): + return x[:stop] + case (False, False, True): + + def subtensor(x, step): + return x[::step] + + case (True, True, False): + + def subtensor(x, start, stop): + return x[start:stop] + case (True, False, True): + + def subtensor(x, start, step): + return x[start::step] + case (False, True, True): + + def subtensor(x, stop, step): + return x[:stop:step] + + case (True, True, True): + + def subtensor(x, start, stop, step): + return x[start:stop:step] + + else: + + def subtensor(x, i): + return np.asarray(x[i]) + + return numba_njit(subtensor) + unique_names = unique_name_generator( ["subtensor", "incsubtensor", "z"], suffix_sep="_" ) @@ -100,7 +154,7 @@ def {function_name}({", ".join(input_names)}): function_name=function_name, global_env=globals() | {"np": np}, ) - return numba_njit(func, boundscheck=True) + return numba_njit(func, boundscheck=True, cache=False) @numba_funcify.register(AdvancedSubtensor) @@ -294,7 +348,9 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): if broadcast: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): + def advanced_incsubtensor1(x, val, idxs): + out = x if inplace else x.copy() + if val.ndim == x.ndim: core_val = val[0] elif val.ndim == 0: @@ -304,24 +360,28 @@ def advancedincsubtensor1_inplace(x, val, idxs): core_val = val for idx in idxs: - x[idx] = core_val - return x + out[idx] = core_val + return out else: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): + def advanced_incsubtensor1(x, vals, idxs): + out = x if inplace else x.copy() + if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") # no strict argument because incompatible with numba for idx, val in zip(idxs, vals): # noqa: B905 - x[idx] = val - return x + out[idx] = val + return out else: if broadcast: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): + def advanced_incsubtensor1(x, val, idxs): + out = x if inplace else x.copy() + if val.ndim == x.ndim: core_val = val[0] elif val.ndim == 0: @@ -331,29 +391,21 @@ def advancedincsubtensor1_inplace(x, val, idxs): core_val = val for idx in idxs: - x[idx] += core_val - return x + out[idx] += core_val + return out else: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): + def advanced_incsubtensor1(x, vals, idxs): + out = x if inplace else x.copy() + if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") # no strict argument because unsupported by numba # TODO: this doesn't come up in tests for idx, val in zip(idxs, vals): # noqa: B905 - x[idx] += val - return x - - if inplace: - return advancedincsubtensor1_inplace - - else: - - @numba_njit - def advancedincsubtensor1(x, vals, idxs): - x = x.copy() - return advancedincsubtensor1_inplace(x, vals, idxs) + out[idx] += val + return out - return advancedincsubtensor1 + return advanced_incsubtensor1 diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..570a234fb9 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -58,6 +58,36 @@ def allocempty({", ".join(shape_var_names)}): @numba_funcify.register(Alloc) def numba_funcify_Alloc(op, node, **kwargs): + x, *shape = node.inputs + if all(x.type.broadcastable): + match len(shape): + case 1: + + def alloc(val, dim0): + shape = (dim0.item(),) + res = np.empty(shape, dtype=val.dtype) + res[...] = val + return res + case 2: + + def alloc(val, dim0, dim1): + shape = (dim0.item(), dim1.item()) + res = np.empty(shape, dtype=val.dtype) + res[...] = val + return res + case 3: + + def alloc(val, dim0, dim1, dim2): + shape = (dim0.item(), dim1.item(), dim2.item()) + res = np.empty(shape, dtype=val.dtype) + res[...] = val + return res + case _: + alloc = None + + if alloc is not None: + return numba_basic.numba_njit(alloc) + global_env = {"np": np, "to_scalar": numba_basic.to_scalar} unique_names = unique_name_generator( @@ -68,7 +98,7 @@ def numba_funcify_Alloc(op, node, **kwargs): shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -86,12 +116,11 @@ def numba_funcify_Alloc(op, node, **kwargs): alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): - val_np = np.asarray(val) {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} - res = np.empty(scalar_shape, dtype=val_np.dtype) - res[...] = val_np + res = np.empty(scalar_shape, dtype=val.dtype) + res[...] = val return res """ alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 74870e29bd..6ad6121719 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -35,6 +35,97 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): on[...] = ton """ + # Hardcode some cases for numba caching + match (nin, nout): + case (1, 1): + + def func(i0, o0): + t0 = core_op_fn(i0) + o0[...] = t0 + case (1, 2): + + def func(i0, o0, o1): + t0, t1 = core_op_fn(i0) + o0[...] = t0 + o1[...] = t1 + case (1, 3): + + def func(i0, o0, o1, o2): + t0, t1, t2 = core_op_fn(i0) + o0[...] = t0 + o1[...] = t1 + o2[...] = t2 + + case (2, 1): + + def func(i0, i1, o0): + t0 = core_op_fn(i0, i1) + o0[...] = t0 + case (2, 2): + + def func(i0, i1, o0, o1): + t0, t1 = core_op_fn(i0, i1) + o0[...] = t0 + o1[...] = t1 + case (2, 3): + + def func(i0, i1, o0, o1, o2): + t0, t1, t2 = core_op_fn(i0, i1) + o0[...] = t0 + o1[...] = t1 + o2[...] = t2 + + case (3, 1): + + def func(i0, i1, i2, o0): + t0 = core_op_fn(i0, i1, i2) + o0[...] = t0 + + case (3, 2): + + def func(i0, i1, i2, o0, o1): + t0, t1 = core_op_fn(i0, i1, i2) + o0[...] = t0 + o1[...] = t1 + case (3, 3): + + def func(i0, i1, i2, o0, o1, o2): + t0, t1, t2 = core_op_fn(i0, i1, i2) + o0[...] = t0 + o1[...] = t1 + o2[...] = t2 + + case (4, 1): + + def func(i0, i1, i2, i3, o0): + t0 = core_op_fn(i0, i1, i2, i3) + o0[...] = t0 + + case (4, 2): + + def func(i0, i1, i2, i3, o0, o1): + t0, t1 = core_op_fn(i0, i1, i2, i3) + o0[...] = t0 + o1[...] = t1 + + case (5, 1): + + def func(i0, i1, i2, i3, i4, o0): + t0 = core_op_fn(i0, i1, i2, i3, i4) + o0[...] = t0 + + case (5, 2): + + def func(i0, i1, i2, i3, i4, o0, o1): + t0, t1 = core_op_fn(i0, i1, i2, i3, i4) + o0[...] = t0 + o1[...] = t1 + case _: + func = None + + if func is not None: + return numba_basic.numba_njit(func) + inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] @@ -55,7 +146,7 @@ def store_core_outputs({inp_signature}, {out_signature}): func = compile_function_src( func_src, "store_core_outputs", {**globals(), **global_env} ) - return cast(Callable, numba_basic.numba_njit(func)) + return cast(Callable, numba_basic.numba_njit(func, cache=False)) _jit_options = { diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..ccad308db5 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -12,7 +12,10 @@ def fgraph_convert(self, fgraph, **kwargs): def jit_compile(self, fn): from pytensor.link.numba.dispatch.basic import numba_njit - jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) + # NUMBA can't cache our dynamically generated funcified_fgraph + jitted_fn = numba_njit( + fn, no_cpython_wrapper=False, no_cfunc_wrapper=False, cache=False + ) return jitted_fn def create_thunk_inputs(self, storage_map):