diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 5336d93c6..43cf983e8 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -1,4 +1,4 @@ -from ._types import Literal, Optional, Tuple, Union, Sequence, array +from ._types import Literal, Optional, Tuple, Union, Sequence, array, dtype from .constants import inf def cholesky(x: array, /, *, upper: bool = False) -> array: @@ -437,10 +437,20 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Alias for :func:`~array_api.tensordot`. """ -def trace(x: array, /, *, offset: int = 0) -> array: +def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> array: """ Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``. + **Special Cases** + + Let ``N`` equal the number of elements over which to compute the sum. + + - If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum). + + For floating-point operands, + + - If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate). + Parameters ---------- x: array @@ -453,6 +463,18 @@ def trace(x: array, /, *, offset: int = 0) -> array: - ``offset < 0``: off-diagonal below the main diagonal. Default: ``0``. + dtype: Optional[dtype] + data type of the returned array. If ``None``, + + - if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. + - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + + If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``. + + .. note:: + keyword argument is intended to help prevent data type overflows. Returns ------- @@ -463,7 +485,7 @@ def trace(x: array, /, *, offset: int = 0) -> array: out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :]) - The returned array must have the same data type as ``x``. + The returned array must have a data type as described by the ``dtype`` parameter above. """ def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array: