Skip to content

Inplace Composite and ScalarLoop Ops with multiple outputs #1322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def codegen(
ctx.nrt.incref(
builder,
sig.return_type.types[inplace_idx],
outputs[inplace_idx]._get_value(),
outputs[inplace_idx]._getvalue(),
)
return ctx.make_tuple(
builder, sig.return_type, [out._getvalue() for out in outputs]
Expand Down
48 changes: 19 additions & 29 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,19 +1302,7 @@ def __hash__(self):
def __str__(self):
if hasattr(self, "name") and self.name:
return self.name
else:
param = [
(k, v)
for k, v in self.__dict__.items()
if k
not in ("name", "_op_use_c_code", "bool", "output_types_preference")
]
if param:
classname = self.__class__.__name__
args = ", ".join(f"{k}={v}" for k, v in param)
return f"{classname}{{{args}}}"
else:
return self.__class__.__name__
return self.__class__.__name__

def c_code_cache_version(self):
return (4,)
Expand Down Expand Up @@ -4102,6 +4090,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):

def __init__(self, *args, **kwargs):
self.prepare_node_called = set()
super().__init__(*args, **kwargs)

def _cleanup_graph(self, inputs, outputs):
# TODO: We could convert to TensorVariable, optimize graph,
Expand Down Expand Up @@ -4441,16 +4430,12 @@ def c_code_template(self):
if hasattr(self, "_c_code"):
return self._c_code

subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
fg = self.fgraph
subd = {e: f"%(i{int(i)})s" for i, e in enumerate(fg.inputs)}

for var in self.fgraph.variables:
for var in fg.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
if var not in fg.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = f"({var.type.c_literal(var.data)})"
Expand All @@ -4465,30 +4450,35 @@ def c_code_template(self):
# flag for elemwise ops to check.
self.inner_float16 = True

_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
self.nodenames = nodenames = [] # Used by self.c_support_code_apply

_c_code = "{\n"
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for j, node in enumerate(fg.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"

nodename = f"%(nodename)s_subnode{int(j)}"
nodenames.append(nodename)

s = node.op.c_code(
node,
self.nodenames[j],
nodename,
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"

# Copy the temporary outputs to the real outputs
for i, output in enumerate(fg.outputs):
_c_code += f"%(o{int(i)})s = {subd[output]};\n"

_c_code += "}\n"

self._c_code = _c_code
Expand All @@ -4512,7 +4502,7 @@ def c_code(self, node, nodename, inames, onames, sub):
return self.c_code_template % d

def c_code_cache_version_outer(self) -> tuple[int, ...]:
return (5,)
return (6,)


class Compositef32:
Expand Down
62 changes: 30 additions & 32 deletions pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
constant: Sequence[Variable] | None = None,
until: Variable | None = None,
name="ScalarLoop",
**kwargs,
):
if constant is None:
constant = []
Expand All @@ -75,7 +76,7 @@ def __init__(
self.nout = len(self.outputs)
self.name = name

super().__init__()
super().__init__(**kwargs)

def output_types(self, input_types):
return self.outputs_type
Expand Down Expand Up @@ -115,7 +116,7 @@ def fgraph(self):
self._fgraph = fgraph
return self._fgraph

def clone(self):
def clone(self, name=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but I just checked and the signature of clone varies quite a bit over the codebase. That's pretty maddening!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't think it's standardized, unlike node.clone

if self.is_while:
*update, until = self.outputs
else:
Expand All @@ -127,28 +128,16 @@ def clone(self):
update=update,
constant=constant,
until=until,
name=self.name,
name=self.name if name is None else name,
**kwargs,
)

@property
def fn(self):
raise NotImplementedError

def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.

"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(ScalarLoop, out).__init__(output_types_preference, name)
return out
return self.clone(output_types_preference=output_types_preference, name=name)

def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1
Expand Down Expand Up @@ -229,11 +218,11 @@ def c_code_template(self):
c: f"%(i{int(i)})s"
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
}
update_subd = {
out_subd = {
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update])
}
until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd}
subd = {**carry_subd, **constant_subd, **until_subd}

for var in fgraph.variables:
if var.owner is None:
Expand All @@ -257,11 +246,11 @@ def c_code_template(self):
_c_code += "bool until = 1;\n\n"

# Copy carried inputs
for i, (var, name) in enumerate(carry_subd.items()):
copy_var_name = f"{name}_copy{i}"
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n"
carry_subd[var] = copy_var_name
subd[var] = copy_var_name
for i, (var, name) in enumerate(carry_subd.items(), start=1):
carry_var_name = f"{name}_carry{i}"
_c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n"
carry_subd[var] = carry_var_name
subd[var] = carry_var_name

# _c_code += 'printf("inputs=[");'
# for i in range(1, len(fgraph.inputs)):
Expand All @@ -270,9 +259,8 @@ def c_code_template(self):

_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n"

self.nodenames = [
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort())
]
# Used by self.c_support_code_apply
self.nodenames = nodenames = []

i = 0
for j, node in enumerate(fgraph.toposort()):
Expand All @@ -282,9 +270,13 @@ def c_code_template(self):
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"

nodename = f"%(nodename)s_subnode{int(j)}"
nodenames.append(nodename)

s = node.op.c_code(
node,
self.nodenames[j],
nodename,
# Any node that depended on `init` will depend on `update` instead
# The initial value of `update` was set to `init` before the loop
[subd[input] for input in node.inputs],
Expand All @@ -294,10 +286,12 @@ def c_code_template(self):
_c_code += s
_c_code += "\n"

# Set the carry variables to the output variables
# Update the carry variables to the output variables
_c_code += "\n"
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True):
_c_code += f"{init} = {update};\n"
for carry, out in zip(
carry_subd.values(), fgraph.outputs[:n_update], strict=True
):
_c_code += f"{carry} = {subd[out]};\n"

# _c_code += 'printf("%%ld\\n", i);\n'
# for carry in range(1, 10):
Expand All @@ -309,6 +303,10 @@ def c_code_template(self):
# End of the loop
_c_code += "}\n"

# Assign the carry variables to the outputs
for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True):
_c_code += f"{out} = {carry};\n"

# Output until flag
if self.is_while:
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
Expand Down Expand Up @@ -343,4 +341,4 @@ def c_code(self, node, nodename, inames, onames, sub):
return res

def c_code_cache_version_outer(self):
return (3,)
return (4,)
16 changes: 2 additions & 14 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
Expand Down Expand Up @@ -74,17 +73,6 @@ def print_profile(cls, stream, prof, level=0):
for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream)

def candidate_input_idxs(self, node):
# TODO: Implement specialized InplaceCompositeOptimizer with logic
# needed to correctly assign inplace for multi-output Composites
# and ScalarLoops
if isinstance(node.op.scalar_op, ScalarLoop):
return []
if isinstance(node.op.scalar_op, ps.Composite) and (len(node.outputs) > 1):
return []
else:
return range(len(node.outputs))

def apply(self, fgraph):
r"""

Expand Down Expand Up @@ -175,7 +163,7 @@ def apply(self, fgraph):

baseline = op.inplace_pattern
candidate_outputs = [
i for i in self.candidate_input_idxs(node) if i not in baseline
i for i in range(len(node.outputs)) if i not in baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
Expand All @@ -192,7 +180,7 @@ def apply(self, fgraph):
]
else:
baseline = []
candidate_outputs = self.candidate_input_idxs(node)
candidate_outputs = range(len(node.outputs))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
Expand Down
26 changes: 24 additions & 2 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad
from pytensor.scalar import float64
from pytensor.scalar import Composite, float64
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
Expand Down Expand Up @@ -548,7 +548,7 @@ def test_Argmax(x, axes, exc):
)


def test_elemwise_out_type():
def test_elemwise_inplace_out_type():
# Create a graph with an elemwise
# Ravel failes if the elemwise output type is reported incorrectly
x = pt.matrix()
Expand All @@ -563,6 +563,28 @@ def test_elemwise_out_type():
assert func(x_val).shape == (18,)


def test_elemwise_multiple_inplace_outs():
x = pt.vector()
y = pt.vector()

x_ = pt.scalar_from_tensor(x[0])
y_ = pt.scalar_from_tensor(y[0])
out_ = x_ + 1, y_ + 1

composite_op = Composite([x_, y_], out_)
elemwise_op = Elemwise(composite_op, inplace_pattern={0: 0, 1: 1})
out = elemwise_op(x, y)

fn = function([x, y], out, mode="NUMBA", accept_inplace=True)
x_test = np.array([1, 2, 3], dtype=config.floatX)
y_test = np.array([4, 5, 6], dtype=config.floatX)
out1, out2 = fn(x_test, y_test)
assert out1 is x_test
assert out2 is y_test
np.testing.assert_allclose(out1, [2, 3, 4])
np.testing.assert_allclose(out2, [5, 6, 7])


def test_scalar_loop():
a = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
Expand Down
Loading