Skip to content

Improve default VCF compression #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/vcf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,47 @@ cloud storage. You can access files stored on Amazon S3 or Google Cloud Storage
using ``s3://`` or ``gs://`` URLs. Setting credentials or other options is
typically achieved using environment variables for the underlying cloud store.

Compression
-----------

Zarr offers a lot of flexibility over controlling how data is compressed. Each variable can use
a different `compression algorithm <https://zarr.readthedocs.io/en/stable/tutorial.html#compressors>`_,
and its own list of `filters <https://zarr.readthedocs.io/en/stable/tutorial.html#filters>`_.

The :func:`sgkit.io.vcf.vcf_to_zarr` function tries to choose good defaults for compression, using
information about the variable's dtype, and also the nature of the data being stored.

For example, ``variant_position`` (from the VCF ``POS`` field) is a monotonically increasing integer
(within a contig) so it benefits from using a delta encoding to store the differences in its values,
since these are smaller integers that compress better. This encoding is specified using the NumCodecs
`Delta <https://numcodecs.readthedocs.io/en/stable/delta.html>`_ codec as a Zarr filter.

When converting from VCF you can specify the default compression algorithm to use for all variables
by specifying ``compressor`` in the call to :func:`sgkit.io.vcf.vcf_to_zarr`. There are trade-offs
between compression speed and size, which this `benchmark <http://alimanfoo.github.io/2016/09/21/genotype-compression-benchmark.html>`_
does a good job of exploring.

Sometimes you may want to override the compression for a particular variable. A good example of this
is for VCF FORMAT fields that are floats. Floats don't compress well, and since there is a value for
every sample they can take up a lot of space. In many cases full float precision is not needed,
so it is a good idea to use a filter to transform the float to an int, that takes less space.

For example, the following code creates an encoding that can be passed to :func:`sgkit.io.vcf.vcf_to_zarr`
to store the VCF ``DS`` FORMAT field to 2 decimal places. (``DS`` is a dosage field that is between 0 and 2
so we know it will fit into an unsigned 8-bit int.)::

from numcodecs import FixedScaleOffset

encoding = {
"call_DS": {
"filters": [FixedScaleOffset(offset=0, scale=100, dtype="f4", astype="u1")],
},
}

Note that this encoding won't work for floats that may be NaN. Consider using
`Quantize <https://numcodecs.readthedocs.io/en/stable/quantize.html>`_ (with ``astype=np.float16``)
or `Bitround <https://numcodecs.readthedocs.io/en/stable/bitround.html>`_ in that case.

.. _vcf_low_level_operation:

Low-level operation
Expand Down
2 changes: 2 additions & 0 deletions sgkit/io/vcf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
try:
from .vcf_partition import partition_into_regions
from .vcf_reader import (
FloatFormatFieldWarning,
MaxAltAllelesExceededWarning,
concat_zarrs,
vcf_to_zarr,
vcf_to_zarrs,
)

__all__ = [
"FloatFormatFieldWarning",
"MaxAltAllelesExceededWarning",
"concat_zarrs",
"partition_into_regions",
Expand Down
62 changes: 61 additions & 1 deletion sgkit/io/vcf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import tempfile
import uuid
from contextlib import contextmanager
from typing import IO, Any, Dict, Iterator, Optional, Sequence, TypeVar
from typing import IO, Any, Dict, Hashable, Iterator, Optional, Sequence, TypeVar
from urllib.parse import urlparse

import fsspec
from numcodecs import Delta, PackBits
from yarl import URL

from sgkit.typing import PathType
Expand Down Expand Up @@ -168,3 +169,62 @@ def temporary_directory(
finally:
# Remove the temporary directory on exiting the context manager
fs.rm(tempdir, recursive=True)


def get_default_vcf_encoding(ds, chunk_length, chunk_width, compressor):
# Enforce uniform chunks in the variants dimension
# Also chunk in the samples direction
def get_chunk_size(dim: Hashable, size: int) -> int:
if dim == "variants":
return chunk_length
elif dim == "samples":
return chunk_width
else:
return size

default_encoding = {}
for var in ds.data_vars:
var_chunks = tuple(
get_chunk_size(dim, size)
for (dim, size) in zip(ds[var].dims, ds[var].shape)
)
default_encoding[var] = dict(chunks=var_chunks, compressor=compressor)

# Enable bit packing by default for boolean arrays
if ds[var].dtype.kind == "b":
default_encoding[var]["filters"] = [PackBits()]

# Position is monotonically increasing (within a contig) so benefits from delta encoding
if var == "variant_position":
default_encoding[var]["filters"] = [Delta(ds[var].dtype)]

return default_encoding


def merge_encodings(
default_encoding: Dict[str, Dict[str, Any]], overrides: Dict[str, Dict[str, Any]]
) -> Dict[str, Dict[str, Any]]:
"""Merge a dictionary of dictionaries specifying encodings with another dictionary of dictionaries of overriding encodings.

Parameters
----------
default_encoding : Dict[str, Dict[str, Any]]
The default encoding dictionary.
overrides : Dict[str, Dict[str, Any]]
A dictionary containing selective overrides.

Returns
-------
Dict[str, Dict[str, Any]]
The merged encoding dictionary
"""
merged = {}
for var, d in default_encoding.items():
if var in overrides:
merged[var] = {**d, **overrides[var]}
else:
merged[var] = d
for var, d in overrides.items():
if var not in merged:
merged[var] = d
return merged
71 changes: 36 additions & 35 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,14 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Dict,
Hashable,
Iterator,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Any, Dict, Iterator, MutableMapping, Optional, Sequence, Tuple, Union

import dask
import fsspec
import numpy as np
import xarray as xr
import zarr
from cyvcf2 import VCF, Variant
from numcodecs import PackBits

from sgkit import variables
from sgkit.io.utils import (
Expand All @@ -37,7 +26,14 @@
STR_MISSING,
)
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
from sgkit.io.vcf.utils import (
build_url,
chunks,
get_default_vcf_encoding,
merge_encodings,
temporary_directory,
url_filename,
)
from sgkit.io.vcfzarr_reader import (
concat_zarrs_optimized,
vcf_number_to_dimension_and_size,
Expand Down Expand Up @@ -65,6 +61,12 @@
DEFAULT_COMPRESSOR = None


class FloatFormatFieldWarning(UserWarning):
"""Warning for VCF FORMAT float fields, which can use a lot of storage."""

pass


class MaxAltAllelesExceededWarning(UserWarning):
"""Warning when the number of alt alleles exceeds the maximum specified."""

Expand Down Expand Up @@ -529,34 +531,33 @@ def vcf_to_zarr_sequential(
ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen

if first_variants_chunk:
# Enforce uniform chunks in the variants dimension
# Also chunk in the samples direction

def get_chunk_size(dim: Hashable, size: int) -> int:
if dim == "variants":
return chunk_length
elif dim == "samples":
return chunk_width
else:
return size

default_encoding = {}
# ensure that booleans are not stored as int8 by xarray https://github.com/pydata/xarray/issues/4386
for var in ds.data_vars:
var_chunks = tuple(
get_chunk_size(dim, size)
for (dim, size) in zip(ds[var].dims, ds[var].shape)
)
default_encoding[var] = dict(
chunks=var_chunks, compressor=compressor
)
if ds[var].dtype.kind == "b":
# ensure that booleans are not stored as int8 by xarray https://github.com/pydata/xarray/issues/4386
ds[var].attrs["dtype"] = "bool"
default_encoding[var]["filters"] = [PackBits()]

# values from function args (encoding) take precedence over default_encoding
default_encoding = get_default_vcf_encoding(
ds, chunk_length, chunk_width, compressor
)
encoding = encoding or {}
merged_encoding = {**default_encoding, **encoding}
merged_encoding = merge_encodings(default_encoding, encoding)

for var in ds.data_vars:
# Issue warning for VCF FORMAT float fields with no filter
if (
var.startswith("call_")
and ds[var].dtype == np.float32
and (
var not in merged_encoding
or "filters" not in merged_encoding[var]
)
):
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to have an explicit test for the warning with pytest.warns?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea - added

f"Storing call variable {var} (FORMAT field) as a float can result in large file sizes. "
f"Consider setting the encoding filters for this variable to FixedScaleOffset or similar.",
FloatFormatFieldWarning,
)

ds.to_zarr(output, mode="w", encoding=merged_encoding)
first_variants_chunk = False
Expand Down
13 changes: 12 additions & 1 deletion sgkit/tests/io/vcf/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from callee.strings import StartsWith

from sgkit.io.vcf.utils import build_url, chunks, temporary_directory
from sgkit.io.vcf.utils import build_url, chunks, merge_encodings, temporary_directory
from sgkit.io.vcf.vcf_reader import get_region_start


Expand Down Expand Up @@ -118,3 +118,14 @@ def test_chunks(x, n, expected_values):
)
def test_get_region_start(region: str, expected: int):
assert get_region_start(region) == expected


def test_merge_encodings():
default_encoding = dict(a=dict(a1=1, a2=2), b=dict(b1=5))
overrides = dict(a=dict(a1=0, a3=3), c=dict(c1=7))
assert merge_encodings(default_encoding, overrides) == dict(
a=dict(a1=0, a2=2, a3=3), b=dict(b1=5), c=dict(c1=7)
)

assert merge_encodings(default_encoding, {}) == default_encoding
assert merge_encodings({}, overrides) == overrides
1 change: 1 addition & 0 deletions sgkit/tests/io/vcf/test_vcf_lossless_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
],
)
@pytest.mark.filterwarnings(
"ignore::sgkit.io.vcf.FloatFormatFieldWarning",
"ignore::sgkit.io.vcfzarr_reader.DimensionNameForFixedFormatFieldWarning",
)
def test_lossless_conversion(shared_datadir, tmp_path, vcf_file):
Expand Down
Loading