Skip to content

Commit 6f7887a

Browse files
committed
Use sgkit.distarray for count_variant_alleles and variant_stats
Get count_hom working with explicit sum reduction over samples
1 parent 8c3d2c3 commit 6f7887a

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

.github/workflows/cubed.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
3131
- name: Test with pytest
3232
run: |
33-
pytest -v sgkit/tests/test_aggregation.py -k "test_count_call_alleles" --use-cubed
33+
pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed

sgkit/stats/aggregation.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,14 @@ def count_variant_alleles(
183183
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
184184
n_alleles = ds.sizes["alleles"]
185185
n_variant = ds.sizes["variants"]
186-
G = da.asarray(ds[call_genotype]).reshape((n_variant, -1))
186+
G = da.asarray(ds[call_genotype])
187+
G = da.reshape(G, (n_variant, -1))
187188
shape = (G.chunks[0], n_alleles)
188189
# use uint64 dummy array to return uin64 counts array
189190
N = np.empty(n_alleles, dtype=np.uint64)
190-
AC = da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=1, new_axis=1)
191+
AC = da.map_blocks(
192+
count_alleles, G, N, chunks=shape, dtype=np.uint64, drop_axis=1, new_axis=1
193+
)
191194
AC = xr.DataArray(AC, dims=["variants", "alleles"])
192195
else:
193196
options = {variables.call_genotype, variables.call_allele_count}
@@ -692,22 +695,23 @@ def variant_stats(
692695
using=variables.call_genotype, # improved performance
693696
merge=False,
694697
)[variant_allele_count]
695-
G = da.array(ds[call_genotype].data)
698+
G = da.asarray(ds[call_genotype].data)
696699
H = xr.DataArray(
697700
da.map_blocks(
698-
count_hom,
701+
lambda *args: count_hom(*args)[:, np.newaxis, :],
699702
G,
700703
np.zeros(3, np.uint64),
701-
drop_axis=(1, 2),
702-
new_axis=1,
704+
drop_axis=2,
705+
new_axis=2,
703706
dtype=np.int64,
704-
chunks=(G.chunks[0], 3),
707+
chunks=(G.chunks[0], 1, 3),
705708
),
706-
dims=["variants", "categories"],
709+
dims=["variants", "samples", "categories"],
707710
)
711+
H = H.sum(axis=1)
708712
_, n_sample, _ = G.shape
709713
n_called = H.sum(axis=-1)
710-
call_rate = n_called / n_sample
714+
call_rate = n_called.astype(float) / float(n_sample)
711715
n_hom_ref = H[:, 0]
712716
n_hom_alt = H[:, 1]
713717
n_het = H[:, 2]
@@ -723,7 +727,8 @@ def variant_stats(
723727
variables.variant_n_non_ref: n_non_ref,
724728
variables.variant_allele_count: AC,
725729
variables.variant_allele_total: allele_total,
726-
variables.variant_allele_frequency: AC / allele_total,
730+
variables.variant_allele_frequency: AC.astype(float)
731+
/ allele_total.astype(float),
727732
}
728733
)
729734
# for backwards compatible behavior
@@ -798,7 +803,7 @@ def sample_stats(
798803
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
799804
if mixed_ploidy:
800805
raise ValueError("Mixed-ploidy dataset")
801-
G = da.array(ds[call_genotype].data)
806+
G = da.asarray(ds[call_genotype].data)
802807
H = xr.DataArray(
803808
da.map_blocks(
804809
count_hom,

sgkit/tests/test_aggregation.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_count_variant_alleles__chunked(using):
144144
chunks={"variants": 5, "samples": 5}
145145
)
146146
ac2 = count_variant_alleles(ds, using=using)
147-
assert isinstance(ac2["variant_allele_count"].data, da.Array)
147+
assert hasattr(ac2["variant_allele_count"].data, "chunks")
148148
xr.testing.assert_equal(ac1, ac2)
149149

150150

@@ -786,13 +786,14 @@ def test_variant_stats__tetraploid():
786786
)
787787

788788

789-
@pytest.mark.parametrize(
790-
"chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1), (100, 10, 1)]
791-
)
792-
def test_variant_stats__chunks(chunks):
789+
@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
790+
@pytest.mark.parametrize("chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1)])
791+
def test_variant_stats__chunks(precompute_variant_allele_count, chunks):
793792
ds = simulate_genotype_call_dataset(
794793
n_variant=1000, n_sample=30, missing_pct=0.01, seed=0
795794
)
795+
if precompute_variant_allele_count:
796+
ds = count_variant_alleles(ds)
796797
expect = variant_stats(ds, merge=False).compute()
797798
ds["call_genotype"] = ds["call_genotype"].chunk(chunks)
798799
actual = variant_stats(ds, merge=False).compute()

0 commit comments

Comments
 (0)