15
15
from pytensor .tensor import TensorLike , as_tensor_variable
16
16
from pytensor .tensor import basic as ptb
17
17
from pytensor .tensor import math as ptm
18
+ from pytensor .tensor .basic import diagonal
18
19
from pytensor .tensor .blockwise import Blockwise
19
20
from pytensor .tensor .nlinalg import kron , matrix_dot
20
21
from pytensor .tensor .shape import reshape
@@ -260,10 +261,10 @@ def make_node(self, A, b):
260
261
raise ValueError (f"`b` must have { self .b_ndim } dims; got { b .type } instead." )
261
262
262
263
# Infer dtype by solving the most simple case with 1x1 matrices
263
- inp_arr = [ np . eye ( 1 ). astype ( A . dtype ), np . eye ( 1 ). astype ( b . dtype )]
264
- out_arr = [[ None ]]
265
- self . perform ( None , inp_arr , out_arr )
266
- o_dtype = out_arr [ 0 ][ 0 ] .dtype
264
+ o_dtype = scipy_linalg . solve (
265
+ np . ones (( 1 , 1 ), dtype = A . dtype ),
266
+ np . ones (( 1 ,), dtype = b . dtype ),
267
+ ) .dtype
267
268
x = tensor (dtype = o_dtype , shape = b .type .shape )
268
269
return Apply (self , [A , b ], [x ])
269
270
@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
315
316
316
317
b = as_tensor_variable (b )
317
318
if b_ndim is None :
318
- return min (b .ndim , 2 ) # By default assume the core case is a matrix
319
+ return min (b .ndim , 2 ) # By default, assume the core case is a matrix
319
320
320
321
321
322
class CholeskySolve (SolveBase ):
@@ -332,6 +333,19 @@ def __init__(self, **kwargs):
332
333
kwargs .setdefault ("lower" , True )
333
334
super ().__init__ (** kwargs )
334
335
336
+ def make_node (self , * inputs ):
337
+ # Allow base class to do input validation
338
+ super_apply = super ().make_node (* inputs )
339
+ A , b = super_apply .inputs
340
+ [super_out ] = super_apply .outputs
341
+ # The dtype of chol_solve does not match solve, which the base class checks
342
+ dtype = scipy_linalg .cho_solve (
343
+ (np .ones ((1 , 1 ), dtype = A .dtype ), False ),
344
+ np .ones ((1 ,), dtype = b .dtype ),
345
+ ).dtype
346
+ out = tensor (dtype = dtype , shape = super_out .type .shape )
347
+ return Apply (self , [A , b ], [out ])
348
+
335
349
def perform (self , node , inputs , output_storage ):
336
350
C , b = inputs
337
351
rval = scipy_linalg .cho_solve (
@@ -499,8 +513,33 @@ class Solve(SolveBase):
499
513
)
500
514
501
515
def __init__ (self , * , assume_a = "gen" , ** kwargs ):
502
- if assume_a not in ("gen" , "sym" , "her" , "pos" ):
503
- raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
516
+ # Triangular and diagonal are handled outside of Solve
517
+ valid_options = ["gen" , "sym" , "her" , "pos" , "tridiagonal" , "banded" ]
518
+
519
+ assume_a = assume_a .lower ()
520
+ # We use the old names as the different dispatches are more likely to support them
521
+ long_to_short = {
522
+ "general" : "gen" ,
523
+ "symmetric" : "sym" ,
524
+ "hermitian" : "her" ,
525
+ "positive definite" : "pos" ,
526
+ }
527
+ assume_a = long_to_short .get (assume_a , assume_a )
528
+
529
+ if assume_a not in valid_options :
530
+ raise ValueError (
531
+ f"Invalid assume_a: { assume_a } . It must be one of { valid_options } or { list (long_to_short .keys ())} "
532
+ )
533
+
534
+ if assume_a in ("tridiagonal" , "banded" ):
535
+ from scipy import __version__ as sp_version
536
+
537
+ if tuple (map (int , sp_version .split ("." )[:- 1 ])) < (1 , 15 ):
538
+ warnings .warn (
539
+ f"assume_a={ assume_a } requires scipy>=1.5.0. Defaulting to assume_a='gen'." ,
540
+ UserWarning ,
541
+ )
542
+ assume_a = "gen"
504
543
505
544
super ().__init__ (** kwargs )
506
545
self .assume_a = assume_a
@@ -536,10 +575,12 @@ def solve(
536
575
a ,
537
576
b ,
538
577
* ,
539
- assume_a = "gen" ,
540
- lower = False ,
541
- transposed = False ,
542
- check_finite = True ,
578
+ lower : bool = False ,
579
+ overwrite_a : bool = False ,
580
+ overwrite_b : bool = False ,
581
+ check_finite : bool = True ,
582
+ assume_a : str = "gen" ,
583
+ transposed : bool = False ,
543
584
b_ndim : int | None = None ,
544
585
):
545
586
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +589,19 @@ def solve(
548
589
corresponding string to ``assume_a`` key chooses the dedicated solver.
549
590
The available options are
550
591
551
- =================== ========
552
- generic matrix 'gen'
553
- symmetric 'sym'
554
- hermitian 'her'
555
- positive definite 'pos'
556
- =================== ========
592
+ =================== ================================
593
+ diagonal 'diagonal'
594
+ tridiagonal 'tridiagonal'
595
+ banded 'banded'
596
+ upper triangular 'upper triangular'
597
+ lower triangular 'lower triangular'
598
+ symmetric 'symmetric' (or 'sym')
599
+ hermitian 'hermitian' (or 'her')
600
+ positive definite 'positive definite' (or 'pos')
601
+ general 'general' (or 'gen')
602
+ =================== ================================
557
603
558
- If omitted, ``'gen '`` is the default structure.
604
+ If omitted, ``'general '`` is the default structure.
559
605
560
606
The datatype of the arrays define which solver is called regardless
561
607
of the values. In other words, even when the complex array entries have
@@ -568,23 +614,52 @@ def solve(
568
614
Square input data
569
615
b : (..., N, NRHS) array_like
570
616
Input data for the right hand side.
571
- lower : bool, optional
572
- If True, use only the data contained in the lower triangle of `a`. Default
573
- is to use upper triangle. (ignored for ``'gen'``)
574
- transposed: bool, optional
575
- If True, solves the system A^T x = b. Default is False.
617
+ lower : bool, default False
618
+ Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
619
+ If True, the calculation uses only the data in the lower triangle of `a`;
620
+ entries above the diagonal are ignored. If False (default), the
621
+ calculation uses only the data in the upper triangle of `a`; entries
622
+ below the diagonal are ignored.
623
+ overwrite_a : bool
624
+ Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
625
+ overwrite_b : bool
626
+ Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
576
627
check_finite : bool, optional
577
628
Whether to check that the input matrices contain only finite numbers.
578
629
Disabling may give a performance gain, but may result in problems
579
630
(crashes, non-termination) if the inputs do contain infinities or NaNs.
580
631
assume_a : str, optional
581
632
Valid entries are explained above.
633
+ transposed: bool, default False
634
+ If True, solves the system A^T x = b. Default is False.
582
635
b_ndim : int
583
636
Whether the core case of b is a vector (1) or matrix (2).
584
637
This will influence how batched dimensions are interpreted.
638
+ By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585
639
"""
640
+ assume_a = assume_a .lower ()
641
+
642
+ if assume_a in ("lower triangular" , "upper triangular" ):
643
+ lower = "lower" in assume_a
644
+ return solve_triangular (
645
+ a ,
646
+ b ,
647
+ lower = lower ,
648
+ trans = transposed ,
649
+ check_finite = check_finite ,
650
+ b_ndim = b_ndim ,
651
+ )
652
+
586
653
b_ndim = _default_b_ndim (b , b_ndim )
587
654
655
+ if assume_a == "diagonal" :
656
+ a_diagonal = diagonal (a , axis1 = - 2 , axis2 = - 1 )
657
+ b_transposed = b [None , :] if b_ndim == 1 else b .mT
658
+ x = (b_transposed / pt .expand_dims (a_diagonal , - 2 )).mT
659
+ if b_ndim == 1 :
660
+ x = x .squeeze (- 1 )
661
+ return x
662
+
588
663
if transposed :
589
664
a = a .mT
590
665
lower = not lower
0 commit comments