Skip to content

Commit 0c37ce7

Browse files
authored
Merge pull request #137 from asmeurer/signbit-nan
Fix sign() for torch and cupy
2 parents c656782 + b9854a7 commit 0c37ce7

File tree

4 files changed

+24
-5
lines changed

4 files changed

+24
-5
lines changed

Diff for: array_api_compat/cupy/_aliases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ def asarray(
108108

109109
return cp.array(obj, dtype=dtype, **kwargs)
110110

111+
def sign(x: ndarray, /) -> ndarray:
112+
# CuPy sign() does not propagate nans. See
113+
# https://github.com/data-apis/array-api-compat/issues/136
114+
out = cp.sign(x)
115+
out[cp.isnan(x)] = cp.nan
116+
return out
117+
111118
# These functions are completely new here. If the library already has them
112119
# (i.e., numpy 2.0), use the library version instead of our wrapper.
113120
if hasattr(cp, 'vecdot'):
@@ -122,6 +129,6 @@ def asarray(
122129
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
123130
'acosh', 'asin', 'asinh', 'atan', 'atan2',
124131
'atanh', 'bitwise_left_shift', 'bitwise_invert',
125-
'bitwise_right_shift', 'concat', 'pow']
132+
'bitwise_right_shift', 'concat', 'pow', 'sign']
126133

127134
_all_ignore = ['cp', 'get_xp']

Diff for: array_api_compat/torch/_aliases.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,21 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
706706
axis = 0
707707
return torch.index_select(x, axis, indices, **kwargs)
708708

709+
def sign(x: array, /) -> array:
710+
# torch sign() does not support complex numbers and does not propagate
711+
# nans. See https://github.com/data-apis/array-api-compat/issues/136
712+
if x.dtype.is_complex:
713+
out = x/torch.abs(x)
714+
# sign(0) = 0 but the above formula would give nan
715+
out[x == 0+0j] = 0+0j
716+
return out
717+
else:
718+
out = torch.sign(x)
719+
if x.dtype.is_floating_point:
720+
out[torch.isnan(x)] = torch.nan
721+
return out
722+
723+
709724
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
710725
'newaxis', 'conj', 'add', 'atan2', 'bitwise_and',
711726
'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift',
@@ -719,6 +734,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
719734
'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult',
720735
'UniqueInverseResult', 'unique_all', 'unique_counts',
721736
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
722-
'vecdot', 'tensordot', 'isdtype', 'take']
737+
'vecdot', 'tensordot', 'isdtype', 'take', 'sign']
723738

724739
_all_ignore = ['torch', 'get_xp']

Diff for: cupy-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ array_api_tests/test_special_cases.py::test_unary[expm1(x_i is -0) -> -0]
160160
array_api_tests/test_special_cases.py::test_unary[floor(x_i is -0) -> -0]
161161
array_api_tests/test_special_cases.py::test_unary[log1p(x_i is -0) -> -0]
162162
array_api_tests/test_special_cases.py::test_unary[round(x_i is -0) -> -0]
163-
array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN]
164163
array_api_tests/test_special_cases.py::test_unary[sin(x_i is -0) -> -0]
165164
array_api_tests/test_special_cases.py::test_unary[sinh(x_i is -0) -> -0]
166165
array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0]

Diff for: torch-xfails.txt

-2
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity
169169
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
170170
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0]
171171
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0]
172-
array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN]
173172

174173
# Float correction is not supported by pytorch
175174
# (https://github.com/data-apis/array-api-tests/issues/168)
@@ -186,7 +185,6 @@ array_api_tests/test_statistical_functions.py::test_sum
186185
array_api_tests/test_statistical_functions.py::test_prod
187186

188187
# These functions do not yet support complex numbers
189-
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
190188
array_api_tests/test_operators_and_elementwise_functions.py::test_expm1
191189
array_api_tests/test_operators_and_elementwise_functions.py::test_round
192190
array_api_tests/test_set_functions.py::test_unique_counts

0 commit comments

Comments
 (0)