Skip to content

Commit bd84bfb

Browse files
committed
Add more default filters for VCF
Issue warning for VCF FORMAT float fields
1 parent 0efd28a commit bd84bfb

File tree

5 files changed

+98
-41
lines changed

5 files changed

+98
-41
lines changed

sgkit/io/vcf/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
try:
44
from .vcf_partition import partition_into_regions
55
from .vcf_reader import (
6+
FloatFormatFieldWarning,
67
MaxAltAllelesExceededWarning,
78
concat_zarrs,
89
vcf_to_zarr,
910
vcf_to_zarrs,
1011
)
1112

1213
__all__ = [
14+
"FloatFormatFieldWarning",
1315
"MaxAltAllelesExceededWarning",
1416
"concat_zarrs",
1517
"partition_into_regions",

sgkit/io/vcf/utils.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import tempfile
44
import uuid
55
from contextlib import contextmanager
6-
from typing import IO, Any, Dict, Iterator, Optional, Sequence, TypeVar
6+
from typing import IO, Any, Dict, Hashable, Iterator, Optional, Sequence, TypeVar
77
from urllib.parse import urlparse
88

99
import fsspec
10+
from numcodecs import Delta, PackBits
1011
from yarl import URL
1112

1213
from sgkit.typing import PathType
@@ -170,6 +171,36 @@ def temporary_directory(
170171
fs.rm(tempdir, recursive=True)
171172

172173

174+
def get_default_vcf_encoding(ds, chunk_length, chunk_width, compressor):
175+
# Enforce uniform chunks in the variants dimension
176+
# Also chunk in the samples direction
177+
def get_chunk_size(dim: Hashable, size: int) -> int:
178+
if dim == "variants":
179+
return chunk_length
180+
elif dim == "samples":
181+
return chunk_width
182+
else:
183+
return size
184+
185+
default_encoding = {}
186+
for var in ds.data_vars:
187+
var_chunks = tuple(
188+
get_chunk_size(dim, size)
189+
for (dim, size) in zip(ds[var].dims, ds[var].shape)
190+
)
191+
default_encoding[var] = dict(chunks=var_chunks, compressor=compressor)
192+
193+
# Enable bit packing by default for boolean arrays
194+
if ds[var].dtype.kind == "b":
195+
default_encoding[var]["filters"] = [PackBits()]
196+
197+
# Position is monotonically increasing (within a contig) so benefits from delta encoding
198+
if var == "variant_position":
199+
default_encoding[var]["filters"] = [Delta(dtype="i4", astype="i4")]
200+
201+
return default_encoding
202+
203+
173204
def merge_encodings(
174205
default_encoding: Dict[str, Dict[str, Any]], overrides: Dict[str, Dict[str, Any]]
175206
) -> Dict[str, Dict[str, Any]]:

sgkit/io/vcf/vcf_reader.py

+28-33
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,14 @@
55
from contextlib import contextmanager
66
from dataclasses import dataclass
77
from pathlib import Path
8-
from typing import (
9-
Any,
10-
Dict,
11-
Hashable,
12-
Iterator,
13-
MutableMapping,
14-
Optional,
15-
Sequence,
16-
Tuple,
17-
Union,
18-
)
8+
from typing import Any, Dict, Iterator, MutableMapping, Optional, Sequence, Tuple, Union
199

2010
import dask
2111
import fsspec
2212
import numpy as np
2313
import xarray as xr
2414
import zarr
2515
from cyvcf2 import VCF, Variant
26-
from numcodecs import PackBits
2716

2817
from sgkit import variables
2918
from sgkit.io.utils import (
@@ -40,6 +29,7 @@
4029
from sgkit.io.vcf.utils import (
4130
build_url,
4231
chunks,
32+
get_default_vcf_encoding,
4333
merge_encodings,
4434
temporary_directory,
4535
url_filename,
@@ -71,6 +61,12 @@
7161
DEFAULT_COMPRESSOR = None
7262

7363

64+
class FloatFormatFieldWarning(UserWarning):
65+
"""Warning for VCF FORMAT float fields, which can use a lot of storage."""
66+
67+
pass
68+
69+
7470
class MaxAltAllelesExceededWarning(UserWarning):
7571
"""Warning when the number of alt alleles exceeds the maximum specified."""
7672

@@ -535,35 +531,34 @@ def vcf_to_zarr_sequential(
535531
ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen
536532

537533
if first_variants_chunk:
538-
# Enforce uniform chunks in the variants dimension
539-
# Also chunk in the samples direction
540-
541-
def get_chunk_size(dim: Hashable, size: int) -> int:
542-
if dim == "variants":
543-
return chunk_length
544-
elif dim == "samples":
545-
return chunk_width
546-
else:
547-
return size
548-
549-
default_encoding = {}
534+
# ensure that booleans are not stored as int8 by xarray https://github.com/pydata/xarray/issues/4386
550535
for var in ds.data_vars:
551-
var_chunks = tuple(
552-
get_chunk_size(dim, size)
553-
for (dim, size) in zip(ds[var].dims, ds[var].shape)
554-
)
555-
default_encoding[var] = dict(
556-
chunks=var_chunks, compressor=compressor
557-
)
558536
if ds[var].dtype.kind == "b":
559-
# ensure that booleans are not stored as int8 by xarray https://github.com/pydata/xarray/issues/4386
560537
ds[var].attrs["dtype"] = "bool"
561-
default_encoding[var]["filters"] = [PackBits()]
562538

563539
# values from function args (encoding) take precedence over default_encoding
540+
default_encoding = get_default_vcf_encoding(
541+
ds, chunk_length, chunk_width, compressor
542+
)
564543
encoding = encoding or {}
565544
merged_encoding = merge_encodings(default_encoding, encoding)
566545

546+
for var in ds.data_vars:
547+
# Issue warning for VCF FORMAT float fields with no filter
548+
if (
549+
var.startswith("call_")
550+
and ds[var].dtype == np.float32
551+
and (
552+
var not in merged_encoding
553+
or "filters" not in merged_encoding[var]
554+
)
555+
):
556+
warnings.warn(
557+
f"Storing call variable {var} (FORMAT field) as a float can result in large file sizes. "
558+
f"Consider setting the encoding filters for this variable to FixedScaleOffset or similar.",
559+
FloatFormatFieldWarning,
560+
)
561+
567562
ds.to_zarr(output, mode="w", encoding=merged_encoding)
568563
first_variants_chunk = False
569564
else:

sgkit/tests/io/vcf/test_vcf_lossless_conversion.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
],
1818
)
1919
@pytest.mark.filterwarnings(
20+
"ignore::sgkit.io.vcf.FloatFormatFieldWarning",
2021
"ignore::sgkit.io.vcfzarr_reader.DimensionNameForFixedFormatFieldWarning",
2122
)
2223
def test_lossless_conversion(shared_datadir, tmp_path, vcf_file):

sgkit/tests/io/vcf/test_vcf_reader.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import xarray as xr
66
import zarr
7-
from numcodecs import Blosc, PackBits, VLenUTF8
7+
from numcodecs import Blosc, Delta, FixedScaleOffset, PackBits, VLenUTF8
88
from numpy.testing import assert_allclose, assert_array_equal
99

1010
from sgkit import load_dataset, save_dataset
@@ -246,6 +246,10 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
246246
assert z["variant_id_mask"].filters is None
247247
assert z["variant_id_mask"].chunks == (5,)
248248

249+
assert z["variant_position"].filters == [
250+
Delta(dtype="i4", astype="i4")
251+
] # sgkit default
252+
249253

250254
@pytest.mark.parametrize(
251255
"is_path",
@@ -259,7 +263,7 @@ def test_vcf_to_zarr__parallel_compressor_and_filters(
259263
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
260264
regions = ["20", "21"]
261265

262-
default_compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE)
266+
compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE)
263267
variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE)
264268
encoding = dict(
265269
variant_id=dict(compressor=variant_id_compressor),
@@ -270,18 +274,29 @@ def test_vcf_to_zarr__parallel_compressor_and_filters(
270274
output,
271275
regions=regions,
272276
chunk_length=5_000,
273-
compressor=default_compressor,
277+
compressor=compressor,
274278
encoding=encoding,
275279
)
276280

277281
# look at actual Zarr store to check compressor and filters
278282
z = zarr.open(output)
279-
assert z["call_genotype"].compressor == default_compressor
280-
assert z["call_genotype"].filters is None
281-
assert z["call_genotype_mask"].filters == [PackBits()]
283+
assert z["call_genotype"].compressor == compressor
284+
assert z["call_genotype"].filters is None # sgkit default
285+
assert z["call_genotype"].chunks == (5000, 1, 2)
286+
assert z["call_genotype_mask"].compressor == compressor
287+
assert z["call_genotype_mask"].filters == [PackBits()] # sgkit default
288+
assert z["call_genotype_mask"].chunks == (5000, 1, 2)
282289

283290
assert z["variant_id"].compressor == variant_id_compressor
291+
assert z["variant_id"].filters == [VLenUTF8()] # sgkit default
292+
assert z["variant_id"].chunks == (5000,)
293+
assert z["variant_id_mask"].compressor == compressor
284294
assert z["variant_id_mask"].filters is None
295+
assert z["variant_id_mask"].chunks == (5000,)
296+
297+
assert z["variant_position"].filters == [
298+
Delta(dtype="i4", astype="i4")
299+
] # sgkit default
285300

286301

287302
@pytest.mark.parametrize(
@@ -992,7 +1007,20 @@ def test_vcf_to_zarr__field_number_G_non_diploid(shared_datadir, tmp_path):
9921007
path = path_for_test(shared_datadir, "simple.output.mixed_depth.likelihoods.vcf")
9931008
output = tmp_path.joinpath("vcf.zarr").as_posix()
9941009

995-
vcf_to_zarr(path, output, ploidy=4, max_alt_alleles=3, fields=["FORMAT/GL"])
1010+
# store GL field as 2dp
1011+
encoding = {
1012+
"call_GL": {
1013+
"filters": [FixedScaleOffset(offset=0, scale=100, dtype="f4", astype="u1")]
1014+
}
1015+
}
1016+
vcf_to_zarr(
1017+
path,
1018+
output,
1019+
ploidy=4,
1020+
max_alt_alleles=3,
1021+
fields=["FORMAT/GL"],
1022+
encoding=encoding,
1023+
)
9961024
ds = xr.open_zarr(output)
9971025

9981026
# comb(n_alleles + ploidy - 1, ploidy) = comb(4 + 4 - 1, 4) = comb(7, 4) = 35

0 commit comments

Comments
 (0)