@@ -63,6 +63,78 @@ def svdvals(x: Array) -> Array:
63
63
vector_norm = get_xp (da )(_linalg .vector_norm )
64
64
diagonal = get_xp (da )(_linalg .diagonal )
65
65
66
+ # Calculate determinant via PLU decomp
67
+ def det (x : Array ) -> Array :
68
+ import scipy .linalg
69
+
70
+ # L has det 1 so don't need to worry about it
71
+ p , _ , u = da .linalg .lu (x )
72
+
73
+ # TODO: numerical stability?
74
+ u_det = da .prod (da .diag (u ))
75
+
76
+ # Now, time to calculate determinant of p
77
+
78
+ # (from reading the source code)
79
+ # We know that dask lu decomp forces square chunks
80
+ # We also know that the P matrix will only be non-zero
81
+ # for a block i, j if and only if i = j
82
+
83
+ # So we will calculate the determinant of each block on
84
+ # the diagonal (of blocks)
85
+
86
+ # This isn't ideal, but hopefully still lets out of core work
87
+ # properly since each block should be able to fit in memory
88
+
89
+ blocks_shape = p .blocks .shape
90
+ n_row_blocks = blocks_shape [0 ]
91
+
92
+ p_det = 1
93
+ for i in range (n_row_blocks ):
94
+ p_det *= scipy .linalg .det (p .blocks [i , i ].compute ())
95
+ return p_det * u_det
96
+
97
+ SlogdetResult = _linalg .SlogdetResult
98
+
99
+ # Calculate determinant via PLU decomp
100
+ def slogdet (x : Array ) -> Array :
101
+ import scipy .linalg
102
+
103
+ # L has det 1 so don't need to worry about it
104
+ p , _ , u = da .linalg .lu (x )
105
+
106
+ u_diag = da .diag (u )
107
+ neg_cnt = (u_diag < 0 ).sum ()
108
+
109
+ u_logabsdet = da .sum (da .log (da .abs (u_diag )))
110
+
111
+ # Now, time to calculate determinant of p
112
+
113
+ # (from reading the source code)
114
+ # We know that dask lu decomp forces square chunks
115
+ # We also know that the P matrix will only be non-zero
116
+ # for a block i, j if and only if i = j
117
+
118
+ # So we will calculate the determinant of each block on
119
+ # the diagonal (of blocks)
120
+
121
+ # This isn't ideal, but hopefully still lets out of core work
122
+ # properly since each block should be able to fit in memory
123
+
124
+ blocks_shape = p .blocks .shape
125
+ n_row_blocks = blocks_shape [0 ]
126
+
127
+ sign = 1
128
+ for i in range (n_row_blocks ):
129
+ sign *= scipy .linalg .det (p .blocks [i , i ].compute ())
130
+
131
+ if neg_cnt % 2 != 0 :
132
+ sign *= - 1
133
+ return SlogdetResult (sign , u_logabsdet )
134
+
135
+
136
+
137
+
66
138
__all__ = linalg_all + ["trace" , "outer" , "matmul" , "tensordot" ,
67
139
"matrix_transpose" , "vecdot" , "EighResult" ,
68
140
"QRResult" , "SlogdetResult" , "SVDResult" , "qr" ,
0 commit comments