Skip to content

Commit 7df6f8f

Browse files
committed
BUG: fix cholesky upper decomp for complex dtypes
1 parent 5c82ea3 commit 7df6f8f

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

Diff for: array_api_compat/common/_linalg.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from numpy.core.numeric import normalize_axis_tuple
99

10-
from ._aliases import matmul, matrix_transpose, tensordot, vecdot
10+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1111
from .._internal import get_xp
1212

1313
# These are in the main NumPy namespace but not in numpy.linalg
@@ -55,7 +55,10 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult
5555
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
5656
L = xp.linalg.cholesky(x, **kwargs)
5757
if upper:
58-
return get_xp(xp)(matrix_transpose)(L)
58+
U = get_xp(xp)(matrix_transpose)(L)
59+
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
60+
U = xp.conj(U)
61+
return U
5962
return L
6063

6164
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.

0 commit comments

Comments
 (0)