Skip to content

Commit 8078bf8

Browse files
committed
Generalise Fst and PBS tests to test variable numbers of cohorts
1 parent 92cbb3c commit 8078bf8

File tree

1 file changed

+39
-30
lines changed

1 file changed

+39
-30
lines changed

sgkit/tests/test_popgen.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_Fst__unknown_estimator():
255255

256256
@pytest.mark.parametrize(
257257
"sample_size, n_cohorts",
258-
[(10, 2)],
258+
[(10, 2), (10, 3)],
259259
)
260260
@pytest.mark.parametrize("chunks", [(-1, -1), (50, -1)])
261261
def test_Fst__windowed(sample_size, n_cohorts, chunks):
@@ -280,16 +280,18 @@ def test_Fst__windowed(sample_size, n_cohorts, chunks):
280280

281281
np.testing.assert_allclose(fst, ts_fst)
282282

283+
# scikit-allel
283284
fst_ds = Fst(ds, estimator="Hudson")
284-
fst = fst_ds["stat_Fst"].sel(cohorts_0="co_0", cohorts_1="co_1").values
285+
for i, j in itertools.combinations(range(n_cohorts), 2):
286+
fst = fst_ds["stat_Fst"].sel(cohorts_0=f"co_{i}", cohorts_1=f"co_{j}").values
285287

286-
ac1 = fst_ds.cohort_allele_count.values[:, 0, :]
287-
ac2 = fst_ds.cohort_allele_count.values[:, 1, :]
288-
ska_fst = allel.moving_hudson_fst(ac1, ac2, size=25)
288+
ac_i = fst_ds.cohort_allele_count.values[:, i, :]
289+
ac_j = fst_ds.cohort_allele_count.values[:, j, :]
290+
ska_fst = allel.moving_hudson_fst(ac_i, ac_j, size=25)
289291

290-
np.testing.assert_allclose(
291-
fst[:-1], ska_fst
292-
) # scikit-allel has final window missing
292+
np.testing.assert_allclose(
293+
fst[:-1], ska_fst
294+
) # scikit-allel has final window missing
293295

294296

295297
@pytest.mark.parametrize("sample_size", [2, 3, 10, 100])
@@ -307,56 +309,63 @@ def test_Tajimas_D(sample_size):
307309

308310
@pytest.mark.parametrize(
309311
"sample_size, n_cohorts",
310-
[(10, 3)],
312+
[(10, 3), (20, 4)],
311313
)
312314
def test_pbs(sample_size, n_cohorts):
313315
ts = msprime.simulate(sample_size, length=100, mutation_rate=0.05, random_seed=42)
314316
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
315-
ds, subsets = add_cohorts(ds, ts, n_cohorts) # type: ignore[no-untyped-call]
317+
ds, subsets = add_cohorts(ds, ts, n_cohorts, cohort_key_names=["cohorts_0", "cohorts_1", "cohorts_2"]) # type: ignore[no-untyped-call]
316318
n_variants = ds.dims["variants"]
317319
ds = window(ds, size=n_variants) # single window
318320

319321
ds = pbs(ds)
320-
stat_pbs = ds["stat_pbs"]
321322

322323
# scikit-allel
323-
ac1 = ds.cohort_allele_count.values[:, 0, :]
324-
ac2 = ds.cohort_allele_count.values[:, 1, :]
325-
ac3 = ds.cohort_allele_count.values[:, 2, :]
326-
327-
ska_pbs_value = np.full([1, n_cohorts, n_cohorts, n_cohorts], np.nan)
328324
for i, j, k in itertools.combinations(range(n_cohorts), 3):
329-
ska_pbs_value[0, i, j, k] = allel.pbs(ac1, ac2, ac3, window_size=n_variants)
325+
stat_pbs = (
326+
ds["stat_pbs"]
327+
.sel(cohorts_0=f"co_{i}", cohorts_1=f"co_{j}", cohorts_2=f"co_{k}")
328+
.values
329+
)
330330

331-
np.testing.assert_allclose(stat_pbs, ska_pbs_value)
331+
ac_i = ds.cohort_allele_count.values[:, i, :]
332+
ac_j = ds.cohort_allele_count.values[:, j, :]
333+
ac_k = ds.cohort_allele_count.values[:, k, :]
334+
335+
ska_pbs_value = allel.pbs(ac_i, ac_j, ac_k, window_size=n_variants)
336+
337+
np.testing.assert_allclose(stat_pbs, ska_pbs_value)
332338

333339

334340
@pytest.mark.parametrize(
335341
"sample_size, n_cohorts",
336-
[(10, 3)],
342+
[(10, 3), (20, 4)],
337343
)
338344
@pytest.mark.parametrize("chunks", [(-1, -1), (50, -1)])
339345
def test_pbs__windowed(sample_size, n_cohorts, chunks):
340346
ts = msprime.simulate(sample_size, length=200, mutation_rate=0.05, random_seed=42)
341347
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
342-
ds, subsets = add_cohorts(ds, ts, n_cohorts) # type: ignore[no-untyped-call]
348+
ds, subsets = add_cohorts(ds, ts, n_cohorts, cohort_key_names=["cohorts_0", "cohorts_1", "cohorts_2"]) # type: ignore[no-untyped-call]
343349
ds = window(ds, size=25)
344350

345351
ds = pbs(ds)
346-
stat_pbs = ds["stat_pbs"].values
347352

348353
# scikit-allel
349-
ac1 = ds.cohort_allele_count.values[:, 0, :]
350-
ac2 = ds.cohort_allele_count.values[:, 1, :]
351-
ac3 = ds.cohort_allele_count.values[:, 2, :]
352-
353-
# scikit-allel has final window missing
354-
n_windows = ds.dims["windows"] - 1
355-
ska_pbs_value = np.full([n_windows, n_cohorts, n_cohorts, n_cohorts], np.nan)
356354
for i, j, k in itertools.combinations(range(n_cohorts), 3):
357-
ska_pbs_value[:, i, j, k] = allel.pbs(ac1, ac2, ac3, window_size=25)
355+
stat_pbs = (
356+
ds["stat_pbs"]
357+
.sel(cohorts_0=f"co_{i}", cohorts_1=f"co_{j}", cohorts_2=f"co_{k}")
358+
.values
359+
)
360+
361+
ac_i = ds.cohort_allele_count.values[:, i, :]
362+
ac_j = ds.cohort_allele_count.values[:, j, :]
363+
ac_k = ds.cohort_allele_count.values[:, k, :]
364+
365+
ska_pbs_value = allel.pbs(ac_i, ac_j, ac_k, window_size=25, window_step=25)
358366

359-
np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)
367+
# scikit-allel has final window missing
368+
np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)
360369

361370

362371
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)