diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 062423bf431..2923d10979f 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -508,6 +508,7 @@ circuit_diagram_info, CircuitDiagramInfo, CircuitDiagramInfoArgs, + cirq_type_from_json, commutes, control_keys, decompose, @@ -520,10 +521,13 @@ has_mixture, has_stabilizer_effect, has_unitary, + HasJSONNamespace, inverse, is_measurement, is_parameterized, JsonResolver, + json_cirq_type, + json_namespace, json_serializable_dataclass, dataclass_json_dict, kraus, diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 04e1e00487a..1237716c7ef 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -91,7 +91,6 @@ def _parallel_gate_op(gate, qubits): 'HPowGate': cirq.HPowGate, 'ISwapPowGate': cirq.ISwapPowGate, 'IdentityGate': cirq.IdentityGate, - 'IdentityOperation': _identity_operation_from_dict, 'InitObsSetting': cirq.work.InitObsSetting, 'KrausChannel': cirq.KrausChannel, 'LinearDict': cirq.LinearDict, @@ -115,7 +114,6 @@ def _parallel_gate_op(gate, qubits): '_PauliY': cirq.ops.pauli_gates._PauliY, '_PauliZ': cirq.ops.pauli_gates._PauliZ, 'ParamResolver': cirq.ParamResolver, - 'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14 'ParallelGate': cirq.ParallelGate, 'PauliMeasurementGate': cirq.PauliMeasurementGate, 'PauliString': cirq.PauliString, @@ -134,7 +132,6 @@ def _parallel_gate_op(gate, qubits): 'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria, 'ResetChannel': cirq.ResetChannel, 'SingleQubitCliffordGate': cirq.SingleQubitCliffordGate, - 'SingleQubitMatrixGate': single_qubit_matrix_gate, 'SingleQubitPauliStringGateOperation': cirq.SingleQubitPauliStringGateOperation, 'SingleQubitReadoutCalibrationResult': cirq.experiments.SingleQubitReadoutCalibrationResult, 'StabilizerStateChForm': cirq.StabilizerStateChForm, @@ -147,7 +144,6 @@ def _parallel_gate_op(gate, qubits): 'Rx': cirq.Rx, 'Ry': cirq.Ry, 'Rz': cirq.Rz, - 'TwoQubitMatrixGate': two_qubit_matrix_gate, '_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice, 'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria, 'VirtualTag': cirq.VirtualTag, @@ -163,6 +159,11 @@ def _parallel_gate_op(gate, qubits): 'YYPowGate': cirq.YYPowGate, 'ZPowGate': cirq.ZPowGate, 'ZZPowGate': cirq.ZZPowGate, + # Old types, only supported for backwards-compatibility + 'IdentityOperation': _identity_operation_from_dict, + 'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14 + 'SingleQubitMatrixGate': single_qubit_matrix_gate, + 'TwoQubitMatrixGate': two_qubit_matrix_gate, # not a cirq class, but treated as one: 'pandas.DataFrame': pd.DataFrame, 'pandas.Index': pd.Index, diff --git a/cirq-core/cirq/protocols/__init__.py b/cirq-core/cirq/protocols/__init__.py index 6e16edb630e..a56f0f7965a 100644 --- a/cirq-core/cirq/protocols/__init__.py +++ b/cirq-core/cirq/protocols/__init__.py @@ -82,9 +82,13 @@ inverse, ) from cirq.protocols.json_serialization import ( + cirq_type_from_json, DEFAULT_RESOLVERS, + HasJSONNamespace, JsonResolver, json_serializable_dataclass, + json_cirq_type, + json_namespace, to_json_gzip, read_json_gzip, to_json, diff --git a/cirq-core/cirq/protocols/json_serialization.py b/cirq-core/cirq/protocols/json_serialization.py index a1e2d5f772c..b3ddb386b91 100644 --- a/cirq-core/cirq/protocols/json_serialization.py +++ b/cirq-core/cirq/protocols/json_serialization.py @@ -128,6 +128,23 @@ def _json_dict_(self) -> Union[None, NotImplementedType, Dict[Any, Any]]: pass +class HasJSONNamespace(Protocol): + """An object which prepends a namespace to its JSON cirq_type. + + Classes which implement this method have the following cirq_type format: + + f"{obj._json_namespace_()}.{obj.__class__.__name__} + + Classes outside of Cirq or its submodules MUST implement this method to be + used in type serialization. + """ + + @doc_private + @classmethod + def _json_namespace_(cls) -> str: + pass + + def obj_to_dict_helper( obj: Any, attribute_names: Iterable[str], namespace: Optional[str] = None ) -> Dict[str, Any]: @@ -350,13 +367,7 @@ def _cirq_object_hook(d, resolvers: Sequence[JsonResolver], context_map: Dict[st if d['cirq_type'] == '_ContextualSerialization': return _ContextualSerialization.deserialize_with_context(**d) - for resolver in resolvers: - cls = resolver(d['cirq_type']) - if cls is not None: - break - else: - raise ValueError(f"Could not resolve type '{d['cirq_type']}' during deserialization") - + cls = factory_from_json(d['cirq_type'], resolvers=resolvers) from_json_dict = getattr(cls, '_from_json_dict_', None) if from_json_dict is not None: return from_json_dict(**d) @@ -505,6 +516,97 @@ def get_serializable_by_keys(obj: Any) -> List[SerializableByKey]: return [] +def json_namespace(type_obj: Type) -> str: + """Returns a namespace for JSON serialization of `type_obj`. + + Types can provide custom namespaces with `_json_namespace_`; otherwise, a + Cirq type will not include a namespace in its cirq_type. Non-Cirq types + must provide a namespace for serialization in Cirq. + + Args: + type_obj: Type to retrieve the namespace from. + + Returns: + The namespace to prepend `type_obj` with in its JSON cirq_type. + + Raises: + ValueError: if `type_obj` is not a Cirq type and does not explicitly + define its namespace with _json_namespace_. + """ + if hasattr(type_obj, '_json_namespace_'): + return type_obj._json_namespace_() + if type_obj.__module__.startswith('cirq'): + return '' + raise ValueError(f'{type_obj} is not a Cirq type, and does not define _json_namespace_.') + + +def json_cirq_type(type_obj: Type) -> str: + """Returns a string type for JSON serialization of `type_obj`. + + This method is not part of the base serialization path. Together with + `cirq_type_from_json`, it can be used to provide type-object serialization + for classes that need it. + """ + namespace = json_namespace(type_obj) + if namespace: + return f'{namespace}.{type_obj.__name__}' + return type_obj.__name__ + + +def factory_from_json( + type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None +) -> ObjectFactory: + """Returns a factory for constructing objects of type `type_str`. + + DEFAULT_RESOLVERS is updated dynamically as cirq submodules are imported. + + Args: + type_str: string representation of the type to deserialize. + resolvers: list of JsonResolvers to use in type resolution. If this is + left blank, DEFAULT_RESOLVERS will be used. + + Returns: + An ObjectFactory that can be called to construct an object whose type + matches the name `type_str`. + + Raises: + ValueError: if type_str does not have a match in `resolvers`. + """ + resolvers = resolvers if resolvers is not None else DEFAULT_RESOLVERS + for resolver in resolvers: + cirq_type = resolver(type_str) + if cirq_type is not None: + return cirq_type + raise ValueError(f"Could not resolve type '{type_str}' during deserialization") + + +def cirq_type_from_json(type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None) -> Type: + """Returns a type object for JSON deserialization of `type_str`. + + This method is not part of the base deserialization path. Together with + `json_cirq_type`, it can be used to provide type-object deserialization + for classes that need it. + + Args: + type_str: string representation of the type to deserialize. + resolvers: list of JsonResolvers to use in type resolution. If this is + left blank, DEFAULT_RESOLVERS will be used. + + Returns: + The type object T for which json_cirq_type(T) matches `type_str`. + + Raises: + ValueError: if type_str does not have a match in `resolvers`, or if the + match found is a factory method instead of a type. + """ + cirq_type = factory_from_json(type_str, resolvers) + if isinstance(cirq_type, type): + return cirq_type + # We assume that if factory_from_json returns a factory, there is not + # another resolver which resolves `type_str` to a type object. + raise ValueError(f"Type {type_str} maps to a factory method instead of a type.") + + # pylint: disable=function-redefined @overload def to_json( diff --git a/cirq-core/cirq/protocols/json_serialization_test.py b/cirq-core/cirq/protocols/json_serialization_test.py index 8fe5a305b13..d3e67f61760 100644 --- a/cirq-core/cirq/protocols/json_serialization_test.py +++ b/cirq-core/cirq/protocols/json_serialization_test.py @@ -21,7 +21,7 @@ import pathlib import sys import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type from unittest import mock import numpy as np @@ -534,6 +534,73 @@ def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: st ) +@dataclasses.dataclass +class SerializableTypeObject: + test_type: Type + + def _json_dict_(self): + return { + 'cirq_type': 'SerializableTypeObject', + 'test_type': json_serialization.json_cirq_type(self.test_type), + } + + @classmethod + def _from_json_dict_(cls, test_type, **kwargs): + return cls(json_serialization.cirq_type_from_json(test_type)) + + +@pytest.mark.parametrize( + 'mod_spec,cirq_obj_name,cls', + _list_public_classes_for_tested_modules(), +) +def test_type_serialization(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls): + if cirq_obj_name in mod_spec.tested_elsewhere: + pytest.skip("Tested elsewhere.") + + if cirq_obj_name in mod_spec.not_yet_serializable: + return pytest.xfail(reason="Not serializable (yet)") + + if cls is None: + pytest.skip(f'No serialization for None-mapped type: {cirq_obj_name}') + + try: + typename = cirq.json_cirq_type(cls) + except ValueError as e: + pytest.skip(f'No serialization for non-Cirq type: {str(e)}') + + def custom_resolver(name): + if name == 'SerializableTypeObject': + return SerializableTypeObject + + sto = SerializableTypeObject(cls) + test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS + expected_json = ( + f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "{typename}"\n}}' + ) + assert cirq.to_json(sto) == expected_json + assert cirq.read_json(json_text=expected_json, resolvers=test_resolvers) == sto + assert_json_roundtrip_works(sto, resolvers=test_resolvers) + + +def test_invalid_type_deserialize(): + def custom_resolver(name): + if name == 'SerializableTypeObject': + return SerializableTypeObject + + test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS + invalid_json = ( + f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "bad_type"\n}}' + ) + with pytest.raises(ValueError, match='Could not resolve type'): + _ = cirq.read_json(json_text=invalid_json, resolvers=test_resolvers) + + factory_json = ( + f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "sympy.Add"\n}}' + ) + with pytest.raises(ValueError, match='maps to a factory method'): + _ = cirq.read_json(json_text=factory_json, resolvers=test_resolvers) + + def test_to_from_strings(): x_json_text = """{ "cirq_type": "_PauliX", diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 11f87518c27..ed52f0569e8 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -130,6 +130,7 @@ 'SimulatesFinalState', 'NamedTopology', # protocols: + 'HasJSONNamespace', 'SupportsActOn', 'SupportsActOnQubits', 'SupportsApplyChannel', diff --git a/cirq-google/cirq_google/workflow/io.py b/cirq-google/cirq_google/workflow/io.py index 53b127c85fa..544cab16801 100644 --- a/cirq-google/cirq_google/workflow/io.py +++ b/cirq-google/cirq_google/workflow/io.py @@ -66,8 +66,12 @@ def load(self, *, base_data_dir: str = ".") -> 'cg.ExecutableGroupResult': ], ) + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self) -> Dict[str, Any]: - return dataclass_json_dict(self, namespace='cirq.google') + return dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) def __repr__(self) -> str: return _compat.dataclass_repr(self, namespace='cirq_google') diff --git a/cirq-google/cirq_google/workflow/quantum_executable.py b/cirq-google/cirq_google/workflow/quantum_executable.py index d1d4a52d1bb..b11dc475211 100644 --- a/cirq-google/cirq_google/workflow/quantum_executable.py +++ b/cirq-google/cirq_google/workflow/quantum_executable.py @@ -53,8 +53,12 @@ class KeyValueExecutableSpec(ExecutableSpec): executable_family: str key_value_pairs: Tuple[Tuple[str, Any], ...] = () + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self) -> Dict[str, Any]: - return cirq.dataclass_json_dict(self, namespace='cirq.google') + return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) @classmethod def from_dict(cls, d: Dict[str, Any], *, executable_family: str) -> 'KeyValueExecutableSpec': @@ -90,8 +94,12 @@ class BitstringsMeasurement: n_repetitions: int + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self): - return cirq.dataclass_json_dict(self, namespace='cirq.google') + return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) def __repr__(self): return cirq._compat.dataclass_repr(self, namespace='cirq_google') @@ -198,8 +206,12 @@ def __str__(self): def __repr__(self): return _compat.dataclass_repr(self, namespace='cirq_google') + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self): - return cirq.dataclass_json_dict(self, namespace='cirq.google') + return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) @dataclass(frozen=True) @@ -248,5 +260,9 @@ def __repr__(self) -> str: def __hash__(self) -> int: return self._hash # type: ignore + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self) -> Dict[str, Any]: - return cirq.dataclass_json_dict(self, namespace='cirq.google') + return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) diff --git a/cirq-google/cirq_google/workflow/quantum_runtime.py b/cirq-google/cirq_google/workflow/quantum_runtime.py index 12301aa97ce..f6a09d5b4ba 100644 --- a/cirq-google/cirq_google/workflow/quantum_runtime.py +++ b/cirq-google/cirq_google/workflow/quantum_runtime.py @@ -43,8 +43,12 @@ class SharedRuntimeInfo: run_id: str + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self) -> Dict[str, Any]: - return dataclass_json_dict(self, namespace='cirq.google') + return dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) def __repr__(self) -> str: return _compat.dataclass_repr(self, namespace='cirq_google') @@ -63,8 +67,12 @@ class RuntimeInfo: execution_index: int + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self) -> Dict[str, Any]: - return dataclass_json_dict(self, namespace='cirq.google') + return dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) def __repr__(self) -> str: return _compat.dataclass_repr(self, namespace='cirq_google') @@ -85,8 +93,12 @@ class ExecutableResult: runtime_info: RuntimeInfo raw_data: cirq.Result + @classmethod + def _json_namespace_(cls) -> str: + return 'cirq.google' + def _json_dict_(self) -> Dict[str, Any]: - return dataclass_json_dict(self, namespace='cirq.google') + return dataclass_json_dict(self, namespace=cirq.json_namespace(type(self))) def __repr__(self) -> str: return _compat.dataclass_repr(self, namespace='cirq_google')