Skip to content

Commit 2024bd1

Browse files
authored
Add support for specifying the output array dtype in linalg.trace (#502)
* Add support for specifying the output array dtype * Remove kwarg * Revert change * Fix missing type
1 parent 7572728 commit 2024bd1

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

Diff for: spec/API_specification/array_api/linalg.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._types import Literal, Optional, Tuple, Union, Sequence, array
1+
from ._types import Literal, Optional, Tuple, Union, Sequence, array, dtype
22
from .constants import inf
33

44
def cholesky(x: array, /, *, upper: bool = False) -> array:
@@ -437,10 +437,20 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
437437
Alias for :func:`~array_api.tensordot`.
438438
"""
439439

440-
def trace(x: array, /, *, offset: int = 0) -> array:
440+
def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> array:
441441
"""
442442
Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``.
443443
444+
**Special Cases**
445+
446+
Let ``N`` equal the number of elements over which to compute the sum.
447+
448+
- If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum).
449+
450+
For floating-point operands,
451+
452+
- If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate).
453+
444454
Parameters
445455
----------
446456
x: array
@@ -453,6 +463,18 @@ def trace(x: array, /, *, offset: int = 0) -> array:
453463
- ``offset < 0``: off-diagonal below the main diagonal.
454464
455465
Default: ``0``.
466+
dtype: Optional[dtype]
467+
data type of the returned array. If ``None``,
468+
469+
- if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
470+
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
471+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
472+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
473+
474+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
475+
476+
.. note::
477+
keyword argument is intended to help prevent data type overflows.
456478
457479
Returns
458480
-------
@@ -463,7 +485,7 @@ def trace(x: array, /, *, offset: int = 0) -> array:
463485
464486
out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :])
465487
466-
The returned array must have the same data type as ``x``.
488+
The returned array must have a data type as described by the ``dtype`` parameter above.
467489
"""
468490

469491
def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:

0 commit comments

Comments
 (0)