Skip to content

Commit 33eea01

Browse files
authored
Rework SerializableByKey handling to improve performance (#6469)
Review: @95-martin-orion
1 parent a4ec796 commit 33eea01

14 files changed

+7585
-3891
lines changed

Diff for: cirq-core/cirq/protocols/json_serialization.py

+50-179
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
Optional,
2929
overload,
3030
Sequence,
31-
Set,
3231
Tuple,
3332
Type,
3433
Union,
@@ -221,10 +220,22 @@ class CirqEncoder(json.JSONEncoder):
221220
See https://github.com/quantumlib/Cirq/issues/2014
222221
"""
223222

223+
def __init__(self, *args, **kwargs) -> None:
224+
super().__init__(*args, **kwargs)
225+
self._memo: dict[Any, dict] = {}
226+
224227
def default(self, o):
225228
# Object with custom method?
226229
if hasattr(o, '_json_dict_'):
227-
return _json_dict_with_cirq_type(o)
230+
json_dict = _json_dict_with_cirq_type(o)
231+
if isinstance(o, SerializableByKey):
232+
if ref := self._memo.get(o):
233+
return ref
234+
key = len(self._memo)
235+
ref = {"cirq_type": "REF", "key": key}
236+
self._memo[o] = ref
237+
return {"cirq_type": "VAL", "key": key, "val": json_dict}
238+
return json_dict
228239

229240
# Sympy object? (Must come before general number checks.)
230241
# TODO: More support for sympy
@@ -306,27 +317,46 @@ def default(self, o):
306317
return super().default(o) # pragma: no cover
307318

308319

309-
def _cirq_object_hook(d, resolvers: Sequence[JsonResolver], context_map: Dict[str, Any]):
310-
if 'cirq_type' not in d:
311-
return d
320+
class ObjectHook:
321+
"""Callable to be used as object_hook during deserialization."""
322+
323+
LEGACY_CONTEXT_TYPES = {'_ContextualSerialization', '_SerializedKey', '_SerializedContext'}
324+
325+
def __init__(self, resolvers: Sequence[JsonResolver]) -> None:
326+
self.resolvers = resolvers
327+
self.memo: Dict[int, SerializableByKey] = {}
328+
self.context_map: Dict[int, SerializableByKey] = {}
312329

313-
if d['cirq_type'] == '_SerializedKey':
314-
return _SerializedKey.read_from_context(context_map, **d)
330+
def __call__(self, d):
331+
cirq_type = d.get('cirq_type')
332+
if cirq_type is None:
333+
return d
315334

316-
if d['cirq_type'] == '_SerializedContext':
317-
_SerializedContext.update_context(context_map, **d)
318-
return None
335+
if cirq_type == 'VAL':
336+
obj = d['val']
337+
self.memo[d['key']] = obj
338+
return obj
319339

320-
if d['cirq_type'] == '_ContextualSerialization':
321-
return _ContextualSerialization.deserialize_with_context(**d)
340+
if cirq_type == 'REF':
341+
return self.memo[d['key']]
322342

323-
cls = factory_from_json(d['cirq_type'], resolvers=resolvers)
324-
from_json_dict = getattr(cls, '_from_json_dict_', None)
325-
if from_json_dict is not None:
326-
return from_json_dict(**d)
343+
# Deserialize from legacy "contextual serialization" format
344+
if cirq_type in self.LEGACY_CONTEXT_TYPES:
345+
if cirq_type == '_SerializedKey':
346+
return self.context_map[d['key']]
347+
if cirq_type == '_SerializedContext':
348+
self.context_map[d['key']] = d['obj']
349+
return None
350+
if cirq_type == '_ContextualSerialization':
351+
return d['object_dag'][-1]
327352

328-
del d['cirq_type']
329-
return cls(**d)
353+
cls = factory_from_json(cirq_type, resolvers=self.resolvers)
354+
from_json_dict = getattr(cls, '_from_json_dict_', None)
355+
if from_json_dict is not None:
356+
return from_json_dict(**d)
357+
358+
del d['cirq_type']
359+
return cls(**d)
330360

331361

332362
class SerializableByKey(SupportsJSON):
@@ -338,137 +368,6 @@ class SerializableByKey(SupportsJSON):
338368
"""
339369

340370

341-
class _SerializedKey(SupportsJSON):
342-
"""Internal object for holding a SerializableByKey key.
343-
344-
This is a private type used in contextual serialization. Its deserialization
345-
is context-dependent, and is not expected to match the original; in other
346-
words, `cls._from_json_dict_(obj._json_dict_())` does not return
347-
the original `obj` for this type.
348-
"""
349-
350-
def __init__(self, key: str):
351-
self.key = key
352-
353-
def _json_dict_(self):
354-
return obj_to_dict_helper(self, ['key'])
355-
356-
@classmethod
357-
def _from_json_dict_(cls, **kwargs):
358-
raise TypeError(f'Internal error: {cls} should never deserialize with _from_json_dict_.')
359-
360-
@classmethod
361-
def read_from_context(cls, context_map, key, **kwargs):
362-
return context_map[key]
363-
364-
365-
class _SerializedContext(SupportsJSON):
366-
"""Internal object for a single SerializableByKey key-to-object mapping.
367-
368-
This is a private type used in contextual serialization. Its deserialization
369-
is context-dependent, and is not expected to match the original; in other
370-
words, `cls._from_json_dict_(obj._json_dict_())` does not return
371-
the original `obj` for this type.
372-
"""
373-
374-
def __init__(self, obj: SerializableByKey, uid: int):
375-
self.key = uid
376-
self.obj = obj
377-
378-
def _json_dict_(self):
379-
return obj_to_dict_helper(self, ['key', 'obj'])
380-
381-
@classmethod
382-
def _from_json_dict_(cls, **kwargs):
383-
raise TypeError(f'Internal error: {cls} should never deserialize with _from_json_dict_.')
384-
385-
@classmethod
386-
def update_context(cls, context_map, key, obj, **kwargs):
387-
context_map.update({key: obj})
388-
389-
390-
class _ContextualSerialization(SupportsJSON):
391-
"""Internal object for serializing an object with its context.
392-
393-
This is a private type used in contextual serialization. Its deserialization
394-
is context-dependent, and is not expected to match the original; in other
395-
words, `cls._from_json_dict_(obj._json_dict_())` does not return
396-
the original `obj` for this type.
397-
"""
398-
399-
def __init__(self, obj: Any):
400-
# Context information and the wrapped object are stored together in
401-
# `object_dag` to ensure consistent serialization ordering.
402-
self.object_dag = []
403-
context = []
404-
for sbk in get_serializable_by_keys(obj):
405-
if sbk not in context:
406-
context.append(sbk)
407-
new_sc = _SerializedContext(sbk, len(context))
408-
self.object_dag.append(new_sc)
409-
self.object_dag += [obj]
410-
411-
def _json_dict_(self):
412-
return obj_to_dict_helper(self, ['object_dag'])
413-
414-
@classmethod
415-
def _from_json_dict_(cls, **kwargs):
416-
raise TypeError(f'Internal error: {cls} should never deserialize with _from_json_dict_.')
417-
418-
@classmethod
419-
def deserialize_with_context(cls, object_dag, **kwargs):
420-
# The last element of object_dag is the object to be deserialized.
421-
return object_dag[-1]
422-
423-
424-
def has_serializable_by_keys(obj: Any) -> bool:
425-
"""Returns true if obj contains one or more SerializableByKey objects."""
426-
if isinstance(obj, SerializableByKey):
427-
return True
428-
json_dict = getattr(obj, '_json_dict_', lambda: None)()
429-
if isinstance(json_dict, Dict):
430-
return any(has_serializable_by_keys(v) for v in json_dict.values())
431-
432-
# Handle primitive container types.
433-
if isinstance(obj, Dict):
434-
return any(has_serializable_by_keys(elem) for pair in obj.items() for elem in pair)
435-
436-
if hasattr(obj, '__iter__') and not isinstance(obj, str):
437-
# Return False on TypeError because some numpy values
438-
# (like np.array(1)) have iterable methods
439-
# yet return a TypeError when there is an attempt to iterate over them
440-
try:
441-
return any(has_serializable_by_keys(elem) for elem in obj)
442-
except TypeError:
443-
return False
444-
return False
445-
446-
447-
def get_serializable_by_keys(obj: Any) -> List[SerializableByKey]:
448-
"""Returns all SerializableByKeys contained by obj.
449-
450-
Objects are ordered such that nested objects appear before the object they
451-
are nested inside. This is required to ensure SerializableByKeys are only
452-
fully defined once in serialization.
453-
"""
454-
result = []
455-
if isinstance(obj, SerializableByKey):
456-
result.append(obj)
457-
json_dict = getattr(obj, '_json_dict_', lambda: None)()
458-
if isinstance(json_dict, Dict):
459-
for v in json_dict.values():
460-
result = get_serializable_by_keys(v) + result
461-
if result:
462-
return result
463-
464-
# Handle primitive container types.
465-
if isinstance(obj, Dict):
466-
return [sbk for pair in obj.items() for sbk in get_serializable_by_keys(pair)]
467-
if hasattr(obj, '__iter__') and not isinstance(obj, str):
468-
return [sbk for v in obj for sbk in get_serializable_by_keys(v)]
469-
return []
470-
471-
472371
def json_namespace(type_obj: Type) -> str:
473372
"""Returns a namespace for JSON serialization of `type_obj`.
474373
@@ -610,37 +509,12 @@ def to_json(
610509
party classes, prefer adding the `_json_dict_` magic method
611510
to your classes rather than overriding this default.
612511
"""
613-
if has_serializable_by_keys(obj):
614-
obj = _ContextualSerialization(obj)
615-
616-
class ContextualEncoder(cls): # type: ignore
617-
"""An encoder with a context map for concise serialization."""
618-
619-
# These lists populate gradually during serialization. An object
620-
# with components defined in 'context' will represent those
621-
# components using their keys instead of inline definition.
622-
seen: Set[str] = set()
623-
624-
def default(self, o):
625-
if not isinstance(o, SerializableByKey):
626-
return super().default(o)
627-
for candidate in obj.object_dag[:-1]:
628-
if candidate.obj == o:
629-
if not candidate.key in ContextualEncoder.seen:
630-
ContextualEncoder.seen.add(candidate.key)
631-
return _json_dict_with_cirq_type(candidate.obj)
632-
else:
633-
return _json_dict_with_cirq_type(_SerializedKey(candidate.key))
634-
raise ValueError("Object mutated during serialization.") # pragma: no cover
635-
636-
cls = ContextualEncoder
637-
638512
if file_or_fn is None:
639513
return json.dumps(obj, indent=indent, separators=separators, cls=cls)
640514

641515
if isinstance(file_or_fn, (str, pathlib.Path)):
642516
with open(file_or_fn, 'w') as actually_a_file:
643-
json.dump(obj, actually_a_file, indent=indent, cls=cls)
517+
json.dump(obj, actually_a_file, indent=indent, separators=separators, cls=cls)
644518
return None
645519

646520
json.dump(obj, file_or_fn, indent=indent, separators=separators, cls=cls)
@@ -682,10 +556,7 @@ def read_json(
682556
if resolvers is None:
683557
resolvers = DEFAULT_RESOLVERS
684558

685-
context_map: Dict[str, 'SerializableByKey'] = {}
686-
687-
def obj_hook(x):
688-
return _cirq_object_hook(x, resolvers, context_map)
559+
obj_hook = ObjectHook(resolvers)
689560

690561
if json_text is not None:
691562
return json.loads(json_text, object_hook=obj_hook)

Diff for: cirq-core/cirq/protocols/json_serialization_test.py

+11-39
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,11 @@ def __eq__(self, other):
373373
and self.data_dict == other.data_dict
374374
)
375375

376+
def __hash__(self):
377+
return hash(
378+
(self.name, tuple(self.data_list), self.data_tuple, frozenset(self.data_dict.items()))
379+
)
380+
376381
def _json_dict_(self):
377382
return {
378383
"name": self.name,
@@ -386,12 +391,12 @@ def _from_json_dict_(cls, name, data_list, data_tuple, data_dict, **kwargs):
386391
return cls(name, data_list, tuple(data_tuple), data_dict)
387392

388393

389-
def test_context_serialization():
394+
def test_serializable_by_key():
390395
def custom_resolver(name):
391396
if name == 'SBKImpl':
392397
return SBKImpl
393398

394-
test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
399+
test_resolvers = [custom_resolver, *cirq.DEFAULT_RESOLVERS]
395400

396401
sbki_empty = SBKImpl('sbki_empty')
397402
assert_json_roundtrip_works(sbki_empty, resolvers=test_resolvers)
@@ -406,55 +411,22 @@ def custom_resolver(name):
406411
assert_json_roundtrip_works(sbki_dict, resolvers=test_resolvers)
407412

408413
sbki_json = str(cirq.to_json(sbki_dict))
409-
# There should be exactly one context item for each previous SBKImpl.
410-
assert sbki_json.count('"cirq_type": "_SerializedContext"') == 4
411-
# There should be exactly two key items for each of sbki_(empty|list|tuple),
412-
# plus one for the top-level sbki_dict.
413-
assert sbki_json.count('"cirq_type": "_SerializedKey"') == 7
414-
# The final object should be a _SerializedKey for sbki_dict.
415-
final_obj_idx = sbki_json.rfind('{')
416-
final_obj = sbki_json[final_obj_idx : sbki_json.find('}', final_obj_idx) + 1]
417-
assert (
418-
final_obj
419-
== """{
420-
"cirq_type": "_SerializedKey",
421-
"key": 4
422-
}"""
423-
)
414+
# There are 4 SBKImpl instances, one each for empty, list, tuple, dict.
415+
assert sbki_json.count('"cirq_type": "VAL"') == 4
416+
# There are 3 SBKImpl refs, one each for empty, list, and tuple.
417+
assert sbki_json.count('"cirq_type": "REF"') == 3
424418

425419
list_sbki = [sbki_dict]
426420
assert_json_roundtrip_works(list_sbki, resolvers=test_resolvers)
427421

428422
dict_sbki = {'a': sbki_dict}
429423
assert_json_roundtrip_works(dict_sbki, resolvers=test_resolvers)
430424

431-
assert sbki_list != json_serialization._SerializedKey(sbki_list)
432-
433425
# Serialization keys have unique suffixes.
434426
sbki_other_list = SBKImpl('sbki_list', data_list=[sbki_list])
435427
assert_json_roundtrip_works(sbki_other_list, resolvers=test_resolvers)
436428

437429

438-
def test_internal_serializer_types():
439-
sbki = SBKImpl('test_key')
440-
key = 1
441-
test_key = json_serialization._SerializedKey(key)
442-
test_context = json_serialization._SerializedContext(sbki, 1)
443-
test_serialization = json_serialization._ContextualSerialization(sbki)
444-
445-
key_json = test_key._json_dict_()
446-
with pytest.raises(TypeError, match='_from_json_dict_'):
447-
_ = json_serialization._SerializedKey._from_json_dict_(**key_json)
448-
449-
context_json = test_context._json_dict_()
450-
with pytest.raises(TypeError, match='_from_json_dict_'):
451-
_ = json_serialization._SerializedContext._from_json_dict_(**context_json)
452-
453-
serialization_json = test_serialization._json_dict_()
454-
with pytest.raises(TypeError, match='_from_json_dict_'):
455-
_ = json_serialization._ContextualSerialization._from_json_dict_(**serialization_json)
456-
457-
458430
# during test setup deprecated submodules are inspected and trigger the
459431
# deprecation error in testing. It is cleaner to just turn it off than to assert
460432
# deprecation for each submodule.

0 commit comments

Comments
 (0)