Skip to content

Commit 4e7cc1f

Browse files
tomwhitemergify[bot]
authored andcommitted
Ignore samples that are not in any cohort #400
1 parent 4e117c9 commit 4e7cc1f

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

sgkit/stats/aggregation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def _count_cohort_alleles(
8888
n_samples, n_alleles = ac.shape
8989
for i in range(n_samples):
9090
for j in range(n_alleles):
91-
out[cohorts[i], j] += ac[i, j]
91+
c = cohorts[i]
92+
if c >= 0:
93+
out[c, j] += ac[i, j]
9294

9395

9496
def count_call_alleles(

sgkit/tests/test_aggregation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,14 @@ def test_count_call_alleles__chunked():
216216
def test_count_cohort_alleles__multi_variant_multi_sample():
217217
ds = get_dataset(
218218
[
219-
[[0, 0], [0, 0], [0, 0]],
220-
[[0, 0], [0, 0], [0, 1]],
221-
[[1, 1], [0, 1], [1, 0]],
222-
[[1, 1], [1, 1], [1, 1]],
219+
[[0, 0], [0, 0], [0, 0], [0, 0]],
220+
[[0, 0], [0, 0], [0, 1], [0, 1]],
221+
[[1, 1], [0, 1], [1, 0], [1, 0]],
222+
[[1, 1], [1, 1], [1, 1], [1, 1]],
223223
]
224224
)
225-
ds["sample_cohort"] = xr.DataArray(np.array([0, 1, 1]), dims="samples")
225+
# -1 means that the sample is not in any cohort
226+
ds["sample_cohort"] = xr.DataArray(np.array([0, 1, 1, -1]), dims="samples")
226227
ds = count_cohort_alleles(ds)
227228
ac = ds.cohort_allele_count
228229
np.testing.assert_equal(

0 commit comments

Comments
 (0)