|
13 | 13 | # limitations under the License.
|
14 | 14 | """A no-qubit global phase operation."""
|
15 | 15 |
|
16 |
| -from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union |
| 16 | +from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union, Optional, Collection |
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import sympy
|
20 | 20 |
|
21 | 21 | import cirq
|
22 | 22 | from cirq import value, protocols
|
23 |
| -from cirq.ops import raw_types |
| 23 | +from cirq.ops import raw_types, controlled_gate, control_values as cv |
24 | 24 | from cirq.type_workarounds import NotImplementedType
|
25 | 25 |
|
26 | 26 |
|
@@ -91,6 +91,32 @@ def _resolve_parameters_(
|
91 | 91 | coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
|
92 | 92 | return GlobalPhaseGate(coefficient=coefficient)
|
93 | 93 |
|
| 94 | + def controlled( |
| 95 | + self, |
| 96 | + num_controls: Optional[int] = None, |
| 97 | + control_values: Optional[ |
| 98 | + Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]] |
| 99 | + ] = None, |
| 100 | + control_qid_shape: Optional[Tuple[int, ...]] = None, |
| 101 | + ) -> raw_types.Gate: |
| 102 | + result = super().controlled(num_controls, control_values, control_qid_shape) |
| 103 | + if ( |
| 104 | + not self._is_parameterized_() |
| 105 | + and isinstance(result, controlled_gate.ControlledGate) |
| 106 | + and isinstance(result.control_values, cv.ProductOfSums) |
| 107 | + and result.control_values[-1] == (1,) |
| 108 | + and result.control_qid_shape[-1] == 2 |
| 109 | + ): |
| 110 | + # A `GlobalPhaseGate` controlled on a qubit in state `|1>` is equivalent |
| 111 | + # to applying a `ZPowGate`. This override ensures that `global_phase_gate.controlled()` |
| 112 | + # returns a `ZPowGate` instead of a `ControlledGate(sub_gate=global_phase_gate)`. |
| 113 | + coefficient = complex(self.coefficient) |
| 114 | + exponent = float(np.angle(coefficient) / np.pi) |
| 115 | + return cirq.ZPowGate(exponent=exponent).controlled( |
| 116 | + result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1] |
| 117 | + ) |
| 118 | + return result |
| 119 | + |
94 | 120 |
|
95 | 121 | def global_phase_operation(
|
96 | 122 | coefficient: 'cirq.TParamValComplex', atol: float = 1e-8
|
|
0 commit comments