Skip to content

RFC: add support for repeating each element of an array #654

Closed
@kgryte

Description

@kgryte

This RFC proposes adding support to the array API specification for repeating each element of an array.

Overview

Based on array comparison data, the API is available in most array libraries. The main exception is PyTorch which deviates in its naming convention (repeat_interleave vs NumPy et al's repeat).

Prior art

Proposal

def repeat(x: array, repeats: Union[int, Sequence[int], array], /, *, axis: Optional[int] = None)
  • repeats: the number of repetitions for each element.

    If axis is not None,

    • if repeats is an array, repeats.shape must broadcast to x.shape[axis].
    • if repeats is a sequence of ints, len(repeats) must broadcast to x.shape[axis].
    • if repeats is an integer, repeats must be broadcasted to match the size of a specified axis.

    If axis is None,

    • if repeats is an array, repeats.shape must broadcast to prod(x.shape).
    • if repeats is a sequence of ints, len(repeats) must broadcast to prod(x.shape).
    • if repeats is an integer, repeats must be broadcasted to match the size of the flattened array.
  • axis: specifies the axis along which to repeat values. If None, use a flattened input array and return a flat output array.

Questions

  • Both PyTorch and JAX support a kwarg for specifying the output size in order to avoid stream synchronization (PyTorch) and to allow compilation (JAX). Without such kwarg support, is this API viable? And what are the reasons for needing this kwarg when other array libraries (TensorFlow) omit such a kwarg?
  • When flattening the input array, flatten in row-major order? (precedent: nonzero)
  • Is PyTorch okay adding a repeat function in its main namespace, given the divergence in behavior for torch.Tensor.repeat, which behaves similar to np.tile?
  • CuPy only allows int, List, and Tuple for repeats, not an array. PyTorch may prefer a list of ints (see Unnecessary cuda synchronizations that we should remove in PyTorch pytorch/pytorch#108968).

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.topic: ManipulationArray manipulation and transformation.

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions