|
| 1 | +from typing import Any, Dict, Hashable |
| 2 | + |
1 | 3 | import dask.array as da
|
2 | 4 | import numpy as np
|
| 5 | +import xarray as xr |
3 | 6 | from numba import guvectorize
|
| 7 | +from typing_extensions import Literal |
4 | 8 | from xarray import Dataset
|
5 | 9 |
|
6 | 10 | from sgkit.typing import ArrayLike
|
7 | 11 | from sgkit.utils import merge_datasets
|
8 | 12 |
|
| 13 | +Dimension = Literal["samples", "variants"] |
| 14 | + |
9 | 15 |
|
10 | 16 | @guvectorize( # type: ignore
|
11 | 17 | [
|
@@ -162,3 +168,91 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
|
162 | 168 | }
|
163 | 169 | )
|
164 | 170 | return merge_datasets(ds, new_ds) if merge else new_ds
|
| 171 | + |
| 172 | + |
| 173 | +def _swap(dim: Dimension) -> Dimension: |
| 174 | + return "samples" if dim == "variants" else "variants" |
| 175 | + |
| 176 | + |
| 177 | +def call_rate(ds: Dataset, dim: Dimension) -> Dataset: |
| 178 | + odim = _swap(dim)[:-1] |
| 179 | + n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim) |
| 180 | + return xr.Dataset( |
| 181 | + {f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]} |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +def genotype_count(ds: Dataset, dim: Dimension) -> Dataset: |
| 186 | + odim = _swap(dim)[:-1] |
| 187 | + M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"] |
| 188 | + n_hom_ref = (G == 0).all(dim="ploidy") |
| 189 | + n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy") |
| 190 | + n_non_ref = (G > 0).any(dim="ploidy") |
| 191 | + n_het = ~(n_hom_alt | n_hom_ref) |
| 192 | + # This would 0 out the `het` case with any missing calls |
| 193 | + agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call] |
| 194 | + return Dataset( |
| 195 | + { |
| 196 | + f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call] |
| 197 | + f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call] |
| 198 | + f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call] |
| 199 | + f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call] |
| 200 | + } |
| 201 | + ) |
| 202 | + |
| 203 | + |
| 204 | +def allele_frequency(ds: Dataset) -> Dataset: |
| 205 | + data_vars: Dict[Hashable, Any] = {} |
| 206 | + # only compute variant allele count if not already in dataset |
| 207 | + if "variant_allele_count" in ds: |
| 208 | + AC = ds["variant_allele_count"] |
| 209 | + else: |
| 210 | + AC = count_variant_alleles(ds, merge=False)["variant_allele_count"] |
| 211 | + data_vars["variant_allele_count"] = AC |
| 212 | + |
| 213 | + M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy")) |
| 214 | + AN = (~M).sum(dim="calls") # type: ignore |
| 215 | + assert AN.shape == (ds.dims["variants"],) |
| 216 | + |
| 217 | + data_vars["variant_allele_total"] = AN |
| 218 | + data_vars["variant_allele_frequency"] = AC / AN |
| 219 | + return Dataset(data_vars) |
| 220 | + |
| 221 | + |
| 222 | +def variant_stats(ds: Dataset, merge: bool = True) -> Dataset: |
| 223 | + """Compute quality control variant statistics from genotype calls. |
| 224 | +
|
| 225 | + Parameters |
| 226 | + ---------- |
| 227 | + ds : Dataset |
| 228 | + Genotype call dataset such as from |
| 229 | + `sgkit.create_genotype_call_dataset`. |
| 230 | + merge : bool, optional |
| 231 | + If True (the default), merge the input dataset and the computed |
| 232 | + output variables into a single dataset. Output variables will |
| 233 | + overwrite any input variables with the same name, and a warning |
| 234 | + will be issued in this case. |
| 235 | + If False, return only the computed output variables. |
| 236 | +
|
| 237 | + Returns |
| 238 | + ------- |
| 239 | + Dataset |
| 240 | + A dataset containing the following variables: |
| 241 | + - `variant_n_called` (variants): The number of samples with called genotypes. |
| 242 | + - `variant_call_rate` (variants): The fraction of samples with called genotypes. |
| 243 | + - `variant_n_het` (variants): The number of samples with heterozygous calls. |
| 244 | + - `variant_n_hom_ref` (variants): The number of samples with homozygous reference calls. |
| 245 | + - `variant_n_hom_alt` (variants): The number of samples with homozygous alternate calls. |
| 246 | + - `variant_n_non_ref` (variants): The number of samples that are not homozygous reference calls. |
| 247 | + - `variant_allele_count` (variants, alleles): The number of occurrences of each allele. |
| 248 | + - `variant_allele_total` (variants): The number of occurrences of all alleles. |
| 249 | + - `variant_allele_frequency` (variants, alleles): The frequency of occurence of each allele. |
| 250 | + """ |
| 251 | + new_ds = xr.merge( |
| 252 | + [ |
| 253 | + call_rate(ds, dim="samples"), |
| 254 | + genotype_count(ds, dim="samples"), |
| 255 | + allele_frequency(ds), |
| 256 | + ] |
| 257 | + ) |
| 258 | + return merge_datasets(ds, new_ds) if merge else new_ds |
0 commit comments