Skip to content

Commit 0efd28a

Browse files
committed
Ensure default compressor is not ignored
1 parent b0f70cd commit 0efd28a

File tree

4 files changed

+61
-8
lines changed

4 files changed

+61
-8
lines changed

sgkit/io/vcf/utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,32 @@ def temporary_directory(
168168
finally:
169169
# Remove the temporary directory on exiting the context manager
170170
fs.rm(tempdir, recursive=True)
171+
172+
173+
def merge_encodings(
174+
default_encoding: Dict[str, Dict[str, Any]], overrides: Dict[str, Dict[str, Any]]
175+
) -> Dict[str, Dict[str, Any]]:
176+
"""Merge a dictionary of dictionaries specifying encodings with another dictionary of dictionaries of overriding encodings.
177+
178+
Parameters
179+
----------
180+
default_encoding : Dict[str, Dict[str, Any]]
181+
The default encoding dictionary.
182+
overrides : Dict[str, Dict[str, Any]]
183+
A dictionary containing selective overrides.
184+
185+
Returns
186+
-------
187+
Dict[str, Dict[str, Any]]
188+
The merged encoding dictionary
189+
"""
190+
merged = {}
191+
for var, d in default_encoding.items():
192+
if var in overrides:
193+
merged[var] = {**d, **overrides[var]}
194+
else:
195+
merged[var] = d
196+
for var, d in overrides.items():
197+
if var not in merged:
198+
merged[var] = d
199+
return merged

sgkit/io/vcf/vcf_reader.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@
3737
STR_MISSING,
3838
)
3939
from sgkit.io.vcf import partition_into_regions
40-
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
40+
from sgkit.io.vcf.utils import (
41+
build_url,
42+
chunks,
43+
merge_encodings,
44+
temporary_directory,
45+
url_filename,
46+
)
4147
from sgkit.io.vcfzarr_reader import (
4248
concat_zarrs_optimized,
4349
vcf_number_to_dimension_and_size,
@@ -556,7 +562,7 @@ def get_chunk_size(dim: Hashable, size: int) -> int:
556562

557563
# values from function args (encoding) take precedence over default_encoding
558564
encoding = encoding or {}
559-
merged_encoding = {**default_encoding, **encoding}
565+
merged_encoding = merge_encodings(default_encoding, encoding)
560566

561567
ds.to_zarr(output, mode="w", encoding=merged_encoding)
562568
first_variants_chunk = False

sgkit/tests/io/vcf/test_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from callee.strings import StartsWith
88

9-
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory
9+
from sgkit.io.vcf.utils import build_url, chunks, merge_encodings, temporary_directory
1010
from sgkit.io.vcf.vcf_reader import get_region_start
1111

1212

@@ -118,3 +118,14 @@ def test_chunks(x, n, expected_values):
118118
)
119119
def test_get_region_start(region: str, expected: int):
120120
assert get_region_start(region) == expected
121+
122+
123+
def test_merge_encodings():
124+
default_encoding = dict(a=dict(a1=1, a2=2), b=dict(b1=5))
125+
overrides = dict(a=dict(a1=0, a3=3), c=dict(c1=7))
126+
assert merge_encodings(default_encoding, overrides) == dict(
127+
a=dict(a1=0, a2=2, a3=3), b=dict(b1=5), c=dict(c1=7)
128+
)
129+
130+
assert merge_encodings(default_encoding, {}) == default_encoding
131+
assert merge_encodings({}, overrides) == overrides

sgkit/tests/io/vcf/test_vcf_reader.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
215215
path = path_for_test(shared_datadir, "sample.vcf.gz", is_path)
216216
output = tmp_path.joinpath("vcf.zarr").as_posix()
217217

218-
default_compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE)
218+
compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE)
219219
variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE)
220220
encoding = dict(
221221
variant_id=dict(compressor=variant_id_compressor),
@@ -226,18 +226,25 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
226226
output,
227227
chunk_length=5,
228228
chunk_width=2,
229-
compressor=default_compressor,
229+
compressor=compressor,
230230
encoding=encoding,
231231
)
232232

233233
# look at actual Zarr store to check compressor and filters
234234
z = zarr.open(output)
235-
assert z["call_genotype"].compressor == default_compressor
236-
assert z["call_genotype"].filters is None
237-
assert z["call_genotype_mask"].filters == [PackBits()]
235+
assert z["call_genotype"].compressor == compressor
236+
assert z["call_genotype"].filters is None # sgkit default
237+
assert z["call_genotype"].chunks == (5, 2, 2)
238+
assert z["call_genotype_mask"].compressor == compressor
239+
assert z["call_genotype_mask"].filters == [PackBits()] # sgkit default
240+
assert z["call_genotype_mask"].chunks == (5, 2, 2)
238241

239242
assert z["variant_id"].compressor == variant_id_compressor
243+
assert z["variant_id"].filters == [VLenUTF8()] # sgkit default
244+
assert z["variant_id"].chunks == (5,)
245+
assert z["variant_id_mask"].compressor == compressor
240246
assert z["variant_id_mask"].filters is None
247+
assert z["variant_id_mask"].chunks == (5,)
241248

242249

243250
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)