Skip to content

Allow transposed argument in linalg.solve #1231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def solve(a, b, lower=lower):
@jax_funcify.register(SolveTriangular)
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite

Expand All @@ -62,7 +61,7 @@ def solve_triangular(A, b):
A,
b,
lower=lower,
trans=trans,
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)
Expand Down
3 changes: 1 addition & 2 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):

@numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
trans = bool(op.trans)
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
Expand Down Expand Up @@ -208,7 +207,7 @@ def solve_triangular(a, b):
res = _solve_triangular(
a,
b,
trans=trans,
trans=0, # transposing is handled explicitly on the graph, so we never use this argument
lower=lower,
unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b,
Expand Down
42 changes: 27 additions & 15 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,12 @@ def L_op(self, inputs, outputs, output_gradients):
# We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0]

trans_solve_op = type(self)(
**{
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
for k in self.__props__
}
)
b_bar = trans_solve_op(A.T, c_bar)
props_dict = self._props_dict()
props_dict["lower"] = not self.lower

solve_op = type(self)(**props_dict)

b_bar = solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)

Expand Down Expand Up @@ -385,19 +384,17 @@ class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""

__props__ = (
"trans",
"unit_diagonal",
"lower",
"check_finite",
"b_ndim",
"overwrite_b",
)

def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
def __init__(self, *, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare")
super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal

def perform(self, node, inputs, outputs):
Expand All @@ -406,7 +403,7 @@ def perform(self, node, inputs, outputs):
A,
b,
lower=self.lower,
trans=self.trans,
trans=0,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
Expand Down Expand Up @@ -445,9 +442,9 @@ def solve_triangular(

Parameters
----------
a
a: TensorVariable
Square input data
b
b: TensorVariable
Input data for the right hand side.
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
Expand All @@ -468,10 +465,17 @@ def solve_triangular(
This will influence how batched dimensions are interpreted.
"""
b_ndim = _default_b_ndim(b, b_ndim)

if trans in [1, "T", True]:
a = a.mT
lower = not lower
if trans in [2, "C"]:
a = a.conj().mT
lower = not lower

ret = Blockwise(
SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim=b_ndim,
Expand Down Expand Up @@ -534,6 +538,7 @@ def solve(
*,
assume_a="gen",
lower=False,
transposed=False,
check_finite=True,
b_ndim: int | None = None,
):
Expand Down Expand Up @@ -564,8 +569,10 @@ def solve(
b : (..., N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
If True, use only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
transposed: bool, optional
If True, solves the system A^T x = b. Default is False.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
Expand All @@ -577,6 +584,11 @@ def solve(
This will influence how batched dimensions are interpreted.
"""
b_ndim = _default_b_ndim(b, b_ndim)

if transposed:
a = a.mT
lower = not lower

return Blockwise(
Solve(
lower=lower,
Expand Down
50 changes: 32 additions & 18 deletions tests/link/jax/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor.configdefaults import config
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg
Expand Down Expand Up @@ -103,28 +104,41 @@ def test_jax_basic():
)


@pytest.mark.parametrize("check_finite", [False, True])
@pytest.mark.parametrize("lower", [False, True])
@pytest.mark.parametrize("trans", [0, 1, 2])
def test_jax_SolveTriangular(trans, lower, check_finite):
x = matrix("x")
b = vector("b")
def test_jax_solve():
rng = np.random.default_rng(utt.fetch_seed())

A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("B", shape=(5, 5))

out = pt_slinalg.solve(A, b, lower=False, transposed=False)

A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)

out = pt_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
compare_jax_and_py(
[x, b],
[A, b],
[out],
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
[A_val, b_val],
)


def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed())

A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("B", shape=(5, 5))

A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)

out = pt_slinalg.solve_triangular(
A,
b,
trans=0,
lower=True,
unit_diagonal=False,
)
compare_jax_and_py([A, b], [out], [A_val, b_val])


def test_jax_block_diag():
Expand Down
32 changes: 15 additions & 17 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy import linalg as scipy_linalg

import pytensor
import pytensor.tensor as pt
Expand All @@ -26,9 +25,9 @@ def transpose_func(x, trans):
if trans == 0:
return x
if trans == 1:
return x.conj().T
if trans == 2:
return x.T
if trans == 2:
return x.conj().T


@pytest.mark.parametrize(
Expand Down Expand Up @@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl

def A_func(x):
x = x @ x.conj().T
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype)
x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)

if unit_diag:
x_tri[np.diag_indices_from(x_tri)] = 1.0
x_tri = pt.fill_diagonal(x_tri, 1.0)

return x_tri.astype(dtype)
return x_tri

solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
)

X = solve_op(A, b)
X = solve_op(A_func(A), b)
f = pytensor.function([A, b], X, mode="NUMBA")

A_val = np.random.normal(size=(5, 5))
Expand All @@ -80,20 +79,20 @@ def A_func(x):
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b_val = b_val + np.random.normal(size=b_shape) * 1j

X_np = f(A_func(A_val), b_val)

test_input = transpose_func(A_func(A_val), trans)

ATOL = 1e-8 if floatX.endswith("64") else 1e-4
RTOL = 1e-8 if floatX.endswith("64") else 1e-4

np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
X_np = f(A_val.copy(), b_val.copy())
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
np.testing.assert_allclose(
A_val_transformed @ X_np,
b_val,
atol=1e-8 if floatX.endswith("64") else 1e-4,
rtol=1e-8 if floatX.endswith("64") else 1e-4,
)

compiled_fgraph = f.maker.fgraph
compare_numba_and_py(
compiled_fgraph.inputs,
compiled_fgraph.outputs,
[A_func(A_val), b_val],
[A_val, b_val],
)


Expand Down Expand Up @@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
b_test_nb = b_test_py.copy(order="F")

op = SolveTriangular(
trans=0,
unit_diagonal=False,
lower=False,
check_finite=True,
Expand Down
Loading