Skip to content

Add repeat to the specification #690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/draft/API_specification/manipulation_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Objects in API
flip
moveaxis
permute_dims
repeat
reshape
roll
squeeze
Expand Down
48 changes: 48 additions & 0 deletions src/array_api_stubs/_draft/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"flip",
"moveaxis",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand Down Expand Up @@ -159,6 +160,53 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
"""


def repeat(
x: array,
repeats: Union[int, array],
/,
*,
axis: Optional[int] = None,
) -> array:
"""
Repeats each element of an array a specified number of times on a per-element basis.

.. admonition:: Data-dependent output shape
:class: important

When ``repeats`` is an array, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing the values in ``repeats``. Accordingly, such libraries may choose to omit support for ``repeats`` arrays; however, conforming implementations must support providing a literal ``int``. See :ref:`data-dependent-output-shapes` section for more details.

Parameters
----------
x: array
input array containing elements to repeat.
repeats: Union[int, array]
the number of repetitions for each element.

If ``axis`` is ``None``, let ``N = prod(x.shape)`` and

- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``).
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`.

If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and

- if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``).
- if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``.

If ``repeats`` is an array, the array must have an integer data type.

.. note::
For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array.

axis: Optional[int]
the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``.

Returns
-------
out: array
an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array must have the same shape as ``x``, except for the axis (dimension) along which elements were repeated.
"""


def reshape(
x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
) -> array:
Expand Down