Skip to content

Hardcode common Op parametrizations to allow numba caching #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.type_other import MakeSlice, NoneConst, NoneTypeT


def global_numba_func(func):
Expand Down Expand Up @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
'"(store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
Expand Down Expand Up @@ -477,6 +477,37 @@ def reshape(x, shape):

@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, node, **kwargs):
x, *shape = node.inputs
ndim = x.type.ndim
specified_dims = tuple(not isinstance(dim.type, NoneTypeT) for dim in shape)
match (ndim, specified_dims):
case (1, (True,)):

def func(x, shape_0):
assert x.shape[0] == shape_0
return x
case (2, (True, False)):

def func(x, shape_0, shape_1):
assert x.shape[0] == shape_0
return x
case (2, (False, True)):

def func(x, shape_0, shape_1):
assert x.shape[1] == shape_1
return x
case (2, (True, True)):

def func(x, shape_0, shape_1):
assert x.shape[0] == shape_0
assert x.shape[1] == shape_1
return x
case _:
func = None

if func is not None:
return numba_njit(func)

shape_inputs = node.inputs[1:]
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]

Expand Down
10 changes: 9 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):


@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs):
def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
if op.is_left_expand_dims and op.new_order.count("x") == 1:
# Most common case, numba compiles it more quickly
@numba_njit
def left_expand_dims(x):
return np.expand_dims(x, 0)

return left_expand_dims

# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
new_order = tuple(op._new_order)
Expand Down
52 changes: 49 additions & 3 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,64 @@ def {binary_op_name}({input_signature}):

@numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs):
match len(node.inputs):
case 2:

def add(i0, i1):
return i0 + i1
case 3:

def add(i0, i1, i2):
return i0 + i1 + i2
case 4:

def add(i0, i1, i2, i3):
return i0 + i1 + i2 + i3
case 5:

def add(i0, i1, i2, i3, i4):
return i0 + i1 + i2 + i3 + i4
case _:
add = None

if add is not None:
return numba_basic.numba_njit(add)

signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba_basic.numba_njit(signature)(nary_add_fn)
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)


@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
match len(node.inputs):
case 2:

def mul(i0, i1):
return i0 * i1
case 3:

def mul(i0, i1, i2):
return i0 * i1 * i2
case 4:

def mul(i0, i1, i2, i3):
return i0 * i1 * i2 * i3
case 5:

def mul(i0, i1, i2, i3, i4):
return i0 * i1 * i2 * i3 * i4
case _:
mul = None

if mul is not None:
return numba_basic.numba_njit(mul)

signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba_basic.numba_njit(signature)(nary_add_fn)
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)


@numba_funcify.register(Cast)
Expand Down Expand Up @@ -233,7 +279,7 @@ def numba_funcify_Composite(op, node, **kwargs):

_ = kwargs.pop("storage_map", None)

composite_fn = numba_basic.numba_njit(signature)(
composite_fn = numba_basic.numba_njit(signature, cache=False)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
Expand Down
100 changes: 76 additions & 24 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,60 @@
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""

if isinstance(op, Subtensor) and len(op.idx_list) == 1:
# Hard code indices along first dimension to allow caching
[idx] = op.idx_list

if isinstance(idx, slice):
slice_info = (
idx.start is not None,
idx.stop is not None,
idx.step is not None,
)
match slice_info:
case (False, False, False):

def subtensor(x):
return x

case (True, False, False):

def subtensor(x, start):
return x[start:]
case (False, True, False):

def subtensor(x, stop):
return x[:stop]
case (False, False, True):

def subtensor(x, step):
return x[::step]

case (True, True, False):

def subtensor(x, start, stop):
return x[start:stop]
case (True, False, True):

def subtensor(x, start, step):
return x[start::step]
case (False, True, True):

def subtensor(x, stop, step):
return x[:stop:step]

case (True, True, True):

def subtensor(x, start, stop, step):
return x[start:stop:step]

else:

def subtensor(x, i):
return np.asarray(x[i])

return numba_njit(subtensor)

unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
Expand Down Expand Up @@ -100,7 +154,7 @@ def {function_name}({", ".join(input_names)}):
function_name=function_name,
global_env=globals() | {"np": np},
)
return numba_njit(func, boundscheck=True)
return numba_njit(func, boundscheck=True, cache=False)


@numba_funcify.register(AdvancedSubtensor)
Expand Down Expand Up @@ -294,7 +348,9 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if broadcast:

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
def advanced_incsubtensor1(x, val, idxs):
out = x if inplace else x.copy()

if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
Expand All @@ -304,24 +360,28 @@ def advancedincsubtensor1_inplace(x, val, idxs):
core_val = val

for idx in idxs:
x[idx] = core_val
return x
out[idx] = core_val
return out

else:

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
def advanced_incsubtensor1(x, vals, idxs):
out = x if inplace else x.copy()

if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
# no strict argument because incompatible with numba
for idx, val in zip(idxs, vals): # noqa: B905
x[idx] = val
return x
out[idx] = val
return out
else:
if broadcast:

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
def advanced_incsubtensor1(x, val, idxs):
out = x if inplace else x.copy()

if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
Expand All @@ -331,29 +391,21 @@ def advancedincsubtensor1_inplace(x, val, idxs):
core_val = val

for idx in idxs:
x[idx] += core_val
return x
out[idx] += core_val
return out

else:

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
def advanced_incsubtensor1(x, vals, idxs):
out = x if inplace else x.copy()

if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for idx, val in zip(idxs, vals): # noqa: B905
x[idx] += val
return x

if inplace:
return advancedincsubtensor1_inplace

else:

@numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
out[idx] += val
return out

return advancedincsubtensor1
return advanced_incsubtensor1
37 changes: 33 additions & 4 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@ def allocempty({", ".join(shape_var_names)}):

@numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
x, *shape = node.inputs
if all(x.type.broadcastable):
match len(shape):
case 1:

def alloc(val, dim0):
shape = (dim0.item(),)
res = np.empty(shape, dtype=val.dtype)
res[...] = val
return res
case 2:

def alloc(val, dim0, dim1):
shape = (dim0.item(), dim1.item())
res = np.empty(shape, dtype=val.dtype)
res[...] = val
return res
case 3:

def alloc(val, dim0, dim1, dim2):
shape = (dim0.item(), dim1.item(), dim2.item())
res = np.empty(shape, dtype=val.dtype)
res[...] = val
return res
case _:
alloc = None

if alloc is not None:
return numba_basic.numba_njit(alloc)

global_env = {"np": np, "to_scalar": numba_basic.to_scalar}

unique_names = unique_name_generator(
Expand All @@ -68,7 +98,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
f"{item_name} = to_scalar({shape_name})"
f"{item_name} = {shape_name}.item()"
for item_name, shape_name in zip(
shape_var_item_names, shape_var_names, strict=True
)
Expand All @@ -86,12 +116,11 @@ def numba_funcify_Alloc(op, node, **kwargs):

alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
res = np.empty(scalar_shape, dtype=val.dtype)
res[...] = val
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
Expand Down
Loading
Loading