Skip to content

Commit b83ca1b

Browse files
tomwhitemergify[bot]
authored andcommitted
Test popgen functions on data chunked in variants dimension.
1 parent eecbb93 commit b83ca1b

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

sgkit/stats/popgen.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,22 @@ def diversity(
8282

8383
# c = cohorts, k = alleles
8484
@guvectorize( # type: ignore
85-
["void(int64[:, :], int64[:], float64[:,:])"],
86-
"(c, k),(c)->(c,c)",
85+
["void(int64[:, :], float64[:,:])"],
86+
"(c, k)->(c,c)",
8787
nopython=True,
8888
)
89-
def _divergence(ac: ArrayLike, an: ArrayLike, out: ArrayLike) -> None:
89+
def _divergence(ac: ArrayLike, out: ArrayLike) -> None:
9090
"""Generalized U-function for computing divergence.
9191
9292
Parameters
9393
----------
9494
ac
9595
Allele counts of shape (cohorts, alleles) containing per-cohort allele counts.
96-
an
97-
Allele totals of shape (cohorts,) containing per-cohort allele totals.
9896
out
9997
Pairwise divergence stats with shape (cohorts, cohorts), where the entry at
10098
(i, j) is the divergence between cohort i and cohort j.
10199
"""
100+
an = ac.sum(axis=-1)
102101
out[:, :] = np.nan # (cohorts, cohorts)
103102
n_cohorts = ac.shape[0]
104103
n_alleles = ac.shape[1]
@@ -171,14 +170,12 @@ def divergence(
171170
else:
172171
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
173172
ac = ds[allele_counts]
174-
an = ac.sum(axis=2)
175173

176174
n_variants = ds.dims["variants"]
177175
n_cohorts = ds.dims["cohorts"]
178176
ac = da.asarray(ac)
179-
an = da.asarray(an)
180177
shape = (ac.chunks[0], n_cohorts, n_cohorts)
181-
d = da.map_blocks(_divergence, ac, an, chunks=shape, dtype=np.float64)
178+
d = da.map_blocks(_divergence, ac, chunks=shape, dtype=np.float64)
182179
assert_array_shape(d, n_variants, n_cohorts, n_cohorts)
183180

184181
d_sum = d.sum(axis=0)

sgkit/tests/test_popgen.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sgkit import Fst, Tajimas_D, create_genotype_call_dataset, divergence, diversity
1010

1111

12-
def ts_to_dataset(ts, samples=None):
12+
def ts_to_dataset(ts, chunks=None, samples=None):
1313
"""
1414
Convert the specified tskit tree sequence into an sgkit dataset.
1515
Note this just generates haploids for now. With msprime 1.0, we'll be
@@ -26,22 +26,24 @@ def ts_to_dataset(ts, samples=None):
2626
alleles = np.array(alleles).astype("S")
2727
genotypes = np.expand_dims(genotypes, axis=2)
2828

29-
df = create_genotype_call_dataset(
29+
ds = create_genotype_call_dataset(
3030
variant_contig_names=["1"],
3131
variant_contig=np.zeros(len(tables.sites), dtype=int),
3232
variant_position=tables.sites.position.astype(int),
3333
variant_alleles=alleles,
3434
sample_id=np.array([f"tsk_{u}" for u in samples]).astype("U"),
3535
call_genotype=genotypes,
3636
)
37-
return df
37+
if chunks is not None:
38+
ds = ds.chunk(dict(zip(["variants", "samples"], chunks)))
39+
return ds
3840

3941

4042
@pytest.mark.parametrize("size", [2, 3, 10, 100])
4143
@pytest.mark.parametrize("chunks", [(-1, -1), (10, -1)])
4244
def test_diversity(size, chunks):
4345
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
44-
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
46+
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
4547
ds = ds.chunk(dict(zip(["variants", "samples"], chunks)))
4648
sample_cohorts = np.full_like(ts.samples(), 0)
4749
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
@@ -56,10 +58,11 @@ def test_diversity(size, chunks):
5658
"size, n_cohorts",
5759
[(2, 2), (3, 2), (3, 3), (10, 2), (10, 3), (10, 4), (100, 2), (100, 3), (100, 4)],
5860
)
59-
def test_divergence(size, n_cohorts):
61+
@pytest.mark.parametrize("chunks", [(-1, -1), (10, -1)])
62+
def test_divergence(size, n_cohorts, chunks):
6063
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
6164
subsets = np.array_split(ts.samples(), n_cohorts)
62-
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
65+
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
6366
sample_cohorts = np.concatenate(
6467
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
6568
)
@@ -84,12 +87,13 @@ def test_divergence(size, n_cohorts):
8487

8588

8689
@pytest.mark.parametrize("size", [2, 3, 10, 100])
87-
def test_Fst__Hudson(size):
90+
@pytest.mark.parametrize("chunks", [(-1, -1), (10, -1)])
91+
def test_Fst__Hudson(size, chunks):
8892
# scikit-allel can only calculate Fst for pairs of cohorts (populations)
8993
n_cohorts = 2
9094
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
9195
subsets = np.array_split(ts.samples(), n_cohorts)
92-
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
96+
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
9397
sample_cohorts = np.concatenate(
9498
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
9599
)
@@ -112,10 +116,11 @@ def test_Fst__Hudson(size):
112116
"size, n_cohorts",
113117
[(2, 2), (3, 2), (3, 3), (10, 2), (10, 3), (10, 4), (100, 2), (100, 3), (100, 4)],
114118
)
115-
def test_Fst__Nei(size, n_cohorts):
119+
@pytest.mark.parametrize("chunks", [(-1, -1), (10, -1)])
120+
def test_Fst__Nei(size, n_cohorts, chunks):
116121
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
117122
subsets = np.array_split(ts.samples(), n_cohorts)
118-
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
123+
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
119124
sample_cohorts = np.concatenate(
120125
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
121126
)
@@ -142,9 +147,10 @@ def test_Fst__unknown_estimator():
142147

143148

144149
@pytest.mark.parametrize("size", [2, 3, 10, 100])
145-
def test_Tajimas_D(size):
150+
@pytest.mark.parametrize("chunks", [(-1, -1), (10, -1)])
151+
def test_Tajimas_D(size, chunks):
146152
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
147-
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
153+
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
148154
sample_cohorts = np.full_like(ts.samples(), 0)
149155
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
150156
ds = Tajimas_D(ds)

0 commit comments

Comments
 (0)