Skip to content

Commit 23bbbf9

Browse files
committed
Fix mypy type errors and remove typos
1 parent ed7d8fc commit 23bbbf9

File tree

4 files changed

+88
-10
lines changed

4 files changed

+88
-10
lines changed

Diff for: cirq-core/cirq/transformers/align.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Transformer passes which align operations to the left or right of the circuit."""
1616

17+
import dataclasses
1718
from typing import Optional, TYPE_CHECKING
1819
from cirq import circuits, ops
1920
from cirq.transformers import transformer_api
@@ -22,7 +23,7 @@
2223
import cirq
2324

2425

25-
@transformer_api.transformer
26+
@transformer_api.transformer(add_support_for_deep=True)
2627
def align_left(
2728
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
2829
) -> 'cirq.Circuit':
@@ -54,7 +55,7 @@ def align_left(
5455
return ret
5556

5657

57-
@transformer_api.transformer
58+
@transformer_api.transformer(add_support_for_deep=True)
5859
def align_right(
5960
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
6061
) -> 'cirq.Circuit':
@@ -70,4 +71,6 @@ def align_right(
7071
Returns:
7172
Copy of the transformed input circuit.
7273
"""
74+
if context is not None and context.deep is True:
75+
context = dataclasses.replace(context, deep=False)
7376
return align_left(circuit[::-1], context=context)[::-1]

Diff for: cirq-core/cirq/transformers/align_test.py

+68
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,41 @@ def test_align_left_no_compile_context():
7171
)
7272

7373

74+
def test_align_left_deep():
75+
q1, q2 = cirq.LineQubit.range(2)
76+
c_nested = cirq.FrozenCircuit(
77+
[
78+
cirq.Moment([cirq.X(q1)]),
79+
cirq.Moment([cirq.Y(q2)]),
80+
cirq.Moment([cirq.Z(q1), cirq.Y(q2).with_tags("nocompile")]),
81+
cirq.Moment([cirq.Y(q1)]),
82+
cirq.measure(q2, key='a'),
83+
cirq.Z(q1).with_classical_controls('a'),
84+
]
85+
)
86+
c_nested_aligned = cirq.FrozenCircuit(
87+
cirq.Moment(cirq.X(q1), cirq.Y(q2)),
88+
cirq.Moment(cirq.Z(q1)),
89+
cirq.Moment([cirq.Y(q1), cirq.Y(q2).with_tags("nocompile")]),
90+
cirq.measure(q2, key='a'),
91+
cirq.Z(q1).with_classical_controls('a'),
92+
)
93+
c_orig = cirq.Circuit(
94+
c_nested,
95+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
96+
c_nested,
97+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
98+
)
99+
c_expected = cirq.Circuit(
100+
c_nested_aligned,
101+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
102+
c_nested_aligned,
103+
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
104+
)
105+
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
106+
cirq.testing.assert_same_circuits(cirq.align_left(c_orig, context=context), c_expected)
107+
108+
74109
def test_align_left_subset_of_operations():
75110
q1 = cirq.NamedQubit('q1')
76111
q2 = cirq.NamedQubit('q2')
@@ -133,6 +168,39 @@ def test_align_right_no_compile_context():
133168
)
134169

135170

171+
def test_align_right_deep():
172+
q1, q2 = cirq.LineQubit.range(2)
173+
c_nested = cirq.FrozenCircuit(
174+
cirq.Moment([cirq.X(q1)]),
175+
cirq.Moment([cirq.Y(q1), cirq.X(q2).with_tags("nocompile")]),
176+
cirq.Moment([cirq.X(q2)]),
177+
cirq.Moment([cirq.Y(q1)]),
178+
cirq.measure(q1, key='a'),
179+
cirq.Z(q2).with_classical_controls('a'),
180+
)
181+
c_nested_aligned = cirq.FrozenCircuit(
182+
cirq.Moment([cirq.X(q1), cirq.X(q2).with_tags("nocompile")]),
183+
[cirq.Y(q1), cirq.Y(q1)],
184+
cirq.Moment(cirq.measure(q1, key='a'), cirq.X(q2)),
185+
cirq.Z(q2).with_classical_controls('a'),
186+
)
187+
c_orig = cirq.Circuit(
188+
c_nested,
189+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
190+
c_nested,
191+
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
192+
)
193+
c_expected = cirq.Circuit(
194+
c_nested_aligned,
195+
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
196+
cirq.Moment(),
197+
c_nested_aligned,
198+
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
199+
)
200+
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
201+
cirq.testing.assert_same_circuits(cirq.align_right(c_orig, context=context), c_expected)
202+
203+
136204
def test_classical_control():
137205
q0, q1 = cirq.LineQubit.range(2)
138206
circuit = cirq.Circuit(

Diff for: cirq-core/cirq/transformers/transformer_api.py

+12
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import (
2323
cast,
2424
Any,
25+
Callable,
2526
Tuple,
2627
Hashable,
2728
List,
@@ -30,6 +31,7 @@
3031
Type,
3132
TYPE_CHECKING,
3233
TypeVar,
34+
Union,
3335
)
3436
from typing_extensions import Protocol
3537

@@ -235,6 +237,16 @@ def __call__(
235237

236238
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
237239
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])
240+
_TRANSFORMER_OR_CLS_T = TypeVar(
241+
'_TRANSFORMER_OR_CLS_T', bound=Union[TRANSFORMER, Type[TRANSFORMER]]
242+
)
243+
244+
245+
@overload
246+
def transformer(
247+
*, add_support_for_deep: bool = False
248+
) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]:
249+
pass
238250

239251

240252
@overload

Diff for: cirq-core/cirq/transformers/transformer_api_test.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,9 @@ def make_transformer_func(add_support_for_deep: bool = False) -> cirq.TRANSFORME
8989
def mock_tranformer_func(
9090
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
9191
) -> cirq.Circuit:
92-
print("CALLED:", circuit, context, sep="\n\n----------\n\n")
9392
my_mock(circuit, context)
9493
return circuit.unfreeze()
9594

96-
# mock_transformer = mock_tranformer_func(add_support_for_deep)
9795
mock_tranformer_func.mock = my_mock # type: ignore
9896
return mock_tranformer_func
9997

@@ -141,9 +139,6 @@ def test_transformer_decorator_with_defaults(transformer):
141139
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))
142140

143141

144-
# from unittest.mock import patch, call
145-
146-
147142
@pytest.mark.parametrize(
148143
'transformer, supports_deep',
149144
[
@@ -154,9 +149,9 @@ def test_transformer_decorator_with_defaults(transformer):
154149
],
155150
)
156151
def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep):
157-
q1, q2 = cirq.LineQubit.range(2)
158-
c_nested_x = cirq.FrozenCircuit(cirq.X(q1))
159-
c_nested_y = cirq.FrozenCircuit(cirq.Y(q1))
152+
q = cirq.NamedQubit("q")
153+
c_nested_x = cirq.FrozenCircuit(cirq.X(q))
154+
c_nested_y = cirq.FrozenCircuit(cirq.Y(q))
160155
c_nested_xy = cirq.FrozenCircuit(
161156
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
162157
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"),

0 commit comments

Comments
 (0)