Skip to content

Commit bfa2602

Browse files
authored
Make PauliMeasurementGate respect sign of the pauli observable. (#4836)
Fixes #4814 Note that this is a breaking change because: - Serialization of the `PauliMeasurementGate` is now different -- the serialized observable is `DensePauliString` instead of a tuple of Pauli's. - A DensePauliString with coefficient != +1/-1 will now raise a `ValueError` whereas earlier the coefficient was simply ignored.
1 parent 3a6ad87 commit bfa2602

6 files changed

+158
-56
lines changed

cirq-core/cirq/ops/pauli_measurement_gate.py

+38-15
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union
15+
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union, cast
1616

1717
from cirq import protocols, value
1818
from cirq.ops import (
1919
raw_types,
2020
measurement_gate,
2121
op_tree,
22-
dense_pauli_string,
22+
dense_pauli_string as dps,
2323
pauli_gates,
2424
pauli_string_phasor,
2525
)
@@ -38,25 +38,36 @@ class PauliMeasurementGate(raw_types.Gate):
3838

3939
def __init__(
4040
self,
41-
observable: Iterable['cirq.Pauli'],
41+
observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']],
4242
key: Union[str, 'cirq.MeasurementKey'] = '',
4343
) -> None:
4444
"""Inits PauliMeasurementGate.
4545
4646
Args:
4747
observable: Pauli observable to measure. Any `Iterable[cirq.Pauli]`
48-
is a valid Pauli observable, including `cirq.DensePauliString`
49-
instances, which do not contain any identity gates.
48+
is a valid Pauli observable (with a +1 coefficient by default).
49+
If you wish to measure pauli observables with coefficient -1,
50+
then pass a `cirq.DensePauliString` as observable.
5051
key: The string key of the measurement.
5152
5253
Raises:
5354
ValueError: If the observable is empty.
5455
"""
5556
if not observable:
5657
raise ValueError(f'Pauli observable {observable} is empty.')
57-
if not all(isinstance(p, pauli_gates.Pauli) for p in observable):
58+
if not all(
59+
isinstance(p, pauli_gates.Pauli) for p in cast(Iterable['cirq.Gate'], observable)
60+
):
5861
raise ValueError(f'Pauli observable {observable} must be Iterable[`cirq.Pauli`].')
59-
self._observable = tuple(observable)
62+
coefficient = (
63+
observable.coefficient if isinstance(observable, dps.BaseDensePauliString) else 1
64+
)
65+
if coefficient not in [+1, -1]:
66+
raise ValueError(
67+
f'`cirq.DensePauliString` observable {observable} must have coefficient +1/-1.'
68+
)
69+
70+
self._observable = dps.DensePauliString(observable, coefficient=coefficient)
6071
self.key = key # type: ignore
6172

6273
@property
@@ -94,9 +105,15 @@ def _with_rescoped_keys_(
94105
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate':
95106
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))
96107

97-
def with_observable(self, observable: Iterable['cirq.Pauli']) -> 'PauliMeasurementGate':
108+
def with_observable(
109+
self, observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']]
110+
) -> 'PauliMeasurementGate':
98111
"""Creates a pauli measurement gate with the new observable and same key."""
99-
if tuple(observable) == self._observable:
112+
if (
113+
observable
114+
if isinstance(observable, dps.BaseDensePauliString)
115+
else dps.DensePauliString(observable)
116+
) == self._observable:
100117
return self
101118
return PauliMeasurementGate(observable, key=self.key)
102119

@@ -111,24 +128,30 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
111128

112129
def observable(self) -> 'cirq.DensePauliString':
113130
"""Pauli observable which should be measured by the gate."""
114-
return dense_pauli_string.DensePauliString(self._observable)
131+
return self._observable
115132

116133
def _decompose_(
117134
self, qubits: Tuple['cirq.Qid', ...]
118135
) -> 'protocols.decompose_protocol.DecomposeResult':
119136
any_qubit = qubits[0]
120-
to_z_ops = op_tree.freeze_op_tree(self.observable().on(*qubits).to_z_basis_ops())
137+
to_z_ops = op_tree.freeze_op_tree(self._observable.on(*qubits).to_z_basis_ops())
121138
xor_decomp = tuple(pauli_string_phasor.xor_nonlocal_decompose(qubits, any_qubit))
122139
yield to_z_ops
123140
yield xor_decomp
124-
yield measurement_gate.MeasurementGate(1, self.mkey).on(any_qubit)
141+
yield measurement_gate.MeasurementGate(
142+
1, self.mkey, invert_mask=(self._observable.coefficient != 1,)
143+
).on(any_qubit)
125144
yield protocols.inverse(xor_decomp)
126145
yield protocols.inverse(to_z_ops)
127146

128147
def _circuit_diagram_info_(
129148
self, args: 'cirq.CircuitDiagramInfoArgs'
130149
) -> 'cirq.CircuitDiagramInfo':
131-
symbols = [f'M({g})' for g in self._observable]
150+
coefficient = '' if self._observable.coefficient == 1 else '-'
151+
symbols = [
152+
f'M({"" if i else coefficient}{self._observable[i]})'
153+
for i in range(len(self._observable))
154+
]
132155

133156
# Mention the measurement key.
134157
label_map = args.label_map or {}
@@ -141,14 +164,14 @@ def _circuit_diagram_info_(
141164
return protocols.CircuitDiagramInfo(tuple(symbols))
142165

143166
def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
144-
args = [repr(self.observable().on(*qubits))]
167+
args = [repr(self._observable.on(*qubits))]
145168
if self.key != _default_measurement_key(qubits):
146169
args.append(f'key={self.mkey!r}')
147170
arg_list = ', '.join(args)
148171
return f'cirq.measure_single_paulistring({arg_list})'
149172

150173
def __repr__(self) -> str:
151-
return f'cirq.PauliMeasurementGate(' f'{self._observable!r}, ' f'{self.mkey!r})'
174+
return f'cirq.PauliMeasurementGate({self._observable!r}, {self.mkey!r})'
152175

153176
def _value_equality_values_(self) -> Any:
154177
return self.key, self._observable

cirq-core/cirq/ops/pauli_measurement_gate_test.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_init(observable, key):
4343
assert g.num_qubits() == len(observable)
4444
assert g.key == 'a'
4545
assert g.mkey == cirq.MeasurementKey('a')
46-
assert g._observable == tuple(observable)
46+
assert g._observable == cirq.DensePauliString(observable)
4747
assert cirq.qid_shape(g) == (2,) * len(observable)
4848

4949

@@ -162,6 +162,9 @@ def test_bad_observable_raises():
162162
with pytest.raises(ValueError, match=r'Pauli observable .* must be Iterable\[`cirq.Pauli`\]'):
163163
_ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZI'))
164164

165+
with pytest.raises(ValueError, match=r'must have coefficient \+1/-1.'):
166+
_ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZ', coefficient=1j))
167+
165168

166169
def test_with_observable():
167170
o1 = [cirq.Z, cirq.Y, cirq.X]
@@ -170,3 +173,20 @@ def test_with_observable():
170173
g2 = cirq.PauliMeasurementGate(o2, key='a')
171174
assert g1.with_observable(o2) == g2
172175
assert g1.with_observable(o1) is g1
176+
177+
178+
@pytest.mark.parametrize(
179+
'rot, obs, out',
180+
[
181+
(cirq.I, cirq.DensePauliString("Z", coefficient=+1), 0),
182+
(cirq.I, cirq.DensePauliString("Z", coefficient=-1), 1),
183+
(cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=+1), 0),
184+
(cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=-1), 1),
185+
(cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=+1), 0),
186+
(cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=-1), 1),
187+
],
188+
)
189+
def test_pauli_measurement_gate_samples(rot, obs, out):
190+
q = cirq.NamedQubit("q")
191+
c = cirq.Circuit(rot(q), cirq.PauliMeasurementGate(obs, key='out').on(q))
192+
assert cirq.Simulator().sample(c)['out'][0] == out
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,53 @@
1-
[{
2-
"cirq_type": "PauliMeasurementGate",
3-
"observable": [
4-
{
5-
"cirq_type": "_PauliX",
6-
"exponent": 1.0,
7-
"global_shift": 0.0
1+
[
2+
{
3+
"cirq_type": "PauliMeasurementGate",
4+
"observable": {
5+
"cirq_type": "DensePauliString",
6+
"pauli_mask": [
7+
1,
8+
2,
9+
3
10+
],
11+
"coefficient": {
12+
"cirq_type": "complex",
13+
"real": 1.0,
14+
"imag": 0.0
15+
}
816
},
9-
{
10-
"cirq_type": "_PauliY",
11-
"exponent": 1.0,
12-
"global_shift": 0.0
17+
"key": "key"
18+
},
19+
{
20+
"cirq_type": "PauliMeasurementGate",
21+
"observable": {
22+
"cirq_type": "DensePauliString",
23+
"pauli_mask": [
24+
1,
25+
2,
26+
3
27+
],
28+
"coefficient": {
29+
"cirq_type": "complex",
30+
"real": 1.0,
31+
"imag": 0.0
32+
}
1333
},
14-
{
15-
"cirq_type": "_PauliZ",
16-
"exponent": 1.0,
17-
"global_shift": 0.0
18-
}
19-
],
20-
"key": "key"
21-
},
22-
{
23-
"cirq_type": "PauliMeasurementGate",
24-
"observable": [
25-
{
26-
"cirq_type": "_PauliX",
27-
"exponent": 1.0,
28-
"global_shift": 0.0
34+
"key": "p:q:key"
35+
},
36+
{
37+
"cirq_type": "PauliMeasurementGate",
38+
"observable": {
39+
"cirq_type": "DensePauliString",
40+
"pauli_mask": [
41+
1,
42+
2,
43+
3
44+
],
45+
"coefficient": {
46+
"cirq_type": "complex",
47+
"real": -1.0,
48+
"imag": 0.0
49+
}
2950
},
30-
{
31-
"cirq_type": "_PauliY",
32-
"exponent": 1.0,
33-
"global_shift": 0.0
34-
},
35-
{
36-
"cirq_type": "_PauliZ",
37-
"exponent": 1.0,
38-
"global_shift": 0.0
39-
}
40-
],
41-
"key": "p:q:key"
42-
}]
51+
"key": "key"
52+
}
53+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
[{
2+
"cirq_type": "PauliMeasurementGate",
3+
"observable": [
4+
{
5+
"cirq_type": "_PauliX",
6+
"exponent": 1.0,
7+
"global_shift": 0.0
8+
},
9+
{
10+
"cirq_type": "_PauliY",
11+
"exponent": 1.0,
12+
"global_shift": 0.0
13+
},
14+
{
15+
"cirq_type": "_PauliZ",
16+
"exponent": 1.0,
17+
"global_shift": 0.0
18+
}
19+
],
20+
"key": "key"
21+
},
22+
{
23+
"cirq_type": "PauliMeasurementGate",
24+
"observable": [
25+
{
26+
"cirq_type": "_PauliX",
27+
"exponent": 1.0,
28+
"global_shift": 0.0
29+
},
30+
{
31+
"cirq_type": "_PauliY",
32+
"exponent": 1.0,
33+
"global_shift": 0.0
34+
},
35+
{
36+
"cirq_type": "_PauliZ",
37+
"exponent": 1.0,
38+
"global_shift": 0.0
39+
}
40+
],
41+
"key": "p:q:key"
42+
}]
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[
22
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')),
3-
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')),
3+
cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ"), cirq.MeasurementKey(path=('p', 'q'), name='key')),
4+
cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ", coefficient=-1), cirq.MeasurementKey(name='key')),
45
]
56

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[
2+
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')),
3+
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')),
4+
]
5+

0 commit comments

Comments
 (0)