diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index 4f087b6788..4d3a5736a7 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. .. function:: shape_padleft(x, n_ones=1) - Reshape `x` by left padding the shape with `n_ones` 1s. Note that all - this new dimension will be broadcastable. To make them non-broadcastable - see the :func:`unbroadcast`. + Reshape `x` by left padding the shape with `n_ones` 1s. + All new dimensions will be broadcastable. :param x: variable to be reshaped :type x: any `TensorVariable` (or compatible) @@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. .. function:: shape_padright(x, n_ones=1) - Reshape `x` by right padding the shape with `n_ones` ones. Note that all - this new dimension will be broadcastable. To make them non-broadcastable - see the :func:`unbroadcast`. + Reshape `x` by right padding the shape with `n_ones` ones. + All new dimensions will be broadcastable. :param x: variable to be reshaped :type x: any TensorVariable (or compatible) @@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. .. function:: shape_padaxis(t, axis) - Reshape `t` by inserting ``1`` at the dimension `axis`. Note that this new - dimension will be broadcastable. To make it non-broadcastable - see the :func:`unbroadcast`. + Reshape `t` by inserting ``1`` at the dimension `axis`. + All new dimensions will be broadcastable. :type x: any `TensorVariable` (or compatible) :param x: variable to be reshaped diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 749ec5cb42..91d6e1a588 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -292,14 +292,8 @@ def clone_inputs(i): f" shared_var.type={store_into.type}," f" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})." ) - err_sug = ( - "If the difference is related to the broadcast pattern," - " you can call the" - " tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])" - " function to mask broadcastable dimensions." - ) - raise TypeError(err_msg, err_sug) + raise TypeError(err_msg) assert store_into.type.is_super(update_val.type) update_d[store_into] = update_val diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index c458e5b296..8c07a99280 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -26,7 +26,7 @@ from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.graph.type import HasDataType, HasShape -from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast +from pytensor.tensor.shape import Reshape, Shape, SpecifyShape if TYPE_CHECKING: @@ -481,7 +481,6 @@ def cond_make_inplace(fgraph, node): Shape, SpecifyShape, Reshape, - Unbroadcast, pt.math.Dot, pt.math.Max, pt.math.Argmax, diff --git a/pytensor/link/jax/dispatch/shape.py b/pytensor/link/jax/dispatch/shape.py index 6d809252a7..d7c1d0bcbd 100644 --- a/pytensor/link/jax/dispatch/shape.py +++ b/pytensor/link/jax/dispatch/shape.py @@ -4,7 +4,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.type import TensorType @@ -104,11 +104,3 @@ def specifyshape(x, *shape): return x return specifyshape - - -@jax_funcify.register(Unbroadcast) -def jax_funcify_Unbroadcast(op, **kwargs): - def unbroadcast(x): - return x - - return unbroadcast diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 80b05d4e81..8f5972c058 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -17,7 +17,6 @@ Split, TensorFromScalar, ) -from pytensor.tensor.shape import Unbroadcast @numba_funcify.register(AllocEmpty) @@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}): return numba_basic.numba_njit(makevector_fn) -@numba_funcify.register(Unbroadcast) -def numba_funcify_Unbroadcast(op, **kwargs): - @numba_basic.numba_njit - def unbroadcast(x): - return x - - return unbroadcast - - @numba_funcify.register(TensorFromScalar) def numba_funcify_TensorFromScalar(op, **kwargs): @numba_basic.numba_njit(inline="always") diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index c15b3a3779..1305211b0c 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -2,7 +2,7 @@ from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @pytorch_funcify.register(Reshape) @@ -56,11 +56,3 @@ def specifyshape(x, *shape): return x return specifyshape - - -@pytorch_funcify.register(Unbroadcast) -def pytorch_funcify_Unbroadcast(op, **kwargs): - def unbroadcast(x): - return x - - return unbroadcast diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index ab2b53061d..ae3785958c 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -15,7 +15,7 @@ from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import minimum -from pytensor.tensor.shape import shape_padleft, unbroadcast +from pytensor.tensor.shape import shape_padleft from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.updates import OrderedUpdates @@ -748,7 +748,7 @@ def wrap_into_list(x): # defined in scan utils sit_sot_scan_inputs.append( expand_empty( - unbroadcast(shape_padleft(actual_arg), 0), + shape_padleft(actual_arg), actual_n_steps, ) ) @@ -865,13 +865,13 @@ def wrap_into_list(x): if n_fixed_steps in (1, -1): for pos, inner_out in enumerate(outputs): # we need to see if we need to pad our sequences with an - # unbroadcastable dimension; case example : we return an + # extra dimension; case example : we return an # output for which we want all intermediate. If n_steps is 1 # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: - outputs[pos] = unbroadcast(shape_padleft(inner_out), 0) + outputs[pos] = shape_padleft(inner_out) if not return_list and len(outputs) == 1: outputs = outputs[0] @@ -1002,7 +1002,7 @@ def wrap_into_list(x): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( expand_empty( - unbroadcast(shape_padleft(input.variable), 0), + shape_padleft(input.variable), actual_n_steps, ) ) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 4f2739fb69..2c3f404449 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -166,8 +166,7 @@ def check_broadcast(v1, v2): "axis %d in `output_info`. This can happen if one of the " "dimension is fixed to 1 in the input, while it is still " "variable in the output, or vice-verca. You have to make " - "them consistent, e.g. using pytensor.tensor." - "{unbroadcast, specify_broadcastable}." + "them consistent, e.g. using pytensor.tensor.specify_broadcastable." ) size = min(v1.type.ndim, v2.type.ndim) for n, (b1, b2) in enumerate( diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 3b74471cd4..b8e6b009d8 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -58,7 +58,11 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Dot, dot, maximum, minimum -from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch +from pytensor.tensor.rewriting.basic import ( + broadcasted_by, + constant_folding, + local_useless_switch, +) from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink from pytensor.tensor.shape import shape @@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): return subtensor_merge_replacements +def _is_default_scan_buffer(x: TensorVariable) -> bool: + node = x.owner + + if node is None: + return False + + op = node.op + if not ( + isinstance(op, IncSubtensor) + and op.set_instead_of_inc + and op.idx_list == [slice(None, ps.int64)] + ): + return False + + x, y, *_ = node.inputs + if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)): + return False + + # The value may have been broadcast to fill in the initial taps. + # If the user specified outputs as: + # x = scalar(); init = alloc(x, 2); + # outputs_info=[init, taps=(-2, -1)] + # Scan will generate an initial buffer that looks like + # alloc_empty(2 + nsteps)[:2].set(alloc(x, 2)) + # PyTensor will then rewrite it as: + # alloc_empty(2 + nsteps)[:2].set(x) + # When the initial value (x) is being broadcast by the set_subtensor + # we can't recreate a newly sized buffer working with x alone + # We want to check that: + # 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable + # But due to laziness we use the slightly more conservative check: + # 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable + if broadcasted_by(y, x): + return False + + return True + + def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): r"""Graph optimizer that reduces scan memory consumption. @@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # 3.2 check orphane outputs to see if we can eliminate any required, not_required = scan_can_remove_outs(node.op, orphane_outs) - # 3.3. compose replace pairs for those nodes that need not - # to store everything in memory ( or ar orphane and required - # by the inner function .. ) + + # 3.3. compose replace pairs for those nodes that need not store everything in memory + # (or ar orphan but required by the inner function) replaced_outs = [] offset = 1 + op_info.n_seqs + op_info.n_mit_mot - for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]): + for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): i = idx + op_info.n_mit_mot - if not (isinstance(_val, int) and _val <= 0 and i not in required): - if idx + op_info.n_mit_mot in required: - val = 1 - else: - val = _val + if not (isinstance(val, int) and val <= 0 and i not in required): + required_orphan = idx + op_info.n_mit_mot in required # If the memory for this output has been pre-allocated # before going into the scan op (by an alloc node) if idx < op_info.n_mit_sot + op_info.n_sit_sot: - # In case the input is still an alloc node, we - # actually have two options: - # a) the input is a set_subtensor, in that case we - # can replace the initial tensor by a slice, - # b) it is not, and we simply take a slice of it. - # TODO: commit change below with Razvan - if ( - nw_inputs[offset + idx].owner - and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor) - and nw_inputs[offset + idx].owner.op.set_instead_of_inc - and isinstance( - nw_inputs[offset + idx].owner.op.idx_list[0], slice - ) - # Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value - # As it happens in set_subtensor(empty(2)[:], 0) - and not ( - nw_inputs[offset + idx].ndim - > nw_inputs[offset + idx].owner.inputs[1].ndim - ) - ): - _nw_input = nw_inputs[offset + idx].owner.inputs[1] - cval = pt.as_tensor_variable(val) - initl = pt.as_tensor_variable(init_l[i]) - tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl) - nw_input = expand_empty(_nw_input, tmp_idx) + nw_input = nw_inputs[offset + idx] + + # Recreate default buffers with new size + if _is_default_scan_buffer(nw_input): + extra_size = 1 if required_orphan else val - init_l[i] + nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) + # Otherwise, just trim with a slice else: - tmp = pt.as_tensor_variable(val) - initl = pt.as_tensor_variable(init_l[i]) - tmp = maximum(tmp, initl) - nw_input = nw_inputs[offset + idx][:tmp] + stop = init_l[i] if required_orphan else val + nw_input = nw_input[:stop] nw_inputs[offset + idx] = nw_input replaced_outs.append(op_info.n_mit_mot + idx) @@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: + op_info.n_shared_outs ) if nw_inputs[pos] == node.inputs[0]: - nw_inputs[pos] = val + nw_inputs[pos] = 1 if required_orphan else val odx = op_info.n_mit_mot + idx replaced_outs.append(odx) old_outputs += [ @@ -1600,8 +1619,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ], ) ] - # 3.4. Recompute inputs for everything else based on the new - # number of steps + # 3.4. Recompute inputs for everything else based on the new number of steps if global_nsteps is not None: for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): if val == 0: @@ -1609,28 +1627,14 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # results for that state, including the initial values. if idx < op_info.n_mit_sot + op_info.n_sit_sot: in_idx = offset + idx - # Number of steps in the initial state - initl = init_l[op_info.n_mit_mot + idx] - - # If the initial buffer has the form - # inc_subtensor(zeros(...)[...], _nw_input) - # we want to make the zeros tensor as small as - # possible (nw_steps + initl), and call - # inc_subtensor on that instead. - # Otherwise, simply take 0:(nw_steps+initl). - if ( - nw_inputs[in_idx].owner - and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor) - and isinstance( - nw_inputs[in_idx].owner.op.idx_list[0], slice - ) - ): - _nw_input = nw_inputs[in_idx].owner.inputs[1] - nw_input = expand_empty(_nw_input, nw_steps) - nw_inputs[in_idx] = nw_input + nw_input = nw_inputs[in_idx] + if _is_default_scan_buffer(nw_input): + nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps) else: - # FIXME: This is never used - nw_input = nw_inputs[in_idx][: (initl + nw_steps)] + # Number of steps in the initial state + init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx]) + nw_input = nw_input[: (init_l_pt + nw_steps)] + nw_inputs[in_idx] = nw_input elif ( idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 2f3b94f104..a108d87f42 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -53,7 +53,6 @@ from pytensor.tensor.shape import ( Shape, Shape_i, - Unbroadcast, shape, shape_padaxis, shape_padleft, @@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value( if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: op = v.owner.op max_recur -= 1 - if isinstance( - op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp - ): + if isinstance(op, Alloc | DimShuffle | OutputGuard | DeepCopyOp): # OutputGuard is only used in debugmode but we # keep it here to avoid problems with old pickles v = v.owner.inputs[0] @@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value( grandparent = leftmost_parent.owner.inputs[0] gp_shape = grandparent.type.shape ndim = grandparent.type.ndim - if grandparent.owner and isinstance( - grandparent.owner.op, Unbroadcast - ): - ggp_shape = grandparent.owner.inputs[0].type.shape - l = [ - _get_underlying_scalar_constant_value(s) for s in ggp_shape - ] - gp_shape = tuple(l) if not (idx < ndim): msg = ( diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 9462504e78..1eb10d247b 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -42,9 +42,7 @@ Shape, Shape_i, SpecifyShape, - Unbroadcast, specify_shape, - unbroadcast, ) from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes @@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node): # structure. replacement = shape_feature.scheduled[node] return [shape_feature.shape_of[replacement][node.op.i]] - - -@register_useless -@register_canonicalize -@register_specialize -@node_rewriter([Unbroadcast]) -def local_useless_unbroadcast(fgraph, node): - """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.""" - if isinstance(node.op, Unbroadcast): - x = node.inputs[0] - if x.type.ndim == node.outputs[0].type.ndim and all( - s1 == s2 - for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True) - if s1 == 1 or s2 == 1 - ): - # No broadcastable flag was modified - # No need to copy over stack trace, - # because x should already have a stack trace. - return [x] - else: - # Keep the flags that modify something - new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1) - if new_axes == node.op.axes: - # All flags are useful - return None - else: - r = unbroadcast(x, *new_axes) - # Copy over stacktrace from previous output - copy_stack_trace(node.outputs, r) - return [r] - - -@register_canonicalize -@register_specialize -@node_rewriter([Unbroadcast]) -def local_unbroadcast_lift(fgraph, node): - """ - Lifts `Unbroadcast` through unary Elemwise operations, - and merges consecutive `Unbroadcast`s. - - Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x)) - Unbroadcast(Unbroadcast(x)) => Unbroadcast(x) - - TODO: Implement equivalent Elemwise lift for SpecifyShape - """ - op = node.op - if not isinstance(op, Unbroadcast): - return False - - inp = node.inputs[0] - inode = inp.owner - if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: - if len(fgraph.clients.get(inp, ())) == 1: - unbroadcasted = unbroadcast(inode.inputs[0], *op.axes) - copy_stack_trace(node.outputs, unbroadcasted) - - rval = inode.op.make_node(unbroadcasted).outputs - - # Copy over stacktrace from previous output (after unbroadcasting) - # and input (after elemwise operation) to new output, because an - # error in the new graph could have been caused by either of the - # two ops. - copy_stack_trace(node.outputs + node.inputs, rval) - return rval - - if inode and isinstance(inode.op, Unbroadcast): - # Merge axis of each unbroadcast - axis = tuple(set(inode.op.axes).union(set(op.axes))) - iinput = inode.inputs[0] - rval = [unbroadcast(iinput, *axis)] - # Copy over stacktrace from previous output (after second unbroadcasting) - # and from previous input (after first unbroadcasting) because an error in - # the new graph could have been caused by either of the two Unbroadcast ops. - copy_stack_trace(node.outputs + node.inputs, rval) - return rval diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index c7a4574a91..1af10e52b4 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -59,11 +59,9 @@ from pytensor.tensor.shape import ( Shape, SpecifyShape, - Unbroadcast, shape_padleft, shape_tuple, specify_shape, - unbroadcast, ) from pytensor.tensor.sharedvar import TensorSharedVariable from pytensor.tensor.subtensor import ( @@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node): Handles the following unary ops: elemwise(x,...)[idx] -> elemwise(x[idx],...) when x,... are broadcasted scalar or not broadcasted at all - Unbroadcast(x)[idx] => Unbroadcast(x[idx]) """ if isinstance(node.op, Subtensor): @@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node): copy_stack_trace([node.outputs[0], node.inputs[0]], ret) return [ret] - if isinstance(u.owner.op, Unbroadcast): - # Subtensor might reduce dim., adapt broadcast pattern accordingly - old_axes = u.owner.op.axes - new_axes = [] - - # loop through indices being subtensor-ed - # i indexes broadcastable pattern before subtensor - # j indexes broadcastable pattern after subtensor - j = 0 - for i, x in enumerate(node.op.idx_list): - # if it is not a slice, it will reduce the dimension, should - # not appear in the broascastable dimensions - if isinstance(x, slice): - if i in old_axes: - new_axes.append(j) - j += 1 - # now keep the broadcastable pattern of all - # items not appearing in subtensor list - for i in range(len(node.op.idx_list), len(u.broadcastable)): - if i in old_axes: - new_axes.append(j) - j += 1 - - subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], subt_x) - - rbcast_subt_x = unbroadcast(subt_x, *new_axes) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) - - return [rbcast_subt_x] - @register_canonicalize @register_specialize diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 1fc4e6dd2b..5a4cfdc52a 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -18,7 +18,6 @@ from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import basic as ptb -from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor from pytensor.tensor.type_other import NoneConst, NoneTypeT @@ -1008,118 +1007,3 @@ def specify_broadcastable(x, *axes): axes = normalize_axis_tuple(axes, x.type.ndim) shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)] return specify_shape(x, shape_info) - - -class Unbroadcast(COp): - """ - Mask static broadcastable dimensions of input as `None` - - See Also - -------- - unbroadcast - - - Examples - -------- - ``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None` - - """ - - view_map = {0: [0]} - _f16_ok = True - # Mapping from Type to C code (and version) to use. - # In the C code, the name of the input variable is %(iname)s, - # the output variable is %(oname)s. - c_code_and_version: dict = {} - - check_input = False - __props__ = ("axes",) - _f16_ok = True - - def __init__(self, *axis): - # Sort them to make sure we merge all possible case. - items = tuple(sorted(axis)) - self.axes = items - for axis in self.axes: - if not isinstance(axis, np.integer | int): - raise TypeError(f"Unbroadcast needs integer axes. Got {axis}") - - def __str__(self): - return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}" - - def make_node(self, x): - x = as_tensor_variable(x) - if x.type.ndim <= max(self.axes): - raise ValueError("Trying to unbroadcast of non-existent dimension") - shape = [ - None if (sh == 1 and i in self.axes) else sh - for i, sh in enumerate(x.type.shape) - ] - return Apply(self, [x], [x.type.clone(shape=shape)()]) - - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ - out[0] = x - - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads - # restore the broadcasting pattern of the input - return [specify_shape(gz, x.type.shape)] - - def infer_shape(self, fgraph, node, ishapes): - assert len(ishapes) == 1 - return [tuple(ishapes[0])] - - def R_op(self, inputs, eval_points): - if eval_points[0] is None: - return [None] - return self(*eval_points, return_list=True) - - def c_code(self, node, nodename, inp, out, sub): - (iname,) = inp - (oname,) = out - - return f""" - Py_XDECREF({oname}); - {oname} = {iname}; - Py_XINCREF({oname}); - """ - - def c_code_cache_version(self): - return (3,) - - -def unbroadcast(x, *axes): - """ - Mask static broadcastable dimensions of input as `None` - - Parameters - ---------- - x : tensor_like - Input pytensor tensor. - axis : an int or an iterable object such as list or tuple of int values - The broadcastable dimensions of x that should be unbroadcasted. - - Returns - ------- - tensor - A pytensor tensor, with static broadcastable dimensions masked as `None` - - """ - x = as_tensor_variable(x) - unbroadcasted_axes = [axis for axis in axes if x.type.shape[axis] == 1] - if not unbroadcasted_axes: - return x - return Unbroadcast(*unbroadcasted_axes)(x) - - -@_vectorize_node.register(Unbroadcast) -def _vectorize_unbroadcast( - op: Unbroadcast, node: Apply, batch_x: TensorVariable -) -> Apply: - core_ndim = node.inputs[0].type.ndim - batch_ndim = batch_x.type.ndim - core_ndim - batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim) - return cast(Apply, unbroadcast(batch_x, *batch_axes).owner) diff --git a/tests/link/jax/test_shape.py b/tests/link/jax/test_shape.py index 085f67f411..751c4cb418 100644 --- a/tests/link/jax/test_shape.py +++ b/tests/link/jax/test_shape.py @@ -4,7 +4,7 @@ import pytensor.tensor as pt from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.configdefaults import config -from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape +from pytensor.tensor.shape import Shape, Shape_i, reshape from pytensor.tensor.type import iscalar, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -70,10 +70,6 @@ def test_jax_compile_ops(): compare_jax_and_py([], [x], []) x_np = np.zeros((20, 1, 1)) - x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) - - compare_jax_and_py([], [x], []) - x = ViewOp()(pt.as_tensor_variable(x_np)) compare_jax_and_py([], [x], []) diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 0eebe115e9..09963f9d36 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -7,7 +7,6 @@ from pytensor import config, function from pytensor.compile import get_mode from pytensor.scalar import Add -from pytensor.tensor.shape import Unbroadcast from tests.link.numba.test_basic import ( compare_numba_and_py, compare_shape_dtype, @@ -75,16 +74,6 @@ def test_ScalarFromTensor(): ) -def test_Unbroadcast(): - v, v_test = pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX) - g = Unbroadcast(0)(v) - compare_numba_and_py( - [v], - g, - [v_test], - ) - - @pytest.mark.parametrize( "vals, dtype", [ diff --git a/tests/link/pytorch/test_shape.py b/tests/link/pytorch/test_shape.py index 4bfe6e1a2b..30c2f0a5c0 100644 --- a/tests/link/pytorch/test_shape.py +++ b/tests/link/pytorch/test_shape.py @@ -2,7 +2,7 @@ import pytensor.tensor as pt from pytensor.configdefaults import config -from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape +from pytensor.tensor.shape import Shape, Shape_i, reshape from pytensor.tensor.type import iscalar, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -50,10 +50,3 @@ def test_pytorch_Reshape_dynamic(): compare_pytorch_and_py( [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] ) - - -def test_pytorch_unbroadcast(): - x_np = np.zeros((20, 1, 1)) - x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) - - compare_pytorch_and_py([], [x], []) diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 70c781a0c9..f6f395a96d 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -36,32 +36,31 @@ def test_debugprint_sitsot(): │ │ │ │ │ ├─ k [id D] │ │ │ │ │ └─ Subtensor{i} [id H] │ │ │ │ │ ├─ Shape [id I] - │ │ │ │ │ │ └─ Unbroadcast{0} [id J] - │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] - │ │ │ │ │ │ └─ Second [id L] - │ │ │ │ │ │ ├─ A [id M] - │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] - │ │ │ │ │ │ └─ 1.0 [id O] - │ │ │ │ │ └─ 0 [id P] - │ │ │ │ └─ Subtensor{i} [id Q] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id J] + │ │ │ │ │ │ └─ Second [id K] + │ │ │ │ │ │ ├─ A [id L] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M] + │ │ │ │ │ │ └─ 1.0 [id N] + │ │ │ │ │ └─ 0 [id O] + │ │ │ │ └─ Subtensor{i} [id P] │ │ │ │ ├─ Shape [id I] │ │ │ │ │ └─ ··· - │ │ │ │ └─ 1 [id R] - │ │ │ ├─ Unbroadcast{0} [id J] + │ │ │ │ └─ 1 [id Q] + │ │ │ ├─ ExpandDims{axis=0} [id J] │ │ │ │ └─ ··· - │ │ │ └─ ScalarFromTensor [id S] + │ │ │ └─ ScalarFromTensor [id R] │ │ │ └─ Subtensor{i} [id H] │ │ │ └─ ··· - │ │ └─ A [id M] (outer_in_non_seqs-0) - │ └─ 1 [id T] - └─ -1 [id U] + │ │ └─ A [id L] (outer_in_non_seqs-0) + │ └─ 1 [id S] + └─ -1 [id T] Inner graphs: Scan{scan_fn, while_loop=False, inplace=none} [id C] - ← Mul [id V] (inner_out_sit_sot-0) - ├─ *0- [id W] -> [id E] (inner_in_sit_sot-0) - └─ *1- [id X] -> [id M] (inner_in_non_seqs-0) + ← Mul [id U] (inner_out_sit_sot-0) + ├─ *0- [id V] -> [id E] (inner_in_sit_sot-0) + └─ *1- [id W] -> [id L] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -94,32 +93,31 @@ def test_debugprint_sitsot_no_extra_info(): │ │ │ │ │ ├─ k [id D] │ │ │ │ │ └─ Subtensor{i} [id H] │ │ │ │ │ ├─ Shape [id I] - │ │ │ │ │ │ └─ Unbroadcast{0} [id J] - │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] - │ │ │ │ │ │ └─ Second [id L] - │ │ │ │ │ │ ├─ A [id M] - │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] - │ │ │ │ │ │ └─ 1.0 [id O] - │ │ │ │ │ └─ 0 [id P] - │ │ │ │ └─ Subtensor{i} [id Q] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id J] + │ │ │ │ │ │ └─ Second [id K] + │ │ │ │ │ │ ├─ A [id L] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M] + │ │ │ │ │ │ └─ 1.0 [id N] + │ │ │ │ │ └─ 0 [id O] + │ │ │ │ └─ Subtensor{i} [id P] │ │ │ │ ├─ Shape [id I] │ │ │ │ │ └─ ··· - │ │ │ │ └─ 1 [id R] - │ │ │ ├─ Unbroadcast{0} [id J] + │ │ │ │ └─ 1 [id Q] + │ │ │ ├─ ExpandDims{axis=0} [id J] │ │ │ │ └─ ··· - │ │ │ └─ ScalarFromTensor [id S] + │ │ │ └─ ScalarFromTensor [id R] │ │ │ └─ Subtensor{i} [id H] │ │ │ └─ ··· - │ │ └─ A [id M] - │ └─ 1 [id T] - └─ -1 [id U] + │ │ └─ A [id L] + │ └─ 1 [id S] + └─ -1 [id T] Inner graphs: Scan{scan_fn, while_loop=False, inplace=none} [id C] - ← Mul [id V] - ├─ *0- [id W] -> [id E] - └─ *1- [id X] -> [id M] + ← Mul [id U] + ├─ *0- [id V] -> [id E] + └─ *1- [id W] -> [id L] """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -278,32 +276,31 @@ def compute_A_k(A, k): │ │ │ │ │ │ ├─ *3- [id BF] -> [id X] (inner_in_non_seqs-1) │ │ │ │ │ │ └─ Subtensor{i} [id BJ] │ │ │ │ │ │ ├─ Shape [id BK] - │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] - │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM] - │ │ │ │ │ │ │ └─ Second [id BN] - │ │ │ │ │ │ │ ├─ *2- [id BO] -> [id W] (inner_in_non_seqs-0) - │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP] - │ │ │ │ │ │ │ └─ 1.0 [id BQ] - │ │ │ │ │ │ └─ 0 [id BR] - │ │ │ │ │ └─ Subtensor{i} [id BS] + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BL] + │ │ │ │ │ │ │ └─ Second [id BM] + │ │ │ │ │ │ │ ├─ *2- [id BN] -> [id W] (inner_in_non_seqs-0) + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO] + │ │ │ │ │ │ │ └─ 1.0 [id BP] + │ │ │ │ │ │ └─ 0 [id BQ] + │ │ │ │ │ └─ Subtensor{i} [id BR] │ │ │ │ │ ├─ Shape [id BK] │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ 1 [id BT] - │ │ │ │ ├─ Unbroadcast{0} [id BL] + │ │ │ │ │ └─ 1 [id BS] + │ │ │ │ ├─ ExpandDims{axis=0} [id BL] │ │ │ │ │ └─ ··· - │ │ │ │ └─ ScalarFromTensor [id BU] + │ │ │ │ └─ ScalarFromTensor [id BT] │ │ │ │ └─ Subtensor{i} [id BJ] │ │ │ │ └─ ··· - │ │ │ └─ *2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) - │ │ └─ 1 [id BV] - │ └─ -1 [id BW] - └─ ExpandDims{axis=0} [id BX] - └─ *1- [id BY] -> [id U] (inner_in_seqs-1) + │ │ │ └─ *2- [id BN] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) + │ │ └─ 1 [id BU] + │ └─ -1 [id BV] + └─ ExpandDims{axis=0} [id BW] + └─ *1- [id BX] -> [id U] (inner_in_seqs-1) Scan{scan_fn, while_loop=False, inplace=none} [id BE] - ← Mul [id BZ] (inner_out_sit_sot-0) - ├─ *0- [id CA] -> [id BG] (inner_in_sit_sot-0) - └─ *1- [id CB] -> [id BO] (inner_in_non_seqs-0) + ← Mul [id BY] (inner_out_sit_sot-0) + ├─ *0- [id BZ] -> [id BG] (inner_in_sit_sot-0) + └─ *1- [id CA] -> [id BN] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -375,34 +372,33 @@ def compute_A_k(A, k): │ │ │ │ │ │ ├─ *3- [id BB] (inner_in_non_seqs-1) │ │ │ │ │ │ └─ Subtensor{i} [id BL] │ │ │ │ │ │ ├─ Shape [id BM] - │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] - │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO] - │ │ │ │ │ │ │ └─ Second [id BP] - │ │ │ │ │ │ │ ├─ *2- [id BA] (inner_in_non_seqs-0) - │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ] - │ │ │ │ │ │ │ └─ 1.0 [id BR] - │ │ │ │ │ │ └─ 0 [id BS] - │ │ │ │ │ └─ Subtensor{i} [id BT] + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BN] + │ │ │ │ │ │ │ └─ Second [id BO] + │ │ │ │ │ │ │ ├─ *2- [id BA] (inner_in_non_seqs-0) + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP] + │ │ │ │ │ │ │ └─ 1.0 [id BQ] + │ │ │ │ │ │ └─ 0 [id BR] + │ │ │ │ │ └─ Subtensor{i} [id BS] │ │ │ │ │ ├─ Shape [id BM] │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ 1 [id BU] - │ │ │ │ ├─ Unbroadcast{0} [id BN] + │ │ │ │ │ └─ 1 [id BT] + │ │ │ │ ├─ ExpandDims{axis=0} [id BN] │ │ │ │ │ └─ ··· - │ │ │ │ └─ ScalarFromTensor [id BV] + │ │ │ │ └─ ScalarFromTensor [id BU] │ │ │ │ └─ Subtensor{i} [id BL] │ │ │ │ └─ ··· │ │ │ └─ *2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) - │ │ └─ 1 [id BW] - │ └─ -1 [id BX] - └─ ExpandDims{axis=0} [id BY] + │ │ └─ 1 [id BV] + │ └─ -1 [id BW] + └─ ExpandDims{axis=0} [id BX] └─ *1- [id Z] (inner_in_seqs-1) Scan{scan_fn, while_loop=False, inplace=none} [id BH] - → *0- [id BZ] -> [id BI] (inner_in_sit_sot-0) - → *1- [id CA] -> [id BA] (inner_in_non_seqs-0) - ← Mul [id CB] (inner_out_sit_sot-0) - ├─ *0- [id BZ] (inner_in_sit_sot-0) - └─ *1- [id CA] (inner_in_non_seqs-0) + → *0- [id BY] -> [id BI] (inner_in_sit_sot-0) + → *1- [id BZ] -> [id BA] (inner_in_non_seqs-0) + ← Mul [id CA] (inner_out_sit_sot-0) + ├─ *0- [id BY] (inner_in_sit_sot-0) + └─ *1- [id BZ] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -516,105 +512,104 @@ def test_debugprint_mitmot(): │ │ │ │ │ │ │ ├─ k [id G] │ │ │ │ │ │ │ └─ Subtensor{i} [id K] │ │ │ │ │ │ │ ├─ Shape [id L] - │ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] - │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] - │ │ │ │ │ │ │ │ └─ Second [id O] - │ │ │ │ │ │ │ │ ├─ A [id P] - │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q] - │ │ │ │ │ │ │ │ └─ 1.0 [id R] - │ │ │ │ │ │ │ └─ 0 [id S] - │ │ │ │ │ │ └─ Subtensor{i} [id T] + │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M] + │ │ │ │ │ │ │ │ └─ Second [id N] + │ │ │ │ │ │ │ │ ├─ A [id O] + │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id P] + │ │ │ │ │ │ │ │ └─ 1.0 [id Q] + │ │ │ │ │ │ │ └─ 0 [id R] + │ │ │ │ │ │ └─ Subtensor{i} [id S] │ │ │ │ │ │ ├─ Shape [id L] │ │ │ │ │ │ │ └─ ··· - │ │ │ │ │ │ └─ 1 [id U] - │ │ │ │ │ ├─ Unbroadcast{0} [id M] + │ │ │ │ │ │ └─ 1 [id T] + │ │ │ │ │ ├─ ExpandDims{axis=0} [id M] │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ ScalarFromTensor [id V] + │ │ │ │ │ └─ ScalarFromTensor [id U] │ │ │ │ │ └─ Subtensor{i} [id K] │ │ │ │ │ └─ ··· - │ │ │ │ └─ A [id P] (outer_in_non_seqs-0) - │ │ │ └─ 0 [id W] - │ │ └─ 1 [id X] - │ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0) - │ │ ├─ Subtensor{::step} [id Z] - │ │ │ ├─ Subtensor{:stop} [id BA] + │ │ │ │ └─ A [id O] (outer_in_non_seqs-0) + │ │ │ └─ 0 [id V] + │ │ └─ 1 [id W] + │ ├─ Subtensor{:stop} [id X] (outer_in_seqs-0) + │ │ ├─ Subtensor{::step} [id Y] + │ │ │ ├─ Subtensor{:stop} [id Z] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ -1 [id BB] - │ │ │ └─ -1 [id BC] - │ │ └─ ScalarFromTensor [id BD] + │ │ │ │ └─ -1 [id BA] + │ │ │ └─ -1 [id BB] + │ │ └─ ScalarFromTensor [id BC] │ │ └─ Sub [id C] │ │ └─ ··· - │ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1) - │ │ ├─ Subtensor{:stop} [id BF] - │ │ │ ├─ Subtensor{::step} [id BG] + │ ├─ Subtensor{:stop} [id BD] (outer_in_seqs-1) + │ │ ├─ Subtensor{:stop} [id BE] + │ │ │ ├─ Subtensor{::step} [id BF] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ -1 [id BH] - │ │ │ └─ -1 [id BI] - │ │ └─ ScalarFromTensor [id BJ] + │ │ │ │ └─ -1 [id BG] + │ │ │ └─ -1 [id BH] + │ │ └─ ScalarFromTensor [id BI] │ │ └─ Sub [id C] │ │ └─ ··· - │ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0) - │ │ ├─ IncSubtensor{start:} [id BL] - │ │ │ ├─ Second [id BM] + │ ├─ Subtensor{::step} [id BJ] (outer_in_mit_mot-0) + │ │ ├─ IncSubtensor{start:} [id BK] + │ │ │ ├─ Second [id BL] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN] - │ │ │ │ └─ 0.0 [id BO] - │ │ │ ├─ IncSubtensor{i} [id BP] - │ │ │ │ ├─ Second [id BQ] - │ │ │ │ │ ├─ Subtensor{start:} [id BR] + │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BM] + │ │ │ │ └─ 0.0 [id BN] + │ │ │ ├─ IncSubtensor{i} [id BO] + │ │ │ │ ├─ Second [id BP] + │ │ │ │ │ ├─ Subtensor{start:} [id BQ] │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ │ └─ ··· - │ │ │ │ │ │ └─ 1 [id BS] - │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT] - │ │ │ │ │ └─ 0.0 [id BU] - │ │ │ │ ├─ Second [id BV] - │ │ │ │ │ ├─ Subtensor{i} [id BW] - │ │ │ │ │ │ ├─ Subtensor{start:} [id BR] + │ │ │ │ │ │ └─ 1 [id BR] + │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BS] + │ │ │ │ │ └─ 0.0 [id BT] + │ │ │ │ ├─ Second [id BU] + │ │ │ │ │ ├─ Subtensor{i} [id BV] + │ │ │ │ │ │ ├─ Subtensor{start:} [id BQ] │ │ │ │ │ │ │ └─ ··· - │ │ │ │ │ │ └─ -1 [id BX] - │ │ │ │ │ └─ ExpandDims{axis=0} [id BY] - │ │ │ │ │ └─ Second [id BZ] - │ │ │ │ │ ├─ Sum{axes=None} [id CA] - │ │ │ │ │ │ └─ Subtensor{i} [id BW] + │ │ │ │ │ │ └─ -1 [id BW] + │ │ │ │ │ └─ ExpandDims{axis=0} [id BX] + │ │ │ │ │ └─ Second [id BY] + │ │ │ │ │ ├─ Sum{axes=None} [id BZ] + │ │ │ │ │ │ └─ Subtensor{i} [id BV] │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ 1.0 [id CB] - │ │ │ │ └─ -1 [id BX] - │ │ │ └─ 1 [id BS] - │ │ └─ -1 [id CC] - │ ├─ Alloc [id CD] (outer_in_sit_sot-0) - │ │ ├─ 0.0 [id CE] - │ │ ├─ Add [id CF] + │ │ │ │ │ └─ 1.0 [id CA] + │ │ │ │ └─ -1 [id BW] + │ │ │ └─ 1 [id BR] + │ │ └─ -1 [id CB] + │ ├─ Alloc [id CC] (outer_in_sit_sot-0) + │ │ ├─ 0.0 [id CD] + │ │ ├─ Add [id CE] │ │ │ ├─ Sub [id C] │ │ │ │ └─ ··· - │ │ │ └─ 1 [id CG] - │ │ └─ Subtensor{i} [id CH] - │ │ ├─ Shape [id CI] - │ │ │ └─ A [id P] - │ │ └─ 0 [id CJ] - │ └─ A [id P] (outer_in_non_seqs-0) - └─ -1 [id CK] + │ │ │ └─ 1 [id CF] + │ │ └─ Subtensor{i} [id CG] + │ │ ├─ Shape [id CH] + │ │ │ └─ A [id O] + │ │ └─ 0 [id CI] + │ └─ A [id O] (outer_in_non_seqs-0) + └─ -1 [id CJ] Inner graphs: Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B] - ← Add [id CL] (inner_out_mit_mot-0-0) - ├─ Mul [id CM] - │ ├─ *2- [id CN] -> [id BK] (inner_in_mit_mot-0-0) - │ └─ *5- [id CO] -> [id P] (inner_in_non_seqs-0) - └─ *3- [id CP] -> [id BK] (inner_in_mit_mot-0-1) - ← Add [id CQ] (inner_out_sit_sot-0) - ├─ Mul [id CR] - │ ├─ *2- [id CN] -> [id BK] (inner_in_mit_mot-0-0) - │ └─ *0- [id CS] -> [id Y] (inner_in_seqs-0) - └─ *4- [id CT] -> [id CD] (inner_in_sit_sot-0) + ← Add [id CK] (inner_out_mit_mot-0-0) + ├─ Mul [id CL] + │ ├─ *2- [id CM] -> [id BJ] (inner_in_mit_mot-0-0) + │ └─ *5- [id CN] -> [id O] (inner_in_non_seqs-0) + └─ *3- [id CO] -> [id BJ] (inner_in_mit_mot-0-1) + ← Add [id CP] (inner_out_sit_sot-0) + ├─ Mul [id CQ] + │ ├─ *2- [id CM] -> [id BJ] (inner_in_mit_mot-0-0) + │ └─ *0- [id CR] -> [id X] (inner_in_seqs-0) + └─ *4- [id CS] -> [id CC] (inner_in_sit_sot-0) Scan{scan_fn, while_loop=False, inplace=none} [id F] - ← Mul [id CU] (inner_out_sit_sot-0) - ├─ *0- [id CS] -> [id H] (inner_in_sit_sot-0) - └─ *1- [id CV] -> [id P] (inner_in_non_seqs-0) + ← Mul [id CT] (inner_out_sit_sot-0) + ├─ *0- [id CR] -> [id H] (inner_in_sit_sot-0) + └─ *1- [id CU] -> [id O] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index e9a6d437ca..1b687afcdc 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -1621,7 +1621,7 @@ def test_while_scan_taps_and_map(self): np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100) np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21) np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1) - with pytest.raises(AssertionError, match="n_steps > 0"): + with pytest.raises((AssertionError, IndexError)): f(x0=0, seq=test_seq, n_steps=0) # Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly. @@ -1634,21 +1634,33 @@ def test_while_scan_taps_and_map(self): assert stored_ys_steps == 2 assert stored_zs_steps == 1 - def test_vector_zeros_init(self): + @pytest.mark.parametrize("val_ndim", (0, 1)) + @pytest.mark.parametrize("keep_beginning", (False, True)) + def test_broadcasted_init(self, keep_beginning, val_ndim): + # Regression test when the original value is a broadcasted alloc + # The scan save mem rewrite used to wrongly slice on the unbroadcasted value + val_shape = (1,) * val_ndim + val = pt.tensor("val", shape=val_shape) + val_test = np.zeros(val_shape, dtype=val.dtype) + + init = pt.full((2,), val) ys, _ = pytensor.scan( - fn=lambda ytm2, ytm1: ytm1 + ytm2, - outputs_info=[{"initial": pt.zeros(2), "taps": range(-2, 0)}], + fn=lambda *args: pt.add(*args), + outputs_info=[{"initial": init, "taps": (-2, -1)}], n_steps=100, ) - fn = pytensor.function([], ys[-50:], mode=self.mode) - assert tuple(fn().shape) == (50,) + out = ys[:-50] if keep_beginning else ys[-50:] + fn = pytensor.function([val], out, mode=self.mode) + assert fn(val_test).shape == (50,) # Check that rewrite worked [scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) _, ys_trace = scan_node.inputs - debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True) - assert debug_fn() == 50 + buffer_size_fn = pytensor.function( + [val], ys_trace.shape[0], accept_inplace=True + ) + assert buffer_size_fn(val_test) == 52 if keep_beginning else 50 def test_inner_replace_dot(): diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index ac8576a8a1..1730ae46ac 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -77,9 +77,7 @@ Reshape, Shape_i, SpecifyShape, - Unbroadcast, specify_shape, - unbroadcast, ) from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -558,48 +556,6 @@ def test_local_useless_tile(self): f(data) -class TestUnbroadcast: - def setup_method(self): - self.mode = get_default_mode().including("canonicalize") - - def test_local_useless_unbroadcast(self): - x1 = tensor(dtype="float64", shape=(1, 2)) - x2 = tensor(dtype="float64", shape=(2, 1)) - unbroadcast_op = Unbroadcast(0) - - f = function([x1], unbroadcast_op(x1), mode=self.mode) - assert ( - sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) - == 1 - ) - - f = function([x2], unbroadcast_op(x2), mode=self.mode) - assert ( - sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) - == 0 - ) - - def test_local_unbroadcast_lift(self): - x = tensor(dtype="float64", shape=(1, 1)) - y = unbroadcast(pt.exp(unbroadcast(x, 0)), 1) - - assert ( - sum( - isinstance(node.op, Unbroadcast) - for node in FunctionGraph([x], [y], copy_inputs=False).toposort() - ) - == 2 - ) - - f = function([x], y, mode=self.mode) - assert ( - sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) - == 1 - ) - - np.testing.assert_almost_equal(f([[1]]), np.exp([[1]])) - - class TestUselessElemwise: def setup_method(self): self.mode = get_default_mode().including("canonicalize", "local_fill_to_alloc") diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index fcfd72ddf2..0f0ec55695 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -28,7 +28,6 @@ ) from pytensor.tensor.shape import ( SpecifyShape, - Unbroadcast, _shape, shape, specify_shape, @@ -55,7 +54,6 @@ lscalar, lscalars, matrix, - row, scalar, tensor, tensor3, @@ -921,64 +919,6 @@ def test_basic_7(self): assert len(prog) == 2 f([1, 2, 3], 4) # let debugmode test something - def test_basic_8(self): - # Test that Subtensor(Unbroadcast(x)) gets optimized into - # Unbroadcast(Subtensor(x)). - - # test basic case - x = row("x") - xval = np.random.random((1, 10)).astype(config.floatX) - assert x.broadcastable == (True, False) - newx = Unbroadcast(0)(x) - assert newx.broadcastable == (False, False) - - f1 = function([x], newx[:2, :5], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast]) - prog = f1.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f1(xval) == xval[:2, :5]).all() - - # corner case 1: Unbroadcast changes dims which are dropped through subtensor - y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x") - yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) - assert y.broadcastable == (True, False, True, False) - newy = Unbroadcast(0, 2)(y) - assert newy.broadcastable == (False, False, False, False) - - f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast]) - prog = f2.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f2(yval) == yval[:, 3, 0, :]).all() - - # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern - f3 = function([y], newy[:, 3, 0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast]) - prog = f3.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f3(yval) == yval[:, 3, 0]).all() - - # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis - z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x") - zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) - assert z.broadcastable == (False, False, False, True) - newz = Unbroadcast(3)(z) - assert newz.broadcastable == (False, False, False, False) - - f4 = function([z], newz[:, 3, 0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast]) - prog = f4.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f4(zval) == zval[:, 3, 0]).all() - class TestLocalSubtensorMerge: def setup_method(self): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 1186aeb35c..9be5044f95 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -287,7 +287,7 @@ def _numpy_second(x, y): ), ) -# unbroadcast a row to a matrix +# broadcast a row to a matrix TestAllocb1GradBroadcast = makeBroadcastTester( name="Allocb1GradTester", op=lambda x: alloc(x, s1, s2), @@ -299,7 +299,7 @@ def _numpy_second(x, y): ), ) -# unbroadcast a row to a tensor3 +# broadcast a row to a tensor3 TestAllocb2GradBroadcast = makeBroadcastTester( name="Allocb2GradTester", op=lambda x: alloc(x, s1, s2, s3), @@ -311,7 +311,7 @@ def _numpy_second(x, y): ), ) -# unbroadcast a col to a matrix +# broadcast a col to a matrix TestAllocb3GradBroadcast = makeBroadcastTester( name="Allocb3GradTester", op=lambda x: alloc(x, s1, s2), @@ -323,7 +323,7 @@ def _numpy_second(x, y): ), ) -# unbroadcast a col to a tensor3 +# broadcast a col to a tensor3 TestAllocb4GradBroadcast = makeBroadcastTester( name="Allocb4GradTester", op=lambda x: alloc(x, s1, s2, s3), @@ -336,7 +336,7 @@ def _numpy_second(x, y): ) -# Partial unbroadcast of a dimshuffled input +# Partial broadcast of a dimshuffled input TestAllocDimshuffleGradBroadcast = makeBroadcastTester( name="Allocb4GradTester", op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3), diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 090819f349..2b37eada72 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -19,14 +19,12 @@ Shape, Shape_i, SpecifyShape, - Unbroadcast, _specify_shape, reshape, shape, shape_tuple, specify_broadcastable, specify_shape, - unbroadcast, ) from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -696,66 +694,6 @@ def test_get_vector_length(): assert get_vector_length(x) == 10 -class TestUnbroadcast: - def test_basic(self): - x = matrix() - assert unbroadcast(x, 0) is x - assert unbroadcast(x, 1) is x - assert unbroadcast(x, 1, 0) is x - assert unbroadcast(x, 0, 1) is x - - x = row() - assert unbroadcast(x, 0) is not x - assert unbroadcast(x, 1) is x - assert unbroadcast(x, 1, 0) is not x - assert unbroadcast(x, 0, 1) is not x - - assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x - - def test_infer_shape(self): - x = matrix() - y = unbroadcast(x, 0) - f = pytensor.function([x], y.shape) - assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all() - topo = f.maker.fgraph.toposort() - if config.mode != "FAST_COMPILE": - assert len(topo) == 3 - assert isinstance(topo[0].op, Shape_i) - assert isinstance(topo[1].op, Shape_i) - assert isinstance(topo[2].op, MakeVector) - - x = row() - y = unbroadcast(x, 0) - f = pytensor.function([x], y.shape) - assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all() - topo = f.maker.fgraph.toposort() - if config.mode != "FAST_COMPILE": - assert len(topo) == 2 - assert isinstance(topo[0].op, Shape_i) - assert isinstance(topo[1].op, MakeVector) - - def test_error_checks(self): - with pytest.raises(TypeError, match="needs integer axes"): - Unbroadcast(0.0) - - with pytest.raises(ValueError, match="^Trying to unbroadcast"): - Unbroadcast(1)(vector()) - - -class TestUnbroadcastInferShape(utt.InferShapeTester): - def test_basic(self): - rng = np.random.default_rng(3453) - adtens4 = tensor(dtype="float64", shape=(1, 1, 1, None)) - adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX) - self._compile_and_check( - [adtens4], - [Unbroadcast(0, 2)(adtens4)], - [adtens4_val], - Unbroadcast, - warn=False, - ) - - def test_shape_tuple(): x = Variable(MyType2(), None, None) assert shape_tuple(x) == () @@ -882,16 +820,3 @@ def test_specify_shape(self): match="Invalid number of shape arguments passed into vectorize node of SpecifyShape", ): vectorize_node(node, tns, *(5, 3, 2, x)) - - def test_unbroadcast(self): - mat = tensor( - shape=( - 1, - 1, - ) - ) - tns = tensor(shape=(4, 1, 1, 1)) - - node = unbroadcast(mat, 0).owner - vect_node = vectorize_node(node, tns) - assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)]) diff --git a/tests/test_rop.py b/tests/test_rop.py index b592f557a5..2e7d4691bb 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -28,7 +28,6 @@ from pytensor.graph.op import Op from pytensor.tensor.math import argmax, dot from pytensor.tensor.math import max as pt_max -from pytensor.tensor.shape import unbroadcast from pytensor.tensor.type import matrix, vector from tests import unittest_tools as utt @@ -252,13 +251,6 @@ def test_dimshuffle(self): # vector self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,)) - def test_unbroadcast(self): - # I need the sum, because the setup expects the output to be a - # vector - self.check_rop_lop( - unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,) - ) - def test_join(self): tv = np.asarray(self.rng.uniform(size=(10,)), pytensor.config.floatX) t = pytensor.shared(tv)