Skip to content

Commit 96459db

Browse files
daletovarjeromekelleher
authored andcommitted
add minimal diversity and divergence
remove ts_to_dataset from public api make divergence take in two datasets add minimal fst Add read_vcfzarr (#40) add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api add tajimas d add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api make divergence take in two datasets add minimal fst add tajimas d add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api add minimal fst add tajimas d fix allele count update cfg remove spaces add msprime and use np.testing add libgsl-dev dependency add docstrings ignore dep warning add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api make divergence take in two datasets add minimal fst Add read_vcfzarr (#40) add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api add tajimas d add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api make divergence take in two datasets add minimal fst add tajimas d add ts_to_dataset add minimal diversity and divergence remove ts_to_dataset from public api add minimal fst add tajimas d fix allele count update cfg remove spaces add msprime and use np.testing add libgsl-dev dependency add docstrings fix divide by zero
1 parent 2f511be commit 96459db

File tree

7 files changed

+255
-29
lines changed

7 files changed

+255
-29
lines changed

.github/workflows/build.yml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
python-version: ${{ matrix.python-version }}
1919
- name: Install dependencies
2020
run: |
21+
sudo apt install libgsl-dev # Needed for msprime < 1.0. Binary wheels include GSL for >= 1.0
2122
python -m pip install --upgrade pip
2223
pip install -r requirements.txt -r requirements-dev.txt
2324
- name: Run pre-commit

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ pytest-datadir
66
hypothesis
77
statsmodels
88
zarr
9+
msprime
910
sphinx
1011
sphinx_rtd_theme

setup.cfg

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ addopts = --doctest-modules --ignore=validation
4444
norecursedirs = .eggs docs
4545
filterwarnings =
4646
error
47+
ignore::DeprecationWarning
4748

4849
[flake8]
4950
ignore =
@@ -61,7 +62,7 @@ ignore =
6162
profile = black
6263
default_section = THIRDPARTY
6364
known_first_party = sgkit
64-
known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,typing_extensions,xarray,yaml,zarr
65+
known_third_party = dask,fire,glow,hail,hypothesis,invoke,msprime,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,typing_extensions,xarray,yaml,zarr
6566
multi_line_output = 3
6667
include_trailing_comma = True
6768
force_grid_wrap = 0

sgkit/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
1212
from .stats.association import gwas_linear_regression
1313
from .stats.hwe import hardy_weinberg_test
14+
from .stats.popgen import Fst, Tajimas_D, divergence, diversity
1415
from .stats.regenie import regenie
1516

1617
__all__ = [
@@ -28,4 +29,8 @@
2829
"regenie",
2930
"hardy_weinberg_test",
3031
"variant_stats",
32+
"diversity",
33+
"divergence",
34+
"Fst",
35+
"Tajimas_D",
3136
]

sgkit/api.py

+5-28
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create_genotype_call_dataset(
2424
variant_id: Any = None,
2525
) -> xr.Dataset:
2626
"""Create a dataset of genotype calls.
27+
2728
Parameters
2829
----------
2930
variant_contig_names : list of str
@@ -45,10 +46,12 @@ def create_genotype_call_dataset(
4546
omitted all calls are unphased.
4647
variant_id: array_like, str or object, optional
4748
The unique identifier of the variant.
49+
4850
Returns
4951
-------
5052
:class:`xarray.Dataset`
5153
The dataset of genotype calls.
54+
5255
"""
5356
check_array_like(variant_contig, kind="i", ndim=1)
5457
check_array_like(variant_position, kind="i", ndim=1)
@@ -112,10 +115,12 @@ def create_genotype_dosage_dataset(
112115
missing value.
113116
variant_id: array_like, str or object, optional
114117
The unique identifier of the variant.
118+
115119
Returns
116120
-------
117121
xr.Dataset
118122
The dataset of genotype calls.
123+
119124
"""
120125
check_array_like(variant_contig, kind="i", ndim=1)
121126
check_array_like(variant_position, kind="i", ndim=1)
@@ -144,31 +149,3 @@ def create_genotype_dosage_dataset(
144149
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
145150
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
146151
return xr.Dataset(data_vars=data_vars, attrs=attrs)
147-
148-
149-
def ts_to_dataset(ts, samples=None):
150-
"""
151-
Convert the specified tskit tree sequence into an sgkit dataset.
152-
Note this just generates haploids for now. With msprime 1.0, we'll be
153-
able to generate diploid/whatever-ploid individuals easily.
154-
"""
155-
if samples is None:
156-
samples = ts.samples()
157-
tables = ts.dump_tables()
158-
alleles = []
159-
genotypes = []
160-
for var in ts.variants(samples=samples):
161-
alleles.append(var.alleles)
162-
genotypes.append(var.genotypes)
163-
alleles = np.array(alleles).astype("S")
164-
genotypes = np.expand_dims(genotypes, axis=2)
165-
166-
df = create_genotype_call_dataset(
167-
variant_contig_names=["1"],
168-
variant_contig=np.zeros(len(tables.sites), dtype=int),
169-
variant_position=tables.sites.position.astype(int),
170-
variant_alleles=alleles,
171-
sample_id=np.array([f"tsk_{u}" for u in samples]).astype("U"),
172-
call_genotype=genotypes,
173-
)
174-
return df

sgkit/stats/popgen.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Hashable
2+
3+
import dask.array as da
4+
import numpy as np
5+
import xarray as xr
6+
from xarray import DataArray, Dataset
7+
8+
from .aggregation import count_variant_alleles
9+
10+
11+
def diversity(
12+
ds: Dataset, allele_counts: Hashable = "variant_allele_count",
13+
) -> DataArray:
14+
"""Compute diversity from allele counts.
15+
16+
Because we're not providing any arguments on windowing, etc,
17+
we return the total over the whole region. Maybe this isn't
18+
the behaviour we want, but it's a starting point. Note that
19+
this is different to the tskit default behaviour where we
20+
normalise by the size of windows so that results
21+
in different windows are comparable. However, we don't have
22+
any information about the overall length of the sequence here
23+
so we can't normalise by it.
24+
25+
Parameters
26+
----------
27+
ds : Dataset
28+
Genotype call dataset.
29+
allele_counts : Hashable
30+
allele counts to use or calculate.
31+
32+
Returns
33+
-------
34+
DataArray
35+
diversity value.
36+
"""
37+
if len(ds.samples) < 2:
38+
return xr.DataArray(np.nan)
39+
if allele_counts not in ds:
40+
ds = count_variant_alleles(ds)
41+
ac = ds[allele_counts]
42+
an = ac.sum(axis=1)
43+
n_pairs = an * (an - 1) / 2
44+
n_same = (ac * (ac - 1) / 2).sum(axis=1)
45+
n_diff = n_pairs - n_same
46+
pi = n_diff / n_pairs
47+
return pi.sum() # type: ignore[no-any-return]
48+
49+
50+
def divergence(
51+
ds1: Dataset, ds2: Dataset, allele_counts: Hashable = "variant_allele_count",
52+
) -> DataArray:
53+
"""Compute divergence between two genotype call datasets.
54+
55+
Parameters
56+
----------
57+
ds1 : Dataset
58+
Genotype call dataset.
59+
ds2 : Dataset
60+
Genotype call dataset.
61+
allele_counts : Hashable
62+
allele counts to use or calculate.
63+
64+
Returns
65+
-------
66+
DataArray
67+
divergence value between the two datasets.
68+
"""
69+
if allele_counts not in ds1:
70+
ds1 = count_variant_alleles(ds1)
71+
ac1 = ds1[allele_counts]
72+
if allele_counts not in ds2:
73+
ds2 = count_variant_alleles(ds2)
74+
ac2 = ds2[allele_counts]
75+
an1 = ds1[allele_counts].sum(axis=1)
76+
an2 = ds2[allele_counts].sum(axis=1)
77+
78+
n_pairs = an1 * an2
79+
n_same = (ac1 * ac2).sum(axis=1)
80+
n_diff = n_pairs - n_same
81+
div = n_diff / n_pairs
82+
return div.sum() # type: ignore[no-any-return]
83+
84+
85+
def Fst(
86+
ds1: Dataset, ds2: Dataset, allele_counts: Hashable = "variant_allele_count",
87+
) -> DataArray:
88+
"""Compute Fst between two genotype call datasets.
89+
90+
Parameters
91+
----------
92+
ds1 : Dataset
93+
Genotype call dataset.
94+
ds2 : Dataset
95+
Genotype call dataset.
96+
allele_counts : Hashable
97+
allele counts to use or calculate.
98+
99+
Returns
100+
-------
101+
DataArray
102+
fst value between the two datasets.
103+
"""
104+
total_div = diversity(ds1) + diversity(ds2)
105+
gs = divergence(ds1, ds2)
106+
den = total_div + 2 * gs # type: ignore[operator]
107+
fst = 1 - (2 * total_div / den)
108+
return fst # type: ignore[no-any-return]
109+
110+
111+
def Tajimas_D(
112+
ds: Dataset, allele_counts: Hashable = "variant_allele_count",
113+
) -> DataArray:
114+
"""Compute Tajimas' D for a genotype call dataset.
115+
116+
Parameters
117+
----------
118+
ds : Dataset
119+
Genotype call dataset.
120+
allele_counts : Hashable
121+
allele counts to use or calculate.
122+
123+
Returns
124+
-------
125+
DataArray
126+
Tajimas' D value.
127+
"""
128+
if allele_counts not in ds:
129+
ds = count_variant_alleles(ds)
130+
ac = ds[allele_counts]
131+
132+
# count segregating
133+
S = ((ac > 0).sum(axis=1) > 1).sum()
134+
135+
# assume number of chromosomes sampled is constant for all variants
136+
n = ac.sum(axis=1).max()
137+
138+
# (n-1)th harmonic number
139+
a1 = (1 / da.arange(1, n)).sum()
140+
141+
# calculate Watterson's theta (absolute value)
142+
theta = S / a1
143+
144+
# calculate diversity
145+
div = diversity(ds)
146+
147+
# N.B., both theta estimates are usually divided by the number of
148+
# (accessible) bases but here we want the absolute difference
149+
d = div - theta
150+
151+
# calculate the denominator (standard deviation)
152+
a2 = (1 / (da.arange(1, n) ** 2)).sum()
153+
b1 = (n + 1) / (3 * (n - 1))
154+
b2 = 2 * (n ** 2 + n + 3) / (9 * n * (n - 1))
155+
c1 = b1 - (1 / a1)
156+
c2 = b2 - ((n + 2) / (a1 * n)) + (a2 / (a1 ** 2))
157+
e1 = c1 / a1
158+
e2 = c2 / (a1 ** 2 + a2)
159+
d_stdev = np.sqrt((e1 * S) + (e2 * S * (S - 1)))
160+
161+
if d_stdev == 0:
162+
return xr.DataArray(np.nan)
163+
164+
# finally calculate Tajima's D
165+
D = d / d_stdev
166+
return D # type: ignore[no-any-return]

sgkit/tests/test_popgen.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import msprime # type: ignore
2+
import numpy as np
3+
import pytest
4+
5+
from sgkit import Fst, Tajimas_D, create_genotype_call_dataset, divergence, diversity
6+
7+
8+
def ts_to_dataset(ts, samples=None):
9+
"""
10+
Convert the specified tskit tree sequence into an sgkit dataset.
11+
Note this just generates haploids for now. With msprime 1.0, we'll be
12+
able to generate diploid/whatever-ploid individuals easily.
13+
"""
14+
if samples is None:
15+
samples = ts.samples()
16+
tables = ts.dump_tables()
17+
alleles = []
18+
genotypes = []
19+
for var in ts.variants(samples=samples):
20+
alleles.append(var.alleles)
21+
genotypes.append(var.genotypes)
22+
alleles = np.array(alleles).astype("S")
23+
genotypes = np.expand_dims(genotypes, axis=2)
24+
25+
df = create_genotype_call_dataset(
26+
variant_contig_names=["1"],
27+
variant_contig=np.zeros(len(tables.sites), dtype=int),
28+
variant_position=tables.sites.position.astype(int),
29+
variant_alleles=alleles,
30+
sample_id=np.array([f"tsk_{u}" for u in samples]).astype("U"),
31+
call_genotype=genotypes,
32+
)
33+
return df
34+
35+
36+
@pytest.mark.parametrize("size", [2, 3, 10, 100])
37+
def test_diversity(size):
38+
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
39+
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
40+
div = diversity(ds).compute()
41+
ts_div = ts.diversity(span_normalise=False)
42+
np.testing.assert_allclose(div, ts_div)
43+
44+
45+
@pytest.mark.parametrize("size", [2, 3, 10, 100])
46+
def test_divergence(size):
47+
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
48+
subset_1 = ts.samples()[: ts.num_samples // 2]
49+
subset_2 = ts.samples()[ts.num_samples // 2 :]
50+
ds1 = ts_to_dataset(ts, subset_1) # type: ignore[no-untyped-call]
51+
ds2 = ts_to_dataset(ts, subset_2) # type: ignore[no-untyped-call]
52+
div = divergence(ds1, ds2).compute()
53+
ts_div = ts.divergence([subset_1, subset_2], span_normalise=False)
54+
np.testing.assert_allclose(div, ts_div)
55+
56+
57+
@pytest.mark.parametrize("size", [2, 3, 10, 100])
58+
def test_Fst(size):
59+
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
60+
subset_1 = ts.samples()[: ts.num_samples // 2]
61+
subset_2 = ts.samples()[ts.num_samples // 2 :]
62+
ds1 = ts_to_dataset(ts, subset_1) # type: ignore[no-untyped-call]
63+
ds2 = ts_to_dataset(ts, subset_2) # type: ignore[no-untyped-call]
64+
fst = Fst(ds1, ds2).compute()
65+
ts_fst = ts.Fst([subset_1, subset_2])
66+
np.testing.assert_allclose(fst, ts_fst)
67+
68+
69+
@pytest.mark.parametrize("size", [2, 3, 10, 100])
70+
def test_Tajimas_D(size):
71+
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
72+
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
73+
ts_d = ts.Tajimas_D()
74+
d = Tajimas_D(ds).compute()
75+
np.testing.assert_allclose(d, ts_d)

0 commit comments

Comments
 (0)