28
28
Optional ,
29
29
overload ,
30
30
Sequence ,
31
- Set ,
32
31
Tuple ,
33
32
Type ,
34
33
Union ,
@@ -221,10 +220,22 @@ class CirqEncoder(json.JSONEncoder):
221
220
See https://github.com/quantumlib/Cirq/issues/2014
222
221
"""
223
222
223
+ def __init__ (self , * args , ** kwargs ) -> None :
224
+ super ().__init__ (* args , ** kwargs )
225
+ self ._memo : dict [Any , dict ] = {}
226
+
224
227
def default (self , o ):
225
228
# Object with custom method?
226
229
if hasattr (o , '_json_dict_' ):
227
- return _json_dict_with_cirq_type (o )
230
+ json_dict = _json_dict_with_cirq_type (o )
231
+ if isinstance (o , SerializableByKey ):
232
+ if ref := self ._memo .get (o ):
233
+ return ref
234
+ key = len (self ._memo )
235
+ ref = {"cirq_type" : "REF" , "key" : key }
236
+ self ._memo [o ] = ref
237
+ return {"cirq_type" : "VAL" , "key" : key , "val" : json_dict }
238
+ return json_dict
228
239
229
240
# Sympy object? (Must come before general number checks.)
230
241
# TODO: More support for sympy
@@ -306,27 +317,46 @@ def default(self, o):
306
317
return super ().default (o ) # pragma: no cover
307
318
308
319
309
- def _cirq_object_hook (d , resolvers : Sequence [JsonResolver ], context_map : Dict [str , Any ]):
310
- if 'cirq_type' not in d :
311
- return d
320
+ class ObjectHook :
321
+ """Callable to be used as object_hook during deserialization."""
322
+
323
+ LEGACY_CONTEXT_TYPES = {'_ContextualSerialization' , '_SerializedKey' , '_SerializedContext' }
324
+
325
+ def __init__ (self , resolvers : Sequence [JsonResolver ]) -> None :
326
+ self .resolvers = resolvers
327
+ self .memo : Dict [int , SerializableByKey ] = {}
328
+ self .context_map : Dict [int , SerializableByKey ] = {}
312
329
313
- if d ['cirq_type' ] == '_SerializedKey' :
314
- return _SerializedKey .read_from_context (context_map , ** d )
330
+ def __call__ (self , d ):
331
+ cirq_type = d .get ('cirq_type' )
332
+ if cirq_type is None :
333
+ return d
315
334
316
- if d ['cirq_type' ] == '_SerializedContext' :
317
- _SerializedContext .update_context (context_map , ** d )
318
- return None
335
+ if cirq_type == 'VAL' :
336
+ obj = d ['val' ]
337
+ self .memo [d ['key' ]] = obj
338
+ return obj
319
339
320
- if d [ ' cirq_type' ] == '_ContextualSerialization ' :
321
- return _ContextualSerialization . deserialize_with_context ( ** d )
340
+ if cirq_type == 'REF ' :
341
+ return self . memo [ d [ 'key' ]]
322
342
323
- cls = factory_from_json (d ['cirq_type' ], resolvers = resolvers )
324
- from_json_dict = getattr (cls , '_from_json_dict_' , None )
325
- if from_json_dict is not None :
326
- return from_json_dict (** d )
343
+ # Deserialize from legacy "contextual serialization" format
344
+ if cirq_type in self .LEGACY_CONTEXT_TYPES :
345
+ if cirq_type == '_SerializedKey' :
346
+ return self .context_map [d ['key' ]]
347
+ if cirq_type == '_SerializedContext' :
348
+ self .context_map [d ['key' ]] = d ['obj' ]
349
+ return None
350
+ if cirq_type == '_ContextualSerialization' :
351
+ return d ['object_dag' ][- 1 ]
327
352
328
- del d ['cirq_type' ]
329
- return cls (** d )
353
+ cls = factory_from_json (cirq_type , resolvers = self .resolvers )
354
+ from_json_dict = getattr (cls , '_from_json_dict_' , None )
355
+ if from_json_dict is not None :
356
+ return from_json_dict (** d )
357
+
358
+ del d ['cirq_type' ]
359
+ return cls (** d )
330
360
331
361
332
362
class SerializableByKey (SupportsJSON ):
@@ -338,137 +368,6 @@ class SerializableByKey(SupportsJSON):
338
368
"""
339
369
340
370
341
- class _SerializedKey (SupportsJSON ):
342
- """Internal object for holding a SerializableByKey key.
343
-
344
- This is a private type used in contextual serialization. Its deserialization
345
- is context-dependent, and is not expected to match the original; in other
346
- words, `cls._from_json_dict_(obj._json_dict_())` does not return
347
- the original `obj` for this type.
348
- """
349
-
350
- def __init__ (self , key : str ):
351
- self .key = key
352
-
353
- def _json_dict_ (self ):
354
- return obj_to_dict_helper (self , ['key' ])
355
-
356
- @classmethod
357
- def _from_json_dict_ (cls , ** kwargs ):
358
- raise TypeError (f'Internal error: { cls } should never deserialize with _from_json_dict_.' )
359
-
360
- @classmethod
361
- def read_from_context (cls , context_map , key , ** kwargs ):
362
- return context_map [key ]
363
-
364
-
365
- class _SerializedContext (SupportsJSON ):
366
- """Internal object for a single SerializableByKey key-to-object mapping.
367
-
368
- This is a private type used in contextual serialization. Its deserialization
369
- is context-dependent, and is not expected to match the original; in other
370
- words, `cls._from_json_dict_(obj._json_dict_())` does not return
371
- the original `obj` for this type.
372
- """
373
-
374
- def __init__ (self , obj : SerializableByKey , uid : int ):
375
- self .key = uid
376
- self .obj = obj
377
-
378
- def _json_dict_ (self ):
379
- return obj_to_dict_helper (self , ['key' , 'obj' ])
380
-
381
- @classmethod
382
- def _from_json_dict_ (cls , ** kwargs ):
383
- raise TypeError (f'Internal error: { cls } should never deserialize with _from_json_dict_.' )
384
-
385
- @classmethod
386
- def update_context (cls , context_map , key , obj , ** kwargs ):
387
- context_map .update ({key : obj })
388
-
389
-
390
- class _ContextualSerialization (SupportsJSON ):
391
- """Internal object for serializing an object with its context.
392
-
393
- This is a private type used in contextual serialization. Its deserialization
394
- is context-dependent, and is not expected to match the original; in other
395
- words, `cls._from_json_dict_(obj._json_dict_())` does not return
396
- the original `obj` for this type.
397
- """
398
-
399
- def __init__ (self , obj : Any ):
400
- # Context information and the wrapped object are stored together in
401
- # `object_dag` to ensure consistent serialization ordering.
402
- self .object_dag = []
403
- context = []
404
- for sbk in get_serializable_by_keys (obj ):
405
- if sbk not in context :
406
- context .append (sbk )
407
- new_sc = _SerializedContext (sbk , len (context ))
408
- self .object_dag .append (new_sc )
409
- self .object_dag += [obj ]
410
-
411
- def _json_dict_ (self ):
412
- return obj_to_dict_helper (self , ['object_dag' ])
413
-
414
- @classmethod
415
- def _from_json_dict_ (cls , ** kwargs ):
416
- raise TypeError (f'Internal error: { cls } should never deserialize with _from_json_dict_.' )
417
-
418
- @classmethod
419
- def deserialize_with_context (cls , object_dag , ** kwargs ):
420
- # The last element of object_dag is the object to be deserialized.
421
- return object_dag [- 1 ]
422
-
423
-
424
- def has_serializable_by_keys (obj : Any ) -> bool :
425
- """Returns true if obj contains one or more SerializableByKey objects."""
426
- if isinstance (obj , SerializableByKey ):
427
- return True
428
- json_dict = getattr (obj , '_json_dict_' , lambda : None )()
429
- if isinstance (json_dict , Dict ):
430
- return any (has_serializable_by_keys (v ) for v in json_dict .values ())
431
-
432
- # Handle primitive container types.
433
- if isinstance (obj , Dict ):
434
- return any (has_serializable_by_keys (elem ) for pair in obj .items () for elem in pair )
435
-
436
- if hasattr (obj , '__iter__' ) and not isinstance (obj , str ):
437
- # Return False on TypeError because some numpy values
438
- # (like np.array(1)) have iterable methods
439
- # yet return a TypeError when there is an attempt to iterate over them
440
- try :
441
- return any (has_serializable_by_keys (elem ) for elem in obj )
442
- except TypeError :
443
- return False
444
- return False
445
-
446
-
447
- def get_serializable_by_keys (obj : Any ) -> List [SerializableByKey ]:
448
- """Returns all SerializableByKeys contained by obj.
449
-
450
- Objects are ordered such that nested objects appear before the object they
451
- are nested inside. This is required to ensure SerializableByKeys are only
452
- fully defined once in serialization.
453
- """
454
- result = []
455
- if isinstance (obj , SerializableByKey ):
456
- result .append (obj )
457
- json_dict = getattr (obj , '_json_dict_' , lambda : None )()
458
- if isinstance (json_dict , Dict ):
459
- for v in json_dict .values ():
460
- result = get_serializable_by_keys (v ) + result
461
- if result :
462
- return result
463
-
464
- # Handle primitive container types.
465
- if isinstance (obj , Dict ):
466
- return [sbk for pair in obj .items () for sbk in get_serializable_by_keys (pair )]
467
- if hasattr (obj , '__iter__' ) and not isinstance (obj , str ):
468
- return [sbk for v in obj for sbk in get_serializable_by_keys (v )]
469
- return []
470
-
471
-
472
371
def json_namespace (type_obj : Type ) -> str :
473
372
"""Returns a namespace for JSON serialization of `type_obj`.
474
373
@@ -610,37 +509,12 @@ def to_json(
610
509
party classes, prefer adding the `_json_dict_` magic method
611
510
to your classes rather than overriding this default.
612
511
"""
613
- if has_serializable_by_keys (obj ):
614
- obj = _ContextualSerialization (obj )
615
-
616
- class ContextualEncoder (cls ): # type: ignore
617
- """An encoder with a context map for concise serialization."""
618
-
619
- # These lists populate gradually during serialization. An object
620
- # with components defined in 'context' will represent those
621
- # components using their keys instead of inline definition.
622
- seen : Set [str ] = set ()
623
-
624
- def default (self , o ):
625
- if not isinstance (o , SerializableByKey ):
626
- return super ().default (o )
627
- for candidate in obj .object_dag [:- 1 ]:
628
- if candidate .obj == o :
629
- if not candidate .key in ContextualEncoder .seen :
630
- ContextualEncoder .seen .add (candidate .key )
631
- return _json_dict_with_cirq_type (candidate .obj )
632
- else :
633
- return _json_dict_with_cirq_type (_SerializedKey (candidate .key ))
634
- raise ValueError ("Object mutated during serialization." ) # pragma: no cover
635
-
636
- cls = ContextualEncoder
637
-
638
512
if file_or_fn is None :
639
513
return json .dumps (obj , indent = indent , separators = separators , cls = cls )
640
514
641
515
if isinstance (file_or_fn , (str , pathlib .Path )):
642
516
with open (file_or_fn , 'w' ) as actually_a_file :
643
- json .dump (obj , actually_a_file , indent = indent , cls = cls )
517
+ json .dump (obj , actually_a_file , indent = indent , separators = separators , cls = cls )
644
518
return None
645
519
646
520
json .dump (obj , file_or_fn , indent = indent , separators = separators , cls = cls )
@@ -682,10 +556,7 @@ def read_json(
682
556
if resolvers is None :
683
557
resolvers = DEFAULT_RESOLVERS
684
558
685
- context_map : Dict [str , 'SerializableByKey' ] = {}
686
-
687
- def obj_hook (x ):
688
- return _cirq_object_hook (x , resolvers , context_map )
559
+ obj_hook = ObjectHook (resolvers )
689
560
690
561
if json_text is not None :
691
562
return json .loads (json_text , object_hook = obj_hook )
0 commit comments