Skip to content

Implement tridiagonal solve in numba backend #1311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,70 @@ def numba_xposv(cls, dtype):
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xgttrf(cls, dtype):
"""
Compute the LU factorization of a tridiagonal matrix A using row interchanges.

Called by scipy.linalg.lu_factor
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # N
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xgttrs(cls, dtype):
"""
Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf.

Called by scipy.linalg.lu_solve
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrs")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
float_pointer, # B
_ptr_int, # LDB
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xgtcon(cls, dtype):
"""
Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gtcon")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # NORM
_ptr_int, # N
float_pointer, # DL
float_pointer, # D
float_pointer, # DU
float_pointer, # DU2
_ptr_int, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
_ptr_int, # IWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
Empty file.
Empty file.
66 changes: 66 additions & 0 deletions pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return (
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)


@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])

Check warning on line 33 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L33

Added line #L33 was not covered by tests
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

Check warning on line 35 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L35

Added line #L35 was not covered by tests

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

Check warning on line 40 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L37-L40

Added lines #L37 - L40 were not covered by tests

if overwrite_a and A.flags.f_contiguous:
A_copy = A

Check warning on line 43 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L43

Added line #L43 was not covered by tests
else:
A_copy = _copy_to_fortran_order(A)

Check warning on line 45 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L45

Added line #L45 was not covered by tests

numba_potrf(

Check warning on line 47 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L47

Added line #L47 was not covered by tests
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)

if lower:
for j in range(1, _N):
for i in range(j):
A_copy[i, j] = 0.0

Check warning on line 58 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L58

Added line #L58 was not covered by tests
else:
for j in range(_N):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0

Check warning on line 62 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L62

Added line #L62 was not covered by tests

return A_copy, int_ptr_to_val(INFO)

Check warning on line 64 in pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py#L64

Added line #L64 was not covered by tests

return impl
Empty file.
87 changes: 87 additions & 0 deletions pytensor/link/numba/dispatch/linalg/solve/cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)


def _cho_solve(
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return linalg.cho_solve(
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
)


@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)

def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B)

Check warning on line 41 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L41

Added line #L41 was not covered by tests

_N = np.int32(C.shape[-1])

Check warning on line 43 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L43

Added line #L43 was not covered by tests
if C.flags.f_contiguous or C.flags.c_contiguous:
C_f = C

Check warning on line 45 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L45

Added line #L45 was not covered by tests
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower

Check warning on line 48 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L48

Added line #L48 was not covered by tests
else:
C_f = np.asfortranarray(C)

Check warning on line 50 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L50

Added line #L50 was not covered by tests

if overwrite_b and B.flags.f_contiguous:
B_copy = B

Check warning on line 53 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L53

Added line #L53 was not covered by tests
else:
B_copy = _copy_to_fortran_order_even_if_1d(B)

Check warning on line 55 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L55

Added line #L55 was not covered by tests

B_is_1d = B.ndim == 1

Check warning on line 57 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L57

Added line #L57 was not covered by tests
if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)

Check warning on line 59 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L59

Added line #L59 was not covered by tests

NRHS = 1 if B_is_1d else int(B.shape[-1])

Check warning on line 61 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L61

Added line #L61 was not covered by tests

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(NRHS)
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

Check warning on line 68 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L63-L68

Added lines #L63 - L68 were not covered by tests

numba_potrs(

Check warning on line 70 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L70

Added line #L70 was not covered by tests
UPLO,
N,
NRHS,
C_f.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)

_solve_check(_N, int_ptr_to_val(INFO))

Check warning on line 81 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L81

Added line #L81 was not covered by tests

if B_is_1d:
return B_copy[..., 0]
return B_copy

Check warning on line 85 in pytensor/link/numba/dispatch/linalg/solve/cholesky.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/cholesky.py#L84-L85

Added lines #L84 - L85 were not covered by tests

return impl
Loading