@@ -320,12 +320,8 @@ def slogdet(x: Array, /) -> SlogdetResult:
320
320
def _solve (a : np .ndarray , b : np .ndarray ) -> np .ndarray :
321
321
try :
322
322
from numpy .linalg ._linalg import ( # type: ignore[attr-defined]
323
- _assert_stacked_2d ,
324
- _assert_stacked_square ,
325
- _commonType ,
326
- _makearray ,
327
- _raise_linalgerror_singular ,
328
- isComplexType ,
323
+ _makearray , _assert_stacked_2d , _assert_stacked_square ,
324
+ _commonType , isComplexType , _raise_linalgerror_singular
329
325
)
330
326
except ImportError :
331
327
from numpy .linalg .linalg import ( # type: ignore[attr-defined]
@@ -412,7 +408,8 @@ def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array:
412
408
413
409
# Note: trace always operates on the last two axes, whereas np.trace
414
410
# operates on the first two axes by default
415
- return Array ._new (np .asarray (np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 , dtype = np_dtype )), device = x .device )
411
+ res = np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 , dtype = np_dtype )
412
+ return Array ._new (np .asarray (res ), device = x .device )
416
413
417
414
# Note: the name here is different from norm(). The array API norm is split
418
415
# into matrix_norm and vector_norm().
0 commit comments