@@ -528,21 +528,27 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
528
528
529
529
530
530
def ceil (x : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
531
- if xp .issubdtype (x .dtype , xp .integer ):
532
- return x
533
- return xp .ceil (x , ** kwargs )
531
+ result = xp .ceil (x , ** kwargs )
532
+ if result .dtype != x .dtype :
533
+ # numpy < 2: ceil(int array) is float
534
+ result = xp .asarray (result , dtype = x .dtype )
535
+ return result
534
536
535
537
536
538
def floor (x : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
537
- if xp .issubdtype (x .dtype , xp .integer ):
538
- return x
539
- return xp .floor (x , ** kwargs )
539
+ result = xp .floor (x , ** kwargs )
540
+ if result .dtype != x .dtype :
541
+ # numpy < 2: floor(int array) is float
542
+ result = xp .asarray (result , dtype = x .dtype )
543
+ return result
540
544
541
545
542
546
def trunc (x : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
543
- if xp .issubdtype (x .dtype , xp .integer ):
544
- return x
545
- return xp .trunc (x , ** kwargs )
547
+ result = xp .trunc (x , ** kwargs )
548
+ if result .dtype != x .dtype :
549
+ # numpy < 2: trunc(int array) is float
550
+ result = xp .asarray (result , dtype = x .dtype )
551
+ return result
546
552
547
553
548
554
# linear algebra functions
0 commit comments