Skip to content

Commit 7f53521

Browse files
tanujkhattarrht
authored andcommitted
Adds default decompositions for cirq.MatrixGate into X/Y/Z/CZ target gateset. (quantumlib#5088)
* Add default decompositions for cirq.MatrixGate * Add special case to handle MatrixGate as a sub gate in ControlledGate
1 parent c305340 commit 7f53521

File tree

4 files changed

+70
-21
lines changed

4 files changed

+70
-21
lines changed

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,32 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union
15+
from typing import (
16+
AbstractSet,
17+
Any,
18+
cast,
19+
Collection,
20+
Dict,
21+
List,
22+
Optional,
23+
Sequence,
24+
Tuple,
25+
Union,
26+
TYPE_CHECKING,
27+
)
1628

1729
import numpy as np
1830

19-
import cirq
20-
from cirq import protocols, value
31+
from cirq import protocols, value, _import
2132
from cirq._compat import deprecated
22-
from cirq.ops import raw_types, controlled_operation as cop
33+
from cirq.ops import raw_types, controlled_operation as cop, matrix_gates
2334
from cirq.type_workarounds import NotImplementedType
2435

36+
if TYPE_CHECKING:
37+
import cirq
38+
39+
line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices')
40+
2541

2642
@value.value_equality
2743
class ControlledGate(raw_types.Gate):
@@ -137,17 +153,21 @@ def num_controls(self) -> int:
137153
return len(self.control_qid_shape)
138154

139155
def _qid_shape_(self) -> Tuple[int, ...]:
140-
return self.control_qid_shape + cirq.qid_shape(self.sub_gate)
156+
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)
141157

142158
def _decompose_(self, qubits):
159+
if isinstance(self.sub_gate, matrix_gates.MatrixGate):
160+
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
161+
# local phase in the controlled variant and hence cannot be ignored.
162+
return NotImplemented
163+
143164
result = protocols.decompose_once_with_qubits(
144165
self.sub_gate, qubits[self.num_controls() :], NotImplemented
145166
)
146-
147167
if result is NotImplemented:
148168
return NotImplemented
149169

150-
decomposed = []
170+
decomposed: List['cirq.Operation'] = []
151171
for op in result:
152172
decomposed.append(
153173
cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
@@ -172,7 +192,7 @@ def _value_equality_values_(self):
172192
)
173193

174194
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
175-
qubits = cirq.LineQid.for_gate(self)
195+
qubits = line_qubit.LineQid.for_gate(self)
176196
op = self.sub_gate.on(*qubits[self.num_controls() :])
177197
c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
178198
return protocols.apply_unitary(c_op, args, default=NotImplemented)
@@ -181,7 +201,7 @@ def _has_unitary_(self) -> bool:
181201
return protocols.has_unitary(self.sub_gate)
182202

183203
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
184-
qubits = cirq.LineQid.for_gate(self)
204+
qubits = line_qubit.LineQid.for_gate(self)
185205
op = self.sub_gate.on(*qubits[self.num_controls() :])
186206
c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
187207

@@ -191,7 +211,7 @@ def _has_mixture_(self) -> bool:
191211
return protocols.has_mixture(self.sub_gate)
192212

193213
def _mixture_(self) -> Union[np.ndarray, NotImplementedType]:
194-
qubits = cirq.LineQid.for_gate(self)
214+
qubits = line_qubit.LineQid.for_gate(self)
195215
op = self.sub_gate.on(*qubits[self.num_controls() :])
196216
c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
197217
return protocols.mixture(c_op, default=NotImplemented)

cirq-core/cirq/ops/matrix_gates.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,23 @@
1818

1919
import numpy as np
2020

21-
from cirq import linalg, protocols
21+
from cirq import linalg, protocols, _import
2222
from cirq._compat import proper_repr
2323
from cirq.ops import raw_types
2424

2525
if TYPE_CHECKING:
2626
import cirq
2727

28+
single_qubit_decompositions = _import.LazyLoader(
29+
'single_qubit_decompositions', globals(), 'cirq.transformers.analytical_decompositions'
30+
)
31+
two_qubit_to_cz = _import.LazyLoader(
32+
'two_qubit_to_cz', globals(), 'cirq.transformers.analytical_decompositions'
33+
)
34+
three_qubit_decomposition = _import.LazyLoader(
35+
'three_qubit_decomposition', globals(), 'cirq.transformers.analytical_decompositions'
36+
)
37+
2838

2939
class MatrixGate(raw_types.Gate):
3040
"""A unitary qubit or qudit gate defined entirely by its matrix."""
@@ -116,6 +126,20 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate':
116126
result[linalg.slice_for_qubits_equal_to([j], 1)] *= np.conj(p)
117127
return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape)
118128

129+
def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> 'cirq.OP_TREE':
130+
if self._qid_shape == (2,):
131+
return [
132+
g.on(qubits[0])
133+
for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix)
134+
]
135+
if self._qid_shape == (2,) * 2:
136+
return two_qubit_to_cz.two_qubit_matrix_to_cz_operations(
137+
*qubits, self._matrix, allow_partial_czs=True
138+
)
139+
if self._qid_shape == (2,) * 3:
140+
return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix)
141+
return NotImplemented
142+
119143
def _has_unitary_(self) -> bool:
120144
return True
121145

cirq-core/cirq/ops/matrix_gates_test.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,19 @@ def test_str_executes():
276276
assert '0' in str(cirq.MatrixGate(np.eye(4)))
277277

278278

279-
def test_one_qubit_consistent():
280-
u = cirq.testing.random_unitary(2)
281-
g = cirq.MatrixGate(u)
282-
cirq.testing.assert_implements_consistent_protocols(g)
283-
284-
285-
def test_two_qubit_consistent():
286-
u = cirq.testing.random_unitary(4)
287-
g = cirq.MatrixGate(u)
288-
cirq.testing.assert_implements_consistent_protocols(g)
279+
@pytest.mark.parametrize('n', [1, 2, 3, 4, 5])
280+
def test_implements_consistent_protocols(n):
281+
u = cirq.testing.random_unitary(2 ** n)
282+
g1 = cirq.MatrixGate(u)
283+
cirq.testing.assert_implements_consistent_protocols(g1, ignoring_global_phase=True)
284+
cirq.testing.assert_decompose_ends_at_default_gateset(g1)
285+
286+
if n == 1:
287+
return
288+
289+
g2 = cirq.MatrixGate(u, qid_shape=(4,) + (2,) * (n - 2))
290+
cirq.testing.assert_implements_consistent_protocols(g2, ignoring_global_phase=True)
291+
cirq.testing.assert_decompose_ends_at_default_gateset(g2)
289292

290293

291294
def test_repr():

cirq-core/cirq/testing/consistent_decomposition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase:
5151

5252
def _known_gate_with_no_decomposition(val: Any):
5353
"""Checks whether `val` is a known gate with no default decomposition to default gateset."""
54+
if isinstance(val, ops.MatrixGate):
55+
return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3]
5456
return False
5557

5658

0 commit comments

Comments
 (0)