Skip to content

Commit 8d77a2a

Browse files
committed
ENH: test can_cast(complex dtypes)
1 parent ad81cf6 commit 8d77a2a

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

Diff for: array_api_tests/test_data_type_functions.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@ def non_complex_dtypes():
1919
return xps.boolean_dtypes() | hh.real_dtypes
2020

2121

22+
def numeric_dtypes():
23+
return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes
24+
25+
2226
def float32(n: Union[int, float]) -> float:
2327
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2428

2529

30+
def _float_match_complex(complex_dtype):
31+
return xp.float32 if complex_dtype == xp.complex64 else xp.float64
32+
33+
2634
@given(
2735
x_dtype=non_complex_dtypes(),
2836
dtype=non_complex_dtypes(),
@@ -107,7 +115,7 @@ def test_broadcast_to(x, data):
107115
# TODO: test values
108116

109117

110-
@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data())
118+
@given(_from=numeric_dtypes(), to=numeric_dtypes(), data=st.data())
111119
def test_can_cast(_from, to, data):
112120
from_ = data.draw(
113121
st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_"
@@ -127,8 +135,15 @@ def test_can_cast(_from, to, data):
127135
break
128136
assert same_family is not None # sanity check
129137
if same_family:
130-
from_min, from_max = dh.dtype_ranges[_from]
131-
to_min, to_max = dh.dtype_ranges[to]
138+
from_dtype = (_float_match_complex(_from)
139+
if _from in (xp.complex64, xp.complex128)
140+
else _from)
141+
to_dtype = (_float_match_complex(to)
142+
if to in (xp.complex64, xp.complex128)
143+
else to)
144+
145+
from_min, from_max = dh.dtype_ranges[from_dtype]
146+
to_min, to_max = dh.dtype_ranges[to_dtype]
132147
expected = from_min >= to_min and from_max <= to_max
133148
else:
134149
expected = False
@@ -139,6 +154,7 @@ def test_can_cast(_from, to, data):
139154
assert out == expected, f"{out=}, but should be {expected} {f_func}"
140155

141156

157+
142158
@pytest.mark.parametrize("dtype", dh.real_float_dtypes)
143159
def test_finfo(dtype):
144160
out = xp.finfo(dtype)

0 commit comments

Comments
 (0)