From ed7d8fcb3907bb4be37de6d77a34703bbc5cb6f1 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Sun, 20 Mar 2022 23:16:22 -0700 Subject: [PATCH 1/3] Add flag to @cirq.transformer decorator --- .../cirq/transformers/transformer_api.py | 64 +++++++++++++++++-- .../cirq/transformers/transformer_api_test.py | 58 ++++++++++++++++- 2 files changed, 114 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/transformers/transformer_api.py b/cirq-core/cirq/transformers/transformer_api.py index 6983312ab02..fc0b35f0de2 100644 --- a/cirq-core/cirq/transformers/transformer_api.py +++ b/cirq-core/cirq/transformers/transformer_api.py @@ -20,6 +20,7 @@ import functools import textwrap from typing import ( + cast, Any, Tuple, Hashable, @@ -32,6 +33,8 @@ ) from typing_extensions import Protocol +from cirq import circuits + if TYPE_CHECKING: import cirq @@ -214,10 +217,13 @@ class TransformerContext: circuit. Transformers should not transform any operation marked with a tag that belongs to this tuple. Note that any instance of a Hashable type (like `str`, `cirq.VirtualTag` etc.) is a valid tag. + deep: If true, the transformer should be recursively applied to all sub-circuits wrapped + inside circuit operations. """ logger: TransformerLogger = NoOpTransformerLogger() tags_to_ignore: Tuple[Hashable, ...] = () + deep: bool = False class TRANSFORMER(Protocol): @@ -232,16 +238,20 @@ def __call__( @overload -def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T: +def transformer( + cls_or_func: _TRANSFORMER_T, *, add_support_for_deep: bool = False +) -> _TRANSFORMER_T: pass @overload -def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T: +def transformer( + cls_or_func: _TRANSFORMER_CLS_T, *, add_support_for_deep: bool = False +) -> _TRANSFORMER_CLS_T: pass -def transformer(cls_or_func: Any) -> Any: +def transformer(cls_or_func: Any = None, *, add_support_for_deep: bool = False) -> Any: """Decorator to verify API and append logging functionality to transformer functions & classes. A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and @@ -284,10 +294,22 @@ def transformer(cls_or_func: Any) -> Any: Args: cls_or_func: The callable class or function to be decorated. + add_support_for_deep: If True, the decorator adds the logic to first apply the + decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s + before applying it on the top-level circuit, if context.deep is True. Returns: Decorated class / function which includes additional logging boilerplate. """ + + # If keyword arguments were specified, python invokes the decorator method + # without a `cls` argument, then passes `cls` into the result. + if cls_or_func is None: + return lambda deferred_cls_or_func: transformer( + deferred_cls_or_func, + add_support_for_deep=add_support_for_deep, + ) + if isinstance(cls_or_func, type): cls = cls_or_func method = cls.__call__ @@ -298,6 +320,7 @@ def method_with_logging( self, circuit: 'cirq.AbstractCircuit', **kwargs ) -> 'cirq.AbstractCircuit': return _transform_and_log( + add_support_for_deep, lambda circuit, **kwargs: method(self, circuit, **kwargs), cls.__name__, circuit, @@ -315,6 +338,7 @@ def method_with_logging( @functools.wraps(func) def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit': return _transform_and_log( + add_support_for_deep, func, func.__name__, circuit, @@ -325,7 +349,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra return func_with_logging -def _get_default_context(func: TRANSFORMER): +def _get_default_context(func: TRANSFORMER) -> TransformerContext: sig = inspect.signature(func) default_context = sig.parameters["context"].default assert ( @@ -334,7 +358,35 @@ def _get_default_context(func: TRANSFORMER): return default_context +def _run_transformer_on_circuit( + add_support_for_deep: bool, + func: TRANSFORMER, + circuit: 'cirq.AbstractCircuit', + extracted_context: Optional[TransformerContext], + **kwargs, +) -> 'cirq.AbstractCircuit': + mutable_circuit = None + if extracted_context and extracted_context.deep and add_support_for_deep: + batch_replace = [] + for i, op in circuit.findall_operations( + lambda o: isinstance(o.untagged, circuits.CircuitOperation) + ): + op_untagged = cast(circuits.CircuitOperation, op.untagged) + if not set(op.tags).isdisjoint(extracted_context.tags_to_ignore): + continue + op_untagged = op_untagged.replace( + circuit=_run_transformer_on_circuit( + add_support_for_deep, func, op_untagged.circuit, extracted_context, **kwargs + ).freeze() + ) + batch_replace.append((i, op, op_untagged.with_tags(*op.tags))) + mutable_circuit = circuit.unfreeze(copy=True) + mutable_circuit.batch_replace(batch_replace) + return func(mutable_circuit if mutable_circuit else circuit, **kwargs) + + def _transform_and_log( + add_support_for_deep: bool, func: TRANSFORMER, transformer_name: str, circuit: 'cirq.AbstractCircuit', @@ -344,7 +396,9 @@ def _transform_and_log( """Helper to log initial and final circuits before and after calling the transformer.""" if extracted_context: extracted_context.logger.register_initial(circuit, transformer_name) - transformed_circuit = func(circuit, **kwargs) + transformed_circuit = _run_transformer_on_circuit( + add_support_for_deep, func, circuit, extracted_context, **kwargs + ) if extracted_context: extracted_context.logger.register_final(transformed_circuit, transformer_name) return transformed_circuit diff --git a/cirq-core/cirq/transformers/transformer_api_test.py b/cirq-core/cirq/transformers/transformer_api_test.py index 8b35656fca8..9f04ea40a49 100644 --- a/cirq-core/cirq/transformers/transformer_api_test.py +++ b/cirq-core/cirq/transformers/transformer_api_test.py @@ -21,7 +21,7 @@ import pytest -@cirq.transformer +@cirq.transformer() class MockTransformerClass: def __init__(self): self.mock = mock.Mock() @@ -59,6 +59,11 @@ def __call__( return circuit[::-1] +@cirq.transformer(add_support_for_deep=True) +class MockTransformerClassSupportsDeep(MockTransformerClass): + pass + + def make_transformer_func_with_defaults() -> cirq.TRANSFORMER: my_mock = mock.Mock() @@ -77,16 +82,18 @@ def func( return func -def make_transformer_func() -> cirq.TRANSFORMER: +def make_transformer_func(add_support_for_deep: bool = False) -> cirq.TRANSFORMER: my_mock = mock.Mock() - @cirq.transformer + @cirq.transformer(add_support_for_deep=add_support_for_deep) def mock_tranformer_func( circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None ) -> cirq.Circuit: + print("CALLED:", circuit, context, sep="\n\n----------\n\n") my_mock(circuit, context) return circuit.unfreeze() + # mock_transformer = mock_tranformer_func(add_support_for_deep) mock_tranformer_func.mock = my_mock # type: ignore return mock_tranformer_func @@ -134,6 +141,51 @@ def test_transformer_decorator_with_defaults(transformer): transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12)) +# from unittest.mock import patch, call + + +@pytest.mark.parametrize( + 'transformer, supports_deep', + [ + (MockTransformerClass(), False), + (make_transformer_func(), False), + (MockTransformerClassSupportsDeep(), True), + (make_transformer_func(add_support_for_deep=True), True), + ], +) +def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep): + q1, q2 = cirq.LineQubit.range(2) + c_nested_x = cirq.FrozenCircuit(cirq.X(q1)) + c_nested_y = cirq.FrozenCircuit(cirq.Y(q1)) + c_nested_xy = cirq.FrozenCircuit( + cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"), + cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"), + ) + c_nested_yx = cirq.FrozenCircuit( + cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("ignore"), + cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("preserve_tag"), + ) + c_orig = cirq.Circuit( + cirq.CircuitOperation(c_nested_xy).repeat(4), + cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"), + cirq.CircuitOperation(c_nested_y).repeat(6), + cirq.CircuitOperation(c_nested_yx).repeat(7), + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True) + transformer(c_orig, context=context) + expected_calls = [mock.call(c_orig, context)] + if supports_deep: + expected_calls = [ + mock.call(c_nested_y, context), # c_orig --> xy --> y + mock.call(c_nested_xy, context), # c_orig --> xy + mock.call(c_nested_y, context), # c_orig --> y + mock.call(c_nested_x, context), # c_orig --> yx --> x + mock.call(c_nested_yx, context), # c_orig --> yx + mock.call(c_orig, context), # c_orig + ] + transformer.mock.assert_has_calls(expected_calls) + + @cirq.transformer class T1: def __call__( From fed3908df09c3fc7c205f74fe6e789783394abc2 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 21 Mar 2022 00:16:59 -0700 Subject: [PATCH 2/3] Fix mypy type errors and remove typos --- cirq-core/cirq/transformers/transformer_api.py | 12 ++++++++++++ cirq-core/cirq/transformers/transformer_api_test.py | 11 +++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/transformers/transformer_api.py b/cirq-core/cirq/transformers/transformer_api.py index fc0b35f0de2..ae6f00dcd20 100644 --- a/cirq-core/cirq/transformers/transformer_api.py +++ b/cirq-core/cirq/transformers/transformer_api.py @@ -22,6 +22,7 @@ from typing import ( cast, Any, + Callable, Tuple, Hashable, List, @@ -30,6 +31,7 @@ Type, TYPE_CHECKING, TypeVar, + Union, ) from typing_extensions import Protocol @@ -235,6 +237,16 @@ def __call__( _TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER) _TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER]) +_TRANSFORMER_OR_CLS_T = TypeVar( + '_TRANSFORMER_OR_CLS_T', bound=Union[TRANSFORMER, Type[TRANSFORMER]] +) + + +@overload +def transformer( + *, add_support_for_deep: bool = False +) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]: + pass @overload diff --git a/cirq-core/cirq/transformers/transformer_api_test.py b/cirq-core/cirq/transformers/transformer_api_test.py index 9f04ea40a49..ad5ea09135c 100644 --- a/cirq-core/cirq/transformers/transformer_api_test.py +++ b/cirq-core/cirq/transformers/transformer_api_test.py @@ -89,11 +89,9 @@ def make_transformer_func(add_support_for_deep: bool = False) -> cirq.TRANSFORME def mock_tranformer_func( circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None ) -> cirq.Circuit: - print("CALLED:", circuit, context, sep="\n\n----------\n\n") my_mock(circuit, context) return circuit.unfreeze() - # mock_transformer = mock_tranformer_func(add_support_for_deep) mock_tranformer_func.mock = my_mock # type: ignore return mock_tranformer_func @@ -141,9 +139,6 @@ def test_transformer_decorator_with_defaults(transformer): transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12)) -# from unittest.mock import patch, call - - @pytest.mark.parametrize( 'transformer, supports_deep', [ @@ -154,9 +149,9 @@ def test_transformer_decorator_with_defaults(transformer): ], ) def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep): - q1, q2 = cirq.LineQubit.range(2) - c_nested_x = cirq.FrozenCircuit(cirq.X(q1)) - c_nested_y = cirq.FrozenCircuit(cirq.Y(q1)) + q = cirq.NamedQubit("q") + c_nested_x = cirq.FrozenCircuit(cirq.X(q)) + c_nested_y = cirq.FrozenCircuit(cirq.Y(q)) c_nested_xy = cirq.FrozenCircuit( cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"), cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"), From 60edd614a345c66a488338e1fedaaa6e1535340b Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 21 Mar 2022 11:14:33 -0700 Subject: [PATCH 3/3] Rename add_support_for_deep to add_deep_support --- .../cirq/transformers/transformer_api.py | 28 +++++++++---------- .../cirq/transformers/transformer_api_test.py | 8 +++--- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/cirq-core/cirq/transformers/transformer_api.py b/cirq-core/cirq/transformers/transformer_api.py index ae6f00dcd20..ee4f6118bf6 100644 --- a/cirq-core/cirq/transformers/transformer_api.py +++ b/cirq-core/cirq/transformers/transformer_api.py @@ -244,26 +244,24 @@ def __call__( @overload def transformer( - *, add_support_for_deep: bool = False + *, add_deep_support: bool = False ) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]: pass @overload -def transformer( - cls_or_func: _TRANSFORMER_T, *, add_support_for_deep: bool = False -) -> _TRANSFORMER_T: +def transformer(cls_or_func: _TRANSFORMER_T, *, add_deep_support: bool = False) -> _TRANSFORMER_T: pass @overload def transformer( - cls_or_func: _TRANSFORMER_CLS_T, *, add_support_for_deep: bool = False + cls_or_func: _TRANSFORMER_CLS_T, *, add_deep_support: bool = False ) -> _TRANSFORMER_CLS_T: pass -def transformer(cls_or_func: Any = None, *, add_support_for_deep: bool = False) -> Any: +def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> Any: """Decorator to verify API and append logging functionality to transformer functions & classes. A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and @@ -306,7 +304,7 @@ def transformer(cls_or_func: Any = None, *, add_support_for_deep: bool = False) Args: cls_or_func: The callable class or function to be decorated. - add_support_for_deep: If True, the decorator adds the logic to first apply the + add_deep_support: If True, the decorator adds the logic to first apply the decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s before applying it on the top-level circuit, if context.deep is True. @@ -319,7 +317,7 @@ def transformer(cls_or_func: Any = None, *, add_support_for_deep: bool = False) if cls_or_func is None: return lambda deferred_cls_or_func: transformer( deferred_cls_or_func, - add_support_for_deep=add_support_for_deep, + add_deep_support=add_deep_support, ) if isinstance(cls_or_func, type): @@ -332,7 +330,7 @@ def method_with_logging( self, circuit: 'cirq.AbstractCircuit', **kwargs ) -> 'cirq.AbstractCircuit': return _transform_and_log( - add_support_for_deep, + add_deep_support, lambda circuit, **kwargs: method(self, circuit, **kwargs), cls.__name__, circuit, @@ -350,7 +348,7 @@ def method_with_logging( @functools.wraps(func) def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit': return _transform_and_log( - add_support_for_deep, + add_deep_support, func, func.__name__, circuit, @@ -371,14 +369,14 @@ def _get_default_context(func: TRANSFORMER) -> TransformerContext: def _run_transformer_on_circuit( - add_support_for_deep: bool, + add_deep_support: bool, func: TRANSFORMER, circuit: 'cirq.AbstractCircuit', extracted_context: Optional[TransformerContext], **kwargs, ) -> 'cirq.AbstractCircuit': mutable_circuit = None - if extracted_context and extracted_context.deep and add_support_for_deep: + if extracted_context and extracted_context.deep and add_deep_support: batch_replace = [] for i, op in circuit.findall_operations( lambda o: isinstance(o.untagged, circuits.CircuitOperation) @@ -388,7 +386,7 @@ def _run_transformer_on_circuit( continue op_untagged = op_untagged.replace( circuit=_run_transformer_on_circuit( - add_support_for_deep, func, op_untagged.circuit, extracted_context, **kwargs + add_deep_support, func, op_untagged.circuit, extracted_context, **kwargs ).freeze() ) batch_replace.append((i, op, op_untagged.with_tags(*op.tags))) @@ -398,7 +396,7 @@ def _run_transformer_on_circuit( def _transform_and_log( - add_support_for_deep: bool, + add_deep_support: bool, func: TRANSFORMER, transformer_name: str, circuit: 'cirq.AbstractCircuit', @@ -409,7 +407,7 @@ def _transform_and_log( if extracted_context: extracted_context.logger.register_initial(circuit, transformer_name) transformed_circuit = _run_transformer_on_circuit( - add_support_for_deep, func, circuit, extracted_context, **kwargs + add_deep_support, func, circuit, extracted_context, **kwargs ) if extracted_context: extracted_context.logger.register_final(transformed_circuit, transformer_name) diff --git a/cirq-core/cirq/transformers/transformer_api_test.py b/cirq-core/cirq/transformers/transformer_api_test.py index ad5ea09135c..9ec7d506661 100644 --- a/cirq-core/cirq/transformers/transformer_api_test.py +++ b/cirq-core/cirq/transformers/transformer_api_test.py @@ -59,7 +59,7 @@ def __call__( return circuit[::-1] -@cirq.transformer(add_support_for_deep=True) +@cirq.transformer(add_deep_support=True) class MockTransformerClassSupportsDeep(MockTransformerClass): pass @@ -82,10 +82,10 @@ def func( return func -def make_transformer_func(add_support_for_deep: bool = False) -> cirq.TRANSFORMER: +def make_transformer_func(add_deep_support: bool = False) -> cirq.TRANSFORMER: my_mock = mock.Mock() - @cirq.transformer(add_support_for_deep=add_support_for_deep) + @cirq.transformer(add_deep_support=add_deep_support) def mock_tranformer_func( circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None ) -> cirq.Circuit: @@ -145,7 +145,7 @@ def test_transformer_decorator_with_defaults(transformer): (MockTransformerClass(), False), (make_transformer_func(), False), (MockTransformerClassSupportsDeep(), True), - (make_transformer_func(add_support_for_deep=True), True), + (make_transformer_func(add_deep_support=True), True), ], ) def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep):