Skip to content

Commit 9d200ea

Browse files
authored
Add repeat to the specification (#690)
1 parent 11273e6 commit 9d200ea

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

Diff for: spec/draft/API_specification/manipulation_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Objects in API
2525
flip
2626
moveaxis
2727
permute_dims
28+
repeat
2829
reshape
2930
roll
3031
squeeze

Diff for: src/array_api_stubs/_draft/manipulation_functions.py

+48
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"flip",
77
"moveaxis",
88
"permute_dims",
9+
"repeat",
910
"reshape",
1011
"roll",
1112
"squeeze",
@@ -159,6 +160,53 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
159160
"""
160161

161162

163+
def repeat(
164+
x: array,
165+
repeats: Union[int, array],
166+
/,
167+
*,
168+
axis: Optional[int] = None,
169+
) -> array:
170+
"""
171+
Repeats each element of an array a specified number of times on a per-element basis.
172+
173+
.. admonition:: Data-dependent output shape
174+
:class: important
175+
176+
When ``repeats`` is an array, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing the values in ``repeats``. Accordingly, such libraries may choose to omit support for ``repeats`` arrays; however, conforming implementations must support providing a literal ``int``. See :ref:`data-dependent-output-shapes` section for more details.
177+
178+
Parameters
179+
----------
180+
x: array
181+
input array containing elements to repeat.
182+
repeats: Union[int, array]
183+
the number of repetitions for each element.
184+
185+
If ``axis`` is ``None``, let ``N = prod(x.shape)`` and
186+
187+
- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``).
188+
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`.
189+
190+
If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and
191+
192+
- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``).
193+
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``.
194+
195+
If ``repeats`` is an array, the array must have an integer data type.
196+
197+
.. note::
198+
For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array.
199+
200+
axis: Optional[int]
201+
the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``.
202+
203+
Returns
204+
-------
205+
out: array
206+
an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array must have the same shape as ``x``, except for the axis (dimension) along which elements were repeated.
207+
"""
208+
209+
162210
def reshape(
163211
x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
164212
) -> array:

0 commit comments

Comments
 (0)