Skip to content

Commit 098617c

Browse files
authored
Add support for decompositions of parameterized cirq.DiagonalGate (#5085)
- Adds support for decomposing parameterized `cirq.DiagonalGate`. - Global phase is ignored for parameterized version because `cirq.GlobalPhaseGate` doesn't yet support symbols. - Part of #4858
1 parent 2fb5651 commit 098617c

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

cirq-core/cirq/ops/diagonal_gate.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _value_equality_values_(self) -> Any:
144144
return tuple(self._diag_angles_radians)
145145

146146
def _decompose_for_basis(
147-
self, index: int, bit_flip: int, theta: float, qubits: Sequence['cirq.Qid']
147+
self, index: int, bit_flip: int, theta: value.TParamVal, qubits: Sequence['cirq.Qid']
148148
) -> Iterator[Union['cirq.ZPowGate', 'cirq.CXPowGate']]:
149149
if index == 0:
150150
return []
@@ -166,7 +166,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
166166
│ │ │ │
167167
2: ───Rz(1)───@───────────@───────────────────────@───────────────────────@───────────
168168
169-
where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transfrom
169+
where the angles in Rz gates are corresponding to the fast-walsh-Hadamard transform
170170
of diagonal_angles in the Gray Code order.
171171
172172
For n qubits decomposition looks similar but with 2^n-1 Rz gates and 2^n-2 CNOT gates.
@@ -176,19 +176,20 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
176176
ancillas." New Journal of Physics 16.3 (2014): 033040.
177177
https://iopscience.iop.org/article/10.1088/1367-2630/16/3/033040/meta
178178
"""
179-
if protocols.is_parameterized(self):
180-
return NotImplemented
181-
182179
n = self._num_qubits_()
183180
hat_angles = _fast_walsh_hadamard_transform(self._diag_angles_radians) / (2 ** n)
184181

185182
# There is one global phase shift between unitary matrix of the diagonal gate and the
186183
# decomposed gates. On its own it is not physically observable. However, if using this
187184
# diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence,
188185
# we add global phase.
189-
decomposed_circ: List[Any] = [
190-
global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))
191-
]
186+
# Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a
187+
# scalar value.
188+
decomposed_circ: List[Any] = (
189+
[global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))]
190+
if not protocols.is_parameterized(hat_angles[0])
191+
else []
192+
)
192193
for i, bit_flip in _gen_gray_code(n):
193194
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))
194195
return decomposed_circ

cirq-core/cirq/ops/diagonal_gate_test.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,24 @@ def test_decomposition_diagonal_exponent(n):
7777
np.testing.assert_allclose(decomposed_f, expected_f)
7878

7979

80-
def test_decomposition_with_parameterization():
81-
diagonal_gate = cirq.DiagonalGate([2, 3, 5, sympy.Symbol('a')])
82-
op = diagonal_gate(*cirq.LineQubit.range(2))
83-
84-
# We do not support the decomposition of parameterized case yet.
85-
# So cirq.decompose should do nothing.
86-
assert cirq.decompose(op) == [op]
80+
@pytest.mark.parametrize('n', [1, 2, 3, 4])
81+
def test_decomposition_with_parameterization(n):
82+
angles = sympy.symbols([f'x_{i}' for i in range(2 ** n)])
83+
exponent = sympy.Symbol('e')
84+
diagonal_gate = cirq.DiagonalGate(angles) ** exponent
85+
parameterized_op = diagonal_gate(*cirq.LineQubit.range(n))
86+
decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_op))
87+
for exponent_value in [-0.5, 0.5, 1]:
88+
for i in range(len(_candidate_angles) - 2 ** n + 1):
89+
resolver = {exponent: exponent_value}
90+
resolver.update(
91+
{angles[j]: x_j for j, x_j in enumerate(_candidate_angles[i : i + 2 ** n])}
92+
)
93+
resolved_op = cirq.resolve_parameters(parameterized_op, resolver)
94+
resolved_circuit = cirq.resolve_parameters(decomposed_circuit, resolver)
95+
cirq.testing.assert_allclose_up_to_global_phase(
96+
cirq.unitary(resolved_op), cirq.unitary(resolved_circuit), atol=1e-8
97+
)
8798

8899

89100
def test_diagram():

0 commit comments

Comments
 (0)