12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Any , Dict , FrozenSet , Iterable , Tuple , Sequence , TYPE_CHECKING , Union
15
+ from typing import Any , Dict , FrozenSet , Iterable , Tuple , Sequence , TYPE_CHECKING , Union , cast
16
16
17
17
from cirq import protocols , value
18
18
from cirq .ops import (
19
19
raw_types ,
20
20
measurement_gate ,
21
21
op_tree ,
22
- dense_pauli_string ,
22
+ dense_pauli_string as dps ,
23
23
pauli_gates ,
24
24
pauli_string_phasor ,
25
25
)
@@ -38,25 +38,36 @@ class PauliMeasurementGate(raw_types.Gate):
38
38
39
39
def __init__ (
40
40
self ,
41
- observable : Iterable ['cirq.Pauli' ],
41
+ observable : Union [ 'cirq.BaseDensePauliString' , Iterable ['cirq.Pauli' ] ],
42
42
key : Union [str , 'cirq.MeasurementKey' ] = '' ,
43
43
) -> None :
44
44
"""Inits PauliMeasurementGate.
45
45
46
46
Args:
47
47
observable: Pauli observable to measure. Any `Iterable[cirq.Pauli]`
48
- is a valid Pauli observable, including `cirq.DensePauliString`
49
- instances, which do not contain any identity gates.
48
+ is a valid Pauli observable (with a +1 coefficient by default).
49
+ If you wish to measure pauli observables with coefficient -1,
50
+ then pass a `cirq.DensePauliString` as observable.
50
51
key: The string key of the measurement.
51
52
52
53
Raises:
53
54
ValueError: If the observable is empty.
54
55
"""
55
56
if not observable :
56
57
raise ValueError (f'Pauli observable { observable } is empty.' )
57
- if not all (isinstance (p , pauli_gates .Pauli ) for p in observable ):
58
+ if not all (
59
+ isinstance (p , pauli_gates .Pauli ) for p in cast (Iterable ['cirq.Gate' ], observable )
60
+ ):
58
61
raise ValueError (f'Pauli observable { observable } must be Iterable[`cirq.Pauli`].' )
59
- self ._observable = tuple (observable )
62
+ coefficient = (
63
+ observable .coefficient if isinstance (observable , dps .BaseDensePauliString ) else 1
64
+ )
65
+ if coefficient not in [+ 1 , - 1 ]:
66
+ raise ValueError (
67
+ f'`cirq.DensePauliString` observable { observable } must have coefficient +1/-1.'
68
+ )
69
+
70
+ self ._observable = dps .DensePauliString (observable , coefficient = coefficient )
60
71
self .key = key # type: ignore
61
72
62
73
@property
@@ -94,9 +105,15 @@ def _with_rescoped_keys_(
94
105
def _with_measurement_key_mapping_ (self , key_map : Dict [str , str ]) -> 'PauliMeasurementGate' :
95
106
return self .with_key (protocols .with_measurement_key_mapping (self .mkey , key_map ))
96
107
97
- def with_observable (self , observable : Iterable ['cirq.Pauli' ]) -> 'PauliMeasurementGate' :
108
+ def with_observable (
109
+ self , observable : Union ['cirq.BaseDensePauliString' , Iterable ['cirq.Pauli' ]]
110
+ ) -> 'PauliMeasurementGate' :
98
111
"""Creates a pauli measurement gate with the new observable and same key."""
99
- if tuple (observable ) == self ._observable :
112
+ if (
113
+ observable
114
+ if isinstance (observable , dps .BaseDensePauliString )
115
+ else dps .DensePauliString (observable )
116
+ ) == self ._observable :
100
117
return self
101
118
return PauliMeasurementGate (observable , key = self .key )
102
119
@@ -111,24 +128,30 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
111
128
112
129
def observable (self ) -> 'cirq.DensePauliString' :
113
130
"""Pauli observable which should be measured by the gate."""
114
- return dense_pauli_string . DensePauliString ( self ._observable )
131
+ return self ._observable
115
132
116
133
def _decompose_ (
117
134
self , qubits : Tuple ['cirq.Qid' , ...]
118
135
) -> 'protocols.decompose_protocol.DecomposeResult' :
119
136
any_qubit = qubits [0 ]
120
- to_z_ops = op_tree .freeze_op_tree (self .observable () .on (* qubits ).to_z_basis_ops ())
137
+ to_z_ops = op_tree .freeze_op_tree (self ._observable .on (* qubits ).to_z_basis_ops ())
121
138
xor_decomp = tuple (pauli_string_phasor .xor_nonlocal_decompose (qubits , any_qubit ))
122
139
yield to_z_ops
123
140
yield xor_decomp
124
- yield measurement_gate .MeasurementGate (1 , self .mkey ).on (any_qubit )
141
+ yield measurement_gate .MeasurementGate (
142
+ 1 , self .mkey , invert_mask = (self ._observable .coefficient != 1 ,)
143
+ ).on (any_qubit )
125
144
yield protocols .inverse (xor_decomp )
126
145
yield protocols .inverse (to_z_ops )
127
146
128
147
def _circuit_diagram_info_ (
129
148
self , args : 'cirq.CircuitDiagramInfoArgs'
130
149
) -> 'cirq.CircuitDiagramInfo' :
131
- symbols = [f'M({ g } )' for g in self ._observable ]
150
+ coefficient = '' if self ._observable .coefficient == 1 else '-'
151
+ symbols = [
152
+ f'M({ "" if i else coefficient } { self ._observable [i ]} )'
153
+ for i in range (len (self ._observable ))
154
+ ]
132
155
133
156
# Mention the measurement key.
134
157
label_map = args .label_map or {}
@@ -141,14 +164,14 @@ def _circuit_diagram_info_(
141
164
return protocols .CircuitDiagramInfo (tuple (symbols ))
142
165
143
166
def _op_repr_ (self , qubits : Sequence ['cirq.Qid' ]) -> str :
144
- args = [repr (self .observable () .on (* qubits ))]
167
+ args = [repr (self ._observable .on (* qubits ))]
145
168
if self .key != _default_measurement_key (qubits ):
146
169
args .append (f'key={ self .mkey !r} ' )
147
170
arg_list = ', ' .join (args )
148
171
return f'cirq.measure_single_paulistring({ arg_list } )'
149
172
150
173
def __repr__ (self ) -> str :
151
- return f'cirq.PauliMeasurementGate(' f' { self ._observable !r} , ' f' { self .mkey !r} )'
174
+ return f'cirq.PauliMeasurementGate({ self ._observable !r} , { self .mkey !r} )'
152
175
153
176
def _value_equality_values_ (self ) -> Any :
154
177
return self .key , self ._observable
0 commit comments