Skip to content

Commit 0bdc9f1

Browse files
authored
Deferred measurements transformer (#4849)
Closes #4818, Also reimplements `mux` simulation based on this, in preparation to deprecate `ignore_measurement_results`. Needs a follow-up after #4512 to support classical controls on multi-qubit measurements, as we need some way of defining the condition "at least one qubit is not zero" to match the classical interpretation of a multi-qubit measurement.
1 parent aed1964 commit 0bdc9f1

8 files changed

+569
-4
lines changed

cirq-core/cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@
362362
decompose_multi_controlled_x,
363363
decompose_multi_controlled_rotation,
364364
decompose_two_qubit_interaction_into_four_fsim_gates,
365+
defer_measurements,
366+
dephase_measurements,
365367
drop_empty_moments,
366368
drop_negligible_operations,
367369
eject_phased_paulis,

cirq-core/cirq/ops/kraus_channel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self._key = key
5656

5757
@staticmethod
58-
def from_channel(channel: 'KrausChannel', key: Union[str, 'cirq.MeasurementKey', None] = None):
58+
def from_channel(channel: 'cirq.Gate', key: Union[str, 'cirq.MeasurementKey', None] = None):
5959
"""Creates a copy of a channel with the given measurement key."""
6060
return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key)
6161

cirq-core/cirq/ops/tags.py

+3
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ def __repr__(self) -> str:
3434

3535
def _json_dict_(self) -> Dict[str, str]:
3636
return {}
37+
38+
def __hash__(self):
39+
return hash(VirtualTag)

cirq-core/cirq/sim/mux.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from cirq._doc import document
2626
from cirq.sim import sparse_simulator, density_matrix_simulator
2727
from cirq.sim.clifford import clifford_simulator
28+
from cirq.transformers import measurement_transformers
2829

2930
if TYPE_CHECKING:
3031
import cirq
@@ -281,9 +282,10 @@ def final_density_matrix(
281282
dtype=dtype,
282283
noise=noise,
283284
seed=seed,
284-
ignore_measurement_results=(ignore_measurement_results),
285285
).simulate(
286-
program=circuit_like,
286+
program=measurement_transformers.dephase_measurements(circuit_like)
287+
if ignore_measurement_results
288+
else circuit_like,
287289
initial_state=initial_state,
288290
qubit_order=qubit_order,
289291
param_resolver=param_resolver,

cirq-core/cirq/transformers/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353

5454
from cirq.transformers.eject_z import eject_z
5555

56+
from cirq.transformers.measurement_transformers import (
57+
defer_measurements,
58+
dephase_measurements,
59+
)
60+
5661
from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements
5762

5863
from cirq.transformers.transformer_api import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
16+
17+
from cirq import ops, protocols, value
18+
from cirq.transformers import (
19+
transformer_api,
20+
transformer_primitives,
21+
)
22+
from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements
23+
24+
if TYPE_CHECKING:
25+
import cirq
26+
27+
28+
class _MeasurementQid(ops.Qid):
29+
"""A qubit that substitutes in for a deferred measurement.
30+
31+
Exactly one qubit will be created per qubit in the measurement gate.
32+
"""
33+
34+
def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'):
35+
"""Initializes the qubit.
36+
37+
Args:
38+
key: The key of the measurement gate being deferred.
39+
qid: One qubit that is being measured. Each deferred measurement
40+
should create one new _MeasurementQid per qubit being measured
41+
by that gate.
42+
"""
43+
self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key
44+
self._qid = qid
45+
46+
@property
47+
def dimension(self) -> int:
48+
return self._qid.dimension
49+
50+
def _comparison_key(self) -> Any:
51+
return (str(self._key), self._qid._comparison_key())
52+
53+
def __str__(self) -> str:
54+
return f"M('{self._key}', q={self._qid})"
55+
56+
def __repr__(self) -> str:
57+
return f'_MeasurementQid({self._key!r}, {self._qid!r})'
58+
59+
60+
@transformer_api.transformer
61+
def defer_measurements(
62+
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
63+
) -> 'cirq.Circuit':
64+
"""Implements the Deferred Measurement Principle.
65+
66+
Uses the Deferred Measurement Principle to move all measurements to the
67+
end of the circuit. All non-terminal measurements are changed to
68+
conditional quantum gates onto ancilla qubits, and classically controlled
69+
operations are transformed to quantum controls from those ancilla qubits.
70+
Finally, measurements of all ancilla qubits are appended to the end of the
71+
circuit.
72+
73+
Optimizing deferred measurements is an area of active research, and future
74+
iterations may contain optimizations that reduce the number of ancilla
75+
qubits, so one should not depend on the exact shape of the output from this
76+
function. Only the logical equivalence is guaranteed to remain unchanged.
77+
Moment and subcircuit structure is not preserved.
78+
79+
Args:
80+
circuit: The circuit to transform. It will not be modified.
81+
context: `cirq.TransformerContext` storing common configurable options
82+
for transformers.
83+
Returns:
84+
A circuit with equivalent logic, but all measurements at the end of the
85+
circuit.
86+
Raises:
87+
ValueError: If sympy-based classical conditions are used, or if
88+
conditions based on multi-qubit measurements exist. (The latter of
89+
these is planned to be implemented soon).
90+
"""
91+
92+
circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None)
93+
terminal_measurements = {op for _, op in find_terminal_measurements(circuit)}
94+
measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {}
95+
96+
def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
97+
if op in terminal_measurements:
98+
return op
99+
gate = op.gate
100+
if isinstance(gate, ops.MeasurementGate):
101+
key = value.MeasurementKey.parse_serialized(gate.key)
102+
targets = [_MeasurementQid(key, q) for q in op.qubits]
103+
measurement_qubits[key] = targets
104+
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
105+
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
106+
return cxs + xs
107+
elif protocols.is_measurement(op):
108+
return [defer(op, None) for op in protocols.decompose_once(op)]
109+
elif op.classical_controls:
110+
controls = []
111+
for c in op.classical_controls:
112+
if isinstance(c, value.KeyCondition):
113+
if c.key not in measurement_qubits:
114+
raise ValueError(f'Deferred measurement for key={c.key} not found.')
115+
qubits = measurement_qubits[c.key]
116+
if len(qubits) != 1:
117+
# TODO: Multi-qubit conditions require
118+
# https://github.com/quantumlib/Cirq/issues/4512
119+
# Remember to update docstring above once this works.
120+
raise ValueError('Only single qubit conditions are allowed.')
121+
controls.extend(qubits)
122+
else:
123+
raise ValueError('Only KeyConditions are allowed.')
124+
return op.without_classical_controls().controlled_by(
125+
*controls, control_values=[tuple(range(1, q.dimension)) for q in controls]
126+
)
127+
return op
128+
129+
circuit = transformer_primitives.map_operations_and_unroll(
130+
circuit=circuit,
131+
map_func=defer,
132+
tags_to_ignore=context.tags_to_ignore if context else (),
133+
raise_if_add_qubits=False,
134+
).unfreeze()
135+
for k, qubits in measurement_qubits.items():
136+
circuit.append(ops.measure(*qubits, key=k))
137+
return circuit
138+
139+
140+
@transformer_api.transformer
141+
def dephase_measurements(
142+
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
143+
) -> 'cirq.Circuit':
144+
"""Changes all measurements to a dephase operation.
145+
146+
This transformer is useful when using a density matrix simulator, when
147+
wishing to calculate the final density matrix of a circuit and not simulate
148+
the measurements themselves.
149+
150+
Args:
151+
circuit: The circuit to transform. It will not be modified.
152+
context: `cirq.TransformerContext` storing common configurable options
153+
for transformers.
154+
Returns:
155+
A copy of the circuit, with dephase operations in place of all
156+
measurements.
157+
Raises:
158+
ValueError: If the circuit contains classical controls. In this case,
159+
it is required to change these to quantum controls via
160+
`cirq.defer_measurements` first. Since deferral adds ancilla qubits
161+
to the circuit, this is not done automatically, to prevent
162+
surprises.
163+
"""
164+
165+
def dephase(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
166+
gate = op.gate
167+
if isinstance(gate, ops.MeasurementGate):
168+
key = value.MeasurementKey.parse_serialized(gate.key)
169+
return ops.KrausChannel.from_channel(ops.phase_damp(1), key=key).on_each(op.qubits)
170+
elif isinstance(op, ops.ClassicallyControlledOperation):
171+
raise ValueError('Use cirq.defer_measurements first to remove classical controls.')
172+
return op
173+
174+
ignored = () if context is None else context.tags_to_ignore
175+
return transformer_primitives.map_operations(
176+
circuit, dephase, deep=True, tags_to_ignore=ignored
177+
).unfreeze()

0 commit comments

Comments
 (0)