4
4
import xarray as xr
5
5
from xarray import Dataset
6
6
7
- from sgkit .stats .aggregation import count_alleles
7
+ from sgkit .stats .aggregation import count_call_alleles , count_variant_alleles
8
8
from sgkit .testing import simulate_genotype_call_dataset
9
9
from sgkit .typing import ArrayLike
10
10
@@ -20,23 +20,23 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:
20
20
return ds
21
21
22
22
23
- def test_count_alleles__single_variant_single_sample ():
24
- ac = count_alleles (get_dataset ([[[1 , 0 ]]]))
23
+ def test_count_variant_alleles__single_variant_single_sample ():
24
+ ac = count_variant_alleles (get_dataset ([[[1 , 0 ]]]))
25
25
np .testing .assert_equal (ac , np .array ([[1 , 1 ]]))
26
26
27
27
28
- def test_count_alleles__multi_variant_single_sample ():
29
- ac = count_alleles (get_dataset ([[[0 , 0 ]], [[0 , 1 ]], [[1 , 0 ]], [[1 , 1 ]]]))
28
+ def test_count_variant_alleles__multi_variant_single_sample ():
29
+ ac = count_variant_alleles (get_dataset ([[[0 , 0 ]], [[0 , 1 ]], [[1 , 0 ]], [[1 , 1 ]]]))
30
30
np .testing .assert_equal (ac , np .array ([[2 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 2 ]]))
31
31
32
32
33
- def test_count_alleles__single_variant_multi_sample ():
34
- ac = count_alleles (get_dataset ([[[0 , 0 ], [1 , 0 ], [0 , 1 ], [1 , 1 ]]]))
33
+ def test_count_variant_alleles__single_variant_multi_sample ():
34
+ ac = count_variant_alleles (get_dataset ([[[0 , 0 ], [1 , 0 ], [0 , 1 ], [1 , 1 ]]]))
35
35
np .testing .assert_equal (ac , np .array ([[4 , 4 ]]))
36
36
37
37
38
- def test_count_alleles__multi_variant_multi_sample ():
39
- ac = count_alleles (
38
+ def test_count_variant_alleles__multi_variant_multi_sample ():
39
+ ac = count_variant_alleles (
40
40
get_dataset (
41
41
[
42
42
[[0 , 0 ], [0 , 0 ], [0 , 0 ]],
@@ -49,8 +49,8 @@ def test_count_alleles__multi_variant_multi_sample():
49
49
np .testing .assert_equal (ac , np .array ([[6 , 0 ], [5 , 1 ], [2 , 4 ], [0 , 6 ]]))
50
50
51
51
52
- def test_count_alleles__missing_data ():
53
- ac = count_alleles (
52
+ def test_count_variant_alleles__missing_data ():
53
+ ac = count_variant_alleles (
54
54
get_dataset (
55
55
[
56
56
[[- 1 , - 1 ], [- 1 , - 1 ], [- 1 , - 1 ]],
@@ -63,8 +63,8 @@ def test_count_alleles__missing_data():
63
63
np .testing .assert_equal (ac , np .array ([[0 , 0 ], [2 , 1 ], [1 , 2 ], [0 , 6 ]]))
64
64
65
65
66
- def test_count_alleles__higher_ploidy ():
67
- ac = count_alleles (
66
+ def test_count_variant_alleles__higher_ploidy ():
67
+ ac = count_variant_alleles (
68
68
get_dataset (
69
69
[
70
70
[[- 1 , - 1 , 0 ], [- 1 , - 1 , 1 ], [- 1 , - 1 , 2 ]],
@@ -77,12 +77,108 @@ def test_count_alleles__higher_ploidy():
77
77
np .testing .assert_equal (ac , np .array ([[1 , 1 , 1 , 0 ], [1 , 2 , 2 , 1 ]]))
78
78
79
79
80
- def test_count_alleles__chunked ():
80
+ def test_count_variant_alleles__chunked ():
81
81
rs = np .random .RandomState (0 )
82
82
calls = rs .randint (0 , 1 , size = (50 , 10 , 2 ))
83
83
ds = get_dataset (calls )
84
- ac1 = count_alleles (ds )
84
+ ac1 = count_variant_alleles (ds )
85
85
# Coerce from numpy to multiple chunks in all dimensions
86
86
ds ["call_genotype" ] = ds ["call_genotype" ].chunk (chunks = (5 , 5 , 1 )) # type: ignore[arg-type]
87
- ac2 = count_alleles (ds )
87
+ ac2 = count_variant_alleles (ds )
88
+ xr .testing .assert_equal (ac1 , ac2 ) # type: ignore[no-untyped-call]
89
+
90
+
91
+ def test_count_call_alleles__single_variant_single_sample ():
92
+ ac = count_call_alleles (get_dataset ([[[1 , 0 ]]]))
93
+ np .testing .assert_equal (ac , np .array ([[[1 , 1 ]]]))
94
+
95
+
96
+ def test_count_call_alleles__multi_variant_single_sample ():
97
+ ac = count_call_alleles (get_dataset ([[[0 , 0 ]], [[0 , 1 ]], [[1 , 0 ]], [[1 , 1 ]]]))
98
+ np .testing .assert_equal (ac , np .array ([[[2 , 0 ]], [[1 , 1 ]], [[1 , 1 ]], [[0 , 2 ]]]))
99
+
100
+
101
+ def test_count_call_alleles__single_variant_multi_sample ():
102
+ ac = count_call_alleles (get_dataset ([[[0 , 0 ], [1 , 0 ], [0 , 1 ], [1 , 1 ]]]))
103
+ np .testing .assert_equal (ac , np .array ([[[2 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 2 ]]]))
104
+
105
+
106
+ def test_count_call_alleles__multi_variant_multi_sample ():
107
+ ac = count_call_alleles (
108
+ get_dataset (
109
+ [
110
+ [[0 , 0 ], [0 , 0 ], [0 , 0 ]],
111
+ [[0 , 0 ], [0 , 0 ], [0 , 1 ]],
112
+ [[1 , 1 ], [0 , 1 ], [1 , 0 ]],
113
+ [[1 , 1 ], [1 , 1 ], [1 , 1 ]],
114
+ ]
115
+ )
116
+ )
117
+ np .testing .assert_equal (
118
+ ac ,
119
+ np .array (
120
+ [
121
+ [[2 , 0 ], [2 , 0 ], [2 , 0 ]],
122
+ [[2 , 0 ], [2 , 0 ], [1 , 1 ]],
123
+ [[0 , 2 ], [1 , 1 ], [1 , 1 ]],
124
+ [[0 , 2 ], [0 , 2 ], [0 , 2 ]],
125
+ ]
126
+ ),
127
+ )
128
+
129
+
130
+ def test_count_call_alleles__missing_data ():
131
+ ac = count_call_alleles (
132
+ get_dataset (
133
+ [
134
+ [[- 1 , - 1 ], [- 1 , - 1 ], [- 1 , - 1 ]],
135
+ [[- 1 , - 1 ], [0 , 0 ], [- 1 , 1 ]],
136
+ [[1 , 1 ], [- 1 , - 1 ], [- 1 , 0 ]],
137
+ [[1 , 1 ], [1 , 1 ], [1 , 1 ]],
138
+ ]
139
+ )
140
+ )
141
+ np .testing .assert_equal (
142
+ ac ,
143
+ np .array (
144
+ [
145
+ [[0 , 0 ], [0 , 0 ], [0 , 0 ]],
146
+ [[0 , 0 ], [2 , 0 ], [0 , 1 ]],
147
+ [[0 , 2 ], [0 , 0 ], [1 , 0 ]],
148
+ [[0 , 2 ], [0 , 2 ], [0 , 2 ]],
149
+ ]
150
+ ),
151
+ )
152
+
153
+
154
+ def test_count_call_alleles__higher_ploidy ():
155
+ ac = count_call_alleles (
156
+ get_dataset (
157
+ [
158
+ [[- 1 , - 1 , 0 ], [- 1 , - 1 , 1 ], [- 1 , - 1 , 2 ]],
159
+ [[0 , 1 , 2 ], [1 , 2 , 3 ], [- 1 , - 1 , - 1 ]],
160
+ ],
161
+ n_allele = 4 ,
162
+ n_ploidy = 3 ,
163
+ )
164
+ )
165
+ np .testing .assert_equal (
166
+ ac ,
167
+ np .array (
168
+ [
169
+ [[1 , 0 , 0 , 0 ], [0 , 1 , 0 , 0 ], [0 , 0 , 1 , 0 ]],
170
+ [[1 , 1 , 1 , 0 ], [0 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 ]],
171
+ ]
172
+ ),
173
+ )
174
+
175
+
176
+ def test_count_call_alleles__chunked ():
177
+ rs = np .random .RandomState (0 )
178
+ calls = rs .randint (0 , 1 , size = (50 , 10 , 2 ))
179
+ ds = get_dataset (calls )
180
+ ac1 = count_call_alleles (ds )
181
+ # Coerce from numpy to multiple chunks in all dimensions
182
+ ds ["call_genotype" ] = ds ["call_genotype" ].chunk (chunks = (5 , 5 , 1 )) # type: ignore[arg-type]
183
+ ac2 = count_call_alleles (ds )
88
184
xr .testing .assert_equal (ac1 , ac2 ) # type: ignore[no-untyped-call]
0 commit comments