Skip to content

Add handling for sympy conditions in deferred measurement transformer #5824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b4849ca
Add handling for sympy conditions in deferred measurement transformer
daxfohl Aug 12, 2022
c88df71
docstring
daxfohl Aug 12, 2022
d9f2776
mypy
daxfohl Aug 12, 2022
48c5736
mypy
daxfohl Aug 12, 2022
17f5e80
cover
daxfohl Aug 12, 2022
d74a699
Make this more generic, covers all kinds of conditions.
daxfohl Aug 12, 2022
d37cbad
Better docs
daxfohl Aug 13, 2022
2c1d003
Sympy can also be CX
daxfohl Aug 13, 2022
27ff43d
docs
daxfohl Aug 13, 2022
78ef78d
docs
daxfohl Aug 13, 2022
4dcb564
Merge branch 'master' into sympy-deferred
tanujkhattar Aug 25, 2022
6ffe765
Merge branch 'master' into sympy-deferred
tanujkhattar Aug 25, 2022
ed1257d
Merge branch 'master' into sympy-deferred
daxfohl Sep 5, 2022
6bcc71c
Add mixed tests, simplify loop, add simplification in ControlledGate
daxfohl Sep 5, 2022
8f045bc
Fix error message
daxfohl Sep 5, 2022
1c32404
Simplify error message
daxfohl Sep 6, 2022
9fd971b
Inline variable
daxfohl Sep 14, 2022
4a484b0
Merge branch 'master' into sympy-deferred
daxfohl Sep 21, 2022
30b9121
Merge branch 'master' into sympy-deferred
daxfohl Oct 11, 2022
9d1f5ef
fix merge
daxfohl Oct 11, 2022
d4c80b9
qudit sympy test
daxfohl Oct 11, 2022
2d1cabf
Merge branch 'master' into sympy-deferred
daxfohl Oct 12, 2022
f033b39
fix build
daxfohl Oct 13, 2022
72388ce
Merge branch 'master' into sympy-deferred
daxfohl Oct 13, 2022
8e8dfc1
Fix test
daxfohl Oct 16, 2022
e733e89
Fix test
daxfohl Oct 16, 2022
3a6c750
Remove need for ControlledGate change
daxfohl Oct 28, 2022
1846525
mypy, comment
daxfohl Oct 28, 2022
05b22d3
Merge branch 'master' into sympy-deferred
daxfohl Oct 31, 2022
ed2c020
nits
daxfohl Nov 4, 2022
4b47de2
Merge branch 'master' into sympy-deferred
daxfohl Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING, Union

from cirq import ops, protocols, value
from cirq.transformers import transformer_api, transformer_primitives
Expand Down Expand Up @@ -82,7 +82,6 @@ def defer_measurements(
A circuit with equivalent logic, but all measurements at the end of the
circuit.
Raises:
ValueError: If sympy-based classical conditions are used.
NotImplementedError: When attempting to defer a measurement with a
confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
"""
Expand Down Expand Up @@ -112,19 +111,28 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
elif op.classical_controls:
new_op = op.without_classical_controls()
for c in op.classical_controls:
if isinstance(c, value.KeyCondition):
if c.key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qs = measurement_qubits[c.key]
if len(qs) == 1:
control_values: Any = range(1, qs[0].dimension)
else:
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
control_values = ops.SumOfProducts(anything_but_all_zeros)
new_op = new_op.controlled_by(*qs, control_values=control_values)
# Convert to a quantum control
missing_keys = [k for k in c.keys if k not in measurement_qubits]
if missing_keys:
raise ValueError(f'Deferred measurement for key={missing_keys[0]} not found.')
qs = [q for k in c.keys for q in measurement_qubits[k]]

# Try every possible datastore state against the condition, and the ones that work
# are the control values for the new op.
datastores = _all_possible_datastore_states(c.keys, measurement_qubits)
compatible_datastores = [store for store in datastores if c.resolve(store)]

# Rearrange these into the format expected by SumOfProducts
products = [
[i for k in c.keys for i in store.records[k][0]]
for store in compatible_datastores
]
if len(qs) == 1:
# Convenience: this renders more nicely than SumOfProducts.
control_values: Any = [[x[0] for x in products]]
else:
raise ValueError('Only KeyConditions are allowed.')
control_values = ops.SumOfProducts(products)
new_op = new_op.controlled_by(*qs, control_values=control_values)
return new_op
return op

Expand All @@ -139,6 +147,37 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return circuit


def _all_possible_datastore_states(
keys: Iterable['cirq.MeasurementKey'],
measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']],
) -> Iterable['cirq.ClassicalDataStoreReader']:
"""The full product of possible DatsStores states for the given keys."""
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
# (2, 2) and a key mapped to a qutrit, the possible measurement values are:
# [((0, 0), (0,)),
# ((0, 0), (1,)),
# ((0, 0), (2,)),
# ((0, 1), (0,)),
# ((0, 1), (1,)),
# ((0, 1), (2,)),
# ((1, 0), (0,)),
# ((1, 0), (1,)),
# ((1, 0), (2,)),
# ((1, 1), (0,)),
# ((1, 1), (1,)),
# ((1, 1), (2,))]
all_values = itertools.product(
*[
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]]))
for k in keys
]
)
# Then we create the ClassicalDataDictionaryStore for each of the above.
for sequences in all_values:
lookup = {k: [sequence] for k, sequence in zip(keys, sequences)}
yield value.ClassicalDataDictionaryStore(_records=lookup)


@transformer_api.transformer
def dephase_measurements(
circuit: 'cirq.AbstractCircuit',
Expand Down
83 changes: 74 additions & 9 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import pytest
import sympy
from sympy.parsing import sympy_parser

import cirq
from cirq.transformers.measurement_transformers import _MeasurementQid
Expand Down Expand Up @@ -58,6 +59,79 @@ def test_basic():
)


def test_sympy_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.X(q1).with_classical_controls(sympy.Symbol('a')),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.CX(q_ma, q1),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_sympy_control_complex():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, key='b'),
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a >= b')),
cirq.measure(q2, key='c'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
q_mb = _MeasurementQid('b', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.CX(q1, q_mb),
cirq.ControlledOperation(
[q_ma, q_mb], cirq.X(q2), cirq.SumOfProducts([[0, 0], [1, 0], [1, 1]])
),
cirq.measure(q_ma, key='a'),
cirq.measure(q_mb, key='b'),
cirq.measure(q2, key='c'),
),
)


def test_sympy_control_multiqubit():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q0, q1, key='a'),
cirq.X(q2).with_classical_controls(sympy_parser.parse_expr('a >= 2')),
cirq.measure(q2, key='c'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma0 = _MeasurementQid('a', q0)
q_ma1 = _MeasurementQid('a', q1)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma0),
cirq.CX(q1, q_ma1),
cirq.ControlledOperation(
[q_ma0, q_ma1], cirq.X(q2), cirq.SumOfProducts([[1, 0], [1, 1]])
),
cirq.measure(q_ma0, q_ma1, key='a'),
cirq.measure(q2, key='c'),
),
)


def test_nocompile_context():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down Expand Up @@ -295,15 +369,6 @@ def test_repr(qid: _MeasurementQid):
test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4)))


def test_sympy_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, q1, key='a'), cirq.X(q1).with_classical_controls(sympy.Symbol('a'))
)
with pytest.raises(ValueError, match='Only KeyConditions are allowed'):
_ = cirq.defer_measurements(circuit)


def test_confusion_map():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down