Skip to content

Commit 99e8a13

Browse files
authored
Ensure that cirq.decompose traverses the yielded OP-TREE in dfs ordering (#6117)
1 parent b1e09a9 commit 99e8a13

7 files changed

+84
-26
lines changed

cirq-core/cirq/ops/classically_controlled_operation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import sympy
2929

3030
from cirq import protocols, value
31-
from cirq.ops import raw_types
31+
from cirq.ops import op_tree, raw_types
3232

3333
if TYPE_CHECKING:
3434
import cirq
@@ -105,11 +105,13 @@ def with_qubits(self, *new_qubits):
105105
)
106106

107107
def _decompose_(self):
108-
result = protocols.decompose_once(self._sub_operation, NotImplemented)
108+
result = protocols.decompose_once(self._sub_operation, NotImplemented, flatten=False)
109109
if result is NotImplemented:
110110
return NotImplemented
111111

112-
return [ClassicallyControlledOperation(op, self._conditions) for op in result]
112+
return op_tree.transform_op_tree(
113+
result, lambda op: ClassicallyControlledOperation(op, self._conditions)
114+
)
113115

114116
def _value_equality_values_(self):
115117
return (frozenset(self._conditions), self._sub_operation)

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@
2828
import numpy as np
2929

3030
from cirq import protocols, value, _import
31-
from cirq.ops import raw_types, controlled_operation as cop, matrix_gates, control_values as cv
31+
from cirq.ops import (
32+
raw_types,
33+
controlled_operation as cop,
34+
op_tree,
35+
matrix_gates,
36+
control_values as cv,
37+
)
3238
from cirq.type_workarounds import NotImplementedType
3339

3440
if TYPE_CHECKING:
@@ -194,17 +200,17 @@ def _decompose_(
194200
return NotImplemented
195201

196202
result = protocols.decompose_once_with_qubits(
197-
self.sub_gate, qubits[self.num_controls() :], NotImplemented
203+
self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False
198204
)
199205
if result is NotImplemented:
200206
return NotImplemented
201207

202-
decomposed: List['cirq.Operation'] = []
203-
for op in result:
204-
decomposed.append(
205-
op.controlled_by(*qubits[: self.num_controls()], control_values=self.control_values)
206-
)
207-
return decomposed
208+
return op_tree.transform_op_tree(
209+
result,
210+
lambda op: op.controlled_by(
211+
*qubits[: self.num_controls()], control_values=self.control_values
212+
),
213+
)
208214

209215
def on(self, *qubits: 'cirq.Qid') -> cop.ControlledOperation:
210216
if len(qubits) == 0:

cirq-core/cirq/ops/controlled_operation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
eigen_gate,
3535
gate_operation,
3636
matrix_gates,
37+
op_tree,
3738
raw_types,
3839
control_values as cv,
3940
)
@@ -145,7 +146,9 @@ def with_qubits(self, *new_qubits):
145146
)
146147

147148
def _decompose_(self):
148-
result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
149+
result = protocols.decompose_once_with_qubits(
150+
self.gate, self.qubits, NotImplemented, flatten=False
151+
)
149152
if result is not NotImplemented:
150153
return result
151154

@@ -154,13 +157,13 @@ def _decompose_(self):
154157
# local phase in the controlled variant and hence cannot be ignored.
155158
return NotImplemented
156159

157-
result = protocols.decompose_once(self.sub_operation, NotImplemented)
160+
result = protocols.decompose_once(self.sub_operation, NotImplemented, flatten=False)
158161
if result is NotImplemented:
159162
return NotImplemented
160163

161-
return [
162-
op.controlled_by(*self.controls, control_values=self.control_values) for op in result
163-
]
164+
return op_tree.transform_op_tree(
165+
result, lambda op: op.controlled_by(*self.controls, control_values=self.control_values)
166+
)
164167

165168
def _value_equality_values_(self):
166169
sorted_controls, expanded_cvals = tuple(

cirq-core/cirq/ops/gate_operation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def _num_qubits_(self):
160160
return len(self._qubits)
161161

162162
def _decompose_(self) -> 'cirq.OP_TREE':
163-
return protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
163+
return protocols.decompose_once_with_qubits(
164+
self.gate, self.qubits, NotImplemented, flatten=False
165+
)
164166

165167
def _pauli_expansion_(self) -> value.LinearDict[str]:
166168
getter = getattr(self.gate, '_pauli_expansion_', None)

cirq-core/cirq/ops/raw_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def _json_dict_(self) -> Dict[str, Any]:
830830
return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags'])
831831

832832
def _decompose_(self) -> 'cirq.OP_TREE':
833-
return protocols.decompose_once(self.sub_operation, default=None)
833+
return protocols.decompose_once(self.sub_operation, default=None, flatten=False)
834834

835835
def _pauli_expansion_(self) -> value.LinearDict[str]:
836836
return protocols.pauli_expansion(self.sub_operation)

cirq-core/cirq/protocols/decompose_protocol.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation'
160160
decomposed = _try_op_decomposer(item, args.intercepting_decomposer)
161161

162162
if decomposed is NotImplemented or decomposed is None:
163-
decomposed = decompose_once(item, default=None)
163+
decomposed = decompose_once(item, default=None, flatten=False)
164164

165165
if decomposed is NotImplemented or decomposed is None:
166166
decomposed = _try_op_decomposer(item, args.fallback_decomposer)
@@ -275,12 +275,14 @@ def decompose_once(val: Any, **kwargs) -> List['cirq.Operation']:
275275

276276
@overload
277277
def decompose_once(
278-
val: Any, default: TDefault, *args, **kwargs
278+
val: Any, default: TDefault, *args, flatten: bool = True, **kwargs
279279
) -> Union[TDefault, List['cirq.Operation']]:
280280
pass
281281

282282

283-
def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwargs):
283+
def decompose_once(
284+
val: Any, default=RaiseTypeErrorIfNotProvided, *args, flatten: bool = True, **kwargs
285+
):
284286
"""Decomposes a value into operations, if possible.
285287
286288
This method decomposes the value exactly once, instead of decomposing it
@@ -296,6 +298,7 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
296298
*args: Positional arguments to forward into the `_decompose_` method of
297299
`val`. For example, this is used to tell gates what qubits they are
298300
being applied to.
301+
flatten: If True, the returned OP-TREE will be flattened to a list of operations.
299302
**kwargs: Keyword arguments to forward into the `_decompose_` method of
300303
`val`.
301304
@@ -311,9 +314,8 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
311314
"""
312315
method = getattr(val, '_decompose_', None)
313316
decomposed = NotImplemented if method is None else method(*args, **kwargs)
314-
315317
if decomposed is not NotImplemented and decomposed is not None:
316-
return list(ops.flatten_op_tree(decomposed))
318+
return list(ops.flatten_to_ops(decomposed)) if flatten else decomposed
317319

318320
if default is not RaiseTypeErrorIfNotProvided:
319321
return default
@@ -332,13 +334,16 @@ def decompose_once_with_qubits(val: Any, qubits: Iterable['cirq.Qid']) -> List['
332334

333335
@overload
334336
def decompose_once_with_qubits(
335-
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault]
337+
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault], flatten: bool = True
336338
) -> Union[TDefault, List['cirq.Operation']]:
337339
pass
338340

339341

340342
def decompose_once_with_qubits(
341-
val: Any, qubits: Iterable['cirq.Qid'], default=RaiseTypeErrorIfNotProvided
343+
val: Any,
344+
qubits: Iterable['cirq.Qid'],
345+
default=RaiseTypeErrorIfNotProvided,
346+
flatten: bool = True,
342347
):
343348
"""Decomposes a value into operations on the given qubits.
344349
@@ -355,6 +360,7 @@ def decompose_once_with_qubits(
355360
`_decompose_` method or that method returns `NotImplemented` or
356361
`None`. If not specified, non-decomposable values cause a
357362
`TypeError`.
363+
flatten: If True, the returned OP-TREE will be flattened to a list of operations.
358364
359365
Returns:
360366
The result of `val._decompose_(qubits)`, if `val` has a
@@ -366,7 +372,7 @@ def decompose_once_with_qubits(
366372
`val` didn't have a `_decompose_` method (or that method returned
367373
`NotImplemented` or `None`) and `default` wasn't set.
368374
"""
369-
return decompose_once(val, default, tuple(qubits))
375+
return decompose_once(val, default, tuple(qubits), flatten=flatten)
370376

371377

372378
# pylint: enable=function-redefined

cirq-core/cirq/protocols/decompose_protocol_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Optional
15+
from unittest import mock
1416
import pytest
1517

1618
import cirq
@@ -308,3 +310,40 @@ def test_decompose_tagged_operation():
308310
'tag',
309311
)
310312
assert cirq.decompose_once(op) == cirq.decompose_once(op.untagged)
313+
314+
315+
def test_decompose_recursive_dfs():
316+
class RecursiveDecompose(cirq.Gate):
317+
def __init__(self, recurse: bool = True, mock_qm: Optional[mock.Mock] = None):
318+
self.recurse = recurse
319+
self.mock_qm = mock.Mock() if mock_qm is None else mock_qm
320+
321+
def _num_qubits_(self) -> int:
322+
return 2
323+
324+
def _decompose_(self, qubits):
325+
self.mock_qm.qalloc(self.recurse)
326+
yield RecursiveDecompose(recurse=False, mock_qm=self.mock_qm).on(
327+
*qubits
328+
) if self.recurse else cirq.Z.on_each(*qubits)
329+
self.mock_qm.qfree(self.recurse)
330+
331+
def _has_unitary_(self):
332+
return True
333+
334+
expected_calls = [
335+
mock.call.qalloc(True),
336+
mock.call.qalloc(False),
337+
mock.call.qfree(False),
338+
mock.call.qfree(True),
339+
]
340+
mock_qm = mock.Mock(spec=["qalloc", "qfree"])
341+
q = cirq.LineQubit.range(3)
342+
gate = RecursiveDecompose(mock_qm=mock_qm)
343+
gate_op = gate.on(*q[:2])
344+
controlled_op = gate_op.controlled_by(q[2])
345+
classically_controlled_op = gate_op.with_classical_controls('key')
346+
for op in [gate_op, controlled_op, classically_controlled_op]:
347+
mock_qm.reset_mock()
348+
_ = cirq.decompose(op)
349+
assert mock_qm.method_calls == expected_calls

0 commit comments

Comments
 (0)