diff --git a/pytensor/compile/profiling.py b/pytensor/compile/profiling.py index a68365527f..0d2e48c114 100644 --- a/pytensor/compile/profiling.py +++ b/pytensor/compile/profiling.py @@ -1480,8 +1480,8 @@ def print_tips(self, file): ps.XOR, ps.AND, ps.Invert, - ps.ScalarMaximum, - ps.ScalarMinimum, + ps.Maximum, + ps.Minimum, ps.Add, ps.Mul, ps.Sub, diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index d3e5ac11f7..14fd9844c8 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -14,6 +14,8 @@ Composite, Identity, IntDiv, + Maximum, + Minimum, Mod, Mul, ScalarOp, @@ -172,6 +174,22 @@ def elemwise(x, y): return elemwise +@jax_funcify.register(Maximum) +def jax_funcify_scalar_Maximum(op, **kwargs): + def elemwise(*inputs): + return functools.reduce(jnp.maximum, inputs[1:], inputs[0]) + + return elemwise + + +@jax_funcify.register(Minimum) +def jax_funcify_scalar_Minimum(op, **kwargs): + def elemwise(*inputs): + return functools.reduce(jnp.minimum, inputs[1:], inputs[0]) + + return elemwise + + @jax_funcify.register(Cast) def jax_funcify_Cast(op, **kwargs): def cast(x): diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..d9f9f80154 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -26,13 +26,13 @@ XOR, Add, IntDiv, + Maximum, + Minimum, Mul, - ScalarMaximum, - ScalarMinimum, Sub, TrueDiv, get_scalar_type, - scalar_maximum, + maximum, ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -103,16 +103,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr): return f"{res}[{idx}] //= {arr}" -@scalar_in_place_fn.register(ScalarMaximum) -def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr): +@scalar_in_place_fn.register(Maximum) +def scalar_in_place_fn_Maximum(op, idx, res, arr): return f""" if {res}[{idx}] < {arr}: {res}[{idx}] = {arr} """ -@scalar_in_place_fn.register(ScalarMinimum) -def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): +@scalar_in_place_fn.register(Minimum) +def scalar_in_place_fn_Minimum(op, idx, res, arr): return f""" if {res}[{idx}] > {arr}: {res}[{idx}] = {arr} @@ -458,7 +458,7 @@ def numba_funcify_Softmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_multiaxis_reducer( - scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True + maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) reduce_sum_py = create_multiaxis_reducer( add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True @@ -522,7 +522,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_multiaxis_reducer( - scalar_maximum, + maximum, -np.inf, (axis,), x_at.ndim, diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..5af7343a27 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -9,6 +9,7 @@ create_numba_signature, generate_fallback_impl, numba_funcify, + numba_njit, ) from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( @@ -16,12 +17,15 @@ get_name_for_object, unique_name_generator, ) +from pytensor.scalar import discrete_dtypes from pytensor.scalar.basic import ( Add, Cast, Clip, Composite, Identity, + Maximum, + Minimum, Mul, Reciprocal, ScalarOp, @@ -186,6 +190,37 @@ def numba_funcify_Mul(op, node, **kwargs): return numba_basic.numba_njit(signature)(nary_add_fn) +@numba_funcify.register(Maximum) +@numba_funcify.register(Minimum) +def numba_funcify_Extremum(op, node, **kwargs): + input_names = [f"x{i}" for i in range(len(node.inputs))] + input_signature = ", ".join(input_names) + assert len(input_names) > 0 + + inner_code = f"res = {input_names[0]}\n" + + if isinstance(op, Maximum): + op = ">" + func_name = "maximum" + else: + op = "<" + func_name = "minimum" + + if all(inp.dtype in discrete_dtypes for inp in node.inputs): + for x in input_names[1:]: + inner_code += f" res = {x} if {x} {op} res else res\n" + else: + for x in input_names[1:]: + inner_code += f" res = {x} if {x} {op} res else (res if res {op}= {x} else np.nan)\n" + inner_code += " return res" + + src = f""" +def {func_name}({input_signature}): + {inner_code} +""" + return numba_njit(compile_function_src(src, func_name, globals() | {"np": np})) + + @numba_funcify.register(Cast) def numba_funcify_Cast(op, node, **kwargs): dtype = np.dtype(op.o_type.dtype) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 26b551875c..6e7c6c2ffb 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -14,6 +14,7 @@ import math from collections.abc import Callable from copy import copy +from functools import reduce from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -1868,89 +1869,119 @@ def c_code(self, node, name, inputs, outputs, sub): ############## # Arithmetic ############## -class ScalarMaximum(BinaryScalarOp): +class AtLeastUnaryScalarOp(ScalarOp): + def make_node(self, *inputs): + if len(inputs) == 0: + raise TypeError(f"{self} requires at least 1 input: got 0") + return super().make_node(*inputs) + + +class Maximum(AtLeastUnaryScalarOp): commutative = True associative = True - nfunc_spec = ("maximum", 2, 1) - nfunc_variadic = "maximum" + nfunc_variadic = "max" identity = -np.inf def impl(self, *inputs): # The built-in max function don't support complex type - return np.maximum(*inputs) + return reduce(np.maximum, inputs) def c_code(self, node, name, inputs, outputs, sub): - (x, y) = inputs - (z,) = outputs if any(i.type in complex_types for i in node.inputs): raise NotImplementedError() - # Test for both y>x and x>=y to detect NaN - return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));' + + x, *ys = inputs + [z] = outputs + + # We need an intermediate variable in case we are working inplace + tmp = f"{z}_tmp" + res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});" + if all(i.dtype in discrete_dtypes for i in node.inputs): + for y in ys: + res += f"\n{tmp} = (({y}) > {tmp})? ({y}): {tmp};" + else: + # Need to check for nans + for y in ys: + res += ( + f"\n{tmp} = (({y}) > {tmp})? ({y}): (({tmp} >= ({y}))? {tmp}: NAN);" + ) + res += f"\n{z} = {tmp};" + return res + + def c_code_cache_version(self): + return (2,) def L_op(self, inputs, outputs, gout): - (x, y) = inputs - (gz,) = gout + [gz] = gout if gz.type in complex_types: # max is currently defined for complex_types, # but the gradient for complex is not. raise NotImplementedError() - if outputs[0].type in discrete_types: - return [ - x.zeros_like(dtype=config.floatX), - y.zeros_like(dtype=config.floatX), - ] - # This form handle the case when both value are the same. - # In that case, gx will be gz, gy will be 0. - e = eq(outputs[0], x) - gx = e * gz - gy = (constant(1, dtype=gz.dtype) - e) * gz - return (gx, gy) + [out] = outputs + + if out.type in discrete_types: + return [inp.zeros_like(dtype=config.floatX) for inp in inputs] + # We propagate the gradient to the maximum value(s) in the input + return [eq(inp, out) * gz for inp in inputs] -scalar_maximum = ScalarMaximum(upcast_out, name="maximum") +maximum = Maximum(upcast_out, name="maximum") -class ScalarMinimum(BinaryScalarOp): + +class Minimum(AtLeastUnaryScalarOp): commutative = True associative = True - nfunc_spec = ("minimum", 2, 1) - nfunc_variadic = "minimum" + nfunc_variadic = "min" identity = np.inf def impl(self, *inputs): # The built-in min function don't support complex type - return np.minimum(*inputs) + return reduce(np.minimum, inputs) def c_code(self, node, name, inputs, outputs, sub): - (x, y) = inputs - (z,) = outputs if any(i.type in complex_types for i in node.inputs): raise NotImplementedError() - return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));' + + x, *ys = inputs + [z] = outputs + + # We need an intermediate variable in case we are working inplace + tmp = f"{z}_tmp" + res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});" + if all(i.dtype in discrete_dtypes for i in node.inputs): + for y in ys: + res += f"\n{tmp} = (({y}) < {tmp})? ({y}): {tmp};" + else: + # Need to check for nans + for y in ys: + res += ( + f"\n{tmp} = (({y}) < {tmp})? ({y}): (({tmp} <= ({y}))? {tmp}: NAN);" + ) + res += f"\n{z} = {tmp};" + return res + + def c_code_cache_version(self): + return (2,) def L_op(self, inputs, outputs, gout): - (x, y) = inputs - (gz,) = gout + [gz] = gout if gz.type in complex_types: - # min is currently defined for complex_types, + # max is currently defined for complex_types, # but the gradient for complex is not. raise NotImplementedError() - if outputs[0].type in discrete_types: - return [ - x.zeros_like(dtype=config.floatX), - y.zeros_like(dtype=config.floatX), - ] - # This form handle the case when both value are the same. - # In that case, gx will be gz, gy will be 0. - e = eq(outputs[0], x) - gx = e * gz - gy = (constant(1, dtype=gz.dtype) - e) * gz - return (gx, gy) + [out] = outputs + + if out.type in discrete_types: + return [inp.zeros_like(dtype=config.floatX) for inp in inputs] + + # We propagate the gradient to the minimum value(s) in the input + return [eq(inp, out) * gz for inp in inputs] -scalar_minimum = ScalarMinimum(upcast_out, name="minimum") +minimum = Minimum(upcast_out, name="minimum") class Add(ScalarOp): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 86029e626f..e67fbf82bc 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -32,8 +32,8 @@ isinf, log, log1p, + maximum, reciprocal, - scalar_maximum, sqrt, switch, true_div, @@ -1305,7 +1305,7 @@ def c_code_cache_version(self): return v -softplus = Softplus(upgrade_to_float, name="scalar_softplus") +softplus = Softplus(upgrade_to_float, name="softplus") class Log1mexp(UnaryScalarOp): @@ -1575,9 +1575,7 @@ def inner_loop( derivative_new = K * (F1 * dK + F2) errapx = scalar_abs(derivative - derivative_new) - d_errapx = errapx / scalar_maximum( - err_threshold, scalar_abs(derivative_new) - ) + d_errapx = errapx / maximum(err_threshold, scalar_abs(derivative_new)) min_iters_cond = n > (min_iters - 1) derivative = switch( @@ -1823,7 +1821,7 @@ def inner_loop(*args): if len(grad_incs) == 1: [max_abs_grad_inc] = grad_incs else: - max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs) + max_abs_grad_inc = reduce(maximum, abs_grad_incs) return ( (*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k), diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 6bcb084f4e..5cade239c7 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -262,8 +262,8 @@ def _obj_is_wrappable_as_tensor(x): ps.Mul, ps.IntDiv, ps.TrueDiv, - ps.ScalarMinimum, - ps.ScalarMaximum, + ps.Minimum, + ps.Maximum, ) diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 3124428016..8c1e5ceb61 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -947,8 +947,8 @@ def infer_shape(self, fgraph, node, input_shapes): z_shape, _, x_shape, y_shape, _ = input_shapes return [ ( - pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]), - pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]), + pytensor.scalar.maximum(z_shape[0], x_shape[0]), + pytensor.scalar.maximum(z_shape[1], y_shape[1]), ) ] diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index cb4476ede0..8c0df0e2e0 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -357,12 +357,12 @@ def second_inplace(a): pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="])) -@scalar_elemwise(symbolname="scalar_maximum_inplace") +@scalar_elemwise def maximum_inplace(a, b): """elementwise addition (inplace on `a`)""" -@scalar_elemwise(symbolname="scalar_minimum_inplace") +@scalar_elemwise def minimum_inplace(a, b): """elementwise addition (inplace on `a`)""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 714f597b32..30a99c3208 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -406,7 +406,7 @@ class Max(NonZeroDimsCAReduce): nfunc_spec = ("max", 1, 1) def __init__(self, axis): - super().__init__(ps.scalar_maximum, axis) + super().__init__(ps.maximum, axis) def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) @@ -464,7 +464,7 @@ class Min(NonZeroDimsCAReduce): nfunc_spec = ("min", 1, 1) def __init__(self, axis): - super().__init__(ps.scalar_minimum, axis) + super().__init__(ps.minimum, axis) def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) @@ -2762,7 +2762,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable: return ifelse(even_k, even_median, odd_median, name="median") -@scalar_elemwise(symbolname="scalar_maximum") +@scalar_elemwise def maximum(x, y): """elemwise maximum. See max for the maximum in one tensor @@ -2798,7 +2798,7 @@ def maximum(x, y): # see decorator for function body -@scalar_elemwise(symbolname="scalar_minimum") +@scalar_elemwise def minimum(x, y): """elemwise minimum. See min for the minimum in one tensor diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 4e75140ceb..d5d9d772dd 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -6,6 +6,7 @@ import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops +import pytensor.tensor.rewriting.extremum import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 59148fae3b..99b3ed7cca 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -1056,11 +1056,8 @@ def local_merge_switch_same_cond(fgraph, node): condition, to enable further simplification of their branches Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) """ - # node must be binary elemwise or add or mul - if not ( - isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul) - ): + # node must be binary elemwise with at least 2 inputs + if len(node.inputs) < 2: return # all inputs must be switch if not all( diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index eaba64c275..f9d956ce6a 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -493,65 +493,65 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): """ if len(node.outputs) > 1: return - try: - shape_i = fgraph.shape_feature.shape_i - except AttributeError: - shape_i = None - if isinstance(node.op, Elemwise): - scalar_op = node.op.scalar_op - # print "aa", scalar_op.output_types_preference - if getattr(scalar_op, "output_types_preference", None) in ( - ps.upgrade_to_float, - ps.upcast_out, - ): - # this is the kind of op that we can screw with the input - # dtypes by upcasting explicitly - output_dtype = node.outputs[0].type.dtype - new_inputs = [] - for i in node.inputs: - if i.type.dtype == output_dtype: - new_inputs.append(i) - else: - try: - cval_i = get_underlying_scalar_constant_value( - i, only_process_constants=True + + if all(isinstance(i, Constant) for i in node.inputs): + # If all inputs are constant, constant_fold will take care of it + return + + if getattr(node.op.scalar_op, "output_types_preference", None) in ( + ps.upgrade_to_float, + ps.upcast_out, + ): + # this is the kind of op that we can screw with the input + # dtypes by upcasting explicitly + output_dtype = node.outputs[0].type.dtype + new_inputs = [] + for i in node.inputs: + if i.type.dtype == output_dtype: + new_inputs.append(i) + else: + try: + cval_i = get_underlying_scalar_constant_value( + i, only_process_constants=True + ) + if all(i.broadcastable): + new_inputs.append( + shape_padleft(cast(cval_i, output_dtype), i.ndim) ) - if all(i.broadcastable): - new_inputs.append( - shape_padleft(cast(cval_i, output_dtype), i.ndim) - ) - else: - if shape_i is None: - return - new_inputs.append( - alloc( - cast(cval_i, output_dtype), - *[shape_i(d)(i) for d in range(i.ndim)], - ) + else: + try: + shape_i = fgraph.shape_feature.shape_i + except AttributeError: + return + new_inputs.append( + alloc( + cast(cval_i, output_dtype), + *[shape_i(d)(i) for d in range(i.ndim)], ) - # print >> sys.stderr, "AAA", - # *[Shape_i(d)(i) for d in range(i.ndim)] - except NotScalarConstantError: - # for the case of a non-scalar - if isinstance(i, TensorConstant): - new_inputs.append(cast(i, output_dtype)) - else: - new_inputs.append(i) + ) + # print >> sys.stderr, "AAA", + # *[Shape_i(d)(i) for d in range(i.ndim)] + except NotScalarConstantError: + # for the case of a non-scalar + if isinstance(i, TensorConstant): + new_inputs.append(cast(i, output_dtype)) + else: + new_inputs.append(i) - if new_inputs != node.inputs: - rval = [node.op(*new_inputs)] - if not node.outputs[0].type.is_super(rval[0].type): - # This can happen for example when floatX=float32 - # and we do the true division between and int64 - # and a constant that will get typed as int8. + if new_inputs != node.inputs: + rval = [node.op(*new_inputs)] + if not node.outputs[0].type.is_super(rval[0].type): + # This can happen for example when floatX=float32 + # and we do the true division between and int64 + # and a constant that will get typed as int8. - # As this is just to allow merging more case, if - # the upcast don't work, we can just skip it. - return + # As this is just to allow merging more case, if + # the upcast don't work, we can just skip it. + return - # Copy over output stacktrace from before upcasting - copy_stack_trace(node.outputs[0], rval) - return rval + # Copy over output stacktrace from before upcasting + copy_stack_trace(node.outputs[0], rval) + return rval @node_rewriter([Elemwise]) diff --git a/pytensor/tensor/rewriting/extremum.py b/pytensor/tensor/rewriting/extremum.py new file mode 100644 index 0000000000..e906093106 --- /dev/null +++ b/pytensor/tensor/rewriting/extremum.py @@ -0,0 +1,555 @@ +import operator +from collections import deque + +import numpy as np + +from pytensor.compile import optdb +from pytensor.graph import Constant, node_rewriter +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + out2in, +) +from pytensor.scalar import ( + GE, + GT, + LE, + LT, + Abs, + Add, + Cast, + Exp, + Log, + Log1p, + Maximum, + Minimum, + Sqr, + Sub, + discrete_dtypes, +) +from pytensor.tensor.basic import atleast_Nd +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.extra_ops import broadcast_arrays +from pytensor.tensor.math import add, maximum, minimum, switch, variadic_add +from pytensor.tensor.rewriting.basic import register_canonicalize +from pytensor.tensor.shape import Shape_i +from pytensor.tensor.type import uint_dtypes +from pytensor.tensor.utils import import_func_from_string +from pytensor.tensor.variable import TensorConstant + + +EXTREMUM_OPS = Minimum | Maximum +DIRECTIONAL_COMPARISON_OPS = GE | GT | LE | LT + + +@register_canonicalize +@node_rewriter([switch]) +def local_switch_to_extremum(fgraph, node): + """Rewrite switch(x >= y, x, y) -> maximum(x, y).""" + [out] = node.outputs + + if not all(out.type.broadcastable): + # Only do this for scalar graphs + return None + + if out.dtype not in discrete_dtypes: + # Switch ignores `nan` values so it is not equivalent to maximum in that case + return None + + cond, x, y = node.inputs + cond_node = cond.owner + if not ( + cond_node is not None + and isinstance(cond_node.op, Elemwise) + and isinstance(cond_node.op.scalar_op, DIRECTIONAL_COMPARISON_OPS) + ): + return None + + cond_x, cond_y = cond_node.inputs + logical_op = cond_node.op.scalar_op + if cond_x is x and cond_y is y: + if isinstance(logical_op, GT | GE): + return [maximum(x, y)] + else: + return [minimum(x, y)] + elif cond_x is y and cond_y is x: + # Flipped meaning + if isinstance(logical_op, GT | GE): + return [minimum(x, y)] + else: + return [maximum(x, y)] + + +@register_canonicalize +@node_rewriter([add]) +def local_extremum_plus_x(fgraph, node): + """Rewrite maximum(y, z) + x -> maximum(y+x, z+x). + + Only do this for scalar graphs and when x is a root variable + """ + if not all(node.out.type.broadcastable): + return None + + minmax_terms = [ + t + for t in node.inputs + if t.owner + and isinstance(t.owner.op, Elemwise) + and isinstance(t.owner.op.scalar_op, EXTREMUM_OPS) + ] + if len(minmax_terms) != 1: + return None + [minmax_term] = minmax_terms + other_terms = [t for t in node.inputs if t is not minmax_term] + if not all(t.owner is None for t in other_terms): + # Keep it to simple additions + return None + c = variadic_add(*other_terms) + + if isinstance(c, Constant) and c.unique_value == 0: + # Eager optimization if c is zero, to reduce number of passes + return [minmax_term] + + # To reduce passes we do c + t, as c is likely to be a constant and this is where the add_canonizer would put them next. + return [minmax_term.owner.op(*[c + t for t in minmax_term.owner.inputs])] + + +@register_canonicalize +@node_rewriter([minimum, maximum]) +def local_flatten_extremum(fgraph, node): + """Rewrite maximum(maximum(x, y), ..., maximum(w, z)) -> maximum(x, y, ..., w, z). + + This makes it easier to remove useless branches that don't seem to talk to each other. + + Also remove duplicated variables or multiple constants. + + Restricted to scalar graphs only. + """ + if not all(node.out.type.broadcastable): + return None + + scalar_op = node.op.scalar_op + inputs = node.inputs + + # Quick exit circuit + if not ( + # Repeated inputs + len(inputs) != len(set(inputs)) + # There's a nested Op that is the same as the outer one + or any( + inp.owner is not None + and isinstance(inp.owner.op, Elemwise) + and inp.owner.op.scalar_op == scalar_op + for inp in inputs + ) + # There are multiple constants + or sum(isinstance(inp, Constant) for inp in inputs) > 1 + ): + return None + + old_inputs = deque(inputs) + new_inputs = [] + new_inputs_set = set() # For faster comparison, but we don't want random ordering + is_maximum = isinstance(scalar_op, Maximum) + extremum_const = None + while old_inputs: + old_inp = old_inputs.popleft() + if old_inp in new_inputs_set: + # duplicate inputs + continue + + if ( + old_inp.owner + and isinstance(old_inp.owner.op, Elemwise) + and old_inp.owner.op.scalar_op == scalar_op + ): + # Add to the queue to be flatten out + old_inputs.extend(old_inp.owner.inputs) + continue + + if isinstance(old_inp, Constant): + if extremum_const is None: + extremum_const = old_inp + else: + # Either discard this constant or the previous one + # TODO: We could apply this logic to non-scalars as well + data = old_inp.data.item() + extremum_data = extremum_const.data.item() + if (is_maximum and data <= extremum_data) or ( + not is_maximum and data >= extremum_data + ): + continue # discard this constant + + new_inputs.remove(extremum_const) + new_inputs_set.remove(extremum_const) + extremum_const = old_inp + + new_inputs.append(old_inp) + new_inputs_set.add(old_inp) + + if len(new_inputs) > 1: + new_out = node.op(*new_inputs) + copy_stack_trace(new_inputs, new_out) + else: + [new_out] = new_inputs + + # Removed constants may have broadcast or upcast the output + if new_out.dtype != node.out.type.dtype: + new_out = new_out.astype(node.out.type.dtype) + if new_out.ndim != node.out.type.ndim: + new_out = atleast_Nd(new_out, node.out.type.ndim) + return [new_out] + + +@register_canonicalize +@node_rewriter([maximum, minimum]) +def local_useless_extremum_x_plus_offset(fgraph, node): + """Rewrite maximum(x, x + 1) -> x + 1.""" + variables, constants = [], [] + for inp in node.inputs: + if ( + inp.owner is not None + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, Add) + ): + if len(inp.owner.inputs) > 2: + # Addition with too many terms for us to reason about + return + x, y = inp.owner.inputs + if isinstance(x, TensorConstant) and x.unique_value is not None: + variables.append(y) + constants.append(x.unique_value) + elif isinstance(y, TensorConstant) and y.unique_value is not None: + variables.append(x) + constants.append(y.unique_value) + else: + return None + else: + variables.append(inp) + constants.append(0) + + if len(set(variables)) != 1: + # TODO: Implement logic for multiple subsets of variables + return None + + # Find the branch with the highest constant + if node.op == maximum: + new_out = node.inputs[np.argmax(constants)] + else: + new_out = node.inputs[np.argmin(constants)] + + # Removed branch may have broadcast or upcast the output + if new_out.dtype != node.out.type.dtype: + new_out = new_out.astype(node.out.type.dtype) + if new_out.type.broadcastable != node.out.type.broadcastable: + new_out = broadcast_arrays(new_out, *node.inputs)[0] + return [new_out] + + +def _estimate_upper_bound(var, atleast=None) -> float: + if atleast is not None and getattr(var.tag, "upper_bound", np.inf) <= atleast: + # We already proved an upper bound as low as atleast + return atleast # type: ignore + + ub = np.inf + + if var.owner is None: + if isinstance(var, Constant): + ub = var.data.item() + else: + if var.dtype == "bool": + ub = 1 + + elif isinstance(var.owner.op, Elemwise): + scalar_op = var.owner.op.scalar_op + + if isinstance(scalar_op, Minimum): + for min_var in var.owner.inputs: + ub = min(ub, _estimate_upper_bound(min_var, atleast=atleast)) + if ub == atleast: + break # This is enough for us + + elif isinstance(scalar_op, Maximum): + ub = -np.inf + for max_var in var.owner.inputs: + ub = max(ub, _estimate_upper_bound(max_var)) + if ub == np.inf: + break # Don't bother with other inputs + + elif isinstance(scalar_op, Add): + ub = 0 + for inp in var.owner.inputs: + ub += _estimate_upper_bound(inp) + if ub == np.inf: + # Don't bother with other inputs + break + + elif isinstance(scalar_op, Sub): + left, right = var.owner.inputs + ub = _estimate_upper_bound(left) + if ub != np.inf: + ub -= _estimate_lower_bound(right) + + elif isinstance(scalar_op, Cast): + # Trivial case + if var.type.dtype == "bool": + ub = 1 + + if atleast is None or ub > atleast: + # We are not satisfied with the trivial upper bound of 1 + [bef_cast] = var.owner.inputs + bef_ub = _estimate_upper_bound(bef_cast, atleast=atleast) + if bef_ub != np.inf: + # If we actually got a bound, we can cast it + bef_ub = np.array(bef_ub).astype(var.dtype).item() + ub = min(ub, bef_ub) + + var.tag.upper_bound = ub + return ub + + +def _estimate_lower_bound(var, atleast=None) -> float: + if atleast is not None and getattr(var.tag, "lower_bound", -np.inf) >= atleast: + # We already proved a lower bound as high as atleast + return atleast # type: ignore + + lb = -np.inf + + if var.owner is None: + if isinstance(var, Constant): + lb = var.data.item() + else: + # We can't reason about the lower bound of a root variable besides from dtypes + if var.dtype == "bool": + lb = 0 + elif var.dtype in uint_dtypes: + lb = 0 + + elif isinstance(var.owner.op, Shape_i): + lb = 0 + + elif isinstance(var.owner.op, Elemwise): + scalar_op = var.owner.op.scalar_op + + if isinstance(scalar_op, Minimum): + lb = np.inf + for min_var in var.owner.inputs: + lb = min(lb, _estimate_lower_bound(min_var, atleast=atleast)) + if lb == -np.inf: + # Don't bother with other inputs + break + + elif isinstance(scalar_op, Maximum): + for max_var in var.owner.inputs: + lb = max(lb, _estimate_lower_bound(max_var)) + if lb == atleast: + break # This is enough for us + + elif isinstance(scalar_op, Add): + lb = 0 + for inp in var.owner.inputs: + lb += _estimate_lower_bound(inp) + if lb == -np.inf: + # Don't bother with other inputs + break + + elif isinstance(scalar_op, Sub): + left, right = var.owner.inputs + lb = _estimate_lower_bound(left) + if lb != -np.inf: + lb -= _estimate_upper_bound(right) + + elif isinstance(scalar_op, Abs): + lb = 0 # Guaranteed by abs + + atleast = 3 + # lb=(-5, inf) -> lb(abs)=(0, inf) -> not enough + # lb=(3, inf) -> lb(abs)=(0, 5) -> not enough + # up=(-inf, -3) -> lb(abs) = (3, inf) -> maybe enough + + if atleast is None or lb < atleast: + # We are not satisfied with the trivial lower bound of 0 + [abs_var] = var.owner.inputs + lb = max(lb, _estimate_lower_bound(abs_var, atleast=atleast)) + if atleast is None or lb < atleast: + # If we are still not satisfied, we can try to estimate the upper bound + # if upper bound is smaller than the negative of the requested value we're good + ub_negative = _estimate_upper_bound( + abs_var, atleast=-atleast if atleast is not None else None + ) + if ub_negative < -lb: + # We learned something more precise + assert ub_negative < 0 + lb = abs(ub_negative) + + elif isinstance(scalar_op, Exp | Sqr | Log | Log1p): + # Monotonic functions + if atleast is not None: + if isinstance(scalar_op, Exp): + atleast = np.log(atleast) + elif isinstance(scalar_op, Sqr): + atleast = np.sqrt(atleast) + elif isinstance(scalar_op, Log): + atleast = np.log(atleast) + elif isinstance(scalar_op, Log1p): + atleast = np.expm1(atleast) + + np_func = import_func_from_string(scalar_op.nfunc_spec[0]) + lb = np_func(_estimate_lower_bound(var, atleast)) + + elif isinstance(scalar_op, Cast): + # Some trivial cases for casts that round to zero + if var.type.dtype == "bool" or var.type.dtype in uint_dtypes: + lb = 0 + + if atleast is None or lb < atleast: + # We are not satisfied with the trivial lower bound of 0 + [bef_cast] = var.owner.inputs + bef_lb = _estimate_lower_bound(bef_cast, atleast=atleast) + if bef_lb != -np.inf: + # If we actually got a bound, we can cast it + bef_lb = np.array(bef_lb).astype(var.dtype).item() + lb = max(lb, bef_lb) + + var.tag.lower_bound = lb + return lb + + +# registered as a graph rewrite below to avoid too many calls +@node_rewriter([switch]) +def local_useless_switch_branches(fgraph, node): + if node.out.dtype not in discrete_dtypes: + return None + + cond, true_branch, false_branch = node.inputs + if not ( + cond.owner is not None + and isinstance(cond.owner.op, Elemwise) + and isinstance(cond.owner.op.scalar_op, DIRECTIONAL_COMPARISON_OPS) + ): + return None + + left, right = cond.owner.inputs + + scalar_op = cond.owner.op.scalar_op + if isinstance(scalar_op, LE): + # Same as GE, but with left and right swapped + scalar_op = GE + left, right = right, left + elif isinstance(scalar_op, LT): + # Same as GT, but with left and right swapped + scalar_op = GT + left, right = right, left + + if isinstance(scalar_op, GE): + # left >= right is useless when lower bound of left >= upper bound of right + # (5, inf) >= (-inf, 5) is always True + left_lb = _estimate_lower_bound(left) + if left_lb != -np.inf and left_lb >= _estimate_upper_bound( + right, atleast=left_lb + ): + return [true_branch] + # or upper bound of left < lower bound of right + # (-inf, 5) >= (5+eps, inf) is always false + left_ub = _estimate_upper_bound(left) + if left_ub != np.inf and left_ub < _estimate_lower_bound( + left, atleast=left_ub + 1e-5 + ): + return [false_branch] + + elif isinstance(scalar_op, GT): + # left > right is useless when lower bound of left > upper bound of right + # (5, inf) > (-inf, 5-eps) is always True + left_lb = _estimate_lower_bound(left) + if left_lb != -np.inf and left_lb > _estimate_upper_bound( + right, atleast=left_lb - 1e-5 + ): + return [true_branch] + # or upper bound of left <= lower bound of right + # (-inf, 5) > (5, inf) is always false + left_ub = _estimate_upper_bound(left) + if left_ub != np.inf and left_ub <= _estimate_lower_bound( + left, atleast=left_ub + ): + return [false_branch] + + +# registered as a graph rewrite below to avoid too many calls +@node_rewriter([minimum, maximum]) +def local_useless_extremum_branches(fgraph, node): + """Rewrite useless branches in a maximum/minimum based on lower-upper bound reasoning. + + maximum(x, y, z) -> if any xyz's upper bound <= yzx' lower bound, i can be discarded. + + Example: + maximum(0, shape(x), y) -> maximum(shape(x), y) since shape(x) is already lower bounded by zero + maximum(2, minimum(x, 1)) -> 2, since minimum(x, y) is already upper bounded by 1 + maximum(1, minimum(x, 1-shape(y)) -> 1 + """ + [old_out] = node.outputs + if not all(old_out.type.broadcastable): + return None + + if isinstance(node.op.scalar_op, Minimum): + informative_bound = _estimate_lower_bound + uninformative_bound_value = -np.inf + reverse = True + reverse_bound = _estimate_upper_bound + logical_comp = operator.le + else: + informative_bound = _estimate_upper_bound + uninformative_bound_value = np.inf + reverse = False + reverse_bound = _estimate_lower_bound + logical_comp = operator.ge + + inputs, bounds = zip( + *sorted( + ((inp, informative_bound(inp)) for inp in node.inputs), + key=lambda x: x[1], + reverse=reverse, + ), + strict=False, # useless + ) + + while len(bounds) > 1 and bounds[0] != uninformative_bound_value: + most_restricted_bound = bounds[0] + + # If any other branch as a lower bound >= upper_bound, they can be discarded + for other_inp in inputs[1:]: + if logical_comp( + reverse_bound(other_inp, atleast=most_restricted_bound), + most_restricted_bound, + ): + # We can remove the restricted bound input + inputs = inputs[1:] + bounds = bounds[1:] + break + else: # no break + break + + if len(inputs) == 1: + [new_out] = inputs + elif len(inputs) < len(node.inputs): + new_out = minimum(*inputs) + copy_stack_trace(old_out, new_out) + else: + return None + + # Removed branches may have broadcast or upcast the output + if new_out.dtype != old_out.type.dtype: + new_out = new_out.astype(old_out.type.dtype) + if new_out.type.ndim != old_out.type.ndim: + new_out = atleast_Nd(new_out, old_out.type.ndim) + + return [new_out] + + +# This rewrite can be expensive, call it once going from out to in +# After all the local rewrites in canonicalize have been applied +# out2in is preferrable because we truncate more on the outputs, and any +# domain bound analysis that go up to the inputs are cached anyway. +optdb["canonicalize"].register( + local_useless_extremum_branches.__name__, + out2in(local_useless_switch_branches, local_useless_extremum_branches), + "fast_run", +) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..15637c897a 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1450,7 +1450,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Elemwise[{minimum,maximum}](X, X) -> X if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum) + isinstance(node.op.scalar_op, ps.Minimum | ps.Maximum) and node.inputs[0] is node.inputs[1] ): res = node.inputs[0] @@ -1493,7 +1493,7 @@ def local_useless_elemwise_comparison(fgraph, node): return [res] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] - if isinstance(node.op.scalar_op, ps.ScalarMaximum): + if isinstance(node.op.scalar_op, ps.Maximum): for idx in range(2): if ( node.inputs[idx].owner @@ -1512,7 +1512,7 @@ def local_useless_elemwise_comparison(fgraph, node): return [res] # Elemwise[minimum](X.shape[i], 0) -> 0 - if isinstance(node.op.scalar_op, ps.ScalarMinimum): + if isinstance(node.op.scalar_op, ps.Minimum): for idx in range(2): if ( node.inputs[idx].owner diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 1af10e52b4..a1249879b3 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -559,6 +559,22 @@ def local_subtensor_merge(fgraph, node): # Do not call make_node for test_value out = subtens(x, *sl_ins) + # Eagerly clean up merged subtensor graph, which can be a mess + # rewriter = EquilibriumGraphRewriter( + # [ + # local_extremum_plus_x, + # local_add_canonizer, + # local_mul_canonizer, + # local_intdiv_by_one, + # local_useless_extremum_branches, + # local_flatten_extremum, + # ], + # max_use_ratio=10.0, + # ) + # fg = FunctionGraph(outputs=[out], clone=False) + # rewriter.rewrite(fg) + # [out] = fg.outputs + # Copy over previous output stacktrace # and stacktrace from previous slicing operation. # Why? Because, the merged slicing operation could have failed diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index a44870ded2..94f5aeb28b 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -60,7 +60,7 @@ def local_max_to_min(fgraph, node): if ( max.owner and isinstance(max.owner.op, CAReduce) - and max.owner.op.scalar_op == ps.scalar_maximum + and max.owner.op.scalar_op == ps.maximum ): neg_node = max.owner.inputs[0] if neg_node.owner and neg_node.owner.op == neg: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 8e3e5cb902..8c2fd179f3 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -33,6 +33,7 @@ alloc, get_scalar_constant_value, nonzero, + switch, ) from pytensor.tensor.basic import ( constant as tensor_constant, @@ -40,7 +41,8 @@ from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError -from pytensor.tensor.math import clip +from pytensor.tensor.math import abs as pt_abs +from pytensor.tensor.math import clip, eq, ge, lt, maximum, minimum, sign from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable from pytensor.tensor.type import ( TensorType, @@ -55,6 +57,7 @@ lscalar, tensor, ubscalar, + uint_dtypes, uiscalar, ulscalar, uwscalar, @@ -254,6 +257,25 @@ def get_idx_list(inputs, idx_list): return indices_from_subtensor(inputs[1:], idx_list) +def undo_scalarization(x): + """Undo scalarization of a variable. + + PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. + But reasoning symbolically about the result of multiple indexing operations, we usually + want to work on TensorVariables, since rewrites work on those and not ScalarVariables. + + This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. + """ + if isinstance(x, ScalarVariable): + if isinstance(x, ScalarConstant): + return tensor_constant(x.data, dtype=x.dtype) + elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): + return x.owner.inputs[0] + else: + return as_tensor_variable(x) + return x + + @overload def get_canonical_form_slice( theslice: slice, @@ -296,25 +318,6 @@ def get_canonical_form_slice( direction Direction to iterate the resulting elements in. (-1 or 1). May be symbolic. """ - from pytensor.tensor import ge, lt, sign, switch - - def undo_scalarization(x): - """Undo scalarization of a variable. - - PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. - But reasoning symbolically about the result of multiple indexing operations, we usually - want to work on TensorVariables, since rewrites work on those and not ScalarVariables. - - This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. - """ - if isinstance(x, ScalarVariable): - if isinstance(x, ScalarConstant): - return tensor_constant(x.data, dtype=x.dtype) - elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): - return x.owner.inputs[0] - else: - return as_tensor_variable(x) - return x def analyze(x): try: @@ -381,7 +384,7 @@ def analyze(x): ) is_stop_length = ( stop is None - or stop in [length, sys.maxsize] + or stop == length or (is_stop_constant and is_length_constant and stop >= length) ) if is_start_0: @@ -390,39 +393,27 @@ def analyze(x): # Full slice. return slice(0, length, 1), 1 if is_stop_constant and stop >= 0: - return (slice(0, switch(lt(stop, length), stop, length), 1), 1) + return slice(0, minimum(stop, length), 1), 1 stop_plus_len = stop + length stop = switch( lt(stop, 0), # stop < 0 - switch( - lt(stop_plus_len, 0), - # stop + len < 0 - 0, - # stop + len >= 0 - stop_plus_len, - ), + maximum(stop_plus_len, 0), # stop >= 0: use min(stop, length) - switch(lt(stop, length), stop, length), + minimum(stop, length), ) return slice(0, stop, 1), 1 elif is_stop_length: # start:length:1 if is_start_constant and start >= 0: - return slice(switch(lt(start, length), start, length), length, 1), 1 + return slice(minimum(start, length), length, 1), 1 start_plus_len = start + length start = switch( lt(start, 0), # start < 0 - switch( - lt(start_plus_len, 0), - # start + len < 0 - 0, - # start + len >= 0 - start_plus_len, - ), - # start >= 0: use min(start, length) - switch(lt(start, length), start, length), + maximum(start_plus_len, 0), + # start >= 0 + minimum(start, length), ) return slice(start, length, 1), 1 @@ -462,26 +453,23 @@ def switch_neg_step(a, b): start = switch(lt(start, 0), start + length, start) start = switch(lt(start, 0), switch_neg_step(-1, 0), start) start = switch(ge(start, length), switch_neg_step(length - 1, length), start) - if stop is None or stop == sys.maxsize: - # The special "maxsize" case is probably not needed here, - # as slices containing maxsize are not generated by - # __getslice__ anymore. + if stop is None: stop = defstop else: stop = switch(lt(stop, 0), stop + length, stop) stop = switch(lt(stop, 0), -1, stop) - stop = switch(ge(stop, length), length, stop) + stop = minimum(length, stop) nw_stop = switch_neg_step(start + 1, stop) slice_len = (start - stop - 1) // abs_step + 1 - slice_len = switch(lt(slice_len, 0), 0, slice_len) + slice_len = maximum(slice_len, 0) neg_start = nw_stop - (slice_len - 1) * abs_step - 1 neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start) nw_start = switch_neg_step(neg_start, start) - nw_start = switch(lt(nw_start, 0), 0, nw_start) - nw_stop = switch(lt(nw_stop, 0), 0, nw_stop) + nw_start = maximum(nw_start, 0) + nw_stop = maximum(nw_stop, 0) # Ensure start <= stop. - nw_start = switch(lt(nw_start, nw_stop), nw_start, nw_stop) + nw_start = minimum(nw_start, nw_stop) nw_step = abs_step if step != 1: @@ -845,6 +833,17 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable: return ps.as_scalar(a) +def _eager_switch( + cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable +) -> TensorVariable: + # Do not create a switch if cond is True/False + # We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64 + # It also simplifies immediately the graph that's returned + if isinstance(cond, bool): + return a if cond else b + return cast(TensorVariable, switch(cond, a, b)) + + class Subtensor(COp): """Basic NumPy indexing operator.""" @@ -956,27 +955,110 @@ def infer_shape(self, fgraph, node, shapes): padded = actual_idx_list + [slice(None, None, None)] * ( len(xshp) - len(self.idx_list) ) + + zero = tensor_constant(np.array(0, dtype="int64")) + one = tensor_constant(np.array(1, dtype="int64")) i = 0 for idx, xl in zip(padded, xshp, strict=True): if isinstance(idx, slice): - # If it is the default (None, None, None) slice, or a variant, - # the shape will be xl + a, b, step = idx.start, idx.stop, idx.step if ( - (idx.start in [None, 0]) - and (idx.stop in [None, sys.maxsize]) - and (idx.step is None or idx.step == 1) + a is None + and b is None + and step is not None + and get_scalar_constant_value(step, raise_not_constant=False) == -1 ): + # Shortcut for x[::-1] outshp.append(xl) + else: - cnf = get_canonical_form_slice(idx, xl)[0] - if cnf.step == 1: - length = cnf.stop - cnf.start + if step is None: + step_pos = True + unit_step = True + abs_step = one + else: + step = undo_scalarization(step) + if step.dtype in uint_dtypes: + step_pos = True + abs_step = step.astype("int64") + else: + step_pos = ge(step, zero) + abs_step = pt_abs(step) + unit_step = eq(abs_step, one) + + if a is None: + a_pos = True + a = _eager_switch(step_pos, zero, xl) else: - length = (cnf.stop - cnf.start - 1) // cnf.step + 1 - outshp.append(length) + a = undo_scalarization(a) + if a.dtype in uint_dtypes: + a_pos = True + a = a.astype("int64") + else: + a_pos = ge(a, zero) + + if b is None: + # For negative steps there is no numerical equivalent for stop=None. + # The formulas below work if we set it to -1 and consider `b_pos=True` + b_pos = True + b = _eager_switch(step_pos, xl, -one) + else: + b = undo_scalarization(b) + if b.dtype in uint_dtypes: + b = b.astype("int64") + b_pos = True + else: + b_pos = ge(b, zero) + + slice_length_pos_step = _eager_switch( + a_pos, + _eager_switch( + b_pos, + minimum(b - a, xl - a), # [a: b] + ((xl + b) - a), # [a: -b] + ), + _eager_switch( + b_pos, + # The [-a: b] is peculiar, the slice length actually decreases for larger arrays + # The branch -a is useless when b - a / 2 <= -a. Similar for the branch b + minimum(xl, b - a - xl, -a, b), # [-a: b] + minimum(b - a, xl + b), # [-a: -b] + ), + ) + + slice_length_neg_step = _eager_switch( + a_pos, + _eager_switch( + b_pos, + minimum(a - b, xl - b - one), # [a: b] + minimum(xl, a - (xl + b), a + one, -b - one), # [a: -b] + ), + _eager_switch( + b_pos, + ((xl + a) - b), # [-a: b] + minimum(a - b, xl + a + one), # [-a: -b] + ), + ) + + slice_length = _eager_switch( + step_pos, + slice_length_pos_step, + slice_length_neg_step, + ) + + # Incorporate step size + slice_length = _eager_switch( + unit_step, + slice_length, + (slice_length - one) // abs_step + one, + ) + # Catch negative sizes + slice_length = maximum(zero, slice_length) + outshp.append(slice_length) + i += 1 else: - # That dimension is dropped + # That dimension is dropped by integer indexing pass assert i == node.outputs[0].ndim assert len(outshp) == node.outputs[0].ndim diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..471bfa6dc4 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -3829,10 +3829,7 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): fgraph = f.maker.fgraph.toposort() for node in fgraph: - if ( - hasattr(node.op, "scalar_op") - and node.op.scalar_op == ps.basic.scalar_maximum - ): + if hasattr(node.op, "scalar_op") and node.op.scalar_op == ps.basic.maximum: return # In mode FAST_COMPILE, the rewrites don't replace the diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 0f0ec55695..c11ac8229f 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import pytest @@ -16,17 +18,19 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.type import Type from pytensor.raise_op import Assert -from pytensor.tensor import inplace +from pytensor.tensor import inplace, switch from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, add, dot, exp, sqr +from pytensor.tensor.math import Dot, add, dot, exp, maximum, minimum, sqr +from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.rewriting.subtensor import ( local_replace_AdvancedSubtensor, local_subtensor_make_vector, local_subtensor_shape_constant, ) from pytensor.tensor.shape import ( + Shape_i, SpecifyShape, _shape, shape, @@ -984,7 +988,6 @@ def test_scalar(self): with pytest.raises(IndexError): g(x_val, idx) - @pytest.mark.slow def test_const2(self): # var[::-1][const] -> var[-1] x = matrix("x") @@ -2415,3 +2418,133 @@ def test_unknown_step(self): f(test_x, -2), test_x[0:3:-2, -1:-6:2, ::], ) + + +class TestSubtensorShapeSimplifies: + """The tests in this class make sure we don't end up with too crazy shape graphs for Subtensor. + + The exact form is not critical, so if rewrites change it slightly in the future, it's fine to tweak the tests. + Just make sure we don't end up with a bazillion nodes. + + https://github.com/pymc-devs/pytensor/issues/112 + """ + + @classmethod + def setup_class(cls): + # Excluding neg_to_mul, because it makes it clumsier to write the expected graphs + cls.rewrite = partial( + rewrite_graph, + include=("ShapeOpt", "canonicalize"), + exclude=("local_neg_to_mul",), + ) + + def test_start(self): + rewrite = self.rewrite + x = vector("x", shape=(None,)) + x_sh0 = Shape_i(0)(x) + + sh = rewrite(x[3:].shape[0]) + expected_sh = maximum(0, -3 + x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + sh = rewrite(x[-3:].shape[0]) + expected_sh = minimum(3, x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + a = scalar("a", dtype="int64") + sh = rewrite(x[a:].shape[0]) + expected = maximum(0, switch(a >= 0, x_sh0 - a, minimum(x_sh0, -a))) + assert equal_computations([sh], [expected], strict_dtype=False) + + # Cases where a must be non-negative + sh = rewrite(x[pt_abs(a) :].shape[0]) + expected_sh = maximum(0, x_sh0 - pt_abs(a)) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + # Not implemented yet + # sh = rewrite(x[assert_op(a, a >= 0) :].shape[0]) + # expected_sh = maximum(x_sh0 - assert_op(a, a >= 0), 0) + # assert equal_computations([sh], [expected_sh], strict_dtype=False) + + y = pt.vector("y", shape=(None,)) + sh = rewrite(x[y.shape[0] :].shape[0]) + expected_sh = maximum(0, x_sh0 - Shape_i(0)(y)) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + a_uint = scalar("a_uint", dtype="uint64") + sh = rewrite(x[a_uint:].shape[0]) + expected_sh = maximum(0, x_sh0 - a_uint.astype("int64")) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + def test_stop(self): + rewrite = self.rewrite + x = vector("x", shape=(None,)) + x_sh0 = Shape_i(0)(x) + + sh = rewrite(x[:3].shape[0]) + expected_sh = minimum(3, x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + sh = rewrite(x[:-3].shape[0]) + expected_sh = maximum(0, -3 + x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + a = scalar("a", dtype="int64") + sh = rewrite(x[:a].shape[0]) + expected = maximum(0, switch(a >= 0, minimum(a, x_sh0), x_sh0 + a)) + assert equal_computations([sh], [expected], strict_dtype=False) + + # Cases where a must be non-negative + sh = rewrite(x[: abs(a)].shape[0]) + expected_sh = minimum(abs(a), x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + # Not implemented yet + # sh = rewrite(x[: assert_op(a, a >= 0)].shape[0]) + # expected_sh = minimum(assert_op(a, a >= 0), x_sh0) + # assert equal_computations([sh], [expected_sh], strict_dtype=False) + + y = pt.vector("y", shape=(None,)) + sh = rewrite(x[: y.shape[0]].shape[0]) + expected_sh = minimum(Shape_i(0)(y), x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + a_uint = scalar("a_uint", dtype="uint64") + sh = rewrite(x[:a_uint].shape[0]) + expected_sh = minimum(a_uint.astype("int64"), x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + def test_nested_start(self): + rewrite = self.rewrite + x = vector("x", shape=(None,)) + x_sh0 = Shape_i(0)(x) + + sh = rewrite(x[3:][::-1][2:][::-1][4:].shape[0]) + expected_sh = maximum(0, -9 + x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + sh = rewrite(x[-3:][::-1][-2:][::-1][-4:].shape[0]) + expected_sh = minimum(2, x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + def test_nested_stop(self): + rewrite = self.rewrite + x = vector("x", shape=(None,)) + x_sh0 = Shape_i(0)(x) + + sh = rewrite(x[:3][::-1][:2][::-1][:4].shape[0]) + expected_sh = minimum(2, x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + sh = rewrite(x[:-3][::-1][:-2][::-1][:-4].shape[0]) + expected_sh = maximum(0, -9 + x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) + + def test_nested_start_and_stop(self): + rewrite = self.rewrite + x = vector("x", shape=(None,)) + x_sh0 = Shape_i(0)(x) + + sh = rewrite(x[1:-1][3:-2][1:-1].shape[0]) + expected_sh = maximum(0, -9 + x_sh0) + assert equal_computations([sh], [expected_sh], strict_dtype=False) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 77d41a03c5..cbc5214432 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -544,14 +544,14 @@ def with_mode( elif scalar_op == ps.mul: for axis in sorted(tosum, reverse=True): zv = np.multiply.reduce(zv, axis) - elif scalar_op == ps.scalar_maximum: + elif scalar_op == ps.maximum: # There is no identity value for the maximum function # So we can't support shape of dimensions 0. if np.prod(zv.shape) == 0: continue for axis in sorted(tosum, reverse=True): zv = np.maximum.reduce(zv, axis) - elif scalar_op == ps.scalar_minimum: + elif scalar_op == ps.minimum: # There is no identity value for the minimum function # So we can't support shape of dimensions 0. if np.prod(zv.shape) == 0: @@ -594,7 +594,7 @@ def with_mode( tosum = list(range(len(xsh))) f = pytensor.function([x], e.shape, mode=mode, on_unused_input="ignore") if not ( - scalar_op in [ps.scalar_maximum, ps.scalar_minimum] + scalar_op in [ps.maximum, ps.minimum] and (xsh == () or np.prod(xsh) == 0) ): assert all(f(xv) == zv.shape) @@ -606,8 +606,8 @@ def test_perform(self): for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]: self.with_mode(Mode(linker="py"), ps.add, dtype=dtype) self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype) - self.with_mode(Mode(linker="py"), ps.scalar_maximum, dtype=dtype) - self.with_mode(Mode(linker="py"), ps.scalar_minimum, dtype=dtype) + self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype) + self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype) self.with_mode(Mode(linker="py"), ps.and_, dtype=dtype, tensor_op=pt_all) self.with_mode(Mode(linker="py"), ps.or_, dtype=dtype, tensor_op=pt_any) for dtype in ["int8", "uint8"]: @@ -619,12 +619,8 @@ def test_perform_nan(self): for dtype in ["floatX", "complex64", "complex128"]: self.with_mode(Mode(linker="py"), ps.add, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype, test_nan=True) - self.with_mode( - Mode(linker="py"), ps.scalar_maximum, dtype=dtype, test_nan=True - ) - self.with_mode( - Mode(linker="py"), ps.scalar_minimum, dtype=dtype, test_nan=True - ) + self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype, test_nan=True) + self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype, test_nan=True) self.with_mode( Mode(linker="py"), ps.or_, @@ -659,8 +655,8 @@ def test_c(self): self.with_mode(Mode(linker="c"), ps.add, dtype=dtype) self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype) for dtype in ["bool", "floatX", "int8", "uint8"]: - self.with_mode(Mode(linker="c"), ps.scalar_minimum, dtype=dtype) - self.with_mode(Mode(linker="c"), ps.scalar_maximum, dtype=dtype) + self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype) + self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype) self.with_mode(Mode(linker="c"), ps.and_, dtype=dtype, tensor_op=pt_all) self.with_mode(Mode(linker="c"), ps.or_, dtype=dtype, tensor_op=pt_any) for dtype in ["bool", "int8", "uint8"]: @@ -678,12 +674,8 @@ def test_c_nan(self): self.with_mode(Mode(linker="c"), ps.add, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype, test_nan=True) for dtype in ["floatX"]: - self.with_mode( - Mode(linker="c"), ps.scalar_minimum, dtype=dtype, test_nan=True - ) - self.with_mode( - Mode(linker="c"), ps.scalar_maximum, dtype=dtype, test_nan=True - ) + self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype, test_nan=True) + self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype, test_nan=True) def test_infer_shape(self, dtype=None, pre_scalar_op=None): if dtype is None: diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 78ec97eff3..5875ead934 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -15,10 +15,10 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.gradient import grad -from pytensor.graph import Constant +from pytensor.graph import Constant, FunctionGraph from pytensor.graph.basic import equal_computations from pytensor.graph.op import get_test_value -from pytensor.graph.rewriting.utils import is_same_graph +from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.printing import pprint from pytensor.scalar.basic import as_scalar, int16 from pytensor.tensor import as_tensor, get_vector_length, vectorize @@ -71,6 +71,7 @@ lscalar, lvector, matrix, + scalar, tensor, tensor3, tensor4, @@ -1055,26 +1056,8 @@ def test_adv_sub1_idx_broadcast(self): assert np.allclose(g_0[0], 1) assert np.allclose(g_0[1:], 0) - @pytest.mark.slow - def test_shape_i_const(self): - # Each axis is treated independently by shape_i/shape operators - - mode_opt = self.mode - data = self.shared(np.array(np.arange(5), dtype=self.dtype)) - for start in [None, -8, -5, -1, 0, 1, 5, 8]: - outs = [] - shapes = [] - for stop in [None, -8, -5, -1, 0, 1, 5, 8]: - for step in [None, -3, -1, 2]: - outs += [data[start:stop:step].shape] - shapes += [data.get_value(borrow=True)[start:stop:step].shape] - f = self.function([], outs, mode=mode_opt, op=subtensor_ops, N=0) - t_shapes = f() - for t_shape, shape in zip(t_shapes, shapes, strict=True): - assert np.all(t_shape == shape) - assert Subtensor not in [x.op for x in f.maker.fgraph.toposort()] - def test_shape_i_scalar(self): + # TODO: Move this to infer_shape # Each axis is treated independently by shape_i/shape operators mode_opt = self.mode @@ -1466,6 +1449,70 @@ def test_adv1_inc_sub_notlastdim_1_2dval_no_broadcast(self): assert np.allclose(m2_val, m2_ref), (m2_val, m2_ref) +class TestSubtensorInferShape: + _NO_OPT_MODE = Mode(linker="py", optimizer=None) + + @pytest.mark.parametrize( + "b", [None, 0, 1, 7, 13, -1, -7, -13], ids=lambda x: f"b={x}" + ) + @pytest.mark.parametrize( + "a", [None, 0, 1, 7, 13, -1, -7, -13], ids=lambda x: f"a={x}" + ) + @pytest.mark.parametrize("step", [None, 1, 3, -1, -4], ids=lambda x: f"step={x}") + def test_constant_params(self, a, b, step): + x = vector("x", dtype="int64") + y = x[a:b:step].shape[0] + + fg = FunctionGraph(outputs=[y], clone=False) + rewrite_graph(fg, include=("ShapeOpt", "canonicalize"), clone=False) + assert not any(isinstance(node.op, Subtensor) for node in fg.apply_nodes) + assert len(fg.apply_nodes) <= 7 + + fn = pytensor.function( + [x], + fg.outputs[0], + trust_input=True, + mode=self._NO_OPT_MODE, + on_unused_input="ignore", + ) + x_full = np.arange(20) + for n in range(0, 20): + x_test = x_full[:n] + assert fn(x_test) == x_test[a:b:step].shape[0], f"failed with {n=}" + + @pytest.mark.parametrize("a_dtype", (None, "int64", "uint64")) + @pytest.mark.parametrize("b_dtype", (None, "int64", "uint64")) + @pytest.mark.parametrize("step_dtype", (None, "int64", "uint64")) + def test_uint(self, a_dtype, b_dtype, step_dtype): + a = None if a_dtype is None else scalar(dtype=a_dtype) + b = None if b_dtype is None else scalar(dtype=b_dtype) + step = None if step_dtype is None else scalar(dtype=step_dtype) + x = vector("x", dtype="int64") + + y = x[a:b:step].shape[0] + + final_y = rewrite_graph(y, include=("ShapeOpt", "canonicalize"), clone=False) + assert final_y.dtype == "int64" + + test_a = None if a is None else 1 if a_dtype.startswith("u") else -1 + test_b = None if b is None else 10 if b_dtype.startswith("u") else -2 + test_step = None if step is None else 2 if step_dtype.startswith("u") else -2 + test_x = np.arange(20) + + test_dict = {x: test_x} + if a is not None: + test_dict[a] = test_a + if b is not None: + test_dict[b] = test_b + if step is not None: + test_dict[step] = test_step + + final_y_eval = final_y.eval( + test_dict, mode=self._NO_OPT_MODE, on_unused_input="ignore" + ) + assert final_y_eval == test_x[test_a:test_b:test_step].shape[0] + + def test_take_basic(): with pytest.raises(TypeError): take(matrix(), lvector(), axis=lscalar())