Skip to content

Commit 718fe3b

Browse files
tomwhitemergify[bot]
authored andcommitted
Issue a warning if the number of alt alleles exceeds the maximum specified
1 parent 5f43ecc commit 718fe3b

File tree

5 files changed

+92
-19
lines changed

5 files changed

+92
-19
lines changed

sgkit/io/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def zarrs_to_dataset(
100100
ds[variable_name] = ds[variable_name].astype(f"S{max_length}")
101101
del ds.attrs[attr]
102102

103+
if "max_alt_alleles_seen" in datasets[0].attrs:
104+
ds.attrs["max_alt_alleles_seen"] = max(
105+
ds.attrs["max_alt_alleles_seen"] for ds in datasets
106+
)
107+
103108
return ds
104109

105110

sgkit/io/vcf/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
try:
44
from ..utils import zarrs_to_dataset
55
from .vcf_partition import partition_into_regions
6-
from .vcf_reader import vcf_to_zarr, vcf_to_zarrs
6+
from .vcf_reader import MaxAltAllelesExceededWarning, vcf_to_zarr, vcf_to_zarrs
77

88
__all__ = [
9+
"MaxAltAllelesExceededWarning",
910
"partition_into_regions",
1011
"vcf_to_zarr",
1112
"vcf_to_zarrs",

sgkit/io/vcf/vcf_reader.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from numcodecs import PackBits
2525

2626
from sgkit import variables
27+
from sgkit.io.dataset import load_dataset
2728
from sgkit.io.utils import zarrs_to_dataset
2829
from sgkit.io.vcf import partition_into_regions
2930
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
@@ -50,6 +51,12 @@
5051
DEFAULT_COMPRESSOR = None
5152

5253

54+
class MaxAltAllelesExceededWarning(UserWarning):
55+
"""Warning when the number of alt alleles exceeds the maximum specified."""
56+
57+
pass
58+
59+
5360
@contextmanager
5461
def open_vcf(path: PathType) -> Iterator[VCF]:
5562
"""A context manager for opening a VCF file."""
@@ -362,6 +369,7 @@ def vcf_to_zarr_sequential(
362369
# Remember max lengths of variable-length strings
363370
max_variant_id_length = 0
364371
max_variant_allele_length = 0
372+
max_alt_alleles_seen = 0
365373

366374
# Iterate through variants in batches of chunk_length
367375

@@ -413,6 +421,7 @@ def vcf_to_zarr_sequential(
413421
variant_position[i] = variant.POS
414422

415423
alleles = [variant.REF] + variant.ALT
424+
max_alt_alleles_seen = max(max_alt_alleles_seen, len(variant.ALT))
416425
if len(alleles) > n_allele:
417426
alleles = alleles[:n_allele]
418427
elif len(alleles) < n_allele:
@@ -457,6 +466,7 @@ def vcf_to_zarr_sequential(
457466
if add_str_max_length_attrs:
458467
ds.attrs["max_length_variant_id"] = max_variant_id_length
459468
ds.attrs["max_length_variant_allele"] = max_variant_allele_length
469+
ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen
460470

461471
if first_variants_chunk:
462472
# Enforce uniform chunks in the variants dimension
@@ -839,6 +849,15 @@ def vcf_to_zarr(
839849
field_defs=field_defs,
840850
)
841851

852+
# Issue a warning if max_alt_alleles caused data to be dropped
853+
ds = load_dataset(output)
854+
max_alt_alleles_seen = ds.attrs["max_alt_alleles_seen"]
855+
if max_alt_alleles_seen > max_alt_alleles:
856+
warnings.warn(
857+
f"Some alternate alleles were dropped, since actual max value {max_alt_alleles_seen} exceeded max_alt_alleles setting of {max_alt_alleles}.",
858+
MaxAltAllelesExceededWarning,
859+
)
860+
842861

843862
def count_variants(path: PathType, region: Optional[str] = None) -> int:
844863
"""Count the number of variants in a VCF file."""

sgkit/tests/io/vcf/test_vcf_reader.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from numpy.testing import assert_allclose, assert_array_equal
99

1010
from sgkit import load_dataset
11-
from sgkit.io.vcf import partition_into_regions, vcf_to_zarr
11+
from sgkit.io.vcf import (
12+
MaxAltAllelesExceededWarning,
13+
partition_into_regions,
14+
vcf_to_zarr,
15+
)
1216

1317
from .utils import path_for_test
1418

@@ -98,30 +102,35 @@ def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path):
98102
path = path_for_test(shared_datadir, "sample.vcf.gz", is_path)
99103
output = tmp_path.joinpath("vcf.zarr").as_posix()
100104

101-
vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1)
102-
ds = xr.open_zarr(output)
105+
with pytest.warns(MaxAltAllelesExceededWarning):
106+
vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1)
107+
ds = xr.open_zarr(output)
103108

104-
# extra alt alleles are silently dropped
105-
assert_array_equal(
106-
ds["variant_allele"],
107-
[
108-
["A", "C"],
109-
["A", "G"],
110-
["G", "A"],
111-
["T", "A"],
112-
["A", "G"],
113-
["T", ""],
114-
["G", "GA"],
115-
["T", ""],
116-
["AC", "A"],
117-
],
118-
)
109+
# extra alt alleles are dropped
110+
assert_array_equal(
111+
ds["variant_allele"],
112+
[
113+
["A", "C"],
114+
["A", "G"],
115+
["G", "A"],
116+
["T", "A"],
117+
["A", "G"],
118+
["T", ""],
119+
["G", "GA"],
120+
["T", ""],
121+
["AC", "A"],
122+
],
123+
)
124+
125+
# the maximum number of alt alleles actually seen is stored as an attribute
126+
assert ds.attrs["max_alt_alleles_seen"] == 3
119127

120128

121129
@pytest.mark.parametrize(
122130
"is_path",
123131
[True, False],
124132
)
133+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
125134
def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, tmp_path):
126135
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
127136
output = tmp_path.joinpath("vcf.zarr").as_posix()
@@ -159,6 +168,7 @@ def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path):
159168
"is_path",
160169
[True, False],
161170
)
171+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
162172
def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path):
163173
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
164174
output: MutableMapping[str, bytes] = {}
@@ -217,6 +227,7 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
217227
"is_path",
218228
[True, False],
219229
)
230+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
220231
def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
221232
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
222233
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
@@ -266,6 +277,7 @@ def test_vcf_to_zarr__empty_region(shared_datadir, is_path, tmp_path):
266277
"is_path",
267278
[False],
268279
)
280+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
269281
def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path):
270282
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
271283
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
@@ -354,6 +366,7 @@ def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_
354366
"is_path",
355367
[True, False],
356368
)
369+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
357370
def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path):
358371
paths = [
359372
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
@@ -381,6 +394,7 @@ def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path):
381394
"is_path",
382395
[True, False],
383396
)
397+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
384398
def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path):
385399
paths = [
386400
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
@@ -410,6 +424,7 @@ def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path):
410424
"is_path",
411425
[True, False],
412426
)
427+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
413428
def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path):
414429
paths = [
415430
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
@@ -456,6 +471,31 @@ def test_vcf_to_zarr__mutiple_partitioned_invalid_regions(
456471
vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000)
457472

458473

474+
@pytest.mark.parametrize(
475+
"is_path",
476+
[True, False],
477+
)
478+
def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path):
479+
paths = [
480+
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
481+
path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path),
482+
]
483+
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
484+
485+
with pytest.warns(MaxAltAllelesExceededWarning):
486+
vcf_to_zarr(
487+
paths,
488+
output,
489+
target_part_size="40KB",
490+
chunk_length=5_000,
491+
max_alt_alleles=1,
492+
)
493+
ds = xr.open_zarr(output)
494+
495+
# the maximum number of alt alleles actually seen is stored as an attribute
496+
assert ds.attrs["max_alt_alleles_seen"] == 7
497+
498+
459499
@pytest.mark.parametrize(
460500
"ploidy,mixed_ploidy,truncate_calls,regions",
461501
[
@@ -647,6 +687,7 @@ def test_vcf_to_zarr__fields(shared_datadir, tmp_path):
647687
assert ds["call_DP"].attrs["comment"] == "Read Depth"
648688

649689

690+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
650691
def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path):
651692
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz")
652693
output = tmp_path.joinpath("vcf.zarr").as_posix()
@@ -703,6 +744,7 @@ def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path):
703744
assert "comment" not in ds["variant_DP"].attrs
704745

705746

747+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
706748
def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path):
707749
path = path_for_test(shared_datadir, "sample.vcf.gz")
708750
output = tmp_path.joinpath("vcf.zarr").as_posix()
@@ -736,6 +778,7 @@ def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path):
736778
)
737779

738780

781+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
739782
def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path):
740783
path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz")
741784
output = tmp_path.joinpath("vcf.zarr").as_posix()
@@ -768,6 +811,7 @@ def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path):
768811
)
769812

770813

814+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
771815
def test_vcf_to_zarr__field_number_G(shared_datadir, tmp_path):
772816
path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz")
773817
output = tmp_path.joinpath("vcf.zarr").as_posix()

sgkit/tests/io/vcf/test_vcf_roundtrip.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_default_fields(shared_datadir, tmpdir):
7979
sg_vcfzarr_path = create_sg_vcfzarr(shared_datadir, tmpdir)
8080
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
8181
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
82+
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel
8283

8384
assert_identical(allel_ds, sg_ds)
8485

@@ -107,6 +108,7 @@ def test_DP_field(shared_datadir, tmpdir):
107108
)
108109
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
109110
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
111+
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel
110112

111113
assert_identical(allel_ds, sg_ds)
112114

@@ -120,6 +122,7 @@ def test_DP_field(shared_datadir, tmpdir):
120122
("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"]),
121123
],
122124
)
125+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
123126
def test_all_fields(
124127
shared_datadir, tmpdir, vcf_file, allel_exclude_fields, sgkit_exclude_fields
125128
):
@@ -159,6 +162,7 @@ def test_all_fields(
159162
)
160163
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
161164
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
165+
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel
162166

163167
# scikit-allel only records contigs for which there are actual variants,
164168
# whereas sgkit records contigs from the header

0 commit comments

Comments
 (0)