From 2c96d74046f41d6fac1c46c8c481dfaa229b1074 Mon Sep 17 00:00:00 2001 From: atongsa Date: Thu, 6 Jul 2023 20:42:02 +0800 Subject: [PATCH] 20230706_pbs --- python/tests/test_tree_stats.py | 123 ++++++++++++++++++++++++++++++++ python/tskit/trees.py | 94 ++++++++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 7725931b73..3ac7f1c886 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -6254,3 +6254,126 @@ def f_too_long(_): output_dim=1, strict=False, ) + + +# pbs +def test_pbs_windows(self): + ts = self.get_example_ts() + self.verify_three_way_stat_windows(ts, ts.pbs) + +def single_site_pbs(ts, sample_sets, indexes): + """ + index: + the index is, [[foucus_pop,sister_pop], [focus_pop,far_pop], [sister_pop,far_pop]] + output: + [focus_pbs, sister_pbs, far_pbs] + first: + Compute single-site Fst, which between two groups x and y, with frequencies p and q is + Fst_xy = 1 - 2 * (p (1-p) + q(1-q)) / ( p(1-p) + q(1-q) + p(1-q) + q(1-p) ) + or in the multiallelic case, replacing p(1-p) with the sum over alleles of p(1-p), + and adjusted for sampling without replacement. + then: + compute pbs of selected populations: + pbs_x = (-math.log(1-Fst_xy, 10) + -math.log(1-Fst_xz, 10) - + -math.log(1-Fst_yz, 10)) / 2 + """ + # TODO: what to do in this case? + if ts.num_sites == 0: + out = np.array([np.repeat(np.nan, len(indexes))]) + return out + out = np.zeros((ts.num_sites, len(indexes))) + samples = ts.samples() + # TODO deal with missing data properly. + for j, v in enumerate(ts.variants(isolated_as_missing=False)): + for i, (ix, iy) in enumerate(indexes): + g = v.genotypes + X = sample_sets[ix] + Y = sample_sets[iy] + gX = [a for k, a in zip(samples, g) if k in X] + gY = [a for k, a in zip(samples, g) if k in Y] + nX = len(X) + nY = len(Y) + dX = dY = dXY = 0 + for a in set(g): + fX = np.sum(gX == a) + fY = np.sum(gY == a) + with suppress_division_by_zero_warning(): + dX += fX * (nX - fX) / (nX * (nX - 1)) + dY += fY * (nY - fY) / (nY * (nY - 1)) + dXY += (fX * (nY - fY) + (nX - fX) * fY) / (2 * nX * nY) + with suppress_division_by_zero_warning(): + out[j][i] = 1 - 2 * (dX + dY) / (dX + dY + 2 * dXY) + + def cal_pbs(arr_0): + arr = 1 - arr_0 + arr[:, 0] = (-np.log10(arr[:, 0]) + np.log10(arr[:, 1]) - np.log10(arr[:, 2]))/2 + arr[:, 1] = (-np.log10(arr[:, 0]) + np.log10(arr[:, 2]) - np.log10(arr[:, 1]))/2 + arr[:, 2] = (-np.log10(arr[:, 1]) + np.log10(arr[:, 2]) - np.log10(arr[:, 0]))/2 + return arr + + pbs = cal_pbs(out) + return pbs + +def pbs(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True): + """ + population branch statistic(PBS) definitions. + """ + windows = ts.parse_windows(windows) + if indexes is None: + indexes = [([0,1],[0,2],[1,2])] + method_map = {"site": single_site_pbs} + return method_map[mode]( + ts, sample_sets, indexes=indexes, windows=windows, span_normalise=span_normalise + ) + + +class Testpbs(StatsTestCase, ThreeWaySampleSetStatsMixin): + + # Derived classes define this to get a specific stats mode. + mode = None + + def verify(self, ts): + # only check per-site + for sample_sets in example_sample_sets(ts, min_size=3): + for indexes in example_sample_set_index_pairs(sample_sets): + self.verify_persite_pbs(ts, sample_sets, indexes) + + def verify_persite_pbs(self, ts, sample_sets, indexes): + sigma1 = ts.pbs( + sample_sets, + indexes=indexes, + windows="sites", + mode=self.mode, + span_normalise=False, + ) + sigma2 = single_site_pbs(ts, sample_sets, indexes) + assert sigma1.shape == sigma2.shape + self.assertArrayAlmostEqual(sigma1, sigma2) + + +class pbsInterfaceMixin: + def test_interface(self): + ts = msprime.simulate(10, mutation_rate=0.0) + sample_sets = [[0, 1, 2], [6, 7], [4]] + with pytest.raises(ValueError): + ts.pbs(sample_sets, mode=self.mode) + with pytest.raises(ValueError): + ts.pbs(sample_sets, indexes=[(0, 1), (0,2)], mode=self.mode) + with pytest.raises(tskit.LibraryError): + ts.pbs(sample_sets, indexes=[(0, 1), (0, 20), (1,20)]) + sigma1 = ts.pbs(sample_sets, indexes=[(0, 1), (1,2), (0,2)], mode=self.mode) + sigma2 = ts.pbs(sample_sets, indexes=[(0, 1), (0, 2), (1, 2)], mode=self.mode) + self.assertArrayAlmostEqual(sigma1[..., 0], sigma2[..., 0]) + +class TestSitepbs(Testpbs, MutatedTopologyExamplesMixin, pbsInterfaceMixin): + mode = "site" + +# Since pbs is defined using diversity and divergence and fst, we don't seriously +# test it for correctness for node and branch, and only test the interface. + +class TestNodepbs(StatsTestCase, pbsInterfaceMixin): + mode = "node" + +class TestBranchpbs(StatsTestCase, pbsInterfaceMixin): + mode = "branch" + diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9ccae3488d..685ff39d79 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -9109,3 +9109,97 @@ def write_ms( ) else: print(file=output) + + + +# PBS + def pbs( + self, sample_sets,indexes=None,windows=None,mode="site",span_normalise=True + ): + """ + Computes "Windowed" population branch statistics(PBS) across trio pop sets of nodes from ``sample_sets``; + operates on ``k = 3`` sample sets at a time; + please see the + :ref:`multi-way statistics ` + section for details on how the ``sample_sets`` and ``indexes`` arguments are + interpreted and how they interact with the dimensions of the output array. + See the :ref:`statistics interface ` section for details on + :ref:`windows `, + :ref:`mode `, + :ref:`span normalise `, + and :ref:`return value `. + + For sample sets ``X``, ``Y``, ``Z``, + ``X`` is the focus population; + ``Y`` is the sister population of ``X``; + ``Z`` is the far population of ``X``; + if ``d(X, Y)`` is the + :meth:`divergence <.TreeSequence.divergence>` + between ``X`` and ``Y``, and ``d(X)`` is the + :meth:`diversity <.TreeSequence.diversity>` of ``X``, then what is + computed is + + .. code-block:: python + Fst_xy = 1 - 2 * (d(X) + d(Y)) / (d(X) + 2 * d(X, Y) + d(Y)) + Fst_xz = 1 - 2 * (d(X) + d(Z)) / (d(X) + 2 * d(X, Z) + d(Z)) + Fst_yz = 1 - 2 * (d(Y) + d(Z)) / (d(Y) + 2 * d(Y, Z) + d(Z)) + + pbs_x = '(-math.log(1-Fst_xy, 10)' + '-math.log(1-Fst_xz, 10)' + '+math.log(1-Fst_yz, 10)') / 2 + pbs_y = '(-math.log(1-Fst_xy, 10)' + '-math.log(1-Fst_yz, 10)' + '+math.log(1-Fst_xz, 10)) / 2' + pbs_z = '(-math.log(1-Fst_xz, 10)' + '-math.log(1-Fst_yz, 10)' + '+math.log(1-Fst_xy, 10)) / 2' + + """ + + def pbs_func(sample_set_sizes, flattened, indexes], **kwargs): + diversities = self._ll_tree_sequence.diversity( + sample_set_sizes, flattened, **kwargs + ) + divergences = self._ll_tree_sequence.divergence( + sample_set_sizes, flattened, indexes, **kwargs + ) + + orig_shape = divergences.shape + # "node" statistics produce a 3D array + if len(divergences.shape) == 2: + divergences.shape = (divergences.shape[0], 1, divergences.shape[1]) + diversities.shape = (diversities.shape[0], 1, diversities.shape[1]) + + fst = np.repeat(1.0, np.product(divergences.shape)) + fst.shape = divergences.shape + for i, (u, v) in enumerate(indexes): + denom = ( + diversities[:, :, u] + + diversities[:, :, v] + + 2 * divergences[:, :, i] + ) + with np.errstate(divide="ignore", invalid="ignore"): + fst[:, :, i] -= ( + 2 * (div=ersities[:, :, u] + diversities[:, :, v]) / denom + ) + fst.shape = orig_shape + + def cal_pbs(arr_0): + arr = 1 - arr_0 + arr[:, 0] = (-np.log10(arr[:, 0]) + np.log10(arr[:, 1]) - np.log10(arr[:, 2]))/2 + arr[:, 1] = (-np.log10(arr[:, 0]) + np.log10(arr[:, 2]) - np.log10(arr[:, 1]))/2 + arr[:, 2] = (-np.log10(arr[:, 1]) + np.log10(arr[:, 2]) - np.log10(arr[:, 0]))/2 + return arr + + return cal_pbs(fst) + + return self.__k_way_sample_set_stat( + pbs_func, + 3, + sample_sets, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + ) +