Skip to content

Commit cfa255a

Browse files
authored
Add add_deep_support flag to @cirq.transformer decorator (#5108)
* Add flag to @cirq.transformer decorator * Fix mypy type errors and remove typos * Rename add_support_for_deep to add_deep_support
1 parent aed4eb8 commit cfa255a

File tree

2 files changed

+119
-8
lines changed

2 files changed

+119
-8
lines changed

cirq-core/cirq/transformers/transformer_api.py

+69-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import functools
2121
import textwrap
2222
from typing import (
23+
cast,
2324
Any,
25+
Callable,
2426
Tuple,
2527
Hashable,
2628
List,
@@ -29,9 +31,12 @@
2931
Type,
3032
TYPE_CHECKING,
3133
TypeVar,
34+
Union,
3235
)
3336
from typing_extensions import Protocol
3437

38+
from cirq import circuits
39+
3540
if TYPE_CHECKING:
3641
import cirq
3742

@@ -214,10 +219,13 @@ class TransformerContext:
214219
circuit. Transformers should not transform any operation marked with a tag that
215220
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
216221
`cirq.VirtualTag` etc.) is a valid tag.
222+
deep: If true, the transformer should be recursively applied to all sub-circuits wrapped
223+
inside circuit operations.
217224
"""
218225

219226
logger: TransformerLogger = NoOpTransformerLogger()
220227
tags_to_ignore: Tuple[Hashable, ...] = ()
228+
deep: bool = False
221229

222230

223231
class TRANSFORMER(Protocol):
@@ -229,19 +237,31 @@ def __call__(
229237

230238
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
231239
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])
240+
_TRANSFORMER_OR_CLS_T = TypeVar(
241+
'_TRANSFORMER_OR_CLS_T', bound=Union[TRANSFORMER, Type[TRANSFORMER]]
242+
)
243+
244+
245+
@overload
246+
def transformer(
247+
*, add_deep_support: bool = False
248+
) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]:
249+
pass
232250

233251

234252
@overload
235-
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
253+
def transformer(cls_or_func: _TRANSFORMER_T, *, add_deep_support: bool = False) -> _TRANSFORMER_T:
236254
pass
237255

238256

239257
@overload
240-
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
258+
def transformer(
259+
cls_or_func: _TRANSFORMER_CLS_T, *, add_deep_support: bool = False
260+
) -> _TRANSFORMER_CLS_T:
241261
pass
242262

243263

244-
def transformer(cls_or_func: Any) -> Any:
264+
def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> Any:
245265
"""Decorator to verify API and append logging functionality to transformer functions & classes.
246266
247267
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
@@ -284,10 +304,22 @@ def transformer(cls_or_func: Any) -> Any:
284304
285305
Args:
286306
cls_or_func: The callable class or function to be decorated.
307+
add_deep_support: If True, the decorator adds the logic to first apply the
308+
decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s
309+
before applying it on the top-level circuit, if context.deep is True.
287310
288311
Returns:
289312
Decorated class / function which includes additional logging boilerplate.
290313
"""
314+
315+
# If keyword arguments were specified, python invokes the decorator method
316+
# without a `cls` argument, then passes `cls` into the result.
317+
if cls_or_func is None:
318+
return lambda deferred_cls_or_func: transformer(
319+
deferred_cls_or_func,
320+
add_deep_support=add_deep_support,
321+
)
322+
291323
if isinstance(cls_or_func, type):
292324
cls = cls_or_func
293325
method = cls.__call__
@@ -298,6 +330,7 @@ def method_with_logging(
298330
self, circuit: 'cirq.AbstractCircuit', **kwargs
299331
) -> 'cirq.AbstractCircuit':
300332
return _transform_and_log(
333+
add_deep_support,
301334
lambda circuit, **kwargs: method(self, circuit, **kwargs),
302335
cls.__name__,
303336
circuit,
@@ -315,6 +348,7 @@ def method_with_logging(
315348
@functools.wraps(func)
316349
def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
317350
return _transform_and_log(
351+
add_deep_support,
318352
func,
319353
func.__name__,
320354
circuit,
@@ -325,7 +359,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra
325359
return func_with_logging
326360

327361

328-
def _get_default_context(func: TRANSFORMER):
362+
def _get_default_context(func: TRANSFORMER) -> TransformerContext:
329363
sig = inspect.signature(func)
330364
default_context = sig.parameters["context"].default
331365
assert (
@@ -334,7 +368,35 @@ def _get_default_context(func: TRANSFORMER):
334368
return default_context
335369

336370

371+
def _run_transformer_on_circuit(
372+
add_deep_support: bool,
373+
func: TRANSFORMER,
374+
circuit: 'cirq.AbstractCircuit',
375+
extracted_context: Optional[TransformerContext],
376+
**kwargs,
377+
) -> 'cirq.AbstractCircuit':
378+
mutable_circuit = None
379+
if extracted_context and extracted_context.deep and add_deep_support:
380+
batch_replace = []
381+
for i, op in circuit.findall_operations(
382+
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
383+
):
384+
op_untagged = cast(circuits.CircuitOperation, op.untagged)
385+
if not set(op.tags).isdisjoint(extracted_context.tags_to_ignore):
386+
continue
387+
op_untagged = op_untagged.replace(
388+
circuit=_run_transformer_on_circuit(
389+
add_deep_support, func, op_untagged.circuit, extracted_context, **kwargs
390+
).freeze()
391+
)
392+
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
393+
mutable_circuit = circuit.unfreeze(copy=True)
394+
mutable_circuit.batch_replace(batch_replace)
395+
return func(mutable_circuit if mutable_circuit else circuit, **kwargs)
396+
397+
337398
def _transform_and_log(
399+
add_deep_support: bool,
338400
func: TRANSFORMER,
339401
transformer_name: str,
340402
circuit: 'cirq.AbstractCircuit',
@@ -344,7 +406,9 @@ def _transform_and_log(
344406
"""Helper to log initial and final circuits before and after calling the transformer."""
345407
if extracted_context:
346408
extracted_context.logger.register_initial(circuit, transformer_name)
347-
transformed_circuit = func(circuit, **kwargs)
409+
transformed_circuit = _run_transformer_on_circuit(
410+
add_deep_support, func, circuit, extracted_context, **kwargs
411+
)
348412
if extracted_context:
349413
extracted_context.logger.register_final(transformed_circuit, transformer_name)
350414
return transformed_circuit

cirq-core/cirq/transformers/transformer_api_test.py

+50-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222

2323

24-
@cirq.transformer
24+
@cirq.transformer()
2525
class MockTransformerClass:
2626
def __init__(self):
2727
self.mock = mock.Mock()
@@ -59,6 +59,11 @@ def __call__(
5959
return circuit[::-1]
6060

6161

62+
@cirq.transformer(add_deep_support=True)
63+
class MockTransformerClassSupportsDeep(MockTransformerClass):
64+
pass
65+
66+
6267
def make_transformer_func_with_defaults() -> cirq.TRANSFORMER:
6368
my_mock = mock.Mock()
6469

@@ -77,10 +82,10 @@ def func(
7782
return func
7883

7984

80-
def make_transformer_func() -> cirq.TRANSFORMER:
85+
def make_transformer_func(add_deep_support: bool = False) -> cirq.TRANSFORMER:
8186
my_mock = mock.Mock()
8287

83-
@cirq.transformer
88+
@cirq.transformer(add_deep_support=add_deep_support)
8489
def mock_tranformer_func(
8590
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
8691
) -> cirq.Circuit:
@@ -134,6 +139,48 @@ def test_transformer_decorator_with_defaults(transformer):
134139
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))
135140

136141

142+
@pytest.mark.parametrize(
143+
'transformer, supports_deep',
144+
[
145+
(MockTransformerClass(), False),
146+
(make_transformer_func(), False),
147+
(MockTransformerClassSupportsDeep(), True),
148+
(make_transformer_func(add_deep_support=True), True),
149+
],
150+
)
151+
def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep):
152+
q = cirq.NamedQubit("q")
153+
c_nested_x = cirq.FrozenCircuit(cirq.X(q))
154+
c_nested_y = cirq.FrozenCircuit(cirq.Y(q))
155+
c_nested_xy = cirq.FrozenCircuit(
156+
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
157+
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"),
158+
)
159+
c_nested_yx = cirq.FrozenCircuit(
160+
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("ignore"),
161+
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("preserve_tag"),
162+
)
163+
c_orig = cirq.Circuit(
164+
cirq.CircuitOperation(c_nested_xy).repeat(4),
165+
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
166+
cirq.CircuitOperation(c_nested_y).repeat(6),
167+
cirq.CircuitOperation(c_nested_yx).repeat(7),
168+
)
169+
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
170+
transformer(c_orig, context=context)
171+
expected_calls = [mock.call(c_orig, context)]
172+
if supports_deep:
173+
expected_calls = [
174+
mock.call(c_nested_y, context), # c_orig --> xy --> y
175+
mock.call(c_nested_xy, context), # c_orig --> xy
176+
mock.call(c_nested_y, context), # c_orig --> y
177+
mock.call(c_nested_x, context), # c_orig --> yx --> x
178+
mock.call(c_nested_yx, context), # c_orig --> yx
179+
mock.call(c_orig, context), # c_orig
180+
]
181+
transformer.mock.assert_has_calls(expected_calls)
182+
183+
137184
@cirq.transformer
138185
class T1:
139186
def __call__(

0 commit comments

Comments
 (0)