1
+ from typing import Any , Dict , Hashable
2
+
1
3
import dask .array as da
2
4
import numpy as np
5
+ import xarray as xr
3
6
from numba import guvectorize
4
7
from typing_extensions import Literal
5
8
from xarray import Dataset
@@ -181,7 +184,7 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
181
184
n_het = ~ (n_hom_alt | n_hom_ref )
182
185
# This would 0 out the `het` case with any missing calls
183
186
agg = lambda x : xr .where (M , False , x ).sum (dim = dim ) # type: ignore[no-untyped-call]
184
- return xr . Dataset (
187
+ return Dataset (
185
188
{
186
189
f"{ odim } _n_het" : agg (n_het ), # type: ignore[no-untyped-call]
187
190
f"{ odim } _n_hom_ref" : agg (n_hom_ref ), # type: ignore[no-untyped-call]
@@ -192,26 +195,29 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
192
195
193
196
194
197
def allele_frequency (ds : Dataset ) -> Dataset :
195
- AC = count_variant_alleles (ds )
198
+ data_vars : Dict [Hashable , Any ] = {}
199
+ # only compute variant allele count if not already in dataset
200
+ if "variant_allele_count" in ds :
201
+ AC = ds ["variant_allele_count" ]
202
+ else :
203
+ AC = count_variant_alleles (ds , merge = False )["variant_allele_count" ]
204
+ data_vars ["variant_allele_count" ] = AC
196
205
197
206
M = ds ["call_genotype_mask" ].stack (calls = ("samples" , "ploidy" ))
198
207
AN = (~ M ).sum (dim = "calls" ) # type: ignore
199
208
assert AN .shape == (ds .dims ["variants" ],)
200
209
201
- return xr .Dataset (
202
- {
203
- "variant_allele_count" : AC ,
204
- "variant_allele_total" : AN ,
205
- "variant_allele_frequency" : AC / AN ,
206
- }
207
- )
210
+ data_vars ["variant_allele_total" ] = AN
211
+ data_vars ["variant_allele_frequency" ] = AC / AN
212
+ return Dataset (data_vars )
208
213
209
214
210
- def variant_stats (ds : Dataset ) -> Dataset :
211
- return xr .merge (
215
+ def variant_stats (ds : Dataset , merge : bool = True ) -> Dataset :
216
+ new_ds = xr .merge (
212
217
[
213
218
call_rate (ds , dim = "samples" ),
214
219
genotype_count (ds , dim = "samples" ),
215
220
allele_frequency (ds ),
216
221
]
217
222
)
223
+ return ds .merge (new_ds ) if merge else new_ds
0 commit comments