|
6 | 6 | "flip",
|
7 | 7 | "moveaxis",
|
8 | 8 | "permute_dims",
|
| 9 | + "repeat", |
9 | 10 | "reshape",
|
10 | 11 | "roll",
|
11 | 12 | "squeeze",
|
@@ -159,6 +160,53 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
|
159 | 160 | """
|
160 | 161 |
|
161 | 162 |
|
| 163 | +def repeat( |
| 164 | + x: array, |
| 165 | + repeats: Union[int, array], |
| 166 | + /, |
| 167 | + *, |
| 168 | + axis: Optional[int] = None, |
| 169 | +) -> array: |
| 170 | + """ |
| 171 | + Repeats each element of an array a specified number of times on a per-element basis. |
| 172 | +
|
| 173 | + .. admonition:: Data-dependent output shape |
| 174 | + :class: important |
| 175 | +
|
| 176 | + When ``repeats`` is an array, the shape of the output array for this function depends on the data values in the ``repeats`` array; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) may find this function difficult to implement without knowing the values in ``repeats``. Accordingly, such libraries may choose to omit support for ``repeats`` arrays; however, conforming implementations must support providing a literal ``int``. See :ref:`data-dependent-output-shapes` section for more details. |
| 177 | +
|
| 178 | + Parameters |
| 179 | + ---------- |
| 180 | + x: array |
| 181 | + input array containing elements to repeat. |
| 182 | + repeats: Union[int, array] |
| 183 | + the number of repetitions for each element. |
| 184 | +
|
| 185 | + If ``axis`` is ``None``, let ``N = prod(x.shape)`` and |
| 186 | +
|
| 187 | + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(N,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(N,)``). |
| 188 | + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape `(N,)`. |
| 189 | +
|
| 190 | + If ``axis`` is not ``None``, let ``M = x.shape[axis]`` and |
| 191 | +
|
| 192 | + - if ``repeats`` is an array, ``repeats`` must be broadcast compatible with the shape ``(M,)`` (i.e., be a one-dimensional array having shape ``(1,)`` or ``(M,)``). |
| 193 | + - if ``repeats`` is an integer, ``repeats`` must be broadcasted to the shape ``(M,)``. |
| 194 | +
|
| 195 | + If ``repeats`` is an array, the array must have an integer data type. |
| 196 | +
|
| 197 | + .. note:: |
| 198 | + For specification-conforming array libraries supporting hardware acceleration, providing an array for ``repeats`` may cause device synchronization due to an unknown output shape. For those array libraries where synchronization concerns are applicable, conforming array libraries are advised to include a warning in their documentation regarding potential performance degradation when ``repeats`` is an array. |
| 199 | +
|
| 200 | + axis: Optional[int] |
| 201 | + the axis (dimension) along which to repeat elements. If ``axis`` is `None`, the function must flatten the input array ``x`` and then repeat elements of the flattened input array and return the result as a one-dimensional output array. A flattened input array must be flattened in row-major, C-style order. Default: ``None``. |
| 202 | +
|
| 203 | + Returns |
| 204 | + ------- |
| 205 | + out: array |
| 206 | + an output array containing repeated elements. The returned array must have the same data type as ``x``. If ``axis`` is ``None``, the returned array must be a one-dimensional array; otherwise, the returned array must have the same shape as ``x``, except for the axis (dimension) along which elements were repeated. |
| 207 | + """ |
| 208 | + |
| 209 | + |
162 | 210 | def reshape(
|
163 | 211 | x: array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
|
164 | 212 | ) -> array:
|
|
0 commit comments