Skip to content

Commit cc4f4cd

Browse files
committed
Rebase and update
1 parent 3524b8f commit cc4f4cd

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

docs/api.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ Methods
2727
gwas_linear_regression
2828
hardy_weinberg_test
2929
regenie
30+
variant_stats
3031

3132
Utilities
32-
=========
3333

3434
.. autosummary::
3535
:toctree: generated/

sgkit/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from .display import display_genotypes
1010
from .io.vcfzarr_reader import read_vcfzarr
11-
from .stats.aggregation import count_call_alleles, count_variant_alleles
11+
from .stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
1212
from .stats.association import gwas_linear_regression
1313
from .stats.hwe import hardy_weinberg_test
1414
from .stats.regenie import regenie
@@ -27,4 +27,5 @@
2727
"read_vcfzarr",
2828
"regenie",
2929
"hardy_weinberg_test",
30+
"variant_stats",
3031
]

sgkit/stats/aggregation.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,11 @@ def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
175175
def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
176176
odim = _swap(dim)[:-1]
177177
M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"]
178-
n_het = (G > 0).any(dim="ploidy") & (G == 0).any(dim="ploidy")
179178
n_hom_ref = (G == 0).all(dim="ploidy")
180-
n_hom_alt = (G > 0).all(dim="ploidy")
179+
n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy")
181180
n_non_ref = (G > 0).any(dim="ploidy")
181+
n_het = ~(n_hom_alt | n_hom_ref)
182+
# This would 0 out the `het` case with any missing calls
182183
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
183184
return xr.Dataset(
184185
{
@@ -191,7 +192,7 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
191192

192193

193194
def allele_frequency(ds: Dataset) -> Dataset:
194-
AC = count_alleles(ds)
195+
AC = count_variant_alleles(ds)
195196

196197
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
197198
AN = (~M).sum(dim="calls") # type: ignore

0 commit comments

Comments
 (0)