Skip to content

Commit 20b3d93

Browse files
authored
Add _decompose_with_context_ protocol to enable passing qubit manager within decompose (#6118)
* Add _decompose_with_context_ protocol to enable passing qubit manager within decompose * Add more test cases and use typing_extensions for runtime_checkable * Fix lint and coverage tests * another attempt to fix coverage * Fix mypy type check
1 parent 9177708 commit 20b3d93

14 files changed

+375
-45
lines changed

cirq-core/cirq/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@
284284
qft,
285285
Qid,
286286
QuantumFourierTransformGate,
287+
QubitManager,
287288
QubitOrder,
288289
QubitOrderOrList,
289290
QubitPermutationGate,
@@ -566,6 +567,7 @@
566567
decompose,
567568
decompose_once,
568569
decompose_once_with_qubits,
570+
DecompositionContext,
569571
DEFAULT_RESOLVERS,
570572
definitely_commutes,
571573
equal_up_to_global_phase,

cirq-core/cirq/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@
120120

121121
from cirq.ops.controlled_operation import ControlledOperation
122122

123+
from cirq.ops.qubit_manager import BorrowableQubit, CleanQubit, QubitManager, SimpleQubitManager
124+
123125
from cirq.ops.qubit_order import QubitOrder
124126

125127
from cirq.ops.qubit_order_or_list import QubitOrderOrList

cirq-core/cirq/ops/classically_controlled_operation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@ def with_qubits(self, *new_qubits):
105105
)
106106

107107
def _decompose_(self):
108-
result = protocols.decompose_once(self._sub_operation, NotImplemented, flatten=False)
108+
return self._decompose_with_context_()
109+
110+
def _decompose_with_context_(self, context: Optional['cirq.DecompositionContext'] = None):
111+
result = protocols.decompose_once(
112+
self._sub_operation, NotImplemented, flatten=False, context=context
113+
)
109114
if result is NotImplemented:
110115
return NotImplemented
111116

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ def _qid_shape_(self) -> Tuple[int, ...]:
151151

152152
def _decompose_(
153153
self, qubits: Tuple['cirq.Qid', ...]
154+
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
155+
return self._decompose_with_context_(qubits)
156+
157+
def _decompose_with_context_(
158+
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
154159
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
155160
if (
156161
protocols.has_unitary(self.sub_gate)
@@ -192,15 +197,21 @@ def _decompose_(
192197
)
193198
)
194199
if self != controlled_z:
195-
return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented)
200+
return protocols.decompose_once_with_qubits(
201+
controlled_z, qubits, NotImplemented, context=context
202+
)
196203

197204
if isinstance(self.sub_gate, matrix_gates.MatrixGate):
198205
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
199206
# local phase in the controlled variant and hence cannot be ignored.
200207
return NotImplemented
201208

202209
result = protocols.decompose_once_with_qubits(
203-
self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False
210+
self.sub_gate,
211+
qubits[self.num_controls() :],
212+
NotImplemented,
213+
flatten=False,
214+
context=context,
204215
)
205216
if result is NotImplemented:
206217
return NotImplemented

cirq-core/cirq/ops/controlled_operation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,11 @@ def with_qubits(self, *new_qubits):
146146
)
147147

148148
def _decompose_(self):
149+
return self._decompose_with_context_()
150+
151+
def _decompose_with_context_(self, context: Optional['cirq.DecompositionContext'] = None):
149152
result = protocols.decompose_once_with_qubits(
150-
self.gate, self.qubits, NotImplemented, flatten=False
153+
self.gate, self.qubits, NotImplemented, flatten=False, context=context
151154
)
152155
if result is not NotImplemented:
153156
return result
@@ -157,7 +160,9 @@ def _decompose_(self):
157160
# local phase in the controlled variant and hence cannot be ignored.
158161
return NotImplemented
159162

160-
result = protocols.decompose_once(self.sub_operation, NotImplemented, flatten=False)
163+
result = protocols.decompose_once(
164+
self.sub_operation, NotImplemented, flatten=False, context=context
165+
)
161166
if result is NotImplemented:
162167
return NotImplemented
163168

cirq-core/cirq/ops/gate_operation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,13 @@ def _num_qubits_(self):
160160
return len(self._qubits)
161161

162162
def _decompose_(self) -> 'cirq.OP_TREE':
163+
return self._decompose_with_context_()
164+
165+
def _decompose_with_context_(
166+
self, context: Optional['cirq.DecompositionContext'] = None
167+
) -> 'cirq.OP_TREE':
163168
return protocols.decompose_once_with_qubits(
164-
self.gate, self.qubits, NotImplemented, flatten=False
169+
self.gate, self.qubits, NotImplemented, flatten=False, context=context
165170
)
166171

167172
def _pauli_expansion_(self) -> value.LinearDict[str]:

cirq-core/cirq/ops/qubit_manager.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2023 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
import dataclasses
17+
from typing import Iterable, List, TYPE_CHECKING
18+
from cirq.ops import raw_types
19+
20+
if TYPE_CHECKING:
21+
import cirq
22+
23+
24+
class QubitManager(metaclass=abc.ABCMeta):
25+
@abc.abstractmethod
26+
def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
27+
"""Allocate `n` clean qubits, i.e. qubits guaranteed to be in state |0>."""
28+
29+
@abc.abstractmethod
30+
def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
31+
"""Allocate `n` dirty qubits, i.e. the returned qubits can be in any state."""
32+
33+
@abc.abstractmethod
34+
def qfree(self, qubits: Iterable['cirq.Qid']) -> None:
35+
"""Free pre-allocated clean or dirty qubits managed by this qubit manager."""
36+
37+
38+
@dataclasses.dataclass(frozen=True)
39+
class _BaseAncillaQid(raw_types.Qid):
40+
id: int
41+
dim: int = 2
42+
43+
def _comparison_key(self) -> int:
44+
return self.id
45+
46+
@property
47+
def dimension(self) -> int:
48+
return self.dim
49+
50+
def __repr__(self) -> str:
51+
dim_str = f', dim={self.dim}' if self.dim != 2 else ''
52+
return f"cirq.ops.{type(self).__name__}({self.id}{dim_str})"
53+
54+
55+
class CleanQubit(_BaseAncillaQid):
56+
"""An internal qid type that represents a clean ancilla allocation."""
57+
58+
def __str__(self) -> str:
59+
dim_str = f' (d={self.dimension})' if self.dim != 2 else ''
60+
return f"_c({self.id}){dim_str}"
61+
62+
63+
class BorrowableQubit(_BaseAncillaQid):
64+
"""An internal qid type that represents a dirty ancilla allocation."""
65+
66+
def __str__(self) -> str:
67+
dim_str = f' (d={self.dimension})' if self.dim != 2 else ''
68+
return f"_b({self.id}){dim_str}"
69+
70+
71+
class SimpleQubitManager(QubitManager):
72+
"""Allocates a new `CleanQubit`/`BorrowableQubit` for every `qalloc`/`qborrow` request."""
73+
74+
def __init__(self):
75+
self._clean_id = 0
76+
self._borrow_id = 0
77+
78+
def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
79+
self._clean_id += n
80+
return [CleanQubit(i, dim) for i in range(self._clean_id - n, self._clean_id)]
81+
82+
def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
83+
self._borrow_id = self._borrow_id + n
84+
return [BorrowableQubit(i, dim) for i in range(self._borrow_id - n, self._borrow_id)]
85+
86+
def qfree(self, qubits: Iterable['cirq.Qid']) -> None:
87+
for q in qubits:
88+
good = isinstance(q, CleanQubit) and q.id < self._clean_id
89+
good |= isinstance(q, BorrowableQubit) and q.id < self._borrow_id
90+
if not good:
91+
raise ValueError(f"{q} was not allocated by {self}")
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2023 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cirq
16+
from cirq.ops import qubit_manager as cqi
17+
import pytest
18+
19+
20+
def test_clean_qubits():
21+
q = cqi.CleanQubit(1)
22+
assert q.id == 1
23+
assert q.dimension == 2
24+
assert str(q) == '_c(1)'
25+
assert repr(q) == 'cirq.ops.CleanQubit(1)'
26+
27+
q = cqi.CleanQubit(2, dim=3)
28+
assert q.id == 2
29+
assert q.dimension == 3
30+
assert str(q) == '_c(2) (d=3)'
31+
assert repr(q) == 'cirq.ops.CleanQubit(2, dim=3)'
32+
33+
assert cqi.CleanQubit(1) < cqi.CleanQubit(2)
34+
35+
36+
def test_borrow_qubits():
37+
q = cqi.BorrowableQubit(10)
38+
assert q.id == 10
39+
assert q.dimension == 2
40+
assert str(q) == '_b(10)'
41+
assert repr(q) == 'cirq.ops.BorrowableQubit(10)'
42+
43+
q = cqi.BorrowableQubit(20, dim=4)
44+
assert q.id == 20
45+
assert q.dimension == 4
46+
assert str(q) == '_b(20) (d=4)'
47+
assert repr(q) == 'cirq.ops.BorrowableQubit(20, dim=4)'
48+
49+
assert cqi.BorrowableQubit(1) < cqi.BorrowableQubit(2)
50+
51+
52+
@pytest.mark.parametrize('_', range(2))
53+
def test_simple_qubit_manager(_):
54+
qm = cirq.ops.SimpleQubitManager()
55+
assert qm.qalloc(1) == [cqi.CleanQubit(0)]
56+
assert qm.qalloc(2) == [cqi.CleanQubit(1), cqi.CleanQubit(2)]
57+
assert qm.qalloc(1, dim=3) == [cqi.CleanQubit(3, dim=3)]
58+
assert qm.qborrow(1) == [cqi.BorrowableQubit(0)]
59+
assert qm.qborrow(2) == [cqi.BorrowableQubit(1), cqi.BorrowableQubit(2)]
60+
assert qm.qborrow(1, dim=3) == [cqi.BorrowableQubit(3, dim=3)]
61+
qm.qfree([cqi.CleanQubit(i) for i in range(3)] + [cqi.CleanQubit(3, dim=3)])
62+
qm.qfree([cqi.BorrowableQubit(i) for i in range(3)] + [cqi.BorrowableQubit(3, dim=3)])
63+
with pytest.raises(ValueError, match="not allocated"):
64+
qm.qfree([cqi.CleanQubit(10)])
65+
with pytest.raises(ValueError, match="not allocated"):
66+
qm.qfree([cqi.BorrowableQubit(10)])

cirq-core/cirq/ops/raw_types.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,14 @@ def _json_dict_(self) -> Dict[str, Any]:
830830
return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags'])
831831

832832
def _decompose_(self) -> 'cirq.OP_TREE':
833-
return protocols.decompose_once(self.sub_operation, default=None, flatten=False)
833+
return self._decompose_with_context_()
834+
835+
def _decompose_with_context_(
836+
self, context: Optional['cirq.DecompositionContext'] = None
837+
) -> 'cirq.OP_TREE':
838+
return protocols.decompose_once(
839+
self.sub_operation, default=None, flatten=False, context=context
840+
)
834841

835842
def _pauli_expansion_(self) -> value.LinearDict[str]:
836843
return protocols.pauli_expansion(self.sub_operation)
@@ -979,7 +986,14 @@ def __pow__(self, power):
979986
return NotImplemented
980987

981988
def _decompose_(self, qubits):
982-
return protocols.inverse(protocols.decompose_once_with_qubits(self._original, qubits))
989+
return self._decompose_with_context_(qubits)
990+
991+
def _decompose_with_context_(
992+
self, qubits: Sequence['cirq.Qid'], context: Optional['cirq.DecompositionContext'] = None
993+
) -> 'cirq.OP_TREE':
994+
return protocols.inverse(
995+
protocols.decompose_once_with_qubits(self._original, qubits, context=context)
996+
)
983997

984998
def _has_unitary_(self):
985999
from cirq import protocols, devices

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def __repr__(self):
201201
assert i**-1 == t
202202
assert t**-1 == i
203203
assert cirq.decompose(i) == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
204+
assert [*i._decompose_()] == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
205+
assert [*i.gate._decompose_([a, b])] == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
204206
cirq.testing.assert_allclose_up_to_global_phase(
205207
cirq.unitary(i), cirq.unitary(t).conj().T, atol=1e-8
206208
)
@@ -618,6 +620,7 @@ def test_tagged_operation_forwards_protocols():
618620
np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
619621
assert cirq.has_unitary(tagged_h)
620622
assert cirq.decompose(tagged_h) == cirq.decompose(h)
623+
assert [*tagged_h._decompose_()] == cirq.decompose(h)
621624
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
622625
assert cirq.equal_up_to_global_phase(h, tagged_h)
623626
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()

cirq-core/cirq/protocols/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
decompose,
5151
decompose_once,
5252
decompose_once_with_qubits,
53+
DecompositionContext,
5354
SupportsDecompose,
5455
SupportsDecomposeWithQubits,
5556
)

0 commit comments

Comments
 (0)