Skip to content

Add handling for sympy conditions in deferred measurement transformer #5824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b4849ca
Add handling for sympy conditions in deferred measurement transformer
daxfohl Aug 12, 2022
c88df71
docstring
daxfohl Aug 12, 2022
d9f2776
mypy
daxfohl Aug 12, 2022
48c5736
mypy
daxfohl Aug 12, 2022
17f5e80
cover
daxfohl Aug 12, 2022
d74a699
Make this more generic, covers all kinds of conditions.
daxfohl Aug 12, 2022
d37cbad
Better docs
daxfohl Aug 13, 2022
2c1d003
Sympy can also be CX
daxfohl Aug 13, 2022
27ff43d
docs
daxfohl Aug 13, 2022
78ef78d
docs
daxfohl Aug 13, 2022
4dcb564
Merge branch 'master' into sympy-deferred
tanujkhattar Aug 25, 2022
6ffe765
Merge branch 'master' into sympy-deferred
tanujkhattar Aug 25, 2022
ed1257d
Merge branch 'master' into sympy-deferred
daxfohl Sep 5, 2022
6bcc71c
Add mixed tests, simplify loop, add simplification in ControlledGate
daxfohl Sep 5, 2022
8f045bc
Fix error message
daxfohl Sep 5, 2022
1c32404
Simplify error message
daxfohl Sep 6, 2022
9fd971b
Inline variable
daxfohl Sep 14, 2022
4a484b0
Merge branch 'master' into sympy-deferred
daxfohl Sep 21, 2022
30b9121
Merge branch 'master' into sympy-deferred
daxfohl Oct 11, 2022
9d1f5ef
fix merge
daxfohl Oct 11, 2022
d4c80b9
qudit sympy test
daxfohl Oct 11, 2022
2d1cabf
Merge branch 'master' into sympy-deferred
daxfohl Oct 12, 2022
f033b39
fix build
daxfohl Oct 13, 2022
72388ce
Merge branch 'master' into sympy-deferred
daxfohl Oct 13, 2022
8e8dfc1
Fix test
daxfohl Oct 16, 2022
e733e89
Fix test
daxfohl Oct 16, 2022
3a6c750
Remove need for ControlledGate change
daxfohl Oct 28, 2022
1846525
mypy, comment
daxfohl Oct 28, 2022
05b22d3
Merge branch 'master' into sympy-deferred
daxfohl Oct 31, 2022
ed2c020
nits
daxfohl Nov 4, 2022
4b47de2
Merge branch 'master' into sympy-deferred
daxfohl Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np

Expand Down Expand Up @@ -85,7 +96,6 @@ def defer_measurements(
A circuit with equivalent logic, but all measurements at the end of the
circuit.
Raises:
ValueError: If sympy-based classical conditions are used.
NotImplementedError: When attempting to defer a measurement with a
confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
"""
Expand All @@ -109,25 +119,34 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
)
for indexes, m in gate.confusion_map.items()
]
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + confusions + xs
elif protocols.is_measurement(op):
return [defer(op, None) for op in protocols.decompose_once(op)]
elif op.classical_controls:
new_op = op.without_classical_controls()
for c in op.classical_controls:
if isinstance(c, value.KeyCondition):
if c.key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qs = measurement_qubits[c.key]
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
control_values = ops.SumOfProducts(anything_but_all_zeros)
new_op = new_op.controlled_by(*qs, control_values=control_values)
else:
raise ValueError('Only KeyConditions are allowed.')
return new_op
# Convert to a quantum control
keys = sorted(set(key for c in op.classical_controls for key in c.keys))
for key in keys:
if key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={key} not found.')

# Try every possible datastore state (exponential in the number of keys) against the
# condition, and the ones that work are the control values for the new op.
compatible_datastores = [
store
for store in _all_possible_datastore_states(keys, measurement_qubits)
if all(c.resolve(store) for c in op.classical_controls)
]

# Rearrange these into the format expected by SumOfProducts
products = [
[i for key in keys for i in store.records[key][0]]
for store in compatible_datastores
]

control_values = ops.SumOfProducts(products)
qs = [q for key in keys for q in measurement_qubits[key]]
return op.without_classical_controls().controlled_by(*qs, control_values=control_values)
return op

circuit = transformer_primitives.map_operations_and_unroll(
Expand All @@ -141,6 +160,39 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return circuit


def _all_possible_datastore_states(
keys: Iterable['cirq.MeasurementKey'],
measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']],
) -> Iterable['cirq.ClassicalDataStoreReader']:
"""The cartesian product of all possible DataStore states for the given keys."""
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
# (2, 2) and a key mapped to a qutrit, the possible measurement values are:
# [((0, 0), (0,)),
# ((0, 0), (1,)),
# ((0, 0), (2,)),
# ((0, 1), (0,)),
# ((0, 1), (1,)),
# ((0, 1), (2,)),
# ((1, 0), (0,)),
# ((1, 0), (1,)),
# ((1, 0), (2,)),
# ((1, 1), (0,)),
# ((1, 1), (1,)),
# ((1, 1), (2,))]
all_values = itertools.product(
*[
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]]))
for k in keys
]
)
# Then we create the ClassicalDataDictionaryStore for each of the above.
for sequences in all_values:
lookup = {k: [sequence] for k, sequence in zip(keys, sequences)}
yield value.ClassicalDataDictionaryStore(
_records=lookup, _measured_qubits={k: [tuple(measurement_qubits[k])] for k in keys}
)


@transformer_api.transformer
def dephase_measurements(
circuit: 'cirq.AbstractCircuit',
Expand Down
183 changes: 174 additions & 9 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import pytest
import sympy
from sympy.parsing import sympy_parser

import cirq
from cirq.transformers.measurement_transformers import _ConfusionChannel, _MeasurementQid, _mod_add
Expand Down Expand Up @@ -79,6 +80,179 @@ def test_qudits():
)


def test_sympy_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.X(q1).with_classical_controls(sympy.Symbol('a')),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.CX(q_ma, q1),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_sympy_qudits():
q0, q1 = cirq.LineQid.range(2, dimension=3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.XPowGate(dimension=3).on(q1).with_classical_controls(sympy.Symbol('a')),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
_mod_add(q0, q_ma),
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_sympy_control_complex():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, key='b'),
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a >= b')),
cirq.measure(q2, key='c'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
q_mb = _MeasurementQid('b', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.CX(q1, q_mb),
cirq.ControlledOperation(
[q_ma, q_mb], cirq.X(q2), cirq.SumOfProducts([[0, 0], [1, 0], [1, 1]])
),
cirq.measure(q_ma, key='a'),
cirq.measure(q_mb, key='b'),
cirq.measure(q2, key='c'),
),
)


def test_sympy_control_complex_qudit():
q0, q1, q2 = cirq.LineQid.for_qid_shape((4, 2, 2))
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, key='b'),
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a > b')),
cirq.measure(q2, key='c'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
q_mb = _MeasurementQid('b', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
_mod_add(q0, q_ma),
cirq.CX(q1, q_mb),
cirq.ControlledOperation(
[q_ma, q_mb],
cirq.X(q2),
cirq.SumOfProducts([[1, 0], [2, 0], [3, 0], [2, 1], [3, 1]]),
),
cirq.measure(q_ma, key='a'),
cirq.measure(q_mb, key='b'),
cirq.measure(q2, key='c'),
),
)


def test_multiple_sympy_control_complex():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, key='b'),
cirq.X(q2)
.with_classical_controls(sympy_parser.parse_expr('a >= b'))
.with_classical_controls(sympy_parser.parse_expr('a <= b')),
cirq.measure(q2, key='c'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
q_mb = _MeasurementQid('b', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.CX(q1, q_mb),
cirq.ControlledOperation(
[q_ma, q_mb], cirq.X(q2), cirq.SumOfProducts([[0, 0], [1, 1]])
),
cirq.measure(q_ma, key='a'),
cirq.measure(q_mb, key='b'),
cirq.measure(q2, key='c'),
),
)


def test_sympy_and_key_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.X(q1).with_classical_controls(sympy.Symbol('a')).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.CX(q_ma, q1),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_sympy_control_multiqubit():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q0, q1, key='a'),
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a >= 2')),
cirq.measure(q2, key='c'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma0 = _MeasurementQid('a', q0)
q_ma1 = _MeasurementQid('a', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma0),
cirq.CX(q1, q_ma1),
cirq.ControlledOperation(
[q_ma0, q_ma1], cirq.X(q2), cirq.SumOfProducts([[1, 0], [1, 1]])
),
cirq.measure(q_ma0, q_ma1, key='a'),
cirq.measure(q2, key='c'),
),
)


def test_nocompile_context():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down Expand Up @@ -316,15 +490,6 @@ def test_repr(qid: _MeasurementQid):
test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4)))


def test_sympy_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, q1, key='a'), cirq.X(q1).with_classical_controls(sympy.Symbol('a'))
)
with pytest.raises(ValueError, match='Only KeyConditions are allowed'):
_ = cirq.defer_measurements(circuit)


def test_confusion_map():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down