Skip to content

Commit 980129b

Browse files
committed
Add det/slogdet workaround for dask
1 parent 376038e commit 980129b

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

array_api_compat/dask/array/linalg.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,78 @@ def svdvals(x: Array) -> Array:
6363
vector_norm = get_xp(da)(_linalg.vector_norm)
6464
diagonal = get_xp(da)(_linalg.diagonal)
6565

66+
# Calculate determinant via PLU decomp
67+
def det(x: Array) -> Array:
68+
import scipy.linalg
69+
70+
# L has det 1 so don't need to worry about it
71+
p, _, u = da.linalg.lu(x)
72+
73+
# TODO: numerical stability?
74+
u_det = da.prod(da.diag(u))
75+
76+
# Now, time to calculate determinant of p
77+
78+
# (from reading the source code)
79+
# We know that dask lu decomp forces square chunks
80+
# We also know that the P matrix will only be non-zero
81+
# for a block i, j if and only if i = j
82+
83+
# So we will calculate the determinant of each block on
84+
# the diagonal (of blocks)
85+
86+
# This isn't ideal, but hopefully still lets out of core work
87+
# properly since each block should be able to fit in memory
88+
89+
blocks_shape = p.blocks.shape
90+
n_row_blocks = blocks_shape[0]
91+
92+
p_det = 1
93+
for i in range(n_row_blocks):
94+
p_det *= scipy.linalg.det(p.blocks[i, i].compute())
95+
return p_det * u_det
96+
97+
SlogdetResult = _linalg.SlogdetResult
98+
99+
# Calculate determinant via PLU decomp
100+
def slogdet(x: Array) -> Array:
101+
import scipy.linalg
102+
103+
# L has det 1 so don't need to worry about it
104+
p, _, u = da.linalg.lu(x)
105+
106+
u_diag = da.diag(u)
107+
neg_cnt = (u_diag < 0).sum()
108+
109+
u_logabsdet = da.sum(da.log(da.abs(u_diag)))
110+
111+
# Now, time to calculate determinant of p
112+
113+
# (from reading the source code)
114+
# We know that dask lu decomp forces square chunks
115+
# We also know that the P matrix will only be non-zero
116+
# for a block i, j if and only if i = j
117+
118+
# So we will calculate the determinant of each block on
119+
# the diagonal (of blocks)
120+
121+
# This isn't ideal, but hopefully still lets out of core work
122+
# properly since each block should be able to fit in memory
123+
124+
blocks_shape = p.blocks.shape
125+
n_row_blocks = blocks_shape[0]
126+
127+
sign = 1
128+
for i in range(n_row_blocks):
129+
sign *= scipy.linalg.det(p.blocks[i, i].compute())
130+
131+
if neg_cnt % 2 != 0:
132+
sign *= -1
133+
return SlogdetResult(sign, u_logabsdet)
134+
135+
136+
137+
66138
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
67139
"matrix_transpose", "vecdot", "EighResult",
68140
"QRResult", "SlogdetResult", "SVDResult", "qr",

dask-xfails.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ array_api_tests/test_linalg.py::test_cholesky
8080
array_api_tests/test_linalg.py::test_tensordot
8181
# probably same reason for failing as numpy
8282
array_api_tests/test_linalg.py::test_trace
83+
# our version depends on dask's LU, which doesn't support ndim > 2
84+
array_api_tests/test_linalg.py::test_det
85+
array_api_tests/test_linalg.py::test_slogdet
8386

8487
# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)]
8588
array_api_tests/test_linalg.py::test_linalg_tensordot
@@ -97,18 +100,14 @@ array_api_tests/test_linalg.py::test_linalg_matmul
97100

98101
# Linalg - these don't exist in dask
99102
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
100-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det]
101103
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh]
102104
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh]
103105
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power]
104106
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
105-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet]
106107
array_api_tests/test_linalg.py::test_cross
107-
array_api_tests/test_linalg.py::test_det
108108
array_api_tests/test_linalg.py::test_eigh
109109
array_api_tests/test_linalg.py::test_eigvalsh
110110
array_api_tests/test_linalg.py::test_pinv
111-
array_api_tests/test_linalg.py::test_slogdet
112111
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
113112
array_api_tests/test_has_names.py::test_has_names[linalg-det]
114113
array_api_tests/test_has_names.py::test_has_names[linalg-eigh]

0 commit comments

Comments
 (0)