diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index ca362e4531..5430ce1da4 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -53,7 +53,6 @@ def solve(a, b, lower=lower): @jax_funcify.register(SolveTriangular) def jax_funcify_SolveTriangular(op, **kwargs): lower = op.lower - trans = op.trans unit_diagonal = op.unit_diagonal check_finite = op.check_finite @@ -62,7 +61,7 @@ def solve_triangular(A, b): A, b, lower=lower, - trans=trans, + trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here. unit_diagonal=unit_diagonal, check_finite=check_finite, ) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 4b5f518926..c64b5fdb3e 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -180,7 +180,6 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): @numba_funcify.register(SolveTriangular) def numba_funcify_SolveTriangular(op, node, **kwargs): - trans = bool(op.trans) lower = op.lower unit_diagonal = op.unit_diagonal check_finite = op.check_finite @@ -208,7 +207,7 @@ def solve_triangular(a, b): res = _solve_triangular( a, b, - trans=trans, + trans=0, # transposing is handled explicitly on the graph, so we never use this argument lower=lower, unit_diagonal=unit_diagonal, overwrite_b=overwrite_b, diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 94973810fd..25ee69a07d 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -296,13 +296,12 @@ def L_op(self, inputs, outputs, output_gradients): # We need to return (dC/d[inv(A)], dC/db) c_bar = output_gradients[0] - trans_solve_op = type(self)( - **{ - k: (not getattr(self, k) if k == "lower" else getattr(self, k)) - for k in self.__props__ - } - ) - b_bar = trans_solve_op(A.T, c_bar) + props_dict = self._props_dict() + props_dict["lower"] = not self.lower + + solve_op = type(self)(**props_dict) + + b_bar = solve_op(A.T, c_bar) # force outer product if vector second input A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) @@ -385,7 +384,6 @@ class SolveTriangular(SolveBase): """Solve a system of linear equations.""" __props__ = ( - "trans", "unit_diagonal", "lower", "check_finite", @@ -393,11 +391,10 @@ class SolveTriangular(SolveBase): "overwrite_b", ) - def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): + def __init__(self, *, unit_diagonal=False, **kwargs): if kwargs.get("overwrite_a", False): raise ValueError("overwrite_a is not supported for SolverTriangulare") super().__init__(**kwargs) - self.trans = trans self.unit_diagonal = unit_diagonal def perform(self, node, inputs, outputs): @@ -406,7 +403,7 @@ def perform(self, node, inputs, outputs): A, b, lower=self.lower, - trans=self.trans, + trans=0, unit_diagonal=self.unit_diagonal, check_finite=self.check_finite, overwrite_b=self.overwrite_b, @@ -445,9 +442,9 @@ def solve_triangular( Parameters ---------- - a + a: TensorVariable Square input data - b + b: TensorVariable Input data for the right hand side. lower : bool, optional Use only data contained in the lower triangle of `a`. Default is to use upper triangle. @@ -468,10 +465,17 @@ def solve_triangular( This will influence how batched dimensions are interpreted. """ b_ndim = _default_b_ndim(b, b_ndim) + + if trans in [1, "T", True]: + a = a.mT + lower = not lower + if trans in [2, "C"]: + a = a.conj().mT + lower = not lower + ret = Blockwise( SolveTriangular( lower=lower, - trans=trans, unit_diagonal=unit_diagonal, check_finite=check_finite, b_ndim=b_ndim, @@ -534,6 +538,7 @@ def solve( *, assume_a="gen", lower=False, + transposed=False, check_finite=True, b_ndim: int | None = None, ): @@ -564,8 +569,10 @@ def solve( b : (..., N, NRHS) array_like Input data for the right hand side. lower : bool, optional - If True, only the data contained in the lower triangle of `a`. Default + If True, use only the data contained in the lower triangle of `a`. Default is to use upper triangle. (ignored for ``'gen'``) + transposed: bool, optional + If True, solves the system A^T x = b. Default is False. check_finite : bool, optional Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems @@ -577,6 +584,11 @@ def solve( This will influence how batched dimensions are interpreted. """ b_ndim = _default_b_ndim(b, b_ndim) + + if transposed: + a = a.mT + lower = not lower + return Blockwise( Solve( lower=lower, diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 2656b0fd04..c446437ddd 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -5,6 +5,7 @@ import pytest import pytensor.tensor as pt +import tests.unittest_tools as utt from pytensor.configdefaults import config from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import slinalg as pt_slinalg @@ -103,28 +104,41 @@ def test_jax_basic(): ) -@pytest.mark.parametrize("check_finite", [False, True]) -@pytest.mark.parametrize("lower", [False, True]) -@pytest.mark.parametrize("trans", [0, 1, 2]) -def test_jax_SolveTriangular(trans, lower, check_finite): - x = matrix("x") - b = vector("b") +def test_jax_solve(): + rng = np.random.default_rng(utt.fetch_seed()) + + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("B", shape=(5, 5)) + + out = pt_slinalg.solve(A, b, lower=False, transposed=False) + + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=(5, 5)).astype(config.floatX) - out = pt_slinalg.solve_triangular( - x, - b, - trans=trans, - lower=lower, - check_finite=check_finite, - ) compare_jax_and_py( - [x, b], + [A, b], [out], - [ - np.eye(10).astype(config.floatX), - np.arange(10).astype(config.floatX), - ], + [A_val, b_val], + ) + + +def test_jax_SolveTriangular(): + rng = np.random.default_rng(utt.fetch_seed()) + + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("B", shape=(5, 5)) + + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=(5, 5)).astype(config.floatX) + + out = pt_slinalg.solve_triangular( + A, + b, + trans=0, + lower=True, + unit_diagonal=False, ) + compare_jax_and_py([A, b], [out], [A_val, b_val]) def test_jax_block_diag(): diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 5caeb8bef9..defbcf6c86 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -5,7 +5,6 @@ import numpy as np import pytest from numpy.testing import assert_allclose -from scipy import linalg as scipy_linalg import pytensor import pytensor.tensor as pt @@ -26,9 +25,9 @@ def transpose_func(x, trans): if trans == 0: return x if trans == 1: - return x.conj().T - if trans == 2: return x.T + if trans == 2: + return x.conj().T @pytest.mark.parametrize( @@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl def A_func(x): x = x @ x.conj().T - x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype) + x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype) if unit_diag: - x_tri[np.diag_indices_from(x_tri)] = 1.0 + x_tri = pt.fill_diagonal(x_tri, 1.0) - return x_tri.astype(dtype) + return x_tri solve_op = partial( pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag ) - X = solve_op(A, b) + X = solve_op(A_func(A), b) f = pytensor.function([A, b], X, mode="NUMBA") A_val = np.random.normal(size=(5, 5)) @@ -80,20 +79,20 @@ def A_func(x): A_val = A_val + np.random.normal(size=(5, 5)) * 1j b_val = b_val + np.random.normal(size=b_shape) * 1j - X_np = f(A_func(A_val), b_val) - - test_input = transpose_func(A_func(A_val), trans) - - ATOL = 1e-8 if floatX.endswith("64") else 1e-4 - RTOL = 1e-8 if floatX.endswith("64") else 1e-4 - - np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) + X_np = f(A_val.copy(), b_val.copy()) + A_val_transformed = transpose_func(A_func(A_val), trans).eval() + np.testing.assert_allclose( + A_val_transformed @ X_np, + b_val, + atol=1e-8 if floatX.endswith("64") else 1e-4, + rtol=1e-8 if floatX.endswith("64") else 1e-4, + ) compiled_fgraph = f.maker.fgraph compare_numba_and_py( compiled_fgraph.inputs, compiled_fgraph.outputs, - [A_func(A_val), b_val], + [A_val, b_val], ) @@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b): b_test_nb = b_test_py.copy(order="F") op = SolveTriangular( - trans=0, unit_diagonal=False, lower=False, check_finite=True, diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 34f1396f4c..f1a6b0fe56 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -214,7 +214,38 @@ def test_solve_raises_on_invalid_A(): Solve(assume_a="test", b_ndim=2) +solve_test_cases = [ + ("gen", False, False), + ("gen", False, True), + ("sym", False, False), + ("sym", True, False), + ("sym", True, True), + ("pos", False, False), + ("pos", True, False), + ("pos", True, True), +] +solve_test_ids = [ + f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}' + for assume_a, lower, transposed in solve_test_cases +] + + class TestSolve(utt.InferShapeTester): + @staticmethod + def A_func(x, assume_a): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x + + @staticmethod + def T(x, transposed): + if transposed: + return x.T + return x + @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) @@ -235,8 +266,12 @@ def test_infer_shape(self, b_shape): @pytest.mark.parametrize( "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] ) - @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) - def test_solve_correctness(self, b_size: tuple[int], assume_a: str): + @pytest.mark.parametrize( + "assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids + ) + def test_solve_correctness( + self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool + ): rng = np.random.default_rng(utt.fetch_seed()) A = pt.tensor("A", shape=(5, 5)) b = pt.tensor("b", shape=b_size) @@ -244,19 +279,18 @@ def test_solve_correctness(self, b_size: tuple[int], assume_a: str): A_val = rng.normal(size=(5, 5)).astype(config.floatX) b_val = rng.normal(size=b_size).astype(config.floatX) - solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) + A_func = functools.partial(self.A_func, assume_a=assume_a) + T = functools.partial(self.T, transposed=transposed) - def A_func(x): - if assume_a == "pos": - return x @ x.T - elif assume_a == "sym": - return (x + x.T) / 2 - else: - return x - - solve_input_val = A_func(A_val) + y = solve( + A_func(A), + b, + assume_a=assume_a, + lower=lower, + transposed=transposed, + b_ndim=len(b_size), + ) - y = solve_op(A_func(A), b) solve_func = pytensor.function([A, b], y) X_np = solve_func(A_val.copy(), b_val.copy()) @@ -264,22 +298,34 @@ def A_func(x): RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4 np.testing.assert_allclose( - scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a), + scipy.linalg.solve( + A_func(A_val), + b_val, + assume_a=assume_a, + transposed=transposed, + lower=lower, + ), X_np, atol=ATOL, rtol=RTOL, ) - np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL) + np.testing.assert_allclose(T(A_func(A_val)) @ X_np, b_val, atol=ATOL, rtol=RTOL) @pytest.mark.parametrize( "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] ) - @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + @pytest.mark.parametrize( + "assume_a, lower, transposed", + solve_test_cases, + ids=solve_test_ids, + ) @pytest.mark.skipif( config.floatX == "float32", reason="Gradients not numerically stable in float32" ) - def test_solve_gradient(self, b_size: tuple[int], assume_a: str): + def test_solve_gradient( + self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool + ): rng = np.random.default_rng(utt.fetch_seed()) eps = 2e-8 if config.floatX == "float64" else None @@ -287,15 +333,8 @@ def test_solve_gradient(self, b_size: tuple[int], assume_a: str): A_val = rng.normal(size=(5, 5)).astype(config.floatX) b_val = rng.normal(size=b_size).astype(config.floatX) - def A_func(x): - if assume_a == "pos": - return x @ x.T - elif assume_a == "sym": - return (x + x.T) / 2 - else: - return x - solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) + A_func = functools.partial(self.A_func, assume_a=assume_a) # To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices # (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included, @@ -307,11 +346,27 @@ def A_func(x): class TestSolveTriangular(utt.InferShapeTester): + @staticmethod + def A_func(x, lower, unit_diagonal): + x = x @ x.T + x = pt.linalg.cholesky(x, lower=lower) + if unit_diagonal: + x = pt.fill_diagonal(x, 1) + return x + + @staticmethod + def T(x, trans): + if trans == 1: + return x.T + elif trans == 2: + return x.conj().T + return x + @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() - b_val = np.asarray(rng.random(b_shape), dtype=config.floatX) + b_val = rng.random(b_shape).astype(config.floatX) b = pt.as_tensor_variable(b_val).type() self._compile_and_check( [A, b], @@ -324,56 +379,78 @@ def test_infer_shape(self, b_shape): warn=False, ) + @pytest.mark.parametrize( + "b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"] + ) @pytest.mark.parametrize("lower", [True, False]) - def test_correctness(self, lower): + @pytest.mark.parametrize("trans", [0, 1, 2]) + @pytest.mark.parametrize("unit_diagonal", [True, False]) + def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal): rng = np.random.default_rng(utt.fetch_seed()) + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=b_shape) - b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) + A_val = rng.random((5, 5)).astype(config.floatX) + b_val = rng.random(b_shape).astype(config.floatX) - A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) - A_val = np.dot(A_val.transpose(), A_val) + A_func = functools.partial( + self.A_func, lower=lower, unit_diagonal=unit_diagonal + ) - C_val = scipy.linalg.cholesky(A_val, lower=lower) + x = solve_triangular( + A_func(A), + b, + lower=lower, + trans=trans, + unit_diagonal=unit_diagonal, + b_ndim=len(b_shape), + ) - A = matrix() - b = matrix() + f = pytensor.function([A, b], x) - cholesky = Cholesky(lower=lower) - C = cholesky(A) - y_lower = solve_triangular(C, b, lower=lower) - lower_solve_func = pytensor.function([C, b], y_lower) + x_pt = f(A_val, b_val) + x_sp = scipy.linalg.solve_triangular( + A_func(A_val).eval(), + b_val, + lower=lower, + trans=trans, + unit_diagonal=unit_diagonal, + ) - assert np.allclose( - scipy.linalg.solve_triangular(C_val, b_val, lower=lower), - lower_solve_func(C_val, b_val), + np.testing.assert_allclose( + x_pt, + x_sp, + atol=1e-8 if config.floatX == "float64" else 1e-4, + rtol=1e-8 if config.floatX == "float64" else 1e-4, ) @pytest.mark.parametrize( - "m, n, lower", - [ - (5, None, False), - (5, None, True), - (4, 2, False), - (4, 2, True), - ], + "b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"] ) - def test_solve_grad(self, m, n, lower): - rng = np.random.default_rng(utt.fetch_seed()) + @pytest.mark.parametrize("lower", [True, False]) + @pytest.mark.parametrize("trans", [0, 1]) + @pytest.mark.parametrize("unit_diagonal", [True, False]) + def test_solve_triangular_grad(self, b_shape, lower, trans, unit_diagonal): + if config.floatX == "float32": + pytest.skip(reason="Not enough precision in float32 to get a good gradient") - # Ensure diagonal elements of `A` are relatively large to avoid - # numerical precision issues - A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) + rng = np.random.default_rng(utt.fetch_seed()) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_shape).astype(config.floatX) - if n is None: - b_val = rng.normal(size=m).astype(config.floatX) - else: - b_val = rng.normal(size=(m, n)).astype(config.floatX) + A_func = functools.partial( + self.A_func, lower=lower, unit_diagonal=unit_diagonal + ) eps = None if config.floatX == "float64": eps = 2e-8 - solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2) + def solve_op(A, b): + return solve_triangular( + A_func(A), b, lower=lower, trans=trans, unit_diagonal=unit_diagonal + ) + utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)