Skip to content

Add support for specifying the output array dtype in linalg.trace #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions spec/API_specification/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._types import Literal, Optional, Tuple, Union, Sequence, array
from ._types import Literal, Optional, Tuple, Union, Sequence, array, dtype
from .constants import inf

def cholesky(x: array, /, *, upper: bool = False) -> array:
Expand Down Expand Up @@ -437,10 +437,20 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
Alias for :func:`~array_api.tensordot`.
"""

def trace(x: array, /, *, offset: int = 0) -> array:
def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> array:
"""
Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``.

**Special Cases**

Let ``N`` equal the number of elements over which to compute the sum.

- If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum).

For floating-point operands,

- If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate).

Parameters
----------
x: array
Expand All @@ -453,6 +463,18 @@ def trace(x: array, /, *, offset: int = 0) -> array:
- ``offset < 0``: off-diagonal below the main diagonal.

Default: ``0``.
dtype: Optional[dtype]
data type of the returned array. If ``None``,

- if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).

If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.

.. note::
keyword argument is intended to help prevent data type overflows.

Returns
-------
Expand All @@ -463,7 +485,7 @@ def trace(x: array, /, *, offset: int = 0) -> array:

out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :])

The returned array must have the same data type as ``x``.
The returned array must have a data type as described by the ``dtype`` parameter above.
"""

def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
Expand Down