Skip to content

Commit 6fba8be

Browse files
Add handling for sympy conditions in deferred measurement transformer (#5824)
* Add handling for sympy conditions in deferred measurement transformer * docstring * mypy * mypy * cover * Make this more generic, covers all kinds of conditions. * Better docs * Sympy can also be CX * docs * docs * Add mixed tests, simplify loop, add simplification in ControlledGate * Fix error message * Simplify error message * Inline variable * fix merge * qudit sympy test * fix build * Fix test * Fix test * Remove need for ControlledGate change * mypy, comment * nits Co-authored-by: Tanuj Khattar <[email protected]>
1 parent c485504 commit 6fba8be

File tree

2 files changed

+242
-25
lines changed

2 files changed

+242
-25
lines changed

cirq-core/cirq/transformers/measurement_transformers.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,18 @@
1313
# limitations under the License.
1414

1515
import itertools
16-
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
16+
from typing import (
17+
Any,
18+
Dict,
19+
Iterable,
20+
List,
21+
Mapping,
22+
Optional,
23+
Sequence,
24+
Tuple,
25+
TYPE_CHECKING,
26+
Union,
27+
)
1728

1829
import numpy as np
1930

@@ -85,7 +96,6 @@ def defer_measurements(
8596
A circuit with equivalent logic, but all measurements at the end of the
8697
circuit.
8798
Raises:
88-
ValueError: If sympy-based classical conditions are used.
8999
NotImplementedError: When attempting to defer a measurement with a
90100
confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
91101
"""
@@ -109,25 +119,34 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
109119
)
110120
for indexes, m in gate.confusion_map.items()
111121
]
112-
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
113122
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
114123
return cxs + confusions + xs
115124
elif protocols.is_measurement(op):
116125
return [defer(op, None) for op in protocols.decompose_once(op)]
117126
elif op.classical_controls:
118-
new_op = op.without_classical_controls()
119-
for c in op.classical_controls:
120-
if isinstance(c, value.KeyCondition):
121-
if c.key not in measurement_qubits:
122-
raise ValueError(f'Deferred measurement for key={c.key} not found.')
123-
qs = measurement_qubits[c.key]
124-
all_values = itertools.product(*[range(q.dimension) for q in qs])
125-
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
126-
control_values = ops.SumOfProducts(anything_but_all_zeros)
127-
new_op = new_op.controlled_by(*qs, control_values=control_values)
128-
else:
129-
raise ValueError('Only KeyConditions are allowed.')
130-
return new_op
127+
# Convert to a quantum control
128+
keys = sorted(set(key for c in op.classical_controls for key in c.keys))
129+
for key in keys:
130+
if key not in measurement_qubits:
131+
raise ValueError(f'Deferred measurement for key={key} not found.')
132+
133+
# Try every possible datastore state (exponential in the number of keys) against the
134+
# condition, and the ones that work are the control values for the new op.
135+
compatible_datastores = [
136+
store
137+
for store in _all_possible_datastore_states(keys, measurement_qubits)
138+
if all(c.resolve(store) for c in op.classical_controls)
139+
]
140+
141+
# Rearrange these into the format expected by SumOfProducts
142+
products = [
143+
[i for key in keys for i in store.records[key][0]]
144+
for store in compatible_datastores
145+
]
146+
147+
control_values = ops.SumOfProducts(products)
148+
qs = [q for key in keys for q in measurement_qubits[key]]
149+
return op.without_classical_controls().controlled_by(*qs, control_values=control_values)
131150
return op
132151

133152
circuit = transformer_primitives.map_operations_and_unroll(
@@ -141,6 +160,39 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
141160
return circuit
142161

143162

163+
def _all_possible_datastore_states(
164+
keys: Iterable['cirq.MeasurementKey'],
165+
measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']],
166+
) -> Iterable['cirq.ClassicalDataStoreReader']:
167+
"""The cartesian product of all possible DataStore states for the given keys."""
168+
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
169+
# (2, 2) and a key mapped to a qutrit, the possible measurement values are:
170+
# [((0, 0), (0,)),
171+
# ((0, 0), (1,)),
172+
# ((0, 0), (2,)),
173+
# ((0, 1), (0,)),
174+
# ((0, 1), (1,)),
175+
# ((0, 1), (2,)),
176+
# ((1, 0), (0,)),
177+
# ((1, 0), (1,)),
178+
# ((1, 0), (2,)),
179+
# ((1, 1), (0,)),
180+
# ((1, 1), (1,)),
181+
# ((1, 1), (2,))]
182+
all_values = itertools.product(
183+
*[
184+
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]]))
185+
for k in keys
186+
]
187+
)
188+
# Then we create the ClassicalDataDictionaryStore for each of the above.
189+
for sequences in all_values:
190+
lookup = {k: [sequence] for k, sequence in zip(keys, sequences)}
191+
yield value.ClassicalDataDictionaryStore(
192+
_records=lookup, _measured_qubits={k: [tuple(measurement_qubits[k])] for k in keys}
193+
)
194+
195+
144196
@transformer_api.transformer
145197
def dephase_measurements(
146198
circuit: 'cirq.AbstractCircuit',

cirq-core/cirq/transformers/measurement_transformers_test.py

Lines changed: 174 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import pytest
1717
import sympy
18+
from sympy.parsing import sympy_parser
1819

1920
import cirq
2021
from cirq.transformers.measurement_transformers import _ConfusionChannel, _MeasurementQid, _mod_add
@@ -79,6 +80,179 @@ def test_qudits():
7980
)
8081

8182

83+
def test_sympy_control():
84+
q0, q1 = cirq.LineQubit.range(2)
85+
circuit = cirq.Circuit(
86+
cirq.measure(q0, key='a'),
87+
cirq.X(q1).with_classical_controls(sympy.Symbol('a')),
88+
cirq.measure(q1, key='b'),
89+
)
90+
assert_equivalent_to_deferred(circuit)
91+
deferred = cirq.defer_measurements(circuit)
92+
q_ma = _MeasurementQid('a', q0)
93+
cirq.testing.assert_same_circuits(
94+
deferred,
95+
cirq.Circuit(
96+
cirq.CX(q0, q_ma),
97+
cirq.CX(q_ma, q1),
98+
cirq.measure(q_ma, key='a'),
99+
cirq.measure(q1, key='b'),
100+
),
101+
)
102+
103+
104+
def test_sympy_qudits():
105+
q0, q1 = cirq.LineQid.range(2, dimension=3)
106+
circuit = cirq.Circuit(
107+
cirq.measure(q0, key='a'),
108+
cirq.XPowGate(dimension=3).on(q1).with_classical_controls(sympy.Symbol('a')),
109+
cirq.measure(q1, key='b'),
110+
)
111+
assert_equivalent_to_deferred(circuit)
112+
deferred = cirq.defer_measurements(circuit)
113+
q_ma = _MeasurementQid('a', q0)
114+
cirq.testing.assert_same_circuits(
115+
deferred,
116+
cirq.Circuit(
117+
_mod_add(q0, q_ma),
118+
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
119+
cirq.measure(q_ma, key='a'),
120+
cirq.measure(q1, key='b'),
121+
),
122+
)
123+
124+
125+
def test_sympy_control_complex():
126+
q0, q1, q2 = cirq.LineQubit.range(3)
127+
circuit = cirq.Circuit(
128+
cirq.measure(q0, key='a'),
129+
cirq.measure(q1, key='b'),
130+
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a >= b')),
131+
cirq.measure(q2, key='c'),
132+
)
133+
assert_equivalent_to_deferred(circuit)
134+
deferred = cirq.defer_measurements(circuit)
135+
q_ma = _MeasurementQid('a', q0)
136+
q_mb = _MeasurementQid('b', q1)
137+
cirq.testing.assert_same_circuits(
138+
deferred,
139+
cirq.Circuit(
140+
cirq.CX(q0, q_ma),
141+
cirq.CX(q1, q_mb),
142+
cirq.ControlledOperation(
143+
[q_ma, q_mb], cirq.X(q2), cirq.SumOfProducts([[0, 0], [1, 0], [1, 1]])
144+
),
145+
cirq.measure(q_ma, key='a'),
146+
cirq.measure(q_mb, key='b'),
147+
cirq.measure(q2, key='c'),
148+
),
149+
)
150+
151+
152+
def test_sympy_control_complex_qudit():
153+
q0, q1, q2 = cirq.LineQid.for_qid_shape((4, 2, 2))
154+
circuit = cirq.Circuit(
155+
cirq.measure(q0, key='a'),
156+
cirq.measure(q1, key='b'),
157+
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a > b')),
158+
cirq.measure(q2, key='c'),
159+
)
160+
assert_equivalent_to_deferred(circuit)
161+
deferred = cirq.defer_measurements(circuit)
162+
q_ma = _MeasurementQid('a', q0)
163+
q_mb = _MeasurementQid('b', q1)
164+
cirq.testing.assert_same_circuits(
165+
deferred,
166+
cirq.Circuit(
167+
_mod_add(q0, q_ma),
168+
cirq.CX(q1, q_mb),
169+
cirq.ControlledOperation(
170+
[q_ma, q_mb],
171+
cirq.X(q2),
172+
cirq.SumOfProducts([[1, 0], [2, 0], [3, 0], [2, 1], [3, 1]]),
173+
),
174+
cirq.measure(q_ma, key='a'),
175+
cirq.measure(q_mb, key='b'),
176+
cirq.measure(q2, key='c'),
177+
),
178+
)
179+
180+
181+
def test_multiple_sympy_control_complex():
182+
q0, q1, q2 = cirq.LineQubit.range(3)
183+
circuit = cirq.Circuit(
184+
cirq.measure(q0, key='a'),
185+
cirq.measure(q1, key='b'),
186+
cirq.X(q2)
187+
.with_classical_controls(sympy_parser.parse_expr('a >= b'))
188+
.with_classical_controls(sympy_parser.parse_expr('a <= b')),
189+
cirq.measure(q2, key='c'),
190+
)
191+
assert_equivalent_to_deferred(circuit)
192+
deferred = cirq.defer_measurements(circuit)
193+
q_ma = _MeasurementQid('a', q0)
194+
q_mb = _MeasurementQid('b', q1)
195+
cirq.testing.assert_same_circuits(
196+
deferred,
197+
cirq.Circuit(
198+
cirq.CX(q0, q_ma),
199+
cirq.CX(q1, q_mb),
200+
cirq.ControlledOperation(
201+
[q_ma, q_mb], cirq.X(q2), cirq.SumOfProducts([[0, 0], [1, 1]])
202+
),
203+
cirq.measure(q_ma, key='a'),
204+
cirq.measure(q_mb, key='b'),
205+
cirq.measure(q2, key='c'),
206+
),
207+
)
208+
209+
210+
def test_sympy_and_key_control():
211+
q0, q1 = cirq.LineQubit.range(2)
212+
circuit = cirq.Circuit(
213+
cirq.measure(q0, key='a'),
214+
cirq.X(q1).with_classical_controls(sympy.Symbol('a')).with_classical_controls('a'),
215+
cirq.measure(q1, key='b'),
216+
)
217+
assert_equivalent_to_deferred(circuit)
218+
deferred = cirq.defer_measurements(circuit)
219+
q_ma = _MeasurementQid('a', q0)
220+
cirq.testing.assert_same_circuits(
221+
deferred,
222+
cirq.Circuit(
223+
cirq.CX(q0, q_ma),
224+
cirq.CX(q_ma, q1),
225+
cirq.measure(q_ma, key='a'),
226+
cirq.measure(q1, key='b'),
227+
),
228+
)
229+
230+
231+
def test_sympy_control_multiqubit():
232+
q0, q1, q2 = cirq.LineQubit.range(3)
233+
circuit = cirq.Circuit(
234+
cirq.measure(q0, q1, key='a'),
235+
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a >= 2')),
236+
cirq.measure(q2, key='c'),
237+
)
238+
assert_equivalent_to_deferred(circuit)
239+
deferred = cirq.defer_measurements(circuit)
240+
q_ma0 = _MeasurementQid('a', q0)
241+
q_ma1 = _MeasurementQid('a', q1)
242+
cirq.testing.assert_same_circuits(
243+
deferred,
244+
cirq.Circuit(
245+
cirq.CX(q0, q_ma0),
246+
cirq.CX(q1, q_ma1),
247+
cirq.ControlledOperation(
248+
[q_ma0, q_ma1], cirq.X(q2), cirq.SumOfProducts([[1, 0], [1, 1]])
249+
),
250+
cirq.measure(q_ma0, q_ma1, key='a'),
251+
cirq.measure(q2, key='c'),
252+
),
253+
)
254+
255+
82256
def test_nocompile_context():
83257
q0, q1 = cirq.LineQubit.range(2)
84258
circuit = cirq.Circuit(
@@ -316,15 +490,6 @@ def test_repr(qid: _MeasurementQid):
316490
test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4)))
317491

318492

319-
def test_sympy_control():
320-
q0, q1 = cirq.LineQubit.range(2)
321-
circuit = cirq.Circuit(
322-
cirq.measure(q0, q1, key='a'), cirq.X(q1).with_classical_controls(sympy.Symbol('a'))
323-
)
324-
with pytest.raises(ValueError, match='Only KeyConditions are allowed'):
325-
_ = cirq.defer_measurements(circuit)
326-
327-
328493
def test_confusion_map():
329494
q0, q1 = cirq.LineQubit.range(2)
330495
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)