Skip to content

Commit a95f009

Browse files
authored
Override gate.controlled() for GlobalPhaseGate to return a ZPowGate (#6073)
* Override gate.controlled() for GlobalPhaseGate to return a ZPowGate * Test unitary equivalence * Override controlled only if gate is not parameterized * Fix typo * Fix type check * another attempt at fixing types * Add a comment and additional tests
1 parent f2cd706 commit a95f009

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

Diff for: cirq-core/cirq/ops/global_phase_op.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414
"""A no-qubit global phase operation."""
1515

16-
from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union
16+
from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union, Optional, Collection
1717

1818
import numpy as np
1919
import sympy
2020

2121
import cirq
2222
from cirq import value, protocols
23-
from cirq.ops import raw_types
23+
from cirq.ops import raw_types, controlled_gate, control_values as cv
2424
from cirq.type_workarounds import NotImplementedType
2525

2626

@@ -91,6 +91,32 @@ def _resolve_parameters_(
9191
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
9292
return GlobalPhaseGate(coefficient=coefficient)
9393

94+
def controlled(
95+
self,
96+
num_controls: Optional[int] = None,
97+
control_values: Optional[
98+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
99+
] = None,
100+
control_qid_shape: Optional[Tuple[int, ...]] = None,
101+
) -> raw_types.Gate:
102+
result = super().controlled(num_controls, control_values, control_qid_shape)
103+
if (
104+
not self._is_parameterized_()
105+
and isinstance(result, controlled_gate.ControlledGate)
106+
and isinstance(result.control_values, cv.ProductOfSums)
107+
and result.control_values[-1] == (1,)
108+
and result.control_qid_shape[-1] == 2
109+
):
110+
# A `GlobalPhaseGate` controlled on a qubit in state `|1>` is equivalent
111+
# to applying a `ZPowGate`. This override ensures that `global_phase_gate.controlled()`
112+
# returns a `ZPowGate` instead of a `ControlledGate(sub_gate=global_phase_gate)`.
113+
coefficient = complex(self.coefficient)
114+
exponent = float(np.angle(coefficient) / np.pi)
115+
return cirq.ZPowGate(exponent=exponent).controlled(
116+
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
117+
)
118+
return result
119+
94120

95121
def global_phase_operation(
96122
coefficient: 'cirq.TParamValComplex', atol: float = 1e-8

Diff for: cirq-core/cirq/ops/global_phase_op_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,24 @@ def test_resolve_error(resolve_fn):
279279
gpt = cirq.GlobalPhaseGate(coefficient=t)
280280
with pytest.raises(ValueError, match='Coefficient is not unitary'):
281281
resolve_fn(gpt, {'t': -2})
282+
283+
284+
@pytest.mark.parametrize(
285+
'coeff, exp', [(-1, 1), (1j, 0.5), (-1j, -0.5), (1 / np.sqrt(2) * (1 + 1j), 0.25)]
286+
)
287+
def test_global_phase_gate_controlled(coeff, exp):
288+
g = cirq.GlobalPhaseGate(coeff)
289+
op = cirq.global_phase_operation(coeff)
290+
q = cirq.LineQubit.range(3)
291+
for num_controls, target_gate in zip(range(1, 4), [cirq.Z, cirq.CZ, cirq.CCZ]):
292+
assert g.controlled(num_controls) == target_gate**exp
293+
np.testing.assert_allclose(
294+
cirq.unitary(cirq.ControlledGate(g, num_controls)),
295+
cirq.unitary(g.controlled(num_controls)),
296+
)
297+
assert op.controlled_by(*q[:num_controls]) == target_gate(*q[:num_controls]) ** exp
298+
assert g.controlled(control_values=[0]) == cirq.ControlledGate(g, control_values=[0])
299+
xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1)))
300+
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
301+
g, control_values=xor_control_values
302+
)

0 commit comments

Comments
 (0)