Skip to content

Commit 185b53a

Browse files
daxfohlMichaelBroughton
authored andcommitted
Implement GlobalPhaseGate (quantumlib#4697)
Implements GlobalPhaseOperation in terms of a GateOperation on a new class GlobalPhaseGate. Mostly involved moving existing functions from the operation to the gate, and then having the operation call those methods under the hood.
1 parent 9e94b9e commit 185b53a

33 files changed

+253
-98
lines changed

cirq-core/cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@
219219
generalized_amplitude_damp,
220220
GeneralizedAmplitudeDampingChannel,
221221
givens,
222+
GlobalPhaseGate,
222223
GlobalPhaseOperation,
224+
global_phase_operation,
223225
H,
224226
HPowGate,
225227
I,

cirq-core/cirq/circuits/circuit.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ def default_namer(label_entity):
11931193
diagram.write(0, i, name)
11941194
first_annotation_row = max(label_map.values(), default=0) + 1
11951195

1196-
if any(isinstance(op.untagged, cirq.GlobalPhaseOperation) for op in self.all_operations()):
1196+
if any(isinstance(op.gate, cirq.GlobalPhaseGate) for op in self.all_operations()):
11971197
diagram.write(0, max(label_map.values(), default=0) + 1, 'global phase:')
11981198
first_annotation_row += 1
11991199

@@ -2359,7 +2359,7 @@ def _get_moment_annotations(
23592359
if op.qubits:
23602360
continue
23612361
op = op.untagged
2362-
if isinstance(op, ops.GlobalPhaseOperation):
2362+
if isinstance(op.gate, ops.GlobalPhaseGate):
23632363
continue
23642364
if isinstance(op, CircuitOperation):
23652365
for m in op.circuit:
@@ -2493,8 +2493,8 @@ def _draw_moment_in_diagram(
24932493

24942494

24952495
def _get_global_phase_and_tags_for_op(op: 'cirq.Operation') -> Tuple[Optional[complex], List[Any]]:
2496-
if isinstance(op.untagged, ops.GlobalPhaseOperation):
2497-
return complex(op.untagged.coefficient), list(op.tags)
2496+
if isinstance(op.gate, ops.GlobalPhaseGate):
2497+
return complex(op.gate.coefficient), list(op.tags)
24982498
elif isinstance(op.untagged, CircuitOperation):
24992499
op_phase, op_tags = _get_global_phase_and_tags_for_ops(op.untagged.circuit.all_operations())
25002500
return op_phase, list(op.tags) + op_tags

cirq-core/cirq/circuits/circuit_operation_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,11 @@ def test_string_format():
324324
assert str(op0) == f"[ ]"
325325

326326
fc0_global_phase_inner = cirq.FrozenCircuit(
327-
cirq.GlobalPhaseOperation(1j), cirq.GlobalPhaseOperation(1j)
327+
cirq.global_phase_operation(1j), cirq.global_phase_operation(1j)
328328
)
329329
op0_global_phase_inner = cirq.CircuitOperation(fc0_global_phase_inner)
330330
fc0_global_phase_outer = cirq.FrozenCircuit(
331-
op0_global_phase_inner, cirq.GlobalPhaseOperation(1j)
331+
op0_global_phase_inner, cirq.global_phase_operation(1j)
332332
)
333333
op0_global_phase_outer = cirq.CircuitOperation(fc0_global_phase_outer)
334334
assert (

cirq-core/cirq/circuits/circuit_test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2568,7 +2568,7 @@ def test_diagram_wgate_none_precision(circuit_cls):
25682568
@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
25692569
def test_diagram_global_phase(circuit_cls):
25702570
qa = cirq.NamedQubit('a')
2571-
global_phase = cirq.GlobalPhaseOperation(coefficient=1j)
2571+
global_phase = cirq.global_phase_operation(coefficient=1j)
25722572
c = circuit_cls([global_phase])
25732573
cirq.testing.assert_has_diagram(
25742574
c, "\n\nglobal phase: 0.5pi", use_unicode_characters=False, precision=2
@@ -2601,7 +2601,9 @@ def test_diagram_global_phase(circuit_cls):
26012601

26022602
c = circuit_cls(
26032603
cirq.X(cirq.LineQubit(2)),
2604-
cirq.CircuitOperation(circuit_cls(cirq.GlobalPhaseOperation(-1).with_tags("tag")).freeze()),
2604+
cirq.CircuitOperation(
2605+
circuit_cls(cirq.global_phase_operation(-1).with_tags("tag")).freeze()
2606+
),
26052607
)
26062608
cirq.testing.assert_has_diagram(
26072609
c,
@@ -5131,7 +5133,7 @@ def _circuit_diagram_info_(self, args) -> str:
51315133
cirq.Moment(
51325134
cirq.H(cirq.LineQubit(0)),
51335135
CustomOperationAnnotation("a"),
5134-
cirq.GlobalPhaseOperation(1j),
5136+
cirq.global_phase_operation(1j),
51355137
),
51365138
),
51375139
"""

cirq-core/cirq/interop/quirk/cells/scalar_cells.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222

2323
def generate_all_scalar_cell_makers() -> Iterator[CellMaker]:
24-
yield _scalar("NeGate", ops.GlobalPhaseOperation(-1))
25-
yield _scalar("i", ops.GlobalPhaseOperation(1j))
26-
yield _scalar("-i", ops.GlobalPhaseOperation(-1j))
27-
yield _scalar("√i", ops.GlobalPhaseOperation(1j ** 0.5))
28-
yield _scalar("√-i", ops.GlobalPhaseOperation((-1j) ** 0.5))
24+
yield _scalar("NeGate", ops.global_phase_operation(-1))
25+
yield _scalar("i", ops.global_phase_operation(1j))
26+
yield _scalar("-i", ops.global_phase_operation(-1j))
27+
yield _scalar("√i", ops.global_phase_operation(1j ** 0.5))
28+
yield _scalar("√-i", ops.global_phase_operation((-1j) ** 0.5))
2929

3030

3131
def _scalar(identifier: str, operation: 'cirq.Operation') -> CellMaker:

cirq-core/cirq/interop/quirk/cells/scalar_cells_test.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@ def test_scalar_operations():
2020
assert_url_to_circuit_returns('{"cols":[["…"]]}', cirq.Circuit())
2121

2222
assert_url_to_circuit_returns(
23-
'{"cols":[["NeGate"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(-1))
23+
'{"cols":[["NeGate"]]}', cirq.Circuit(cirq.global_phase_operation(-1))
2424
)
2525

26-
assert_url_to_circuit_returns('{"cols":[["i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j)))
26+
assert_url_to_circuit_returns('{"cols":[["i"]]}', cirq.Circuit(cirq.global_phase_operation(1j)))
2727

28-
assert_url_to_circuit_returns('{"cols":[["-i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(-1j)))
28+
assert_url_to_circuit_returns(
29+
'{"cols":[["-i"]]}', cirq.Circuit(cirq.global_phase_operation(-1j))
30+
)
2931

3032
assert_url_to_circuit_returns(
31-
'{"cols":[["√i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j ** 0.5))
33+
'{"cols":[["√i"]]}', cirq.Circuit(cirq.global_phase_operation(1j ** 0.5))
3234
)
3335

3436
assert_url_to_circuit_returns(
35-
'{"cols":[["√-i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j ** -0.5))
37+
'{"cols":[["√-i"]]}', cirq.Circuit(cirq.global_phase_operation(1j ** -0.5))
3638
)

cirq-core/cirq/json_resolver_cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _parallel_gate_op(gate, qubits):
8787
'GateOperation': cirq.GateOperation,
8888
'Gateset': cirq.Gateset,
8989
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
90+
'GlobalPhaseGate': cirq.GlobalPhaseGate,
9091
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
9192
'GridInteractionLayer': GridInteractionLayer,
9293
'GridParallelXEBMetadata': GridParallelXEBMetadata,

cirq-core/cirq/linalg/decompositions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def _decompose_(self, qubits):
535535

536536
a, b = qubits
537537
return [
538-
ops.GlobalPhaseOperation(self.global_phase),
538+
ops.global_phase_operation(self.global_phase),
539539
ops.MatrixGate(self.single_qubit_operations_before[0]).on(a),
540540
ops.MatrixGate(self.single_qubit_operations_before[1]).on(b),
541541
np.exp(1j * ops.X(a) * ops.X(b) * self.interaction_coefficients[0]),

cirq-core/cirq/ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@
123123
)
124124

125125
from cirq.ops.global_phase_op import (
126+
GlobalPhaseGate,
126127
GlobalPhaseOperation,
128+
global_phase_operation,
127129
)
128130

129131
from cirq.ops.kraus_channel import (

cirq-core/cirq/ops/dense_pauli_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _decompose_(self, qubits):
162162
return NotImplemented
163163
result = [PAULI_GATES[p].on(q) for p, q in zip(self.pauli_mask, qubits) if p]
164164
if self.coefficient != 1:
165-
result.append(global_phase_op.GlobalPhaseOperation(self.coefficient))
165+
result.append(global_phase_op.global_phase_operation(self.coefficient))
166166
return result
167167

168168
def _is_parameterized_(self) -> bool:

cirq-core/cirq/ops/diagonal_gate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
187187
# diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence,
188188
# we add global phase.
189189
decomposed_circ: List[Any] = [
190-
global_phase_op.GlobalPhaseOperation(np.exp(1j * hat_angles[0]))
190+
global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))
191191
]
192192
for i, bit_flip in _gen_gray_code(n):
193193
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))

cirq-core/cirq/ops/gate_operation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __repr__(self):
128128

129129
def __str__(self) -> str:
130130
qubits = ', '.join(str(e) for e in self.qubits)
131-
return f'{self.gate}({qubits})'
131+
return f'{self.gate}({qubits})' if qubits else str(self.gate)
132132

133133
def _json_dict_(self) -> Dict[str, Any]:
134134
return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])

cirq-core/cirq/ops/gateset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
326326
g = item if isinstance(item, raw_types.Gate) else item.gate
327327
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'
328328

329+
if isinstance(g, global_phase_op.GlobalPhaseGate):
330+
return self._accept_global_phase_op
331+
329332
if g in self._instance_gate_families:
330333
assert item in self._instance_gate_families[g], (
331334
f"{item} instance matches {self._instance_gate_families[g]} but "
@@ -396,8 +399,6 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
396399
lambda q: cast(circuit_operation.CircuitOperation, op).qubit_map.get(q, q)
397400
)
398401
return self.validate(op_circuit)
399-
elif isinstance(op, global_phase_op.GlobalPhaseOperation):
400-
return self._accept_global_phase_op
401402
else:
402403
return False
403404

cirq-core/cirq/ops/gateset_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_gate_family_eq():
132132
(cirq.SingleQubitGate(), False),
133133
(cirq.X ** 0.5, False),
134134
(None, False),
135-
(cirq.GlobalPhaseOperation(1j), False),
135+
(cirq.global_phase_operation(1j), False),
136136
],
137137
),
138138
(
@@ -144,7 +144,7 @@ def test_gate_family_eq():
144144
(CustomX ** 3, True),
145145
(CustomX ** sympy.Symbol('theta'), False),
146146
(None, False),
147-
(cirq.GlobalPhaseOperation(1j), False),
147+
(cirq.global_phase_operation(1j), False),
148148
],
149149
),
150150
(
@@ -255,7 +255,7 @@ def get_ops(use_circuit_op, use_global_phase):
255255
)
256256
yield [circuit_op, recursive_circuit_op]
257257
if use_global_phase:
258-
yield cirq.GlobalPhaseOperation(1j)
258+
yield cirq.global_phase_operation(1j)
259259

260260
def assert_validate_and_contains_consistent(gateset, op_tree, result):
261261
assert all(op in gateset for op in cirq.flatten_to_ops(op_tree)) is result

cirq-core/cirq/ops/global_phase_op.py

+51-14
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,69 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""A no-qubit global phase operation."""
15-
from typing import Any, Dict, Tuple, TYPE_CHECKING
15+
from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING
1616

1717
import numpy as np
1818

1919
from cirq import value, protocols
20-
from cirq.ops import raw_types
20+
from cirq._compat import deprecated_class
21+
from cirq.ops import gate_operation, raw_types
2122

2223
if TYPE_CHECKING:
2324
import cirq
2425

2526

2627
@value.value_equality(approximate=True)
27-
class GlobalPhaseOperation(raw_types.Operation):
28+
@deprecated_class(deadline='v0.16', fix='Use cirq.global_phase_operation')
29+
class GlobalPhaseOperation(gate_operation.GateOperation):
2830
def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None:
29-
if abs(1 - abs(coefficient)) > atol:
30-
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
31-
self.coefficient = coefficient
32-
33-
@property
34-
def qubits(self) -> Tuple['cirq.Qid', ...]:
35-
return ()
31+
gate = GlobalPhaseGate(coefficient, atol)
32+
super().__init__(gate, [])
3633

3734
def with_qubits(self, *new_qubits) -> 'GlobalPhaseOperation':
3835
if new_qubits:
3936
raise ValueError(f'{self!r} applies to 0 qubits but new_qubits={new_qubits!r}.')
4037
return self
4138

39+
@property
40+
def coefficient(self) -> value.Scalar:
41+
return self.gate.coefficient # type: ignore
42+
43+
@coefficient.setter
44+
def coefficient(self, coefficient: value.Scalar):
45+
# coverage: ignore
46+
self.gate._coefficient = coefficient # type: ignore
47+
48+
def __str__(self) -> str:
49+
return str(self.coefficient)
50+
51+
def __repr__(self) -> str:
52+
return f'cirq.GlobalPhaseOperation({self.coefficient!r})'
53+
54+
def _json_dict_(self) -> Dict[str, Any]:
55+
return protocols.obj_to_dict_helper(self, ['coefficient'])
56+
57+
58+
@value.value_equality(approximate=True)
59+
class GlobalPhaseGate(raw_types.Gate):
60+
def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None:
61+
if abs(1 - abs(coefficient)) > atol:
62+
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
63+
self._coefficient = coefficient
64+
65+
@property
66+
def coefficient(self) -> value.Scalar:
67+
return self._coefficient
68+
4269
def _value_equality_values_(self) -> Any:
4370
return self.coefficient
4471

4572
def _has_unitary_(self) -> bool:
4673
return True
4774

48-
def __pow__(self, power):
75+
def __pow__(self, power) -> 'cirq.GlobalPhaseGate':
4976
if isinstance(power, (int, float)):
50-
return GlobalPhaseOperation(self.coefficient ** power)
77+
return GlobalPhaseGate(self.coefficient ** power)
5178
return NotImplemented
5279

5380
def _unitary_(self) -> np.ndarray:
@@ -60,7 +87,7 @@ def _apply_unitary_(self, args) -> np.ndarray:
6087
def _has_stabilizer_effect_(self) -> bool:
6188
return True
6289

63-
def _act_on_(self, args: 'cirq.ActOnArgs'):
90+
def _act_on_(self, args: 'cirq.ActOnArgs', qubits):
6491
from cirq.sim import clifford
6592

6693
if isinstance(args, clifford.ActOnCliffordTableauArgs):
@@ -78,7 +105,17 @@ def __str__(self) -> str:
78105
return str(self.coefficient)
79106

80107
def __repr__(self) -> str:
81-
return f'cirq.GlobalPhaseOperation({self.coefficient!r})'
108+
return f'cirq.GlobalPhaseGate({self.coefficient!r})'
109+
110+
def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
111+
return f'cirq.global_phase_operation({self.coefficient!r})'
82112

83113
def _json_dict_(self) -> Dict[str, Any]:
84114
return protocols.obj_to_dict_helper(self, ['coefficient'])
115+
116+
def _qid_shape_(self) -> Tuple[int, ...]:
117+
return tuple()
118+
119+
120+
def global_phase_operation(coefficient: value.Scalar, atol: float = 1e-8) -> 'cirq.GateOperation':
121+
return GlobalPhaseGate(coefficient, atol)()

0 commit comments

Comments
 (0)