Skip to content

Commit c631fa8

Browse files
committed
BUG: add torch.repeat
1 parent 3e5fdc0 commit c631fa8

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

Diff for: array_api_compat/torch/_aliases.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,11 @@ def count_nonzero(
574574
return result
575575

576576

577+
# "repeat" is torch.repeat_interleave; also the dim argument
578+
def repeat(x: Array, repeats: int | array, /, *, axis: int | None = None) -> Array:
579+
return torch.repeat_interleave(x, repeats, axis)
580+
581+
577582
def where(
578583
condition: Array,
579584
x1: Array | bool | int | float | complex,
@@ -854,6 +859,6 @@ def sign(x: Array, /) -> Array:
854859
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
855860
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
856861
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
857-
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo']
862+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
858863

859864
_all_ignore = ['torch', 'get_xp']

Diff for: torch-xfails.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,8 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype
120120
array_api_tests/test_data_type_functions.py::test_iinfo_dtype
121121

122122
# 2023.12 support
123-
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
123+
# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers
124124
array_api_tests/test_manipulation_functions.py::test_repeat
125-
array_api_tests/test_signatures.py::test_func_signature[repeat]
126125
# Argument 'device' missing from signature
127126
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
128127
# Argument 'max_version' missing from signature

0 commit comments

Comments
 (0)