From 1d93510edce27641bf78316bee6a76c1f377abb1 Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Thu, 13 Mar 2025 20:53:46 +0530 Subject: [PATCH 1/8] Refactor infer_shape methods to utilize _gufunc_to_out_shape for output shape computation --- pytensor/tensor/nlinalg.py | 23 +++++++---------------- pytensor/tensor/slinalg.py | 20 +++++++------------- pytensor/tensor/utils.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index ee33f6533c..5af27a9919 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -17,6 +17,7 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector +from pytensor.tensor.utils import _gufunc_to_out_shape class MatrixPinv(Op): @@ -63,7 +64,7 @@ def L_op(self, inputs, outputs, g_outputs): return [grad] def infer_shape(self, fgraph, node, shapes): - return [list(reversed(shapes[0]))] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def pinv(x, hermitian=False): @@ -156,7 +157,7 @@ def R_op(self, inputs, eval_points): return [-matrix_dot(xi, ev, xi)] def infer_shape(self, fgraph, node, shapes): - return shapes + return _gufunc_to_out_shape(self.gufunc_signature, shapes) inv = matrix_inverse = Blockwise(MatrixInverse()) @@ -225,7 +226,7 @@ def grad(self, inputs, g_outputs): return [gz * self(x) * matrix_inverse(x).T] def infer_shape(self, fgraph, node, shapes): - return [()] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def __str__(self): return "Det" @@ -259,7 +260,7 @@ def perform(self, node, inputs, outputs): raise ValueError("Failed to compute determinant", x) from e def infer_shape(self, fgraph, node, shapes): - return [(), ()] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def __str__(self): return "SLogDet" @@ -317,8 +318,7 @@ def perform(self, node, inputs, outputs): w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x)) def infer_shape(self, fgraph, node, shapes): - n = shapes[0][0] - return [(n,), (n, n)] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) eig = Blockwise(Eig()) @@ -619,16 +619,7 @@ def perform(self, node, inputs, outputs): s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) def infer_shape(self, fgraph, node, shapes): - (x_shape,) = shapes - M, N = x_shape - K = ptm.minimum(M, N) - s_shape = (K,) - if self.compute_uv: - u_shape = (M, M) if self.full_matrices else (M, K) - vt_shape = (N, N) if self.full_matrices else (K, N) - return [u_shape, s_shape, vt_shape] - else: - return [s_shape] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def L_op( self, diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index a8f9377170..f982e88166 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -20,6 +20,7 @@ from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector +from pytensor.tensor.utils import _gufunc_to_out_shape from pytensor.tensor.variable import TensorVariable @@ -51,7 +52,7 @@ def __init__( self.destroy_map = {0: [0]} def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def make_node(self, x): x = as_tensor_variable(x) @@ -269,13 +270,7 @@ def make_node(self, A, b): return Apply(self, [A, b], [x]) def infer_shape(self, fgraph, node, shapes): - Ashape, Bshape = shapes - rows = Ashape[1] - if len(Bshape) == 1: - return [(rows,)] - else: - cols = Bshape[1] - return [(rows, cols)] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def L_op(self, inputs, outputs, output_gradients): r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`. @@ -891,7 +886,7 @@ def perform(self, node, inputs, output_storage): X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf @@ -963,7 +958,7 @@ def perform(self, node, inputs, output_storage): ) def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf @@ -1083,7 +1078,7 @@ def perform(self, node, inputs, output_storage): X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf @@ -1181,8 +1176,7 @@ def grad(self, inputs, gout): return [gout[0][slc] for slc in slices] def infer_shape(self, fgraph, nodes, shapes): - first, second = zip(*shapes, strict=True) - return [(pt.add(*first), pt.add(*second))] + return _gufunc_to_out_shape(self.gufunc_signature, shapes) def _validate_and_prepare_inputs(self, matrices, as_tensor_func): if len(matrices) != self.n_inputs: diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 0ebb2e5434..11daefd499 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -202,6 +202,39 @@ def _parse_gufunc_signature( ) +def _gufunc_to_out_shape( + signature: str, shapes: list[tuple[int, ...]] +) -> list[tuple[int, ...]]: + """ + Compute the shape of the output of an Op given its gufunc signature and the + shapes of its inputs. + + Parameters + ---------- + signature : str + The gufunc signature of the Op. + eg: "(m,n),(n,p)->(m,p)". + + shapes : list of tuple of int + The list of shapes of the inputs. + + Returns + ------- + out_shape : list of tuple of int + The list of shapes of the outputs. + """ + parsed = _parse_gufunc_signature(signature) + out_shape = [] + dic = dict() + for i in range(len(parsed[0])): + for j in range(len(parsed[0][i])): + dic[parsed[0][i][j]] = shapes[i][j] + for i in range(len(parsed[1])): + temp_list = [dic[x] for x in parsed[1][i]] + out_shape.append(tuple(temp_list)) + return out_shape + + def safe_signature( core_inputs_ndim: Sequence[int], core_outputs_ndim: Sequence[int], From 81899de2dcc6d8b74a678028c727f50d290be1bc Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Thu, 13 Mar 2025 22:03:26 +0530 Subject: [PATCH 2/8] Redacted changes for Ops with non-conclusive gufunc_signature --- pytensor/tensor/nlinalg.py | 11 ++++++++++- pytensor/tensor/slinalg.py | 5 +++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 5af27a9919..78fbe40869 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -619,7 +619,16 @@ def perform(self, node, inputs, outputs): s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) + (x_shape,) = shapes + M, N = x_shape + K = ptm.minimum(M, N) + s_shape = (K,) + if self.compute_uv: + u_shape = (M, M) if self.full_matrices else (M, K) + vt_shape = (N, N) if self.full_matrices else (K, N) + return [u_shape, s_shape, vt_shape] + else: + return [s_shape] def L_op( self, diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f982e88166..5382130355 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -20,8 +20,8 @@ from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector -from pytensor.tensor.utils import _gufunc_to_out_shape from pytensor.tensor.variable import TensorVariable +from pytensor.tensor.utils import _gufunc_to_out_shape logger = logging.getLogger(__name__) @@ -1176,7 +1176,8 @@ def grad(self, inputs, gout): return [gout[0][slc] for slc in slices] def infer_shape(self, fgraph, nodes, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) + first, second = zip(*shapes, strict=True) + return [(pt.add(*first), pt.add(*second))] def _validate_and_prepare_inputs(self, matrices, as_tensor_func): if len(matrices) != self.n_inputs: From 7931668f6a83e1d4324911587ea39cba683dbf49 Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Thu, 13 Mar 2025 22:05:44 +0530 Subject: [PATCH 3/8] fixed ruff format --- pytensor/tensor/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 5382130355..d9a023f0aa 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -20,8 +20,8 @@ from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector -from pytensor.tensor.variable import TensorVariable from pytensor.tensor.utils import _gufunc_to_out_shape +from pytensor.tensor.variable import TensorVariable logger = logging.getLogger(__name__) From 870b900d7fc9fb1dc3d64a02152717f3c226090c Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Fri, 21 Mar 2025 23:51:42 +0530 Subject: [PATCH 4/8] Refactor _gufunc_to_out_shape for giving priority to Constant dimensions along with error handling --- pytensor/tensor/utils.py | 52 +++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 11daefd499..bcaeb5c3be 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -7,6 +7,7 @@ import pytensor from pytensor.graph import FunctionGraph, Variable from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.tensor import Any, Constant from pytensor.utils import hash_from_code @@ -203,8 +204,8 @@ def _parse_gufunc_signature( def _gufunc_to_out_shape( - signature: str, shapes: list[tuple[int, ...]] -) -> list[tuple[int, ...]]: + signature: str, shapes: list[tuple[Any, ...]] +) -> list[tuple[Any, ...]]: """ Compute the shape of the output of an Op given its gufunc signature and the shapes of its inputs. @@ -215,24 +216,47 @@ def _gufunc_to_out_shape( The gufunc signature of the Op. eg: "(m,n),(n,p)->(m,p)". - shapes : list of tuple of int + shapes : list of tuple of Any The list of shapes of the inputs. Returns ------- - out_shape : list of tuple of int + out_shape : list of tuple of Any The list of shapes of the outputs. + + Raises + ------ + ValueError + If the signature is invalid for the shapes of the inputs. """ - parsed = _parse_gufunc_signature(signature) - out_shape = [] - dic = dict() - for i in range(len(parsed[0])): - for j in range(len(parsed[0][i])): - dic[parsed[0][i][j]] = shapes[i][j] - for i in range(len(parsed[1])): - temp_list = [dic[x] for x in parsed[1][i]] - out_shape.append(tuple(temp_list)) - return out_shape + input_sig, output_sig = _parse_gufunc_signature(signature) + dim_to_size: dict[str, Any] = {} + for input_shape, sig in zip(shapes, input_sig, strict=True): + for size, dim_name in zip(input_shape, sig, strict=True): + prev_size = dim_to_size.get(dim_name) + if prev_size is None: + dim_to_size[dim_name] = size + # Prefer constants + elif not isinstance(prev_size, Constant): + dim_to_size[dim_name] = size + elif prev_size.data != size: + raise ValueError( + f"Invalid signature {signature} for shapes {shapes}. " + f"Dimension {dim_name} is not consistent across inputs." + ) + out_shapes = [] + for output_shape in output_sig: + temp_list = [] + for dim in output_shape: + if dim not in dim_to_size: + raise ValueError( + f"Invalid signature {signature} for shapes {shapes}. " + f"Dimension {dim} not in input dimensions." + ) + else: + temp_list.append(dim_to_size[dim]) + out_shapes.append((*temp_list,)) + return out_shapes def safe_signature( From c16145227768c14fd11b19c9ed2d95778a9f82bd Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Sat, 22 Mar 2025 00:37:25 +0530 Subject: [PATCH 5/8] Remove error handling for inconsistent dimensions in _gufunc_to_out_shape --- pytensor/tensor/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index bcaeb5c3be..2dbfa9b8ea 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -239,11 +239,7 @@ def _gufunc_to_out_shape( # Prefer constants elif not isinstance(prev_size, Constant): dim_to_size[dim_name] = size - elif prev_size.data != size: - raise ValueError( - f"Invalid signature {signature} for shapes {shapes}. " - f"Dimension {dim_name} is not consistent across inputs." - ) + out_shapes = [] for output_shape in output_sig: temp_list = [] From 7b44445280d44fbe254b47ae7fcb59add99f5991 Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Sat, 22 Mar 2025 02:26:07 +0530 Subject: [PATCH 6/8] Implement infer_shape method in Op class and remove redundant implementations in subclasses --- pytensor/graph/op.py | 7 +++++++ pytensor/tensor/nlinalg.py | 16 ---------------- pytensor/tensor/slinalg.py | 16 ---------------- 3 files changed, 7 insertions(+), 32 deletions(-) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 690bb44df5..47292ade1e 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -20,6 +20,7 @@ add_tag_trace, get_variable_trace_string, ) +from pytensor.tensor.utils import _gufunc_to_out_shape if TYPE_CHECKING: @@ -596,6 +597,12 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": # By default, do nothing return self + def infer_shape(self, fgraph, node, input_shapes): + if hasattr(self, "gufunc_signature"): + return _gufunc_to_out_shape(self.gufunc_signature, input_shapes) + else: + raise NotImplementedError(f"Op {self} does not implement infer_shape") + def __str__(self): return getattr(type(self), "__name__", super().__str__()) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 78fbe40869..446cf5ab44 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -17,7 +17,6 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector -from pytensor.tensor.utils import _gufunc_to_out_shape class MatrixPinv(Op): @@ -63,9 +62,6 @@ def L_op(self, inputs, outputs, g_outputs): ).T return [grad] - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def pinv(x, hermitian=False): """Computes the pseudo-inverse of a matrix :math:`A`. @@ -156,9 +152,6 @@ def R_op(self, inputs, eval_points): return [None] return [-matrix_dot(xi, ev, xi)] - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - inv = matrix_inverse = Blockwise(MatrixInverse()) @@ -225,9 +218,6 @@ def grad(self, inputs, g_outputs): (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def __str__(self): return "Det" @@ -259,9 +249,6 @@ def perform(self, node, inputs, outputs): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def __str__(self): return "SLogDet" @@ -317,9 +304,6 @@ def perform(self, node, inputs, outputs): (w, v) = outputs w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x)) - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - eig = Blockwise(Eig()) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index d9a023f0aa..ddbd74a9d4 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -20,7 +20,6 @@ from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector -from pytensor.tensor.utils import _gufunc_to_out_shape from pytensor.tensor.variable import TensorVariable @@ -51,9 +50,6 @@ def __init__( if self.overwrite_a: self.destroy_map = {0: [0]} - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def make_node(self, x): x = as_tensor_variable(x) if x.type.ndim != 2: @@ -269,9 +265,6 @@ def make_node(self, A, b): x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def L_op(self, inputs, outputs, output_gradients): r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`. @@ -885,9 +878,6 @@ def perform(self, node, inputs, output_storage): out_dtype = node.outputs[0].type.dtype X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf # Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q, @@ -957,9 +947,6 @@ def perform(self, node, inputs, output_storage): out_dtype ) - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf A, Q = inputs @@ -1077,9 +1064,6 @@ def perform(self, node, inputs, output_storage): out_dtype = node.outputs[0].type.dtype X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) - def infer_shape(self, fgraph, node, shapes): - return _gufunc_to_out_shape(self.gufunc_signature, shapes) - def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf A, B, Q, R = inputs From 707c82e71c89ecb4655c4a8fbf2102970af8b2fc Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Sat, 22 Mar 2025 02:42:33 +0530 Subject: [PATCH 7/8] fixed circular import --- pytensor/graph/op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 47292ade1e..a87e9c4a1b 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -20,7 +20,6 @@ add_tag_trace, get_variable_trace_string, ) -from pytensor.tensor.utils import _gufunc_to_out_shape if TYPE_CHECKING: @@ -599,6 +598,8 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": def infer_shape(self, fgraph, node, input_shapes): if hasattr(self, "gufunc_signature"): + from pytensor.tensor.utils import _gufunc_to_out_shape + return _gufunc_to_out_shape(self.gufunc_signature, input_shapes) else: raise NotImplementedError(f"Op {self} does not implement infer_shape") From a5f6ce437ad190efed47d668e7d50e939e83408e Mon Sep 17 00:00:00 2001 From: Aarsh-Wankar <23110003@iitgn.ac.in> Date: Sat, 22 Mar 2025 03:10:17 +0530 Subject: [PATCH 8/8] Raise ShapeError instead of NotImplementedError for unimplemented infer_shape in Op class --- pytensor/graph/op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index a87e9c4a1b..79f0508588 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -602,7 +602,9 @@ def infer_shape(self, fgraph, node, input_shapes): return _gufunc_to_out_shape(self.gufunc_signature, input_shapes) else: - raise NotImplementedError(f"Op {self} does not implement infer_shape") + from pytensor.tensor.exceptions import ShapeError + + raise ShapeError(f"Op {self} does not implement infer_shape") def __str__(self): return getattr(type(self), "__name__", super().__str__())