20
20
import functools
21
21
import textwrap
22
22
from typing import (
23
+ cast ,
23
24
Any ,
24
25
Tuple ,
25
26
Hashable ,
32
33
)
33
34
from typing_extensions import Protocol
34
35
36
+ from cirq import circuits
37
+
35
38
if TYPE_CHECKING :
36
39
import cirq
37
40
@@ -214,10 +217,13 @@ class TransformerContext:
214
217
circuit. Transformers should not transform any operation marked with a tag that
215
218
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
216
219
`cirq.VirtualTag` etc.) is a valid tag.
220
+ deep: If true, the transformer should be recursively applied to all sub-circuits wrapped
221
+ inside circuit operations.
217
222
"""
218
223
219
224
logger : TransformerLogger = NoOpTransformerLogger ()
220
225
tags_to_ignore : Tuple [Hashable , ...] = ()
226
+ deep : bool = False
221
227
222
228
223
229
class TRANSFORMER (Protocol ):
@@ -232,16 +238,20 @@ def __call__(
232
238
233
239
234
240
@overload
235
- def transformer (cls_or_func : _TRANSFORMER_T ) -> _TRANSFORMER_T :
241
+ def transformer (
242
+ cls_or_func : _TRANSFORMER_T , * , add_support_for_deep : bool = False
243
+ ) -> _TRANSFORMER_T :
236
244
pass
237
245
238
246
239
247
@overload
240
- def transformer (cls_or_func : _TRANSFORMER_CLS_T ) -> _TRANSFORMER_CLS_T :
248
+ def transformer (
249
+ cls_or_func : _TRANSFORMER_CLS_T , * , add_support_for_deep : bool = False
250
+ ) -> _TRANSFORMER_CLS_T :
241
251
pass
242
252
243
253
244
- def transformer (cls_or_func : Any ) -> Any :
254
+ def transformer (cls_or_func : Any = None , * , add_support_for_deep : bool = False ) -> Any :
245
255
"""Decorator to verify API and append logging functionality to transformer functions & classes.
246
256
247
257
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
@@ -284,10 +294,22 @@ def transformer(cls_or_func: Any) -> Any:
284
294
285
295
Args:
286
296
cls_or_func: The callable class or function to be decorated.
297
+ add_support_for_deep: If True, the decorator adds the logic to first apply the
298
+ decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s
299
+ before applying it on the top-level circuit, if context.deep is True.
287
300
288
301
Returns:
289
302
Decorated class / function which includes additional logging boilerplate.
290
303
"""
304
+
305
+ # If keyword arguments were specified, python invokes the decorator method
306
+ # without a `cls` argument, then passes `cls` into the result.
307
+ if cls_or_func is None :
308
+ return lambda deferred_cls_or_func : transformer (
309
+ deferred_cls_or_func ,
310
+ add_support_for_deep = add_support_for_deep ,
311
+ )
312
+
291
313
if isinstance (cls_or_func , type ):
292
314
cls = cls_or_func
293
315
method = cls .__call__
@@ -298,6 +320,7 @@ def method_with_logging(
298
320
self , circuit : 'cirq.AbstractCircuit' , ** kwargs
299
321
) -> 'cirq.AbstractCircuit' :
300
322
return _transform_and_log (
323
+ add_support_for_deep ,
301
324
lambda circuit , ** kwargs : method (self , circuit , ** kwargs ),
302
325
cls .__name__ ,
303
326
circuit ,
@@ -315,6 +338,7 @@ def method_with_logging(
315
338
@functools .wraps (func )
316
339
def func_with_logging (circuit : 'cirq.AbstractCircuit' , ** kwargs ) -> 'cirq.AbstractCircuit' :
317
340
return _transform_and_log (
341
+ add_support_for_deep ,
318
342
func ,
319
343
func .__name__ ,
320
344
circuit ,
@@ -325,7 +349,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra
325
349
return func_with_logging
326
350
327
351
328
- def _get_default_context (func : TRANSFORMER ):
352
+ def _get_default_context (func : TRANSFORMER ) -> TransformerContext :
329
353
sig = inspect .signature (func )
330
354
default_context = sig .parameters ["context" ].default
331
355
assert (
@@ -334,7 +358,35 @@ def _get_default_context(func: TRANSFORMER):
334
358
return default_context
335
359
336
360
361
+ def _run_transformer_on_circuit (
362
+ add_support_for_deep : bool ,
363
+ func : TRANSFORMER ,
364
+ circuit : 'cirq.AbstractCircuit' ,
365
+ extracted_context : Optional [TransformerContext ],
366
+ ** kwargs ,
367
+ ) -> 'cirq.AbstractCircuit' :
368
+ mutable_circuit = None
369
+ if extracted_context and extracted_context .deep and add_support_for_deep :
370
+ batch_replace = []
371
+ for i , op in circuit .findall_operations (
372
+ lambda o : isinstance (o .untagged , circuits .CircuitOperation )
373
+ ):
374
+ op_untagged = cast (circuits .CircuitOperation , op .untagged )
375
+ if not set (op .tags ).isdisjoint (extracted_context .tags_to_ignore ):
376
+ continue
377
+ op_untagged = op_untagged .replace (
378
+ circuit = _run_transformer_on_circuit (
379
+ add_support_for_deep , func , op_untagged .circuit , extracted_context , ** kwargs
380
+ ).freeze ()
381
+ )
382
+ batch_replace .append ((i , op , op_untagged .with_tags (* op .tags )))
383
+ mutable_circuit = circuit .unfreeze (copy = True )
384
+ mutable_circuit .batch_replace (batch_replace )
385
+ return func (mutable_circuit if mutable_circuit else circuit , ** kwargs )
386
+
387
+
337
388
def _transform_and_log (
389
+ add_support_for_deep : bool ,
338
390
func : TRANSFORMER ,
339
391
transformer_name : str ,
340
392
circuit : 'cirq.AbstractCircuit' ,
@@ -344,7 +396,9 @@ def _transform_and_log(
344
396
"""Helper to log initial and final circuits before and after calling the transformer."""
345
397
if extracted_context :
346
398
extracted_context .logger .register_initial (circuit , transformer_name )
347
- transformed_circuit = func (circuit , ** kwargs )
399
+ transformed_circuit = _run_transformer_on_circuit (
400
+ add_support_for_deep , func , circuit , extracted_context , ** kwargs
401
+ )
348
402
if extracted_context :
349
403
extracted_context .logger .register_final (transformed_circuit , transformer_name )
350
404
return transformed_circuit
0 commit comments