Skip to content

Commit 051d572

Browse files
authored
Extend default decomposition of cirq.ControlledGate and cirq.ControlledOperation to end in X/Y/Z/CZ target gateset (quantumlib#5091)
When decomposed, controlled gates and operations simply fall back on the decomposition of underlying sub_gate / sub_operation and return apply appropriate controls to each decomposed operation. If we can ensure that all underlying gates / operations decompose to X/Y/Z/CZ target gateset, then their controlled versions will decompose to: - Multi controlled single qubit rotations (corresponding to (X/Y/Z).controlled_by(...)) OR - Multi controlled CZs, which is also equivalent to a multi controlled single qubit rotation (Z.controlled_by(...)) In Cirq, we have an analytical method to decompose a multi controlled rotation into X/Y/Z/CZ - `cirq.decompose_multi_controlled_rotation`, which is now used in the `_decompose_` method of controlled gates. However, there are many corner cases and limitations of the current approach, which are dealt appropriately in this PR to enable a "best-effort" decomposition of controlled gates to the cirq target gateset. Some of the limitations are: - If decomposition of sub_gate / sub_operation ignores global phase, then the controlled operation cannot directly rely on decomposing the sub operation. An explicit check is added to not fallback on sub_gate if sub_gate is a MatrixGate. - `decompose_multi_controlled_rotation` works only for qubits (doesn't work for qudits) and when all control_values are 1. Appropriate logic is added to extend its functionality to handle control_values which are 0 or (0, 1). - We have explicit types for a few important controlled gates, like `CCZ`, `CZ`, `CCX`, `CX` etc. in cirq. Appropriate type conversion logic is added to smartly infer the types of equivalent gates (eg: Controlled(sub_gate=CZ) should be inferred as CCZ) such that their decompositions can be used for decomposing the controlled gates. This is definitely the most tricky one to get right and I've added appropriate tests to cover the different cases. Part of quantumlib#4858
1 parent 2ea535d commit 051d572

5 files changed

+179
-43
lines changed

cirq/ops/controlled_gate.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
if TYPE_CHECKING:
3737
import cirq
3838

39+
controlled_gate_decomposition = _import.LazyLoader(
40+
'controlled_gate_decomposition', globals(), 'cirq.transformers.analytical_decompositions'
41+
)
42+
common_gates = _import.LazyLoader('common_gates', globals(), 'cirq.ops')
3943
line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices')
4044

4145

@@ -156,6 +160,40 @@ def _qid_shape_(self) -> Tuple[int, ...]:
156160
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)
157161

158162
def _decompose_(self, qubits):
163+
if (
164+
protocols.has_unitary(self.sub_gate)
165+
and protocols.num_qubits(self.sub_gate) == 1
166+
and self._qid_shape_() == (2,) * len(self._qid_shape_())
167+
):
168+
control_qubits = list(qubits[: self.num_controls()])
169+
invert_ops: List['cirq.Operation'] = []
170+
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
171+
if set(cvals) == {0}:
172+
invert_ops.append(common_gates.X(cqbit))
173+
elif set(cvals) == {0, 1}:
174+
control_qubits.remove(cqbit)
175+
decomposed_ops = controlled_gate_decomposition.decompose_multi_controlled_rotation(
176+
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
177+
)
178+
return invert_ops + decomposed_ops + invert_ops
179+
180+
if isinstance(self.sub_gate, common_gates.CZPowGate):
181+
z_sub_gate = common_gates.ZPowGate(
182+
exponent=self.sub_gate.exponent, global_shift=self.sub_gate.global_shift
183+
)
184+
kwargs = {
185+
'num_controls': self.num_controls() + 1,
186+
'control_values': self.control_values + (1,),
187+
'control_qid_shape': self.control_qid_shape + (2,),
188+
}
189+
controlled_z = (
190+
z_sub_gate.controlled(**kwargs)
191+
if protocols.is_parameterized(self)
192+
else ControlledGate(z_sub_gate, **kwargs)
193+
)
194+
if self != controlled_z:
195+
return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented)
196+
159197
if isinstance(self.sub_gate, matrix_gates.MatrixGate):
160198
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
161199
# local phase in the controlled variant and hence cannot be ignored.
@@ -170,7 +208,7 @@ def _decompose_(self, qubits):
170208
decomposed: List['cirq.Operation'] = []
171209
for op in result:
172210
decomposed.append(
173-
cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
211+
op.controlled_by(*qubits[: self.num_controls()], control_values=self.control_values)
174212
)
175213
return decomposed
176214

cirq/ops/controlled_gate_test.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,28 @@ def __repr__(self):
3939

4040

4141
class GateAllocatingNewSpaceForResult(cirq.SingleQubitGate):
42+
def __init__(self):
43+
self._matrix = cirq.testing.random_unitary(2, random_state=4321)
44+
4245
def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]:
4346
assert len(args.axes) == 1
4447
a = args.axes[0]
4548
seed = cast(Tuple[Union[int, slice, 'ellipsis'], ...], (slice(None),))
4649
zero = seed * a + (0, Ellipsis)
4750
one = seed * a + (1, Ellipsis)
4851
result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype)
49-
result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3
50-
result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7
52+
result[zero] = (
53+
args.target_tensor[zero] * self._matrix[0][0]
54+
+ args.target_tensor[one] * self._matrix[0][1]
55+
)
56+
result[one] = (
57+
args.target_tensor[zero] * self._matrix[1][0]
58+
+ args.target_tensor[one] * self._matrix[1][1]
59+
)
5160
return result
5261

5362
def _unitary_(self):
54-
return np.array([[2, 3], [5, 7]])
63+
return self._matrix
5564

5665
def __eq__(self, other):
5766
return isinstance(other, type(self))
@@ -316,28 +325,42 @@ def test_unitary():
316325

317326

318327
@pytest.mark.parametrize(
319-
'gate',
328+
'gate, should_decompose_to_target',
320329
[
321-
cirq.X,
322-
cirq.X ** 0.5,
323-
cirq.rx(np.pi),
324-
cirq.rx(np.pi / 2),
325-
cirq.Z,
326-
cirq.H,
327-
cirq.CNOT,
328-
cirq.SWAP,
329-
cirq.CCZ,
330-
cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)),
331-
GateUsingWorkspaceForApplyUnitary(),
332-
GateAllocatingNewSpaceForResult(),
333-
cirq.IdentityGate(qid_shape=(3, 4)),
330+
(cirq.X, True),
331+
(cirq.X ** 0.5, True),
332+
(cirq.rx(np.pi), True),
333+
(cirq.rx(np.pi / 2), True),
334+
(cirq.Z, True),
335+
(cirq.H, True),
336+
(cirq.CNOT, True),
337+
(cirq.SWAP, True),
338+
(cirq.CCZ, True),
339+
(cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ)), True),
340+
(GateUsingWorkspaceForApplyUnitary(), True),
341+
(GateAllocatingNewSpaceForResult(), True),
342+
(cirq.IdentityGate(qid_shape=(3, 4)), True),
343+
(
344+
cirq.ControlledGate(
345+
cirq.XXPowGate(exponent=0.25, global_shift=-0.5),
346+
num_controls=2,
347+
control_values=(1, (1, 0)),
348+
),
349+
True,
350+
),
334351
# Single qudit gate with dimension 4.
335-
cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2)),
352+
(cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), False),
353+
(cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False),
354+
(cirq.XX ** sympy.Symbol("s"), True),
355+
(cirq.CZ ** sympy.Symbol("s"), True),
336356
],
337357
)
338-
def test_controlled_gate_is_consistent(gate: cirq.Gate):
358+
def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target):
339359
cgate = cirq.ControlledGate(gate)
340360
cirq.testing.assert_implements_consistent_protocols(cgate)
361+
cirq.testing.assert_decompose_ends_at_default_gateset(
362+
cgate, ignore_known_gates=not should_decompose_to_target
363+
)
341364

342365

343366
def test_pow_inverse():

cirq/ops/controlled_operation.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from cirq import protocols, qis, value
3232
from cirq._compat import deprecated
33-
from cirq.ops import raw_types, gate_operation, controlled_gate
33+
from cirq.ops import raw_types, gate_operation, controlled_gate, matrix_gates
3434
from cirq.type_workarounds import NotImplementedType
3535

3636
if TYPE_CHECKING:
@@ -130,11 +130,22 @@ def with_qubits(self, *new_qubits):
130130
)
131131

132132
def _decompose_(self):
133+
result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
134+
if result is not NotImplemented:
135+
return result
136+
137+
if isinstance(self.sub_operation.gate, matrix_gates.MatrixGate):
138+
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
139+
# local phase in the controlled variant and hence cannot be ignored.
140+
return NotImplemented
141+
133142
result = protocols.decompose_once(self.sub_operation, NotImplemented)
134143
if result is NotImplemented:
135144
return NotImplemented
136145

137-
return [ControlledOperation(self.controls, op, self.control_values) for op in result]
146+
return [
147+
op.controlled_by(*self.controls, control_values=self.control_values) for op in result
148+
]
138149

139150
def _value_equality_values_(self):
140151
return (frozenset(zip(self.controls, self.control_values)), self.sub_operation)

cirq/ops/controlled_operation_test.py

+75-17
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,28 @@ def __repr__(self):
4040

4141

4242
class GateAllocatingNewSpaceForResult(cirq.SingleQubitGate):
43+
def __init__(self):
44+
self._matrix = cirq.testing.random_unitary(2, random_state=1234)
45+
4346
def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]:
4447
assert len(args.axes) == 1
4548
a = args.axes[0]
4649
seed = cast(Tuple[Union[int, slice, 'ellipsis'], ...], (slice(None),))
4750
zero = seed * a + (0, Ellipsis)
4851
one = seed * a + (1, Ellipsis)
4952
result = np.zeros(args.target_tensor.shape, args.target_tensor.dtype)
50-
result[zero] = args.target_tensor[zero] * 2 + args.target_tensor[one] * 3
51-
result[one] = args.target_tensor[zero] * 5 + args.target_tensor[one] * 7
53+
result[zero] = (
54+
args.target_tensor[zero] * self._matrix[0][0]
55+
+ args.target_tensor[one] * self._matrix[0][1]
56+
)
57+
result[one] = (
58+
args.target_tensor[zero] * self._matrix[1][0]
59+
+ args.target_tensor[one] * self._matrix[1][1]
60+
)
5261
return result
5362

5463
def _unitary_(self):
55-
return np.array([[2, 3], [5, 7]])
64+
return self._matrix
5665

5766
def __eq__(self, other):
5867
return isinstance(other, type(self))
@@ -297,33 +306,82 @@ class UndiagrammableGate(cirq.SingleQubitGate):
297306

298307

299308
@pytest.mark.parametrize(
300-
'gate',
309+
'gate, should_decompose_to_target',
301310
[
302-
cirq.X(cirq.NamedQubit('q1')),
303-
cirq.X(cirq.NamedQubit('q1')) ** 0.5,
304-
cirq.rx(np.pi)(cirq.NamedQubit('q1')),
305-
cirq.rx(np.pi / 2)(cirq.NamedQubit('q1')),
306-
cirq.Z(cirq.NamedQubit('q1')),
307-
cirq.H(cirq.NamedQubit('q1')),
308-
cirq.CNOT(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')),
309-
cirq.SWAP(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')),
310-
cirq.CCZ(cirq.NamedQubit('q1'), cirq.NamedQubit('q2'), cirq.NamedQubit('q3')),
311-
cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ))(*cirq.LineQubit.range(5)),
312-
GateUsingWorkspaceForApplyUnitary()(cirq.NamedQubit('q1')),
313-
GateAllocatingNewSpaceForResult()(cirq.NamedQubit('q1')),
311+
(cirq.X(cirq.NamedQubit('q1')), True),
312+
(cirq.X(cirq.NamedQubit('q1')) ** 0.5, True),
313+
(cirq.rx(np.pi)(cirq.NamedQubit('q1')), True),
314+
(cirq.rx(np.pi / 2)(cirq.NamedQubit('q1')), True),
315+
(cirq.Z(cirq.NamedQubit('q1')), True),
316+
(cirq.H(cirq.NamedQubit('q1')), True),
317+
(cirq.CNOT(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), True),
318+
(cirq.SWAP(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')), True),
319+
(cirq.CCZ(cirq.NamedQubit('q1'), cirq.NamedQubit('q2'), cirq.NamedQubit('q3')), True),
320+
(cirq.ControlledGate(cirq.ControlledGate(cirq.CCZ))(*cirq.LineQubit.range(5)), True),
321+
(GateUsingWorkspaceForApplyUnitary()(cirq.NamedQubit('q1')), True),
322+
(GateAllocatingNewSpaceForResult()(cirq.NamedQubit('q1')), True),
323+
(
324+
cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)).on(
325+
cirq.NamedQid("q", 4)
326+
),
327+
False,
328+
),
329+
(
330+
cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)).on(
331+
cirq.NamedQubit('q1'), cirq.NamedQubit('q2')
332+
),
333+
False,
334+
),
335+
(cirq.XX(cirq.NamedQubit('q1'), cirq.NamedQubit('q2')) ** sympy.Symbol("s"), True),
336+
(cirq.DiagonalGate(sympy.symbols("s1, s2")).on(cirq.NamedQubit("q")), False),
314337
],
315338
)
316-
def test_controlled_operation_is_consistent(gate: cirq.GateOperation):
339+
def test_controlled_operation_is_consistent(
340+
gate: cirq.GateOperation, should_decompose_to_target: bool
341+
):
317342
cb = cirq.NamedQubit('ctr')
318343
cgate = cirq.ControlledOperation([cb], gate)
319344
cirq.testing.assert_implements_consistent_protocols(cgate)
345+
cirq.testing.assert_decompose_ends_at_default_gateset(
346+
cgate, ignore_known_gates=not should_decompose_to_target
347+
)
320348

321349
cgate = cirq.ControlledOperation([cb], gate, control_values=[0])
322350
cirq.testing.assert_implements_consistent_protocols(cgate)
351+
cirq.testing.assert_decompose_ends_at_default_gateset(
352+
cgate, ignore_known_gates=(not should_decompose_to_target or cirq.is_parameterized(gate))
353+
)
354+
355+
cgate = cirq.ControlledOperation([cb], gate, control_values=[(0, 1)])
356+
cirq.testing.assert_implements_consistent_protocols(cgate)
357+
cirq.testing.assert_decompose_ends_at_default_gateset(
358+
cgate, ignore_known_gates=(not should_decompose_to_target or cirq.is_parameterized(gate))
359+
)
323360

324361
cb3 = cb.with_dimension(3)
325362
cgate = cirq.ControlledOperation([cb3], gate, control_values=[(0, 2)])
326363
cirq.testing.assert_implements_consistent_protocols(cgate)
364+
cirq.testing.assert_decompose_ends_at_default_gateset(cgate)
365+
366+
367+
def test_controlled_circuit_operation_is_consistent():
368+
op = cirq.CircuitOperation(
369+
cirq.FrozenCircuit(
370+
cirq.XXPowGate(exponent=0.25, global_shift=-0.5).on(*cirq.LineQubit.range(2))
371+
)
372+
)
373+
cb = cirq.NamedQubit('ctr')
374+
cop = cirq.ControlledOperation([cb], op)
375+
cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2))
376+
cirq.testing.assert_decompose_ends_at_default_gateset(cop)
377+
378+
cop = cirq.ControlledOperation([cb], op, control_values=[0])
379+
cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2))
380+
cirq.testing.assert_decompose_ends_at_default_gateset(cop)
381+
382+
cop = cirq.ControlledOperation([cb], op, control_values=[(0, 1)])
383+
cirq.testing.assert_implements_consistent_protocols(cop, exponents=(-1, 1, 2))
384+
cirq.testing.assert_decompose_ends_at_default_gateset(cop)
327385

328386

329387
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])

cirq/testing/consistent_decomposition.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,22 @@ def _known_gate_with_no_decomposition(val: Any):
5353
"""Checks whether `val` is a known gate with no default decomposition to default gateset."""
5454
if isinstance(val, ops.MatrixGate):
5555
return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3]
56+
if isinstance(val, ops.ControlledGate):
57+
if protocols.is_parameterized(val):
58+
return True
59+
if isinstance(val.sub_gate, ops.MatrixGate) and protocols.num_qubits(val.sub_gate) > 1:
60+
return True
61+
if val.control_qid_shape != (2,) * val.num_controls():
62+
return True
63+
return _known_gate_with_no_decomposition(val.sub_gate)
5664
return False
5765

5866

59-
def assert_decompose_ends_at_default_gateset(val: Any):
67+
def assert_decompose_ends_at_default_gateset(val: Any, ignore_known_gates: bool = True):
6068
"""Asserts that cirq.decompose(val) ends at default cirq gateset or a known gate."""
61-
if _known_gate_with_no_decomposition(val):
62-
return # coverage: ignore
6369
args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),)
6470
dec_once = protocols.decompose_once(val, [val(*args[0]) if args else val], *args)
6571
for op in [*ops.flatten_to_ops(protocols.decompose(d) for d in dec_once)]:
66-
assert _known_gate_with_no_decomposition(op.gate) or (
72+
assert (_known_gate_with_no_decomposition(op.gate) and ignore_known_gates) or (
6773
op in protocols.decompose_protocol.DECOMPOSE_TARGET_GATESET
6874
), f'{val} decomposed to {op}, which is not part of default cirq target gateset.'

0 commit comments

Comments
 (0)