Skip to content

Commit d027138

Browse files
mpharriganrht
authored andcommitted
dataclass_json_dict (quantumlib#4391)
Dataclasses keep track of their relevant fields, so we can automatically generate these. Dataclasses are implemented with somewhat complex metaprogramming, and tooling (PyCharm, mypy) have special cases for dealing with classes decorated with `@dataclass`. There is very little support (and no plans for support) for decorators that wrap `@dataclass` (like `@cirq.json_serializable_dataclass`) or combining additional decorators with `@dataclass`. Although not as elegant, you may want to consider explicitly defining `_json_dict_` on your dataclasses which simply `return dataclass_json_dict(self)`.
1 parent df9c9e8 commit d027138

File tree

5 files changed

+50
-12
lines changed

5 files changed

+50
-12
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@
512512
is_parameterized,
513513
JsonResolver,
514514
json_serializable_dataclass,
515+
dataclass_json_dict,
515516
kraus,
516517
measurement_key,
517518
measurement_key_name,

cirq-core/cirq/experiments/xeb_fitting.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Estimation of fidelity associated with experimental circuit executions."""
15+
import dataclasses
1516
from abc import abstractmethod, ABC
16-
from dataclasses import dataclass
1717
from typing import (
1818
List,
1919
Optional,
@@ -29,20 +29,14 @@
2929
import scipy.optimize
3030
import scipy.stats
3131
import sympy
32-
33-
from cirq import ops
32+
from cirq import ops, protocols
3433
from cirq.circuits import Circuit
3534
from cirq.experiments.xeb_simulation import simulate_2q_xeb_circuits
3635

3736
if TYPE_CHECKING:
3837
import cirq
3938
import multiprocessing
4039

41-
# Workaround for mypy custom dataclasses (python/mypy#5406)
42-
from dataclasses import dataclass as json_serializable_dataclass
43-
else:
44-
from cirq.protocols import json_serializable_dataclass
45-
4640
THETA_SYMBOL, ZETA_SYMBOL, CHI_SYMBOL, GAMMA_SYMBOL, PHI_SYMBOL = sympy.symbols(
4741
'theta zeta chi gamma phi'
4842
)
@@ -191,8 +185,7 @@ def phased_fsim_angles_from_gate(gate: 'cirq.Gate') -> Dict[str, float]:
191185
raise ValueError(f"Unknown default angles for {gate}.")
192186

193187

194-
# mypy issue: https://github.com/python/mypy/issues/5374
195-
@json_serializable_dataclass(frozen=True) # type: ignore
188+
@dataclasses.dataclass(frozen=True)
196189
class XEBPhasedFSimCharacterizationOptions(XEBCharacterizationOptions):
197190
"""Options for calibrating a PhasedFSim-like gate using XEB.
198191
@@ -320,6 +313,9 @@ def with_defaults_from_gate(
320313
**gate_to_angles_func(gate),
321314
)
322315

316+
def _json_dict_(self):
317+
return protocols.dataclass_json_dict(self)
318+
323319

324320
def SqrtISwapXEBOptions(*args, **kwargs):
325321
"""Options for calibrating a sqrt(ISWAP) gate using XEB."""
@@ -346,7 +342,7 @@ def parameterize_circuit(
346342
QPair_T = Tuple['cirq.Qid', 'cirq.Qid']
347343

348344

349-
@dataclass(frozen=True)
345+
@dataclasses.dataclass(frozen=True)
350346
class XEBCharacterizationResult:
351347
"""The result of `characterize_phased_fsim_parameters_with_xeb`.
352348
@@ -437,7 +433,7 @@ def _mean_infidelity(angles):
437433
)
438434

439435

440-
@dataclass(frozen=True)
436+
@dataclasses.dataclass(frozen=True)
441437
class _CharacterizePhasedFsimParametersWithXebClosure:
442438
"""A closure object to wrap `characterize_phased_fsim_parameters_with_xeb` for use in
443439
multiprocessing."""

cirq-core/cirq/protocols/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
to_json,
8989
read_json,
9090
obj_to_dict_helper,
91+
dataclass_json_dict,
9192
SerializableByKey,
9293
SupportsJSON,
9394
)

cirq-core/cirq/protocols/json_serialization.py

+22
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ def json_serializable_dataclass(
177177
the ``_json_dict_`` protocol method which automatically determines
178178
the appropriate fields from the dataclass.
179179
180+
Dataclasses are implemented with somewhat complex metaprogramming, and
181+
tooling (PyCharm, mypy) have special cases for dealing with classes
182+
decorated with @dataclass. There is very little support (and no plans for
183+
support) for decorators that wrap @dataclass like this. Consider explicitly
184+
defining `_json_dict_` on your dataclasses which simply
185+
`return dataclass_json_dict(self)`.
186+
180187
Args:
181188
namespace: An optional prefix to the value associated with the
182189
key "cirq_type". The namespace name will be joined with the
@@ -209,6 +216,21 @@ def wrap(cls):
209216
# pylint: enable=redefined-builtin
210217

211218

219+
def dataclass_json_dict(obj: Any, namespace: str = None) -> Dict[str, Any]:
220+
"""Return a dictionary suitable for _json_dict_ from a dataclass.
221+
222+
Dataclasses keep track of their relevant fields, so we can automatically generate these.
223+
224+
Dataclasses are implemented with somewhat complex metaprogramming, and tooling (PyCharm, mypy)
225+
have special cases for dealing with classes decorated with @dataclass. There is very little
226+
support (and no plans for support) for decorators that wrap @dataclass (like
227+
@cirq.json_serializable_dataclass) or combining additional decorators with @dataclass.
228+
Although not as elegant, you may want to consider explicitly defining `_json_dict_` on your
229+
dataclasses which simply `return dataclass_json_dict(self)`.
230+
"""
231+
return obj_to_dict_helper(obj, [f.name for f in dataclasses.fields(obj)], namespace=namespace)
232+
233+
212234
class CirqEncoder(json.JSONEncoder):
213235
"""Extend json.JSONEncoder to support Cirq objects.
214236

cirq-core/cirq/protocols/json_serialization_test.py

+18
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,24 @@ def custom_resolver(name):
775775
assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS)
776776

777777

778+
def test_dataclass_json_dict():
779+
@dataclasses.dataclass(frozen=True)
780+
class MyDC:
781+
q: cirq.LineQubit
782+
desc: str
783+
784+
def _json_dict_(self):
785+
return cirq.dataclass_json_dict(self)
786+
787+
def custom_resolver(name):
788+
if name == 'MyDC':
789+
return MyDC
790+
791+
my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
792+
793+
assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver, *cirq.DEFAULT_RESOLVERS])
794+
795+
778796
def test_json_serializable_dataclass_namespace():
779797
@cirq.json_serializable_dataclass(namespace='cirq.experiments')
780798
class QuantumVolumeParams:

0 commit comments

Comments
 (0)