|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import itertools |
15 | 16 | from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
|
16 | 17 |
|
17 | 18 | from cirq import ops, protocols, value
|
@@ -81,9 +82,7 @@ def defer_measurements(
|
81 | 82 | A circuit with equivalent logic, but all measurements at the end of the
|
82 | 83 | circuit.
|
83 | 84 | Raises:
|
84 |
| - ValueError: If sympy-based classical conditions are used, or if |
85 |
| - conditions based on multi-qubit measurements exist. (The latter of |
86 |
| - these is planned to be implemented soon). |
| 85 | + ValueError: If sympy-based classical conditions are used. |
87 | 86 | NotImplementedError: When attempting to defer a measurement with a
|
88 | 87 | confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
|
89 | 88 | """
|
@@ -111,23 +110,22 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
|
111 | 110 | elif protocols.is_measurement(op):
|
112 | 111 | return [defer(op, None) for op in protocols.decompose_once(op)]
|
113 | 112 | elif op.classical_controls:
|
114 |
| - controls = [] |
| 113 | + new_op = op.without_classical_controls() |
115 | 114 | for c in op.classical_controls:
|
116 | 115 | if isinstance(c, value.KeyCondition):
|
117 | 116 | if c.key not in measurement_qubits:
|
118 | 117 | raise ValueError(f'Deferred measurement for key={c.key} not found.')
|
119 |
| - qubits = measurement_qubits[c.key] |
120 |
| - if len(qubits) != 1: |
121 |
| - # TODO: Multi-qubit conditions require |
122 |
| - # https://github.com/quantumlib/Cirq/issues/4512 |
123 |
| - # Remember to update docstring above once this works. |
124 |
| - raise ValueError('Only single qubit conditions are allowed.') |
125 |
| - controls.extend(qubits) |
| 118 | + qs = measurement_qubits[c.key] |
| 119 | + if len(qs) == 1: |
| 120 | + control_values: Any = range(1, qs[0].dimension) |
| 121 | + else: |
| 122 | + all_values = itertools.product(*[range(q.dimension) for q in qs]) |
| 123 | + anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None)) |
| 124 | + control_values = ops.SumOfProducts(anything_but_all_zeros) |
| 125 | + new_op = new_op.controlled_by(*qs, control_values=control_values) |
126 | 126 | else:
|
127 | 127 | raise ValueError('Only KeyConditions are allowed.')
|
128 |
| - return op.without_classical_controls().controlled_by( |
129 |
| - *controls, control_values=[tuple(range(1, q.dimension)) for q in controls] |
130 |
| - ) |
| 128 | + return new_op |
131 | 129 | return op
|
132 | 130 |
|
133 | 131 | circuit = transformer_primitives.map_operations_and_unroll(
|
|
0 commit comments