Skip to content

Support and test type serialization #4693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@
circuit_diagram_info,
CircuitDiagramInfo,
CircuitDiagramInfoArgs,
cirq_type_from_json,
commutes,
control_keys,
decompose,
Expand All @@ -514,10 +515,13 @@
has_mixture,
has_stabilizer_effect,
has_unitary,
HasJSONNamespace,
inverse,
is_measurement,
is_parameterized,
JsonResolver,
json_cirq_type,
json_namespace,
json_serializable_dataclass,
dataclass_json_dict,
kraus,
Expand Down
9 changes: 5 additions & 4 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _parallel_gate_op(gate, qubits):
'HPowGate': cirq.HPowGate,
'ISwapPowGate': cirq.ISwapPowGate,
'IdentityGate': cirq.IdentityGate,
'IdentityOperation': _identity_operation_from_dict,
'InitObsSetting': cirq.work.InitObsSetting,
'KrausChannel': cirq.KrausChannel,
'LinearDict': cirq.LinearDict,
Expand All @@ -115,7 +114,6 @@ def _parallel_gate_op(gate, qubits):
'_PauliY': cirq.ops.pauli_gates._PauliY,
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14
'ParallelGate': cirq.ParallelGate,
'PauliMeasurementGate': cirq.PauliMeasurementGate,
'PauliString': cirq.PauliString,
Expand All @@ -134,7 +132,6 @@ def _parallel_gate_op(gate, qubits):
'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria,
'ResetChannel': cirq.ResetChannel,
'SingleQubitCliffordGate': cirq.SingleQubitCliffordGate,
'SingleQubitMatrixGate': single_qubit_matrix_gate,
'SingleQubitPauliStringGateOperation': cirq.SingleQubitPauliStringGateOperation,
'SingleQubitReadoutCalibrationResult': cirq.experiments.SingleQubitReadoutCalibrationResult,
'StabilizerStateChForm': cirq.StabilizerStateChForm,
Expand All @@ -147,7 +144,6 @@ def _parallel_gate_op(gate, qubits):
'Rx': cirq.Rx,
'Ry': cirq.Ry,
'Rz': cirq.Rz,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
'_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice,
'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria,
'VirtualTag': cirq.VirtualTag,
Expand All @@ -163,6 +159,11 @@ def _parallel_gate_op(gate, qubits):
'YYPowGate': cirq.YYPowGate,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
# Old types, only supported for backwards-compatibility
'IdentityOperation': _identity_operation_from_dict,
'ParallelGateOperation': _parallel_gate_op, # Removed in v0.14
'SingleQubitMatrixGate': single_qubit_matrix_gate,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
# not a cirq class, but treated as one:
'pandas.DataFrame': pd.DataFrame,
'pandas.Index': pd.Index,
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,13 @@
inverse,
)
from cirq.protocols.json_serialization import (
cirq_type_from_json,
DEFAULT_RESOLVERS,
HasJSONNamespace,
JsonResolver,
json_serializable_dataclass,
json_cirq_type,
json_namespace,
to_json_gzip,
read_json_gzip,
to_json,
Expand Down
116 changes: 109 additions & 7 deletions cirq-core/cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@ def _json_dict_(self) -> Union[None, NotImplementedType, Dict[Any, Any]]:
pass


class HasJSONNamespace(Protocol):
"""An object which prepends a namespace to its JSON cirq_type.

Classes which implement this method have the following cirq_type format:

f"{obj._json_namespace_()}.{obj.__class__.__name__}

Classes outside of Cirq or its submodules MUST implement this method to be
used in type serialization.
"""

@doc_private
@classmethod
def _json_namespace_(cls) -> str:
pass


def obj_to_dict_helper(
obj: Any, attribute_names: Iterable[str], namespace: Optional[str] = None
) -> Dict[str, Any]:
Expand Down Expand Up @@ -350,13 +367,7 @@ def _cirq_object_hook(d, resolvers: Sequence[JsonResolver], context_map: Dict[st
if d['cirq_type'] == '_ContextualSerialization':
return _ContextualSerialization.deserialize_with_context(**d)

for resolver in resolvers:
cls = resolver(d['cirq_type'])
if cls is not None:
break
else:
raise ValueError(f"Could not resolve type '{d['cirq_type']}' during deserialization")

cls = factory_from_json(d['cirq_type'], resolvers=resolvers)
from_json_dict = getattr(cls, '_from_json_dict_', None)
if from_json_dict is not None:
return from_json_dict(**d)
Expand Down Expand Up @@ -505,6 +516,97 @@ def get_serializable_by_keys(obj: Any) -> List[SerializableByKey]:
return []


def json_namespace(type_obj: Type) -> str:
"""Returns a namespace for JSON serialization of `type_obj`.

Types can provide custom namespaces with `_json_namespace_`; otherwise, a
Cirq type will not include a namespace in its cirq_type. Non-Cirq types
must provide a namespace for serialization in Cirq.

Args:
type_obj: Type to retrieve the namespace from.

Returns:
The namespace to prepend `type_obj` with in its JSON cirq_type.

Raises:
ValueError: if `type_obj` is not a Cirq type and does not explicitly
define its namespace with _json_namespace_.
"""
if hasattr(type_obj, '_json_namespace_'):
return type_obj._json_namespace_()
if type_obj.__module__.startswith('cirq'):
return ''
raise ValueError(f'{type_obj} is not a Cirq type, and does not define _json_namespace_.')


def json_cirq_type(type_obj: Type) -> str:
"""Returns a string type for JSON serialization of `type_obj`.

This method is not part of the base serialization path. Together with
`cirq_type_from_json`, it can be used to provide type-object serialization
for classes that need it.
"""
namespace = json_namespace(type_obj)
if namespace:
return f'{namespace}.{type_obj.__name__}'
return type_obj.__name__


def factory_from_json(
type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None
) -> ObjectFactory:
"""Returns a factory for constructing objects of type `type_str`.

DEFAULT_RESOLVERS is updated dynamically as cirq submodules are imported.

Args:
type_str: string representation of the type to deserialize.
resolvers: list of JsonResolvers to use in type resolution. If this is
left blank, DEFAULT_RESOLVERS will be used.

Returns:
An ObjectFactory that can be called to construct an object whose type
matches the name `type_str`.

Raises:
ValueError: if type_str does not have a match in `resolvers`.
"""
resolvers = resolvers if resolvers is not None else DEFAULT_RESOLVERS
for resolver in resolvers:
cirq_type = resolver(type_str)
if cirq_type is not None:
return cirq_type
raise ValueError(f"Could not resolve type '{type_str}' during deserialization")


def cirq_type_from_json(type_str: str, resolvers: Optional[Sequence[JsonResolver]] = None) -> Type:
"""Returns a type object for JSON deserialization of `type_str`.

This method is not part of the base deserialization path. Together with
`json_cirq_type`, it can be used to provide type-object deserialization
for classes that need it.

Args:
type_str: string representation of the type to deserialize.
resolvers: list of JsonResolvers to use in type resolution. If this is
left blank, DEFAULT_RESOLVERS will be used.

Returns:
The type object T for which json_cirq_type(T) matches `type_str`.

Raises:
ValueError: if type_str does not have a match in `resolvers`, or if the
match found is a factory method instead of a type.
"""
cirq_type = factory_from_json(type_str, resolvers)
if isinstance(cirq_type, type):
return cirq_type
# We assume that if factory_from_json returns a factory, there is not
# another resolver which resolves `type_str` to a type object.
raise ValueError(f"Type {type_str} maps to a factory method instead of a type.")


# pylint: disable=function-redefined
@overload
def to_json(
Expand Down
69 changes: 68 additions & 1 deletion cirq-core/cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pathlib
import sys
import warnings
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Type
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -534,6 +534,73 @@ def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: st
)


@dataclasses.dataclass
class SerializableTypeObject:
test_type: Type

def _json_dict_(self):
return {
'cirq_type': 'SerializableTypeObject',
'test_type': json_serialization.json_cirq_type(self.test_type),
}

@classmethod
def _from_json_dict_(cls, test_type, **kwargs):
return cls(json_serialization.cirq_type_from_json(test_type))


@pytest.mark.parametrize(
'mod_spec,cirq_obj_name,cls',
_list_public_classes_for_tested_modules(),
)
def test_type_serialization(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls):
if cirq_obj_name in mod_spec.tested_elsewhere:
pytest.skip("Tested elsewhere.")

if cirq_obj_name in mod_spec.not_yet_serializable:
return pytest.xfail(reason="Not serializable (yet)")

if cls is None:
pytest.skip(f'No serialization for None-mapped type: {cirq_obj_name}')

try:
typename = cirq.json_cirq_type(cls)
except ValueError as e:
pytest.skip(f'No serialization for non-Cirq type: {str(e)}')

def custom_resolver(name):
if name == 'SerializableTypeObject':
return SerializableTypeObject

sto = SerializableTypeObject(cls)
test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
expected_json = (
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "{typename}"\n}}'
)
assert cirq.to_json(sto) == expected_json
assert cirq.read_json(json_text=expected_json, resolvers=test_resolvers) == sto
assert_json_roundtrip_works(sto, resolvers=test_resolvers)


def test_invalid_type_deserialize():
def custom_resolver(name):
if name == 'SerializableTypeObject':
return SerializableTypeObject

test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
invalid_json = (
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "bad_type"\n}}'
)
with pytest.raises(ValueError, match='Could not resolve type'):
_ = cirq.read_json(json_text=invalid_json, resolvers=test_resolvers)

factory_json = (
f'{{\n "cirq_type": "SerializableTypeObject",\n' f' "test_type": "sympy.Add"\n}}'
)
with pytest.raises(ValueError, match='maps to a factory method'):
_ = cirq.read_json(json_text=factory_json, resolvers=test_resolvers)


def test_to_from_strings():
x_json_text = """{
"cirq_type": "_PauliX",
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
'SimulatesFinalState',
'NamedTopology',
# protocols:
'HasJSONNamespace',
'SupportsActOn',
'SupportsActOnQubits',
'SupportsApplyChannel',
Expand Down
6 changes: 5 additions & 1 deletion cirq-google/cirq_google/workflow/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def load(self, *, base_data_dir: str = ".") -> 'cg.ExecutableGroupResult':
],
)

@classmethod
def _json_namespace_(cls) -> str:
return 'cirq.google'

def _json_dict_(self) -> Dict[str, Any]:
return dataclass_json_dict(self, namespace='cirq.google')
return dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems rather complicated. Why not namespace=self._json_namespace_())?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling cirq.json_namespace is more friendly to changing _json_namespace_ in a backwards-compatible way, as it limits the number of places that need to change to the _json_namespace_ invocation in the protocol method. (We encountered something similar for _resolve_parameters_ in the past.)

This is primarily a concern outside of Cirq, but since these are the only examples in Cirq I figured I should use this format to encourage its use elsewhere.


def __repr__(self) -> str:
return _compat.dataclass_repr(self, namespace='cirq_google')
Expand Down
24 changes: 20 additions & 4 deletions cirq-google/cirq_google/workflow/quantum_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ class KeyValueExecutableSpec(ExecutableSpec):
executable_family: str
key_value_pairs: Tuple[Tuple[str, Any], ...] = ()

@classmethod
def _json_namespace_(cls) -> str:
return 'cirq.google'

def _json_dict_(self) -> Dict[str, Any]:
return cirq.dataclass_json_dict(self, namespace='cirq.google')
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))

@classmethod
def from_dict(cls, d: Dict[str, Any], *, executable_family: str) -> 'KeyValueExecutableSpec':
Expand Down Expand Up @@ -90,8 +94,12 @@ class BitstringsMeasurement:

n_repetitions: int

@classmethod
def _json_namespace_(cls) -> str:
return 'cirq.google'

def _json_dict_(self):
return cirq.dataclass_json_dict(self, namespace='cirq.google')
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))

def __repr__(self):
return cirq._compat.dataclass_repr(self, namespace='cirq_google')
Expand Down Expand Up @@ -198,8 +206,12 @@ def __str__(self):
def __repr__(self):
return _compat.dataclass_repr(self, namespace='cirq_google')

@classmethod
def _json_namespace_(cls) -> str:
return 'cirq.google'

def _json_dict_(self):
return cirq.dataclass_json_dict(self, namespace='cirq.google')
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))


@dataclass(frozen=True)
Expand Down Expand Up @@ -248,5 +260,9 @@ def __repr__(self) -> str:
def __hash__(self) -> int:
return self._hash # type: ignore

@classmethod
def _json_namespace_(cls) -> str:
return 'cirq.google'

def _json_dict_(self) -> Dict[str, Any]:
return cirq.dataclass_json_dict(self, namespace='cirq.google')
return cirq.dataclass_json_dict(self, namespace=cirq.json_namespace(type(self)))
Loading