From 8764700533a0308a065d8f5024121ef8ef0ee3a9 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 20 Sep 2023 18:32:30 -0700 Subject: [PATCH 01/19] Add `repeat` to the specification --- .../manipulation_functions.rst | 1 + .../_draft/manipulation_functions.py | 43 ++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index 7eb7fa8b0..75f9e1ba6 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -25,6 +25,7 @@ Objects in API flip moveaxis permute_dims + repeat reshape roll squeeze diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 2bc929134..0130c9009 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -6,6 +6,7 @@ "flip", "moveaxis", "permute_dims", + "repeat", "reshape", "roll", "squeeze", @@ -14,7 +15,7 @@ ] -from ._types import List, Optional, Tuple, Union, array +from ._types import List, Optional, Tuple, Union, Sequence, array def broadcast_arrays(*arrays: array) -> List[array]: @@ -158,6 +159,46 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: """ +def repeat( + x: array, + repeats: Union[int, Sequence[int], array], + /, + *, + axis: Optional[int] = None, +) -> array: + """ + Repeats elements of an array. + + Parameters + ---------- + x: array + input array containing elements to repeat. + repeats: Union[int, Sequence[int], array] + the number of repetitions for each element. + + If ``axis`` is ``None``, let ``N = prod(x.shape)`` and + + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N)`` (i.e., be a one-dimensional array having shape ``(1)`` or ``(N)``). + - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(N)`` (i.e., the number of sequence elements be either ``1`` or ``N``). + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N)`. + + If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and + + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M)`` (i.e., be a one-dimensional array having shape ``(1)`` or ``(M)``. + - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(M)`` (i.e., the number of sequence elements must be either ``1`` or ``M``). + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M)``. + + If ``repeats`` is an array, the array must have an integer data type. + axis: Optional[int] + the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must repeat elements of a flattened input array ``x`` and return the result as a one-dimensional output array. Default: ``None``. + + Returns + ------- + out: array + an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array have the same shape as ``x``, except for the axis (dimension) along which elements were repeated. + """ + + def reshape( x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None ) -> array: From 2decb2a73c9d46c426e7ee14fedcd9c0e9c9316c Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 20 Sep 2023 18:47:16 -0700 Subject: [PATCH 02/19] Add guidance on flattening behavior --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 0130c9009..5a7adb211 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -190,7 +190,7 @@ def repeat( If ``repeats`` is an array, the array must have an integer data type. axis: Optional[int] - the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must repeat elements of a flattened input array ``x`` and return the result as a one-dimensional output array. Default: ``None``. + the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must repeat elements of the flattened input array ``x`` and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. Returns ------- From 0ee5df95b23351af9891bf95cfbb67d218683fe4 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 21 Sep 2023 01:43:15 -0700 Subject: [PATCH 03/19] Fix missing parenthesis --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 5a7adb211..88757b7e9 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -184,7 +184,7 @@ def repeat( If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and - - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M)`` (i.e., be a one-dimensional array having shape ``(1)`` or ``(M)``. + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M)`` (i.e., be a one-dimensional array having shape ``(1)`` or ``(M)``). - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(M)`` (i.e., the number of sequence elements must be either ``1`` or ``M``). - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M)``. From 7b31dd85a13b3565dd69b875303d296c485d0264 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 18 Oct 2023 16:50:31 -0700 Subject: [PATCH 04/19] Document shape as a proper tuple --- src/array_api_stubs/_draft/manipulation_functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 88757b7e9..1dc0359d3 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -178,15 +178,15 @@ def repeat( If ``axis`` is ``None``, let ``N = prod(x.shape)`` and - - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N)`` (i.e., be a one-dimensional array having shape ``(1)`` or ``(N)``). - - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(N)`` (i.e., the number of sequence elements be either ``1`` or ``N``). - - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N)`. + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``). + - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(N,)`` (i.e., the number of sequence elements be either ``1`` or ``N``). + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`. If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and - - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M)`` (i.e., be a one-dimensional array having shape ``(1)`` or ``(M)``). - - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(M)`` (i.e., the number of sequence elements must be either ``1`` or ``M``). - - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M)``. + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``). + - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(M,)`` (i.e., the number of sequence elements must be either ``1`` or ``M``). + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``. If ``repeats`` is an array, the array must have an integer data type. axis: Optional[int] From cb7417eb10abc5276b2c5a6068f095d37e893d9a Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 18 Oct 2023 17:08:52 -0700 Subject: [PATCH 05/19] Add note advising accelerator libraries to provide a warning in their documentation --- src/array_api_stubs/_draft/manipulation_functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 1dc0359d3..76eed7b21 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -189,6 +189,11 @@ def repeat( - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``. If ``repeats`` is an array, the array must have an integer data type. + + + .. note:: + For specification-conforming array libraries supporting hardware acceleration, providing an array of ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. + axis: Optional[int] the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must repeat elements of the flattened input array ``x`` and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. From 23d0514bba649a7a26cfcf3818b0422f8d5015af Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 19 Oct 2023 02:36:58 -0700 Subject: [PATCH 06/19] Fix note --- src/array_api_stubs/_draft/manipulation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 76eed7b21..1358d21ec 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -191,8 +191,8 @@ def repeat( If ``repeats`` is an array, the array must have an integer data type. - .. note:: - For specification-conforming array libraries supporting hardware acceleration, providing an array of ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. + .. note:: + For specification-conforming array libraries supporting hardware acceleration, providing an array of ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. axis: Optional[int] the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must repeat elements of the flattened input array ``x`` and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. From dccf3662ee1a8e97496db2e0aab1c10e5fc754cf Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 25 Jan 2024 04:12:01 -0800 Subject: [PATCH 07/19] Update copy --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index e72debaaa..98df7c106 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -196,7 +196,7 @@ def repeat( For specification-conforming array libraries supporting hardware acceleration, providing an array of ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. axis: Optional[int] - the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must repeat elements of the flattened input array ``x`` and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. + the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. Returns ------- From b93274f73f56d6813f3b5fcc647f0c1aa49e19f6 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 25 Jan 2024 04:17:19 -0800 Subject: [PATCH 08/19] Add admonition regarding data-dependent output shape --- src/array_api_stubs/_draft/manipulation_functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 98df7c106..d4cac8c24 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -170,6 +170,11 @@ def repeat( """ Repeats elements of an array. + .. admonition:: Data-dependent output shape + :class: important + + The shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. + Parameters ---------- x: array From d27cdad61743a3af0783a39dd9de475cf98e85ad Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 25 Jan 2024 04:18:31 -0800 Subject: [PATCH 09/19] Remove admonition, as covered by subsequent note --- src/array_api_stubs/_draft/manipulation_functions.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index d4cac8c24..98df7c106 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -170,11 +170,6 @@ def repeat( """ Repeats elements of an array. - .. admonition:: Data-dependent output shape - :class: important - - The shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. - Parameters ---------- x: array From 8f9c063c2544ecc315b7b497d1e0d14721ab8858 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 25 Jan 2024 04:19:01 -0800 Subject: [PATCH 10/19] Update copy --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 98df7c106..960dca59b 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -193,7 +193,7 @@ def repeat( .. note:: - For specification-conforming array libraries supporting hardware acceleration, providing an array of ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. + For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. axis: Optional[int] the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. From 6e0064ed417c71a77e74aa43b2d0fecce65f0be4 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 7 Feb 2024 20:57:00 -0800 Subject: [PATCH 11/19] Update copy --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 960dca59b..4807ccce3 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -193,7 +193,7 @@ def repeat( .. note:: - For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. Conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. + For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. axis: Optional[int] the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. From 5c5e34ace7e52e46a839af7a3baa32bcf816bbe1 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Feb 2024 02:16:05 -0800 Subject: [PATCH 12/19] style: fix indentation level --- src/array_api_stubs/_draft/manipulation_functions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 4807ccce3..da3e8315e 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -191,9 +191,8 @@ def repeat( If ``repeats`` is an array, the array must have an integer data type. - - .. note:: - For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. + .. note:: + For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. axis: Optional[int] the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. From 8d7d7008f5e277192e7734614e934f244913ab35 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 20 Feb 2024 02:15:10 -0800 Subject: [PATCH 13/19] docs: add data-dependent shape admonition --- src/array_api_stubs/_draft/manipulation_functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index da3e8315e..82abde3e6 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -170,6 +170,11 @@ def repeat( """ Repeats elements of an array. + .. admonition:: Data-dependent output shape + :class: important + + The shape of the output array for this function depend on the data values in the `repeats` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. + Parameters ---------- x: array From fa3fbd3c9917051f685a462f2b1d508c6ed80167 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 20 Feb 2024 02:16:49 -0800 Subject: [PATCH 14/19] docs: fix grammar --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 82abde3e6..a18b259ee 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -173,7 +173,7 @@ def repeat( .. admonition:: Data-dependent output shape :class: important - The shape of the output array for this function depend on the data values in the `repeats` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. + The shape of the output array for this function depends on the data values in the `repeats` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. Parameters ---------- From 35e9d44811895815366655c7762111e1139224e9 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 20 Feb 2024 22:47:04 -0800 Subject: [PATCH 15/19] remove: drop support for sequences --- src/array_api_stubs/_draft/manipulation_functions.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index a18b259ee..2d86d8f5f 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -162,7 +162,7 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: def repeat( x: array, - repeats: Union[int, Sequence[int], array], + repeats: Union[int, array], /, *, axis: Optional[int] = None, @@ -173,25 +173,23 @@ def repeat( .. admonition:: Data-dependent output shape :class: important - The shape of the output array for this function depends on the data values in the `repeats` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. + Unless ``repeats`` is a literal ``int``, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. Parameters ---------- x: array input array containing elements to repeat. - repeats: Union[int, Sequence[int], array] + repeats: Union[int, array] the number of repetitions for each element. If ``axis`` is ``None``, let ``N = prod(x.shape)`` and - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``). - - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(N,)`` (i.e., the number of sequence elements be either ``1`` or ``N``). - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`. If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``). - - if ``repeats`` is a sequence of integers, ``len(repeats)`` must be broadcast compatible with the shape ``(M,)`` (i.e., the number of sequence elements must be either ``1`` or ``M``). - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``. If ``repeats`` is an array, the array must have an integer data type. From 03099a1e174f2dc9d7e4d3c5f3168dfd649c4604 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 20 Feb 2024 22:56:01 -0800 Subject: [PATCH 16/19] docs: update note --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 2d86d8f5f..8bd24a70a 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -173,7 +173,7 @@ def repeat( .. admonition:: Data-dependent output shape :class: important - Unless ``repeats`` is a literal ``int``, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing array values. Accordingly, such libraries may choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. + When ``repeats`` is an array, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing the values in ``repeats``. Accordingly, such libraries may choose to omit support for ``repeats`` arrays; however, conforming implementations must support providing a literal ``int``. See :ref:`data-dependent-output-shapes` section for more details. Parameters ---------- From e377c4ebe622a117a3b20e693f3ef78de2000b07 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 20 Feb 2024 22:59:25 -0800 Subject: [PATCH 17/19] refactor: remove unused type --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 8bd24a70a..db1bbd58e 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -16,7 +16,7 @@ ] -from ._types import List, Optional, Tuple, Union, Sequence, array +from ._types import List, Optional, Tuple, Union, array def broadcast_arrays(*arrays: array) -> List[array]: From 0418959d3f29621cd5ac8a4cdbb4f91f94743c6e Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 21 Feb 2024 13:27:48 -0800 Subject: [PATCH 18/19] docs: fix missing word --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index db1bbd58e..aab1bacf7 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -203,7 +203,7 @@ def repeat( Returns ------- out: array - an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array have the same shape as ``x``, except for the axis (dimension) along which elements were repeated. + an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array must have the same shape as ``x``, except for the axis (dimension) along which elements were repeated. """ From 8c89fdeb4a871012192166c74c32b587122ccfc4 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 21 Feb 2024 13:33:09 -0800 Subject: [PATCH 19/19] docs: clarify function behavior --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index aab1bacf7..4d7a17dda 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -168,7 +168,7 @@ def repeat( axis: Optional[int] = None, ) -> array: """ - Repeats elements of an array. + Repeats each element of an array a specified number of times on a per-element basis. .. admonition:: Data-dependent output shape :class: important