Skip to content

Commit de39ee0

Browse files
committed
BUG: make ceil,trunc,floor always respect view/copy semantics
1 parent 7b376a0 commit de39ee0

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

array_api_compat/common/_aliases.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -528,21 +528,27 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
528528

529529

530530
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
531-
if xp.issubdtype(x.dtype, xp.integer):
532-
return x
533-
return xp.ceil(x, **kwargs)
531+
result = xp.ceil(x, **kwargs)
532+
if result.dtype != x.dtype:
533+
# numpy < 2: ceil(int array) is float
534+
result = xp.asarray(result, dtype=x.dtype)
535+
return result
534536

535537

536538
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
537-
if xp.issubdtype(x.dtype, xp.integer):
538-
return x
539-
return xp.floor(x, **kwargs)
539+
result = xp.floor(x, **kwargs)
540+
if result.dtype != x.dtype:
541+
# numpy < 2: floor(int array) is float
542+
result = xp.asarray(result, dtype=x.dtype)
543+
return result
540544

541545

542546
def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
543-
if xp.issubdtype(x.dtype, xp.integer):
544-
return x
545-
return xp.trunc(x, **kwargs)
547+
result = xp.trunc(x, **kwargs)
548+
if result.dtype != x.dtype:
549+
# numpy < 2: trunc(int array) is float
550+
result = xp.asarray(result, dtype=x.dtype)
551+
return result
546552

547553

548554
# linear algebra functions

0 commit comments

Comments
 (0)