20
20
import functools
21
21
import textwrap
22
22
from typing import (
23
+ cast ,
23
24
Any ,
25
+ Callable ,
24
26
Tuple ,
25
27
Hashable ,
26
28
List ,
29
31
Type ,
30
32
TYPE_CHECKING ,
31
33
TypeVar ,
34
+ Union ,
32
35
)
33
36
from typing_extensions import Protocol
34
37
38
+ from cirq import circuits
39
+
35
40
if TYPE_CHECKING :
36
41
import cirq
37
42
@@ -214,10 +219,13 @@ class TransformerContext:
214
219
circuit. Transformers should not transform any operation marked with a tag that
215
220
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
216
221
`cirq.VirtualTag` etc.) is a valid tag.
222
+ deep: If true, the transformer should be recursively applied to all sub-circuits wrapped
223
+ inside circuit operations.
217
224
"""
218
225
219
226
logger : TransformerLogger = NoOpTransformerLogger ()
220
227
tags_to_ignore : Tuple [Hashable , ...] = ()
228
+ deep : bool = False
221
229
222
230
223
231
class TRANSFORMER (Protocol ):
@@ -229,19 +237,31 @@ def __call__(
229
237
230
238
_TRANSFORMER_T = TypeVar ('_TRANSFORMER_T' , bound = TRANSFORMER )
231
239
_TRANSFORMER_CLS_T = TypeVar ('_TRANSFORMER_CLS_T' , bound = Type [TRANSFORMER ])
240
+ _TRANSFORMER_OR_CLS_T = TypeVar (
241
+ '_TRANSFORMER_OR_CLS_T' , bound = Union [TRANSFORMER , Type [TRANSFORMER ]]
242
+ )
243
+
244
+
245
+ @overload
246
+ def transformer (
247
+ * , add_deep_support : bool = False
248
+ ) -> Callable [[_TRANSFORMER_OR_CLS_T ], _TRANSFORMER_OR_CLS_T ]:
249
+ pass
232
250
233
251
234
252
@overload
235
- def transformer (cls_or_func : _TRANSFORMER_T ) -> _TRANSFORMER_T :
253
+ def transformer (cls_or_func : _TRANSFORMER_T , * , add_deep_support : bool = False ) -> _TRANSFORMER_T :
236
254
pass
237
255
238
256
239
257
@overload
240
- def transformer (cls_or_func : _TRANSFORMER_CLS_T ) -> _TRANSFORMER_CLS_T :
258
+ def transformer (
259
+ cls_or_func : _TRANSFORMER_CLS_T , * , add_deep_support : bool = False
260
+ ) -> _TRANSFORMER_CLS_T :
241
261
pass
242
262
243
263
244
- def transformer (cls_or_func : Any ) -> Any :
264
+ def transformer (cls_or_func : Any = None , * , add_deep_support : bool = False ) -> Any :
245
265
"""Decorator to verify API and append logging functionality to transformer functions & classes.
246
266
247
267
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
@@ -284,10 +304,22 @@ def transformer(cls_or_func: Any) -> Any:
284
304
285
305
Args:
286
306
cls_or_func: The callable class or function to be decorated.
307
+ add_deep_support: If True, the decorator adds the logic to first apply the
308
+ decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s
309
+ before applying it on the top-level circuit, if context.deep is True.
287
310
288
311
Returns:
289
312
Decorated class / function which includes additional logging boilerplate.
290
313
"""
314
+
315
+ # If keyword arguments were specified, python invokes the decorator method
316
+ # without a `cls` argument, then passes `cls` into the result.
317
+ if cls_or_func is None :
318
+ return lambda deferred_cls_or_func : transformer (
319
+ deferred_cls_or_func ,
320
+ add_deep_support = add_deep_support ,
321
+ )
322
+
291
323
if isinstance (cls_or_func , type ):
292
324
cls = cls_or_func
293
325
method = cls .__call__
@@ -298,6 +330,7 @@ def method_with_logging(
298
330
self , circuit : 'cirq.AbstractCircuit' , ** kwargs
299
331
) -> 'cirq.AbstractCircuit' :
300
332
return _transform_and_log (
333
+ add_deep_support ,
301
334
lambda circuit , ** kwargs : method (self , circuit , ** kwargs ),
302
335
cls .__name__ ,
303
336
circuit ,
@@ -315,6 +348,7 @@ def method_with_logging(
315
348
@functools .wraps (func )
316
349
def func_with_logging (circuit : 'cirq.AbstractCircuit' , ** kwargs ) -> 'cirq.AbstractCircuit' :
317
350
return _transform_and_log (
351
+ add_deep_support ,
318
352
func ,
319
353
func .__name__ ,
320
354
circuit ,
@@ -325,7 +359,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra
325
359
return func_with_logging
326
360
327
361
328
- def _get_default_context (func : TRANSFORMER ):
362
+ def _get_default_context (func : TRANSFORMER ) -> TransformerContext :
329
363
sig = inspect .signature (func )
330
364
default_context = sig .parameters ["context" ].default
331
365
assert (
@@ -334,7 +368,35 @@ def _get_default_context(func: TRANSFORMER):
334
368
return default_context
335
369
336
370
371
+ def _run_transformer_on_circuit (
372
+ add_deep_support : bool ,
373
+ func : TRANSFORMER ,
374
+ circuit : 'cirq.AbstractCircuit' ,
375
+ extracted_context : Optional [TransformerContext ],
376
+ ** kwargs ,
377
+ ) -> 'cirq.AbstractCircuit' :
378
+ mutable_circuit = None
379
+ if extracted_context and extracted_context .deep and add_deep_support :
380
+ batch_replace = []
381
+ for i , op in circuit .findall_operations (
382
+ lambda o : isinstance (o .untagged , circuits .CircuitOperation )
383
+ ):
384
+ op_untagged = cast (circuits .CircuitOperation , op .untagged )
385
+ if not set (op .tags ).isdisjoint (extracted_context .tags_to_ignore ):
386
+ continue
387
+ op_untagged = op_untagged .replace (
388
+ circuit = _run_transformer_on_circuit (
389
+ add_deep_support , func , op_untagged .circuit , extracted_context , ** kwargs
390
+ ).freeze ()
391
+ )
392
+ batch_replace .append ((i , op , op_untagged .with_tags (* op .tags )))
393
+ mutable_circuit = circuit .unfreeze (copy = True )
394
+ mutable_circuit .batch_replace (batch_replace )
395
+ return func (mutable_circuit if mutable_circuit else circuit , ** kwargs )
396
+
397
+
337
398
def _transform_and_log (
399
+ add_deep_support : bool ,
338
400
func : TRANSFORMER ,
339
401
transformer_name : str ,
340
402
circuit : 'cirq.AbstractCircuit' ,
@@ -344,7 +406,9 @@ def _transform_and_log(
344
406
"""Helper to log initial and final circuits before and after calling the transformer."""
345
407
if extracted_context :
346
408
extracted_context .logger .register_initial (circuit , transformer_name )
347
- transformed_circuit = func (circuit , ** kwargs )
409
+ transformed_circuit = _run_transformer_on_circuit (
410
+ add_deep_support , func , circuit , extracted_context , ** kwargs
411
+ )
348
412
if extracted_context :
349
413
extracted_context .logger .register_final (transformed_circuit , transformer_name )
350
414
return transformed_circuit
0 commit comments