10
10
# mypy: disable-error-code=no-any-decorated
11
11
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
12
12
13
-
14
- @pytest .mark .parametrize (
13
+ param_assert_equal_close = pytest .mark .parametrize (
15
14
"func" ,
16
15
[
17
16
xp_assert_equal ,
21
20
),
22
21
],
23
22
)
23
+
24
+
25
+ @param_assert_equal_close
24
26
def test_assert_close_equal_basic (xp : ModuleType , func : Callable [..., None ]): # type: ignore[no-any-explicit]
25
27
func (xp .asarray (0 ), xp .asarray (0 ))
26
28
func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]))
@@ -40,16 +42,7 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): #
40
42
41
43
@pytest .mark .skip_xp_backend (Backend .NUMPY , reason = "test other ns vs. numpy" )
42
44
@pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "test other ns vs. numpy" )
43
- @pytest .mark .parametrize (
44
- "func" ,
45
- [
46
- xp_assert_equal ,
47
- pytest .param (
48
- xp_assert_close ,
49
- marks = pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "no isdtype" ),
50
- ),
51
- ],
52
- )
45
+ @param_assert_equal_close
53
46
def test_assert_close_equal_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[no-any-explicit]
54
47
with pytest .raises (AssertionError ):
55
48
func (xp .asarray (0 ), np .asarray (0 ))
@@ -68,3 +61,30 @@ def test_assert_close_tolerance(xp: ModuleType):
68
61
xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 3 )
69
62
with pytest .raises (AssertionError ):
70
63
xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 1 )
64
+
65
+
66
+ @param_assert_equal_close
67
+ @pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "no bool indexing by sparse arrays" )
68
+ def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[no-any-explicit]
69
+ """On dask and other lazy backends, test that a shape with NaN's or None's
70
+ can be compared to a real shape.
71
+ """
72
+ a = xp .asarray ([1 , 2 ])
73
+ a = a [a > 1 ]
74
+
75
+ func (a , xp .asarray ([2 ]))
76
+ with pytest .raises (AssertionError ):
77
+ func (a , xp .asarray ([2 , 3 ]))
78
+ with pytest .raises (AssertionError ):
79
+ func (a , xp .asarray (2 ))
80
+ with pytest .raises (AssertionError ):
81
+ func (a , xp .asarray ([3 ]))
82
+
83
+ # Swap actual and desired
84
+ func (xp .asarray ([2 ]), a )
85
+ with pytest .raises (AssertionError ):
86
+ func (xp .asarray ([2 , 3 ]), a )
87
+ with pytest .raises (AssertionError ):
88
+ func (xp .asarray (2 ), a )
89
+ with pytest .raises (AssertionError ):
90
+ func (xp .asarray ([3 ]), a )
0 commit comments