Skip to content

Commit d35ee58

Browse files
rgommershonno
authored andcommitted
Fix linalg.trace test, result dtype was incorrect
As the spec says, the output dtype should be the default integer/float/complex dtype just like for `sum`. Given that the reference was implemented with `sum`, avoiding to pass an explicit dtype should be enough to obtain the correct results.
1 parent b6370dc commit d35ee58

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Diff for: array_api_tests/test_linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def true_trace(x_stack):
626626
x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)]
627627
else:
628628
x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
629-
return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype), dtype=x.dtype)
629+
return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype))
630630

631631
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
632632

0 commit comments

Comments
 (0)