Skip to content

Commit ed7d8fc

Browse files
committed
Add flag to @cirq.transformer decorator
1 parent 6af5387 commit ed7d8fc

File tree

2 files changed

+114
-8
lines changed

2 files changed

+114
-8
lines changed

cirq-core/cirq/transformers/transformer_api.py

+59-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import functools
2121
import textwrap
2222
from typing import (
23+
cast,
2324
Any,
2425
Tuple,
2526
Hashable,
@@ -32,6 +33,8 @@
3233
)
3334
from typing_extensions import Protocol
3435

36+
from cirq import circuits
37+
3538
if TYPE_CHECKING:
3639
import cirq
3740

@@ -214,10 +217,13 @@ class TransformerContext:
214217
circuit. Transformers should not transform any operation marked with a tag that
215218
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
216219
`cirq.VirtualTag` etc.) is a valid tag.
220+
deep: If true, the transformer should be recursively applied to all sub-circuits wrapped
221+
inside circuit operations.
217222
"""
218223

219224
logger: TransformerLogger = NoOpTransformerLogger()
220225
tags_to_ignore: Tuple[Hashable, ...] = ()
226+
deep: bool = False
221227

222228

223229
class TRANSFORMER(Protocol):
@@ -232,16 +238,20 @@ def __call__(
232238

233239

234240
@overload
235-
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
241+
def transformer(
242+
cls_or_func: _TRANSFORMER_T, *, add_support_for_deep: bool = False
243+
) -> _TRANSFORMER_T:
236244
pass
237245

238246

239247
@overload
240-
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
248+
def transformer(
249+
cls_or_func: _TRANSFORMER_CLS_T, *, add_support_for_deep: bool = False
250+
) -> _TRANSFORMER_CLS_T:
241251
pass
242252

243253

244-
def transformer(cls_or_func: Any) -> Any:
254+
def transformer(cls_or_func: Any = None, *, add_support_for_deep: bool = False) -> Any:
245255
"""Decorator to verify API and append logging functionality to transformer functions & classes.
246256
247257
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
@@ -284,10 +294,22 @@ def transformer(cls_or_func: Any) -> Any:
284294
285295
Args:
286296
cls_or_func: The callable class or function to be decorated.
297+
add_support_for_deep: If True, the decorator adds the logic to first apply the
298+
decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s
299+
before applying it on the top-level circuit, if context.deep is True.
287300
288301
Returns:
289302
Decorated class / function which includes additional logging boilerplate.
290303
"""
304+
305+
# If keyword arguments were specified, python invokes the decorator method
306+
# without a `cls` argument, then passes `cls` into the result.
307+
if cls_or_func is None:
308+
return lambda deferred_cls_or_func: transformer(
309+
deferred_cls_or_func,
310+
add_support_for_deep=add_support_for_deep,
311+
)
312+
291313
if isinstance(cls_or_func, type):
292314
cls = cls_or_func
293315
method = cls.__call__
@@ -298,6 +320,7 @@ def method_with_logging(
298320
self, circuit: 'cirq.AbstractCircuit', **kwargs
299321
) -> 'cirq.AbstractCircuit':
300322
return _transform_and_log(
323+
add_support_for_deep,
301324
lambda circuit, **kwargs: method(self, circuit, **kwargs),
302325
cls.__name__,
303326
circuit,
@@ -315,6 +338,7 @@ def method_with_logging(
315338
@functools.wraps(func)
316339
def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
317340
return _transform_and_log(
341+
add_support_for_deep,
318342
func,
319343
func.__name__,
320344
circuit,
@@ -325,7 +349,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra
325349
return func_with_logging
326350

327351

328-
def _get_default_context(func: TRANSFORMER):
352+
def _get_default_context(func: TRANSFORMER) -> TransformerContext:
329353
sig = inspect.signature(func)
330354
default_context = sig.parameters["context"].default
331355
assert (
@@ -334,7 +358,35 @@ def _get_default_context(func: TRANSFORMER):
334358
return default_context
335359

336360

361+
def _run_transformer_on_circuit(
362+
add_support_for_deep: bool,
363+
func: TRANSFORMER,
364+
circuit: 'cirq.AbstractCircuit',
365+
extracted_context: Optional[TransformerContext],
366+
**kwargs,
367+
) -> 'cirq.AbstractCircuit':
368+
mutable_circuit = None
369+
if extracted_context and extracted_context.deep and add_support_for_deep:
370+
batch_replace = []
371+
for i, op in circuit.findall_operations(
372+
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
373+
):
374+
op_untagged = cast(circuits.CircuitOperation, op.untagged)
375+
if not set(op.tags).isdisjoint(extracted_context.tags_to_ignore):
376+
continue
377+
op_untagged = op_untagged.replace(
378+
circuit=_run_transformer_on_circuit(
379+
add_support_for_deep, func, op_untagged.circuit, extracted_context, **kwargs
380+
).freeze()
381+
)
382+
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
383+
mutable_circuit = circuit.unfreeze(copy=True)
384+
mutable_circuit.batch_replace(batch_replace)
385+
return func(mutable_circuit if mutable_circuit else circuit, **kwargs)
386+
387+
337388
def _transform_and_log(
389+
add_support_for_deep: bool,
338390
func: TRANSFORMER,
339391
transformer_name: str,
340392
circuit: 'cirq.AbstractCircuit',
@@ -344,7 +396,9 @@ def _transform_and_log(
344396
"""Helper to log initial and final circuits before and after calling the transformer."""
345397
if extracted_context:
346398
extracted_context.logger.register_initial(circuit, transformer_name)
347-
transformed_circuit = func(circuit, **kwargs)
399+
transformed_circuit = _run_transformer_on_circuit(
400+
add_support_for_deep, func, circuit, extracted_context, **kwargs
401+
)
348402
if extracted_context:
349403
extracted_context.logger.register_final(transformed_circuit, transformer_name)
350404
return transformed_circuit

cirq-core/cirq/transformers/transformer_api_test.py

+55-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222

2323

24-
@cirq.transformer
24+
@cirq.transformer()
2525
class MockTransformerClass:
2626
def __init__(self):
2727
self.mock = mock.Mock()
@@ -59,6 +59,11 @@ def __call__(
5959
return circuit[::-1]
6060

6161

62+
@cirq.transformer(add_support_for_deep=True)
63+
class MockTransformerClassSupportsDeep(MockTransformerClass):
64+
pass
65+
66+
6267
def make_transformer_func_with_defaults() -> cirq.TRANSFORMER:
6368
my_mock = mock.Mock()
6469

@@ -77,16 +82,18 @@ def func(
7782
return func
7883

7984

80-
def make_transformer_func() -> cirq.TRANSFORMER:
85+
def make_transformer_func(add_support_for_deep: bool = False) -> cirq.TRANSFORMER:
8186
my_mock = mock.Mock()
8287

83-
@cirq.transformer
88+
@cirq.transformer(add_support_for_deep=add_support_for_deep)
8489
def mock_tranformer_func(
8590
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
8691
) -> cirq.Circuit:
92+
print("CALLED:", circuit, context, sep="\n\n----------\n\n")
8793
my_mock(circuit, context)
8894
return circuit.unfreeze()
8995

96+
# mock_transformer = mock_tranformer_func(add_support_for_deep)
9097
mock_tranformer_func.mock = my_mock # type: ignore
9198
return mock_tranformer_func
9299

@@ -134,6 +141,51 @@ def test_transformer_decorator_with_defaults(transformer):
134141
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))
135142

136143

144+
# from unittest.mock import patch, call
145+
146+
147+
@pytest.mark.parametrize(
148+
'transformer, supports_deep',
149+
[
150+
(MockTransformerClass(), False),
151+
(make_transformer_func(), False),
152+
(MockTransformerClassSupportsDeep(), True),
153+
(make_transformer_func(add_support_for_deep=True), True),
154+
],
155+
)
156+
def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep):
157+
q1, q2 = cirq.LineQubit.range(2)
158+
c_nested_x = cirq.FrozenCircuit(cirq.X(q1))
159+
c_nested_y = cirq.FrozenCircuit(cirq.Y(q1))
160+
c_nested_xy = cirq.FrozenCircuit(
161+
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
162+
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"),
163+
)
164+
c_nested_yx = cirq.FrozenCircuit(
165+
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("ignore"),
166+
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("preserve_tag"),
167+
)
168+
c_orig = cirq.Circuit(
169+
cirq.CircuitOperation(c_nested_xy).repeat(4),
170+
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
171+
cirq.CircuitOperation(c_nested_y).repeat(6),
172+
cirq.CircuitOperation(c_nested_yx).repeat(7),
173+
)
174+
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
175+
transformer(c_orig, context=context)
176+
expected_calls = [mock.call(c_orig, context)]
177+
if supports_deep:
178+
expected_calls = [
179+
mock.call(c_nested_y, context), # c_orig --> xy --> y
180+
mock.call(c_nested_xy, context), # c_orig --> xy
181+
mock.call(c_nested_y, context), # c_orig --> y
182+
mock.call(c_nested_x, context), # c_orig --> yx --> x
183+
mock.call(c_nested_yx, context), # c_orig --> yx
184+
mock.call(c_orig, context), # c_orig
185+
]
186+
transformer.mock.assert_has_calls(expected_calls)
187+
188+
137189
@cirq.transformer
138190
class T1:
139191
def __call__(

0 commit comments

Comments
 (0)