Skip to content

Override gate.controlled() for GlobalPhaseGate to return a ZPowGate #6073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 10, 2023
30 changes: 28 additions & 2 deletions cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""A no-qubit global phase operation."""

from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union
from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union, Optional, Collection

import numpy as np
import sympy

import cirq
from cirq import value, protocols
from cirq.ops import raw_types
from cirq.ops import raw_types, controlled_gate, control_values as cv
from cirq.type_workarounds import NotImplementedType


Expand Down Expand Up @@ -91,6 +91,32 @@ def _resolve_parameters_(
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
return GlobalPhaseGate(coefficient=coefficient)

def controlled(
self,
num_controls: Optional[int] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
not self._is_parameterized_()
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
# A `GlobalPhaseGate` controlled on a qubit in state `|1>` is equivalent
# to applying a `ZPowGate`. This override ensures that `global_phase_gate.controlled()`
# returns a `ZPowGate` instead of a `ControlledGate(sub_gate=global_phase_gate)`.
coefficient = complex(self.coefficient)
exponent = float(np.angle(coefficient) / np.pi)
return cirq.ZPowGate(exponent=exponent).controlled(
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
)
return result


def global_phase_operation(
coefficient: 'cirq.TParamValComplex', atol: float = 1e-8
Expand Down
21 changes: 21 additions & 0 deletions cirq-core/cirq/ops/global_phase_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,24 @@ def test_resolve_error(resolve_fn):
gpt = cirq.GlobalPhaseGate(coefficient=t)
with pytest.raises(ValueError, match='Coefficient is not unitary'):
resolve_fn(gpt, {'t': -2})


@pytest.mark.parametrize(
'coeff, exp', [(-1, 1), (1j, 0.5), (-1j, -0.5), (1 / np.sqrt(2) * (1 + 1j), 0.25)]
)
def test_global_phase_gate_controlled(coeff, exp):
g = cirq.GlobalPhaseGate(coeff)
op = cirq.global_phase_operation(coeff)
q = cirq.LineQubit.range(3)
for num_controls, target_gate in zip(range(1, 4), [cirq.Z, cirq.CZ, cirq.CCZ]):
assert g.controlled(num_controls) == target_gate**exp
np.testing.assert_allclose(
cirq.unitary(cirq.ControlledGate(g, num_controls)),
cirq.unitary(g.controlled(num_controls)),
)
assert op.controlled_by(*q[:num_controls]) == target_gate(*q[:num_controls]) ** exp
assert g.controlled(control_values=[0]) == cirq.ControlledGate(g, control_values=[0])
xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1)))
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
g, control_values=xor_control_values
)