Skip to content

Commit 8e31ed8

Browse files
95-martin-orionrht
authored andcommitted
Support and test type serialization (quantumlib#4693)
Prior to this PR, Cirq supported serialization of _instances_ of Cirq types, but not the types themselves. This PR adds serialization support for Cirq types, with the format: ``` { 'cirq_type': 'type', 'typename': $NAME } ``` where `$NAME` is the `cirq_type` of the object in its JSON representation. For type T, `$NAME` is usually `T.__name__`, but some types (mostly in `cirq_google`) do not follow this rule. The `json_cirq_type` protocol and `_json_cirq_type_` magic method are provided to handle this. It is worth noting that this PR explicitly **does not** support serialization of non-Cirq types (e.g. python builtins, sympy and numpy objects) despite instances of these objects being serializable in Cirq. This support can be added to `json_cirq_type` and `_cirq_object_hook` in `json_serialization.py` if we decide it is necessary; I left it out of this PR as it is not required by the motivating changes behind this PR (quantumlib#4640 and sub-PRs).
1 parent 5e9bef6 commit 8e31ed8

File tree

9 files changed

+231
-20
lines changed

9 files changed

+231
-20
lines changed

cirq-core/cirq/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@
508508
circuit_diagram_info,
509509
CircuitDiagramInfo,
510510
CircuitDiagramInfoArgs,
511+
cirq_type_from_json,
511512
commutes,
512513
control_keys,
513514
decompose,
@@ -520,10 +521,13 @@
520521
has_mixture,
521522
has_stabilizer_effect,
522523
has_unitary,
524+
HasJSONNamespace,
523525
inverse,
524526
is_measurement,
525527
is_parameterized,
526528
JsonResolver,
529+
json_cirq_type,
530+
json_namespace,
527531
json_serializable_dataclass,
528532
dataclass_json_dict,
529533
kraus,

cirq-core/cirq/json_resolver_cache.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def _parallel_gate_op(gate, qubits):
9191
'HPowGate': cirq.HPowGate,
9292
'ISwapPowGate': cirq.ISwapPowGate,
9393
'IdentityGate': cirq.IdentityGate,
94-
'IdentityOperation': _identity_operation_from_dict,
9594
'InitObsSetting': cirq.work.InitObsSetting,
9695
'KrausChannel': cirq.KrausChannel,
9796
'LinearDict': cirq.LinearDict,
@@ -115,7 +114,6 @@ def _parallel_gate_op(gate, qubits):
115114
'_PauliY': cirq.ops.pauli_gates._PauliY,
116115
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
117116
'ParamResolver': cirq.ParamResolver,
118-
'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14
119117
'ParallelGate': cirq.ParallelGate,
120118
'PauliMeasurementGate': cirq.PauliMeasurementGate,
121119
'PauliString': cirq.PauliString,
@@ -134,7 +132,6 @@ def _parallel_gate_op(gate, qubits):
134132
'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria,
135133
'ResetChannel': cirq.ResetChannel,
136134
'SingleQubitCliffordGate': cirq.SingleQubitCliffordGate,
137-
'SingleQubitMatrixGate': single_qubit_matrix_gate,
138135
'SingleQubitPauliStringGateOperation': cirq.SingleQubitPauliStringGateOperation,
139136
'SingleQubitReadoutCalibrationResult': cirq.experiments.SingleQubitReadoutCalibrationResult,
140137
'StabilizerStateChForm': cirq.StabilizerStateChForm,
@@ -147,7 +144,6 @@ def _parallel_gate_op(gate, qubits):
147144
'Rx': cirq.Rx,
148145
'Ry': cirq.Ry,
149146
'Rz': cirq.Rz,
150-
'TwoQubitMatrixGate': two_qubit_matrix_gate,
151147
'_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice,
152148
'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria,
153149
'VirtualTag': cirq.VirtualTag,
@@ -163,6 +159,11 @@ def _parallel_gate_op(gate, qubits):
163159
'YYPowGate': cirq.YYPowGate,
164160
'ZPowGate': cirq.ZPowGate,
165161
'ZZPowGate': cirq.ZZPowGate,
162+
# Old types, only supported for backwards-compatibility
163+
'IdentityOperation': _identity_operation_from_dict,
164+
'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14
165+
'SingleQubitMatrixGate': single_qubit_matrix_gate,
166+
'TwoQubitMatrixGate': two_qubit_matrix_gate,
166167
# not a cirq class, but treated as one:
167168
'pandas.DataFrame': pd.DataFrame,
168169
'pandas.Index': pd.Index,

cirq-core/cirq/protocols/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@
8282
inverse,
8383
)
8484
from cirq.protocols.json_serialization import (
85+
cirq_type_from_json,
8586
DEFAULT_RESOLVERS,
87+
HasJSONNamespace,
8688
JsonResolver,
8789
json_serializable_dataclass,
90+
json_cirq_type,
91+
json_namespace,
8892
to_json_gzip,
8993
read_json_gzip,
9094
to_json,

cirq-core/cirq/protocols/json_serialization.py

+109-7
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,23 @@ def _json_dict_(self) -> Union[None, NotImplementedType, Dict[Any, Any]]:
128128
pass
129129

130130

131+
class HasJSONNamespace(Protocol):
132+
"""An object which prepends a namespace to its JSON cirq_type.
133+
134+
Classes which implement this method have the following cirq_type format:
135+
136+
f"{obj._json_namespace_()}.{obj.__class__.__name__}
137+
138+
Classes outside of Cirq or its submodules MUST implement this method to be
139+
used in type serialization.
140+
"""
141+
142+
@doc_private
143+
@classmethod
144+
def _json_namespace_(cls) -> str:
145+
pass
146+
147+
131148
def obj_to_dict_helper(
132149
obj: Any, attribute_names: Iterable[str], namespace: Optional[str] = None
133150
) -> Dict[str, Any]:
@@ -350,13 +367,7 @@ def _cirq_object_hook(d, resolvers: Sequence[JsonResolver], context_map: Dict[st
350367
if d['cirq_type'] == '_ContextualSerialization':
351368
return _ContextualSerialization.deserialize_with_context(**d)
352369

353-
for resolver in resolvers:
354-
cls = resolver(d['cirq_type'])
355-
if cls is not None:
356-
break
357-
else:
358-
raise ValueError(f"Could not resolve type '{d['cirq_type']}' during deserialization")
359-
370+
cls = factory_from_json(d['cirq_type'], resolvers=resolvers)
360371
from_json_dict = getattr(cls, '_from_json_dict_', None)
361372
if from_json_dict is not None:
362373
return from_json_dict(**d)
@@ -505,6 +516,97 @@ def get_serializable_by_keys(obj: Any) -> List[SerializableByKey]:
505516
return []
506517

507518

519+
def json_namespace(type_obj: Type) -> str:
520+
"""Returns a namespace for JSON serialization of `type_obj`.
521+
522+
Types can provide custom namespaces with `_json_namespace_`; otherwise, a
523+
Cirq type will not include a namespace in its cirq_type. Non-Cirq types
524+
must provide a namespace for serialization in Cirq.
525+
526+
Args:
527+
type_obj: Type to retrieve the namespace from.
528+
529+
Returns:
530+
The namespace to prepend `type_obj` with in its JSON cirq_type.
531+
532+
Raises:
533+
ValueError: if `type_obj` is not a Cirq type and does not explicitly
534+
define its namespace with _json_namespace_.
535+
"""
536+
if hasattr(type_obj, '_json_namespace_'):
537+
return type_obj._json_namespace_()
538+
if type_obj.__module__.startswith('cirq'):
539+
return ''
540+
raise ValueError(f'{type_obj} is not a Cirq type, and does not define _json_namespace_.')
541+
542+
543+
def json_cirq_type(type_obj: Type) -> str:
544+
"""Returns a string type for JSON serialization of `type_obj`.
545+
546+
This method is not part of the base serialization path. Together with
547+
`cirq_type_from_json`, it can be used to provide type-object serialization
548+
for classes that need it.
549+
"""
550+
namespace = json_namespace(type_obj)
551+
if namespace:
552+
return f'{namespace}.{type_obj.__name__}'
553+
return type_obj.__name__
554+
555+
556+
def factory_from_json(
557+
type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None
558+
) -> ObjectFactory:
559+
"""Returns a factory for constructing objects of type `type_str`.
560+
561+
DEFAULT_RESOLVERS is updated dynamically as cirq submodules are imported.
562+
563+
Args:
564+
type_str: string representation of the type to deserialize.
565+
resolvers: list of JsonResolvers to use in type resolution. If this is
566+
left blank, DEFAULT_RESOLVERS will be used.
567+
568+
Returns:
569+
An ObjectFactory that can be called to construct an object whose type
570+
matches the name `type_str`.
571+
572+
Raises:
573+
ValueError: if type_str does not have a match in `resolvers`.
574+
"""
575+
resolvers = resolvers if resolvers is not None else DEFAULT_RESOLVERS
576+
for resolver in resolvers:
577+
cirq_type = resolver(type_str)
578+
if cirq_type is not None:
579+
return cirq_type
580+
raise ValueError(f"Could not resolve type '{type_str}' during deserialization")
581+
582+
583+
def cirq_type_from_json(type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None) -> Type:
584+
"""Returns a type object for JSON deserialization of `type_str`.
585+
586+
This method is not part of the base deserialization path. Together with
587+
`json_cirq_type`, it can be used to provide type-object deserialization
588+
for classes that need it.
589+
590+
Args:
591+
type_str: string representation of the type to deserialize.
592+
resolvers: list of JsonResolvers to use in type resolution. If this is
593+
left blank, DEFAULT_RESOLVERS will be used.
594+
595+
Returns:
596+
The type object T for which json_cirq_type(T) matches `type_str`.
597+
598+
Raises:
599+
ValueError: if type_str does not have a match in `resolvers`, or if the
600+
match found is a factory method instead of a type.
601+
"""
602+
cirq_type = factory_from_json(type_str, resolvers)
603+
if isinstance(cirq_type, type):
604+
return cirq_type
605+
# We assume that if factory_from_json returns a factory, there is not
606+
# another resolver which resolves `type_str` to a type object.
607+
raise ValueError(f"Type {type_str} maps to a factory method instead of a type.")
608+
609+
508610
# pylint: disable=function-redefined
509611
@overload
510612
def to_json(

cirq-core/cirq/protocols/json_serialization_test.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pathlib
2222
import sys
2323
import warnings
24-
from typing import Dict, List, Optional, Tuple
24+
from typing import Dict, List, Optional, Tuple, Type
2525
from unittest import mock
2626

2727
import numpy as np
@@ -534,6 +534,73 @@ def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: st
534534
)
535535

536536

537+
@dataclasses.dataclass
538+
class SerializableTypeObject:
539+
test_type: Type
540+
541+
def _json_dict_(self):
542+
return {
543+
'cirq_type': 'SerializableTypeObject',
544+
'test_type': json_serialization.json_cirq_type(self.test_type),
545+
}
546+
547+
@classmethod
548+
def _from_json_dict_(cls, test_type, **kwargs):
549+
return cls(json_serialization.cirq_type_from_json(test_type))
550+
551+
552+
@pytest.mark.parametrize(
553+
'mod_spec,cirq_obj_name,cls',
554+
_list_public_classes_for_tested_modules(),
555+
)
556+
def test_type_serialization(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls):
557+
if cirq_obj_name in mod_spec.tested_elsewhere:
558+
pytest.skip("Tested elsewhere.")
559+
560+
if cirq_obj_name in mod_spec.not_yet_serializable:
561+
return pytest.xfail(reason="Not serializable (yet)")
562+
563+
if cls is None:
564+
pytest.skip(f'No serialization for None-mapped type: {cirq_obj_name}')
565+
566+
try:
567+
typename = cirq.json_cirq_type(cls)
568+
except ValueError as e:
569+
pytest.skip(f'No serialization for non-Cirq type: {str(e)}')
570+
571+
def custom_resolver(name):
572+
if name == 'SerializableTypeObject':
573+
return SerializableTypeObject
574+
575+
sto = SerializableTypeObject(cls)
576+
test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
577+
expected_json = (
578+
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "{typename}"\n}}'
579+
)
580+
assert cirq.to_json(sto) == expected_json
581+
assert cirq.read_json(json_text=expected_json, resolvers=test_resolvers) == sto
582+
assert_json_roundtrip_works(sto, resolvers=test_resolvers)
583+
584+
585+
def test_invalid_type_deserialize():
586+
def custom_resolver(name):
587+
if name == 'SerializableTypeObject':
588+
return SerializableTypeObject
589+
590+
test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
591+
invalid_json = (
592+
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "bad_type"\n}}'
593+
)
594+
with pytest.raises(ValueError, match='Could not resolve type'):
595+
_ = cirq.read_json(json_text=invalid_json, resolvers=test_resolvers)
596+
597+
factory_json = (
598+
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "sympy.Add"\n}}'
599+
)
600+
with pytest.raises(ValueError, match='maps to a factory method'):
601+
_ = cirq.read_json(json_text=factory_json, resolvers=test_resolvers)
602+
603+
537604
def test_to_from_strings():
538605
x_json_text = """{
539606
"cirq_type": "_PauliX",

cirq-core/cirq/protocols/json_test_data/spec.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
'SimulatesFinalState',
131131
'NamedTopology',
132132
# protocols:
133+
'HasJSONNamespace',
133134
'SupportsActOn',
134135
'SupportsActOnQubits',
135136
'SupportsApplyChannel',

cirq-google/cirq_google/workflow/io.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ def load(self, *, base_data_dir: str = ".") -> 'cg.ExecutableGroupResult':
6666
],
6767
)
6868

69+
@classmethod
70+
def _json_namespace_(cls) -> str:
71+
return 'cirq.google'
72+
6973
def _json_dict_(self) -> Dict[str, Any]:
70-
return dataclass_json_dict(self, namespace='cirq.google')
74+
return dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))
7175

7276
def __repr__(self) -> str:
7377
return _compat.dataclass_repr(self, namespace='cirq_google')

cirq-google/cirq_google/workflow/quantum_executable.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,12 @@ class KeyValueExecutableSpec(ExecutableSpec):
5353
executable_family: str
5454
key_value_pairs: Tuple[Tuple[str, Any], ...] = ()
5555

56+
@classmethod
57+
def _json_namespace_(cls) -> str:
58+
return 'cirq.google'
59+
5660
def _json_dict_(self) -> Dict[str, Any]:
57-
return cirq.dataclass_json_dict(self, namespace='cirq.google')
61+
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))
5862

5963
@classmethod
6064
def from_dict(cls, d: Dict[str, Any], *, executable_family: str) -> 'KeyValueExecutableSpec':
@@ -90,8 +94,12 @@ class BitstringsMeasurement:
9094

9195
n_repetitions: int
9296

97+
@classmethod
98+
def _json_namespace_(cls) -> str:
99+
return 'cirq.google'
100+
93101
def _json_dict_(self):
94-
return cirq.dataclass_json_dict(self, namespace='cirq.google')
102+
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))
95103

96104
def __repr__(self):
97105
return cirq._compat.dataclass_repr(self, namespace='cirq_google')
@@ -198,8 +206,12 @@ def __str__(self):
198206
def __repr__(self):
199207
return _compat.dataclass_repr(self, namespace='cirq_google')
200208

209+
@classmethod
210+
def _json_namespace_(cls) -> str:
211+
return 'cirq.google'
212+
201213
def _json_dict_(self):
202-
return cirq.dataclass_json_dict(self, namespace='cirq.google')
214+
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))
203215

204216

205217
@dataclass(frozen=True)
@@ -248,5 +260,9 @@ def __repr__(self) -> str:
248260
def __hash__(self) -> int:
249261
return self._hash # type: ignore
250262

263+
@classmethod
264+
def _json_namespace_(cls) -> str:
265+
return 'cirq.google'
266+
251267
def _json_dict_(self) -> Dict[str, Any]:
252-
return cirq.dataclass_json_dict(self, namespace='cirq.google')
268+
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))

0 commit comments

Comments
 (0)