From a8857d9490f09763803bc06e0b4f6593138c6f4b Mon Sep 17 00:00:00 2001 From: eric-czech Date: Tue, 7 Jul 2020 18:40:07 -0400 Subject: [PATCH 1/9] HWE exact test implementation for scalar genotype counts --- requirements.txt | 1 + sgkit/stats/hwe.py | 83 +++++++++++++++++++++++++++++++++++ sgkit/tests/test_hwe.py | 97 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+) create mode 100644 sgkit/stats/hwe.py create mode 100644 sgkit/tests/test_hwe.py diff --git a/requirements.txt b/requirements.txt index 894396a41..357677228 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy xarray dask[array] scipy +numba \ No newline at end of file diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py new file mode 100644 index 000000000..88d678453 --- /dev/null +++ b/sgkit/stats/hwe.py @@ -0,0 +1,83 @@ +from numba import njit +import numpy as np + +# TODO: Is there a way to get coverage on jit functions? + +def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float: # pragma: no cover + """Exact test for HWE as described in Wigginton et al. 2005 [1] + + Parameters + ---------- + obs_hets : int + Number of heterozygotes with minor variant + obs_hom1 : int + Number of reference/major homozygotes + obs_hom2 : int + Number of alternate/minor homozygotes + + Returns + ------- + float + P value in [0, 1] + + References + ---------- + - [1] Wigginton, Janis E., David J. Cutler, and Goncalo R. Abecasis. 2005. + “A Note on Exact Tests of Hardy-Weinberg Equilibrium.” American Journal of + Human Genetics 76 (5): 887–93. + + Raises + ------ + ValueError + If any observed counts are negative + """ + if obs_hom1 < 0 or obs_hom2 < 0 or obs_hets < 0: + raise ValueError('Observed genotype counts must be positive') + + obs_homc = obs_hom2 if obs_hom1 < obs_hom2 else obs_hom1 + obs_homr = obs_hom1 if obs_hom1 < obs_hom2 else obs_hom2 + obs_mac = 2 * obs_homr + obs_hets + obs_n = obs_hets + obs_homc + obs_homr + het_probs = np.zeros(obs_mac + 1, dtype=np.float64) + + if obs_n == 0: + return np.nan + + # Identify distribution midpoint + mid = int(obs_mac * (2 * obs_n - obs_mac) / (2 * obs_n)) + if ((obs_mac & 1) ^ (mid & 1)): + mid += 1 + het_probs[mid] = 1.0 + prob_sum = het_probs[mid] + + # Integrate downward from distribution midpoint + curr_hets = mid + curr_homr = int((obs_mac - mid) / 2) + curr_homc = obs_n - curr_hets - curr_homr + while curr_hets > 1: + het_probs[curr_hets - 2] = het_probs[curr_hets] * curr_hets * (curr_hets - 1.0) / (4.0 * (curr_homr + 1.0) * (curr_homc + 1.0)) + prob_sum += het_probs[curr_hets - 2] + curr_homr += 1 + curr_homc += 1 + curr_hets -= 2 + + # Integrate upward from distribution midpoint + curr_hets = mid + curr_homr = int((obs_mac - mid) / 2) + curr_homc = obs_n - curr_hets - curr_homr + while curr_hets <= obs_mac - 2: + het_probs[curr_hets + 2] = het_probs[curr_hets] * 4.0 * curr_homr * curr_homc / ((curr_hets + 2.0) * (curr_hets + 1.0)) + prob_sum += het_probs[curr_hets + 2] + curr_homr -= 1 + curr_homc -= 1 + curr_hets += 2 + + if prob_sum <= 0: + return np.nan + het_probs = het_probs / prob_sum + p = het_probs[het_probs <= het_probs[obs_hets]].sum() + p = max(min(1.0, p), 0.0) + + return p + +hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value) \ No newline at end of file diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py new file mode 100644 index 000000000..77567a839 --- /dev/null +++ b/sgkit/tests/test_hwe.py @@ -0,0 +1,97 @@ +import pytest +import numpy as np +from sgkit.stats.hwe import ( + hardy_weinberg_p_value_jit as hwep_jit, + hardy_weinberg_p_value as hwep +) + +def get_obs_gt_cts(): + n, step = 10_000, 50 + rs = np.random.RandomState(0) + n_het = np.expand_dims(np.arange(n, step=step) + 1, -1) + frac = rs.uniform(.3, .7, size=(n // step, 2)) + n_hom = frac * n_het + n_hom = n_hom.astype(int) + return np.concatenate((n_het, n_hom), axis=1) + +# Export from execution of C/C++ code at http://csg.sph.umich.edu/abecasis/Exact/snp_hwe.c +EXPECTED_P_VAL = [ + 1.00000000e+000, 8.45926829e-001, 8.89304245e-001, 3.68487492e-001, + 2.83442131e-001, 1.93780506e-001, 3.46415612e-002, 9.77805142e-007, + 9.00169099e-002, 2.77392776e-004, 5.78595078e-006, 1.56290046e-001, + 3.11983705e-002, 7.78234779e-001, 6.28255056e-001, 9.17242816e-001, + 8.81087089e-001, 1.20954751e-004, 6.51960684e-002, 4.87927509e-007, + 6.14320396e-002, 1.67216769e-003, 2.58323982e-003, 9.22666204e-012, + 1.15591803e-003, 1.00000000e+000, 5.21303203e-001, 2.40595832e-012, + 1.79017126e-001, 8.50964237e-004, 4.08782584e-018, 2.65625649e-003, + 1.73047163e-007, 2.61257337e-002, 3.40282167e-002, 5.57265342e-006, + 2.28187711e-010, 3.71009969e-005, 2.02796027e-015, 2.85690782e-015, + 4.43715904e-004, 1.24880234e-005, 1.39680904e-002, 6.69133747e-009, + 9.43219724e-010, 6.10161450e-001, 1.93499955e-003, 1.44451527e-014, + 1.15651799e-011, 6.16416362e-006, 2.18519190e-001, 2.67902896e-020, + 3.81265044e-003, 1.87170429e-002, 2.87276124e-001, 1.46939801e-004, + 5.90523804e-001, 9.00712608e-003, 7.82143524e-011, 1.55029275e-016, + 1.00796610e-003, 6.51775272e-018, 7.22627291e-001, 3.50621941e-033, + 2.15694037e-001, 5.36554440e-001, 4.98209450e-023, 1.00725415e-002, + 2.83256119e-004, 2.31647615e-001, 5.40831311e-004, 2.28693251e-006, + 2.33943256e-016, 4.63666449e-002, 1.95571664e-029, 1.32013500e-001, + 1.93010279e-006, 1.72246817e-002, 4.44008208e-010, 2.64771353e-025, + 1.42567926e-002, 2.34658222e-023, 5.14985651e-044, 4.48467881e-038, + 2.38901290e-003, 3.00019737e-020, 9.91998679e-058, 3.85771324e-001, + 1.19901665e-004, 1.09586529e-012, 4.52696626e-007, 4.52117435e-005, + 3.74269466e-022, 1.84769664e-002, 9.01235925e-001, 4.71167421e-016, + 7.26213285e-001, 2.68067642e-005, 1.95763513e-027, 3.44681033e-030, + 6.72973257e-001, 1.90998085e-021, 2.71129678e-092, 1.33474542e-002, + 9.42328262e-016, 6.04559513e-002, 2.73568136e-002, 3.45497420e-013, + 1.85964309e-010, 2.25791165e-016, 8.88002002e-023, 7.31645858e-001, + 6.20103273e-001, 2.02013957e-003, 3.26543825e-041, 9.55096556e-034, + 1.58435946e-031, 1.67723973e-017, 3.01571822e-004, 5.94647843e-004, + 3.50999380e-003, 1.42692287e-018, 4.40701593e-002, 1.02072821e-010, + 6.12844453e-020, 4.01149386e-007, 4.52329633e-028, 6.36621011e-004, + 2.40691727e-003, 1.51079564e-004, 1.46439431e-059, 1.19603499e-007, + 2.30499126e-023, 3.90483620e-004, 3.00491712e-033, 4.67334134e-075, + 2.14446525e-007, 5.74808603e-002, 7.54901939e-059, 1.00820382e-028, + 5.45503604e-002, 2.00408985e-029, 2.60055020e-038, 1.37950333e-021, + 1.67336706e-003, 5.11497091e-038, 9.63001456e-002, 1.85048263e-012, + 7.60512104e-005, 1.90260703e-097, 8.41707732e-055, 5.02772009e-056, + 4.74769747e-021, 1.53427038e-108, 3.65547065e-022, 3.59345583e-005, + 4.29008968e-115, 2.29690838e-003, 5.12962271e-001, 2.82010264e-044, + 1.25488919e-059, 4.26516777e-072, 2.92597766e-014, 1.13938024e-020, + 2.65101694e-019, 6.39260807e-003, 3.44575391e-019, 2.46964669e-042, + 2.18893082e-023, 2.32535921e-005, 3.67548497e-033, 6.28178465e-050, + 4.01855250e-010, 8.14210277e-007, 7.19942047e-038, 1.23293898e-028, + 1.04555107e-001, 2.80977631e-008, 3.38829632e-065, 3.67682844e-014, + 7.97794167e-001, 9.88137129e-001, 7.83054274e-016, 6.10205517e-003, + 3.54737998e-051, 1.00000000e+000, 1.23015267e-024, 7.06536040e-069, + 2.27403687e-082, 2.12853071e-001, 2.09868517e-014, 4.20835611e-040, + 1.72349554e-079, 1.58828256e-003, 6.46108778e-001, 1.80557310e-058, + 2.70043232e-001, 1.84978056e-007, 6.97911818e-017, 6.09976723e-137] + +def test_hwep_against_reference_impl(): + args = get_obs_gt_cts() + p = [hwep(*arg) for arg in args] + np.testing.assert_allclose(p, EXPECTED_P_VAL) + +def test_hwep_raise_on_negative(): + args = [ + [-1, 0, 0], + [0, -1, 0], + [0, 0, -1] + ] + for arg in args: + with pytest.raises(ValueError): + hwep(*arg) + +def test_hwep_zeros(): + assert np.isnan(hwep(0, 0, 0)) + +def test_hwep_large_counts(): + # Note: use jit-compiled function for large counts to avoid slowing build down + for n_het in 10**np.arange(3, 8): + # Test case in perfect equilibrium + p = hwep_jit(n_het, n_het//2, n_het//2) + assert np.isclose(p, 1.0, atol=1e-8) + # Test case way out of equilibrium + p = hwep_jit(n_het, n_het//10, n_het//2 + n_het//10) + assert np.isclose(p, 0, atol=1e-8) + \ No newline at end of file From d04b586e45261b2858ddddec753721a9e5cccb35 Mon Sep 17 00:00:00 2001 From: eric-czech Date: Tue, 7 Jul 2020 18:40:45 -0400 Subject: [PATCH 2/9] Formatting --- requirements.txt | 2 +- setup.cfg | 2 +- sgkit/stats/hwe.py | 37 ++++-- sgkit/tests/test_hwe.py | 283 ++++++++++++++++++++++++++++++---------- 4 files changed, 245 insertions(+), 79 deletions(-) diff --git a/requirements.txt b/requirements.txt index 357677228..f9fcad207 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ numpy xarray dask[array] scipy -numba \ No newline at end of file +numba diff --git a/setup.cfg b/setup.cfg index 700ae0930..6189dde87 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ ignore = [isort] default_section = THIRDPARTY known_first_party = sgkit -known_third_party = dask,numpy,pandas,pytest,setuptools,statsmodels,xarray +known_third_party = dask,numba,numpy,pandas,pytest,setuptools,statsmodels,xarray multi_line_output = 3 include_trailing_comma = True force_grid_wrap = 0 diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index 88d678453..e054568f6 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -1,9 +1,12 @@ -from numba import njit import numpy as np +from numba import njit # TODO: Is there a way to get coverage on jit functions? -def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float: # pragma: no cover + +def hardy_weinberg_p_value( + obs_hets: int, obs_hom1: int, obs_hom2: int +) -> float: # pragma: no cover """Exact test for HWE as described in Wigginton et al. 2005 [1] Parameters @@ -32,8 +35,8 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float If any observed counts are negative """ if obs_hom1 < 0 or obs_hom2 < 0 or obs_hets < 0: - raise ValueError('Observed genotype counts must be positive') - + raise ValueError("Observed genotype counts must be positive") + obs_homc = obs_hom2 if obs_hom1 < obs_hom2 else obs_hom1 obs_homr = obs_hom1 if obs_hom1 < obs_hom2 else obs_hom2 obs_mac = 2 * obs_homr + obs_hets @@ -42,20 +45,25 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float if obs_n == 0: return np.nan - + # Identify distribution midpoint mid = int(obs_mac * (2 * obs_n - obs_mac) / (2 * obs_n)) - if ((obs_mac & 1) ^ (mid & 1)): + if (obs_mac & 1) ^ (mid & 1): mid += 1 het_probs[mid] = 1.0 prob_sum = het_probs[mid] - + # Integrate downward from distribution midpoint curr_hets = mid curr_homr = int((obs_mac - mid) / 2) curr_homc = obs_n - curr_hets - curr_homr while curr_hets > 1: - het_probs[curr_hets - 2] = het_probs[curr_hets] * curr_hets * (curr_hets - 1.0) / (4.0 * (curr_homr + 1.0) * (curr_homc + 1.0)) + het_probs[curr_hets - 2] = ( + het_probs[curr_hets] + * curr_hets + * (curr_hets - 1.0) + / (4.0 * (curr_homr + 1.0) * (curr_homc + 1.0)) + ) prob_sum += het_probs[curr_hets - 2] curr_homr += 1 curr_homc += 1 @@ -66,7 +74,13 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float curr_homr = int((obs_mac - mid) / 2) curr_homc = obs_n - curr_hets - curr_homr while curr_hets <= obs_mac - 2: - het_probs[curr_hets + 2] = het_probs[curr_hets] * 4.0 * curr_homr * curr_homc / ((curr_hets + 2.0) * (curr_hets + 1.0)) + het_probs[curr_hets + 2] = ( + het_probs[curr_hets] + * 4.0 + * curr_homr + * curr_homc + / ((curr_hets + 2.0) * (curr_hets + 1.0)) + ) prob_sum += het_probs[curr_hets + 2] curr_homr -= 1 curr_homc -= 1 @@ -77,7 +91,8 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float het_probs = het_probs / prob_sum p = het_probs[het_probs <= het_probs[obs_hets]].sum() p = max(min(1.0, p), 0.0) - + return p -hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value) \ No newline at end of file + +hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index 77567a839..33dc04af5 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -1,97 +1,248 @@ -import pytest import numpy as np -from sgkit.stats.hwe import ( - hardy_weinberg_p_value_jit as hwep_jit, - hardy_weinberg_p_value as hwep -) +import pytest + +from sgkit.stats.hwe import hardy_weinberg_p_value as hwep +from sgkit.stats.hwe import hardy_weinberg_p_value_jit as hwep_jit + def get_obs_gt_cts(): n, step = 10_000, 50 rs = np.random.RandomState(0) n_het = np.expand_dims(np.arange(n, step=step) + 1, -1) - frac = rs.uniform(.3, .7, size=(n // step, 2)) + frac = rs.uniform(0.3, 0.7, size=(n // step, 2)) n_hom = frac * n_het n_hom = n_hom.astype(int) return np.concatenate((n_het, n_hom), axis=1) - + + # Export from execution of C/C++ code at http://csg.sph.umich.edu/abecasis/Exact/snp_hwe.c EXPECTED_P_VAL = [ - 1.00000000e+000, 8.45926829e-001, 8.89304245e-001, 3.68487492e-001, - 2.83442131e-001, 1.93780506e-001, 3.46415612e-002, 9.77805142e-007, - 9.00169099e-002, 2.77392776e-004, 5.78595078e-006, 1.56290046e-001, - 3.11983705e-002, 7.78234779e-001, 6.28255056e-001, 9.17242816e-001, - 8.81087089e-001, 1.20954751e-004, 6.51960684e-002, 4.87927509e-007, - 6.14320396e-002, 1.67216769e-003, 2.58323982e-003, 9.22666204e-012, - 1.15591803e-003, 1.00000000e+000, 5.21303203e-001, 2.40595832e-012, - 1.79017126e-001, 8.50964237e-004, 4.08782584e-018, 2.65625649e-003, - 1.73047163e-007, 2.61257337e-002, 3.40282167e-002, 5.57265342e-006, - 2.28187711e-010, 3.71009969e-005, 2.02796027e-015, 2.85690782e-015, - 4.43715904e-004, 1.24880234e-005, 1.39680904e-002, 6.69133747e-009, - 9.43219724e-010, 6.10161450e-001, 1.93499955e-003, 1.44451527e-014, - 1.15651799e-011, 6.16416362e-006, 2.18519190e-001, 2.67902896e-020, - 3.81265044e-003, 1.87170429e-002, 2.87276124e-001, 1.46939801e-004, - 5.90523804e-001, 9.00712608e-003, 7.82143524e-011, 1.55029275e-016, - 1.00796610e-003, 6.51775272e-018, 7.22627291e-001, 3.50621941e-033, - 2.15694037e-001, 5.36554440e-001, 4.98209450e-023, 1.00725415e-002, - 2.83256119e-004, 2.31647615e-001, 5.40831311e-004, 2.28693251e-006, - 2.33943256e-016, 4.63666449e-002, 1.95571664e-029, 1.32013500e-001, - 1.93010279e-006, 1.72246817e-002, 4.44008208e-010, 2.64771353e-025, - 1.42567926e-002, 2.34658222e-023, 5.14985651e-044, 4.48467881e-038, - 2.38901290e-003, 3.00019737e-020, 9.91998679e-058, 3.85771324e-001, - 1.19901665e-004, 1.09586529e-012, 4.52696626e-007, 4.52117435e-005, - 3.74269466e-022, 1.84769664e-002, 9.01235925e-001, 4.71167421e-016, - 7.26213285e-001, 2.68067642e-005, 1.95763513e-027, 3.44681033e-030, - 6.72973257e-001, 1.90998085e-021, 2.71129678e-092, 1.33474542e-002, - 9.42328262e-016, 6.04559513e-002, 2.73568136e-002, 3.45497420e-013, - 1.85964309e-010, 2.25791165e-016, 8.88002002e-023, 7.31645858e-001, - 6.20103273e-001, 2.02013957e-003, 3.26543825e-041, 9.55096556e-034, - 1.58435946e-031, 1.67723973e-017, 3.01571822e-004, 5.94647843e-004, - 3.50999380e-003, 1.42692287e-018, 4.40701593e-002, 1.02072821e-010, - 6.12844453e-020, 4.01149386e-007, 4.52329633e-028, 6.36621011e-004, - 2.40691727e-003, 1.51079564e-004, 1.46439431e-059, 1.19603499e-007, - 2.30499126e-023, 3.90483620e-004, 3.00491712e-033, 4.67334134e-075, - 2.14446525e-007, 5.74808603e-002, 7.54901939e-059, 1.00820382e-028, - 5.45503604e-002, 2.00408985e-029, 2.60055020e-038, 1.37950333e-021, - 1.67336706e-003, 5.11497091e-038, 9.63001456e-002, 1.85048263e-012, - 7.60512104e-005, 1.90260703e-097, 8.41707732e-055, 5.02772009e-056, - 4.74769747e-021, 1.53427038e-108, 3.65547065e-022, 3.59345583e-005, - 4.29008968e-115, 2.29690838e-003, 5.12962271e-001, 2.82010264e-044, - 1.25488919e-059, 4.26516777e-072, 2.92597766e-014, 1.13938024e-020, - 2.65101694e-019, 6.39260807e-003, 3.44575391e-019, 2.46964669e-042, - 2.18893082e-023, 2.32535921e-005, 3.67548497e-033, 6.28178465e-050, - 4.01855250e-010, 8.14210277e-007, 7.19942047e-038, 1.23293898e-028, - 1.04555107e-001, 2.80977631e-008, 3.38829632e-065, 3.67682844e-014, - 7.97794167e-001, 9.88137129e-001, 7.83054274e-016, 6.10205517e-003, - 3.54737998e-051, 1.00000000e+000, 1.23015267e-024, 7.06536040e-069, - 2.27403687e-082, 2.12853071e-001, 2.09868517e-014, 4.20835611e-040, - 1.72349554e-079, 1.58828256e-003, 6.46108778e-001, 1.80557310e-058, - 2.70043232e-001, 1.84978056e-007, 6.97911818e-017, 6.09976723e-137] + 1.00000000e000, + 8.45926829e-001, + 8.89304245e-001, + 3.68487492e-001, + 2.83442131e-001, + 1.93780506e-001, + 3.46415612e-002, + 9.77805142e-007, + 9.00169099e-002, + 2.77392776e-004, + 5.78595078e-006, + 1.56290046e-001, + 3.11983705e-002, + 7.78234779e-001, + 6.28255056e-001, + 9.17242816e-001, + 8.81087089e-001, + 1.20954751e-004, + 6.51960684e-002, + 4.87927509e-007, + 6.14320396e-002, + 1.67216769e-003, + 2.58323982e-003, + 9.22666204e-012, + 1.15591803e-003, + 1.00000000e000, + 5.21303203e-001, + 2.40595832e-012, + 1.79017126e-001, + 8.50964237e-004, + 4.08782584e-018, + 2.65625649e-003, + 1.73047163e-007, + 2.61257337e-002, + 3.40282167e-002, + 5.57265342e-006, + 2.28187711e-010, + 3.71009969e-005, + 2.02796027e-015, + 2.85690782e-015, + 4.43715904e-004, + 1.24880234e-005, + 1.39680904e-002, + 6.69133747e-009, + 9.43219724e-010, + 6.10161450e-001, + 1.93499955e-003, + 1.44451527e-014, + 1.15651799e-011, + 6.16416362e-006, + 2.18519190e-001, + 2.67902896e-020, + 3.81265044e-003, + 1.87170429e-002, + 2.87276124e-001, + 1.46939801e-004, + 5.90523804e-001, + 9.00712608e-003, + 7.82143524e-011, + 1.55029275e-016, + 1.00796610e-003, + 6.51775272e-018, + 7.22627291e-001, + 3.50621941e-033, + 2.15694037e-001, + 5.36554440e-001, + 4.98209450e-023, + 1.00725415e-002, + 2.83256119e-004, + 2.31647615e-001, + 5.40831311e-004, + 2.28693251e-006, + 2.33943256e-016, + 4.63666449e-002, + 1.95571664e-029, + 1.32013500e-001, + 1.93010279e-006, + 1.72246817e-002, + 4.44008208e-010, + 2.64771353e-025, + 1.42567926e-002, + 2.34658222e-023, + 5.14985651e-044, + 4.48467881e-038, + 2.38901290e-003, + 3.00019737e-020, + 9.91998679e-058, + 3.85771324e-001, + 1.19901665e-004, + 1.09586529e-012, + 4.52696626e-007, + 4.52117435e-005, + 3.74269466e-022, + 1.84769664e-002, + 9.01235925e-001, + 4.71167421e-016, + 7.26213285e-001, + 2.68067642e-005, + 1.95763513e-027, + 3.44681033e-030, + 6.72973257e-001, + 1.90998085e-021, + 2.71129678e-092, + 1.33474542e-002, + 9.42328262e-016, + 6.04559513e-002, + 2.73568136e-002, + 3.45497420e-013, + 1.85964309e-010, + 2.25791165e-016, + 8.88002002e-023, + 7.31645858e-001, + 6.20103273e-001, + 2.02013957e-003, + 3.26543825e-041, + 9.55096556e-034, + 1.58435946e-031, + 1.67723973e-017, + 3.01571822e-004, + 5.94647843e-004, + 3.50999380e-003, + 1.42692287e-018, + 4.40701593e-002, + 1.02072821e-010, + 6.12844453e-020, + 4.01149386e-007, + 4.52329633e-028, + 6.36621011e-004, + 2.40691727e-003, + 1.51079564e-004, + 1.46439431e-059, + 1.19603499e-007, + 2.30499126e-023, + 3.90483620e-004, + 3.00491712e-033, + 4.67334134e-075, + 2.14446525e-007, + 5.74808603e-002, + 7.54901939e-059, + 1.00820382e-028, + 5.45503604e-002, + 2.00408985e-029, + 2.60055020e-038, + 1.37950333e-021, + 1.67336706e-003, + 5.11497091e-038, + 9.63001456e-002, + 1.85048263e-012, + 7.60512104e-005, + 1.90260703e-097, + 8.41707732e-055, + 5.02772009e-056, + 4.74769747e-021, + 1.53427038e-108, + 3.65547065e-022, + 3.59345583e-005, + 4.29008968e-115, + 2.29690838e-003, + 5.12962271e-001, + 2.82010264e-044, + 1.25488919e-059, + 4.26516777e-072, + 2.92597766e-014, + 1.13938024e-020, + 2.65101694e-019, + 6.39260807e-003, + 3.44575391e-019, + 2.46964669e-042, + 2.18893082e-023, + 2.32535921e-005, + 3.67548497e-033, + 6.28178465e-050, + 4.01855250e-010, + 8.14210277e-007, + 7.19942047e-038, + 1.23293898e-028, + 1.04555107e-001, + 2.80977631e-008, + 3.38829632e-065, + 3.67682844e-014, + 7.97794167e-001, + 9.88137129e-001, + 7.83054274e-016, + 6.10205517e-003, + 3.54737998e-051, + 1.00000000e000, + 1.23015267e-024, + 7.06536040e-069, + 2.27403687e-082, + 2.12853071e-001, + 2.09868517e-014, + 4.20835611e-040, + 1.72349554e-079, + 1.58828256e-003, + 6.46108778e-001, + 1.80557310e-058, + 2.70043232e-001, + 1.84978056e-007, + 6.97911818e-017, + 6.09976723e-137, +] + def test_hwep_against_reference_impl(): args = get_obs_gt_cts() p = [hwep(*arg) for arg in args] np.testing.assert_allclose(p, EXPECTED_P_VAL) + def test_hwep_raise_on_negative(): - args = [ - [-1, 0, 0], - [0, -1, 0], - [0, 0, -1] - ] + args = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] for arg in args: with pytest.raises(ValueError): hwep(*arg) + def test_hwep_zeros(): assert np.isnan(hwep(0, 0, 0)) + def test_hwep_large_counts(): # Note: use jit-compiled function for large counts to avoid slowing build down - for n_het in 10**np.arange(3, 8): + for n_het in 10 ** np.arange(3, 8): # Test case in perfect equilibrium - p = hwep_jit(n_het, n_het//2, n_het//2) + p = hwep_jit(n_het, n_het // 2, n_het // 2) assert np.isclose(p, 1.0, atol=1e-8) # Test case way out of equilibrium - p = hwep_jit(n_het, n_het//10, n_het//2 + n_het//10) + p = hwep_jit(n_het, n_het // 10, n_het // 2 + n_het // 10) assert np.isclose(p, 0, atol=1e-8) - \ No newline at end of file From 59290e72cd99b4baf5abe07ebda63c9a7f91c7b3 Mon Sep 17 00:00:00 2001 From: eric-czech Date: Tue, 7 Jul 2020 18:43:40 -0400 Subject: [PATCH 3/9] Formatting --- sgkit/tests/test_hwe.py | 58 ++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index 33dc04af5..acad42350 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -15,6 +15,34 @@ def get_obs_gt_cts(): return np.concatenate((n_het, n_hom), axis=1) +def test_hwep_against_reference_impl(): + args = get_obs_gt_cts() + p = [hwep(*arg) for arg in args] + np.testing.assert_allclose(p, EXPECTED_P_VAL) + + +def test_hwep_raise_on_negative(): + args = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] + for arg in args: + with pytest.raises(ValueError): + hwep(*arg) + + +def test_hwep_zeros(): + assert np.isnan(hwep(0, 0, 0)) + + +def test_hwep_large_counts(): + # Note: use jit-compiled function for large counts to avoid slowing build down + for n_het in 10 ** np.arange(3, 8): + # Test case in perfect equilibrium + p = hwep_jit(n_het, n_het // 2, n_het // 2) + assert np.isclose(p, 1.0, atol=1e-8) + # Test case way out of equilibrium + p = hwep_jit(n_het, n_het // 10, n_het // 2 + n_het // 10) + assert np.isclose(p, 0, atol=1e-8) + + # Export from execution of C/C++ code at http://csg.sph.umich.edu/abecasis/Exact/snp_hwe.c EXPECTED_P_VAL = [ 1.00000000e000, @@ -217,32 +245,4 @@ def get_obs_gt_cts(): 1.84978056e-007, 6.97911818e-017, 6.09976723e-137, -] - - -def test_hwep_against_reference_impl(): - args = get_obs_gt_cts() - p = [hwep(*arg) for arg in args] - np.testing.assert_allclose(p, EXPECTED_P_VAL) - - -def test_hwep_raise_on_negative(): - args = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] - for arg in args: - with pytest.raises(ValueError): - hwep(*arg) - - -def test_hwep_zeros(): - assert np.isnan(hwep(0, 0, 0)) - - -def test_hwep_large_counts(): - # Note: use jit-compiled function for large counts to avoid slowing build down - for n_het in 10 ** np.arange(3, 8): - # Test case in perfect equilibrium - p = hwep_jit(n_het, n_het // 2, n_het // 2) - assert np.isclose(p, 1.0, atol=1e-8) - # Test case way out of equilibrium - p = hwep_jit(n_het, n_het // 10, n_het // 2 + n_het // 10) - assert np.isclose(p, 0, atol=1e-8) +] \ No newline at end of file From 9d9de8ac83de784665a88e2d94ce0bd6fcee4b3b Mon Sep 17 00:00:00 2001 From: eric-czech Date: Wed, 8 Jul 2020 12:46:07 -0400 Subject: [PATCH 4/9] Adding more tests --- sgkit/stats/hwe.py | 46 ++++++++++++++++++- sgkit/tests/test_hwe.py | 99 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 5 deletions(-) diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index e054568f6..a8e2a4e02 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -1,7 +1,11 @@ +import dask.array as da import numpy as np +import xarray as xr from numba import njit +from numpy import ndarray +from xarray import Dataset -# TODO: Is there a way to get coverage on jit functions? +# TODO: Is there a way to get coverage on jit-compiled functions? def hardy_weinberg_p_value( @@ -95,4 +99,42 @@ def hardy_weinberg_p_value( return p -hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value) +# Benchmarks show ~25% improvement w/ fastmath on large (~10M) counts +hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value, fastmath=True) + + +def hardy_weinberg_p_value_vec( + obs_hets: ndarray, obs_hom1: ndarray, obs_hom2: ndarray +) -> ndarray: + arrs = [obs_hets, obs_hom1, obs_hom2] + if len(set(map(len, arrs))) != 1: + raise ValueError("All arrays must have same length") + if list(set(map(lambda x: x.ndim, arrs))) != [1]: + raise ValueError("All arrays must be 1D") + n = len(obs_hets) + p = np.empty(n, dtype=np.float64) + for i in range(n): + p[i] = hardy_weinberg_p_value_jit(obs_hets[i], obs_hom1[i], obs_hom2[i]) + return p + + +hardy_weinberg_p_value_vec_jit = njit(hardy_weinberg_p_value_vec, fastmath=True) + + +def hardy_weinberg_test(ds: Dataset): + if ds.dims["ploidy"] != 2: + raise NotImplementedError("HWE test only implemented for diploid genotypes") + if ds.dims["alleles"] != 2: + raise NotImplementedError("HWE test only implemented for biallelic genotypes") + if "call/allele_count" in ds: + ac = ds["call/allele_count"] + else: + # TODO: centralize allele counting like this somewhere + mask = ds["call/genotype_mask"].any(dim="ploidy") + ac = xr.where(mask, -1, ds["call/genotype"].sum(dim="ploidy")) + # Split into separate per-variant sums for homozygotes and heterozygotes; + # note that negative values will be ignored + cts = [1, 0, 2] # arg order: hets, hom1, hom2 + obs = [da.asarray((ac == ct).sum(dim="samples")) for ct in cts] + p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs) + return xr.Dataset({"variant/hwe_p_value": ("variants", p)}) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index acad42350..2e64cff84 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -1,11 +1,82 @@ +from typing import Optional, Tuple + import numpy as np import pytest +from sgkit.api import create_genotype_call_dataset from sgkit.stats.hwe import hardy_weinberg_p_value as hwep from sgkit.stats.hwe import hardy_weinberg_p_value_jit as hwep_jit +from sgkit.stats.hwe import hardy_weinberg_p_value_vec as hwep_vec +from sgkit.stats.hwe import hardy_weinberg_p_value_vec_jit as hwep_vec_jit +from sgkit.stats.hwe import hardy_weinberg_test as hwep_test + + +def to_genotype_call_dataset( + call_genotype, n_contig: int = 1, seed: Optional[int] = None +): + """Wrap a genotype call array in a Dataset instance + + Parameters + ---------- + call_genotype : (M, N, P) array-like + Genotype call array + n_contig : int, optional + Number of contigs to create in result, by default 1 + seed : int, optional + Seed for random number generation + + Returns + ------- + Dataset + Dataset from `sgkit.create_genotype_call_dataset` + """ + rs = np.random.RandomState(seed=seed) + m, n = call_genotype.shape[:2] + contig_size = np.ceil(m / n_contig).astype(int) + contig = np.arange(m) % contig_size + contig_names = np.unique(contig) + position = np.concatenate([np.arange(contig_size) for i in range(n_contig)])[:m] + alleles = rs.choice(["A", "C", "G", "T"], size=(m, 2)).astype("S") + sample_id = np.array([f"S{i}" for i in range(n)]) + return create_genotype_call_dataset( + variant_contig_names=list(contig_names), + variant_contig=contig, + variant_position=position, + variant_alleles=alleles, + sample_id=sample_id, + call_genotype=call_genotype, + ) + + +def simulate_genotype_calls(m: int, n: int, p: Tuple[float, float, float]): + """Get dataset with diploid calls simulated from genotype distribution + + Parameters + ---------- + m : int + Number of variants + n : int + Number of samples + p : Tuple[float, float, float] + Genotype distribution as float in [0, 1] with order + homozygous ref, heterozygous, homozygous alt + Returns + ------- + call_genotype: array-like + Dataset from `sgkit.create_genotype_call_dataset` + """ + rs = np.random.RandomState(1) + # Draw genotype codes with provided distribution + gt = np.stack([rs.choice([0, 1, 2], size=n, replace=True, p=p) for i in range(m)]) + # Expand 3rd dimenion with calls matching genotypes + return np.stack([np.where(gt == 0, 0, 1), np.where(gt == 2, 1, 0)], axis=-1) -def get_obs_gt_cts(): + +def get_genotype_counts(): + # Arguments for hwe calculations generated here + # match those generated externally for validation + # against C implementation (i.e. do not parameterize) n, step = 10_000, 50 rs = np.random.RandomState(0) n_het = np.expand_dims(np.arange(n, step=step) + 1, -1) @@ -16,7 +87,7 @@ def get_obs_gt_cts(): def test_hwep_against_reference_impl(): - args = get_obs_gt_cts() + args = get_genotype_counts() p = [hwep(*arg) for arg in args] np.testing.assert_allclose(p, EXPECTED_P_VAL) @@ -43,6 +114,28 @@ def test_hwep_large_counts(): assert np.isclose(p, 0, atol=1e-8) +def test_hwep_vec(): + args = get_genotype_counts() + p = hwep_vec(*args.T) + np.testing.assert_allclose(p, EXPECTED_P_VAL) + p = hwep_vec_jit(*args.T) + np.testing.assert_allclose(p, EXPECTED_P_VAL) + + +def test_hwep_dataset(): + gt_dist = [0.25, 0.5, 0.25] + call_genotype = simulate_genotype_calls(50, 1000, p=gt_dist) + ds = to_genotype_call_dataset(call_genotype) + p = hwep_test(ds)["variant/hwe_p_value"].values + assert np.all(p > 1e-8) + + gt_dist = [0.9, 0.05, 0.05] + call_genotype = simulate_genotype_calls(50, 1000, p=gt_dist) + ds = to_genotype_call_dataset(call_genotype) + p = hwep_test(ds)["variant/hwe_p_value"].values + assert np.all(p < 1e-8) + + # Export from execution of C/C++ code at http://csg.sph.umich.edu/abecasis/Exact/snp_hwe.c EXPECTED_P_VAL = [ 1.00000000e000, @@ -245,4 +338,4 @@ def test_hwep_large_counts(): 1.84978056e-007, 6.97911818e-017, 6.09976723e-137, -] \ No newline at end of file +] From fcdf6b532cf98fd0bc9a762af4ed0f5ab41c83c1 Mon Sep 17 00:00:00 2001 From: eric-czech Date: Wed, 8 Jul 2020 16:16:57 -0400 Subject: [PATCH 5/9] Adding tests for full coverage --- sgkit/tests/test_hwe.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index 2e64cff84..0f303429f 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import xarray as xr from sgkit.api import create_genotype_call_dataset from sgkit.stats.hwe import hardy_weinberg_p_value as hwep @@ -33,7 +34,7 @@ def to_genotype_call_dataset( rs = np.random.RandomState(seed=seed) m, n = call_genotype.shape[:2] contig_size = np.ceil(m / n_contig).astype(int) - contig = np.arange(m) % contig_size + contig = np.arange(m) // contig_size contig_names = np.unique(contig) position = np.concatenate([np.arange(contig_size) for i in range(n_contig)])[:m] alleles = rs.choice(["A", "C", "G", "T"], size=(m, 2)).astype("S") @@ -122,19 +123,42 @@ def test_hwep_vec(): np.testing.assert_allclose(p, EXPECTED_P_VAL) +def test_hwep_vec_bad_args(): + with pytest.raises(ValueError): + hwep_vec(np.zeros(2), np.zeros(1), np.zeros(1)) + with pytest.raises(ValueError): + hwep_vec(np.zeros((2, 2)), np.zeros(2), np.zeros(2)) + + def test_hwep_dataset(): + # Test cases in equilibrium gt_dist = [0.25, 0.5, 0.25] call_genotype = simulate_genotype_calls(50, 1000, p=gt_dist) ds = to_genotype_call_dataset(call_genotype) p = hwep_test(ds)["variant/hwe_p_value"].values assert np.all(p > 1e-8) + # Test cases out of equilibrium gt_dist = [0.9, 0.05, 0.05] call_genotype = simulate_genotype_calls(50, 1000, p=gt_dist) ds = to_genotype_call_dataset(call_genotype) p = hwep_test(ds)["variant/hwe_p_value"].values assert np.all(p < 1e-8) + # Test with pre-assigned counts + ds = ds.assign(**{"call/allele_count": ds["call/genotype"].sum(dim="ploidy")}) + p = hwep_test(ds) + assert np.all(p < 1e-8) + + +def test_hwep_dataset_bad_args(): + with pytest.raises(NotImplementedError): + ds = xr.Dataset({"x": (("ploidy", "alleles"), np.zeros((3, 2)))}) + hwep_test(ds) + with pytest.raises(NotImplementedError): + ds = xr.Dataset({"x": (("ploidy", "alleles"), np.zeros((2, 3)))}) + hwep_test(ds) + # Export from execution of C/C++ code at http://csg.sph.umich.edu/abecasis/Exact/snp_hwe.c EXPECTED_P_VAL = [ From d2ec9a5344a8c5eb7021d2363aea8540e77450dd Mon Sep 17 00:00:00 2001 From: eric-czech Date: Wed, 29 Jul 2020 07:57:25 -0400 Subject: [PATCH 6/9] Refactoring tests to match new conventions --- requirements-dev.txt | 1 + setup.cfg | 2 + sgkit/stats/hwe.py | 37 +-- sgkit/tests/test_hwe.py | 393 ++++++++------------------------ sgkit/tests/test_hwe/sim_01.csv | 201 ++++++++++++++++ 5 files changed, 317 insertions(+), 317 deletions(-) create mode 100644 sgkit/tests/test_hwe/sim_01.csv diff --git a/requirements-dev.txt b/requirements-dev.txt index c5a02b2b5..d17d2c369 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,5 @@ codecov pre-commit pytest pytest-cov +pytest-datadir statsmodels diff --git a/setup.cfg b/setup.cfg index 14693e3ba..4917f620e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-dask.*] ignore_missing_imports = True +[mypy-numba.*] +ignore_missing_imports = True [mypy-pytest.*] ignore_missing_imports = True [mypy-statsmodels.*] diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index a8e2a4e02..a6506bea0 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -1,3 +1,5 @@ +from typing import Hashable, Optional + import dask.array as da import numpy as np import xarray as xr @@ -5,12 +7,8 @@ from numpy import ndarray from xarray import Dataset -# TODO: Is there a way to get coverage on jit-compiled functions? - -def hardy_weinberg_p_value( - obs_hets: int, obs_hom1: int, obs_hom2: int -) -> float: # pragma: no cover +def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float: """Exact test for HWE as described in Wigginton et al. 2005 [1] Parameters @@ -48,7 +46,7 @@ def hardy_weinberg_p_value( het_probs = np.zeros(obs_mac + 1, dtype=np.float64) if obs_n == 0: - return np.nan + return np.nan # type: ignore[no-any-return] # Identify distribution midpoint mid = int(obs_mac * (2 * obs_n - obs_mac) / (2 * obs_n)) @@ -90,13 +88,13 @@ def hardy_weinberg_p_value( curr_homc -= 1 curr_hets += 2 - if prob_sum <= 0: - return np.nan + if prob_sum <= 0: # pragma: no cover + return np.nan # type: ignore[no-any-return] het_probs = het_probs / prob_sum p = het_probs[het_probs <= het_probs[obs_hets]].sum() p = max(min(1.0, p), 0.0) - return p + return p # type: ignore[no-any-return] # Benchmarks show ~25% improvement w/ fastmath on large (~10M) counts @@ -121,20 +119,23 @@ def hardy_weinberg_p_value_vec( hardy_weinberg_p_value_vec_jit = njit(hardy_weinberg_p_value_vec, fastmath=True) -def hardy_weinberg_test(ds: Dataset): +def hardy_weinberg_test( + ds: Dataset, genotype_counts: Optional[Hashable] = None +) -> Dataset: if ds.dims["ploidy"] != 2: raise NotImplementedError("HWE test only implemented for diploid genotypes") if ds.dims["alleles"] != 2: raise NotImplementedError("HWE test only implemented for biallelic genotypes") - if "call/allele_count" in ds: - ac = ds["call/allele_count"] + # Use precomputed genotype counts, if provided + if genotype_counts is not None: + obs = list(da.asarray(ds[genotype_counts]).T) + # Otherwise, compute genotype counts from calls else: - # TODO: centralize allele counting like this somewhere + # TODO: Use API genotype counting function instead, e.g. + # https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069 mask = ds["call/genotype_mask"].any(dim="ploidy") - ac = xr.where(mask, -1, ds["call/genotype"].sum(dim="ploidy")) - # Split into separate per-variant sums for homozygotes and heterozygotes; - # note that negative values will be ignored - cts = [1, 0, 2] # arg order: hets, hom1, hom2 - obs = [da.asarray((ac == ct).sum(dim="samples")) for ct in cts] + gtc = xr.where(mask, -1, ds["call/genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] + cts = [1, 0, 2] # arg order: hets, hom1, hom2 + obs = [da.asarray((gtc == ct).sum(dim="samples")) for ct in cts] p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs) return xr.Dataset({"variant/hwe_p_value": ("variants", p)}) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index 0f303429f..b92685e6b 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -1,62 +1,31 @@ -from typing import Optional, Tuple +from pathlib import Path +from typing import Tuple import numpy as np +import pandas as pd import pytest import xarray as xr +from pandas import DataFrame +from xarray import DataArray, Dataset -from sgkit.api import create_genotype_call_dataset from sgkit.stats.hwe import hardy_weinberg_p_value as hwep from sgkit.stats.hwe import hardy_weinberg_p_value_jit as hwep_jit from sgkit.stats.hwe import hardy_weinberg_p_value_vec as hwep_vec from sgkit.stats.hwe import hardy_weinberg_p_value_vec_jit as hwep_vec_jit from sgkit.stats.hwe import hardy_weinberg_test as hwep_test +from sgkit.testing import simulate_genotype_call_dataset -def to_genotype_call_dataset( - call_genotype, n_contig: int = 1, seed: Optional[int] = None -): - """Wrap a genotype call array in a Dataset instance +def simulate_genotype_calls( + n_variant: int, n_sample: int, p: Tuple[float, float, float], seed: int = 0 +) -> DataArray: + """Get dataset with diploid calls simulated from provided genotype distribution Parameters ---------- - call_genotype : (M, N, P) array-like - Genotype call array - n_contig : int, optional - Number of contigs to create in result, by default 1 - seed : int, optional - Seed for random number generation - - Returns - ------- - Dataset - Dataset from `sgkit.create_genotype_call_dataset` - """ - rs = np.random.RandomState(seed=seed) - m, n = call_genotype.shape[:2] - contig_size = np.ceil(m / n_contig).astype(int) - contig = np.arange(m) // contig_size - contig_names = np.unique(contig) - position = np.concatenate([np.arange(contig_size) for i in range(n_contig)])[:m] - alleles = rs.choice(["A", "C", "G", "T"], size=(m, 2)).astype("S") - sample_id = np.array([f"S{i}" for i in range(n)]) - return create_genotype_call_dataset( - variant_contig_names=list(contig_names), - variant_contig=contig, - variant_position=position, - variant_alleles=alleles, - sample_id=sample_id, - call_genotype=call_genotype, - ) - - -def simulate_genotype_calls(m: int, n: int, p: Tuple[float, float, float]): - """Get dataset with diploid calls simulated from genotype distribution - - Parameters - ---------- - m : int + n_variant : int Number of variants - n : int + n_sample : int Number of samples p : Tuple[float, float, float] Genotype distribution as float in [0, 1] with order @@ -64,47 +33,55 @@ def simulate_genotype_calls(m: int, n: int, p: Tuple[float, float, float]): Returns ------- - call_genotype: array-like - Dataset from `sgkit.create_genotype_call_dataset` + call_genotype : (variants, samples, ploidy) DataArray + Genotype call matrix as 3D array with ploidy = 2. """ - rs = np.random.RandomState(1) + rs = np.random.RandomState(seed) # Draw genotype codes with provided distribution - gt = np.stack([rs.choice([0, 1, 2], size=n, replace=True, p=p) for i in range(m)]) - # Expand 3rd dimenion with calls matching genotypes - return np.stack([np.where(gt == 0, 0, 1), np.where(gt == 2, 1, 0)], axis=-1) + gt = np.stack( + [ + rs.choice([0, 1, 2], size=n_sample, replace=True, p=p) + for i in range(n_variant) + ] + ) + # Expand 3rd dimension with calls matching genotypes + gt = np.stack([np.where(gt == 0, 0, 1), np.where(gt == 2, 1, 0)], axis=-1) + return xr.DataArray(gt, dims=("variants", "samples", "ploidy")) -def get_genotype_counts(): - # Arguments for hwe calculations generated here - # match those generated externally for validation - # against C implementation (i.e. do not parameterize) - n, step = 10_000, 50 - rs = np.random.RandomState(0) - n_het = np.expand_dims(np.arange(n, step=step) + 1, -1) - frac = rs.uniform(0.3, 0.7, size=(n // step, 2)) - n_hom = frac * n_het - n_hom = n_hom.astype(int) - return np.concatenate((n_het, n_hom), axis=1) +def get_simulation_data(datadir: Path) -> DataFrame: + return pd.read_csv(datadir / "sim_01.csv") -def test_hwep_against_reference_impl(): - args = get_genotype_counts() - p = [hwep(*arg) for arg in args] - np.testing.assert_allclose(p, EXPECTED_P_VAL) +def test_hwep__reference_impl_comparison(datadir): + df = get_simulation_data(datadir) + cts = df[["n_het", "n_hom_1", "n_hom_2"]].values + p_expected = df["p"].values + p_actual = hwep_vec(*cts.T) + np.testing.assert_allclose(p_expected, p_actual) + p_actual = hwep_vec_jit(*cts.T) + np.testing.assert_allclose(p_expected, p_actual) -def test_hwep_raise_on_negative(): +def test_hwep__raise_on_negative(): args = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]] for arg in args: with pytest.raises(ValueError): hwep(*arg) -def test_hwep_zeros(): +def test_hwep__zeros(): assert np.isnan(hwep(0, 0, 0)) -def test_hwep_large_counts(): +def test_hwep__pass(): + # These seemingly arbitrary arguments trigger separate conditional + # branches based on odd/even midpoints in the Levene-Haldane distribution + assert not np.isnan(hwep(1, 1, 1)) + assert not np.isnan(hwep(1, 2, 2)) + + +def test_hwep__large_counts(): # Note: use jit-compiled function for large counts to avoid slowing build down for n_het in 10 ** np.arange(3, 8): # Test case in perfect equilibrium @@ -115,251 +92,69 @@ def test_hwep_large_counts(): assert np.isclose(p, 0, atol=1e-8) -def test_hwep_vec(): - args = get_genotype_counts() - p = hwep_vec(*args.T) - np.testing.assert_allclose(p, EXPECTED_P_VAL) - p = hwep_vec_jit(*args.T) - np.testing.assert_allclose(p, EXPECTED_P_VAL) +def test_hwep_vec__raise_on_unequal_dims(): + with pytest.raises(ValueError, match="All arrays must have same length"): + hwep_vec(np.zeros(2), np.zeros(1), np.zeros(1)) -def test_hwep_vec_bad_args(): - with pytest.raises(ValueError): - hwep_vec(np.zeros(2), np.zeros(1), np.zeros(1)) - with pytest.raises(ValueError): +def test_hwep_vec__raise_on_non1d(): + with pytest.raises(ValueError, match="All arrays must be 1D"): hwep_vec(np.zeros((2, 2)), np.zeros(2), np.zeros(2)) -def test_hwep_dataset(): - # Test cases in equilibrium - gt_dist = [0.25, 0.5, 0.25] - call_genotype = simulate_genotype_calls(50, 1000, p=gt_dist) - ds = to_genotype_call_dataset(call_genotype) - p = hwep_test(ds)["variant/hwe_p_value"].values +@pytest.fixture(scope="module") +def ds_eq(): + """Dataset with all variants near HWE""" + ds = simulate_genotype_call_dataset(n_variant=50, n_sample=1000) + gt_dist = (0.25, 0.5, 0.25) + ds["call/genotype"] = simulate_genotype_calls( + ds.dims["variants"], ds.dims["samples"], p=gt_dist + ) + return ds + + +@pytest.fixture(scope="module") +def ds_neq(): + """Dataset with all variants well out of HWE""" + ds = simulate_genotype_call_dataset(n_variant=50, n_sample=1000) + gt_dist = (0.9, 0.05, 0.05) + ds["call/genotype"] = simulate_genotype_calls( + ds.dims["variants"], ds.dims["samples"], p=gt_dist + ) + return ds + + +def test_hwep_dataset__in_eq(ds_eq: Dataset) -> None: + p = hwep_test(ds_eq)["variant/hwe_p_value"].values assert np.all(p > 1e-8) - # Test cases out of equilibrium - gt_dist = [0.9, 0.05, 0.05] - call_genotype = simulate_genotype_calls(50, 1000, p=gt_dist) - ds = to_genotype_call_dataset(call_genotype) - p = hwep_test(ds)["variant/hwe_p_value"].values + +def test_hwep_dataset__out_of_eq(ds_neq: Dataset) -> None: + p = hwep_test(ds_neq)["variant/hwe_p_value"].values assert np.all(p < 1e-8) - # Test with pre-assigned counts - ds = ds.assign(**{"call/allele_count": ds["call/genotype"].sum(dim="ploidy")}) - p = hwep_test(ds) + +def test_hwep_dataset__precomputed_counts(ds_neq: Dataset) -> None: + ds = ds_neq + ac = ds["call/genotype"].sum(dim="ploidy") + cts = [1, 0, 2] # arg order: hets, hom1, hom2 + gtc = xr.concat([(ac == ct).sum(dim="samples") for ct in cts], dim="counts").T # type: ignore[no-untyped-call] + ds = ds.assign(**{"variant/genotype_counts": gtc}) + p = hwep_test(ds, genotype_counts="variant/genotype_counts") assert np.all(p < 1e-8) -def test_hwep_dataset_bad_args(): - with pytest.raises(NotImplementedError): +def test_hwep_dataset__raise_on_nondiploid(): + with pytest.raises( + NotImplementedError, match="HWE test only implemented for diploid genotypes" + ): ds = xr.Dataset({"x": (("ploidy", "alleles"), np.zeros((3, 2)))}) hwep_test(ds) - with pytest.raises(NotImplementedError): - ds = xr.Dataset({"x": (("ploidy", "alleles"), np.zeros((2, 3)))}) - hwep_test(ds) -# Export from execution of C/C++ code at http://csg.sph.umich.edu/abecasis/Exact/snp_hwe.c -EXPECTED_P_VAL = [ - 1.00000000e000, - 8.45926829e-001, - 8.89304245e-001, - 3.68487492e-001, - 2.83442131e-001, - 1.93780506e-001, - 3.46415612e-002, - 9.77805142e-007, - 9.00169099e-002, - 2.77392776e-004, - 5.78595078e-006, - 1.56290046e-001, - 3.11983705e-002, - 7.78234779e-001, - 6.28255056e-001, - 9.17242816e-001, - 8.81087089e-001, - 1.20954751e-004, - 6.51960684e-002, - 4.87927509e-007, - 6.14320396e-002, - 1.67216769e-003, - 2.58323982e-003, - 9.22666204e-012, - 1.15591803e-003, - 1.00000000e000, - 5.21303203e-001, - 2.40595832e-012, - 1.79017126e-001, - 8.50964237e-004, - 4.08782584e-018, - 2.65625649e-003, - 1.73047163e-007, - 2.61257337e-002, - 3.40282167e-002, - 5.57265342e-006, - 2.28187711e-010, - 3.71009969e-005, - 2.02796027e-015, - 2.85690782e-015, - 4.43715904e-004, - 1.24880234e-005, - 1.39680904e-002, - 6.69133747e-009, - 9.43219724e-010, - 6.10161450e-001, - 1.93499955e-003, - 1.44451527e-014, - 1.15651799e-011, - 6.16416362e-006, - 2.18519190e-001, - 2.67902896e-020, - 3.81265044e-003, - 1.87170429e-002, - 2.87276124e-001, - 1.46939801e-004, - 5.90523804e-001, - 9.00712608e-003, - 7.82143524e-011, - 1.55029275e-016, - 1.00796610e-003, - 6.51775272e-018, - 7.22627291e-001, - 3.50621941e-033, - 2.15694037e-001, - 5.36554440e-001, - 4.98209450e-023, - 1.00725415e-002, - 2.83256119e-004, - 2.31647615e-001, - 5.40831311e-004, - 2.28693251e-006, - 2.33943256e-016, - 4.63666449e-002, - 1.95571664e-029, - 1.32013500e-001, - 1.93010279e-006, - 1.72246817e-002, - 4.44008208e-010, - 2.64771353e-025, - 1.42567926e-002, - 2.34658222e-023, - 5.14985651e-044, - 4.48467881e-038, - 2.38901290e-003, - 3.00019737e-020, - 9.91998679e-058, - 3.85771324e-001, - 1.19901665e-004, - 1.09586529e-012, - 4.52696626e-007, - 4.52117435e-005, - 3.74269466e-022, - 1.84769664e-002, - 9.01235925e-001, - 4.71167421e-016, - 7.26213285e-001, - 2.68067642e-005, - 1.95763513e-027, - 3.44681033e-030, - 6.72973257e-001, - 1.90998085e-021, - 2.71129678e-092, - 1.33474542e-002, - 9.42328262e-016, - 6.04559513e-002, - 2.73568136e-002, - 3.45497420e-013, - 1.85964309e-010, - 2.25791165e-016, - 8.88002002e-023, - 7.31645858e-001, - 6.20103273e-001, - 2.02013957e-003, - 3.26543825e-041, - 9.55096556e-034, - 1.58435946e-031, - 1.67723973e-017, - 3.01571822e-004, - 5.94647843e-004, - 3.50999380e-003, - 1.42692287e-018, - 4.40701593e-002, - 1.02072821e-010, - 6.12844453e-020, - 4.01149386e-007, - 4.52329633e-028, - 6.36621011e-004, - 2.40691727e-003, - 1.51079564e-004, - 1.46439431e-059, - 1.19603499e-007, - 2.30499126e-023, - 3.90483620e-004, - 3.00491712e-033, - 4.67334134e-075, - 2.14446525e-007, - 5.74808603e-002, - 7.54901939e-059, - 1.00820382e-028, - 5.45503604e-002, - 2.00408985e-029, - 2.60055020e-038, - 1.37950333e-021, - 1.67336706e-003, - 5.11497091e-038, - 9.63001456e-002, - 1.85048263e-012, - 7.60512104e-005, - 1.90260703e-097, - 8.41707732e-055, - 5.02772009e-056, - 4.74769747e-021, - 1.53427038e-108, - 3.65547065e-022, - 3.59345583e-005, - 4.29008968e-115, - 2.29690838e-003, - 5.12962271e-001, - 2.82010264e-044, - 1.25488919e-059, - 4.26516777e-072, - 2.92597766e-014, - 1.13938024e-020, - 2.65101694e-019, - 6.39260807e-003, - 3.44575391e-019, - 2.46964669e-042, - 2.18893082e-023, - 2.32535921e-005, - 3.67548497e-033, - 6.28178465e-050, - 4.01855250e-010, - 8.14210277e-007, - 7.19942047e-038, - 1.23293898e-028, - 1.04555107e-001, - 2.80977631e-008, - 3.38829632e-065, - 3.67682844e-014, - 7.97794167e-001, - 9.88137129e-001, - 7.83054274e-016, - 6.10205517e-003, - 3.54737998e-051, - 1.00000000e000, - 1.23015267e-024, - 7.06536040e-069, - 2.27403687e-082, - 2.12853071e-001, - 2.09868517e-014, - 4.20835611e-040, - 1.72349554e-079, - 1.58828256e-003, - 6.46108778e-001, - 1.80557310e-058, - 2.70043232e-001, - 1.84978056e-007, - 6.97911818e-017, - 6.09976723e-137, -] +def test_hwep_dataset__raise_on_biallelic(): + with pytest.raises( + NotImplementedError, match="HWE test only implemented for biallelic genotypes" + ): + ds = xr.Dataset({"x": (("ploidy", "alleles"), np.zeros((2, 3)))}) + hwep_test(ds) diff --git a/sgkit/tests/test_hwe/sim_01.csv b/sgkit/tests/test_hwe/sim_01.csv new file mode 100644 index 000000000..97bbb05d7 --- /dev/null +++ b/sgkit/tests/test_hwe/sim_01.csv @@ -0,0 +1,201 @@ +n_het,n_hom_1,n_hom_2,p +1,0,0,1.0 +51,27,26,0.845926828898329 +101,47,56,0.88930424473698 +151,71,99,0.3684874920023832 +201,137,91,0.28344213097688165 +251,154,128,0.1937805064539818 +301,158,201,0.03464156123462855 +351,115,117,9.778051416840305e-07 +401,123,253,0.09001690989721362 +451,275,292,0.00027739277561481717 +501,346,310,5.785950777451914e-06 +551,267,337,0.15629004579821143 +601,208,334,0.031198370505333042 +651,232,441,0.7782347792381512 +701,356,326,0.6282550560763566 +751,304,457,0.9172428161320041 +801,386,422,0.8810870885296063 +851,261,465,0.00012095475067049487 +901,490,492,0.06519606837210713 +951,644,544,4.879275092149919e-07 +1001,444,475,0.0614320395959403 +1051,608,340,0.0016721676932256468 +1101,623,625,0.002583239824235713 +1151,442,404,9.22666204121466e-12 +1201,511,535,0.0011559180266171758 +1251,660,594,0.999999999999999 +1301,904,443,0.5213032029206021 +1351,518,492,2.4059583160333022e-12 +1401,786,562,0.17901712606288267 +1451,705,577,0.0008509642369958977 +1501,545,516,4.0878258413432674e-18 +1551,872,551,0.0026562564901761523 +1601,606,716,1.7304716281165935e-07 +1651,1037,559,0.026125733725720014 +1701,1080,575,0.03402821665047714 +1751,1209,853,5.572653418323456e-06 +1801,1243,976,2.2818771131803835e-10 +1851,1102,584,3.710099688663844e-05 +1901,785,661,2.0279602738807187e-15 +1951,816,677,2.8569078183446102e-15 +2001,854,931,0.00044371590402349367 +2051,667,1183,1.2488023350848009e-05 +2101,1106,853,0.01396809036282682 +2151,1095,726,6.691337471001894e-09 +2201,1167,1478,9.432197243064005e-10 +2251,962,1276,0.610161450028642 +2301,811,1349,0.0019349995547350155 +2351,977,877,1.444515271924519e-14 +2401,1283,739,1.1565179920072941e-11 +2451,1547,739,6.164163620718031e-06 +2501,1428,1020,0.21851919021140662 +2551,1515,1747,2.6790289570097594e-20 +2601,1039,1379,0.003812650444672208 +2651,1423,1402,0.018717042880622592 +2701,1051,1839,0.2872761239911692 +2751,1317,1756,0.00014693980061811332 +2801,1623,1173,0.5905238038361554 +2851,1783,1307,0.009007126080770289 +2901,1892,1544,7.821435236313817e-11 +2951,1926,1702,1.5502927534087137e-16 +3001,1770,1502,0.0010079660986331209 +3051,2082,1701,6.517752718327402e-18 +3101,1456,1682,0.7226272913884884 +3151,969,1325,3.506219413030305e-33 +3201,1805,1331,0.21569403737115994 +3251,1778,1532,0.5365544398754494 +3301,1169,1384,4.982094499906052e-23 +3351,1769,1797,0.010072541488350021 +3401,1801,1908,0.00028325611912255233 +3451,1935,1630,0.23164761482387491 +3501,2305,1565,0.0005408313105776889 +3551,1684,2332,2.2869325074804336e-06 +3601,2241,2094,2.3394325551477485e-16 +3651,1241,2438,0.04636664490239822 +3701,2167,2588,1.9557166448260108e-29 +3751,1349,2427,0.13201350023523056 +3801,1387,2076,1.9301027903365754e-06 +3851,1346,2461,0.017224681698312864 +3901,2430,2058,4.440082083118905e-10 +3951,1828,1294,2.647713527040051e-25 +4001,2316,1926,0.014256792586706637 +4051,2385,2619,2.3465822237767657e-23 +4101,2830,2634,5.149856509652088e-44 +4151,1264,1843,4.48467880717666e-38 +4201,2486,1548,0.002389012903737995 +4251,2161,1367,3.000197368786019e-20 +4301,1634,1322,9.919986792554593e-58 +4351,2686,1695,0.38577132434045974 +4401,1928,2954,0.0001199016650457927 +4451,2589,1391,1.0958652901993421e-12 +4501,1646,2469,4.5269662611402833e-07 +4551,2416,1798,4.521174352826995e-05 +4601,3099,2510,3.742694664287275e-22 +4651,2391,2492,0.01847696637280268 +4701,2783,1996,0.9012359250831754 +4751,2182,1824,4.711674213296269e-16 +4801,1797,3253,0.7262132846779064 +4851,2890,2406,2.680676415905425e-05 +4901,1916,1968,1.9576351291926123e-27 +4951,1600,2345,3.4468103297849734e-30 +5001,2124,2893,0.6729732567000968 +5051,2278,1878,1.9099808504594497e-21 +5101,1580,1667,2.7112967783602924e-92 +5151,2945,2480,0.0133474541803742 +5201,2676,3425,9.423282617882861e-16 +5251,3655,2030,0.06045595125773339 +5301,2996,2148,0.027356813588423446 +5351,1649,3228,3.454974198424328e-13 +5401,2311,2448,1.8596430888102382e-10 +5451,2918,3447,2.2579116525279354e-16 +5501,3034,3570,8.880020019491782e-23 +5551,2272,3437,0.7316458577549785 +5601,2096,3814,0.6201032727285236 +5651,3249,2182,0.002020139572176394 +5701,3870,3376,3.265438252095903e-41 +5751,2309,2216,9.550965561735892e-34 +5801,2942,1799,1.584359459942145e-31 +5851,2240,2749,1.6772397283662627e-17 +5901,2653,2864,0.0003015718221877463 +5951,2446,3182,0.0005946478430517058 +6001,3873,2082,0.0035099937981355705 +6051,3067,2134,1.4269228740338587e-18 +6101,3579,2796,0.044070159340087074 +6151,3236,2296,1.0207282058200817e-10 +6201,2219,3070,6.128444533922714e-20 +6251,2764,4226,4.0114938634006145e-07 +6301,3819,3777,4.523296332293788e-28 +6351,4201,2117,0.0006366210109203425 +6401,3334,3416,0.0024069172689109197 +6451,4417,2689,0.0001510795635284454 +6501,2576,2211,1.464394314932261e-59 +6551,2008,4401,1.1960349880178966e-07 +6601,3749,4053,2.3049912620301826e-23 +6651,2744,3555,0.00039048361950001163 +6701,2181,3311,3.0049171175443936e-33 +6751,4664,4392,4.673341342704634e-75 +6801,2960,4656,2.1444652549020045e-07 +6851,2690,4656,0.05748086029092399 +6901,4668,4276,7.549019391941723e-59 +6951,3838,4516,1.0082038188372868e-28 +7001,2920,4477,0.05455036043254638 +7051,3857,2152,2.004089851444536e-29 +7101,3116,2551,2.600550198892281e-38 +7151,4953,3513,1.3795033294064868e-21 +7201,3592,4002,0.0016733670587886574 +7251,3244,2572,5.114970910833606e-38 +7301,4591,2744,0.09630014560631561 +7351,3708,2864,1.850482626161518e-12 +7401,2509,4772,7.605121039219745e-05 +7451,5134,5098,1.9026070284002625e-97 +7501,4970,4572,8.417077321620368e-55 +7551,3271,2510,5.027720092225525e-56 +7601,3518,2986,4.747697469921551e-21 +7651,2700,2458,1.5342703779996893e-108 +7701,4545,2345,3.6554706456314347e-22 +7751,4714,2780,3.593455829002233e-05 +7801,2588,2619,4.290089679393479e-115 +7851,4465,3125,0.002296908375565855 +7901,3699,4131,0.5129622712973023 +7951,5122,4697,2.8201026369032667e-44 +8001,3265,2821,1.2548891947495432e-59 +8051,2593,3386,4.265167768158848e-72 +8101,3279,3908,2.925977657425333e-14 +8151,4673,4713,1.1393802442050531e-20 +8201,3390,3706,2.6510169444939377e-19 +8251,3073,5077,0.006392608069466881 +8301,2679,4804,3.445753907043592e-19 +8351,5106,5102,2.4696466905259013e-42 +8401,3392,3776,2.188930822752493e-23 +8451,4521,3457,2.325359213936893e-05 +8501,3811,3220,3.675484972066898e-33 +8551,4138,2717,6.2817846465740635e-50 +8601,5331,2845,4.01855250200448e-10 +8651,4390,3656,8.142102766041389e-07 +8701,4620,5949,7.199420468330689e-38 +8751,4885,2749,1.232938976024364e-28 +8801,4155,4435,0.10455510731861506 +8851,4553,5067,2.8097763112411485e-08 +8901,3658,3129,3.388296324761072e-65 +8951,4091,6109,3.676828438505e-14 +9001,3374,5955,0.7977941668959887 +9051,4684,4369,0.9881371290166111 +9101,5941,4399,7.830542740608026e-16 +9151,5396,4205,0.006102055167403756 +9201,6087,5299,3.5473799808013304e-51 +9251,5364,3987,1.0 +9301,5605,5156,1.2301526743633959e-24 +9351,3703,3405,7.065360404513176e-69 +9401,5815,6427,2.27403687143121e-82 +9451,4567,5069,0.2128530705427019 +9501,6109,4587,2.098685170824773e-14 +9551,6501,5064,4.208356112420144e-40 +9601,6032,6370,1.7234955440937876e-79 +9651,6043,3510,0.001588282559753373 +9701,5350,4456,0.6461087776569374 +9751,3169,4579,1.8055731034783026e-58 +9801,3954,6268,0.270043231915957 +9851,3086,6734,1.8497805574159475e-07 +9901,4377,4383,6.979118182818968e-17 +9951,3050,3722,6.099767234370392e-137 From eaec6f37dc505ea134e20de1819525e23fcf3734 Mon Sep 17 00:00:00 2001 From: eric-czech Date: Wed, 29 Jul 2020 08:35:19 -0400 Subject: [PATCH 7/9] Cleaning up docs --- sgkit/stats/association.py | 14 +++++------ sgkit/stats/hwe.py | 51 ++++++++++++++++++++++++++++++++------ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/sgkit/stats/association.py b/sgkit/stats/association.py index 7e63e1b77..9fd56bf81 100644 --- a/sgkit/stats/association.py +++ b/sgkit/stats/association.py @@ -183,7 +183,7 @@ def gwas_linear_regression( 2D trait arrays will be assumed to contain separate traits within columns and concatenated to any 1D traits along the second axis (columns). add_intercept : bool, optional - Add intercept term to covariate set, by default True + Add intercept term to covariate set, by default True. Warnings -------- @@ -210,12 +210,12 @@ def gwas_linear_regression( ------- Dataset Dataset containing (N = num variants, O = num traits): - beta : (N, O) array-like - Beta values associated with each variant and trait - t_value : (N, O) array-like - T statistics for each beta - p_value : (N, O) array-like - P values as float in [0, 1] + variant/beta : (N, O) ArrayLike + Beta values associated with each variant and trait. + variant/t_value : (N, O) ArrayLike + T statistics for each beta. + variant/p_value : (N, O) ArrayLike + P values as float in [0, 1]. """ G = _get_loop_covariates(ds, dosage=dosage) X = _get_core_covariates(ds, covariates, add_intercept=add_intercept) diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index a6506bea0..0f4b52597 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -9,16 +9,16 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float: - """Exact test for HWE as described in Wigginton et al. 2005 [1] + """Exact test for HWE as described in Wigginton et al. 2005 [1]. Parameters ---------- obs_hets : int - Number of heterozygotes with minor variant + Number of heterozygotes with minor variant. obs_hom1 : int - Number of reference/major homozygotes + Number of reference/major homozygotes. obs_hom2 : int - Number of alternate/minor homozygotes + Number of alternate/minor homozygotes. Returns ------- @@ -34,7 +34,7 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float Raises ------ ValueError - If any observed counts are negative + If any observed counts are negative. """ if obs_hom1 < 0 or obs_hom2 < 0 or obs_hets < 0: raise ValueError("Observed genotype counts must be positive") @@ -122,14 +122,51 @@ def hardy_weinberg_p_value_vec( def hardy_weinberg_test( ds: Dataset, genotype_counts: Optional[Hashable] = None ) -> Dataset: + """Exact test for HWE as described in Wigginton et al. 2005 [1]. + + Parameters + ---------- + ds : Dataset + Dataset containing genotype calls or precomputed genotype counts. + genotype_counts : Optional[Hashable], optional + Name of variable containing precomputed genotype counts, by default + None. If not provided, these counts will be computed automatically + from genotype calls. If present, must correspond to an (`N`, 3) array + where `N` is equal to the number of variants and the 3 columns contain + heterozygous, homozygous reference, and homozygous alternate counts + (in that order) across all samples for a variant. + + Warnings + -------- + This function is only applicable to diploid, biallelic datasets. + + Returns + ------- + Dataset + Dataset containing (N = num variants): + variant/hwe_p_value : (N,) ArrayLike + P values from HWE test for each variant as float in [0, 1]. + + References + ---------- + - [1] Wigginton, Janis E., David J. Cutler, and Goncalo R. Abecasis. 2005. + “A Note on Exact Tests of Hardy-Weinberg Equilibrium.” American Journal of + Human Genetics 76 (5): 887–93. + + Raises + ------ + NotImplementedError + * If ploidy of provided dataset != 2 + * If maximum number of alleles in provided dataset != 2 + """ if ds.dims["ploidy"] != 2: raise NotImplementedError("HWE test only implemented for diploid genotypes") if ds.dims["alleles"] != 2: raise NotImplementedError("HWE test only implemented for biallelic genotypes") - # Use precomputed genotype counts, if provided + # Use precomputed genotype counts if provided if genotype_counts is not None: obs = list(da.asarray(ds[genotype_counts]).T) - # Otherwise, compute genotype counts from calls + # Otherwise compute genotype counts from calls else: # TODO: Use API genotype counting function instead, e.g. # https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069 From 1ef1072ec39260cb13944b407376b3c0302f976b Mon Sep 17 00:00:00 2001 From: eric-czech Date: Fri, 7 Aug 2020 12:12:16 -0400 Subject: [PATCH 8/9] Fix typo in test name --- sgkit/tests/test_hwe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index b92685e6b..0b43ad75d 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -152,7 +152,7 @@ def test_hwep_dataset__raise_on_nondiploid(): hwep_test(ds) -def test_hwep_dataset__raise_on_biallelic(): +def test_hwep_dataset__raise_on_nonbiallelic(): with pytest.raises( NotImplementedError, match="HWE test only implemented for biallelic genotypes" ): From 835e745c81ac24c7823eef40e6c64c0f9ef91e71 Mon Sep 17 00:00:00 2001 From: eric-czech Date: Fri, 7 Aug 2020 12:21:36 -0400 Subject: [PATCH 9/9] Update variable names for new convention --- sgkit/stats/hwe.py | 8 ++++---- sgkit/tests/test_hwe.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index 0f4b52597..607cb389c 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -144,7 +144,7 @@ def hardy_weinberg_test( ------- Dataset Dataset containing (N = num variants): - variant/hwe_p_value : (N,) ArrayLike + variant_hwe_p_value : (N,) ArrayLike P values from HWE test for each variant as float in [0, 1]. References @@ -170,9 +170,9 @@ def hardy_weinberg_test( else: # TODO: Use API genotype counting function instead, e.g. # https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069 - mask = ds["call/genotype_mask"].any(dim="ploidy") - gtc = xr.where(mask, -1, ds["call/genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] + mask = ds["call_genotype_mask"].any(dim="ploidy") + gtc = xr.where(mask, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] cts = [1, 0, 2] # arg order: hets, hom1, hom2 obs = [da.asarray((gtc == ct).sum(dim="samples")) for ct in cts] p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs) - return xr.Dataset({"variant/hwe_p_value": ("variants", p)}) + return xr.Dataset({"variant_hwe_p_value": ("variants", p)}) diff --git a/sgkit/tests/test_hwe.py b/sgkit/tests/test_hwe.py index 0b43ad75d..6c3a66cb0 100644 --- a/sgkit/tests/test_hwe.py +++ b/sgkit/tests/test_hwe.py @@ -107,7 +107,7 @@ def ds_eq(): """Dataset with all variants near HWE""" ds = simulate_genotype_call_dataset(n_variant=50, n_sample=1000) gt_dist = (0.25, 0.5, 0.25) - ds["call/genotype"] = simulate_genotype_calls( + ds["call_genotype"] = simulate_genotype_calls( ds.dims["variants"], ds.dims["samples"], p=gt_dist ) return ds @@ -118,29 +118,29 @@ def ds_neq(): """Dataset with all variants well out of HWE""" ds = simulate_genotype_call_dataset(n_variant=50, n_sample=1000) gt_dist = (0.9, 0.05, 0.05) - ds["call/genotype"] = simulate_genotype_calls( + ds["call_genotype"] = simulate_genotype_calls( ds.dims["variants"], ds.dims["samples"], p=gt_dist ) return ds def test_hwep_dataset__in_eq(ds_eq: Dataset) -> None: - p = hwep_test(ds_eq)["variant/hwe_p_value"].values + p = hwep_test(ds_eq)["variant_hwe_p_value"].values assert np.all(p > 1e-8) def test_hwep_dataset__out_of_eq(ds_neq: Dataset) -> None: - p = hwep_test(ds_neq)["variant/hwe_p_value"].values + p = hwep_test(ds_neq)["variant_hwe_p_value"].values assert np.all(p < 1e-8) def test_hwep_dataset__precomputed_counts(ds_neq: Dataset) -> None: ds = ds_neq - ac = ds["call/genotype"].sum(dim="ploidy") + ac = ds["call_genotype"].sum(dim="ploidy") cts = [1, 0, 2] # arg order: hets, hom1, hom2 gtc = xr.concat([(ac == ct).sum(dim="samples") for ct in cts], dim="counts").T # type: ignore[no-untyped-call] - ds = ds.assign(**{"variant/genotype_counts": gtc}) - p = hwep_test(ds, genotype_counts="variant/genotype_counts") + ds = ds.assign(**{"variant_genotype_counts": gtc}) + p = hwep_test(ds, genotype_counts="variant_genotype_counts") assert np.all(p < 1e-8)