From 9a11f9f4903135f63d08dc91271231469d1c3578 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 3 Nov 2022 00:45:22 -0700 Subject: [PATCH 1/4] Add support for specifying the output array dtype --- spec/API_specification/array_api/linalg.py | 26 ++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 5336d93c6..f6e4b25f4 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -437,10 +437,20 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Alias for :func:`~array_api.tensordot`. """ -def trace(x: array, /, *, offset: int = 0) -> array: +def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> array: """ Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``. + **Special Cases** + + Let ``N`` equal the number of elements over which to compute the sum. + + - If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum). + + For floating-point operands, + + - If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate). + Parameters ---------- x: array @@ -453,6 +463,18 @@ def trace(x: array, /, *, offset: int = 0) -> array: - ``offset < 0``: off-diagonal below the main diagonal. Default: ``0``. + dtype: Optional[dtype] + data type of the returned array. If ``None``, + + - if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. + - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + + If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``. + + .. note:: + keyword argument is intended to help prevent data type overflows. Returns ------- @@ -463,7 +485,7 @@ def trace(x: array, /, *, offset: int = 0) -> array: out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :]) - The returned array must have the same data type as ``x``. + The returned array must have a data type as described by the ``dtype`` parameter above. """ def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array: From 6b9fa1e7e691543b9a74f3a9cd2a49cf157ed421 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 3 Nov 2022 01:00:57 -0700 Subject: [PATCH 2/4] Remove kwarg --- spec/API_specification/array_api/linalg.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index f6e4b25f4..3f16698de 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -463,18 +463,6 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr - ``offset < 0``: off-diagonal below the main diagonal. Default: ``0``. - dtype: Optional[dtype] - data type of the returned array. If ``None``, - - - if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. - - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. - - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. - - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). - - If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``. - - .. note:: - keyword argument is intended to help prevent data type overflows. Returns ------- From 51296918530096564e3f68bb4656f7eb57f5e487 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 3 Nov 2022 01:02:09 -0700 Subject: [PATCH 3/4] Revert change --- spec/API_specification/array_api/linalg.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 3f16698de..bad36b865 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -441,16 +441,6 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr """ Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``. - **Special Cases** - - Let ``N`` equal the number of elements over which to compute the sum. - - - If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum). - - For floating-point operands, - - - If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate). - Parameters ---------- x: array @@ -463,6 +453,18 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr - ``offset < 0``: off-diagonal below the main diagonal. Default: ``0``. + dtype: Optional[dtype] + data type of the returned array. If ``None``, + + - if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``. + - if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type. + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + + If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``. + + .. note:: + keyword argument is intended to help prevent data type overflows. Returns ------- From eddb2680d0720178c39872bfa0c3246b9e815b69 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 3 Nov 2022 01:03:22 -0700 Subject: [PATCH 4/4] Fix missing type --- spec/API_specification/array_api/linalg.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index bad36b865..43cf983e8 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -1,4 +1,4 @@ -from ._types import Literal, Optional, Tuple, Union, Sequence, array +from ._types import Literal, Optional, Tuple, Union, Sequence, array, dtype from .constants import inf def cholesky(x: array, /, *, upper: bool = False) -> array: @@ -441,6 +441,16 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr """ Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``. + **Special Cases** + + Let ``N`` equal the number of elements over which to compute the sum. + + - If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum). + + For floating-point operands, + + - If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate). + Parameters ---------- x: array