Skip to content

Commit 3c566d2

Browse files
Support gzip serialization (#3662)
For comparison with #3655. Observed benchmark results, running locally: ``` nested: 7.016707049 s size: 299304 flattened: 1.278540381 s size: 20911 ```
1 parent 97e9a91 commit 3c566d2

File tree

6 files changed

+264
-0
lines changed

6 files changed

+264
-0
lines changed

cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@
493493
qid_shape,
494494
quil,
495495
QuilFormatter,
496+
read_json_gzip,
496497
read_json,
497498
resolve_parameters,
498499
resolve_parameters_once,
@@ -521,6 +522,7 @@
521522
SupportsQasmWithArgsAndQubits,
522523
SupportsTraceDistanceBound,
523524
SupportsUnitary,
525+
to_json_gzip,
524526
to_json,
525527
obj_to_dict_helper,
526528
trace_distance_bound,

cirq/protocols/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
DEFAULT_RESOLVERS,
8080
JsonResolver,
8181
json_serializable_dataclass,
82+
to_json_gzip,
83+
read_json_gzip,
8284
to_json,
8385
read_json,
8486
obj_to_dict_helper,

cirq/protocols/json_serialization.py

+39
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import dataclasses
1515
import functools
16+
import gzip
1617
import json
1718
import numbers
1819
import pathlib
@@ -706,3 +707,41 @@ def obj_hook(x):
706707
return json.load(file, object_hook=obj_hook)
707708

708709
return json.load(cast(IO, file_or_fn), object_hook=obj_hook)
710+
711+
712+
def to_json_gzip(
713+
obj: Any,
714+
file_or_fn: Union[None, IO, pathlib.Path, str] = None,
715+
*,
716+
indent: int = 2,
717+
cls: Type[json.JSONEncoder] = CirqEncoder,
718+
) -> Optional[bytes]:
719+
json_str = to_json(obj, indent=indent, cls=cls)
720+
if isinstance(file_or_fn, (str, pathlib.Path)):
721+
with gzip.open(file_or_fn, 'wt', encoding='utf-8') as actually_a_file:
722+
actually_a_file.write(json_str)
723+
return None
724+
725+
gzip_data = gzip.compress(bytes(json_str, encoding='utf-8')) # type: ignore
726+
if file_or_fn is None:
727+
return gzip_data
728+
729+
file_or_fn.write(gzip_data)
730+
return None
731+
732+
733+
def read_json_gzip(
734+
file_or_fn: Union[None, IO, pathlib.Path, str] = None,
735+
*,
736+
gzip_raw: Optional[bytes] = None,
737+
resolvers: Optional[Sequence[JsonResolver]] = None,
738+
):
739+
if (file_or_fn is None) == (gzip_raw is None):
740+
raise ValueError('Must specify ONE of "file_or_fn" or "gzip_raw".')
741+
742+
if gzip_raw is not None:
743+
json_str = gzip.decompress(gzip_raw).decode(encoding='utf-8')
744+
return read_json(json_text=json_str, resolvers=resolvers)
745+
746+
with gzip.open(file_or_fn, 'rt') as json_file: # type: ignore
747+
return read_json(cast(IO, json_file), resolvers=resolvers)

cirq/protocols/json_serialization_test.py

+43
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,32 @@ def test_op_roundtrip_filename(tmpdir):
9191
op2 = cirq.read_json(filename)
9292
assert op1 == op2
9393

94+
gzip_filename = f'{tmpdir}/op.gz'
95+
cirq.to_json_gzip(op1, gzip_filename)
96+
assert os.path.exists(gzip_filename)
97+
op3 = cirq.read_json_gzip(gzip_filename)
98+
assert op1 == op3
99+
100+
101+
def test_op_roundtrip_file_obj(tmpdir):
102+
filename = f'{tmpdir}/op.json'
103+
q = cirq.LineQubit(5)
104+
op1 = cirq.rx(0.123).on(q)
105+
with open(filename, 'w+') as file:
106+
cirq.to_json(op1, file)
107+
assert os.path.exists(filename)
108+
file.seek(0)
109+
op2 = cirq.read_json(file)
110+
assert op1 == op2
111+
112+
gzip_filename = f'{tmpdir}/op.gz'
113+
with open(gzip_filename, 'w+b') as gzip_file:
114+
cirq.to_json_gzip(op1, gzip_file)
115+
assert os.path.exists(gzip_filename)
116+
gzip_file.seek(0)
117+
op3 = cirq.read_json_gzip(gzip_file)
118+
assert op1 == op3
119+
94120

95121
def test_fail_to_resolve():
96122
buffer = io.StringIO()
@@ -640,6 +666,19 @@ def test_to_from_strings():
640666
cirq.read_json(io.StringIO(), json_text=x_json_text)
641667

642668

669+
def test_to_from_json_gzip():
670+
a, b = cirq.LineQubit.range(2)
671+
test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b))
672+
gzip_data = cirq.to_json_gzip(test_circuit)
673+
unzip_circuit = cirq.read_json_gzip(gzip_raw=gzip_data)
674+
assert test_circuit == unzip_circuit
675+
676+
with pytest.raises(ValueError):
677+
_ = cirq.read_json_gzip(io.StringIO(), gzip_raw=gzip_data)
678+
with pytest.raises(ValueError):
679+
_ = cirq.read_json_gzip()
680+
681+
643682
def _eval_repr_data_file(path: pathlib.Path):
644683
return eval(
645684
path.read_text(),
@@ -736,6 +775,10 @@ def test_pathlib_paths(tmpdir):
736775
cirq.to_json(cirq.X, path)
737776
assert cirq.read_json(path) == cirq.X
738777

778+
gzip_path = pathlib.Path(tmpdir) / 'op.gz'
779+
cirq.to_json_gzip(cirq.X, gzip_path)
780+
assert cirq.read_json_gzip(gzip_path) == cirq.X
781+
739782

740783
def test_json_serializable_dataclass():
741784
@cirq.json_serializable_dataclass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2020 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tool for benchmarking serialization of large circuits.
16+
17+
This tool was originally introduced to enable comparison of the two JSON
18+
serialization protocols (gzip and non-gzip):
19+
https://github.com/quantumlib/Cirq/pull/3662
20+
21+
This is part of the "efficient serialization" effort:
22+
https://github.com/quantumlib/Cirq/issues/3438
23+
24+
Run this benchmark with the following command (make sure to install cirq-dev):
25+
26+
python3 dev_tools/profiling/benchmark_serializers.py \
27+
--num_gates=<int> --nesting_depth=<int> --num_repetitions=<int>
28+
29+
WARNING: runtime increases exponentially with nesting_depth. Values much
30+
higher than nesting_depth=10 are not recommended.
31+
"""
32+
33+
import argparse
34+
import sys
35+
import timeit
36+
37+
import numpy as np
38+
39+
import cirq
40+
41+
_JSON_GZIP = 'json_gzip'
42+
_JSON = 'json'
43+
44+
NUM_QUBITS = 8
45+
46+
SUFFIXES = ['B', 'kB', 'MB', 'GB', 'TB']
47+
48+
49+
def serialize(serializer: str, num_gates: int, nesting_depth: int) -> int:
50+
""""Runs a round-trip of the serializer."""
51+
circuit = cirq.Circuit()
52+
for _ in range(num_gates):
53+
which = np.random.choice(['expz', 'expw', 'exp11'])
54+
if which == 'expw':
55+
q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS))
56+
circuit.append(
57+
cirq.PhasedXPowGate(
58+
phase_exponent=np.random.random(), exponent=np.random.random()
59+
).on(q1)
60+
)
61+
elif which == 'expz':
62+
q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS))
63+
circuit.append(cirq.Z(q1) ** np.random.random())
64+
elif which == 'exp11':
65+
q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS - 1))
66+
q2 = cirq.GridQubit(0, q1.col + 1)
67+
circuit.append(cirq.CZ(q1, q2) ** np.random.random())
68+
cs = [circuit]
69+
for _ in range(1, nesting_depth):
70+
fc = cs[-1].freeze()
71+
cs.append(cirq.Circuit(fc.to_op(), fc.to_op()))
72+
test_circuit = cs[-1]
73+
74+
if serializer == _JSON:
75+
json_data = cirq.to_json(test_circuit)
76+
assert json_data is not None
77+
data_size = len(json_data)
78+
cirq.read_json(json_text=json_data)
79+
elif serializer == _JSON_GZIP:
80+
gzip_data = cirq.to_json_gzip(test_circuit)
81+
assert gzip_data is not None
82+
data_size = len(gzip_data)
83+
cirq.read_json_gzip(gzip_raw=gzip_data)
84+
return data_size
85+
86+
87+
def main(
88+
num_gates: int,
89+
nesting_depth: int,
90+
num_repetitions: int,
91+
setup: str = 'from __main__ import serialize',
92+
):
93+
for serializer in [_JSON_GZIP, _JSON]:
94+
print()
95+
print(f'Using serializer "{serializer}":')
96+
command = f'serialize(\'{serializer}\', {num_gates}, {nesting_depth})'
97+
time = timeit.timeit(command, setup, number=num_repetitions)
98+
print(f'Round-trip serializer time: {time / num_repetitions}s')
99+
data_size = float(serialize(serializer, num_gates, nesting_depth))
100+
suffix_idx = 0
101+
while data_size > 1000:
102+
data_size /= 1024
103+
suffix_idx += 1
104+
print(f'Serialized data size: {data_size} {SUFFIXES[suffix_idx]}.')
105+
106+
107+
def parse_arguments(args):
108+
parser = argparse.ArgumentParser('Benchmark a serializer.')
109+
parser.add_argument(
110+
'--num_gates', default=100, type=int, help='Number of gates at the bottom nesting layer.'
111+
)
112+
parser.add_argument(
113+
'--nesting_depth',
114+
default=1,
115+
type=int,
116+
help='Depth of nested subcircuits. Total gate count will be 2^nesting_depth * num_gates.',
117+
)
118+
parser.add_argument(
119+
'--num_repetitions', default=10, type=int, help='Number of times to repeat serialization.'
120+
)
121+
return vars(parser.parse_args(args))
122+
123+
124+
if __name__ == '__main__':
125+
main(**parse_arguments(sys.argv[1:]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2018 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the simulator benchmarker."""
16+
17+
from dev_tools.profiling import benchmark_serializers
18+
19+
20+
def test_gzip_serializer():
21+
for num_gates in (10, 20):
22+
for nesting_depth in (1, 8):
23+
benchmark_serializers.serialize('json_gzip', num_gates, nesting_depth)
24+
25+
26+
def test_json_serializer():
27+
for num_gates in (10, 20):
28+
for nesting_depth in (1, 8):
29+
benchmark_serializers.serialize('json', num_gates, nesting_depth)
30+
31+
32+
def test_args_have_defaults():
33+
kwargs = benchmark_serializers.parse_arguments([])
34+
for _, v in kwargs.items():
35+
assert v is not None
36+
37+
38+
def test_main_loop():
39+
# Keep test from taking a long time by lowering max qubits.
40+
benchmark_serializers.main(
41+
**benchmark_serializers.parse_arguments({}),
42+
setup='from dev_tools.profiling.benchmark_serializers import serialize',
43+
)
44+
45+
46+
def test_parse_args():
47+
args = ('--num_gates 5 --nesting_depth 8 --num_repetitions 2').split()
48+
kwargs = benchmark_serializers.parse_arguments(args)
49+
assert kwargs == {
50+
'num_gates': 5,
51+
'nesting_depth': 8,
52+
'num_repetitions': 2,
53+
}

0 commit comments

Comments
 (0)