diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 74870e29bd..e6bd7fa4ca 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -265,7 +265,7 @@ def codegen( ctx.nrt.incref( builder, sig.return_type.types[inplace_idx], - outputs[inplace_idx]._get_value(), + outputs[inplace_idx]._getvalue(), ) return ctx.make_tuple( builder, sig.return_type, [out._getvalue() for out in outputs] diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 26b551875c..909fc47c27 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1302,19 +1302,7 @@ def __hash__(self): def __str__(self): if hasattr(self, "name") and self.name: return self.name - else: - param = [ - (k, v) - for k, v in self.__dict__.items() - if k - not in ("name", "_op_use_c_code", "bool", "output_types_preference") - ] - if param: - classname = self.__class__.__name__ - args = ", ".join(f"{k}={v}" for k, v in param) - return f"{classname}{{{args}}}" - else: - return self.__class__.__name__ + return self.__class__.__name__ def c_code_cache_version(self): return (4,) @@ -4102,6 +4090,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph): def __init__(self, *args, **kwargs): self.prepare_node_called = set() + super().__init__(*args, **kwargs) def _cleanup_graph(self, inputs, outputs): # TODO: We could convert to TensorVariable, optimize graph, @@ -4344,9 +4333,9 @@ def __str__(self): # Rename internal variables for i, r in enumerate(self.fgraph.inputs): - r.name = f"i{int(i)}" + r.name = f"i{i}" for i, r in enumerate(self.fgraph.outputs): - r.name = f"o{int(i)}" + r.name = f"o{i}" io = set(self.fgraph.inputs + self.fgraph.outputs) for i, r in enumerate(self.fgraph.variables): if ( @@ -4354,7 +4343,7 @@ def __str__(self): and r not in io and len(self.fgraph.clients[r]) > 1 ): - r.name = f"t{int(i)}" + r.name = f"t{i}" if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10: self._name = "Composite{...}" @@ -4441,16 +4430,12 @@ def c_code_template(self): if hasattr(self, "_c_code"): return self._c_code - subd = dict( - chain( - ((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)), - ((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)), - ) - ) + fg = self.fgraph + subd = {e: f"%(i{i})s" for i, e in enumerate(fg.inputs)} - for var in self.fgraph.variables: + for var in fg.variables: if var.owner is None: - if var not in self.fgraph.inputs: + if var not in fg.inputs: # This is an orphan if isinstance(var, Constant) and isinstance(var.type, CLinkerType): subd[var] = f"({var.type.c_literal(var.data)})" @@ -4465,30 +4450,35 @@ def c_code_template(self): # flag for elemwise ops to check. self.inner_float16 = True - _c_code = "{\n" - self.nodenames = [ - f"%(nodename)s_subnode{int(j)}" - for j, n in enumerate(self.fgraph.toposort()) - ] + self.nodenames = nodenames = [] # Used by self.c_support_code_apply + _c_code = "{\n" i = 0 - for j, node in enumerate(self.fgraph.toposort()): + for j, node in enumerate(fg.toposort()): for output in node.outputs: if output not in subd: i += 1 - name = f"V%(id)s_tmp{int(i)}" + name = f"V%(id)s_tmp{i}" subd[output] = name _c_code += f"{output.type.dtype_specs()[1]} {name};\n" + + nodename = f"%(nodename)s_subnode{j}" + nodenames.append(nodename) + s = node.op.c_code( node, - self.nodenames[j], + nodename, [subd[input] for input in node.inputs], [subd[output] for output in node.outputs], - dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"), + dict(fail="%(fail)s", id=f"%(id)s_{j}"), ) _c_code += s _c_code += "\n" + # Copy the temporary outputs to the real outputs + for i, output in enumerate(fg.outputs): + _c_code += f"%(o{i})s = {subd[output]};\n" + _c_code += "}\n" self._c_code = _c_code @@ -4498,8 +4488,8 @@ def c_code_template(self): def c_code(self, node, nodename, inames, onames, sub): d = dict( chain( - zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True), - zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True), + zip((f"i{i}" for i in range(len(inames))), inames, strict=True), + zip((f"o{i}" for i in range(len(onames))), onames, strict=True), ), **sub, ) @@ -4512,7 +4502,7 @@ def c_code(self, node, nodename, inames, onames, sub): return self.c_code_template % d def c_code_cache_version_outer(self) -> tuple[int, ...]: - return (5,) + return (6,) class Compositef32: diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 0408cba9b3..1023e6a127 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -55,6 +55,7 @@ def __init__( constant: Sequence[Variable] | None = None, until: Variable | None = None, name="ScalarLoop", + **kwargs, ): if constant is None: constant = [] @@ -75,7 +76,7 @@ def __init__( self.nout = len(self.outputs) self.name = name - super().__init__() + super().__init__(**kwargs) def output_types(self, input_types): return self.outputs_type @@ -115,7 +116,7 @@ def fgraph(self): self._fgraph = fgraph return self._fgraph - def clone(self): + def clone(self, name=None, **kwargs): if self.is_while: *update, until = self.outputs else: @@ -127,7 +128,8 @@ def clone(self): update=update, constant=constant, until=until, - name=self.name, + name=self.name if name is None else name, + **kwargs, ) @property @@ -135,20 +137,7 @@ def fn(self): raise NotImplementedError def make_new_inplace(self, output_types_preference=None, name=None): - """ - This op.__init__ fct don't have the same parameter as other scalar op. - This break the insert_inplace_optimizer optimization. - This fct allow fix patch this. - - """ - d = {k: getattr(self, k) for k in self.init_param} - out = self.__class__(**d) - if name: - out.name = name - else: - name = out.name - super(ScalarLoop, out).__init__(output_types_preference, name) - return out + return self.clone(output_types_preference=output_types_preference, name=name) def make_node(self, n_steps, *inputs): assert len(inputs) == self.nin - 1 @@ -223,17 +212,15 @@ def c_code_template(self): # The first input is `n_steps` so we skip it in the mapping dictionary n_update = len(self.outputs) - (1 if self.is_while else 0) carry_subd = { - c: f"%(i{int(i)})s" for i, c in enumerate(fgraph.inputs[:n_update], start=1) + c: f"%(i{i})s" for i, c in enumerate(fgraph.inputs[:n_update], start=1) } constant_subd = { - c: f"%(i{int(i)})s" + c: f"%(i{i})s" for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1) } - update_subd = { - u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update]) - } + out_subd = {u: f"%(o{i})s" for i, u in enumerate(fgraph.outputs[:n_update])} until_subd = {u: "until" for u in fgraph.outputs[n_update:]} - subd = {**carry_subd, **constant_subd, **update_subd, **until_subd} + subd = {**carry_subd, **constant_subd, **until_subd} for var in fgraph.variables: if var.owner is None: @@ -257,11 +244,11 @@ def c_code_template(self): _c_code += "bool until = 1;\n\n" # Copy carried inputs - for i, (var, name) in enumerate(carry_subd.items()): - copy_var_name = f"{name}_copy{i}" - _c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n" - carry_subd[var] = copy_var_name - subd[var] = copy_var_name + for i, (var, name) in enumerate(carry_subd.items(), start=1): + carry_var_name = f"{name}_carry{i}" + _c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n" + carry_subd[var] = carry_var_name + subd[var] = carry_var_name # _c_code += 'printf("inputs=[");' # for i in range(1, len(fgraph.inputs)): @@ -270,34 +257,39 @@ def c_code_template(self): _c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n" - self.nodenames = [ - f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort()) - ] + # Used by self.c_support_code_apply + self.nodenames = nodenames = [] i = 0 for j, node in enumerate(fgraph.toposort()): for output in node.outputs: if output not in subd: i += 1 - name = f"V%(id)s_tmp{int(i)}" + name = f"V%(id)s_tmp{i}" subd[output] = name _c_code += f"{output.type.dtype_specs()[1]} {name};\n" + + nodename = f"%(nodename)s_subnode{j}" + nodenames.append(nodename) + s = node.op.c_code( node, - self.nodenames[j], + nodename, # Any node that depended on `init` will depend on `update` instead # The initial value of `update` was set to `init` before the loop [subd[input] for input in node.inputs], [subd[output] for output in node.outputs], - dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"), + dict(fail="%(fail)s", id=f"%(id)s_{j}"), ) _c_code += s _c_code += "\n" - # Set the carry variables to the output variables + # Update the carry variables to the output variables _c_code += "\n" - for init, update in zip(carry_subd.values(), update_subd.values(), strict=True): - _c_code += f"{init} = {update};\n" + for carry, out in zip( + carry_subd.values(), fgraph.outputs[:n_update], strict=True + ): + _c_code += f"{carry} = {subd[out]};\n" # _c_code += 'printf("%%ld\\n", i);\n' # for carry in range(1, 10): @@ -309,6 +301,10 @@ def c_code_template(self): # End of the loop _c_code += "}\n" + # Assign the carry variables to the outputs + for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True): + _c_code += f"{out} = {carry};\n" + # Output until flag if self.is_while: _c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n" @@ -322,8 +318,8 @@ def c_code_template(self): def c_code(self, node, nodename, inames, onames, sub): d = dict( chain( - zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True), - zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True), + zip((f"i{i}" for i in range(len(inames))), inames, strict=True), + zip((f"o{i}" for i in range(len(onames))), onames, strict=True), ), **sub, ) @@ -343,4 +339,4 @@ def c_code(self, node, nodename, inames, onames, sub): return res def c_code_cache_version_outer(self): - return (3,) + return (4,) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index eaba64c275..4b5a5075eb 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -24,7 +24,6 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined -from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( MakeVector, @@ -74,17 +73,6 @@ def print_profile(cls, stream, prof, level=0): for n in sorted(ndim): print(blanc, n, ndim[n], file=stream) - def candidate_input_idxs(self, node): - # TODO: Implement specialized InplaceCompositeOptimizer with logic - # needed to correctly assign inplace for multi-output Composites - # and ScalarLoops - if isinstance(node.op.scalar_op, ScalarLoop): - return [] - if isinstance(node.op.scalar_op, ps.Composite) and (len(node.outputs) > 1): - return [] - else: - return range(len(node.outputs)) - def apply(self, fgraph): r""" @@ -175,7 +163,7 @@ def apply(self, fgraph): baseline = op.inplace_pattern candidate_outputs = [ - i for i in self.candidate_input_idxs(node) if i not in baseline + i for i in range(len(node.outputs)) if i not in baseline ] # node inputs that are Constant, already destroyed, # or fgraph protected inputs and fgraph outputs can't be used as @@ -192,7 +180,7 @@ def apply(self, fgraph): ] else: baseline = [] - candidate_outputs = self.candidate_input_idxs(node) + candidate_outputs = range(len(node.outputs)) # node inputs that are Constant, already destroyed, # fgraph protected inputs and fgraph outputs can't be used as inplace # target. diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 7ef5705c07..25efd69a8d 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -12,7 +12,7 @@ from pytensor.compile import get_mode from pytensor.compile.ops import deep_copy_op from pytensor.gradient import grad -from pytensor.scalar import float64 +from pytensor.scalar import Composite, float64 from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -548,7 +548,7 @@ def test_Argmax(x, axes, exc): ) -def test_elemwise_out_type(): +def test_elemwise_inplace_out_type(): # Create a graph with an elemwise # Ravel failes if the elemwise output type is reported incorrectly x = pt.matrix() @@ -563,6 +563,28 @@ def test_elemwise_out_type(): assert func(x_val).shape == (18,) +def test_elemwise_multiple_inplace_outs(): + x = pt.vector() + y = pt.vector() + + x_ = pt.scalar_from_tensor(x[0]) + y_ = pt.scalar_from_tensor(y[0]) + out_ = x_ + 1, y_ + 1 + + composite_op = Composite([x_, y_], out_) + elemwise_op = Elemwise(composite_op, inplace_pattern={0: 0, 1: 1}) + out = elemwise_op(x, y) + + fn = function([x, y], out, mode="NUMBA", accept_inplace=True) + x_test = np.array([1, 2, 3], dtype=config.floatX) + y_test = np.array([4, 5, 6], dtype=config.floatX) + out1, out2 = fn(x_test, y_test) + assert out1 is x_test + assert out2 is y_test + np.testing.assert_allclose(out1, [2, 3, 4]) + np.testing.assert_allclose(out2, [5, 6, 7]) + + def test_scalar_loop(): a = float64("a") scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index 88d14c6e43..6e46a56cdc 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from pytensor import Mode, function +from pytensor import In, Mode, function +from pytensor.compile import get_default_mode from pytensor.scalar import ( Composite, as_scalar, @@ -18,6 +19,8 @@ ) from pytensor.scalar.loop import ScalarLoop from pytensor.tensor import exp as tensor_exp +from pytensor.tensor import lvector +from pytensor.tensor.elemwise import Elemwise mode = pytest.mark.parametrize( @@ -255,3 +258,46 @@ def test_inner_loop(mode): out16, 3**2 + 2.5, ) + + +@pytest.mark.parametrize("mutate_arg_idx", (0, 1, 2, 3)) +def test_elemwise_inplace(mutate_arg_idx): + x0 = int64("x0") + y0 = int64("y0") + c = int64("c") + x = x0 - y0 + c + y = y0 - x0 + c + op = Elemwise(ScalarLoop(init=[x0, y0], constant=[c], update=[x, y])) + + n_steps = lvector("n_steps") + x0v = lvector("x0") + y0v = lvector("y0") + cv = lvector("c") + xv, yv = op(n_steps, x0v, y0v, cv) + + inputs = [ + In(inp, mutable=i == mutate_arg_idx) + for i, inp in enumerate([n_steps, x0v, y0v, cv]) + ] + + fn = function( + inputs, + [xv, yv], + mode=get_default_mode().including("inplace"), + ) + fn.dprint() + elem_op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(elem_op, Elemwise) and isinstance(elem_op.scalar_op, ScalarLoop) + destroy_map = elem_op.destroy_map + assert destroy_map == {0: [mutate_arg_idx]} + + n_test = np.array([1, 4, 8], dtype="int64") + x0v_test = np.array([0, 0, 0], dtype="int64") + y0v_test = np.array([1, 1, 1], dtype="int64") + cv_test = np.array([0, 0, 0], dtype="int64") + + xv_res, yv_res = fn(n_test, x0v_test, y0v_test, cv_test) + # Check the outputs are the destroyed inputs + assert xv_res is (n_test, x0v_test, y0v_test, cv_test)[mutate_arg_idx] + np.testing.assert_allclose(xv_res, [-1, -8, -128]) + np.testing.assert_allclose(yv_res, [1, 8, 128]) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 6fb0594ed5..4e7fe54581 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1104,7 +1104,8 @@ def test_add_mul_fusion_inplace(self): np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) ) - def test_fusion_multiout_inplace(self): + @pytest.mark.parametrize("linker", ["cvm", "py"]) + def test_fusion_multiout_inplace(self, linker): x = vector("x") # Create Composite where inplacing the first non-constant output would corrupt the second output @@ -1118,17 +1119,16 @@ def test_fusion_multiout_inplace(self): f = pytensor.function( [In(x, mutable=True)], outs, - mode=self.mode.including("inplace"), + mode=Mode(linker=linker, optimizer=self.rewrites.including("inplace")), ) (composite_node,) = f.maker.fgraph.apply_nodes - # Destroy map must be None or the last toposorted output destroy_map = composite_node.op.destroy_map - assert (destroy_map == {}) or ( - destroy_map == {1: [composite_node.inputs.index(x)]} - ) + assert destroy_map == {0: [0]} - res = f([0, 1, 2]) + inp = np.array([0, 1, 2], dtype=config.floatX) + res = f(inp) + assert not np.allclose(inp, [0, 1, 2]) assert np.allclose(res[0], [1, 2, 3]) assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2])) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index e15293e979..e7579b10ac 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -431,11 +431,13 @@ def test_gammaincc_ddk_performance(benchmark): x = vector("x") out = gammaincc(k, x) - grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN") + grad_fn = function( + [k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN", trust_input=True + ) vals = [ # Values that hit the second branch of the gradient - np.full((1000,), 3.2), - np.full((1000,), 0.01), + np.full((1000,), 3.2, dtype=k.dtype), + np.full((1000,), 0.01, dtype=x.dtype), ] verify_grad(gammaincc, vals, rng=rng) @@ -1127,9 +1129,13 @@ def test_benchmark(self, case, wrt, benchmark): a1, a2, b1, z = pt.scalars("a1", "a2", "b1", "z") hyp2f1_out = pt.hyp2f1(a1, a2, b1, z) hyp2f1_grad = pt.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z]) - f_grad = function([a1, a2, b1, z], hyp2f1_grad) + f_grad = function([a1, a2, b1, z], hyp2f1_grad, trust_input=True) (test_a1, test_a2, test_b1, test_z, *expected_dds) = case + test_a1 = np.array(test_a1, dtype=a1.dtype) + test_a2 = np.array(test_a2, dtype=a2.dtype) + test_b1 = np.array(test_b1, dtype=b1.dtype) + test_z = np.array(test_z, dtype=z.dtype) result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)