Skip to content

Commit 5ca4ab6

Browse files
authored
Move on_each to Gate, and deprecate SupportsOnEachGate (#4303)
Part of #4034. (See #4034 (comment)), moves the on_each functionality to cirq.Gate, and deprecates SupportsOnEachGate. Note the tests that are deleted here had already been duplicated in raw_types_test.py by https://github.com/quantumlib/Cirq/pull/4281/files.
1 parent 4852c46 commit 5ca4ab6

File tree

6 files changed

+107
-105
lines changed

6 files changed

+107
-105
lines changed

cirq-core/cirq/ops/common_channels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def asymmetric_depolarize(
246246

247247

248248
@value.value_equality
249-
class DepolarizingChannel(gate_features.SupportsOnEachGate, raw_types.Gate):
249+
class DepolarizingChannel(raw_types.Gate):
250250
"""A channel that depolarizes one or several qubits."""
251251

252252
def __init__(self, p: float, n_qubits: int = 1) -> None:

cirq-core/cirq/ops/gate_features.py

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
"""
1919

2020
import abc
21-
from typing import Union, Iterable, Any, List, Sequence
2221

22+
from cirq import value, ops
23+
from cirq._compat import deprecated_class
2324
from cirq.ops import raw_types
2425

2526

@@ -31,57 +32,23 @@ def qubit_index_to_equivalence_group_key(self, index: int) -> int:
3132
return 0
3233

3334

34-
class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta):
35-
"""A gate that can be applied to exactly one qubit."""
36-
37-
def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]:
38-
"""Returns a list of operations applying the gate to all targets.
39-
Args:
40-
*targets: The qubits to apply this gate to. For single-qubit gates
41-
this can be provided as varargs or a combination of nested
42-
iterables. For multi-qubit gates this must be provided as an
43-
`Iterable[Sequence[Qid]]`, where each sequence has `num_qubits`
44-
qubits.
45-
Returns:
46-
Operations applying this gate to the target qubits.
47-
Raises:
48-
ValueError if targets are not instances of Qid or Iterable[Qid].
49-
ValueError if the gate qubit number is incompatible.
50-
"""
51-
operations: List[raw_types.Operation] = []
52-
if self._num_qubits_() > 1:
53-
iterator: Iterable = targets
54-
if len(targets) == 1:
55-
if not isinstance(targets[0], Iterable):
56-
raise TypeError(f'{targets[0]} object is not iterable.')
57-
t0 = list(targets[0])
58-
iterator = [t0] if t0 and isinstance(t0[0], raw_types.Qid) else t0
59-
for target in iterator:
60-
if not isinstance(target, Sequence):
61-
raise ValueError(
62-
f'Inputs to multi-qubit gates must be Sequence[Qid].'
63-
f' Type: {type(target)}'
64-
)
65-
if not all(isinstance(x, raw_types.Qid) for x in target):
66-
raise ValueError(f'All values in sequence should be Qids, but got {target}')
67-
if len(target) != self._num_qubits_():
68-
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
69-
operations.append(self.on(*target))
70-
return operations
71-
72-
for target in targets:
73-
if isinstance(target, raw_types.Qid):
74-
operations.append(self.on(target))
75-
elif isinstance(target, Iterable) and not isinstance(target, str):
76-
operations.extend(self.on_each(*target))
77-
else:
78-
raise ValueError(
79-
f'Gate was called with type different than Qid. Type: {type(target)}'
80-
)
81-
return operations
82-
83-
84-
class SingleQubitGate(SupportsOnEachGate, metaclass=abc.ABCMeta):
35+
class _SupportsOnEachGateMeta(value.ABCMetaImplementAnyOneOf):
36+
def __instancecheck__(cls, instance):
37+
return isinstance(instance, (SingleQubitGate, ops.DepolarizingChannel)) or issubclass(
38+
type(instance), SupportsOnEachGate
39+
)
40+
41+
42+
@deprecated_class(
43+
deadline='v0.14',
44+
fix='Remove `SupportsOnEachGate` from the list of parent classes. '
45+
'`on_each` is now directly supported in the `Gate` base class.',
46+
)
47+
class SupportsOnEachGate(raw_types.Gate, metaclass=_SupportsOnEachGateMeta):
48+
pass
49+
50+
51+
class SingleQubitGate(raw_types.Gate, metaclass=abc.ABCMeta):
8552
"""A gate that must be applied to exactly one qubit."""
8653

8754
def _num_qubits_(self) -> int:

cirq-core/cirq/ops/gate_features_test.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections.abc import Iterator
16-
from typing import Any
17-
1815
import pytest
1916

2017
import cirq
18+
from cirq.testing import assert_deprecated
2119

2220

2321
def test_single_qubit_gate_validate_args():
@@ -137,41 +135,6 @@ def matrix(self):
137135
g.validate_args([a, b, c, d])
138136

139137

140-
def test_on_each():
141-
class CustomGate(cirq.SingleQubitGate):
142-
pass
143-
144-
a = cirq.NamedQubit('a')
145-
b = cirq.NamedQubit('b')
146-
c = CustomGate()
147-
148-
assert c.on_each() == []
149-
assert c.on_each(a) == [c(a)]
150-
assert c.on_each(a, b) == [c(a), c(b)]
151-
assert c.on_each(b, a) == [c(b), c(a)]
152-
153-
assert c.on_each([]) == []
154-
assert c.on_each([a]) == [c(a)]
155-
assert c.on_each([a, b]) == [c(a), c(b)]
156-
assert c.on_each([b, a]) == [c(b), c(a)]
157-
assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)]
158-
159-
with pytest.raises(ValueError):
160-
c.on_each('abcd')
161-
with pytest.raises(ValueError):
162-
c.on_each(['abcd'])
163-
with pytest.raises(ValueError):
164-
c.on_each([a, 'abcd'])
165-
166-
def iterator(qubits):
167-
for i in range(len(qubits)):
168-
yield qubits[i]
169-
170-
qubit_iterator = iterator([a, b, a, b])
171-
assert isinstance(qubit_iterator, Iterator)
172-
assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)]
173-
174-
175138
def test_qasm_output_args_validate():
176139
args = cirq.QasmArgs(version='2.0')
177140
args.validate_version('2.0')
@@ -231,16 +194,40 @@ def __init__(self, num_qubits):
231194
g.validate_args([a, b, c, d])
232195

233196

234-
def test_on_each_iterable_qid():
235-
class QidIter(cirq.Qid):
236-
@property
237-
def dimension(self) -> int:
238-
return 2
197+
def test_supports_on_each_inheritance_shim():
198+
class NotOnEach(cirq.Gate):
199+
def num_qubits(self):
200+
return 1 # coverage: ignore
201+
202+
class OnEach(cirq.ops.gate_features.SupportsOnEachGate):
203+
def num_qubits(self):
204+
return 1 # coverage: ignore
205+
206+
class SingleQ(cirq.SingleQubitGate):
207+
pass
208+
209+
class TwoQ(cirq.TwoQubitGate):
210+
pass
211+
212+
not_on_each = NotOnEach()
213+
single_q = SingleQ()
214+
two_q = TwoQ()
215+
with assert_deprecated(deadline="v0.14"):
216+
on_each = OnEach()
217+
218+
assert not isinstance(not_on_each, cirq.ops.gate_features.SupportsOnEachGate)
219+
assert isinstance(on_each, cirq.ops.gate_features.SupportsOnEachGate)
220+
assert isinstance(single_q, cirq.ops.gate_features.SupportsOnEachGate)
221+
assert not isinstance(two_q, cirq.ops.gate_features.SupportsOnEachGate)
222+
assert isinstance(cirq.X, cirq.ops.gate_features.SupportsOnEachGate)
223+
assert not isinstance(cirq.CX, cirq.ops.gate_features.SupportsOnEachGate)
224+
assert isinstance(cirq.DepolarizingChannel(0.01), cirq.ops.gate_features.SupportsOnEachGate)
239225

240-
def _comparison_key(self) -> Any:
241-
return 1
242226

243-
def __iter__(self):
244-
raise NotImplementedError()
227+
def test_supports_on_each_deprecation():
228+
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate):
229+
def num_qubits(self):
230+
return 1 # coverage: ignore
245231

246-
assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())
232+
with assert_deprecated(deadline="v0.14"):
233+
assert isinstance(CustomGate(), cirq.ops.gate_features.SupportsOnEachGate)

cirq-core/cirq/ops/identity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020

2121
from cirq import protocols, value
2222
from cirq._doc import document
23-
from cirq.ops import gate_features, raw_types
23+
from cirq.ops import raw_types
2424

2525
if TYPE_CHECKING:
2626
import cirq
2727

2828

2929
@value.value_equality
30-
class IdentityGate(gate_features.SupportsOnEachGate, raw_types.Gate):
30+
class IdentityGate(raw_types.Gate):
3131
"""A Gate that perform no operation on qubits.
3232
3333
The unitary matrix of this gate is a diagonal matrix with all 1s on the

cirq-core/cirq/ops/raw_types.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Collection,
2424
Dict,
2525
Hashable,
26+
Iterable,
27+
List,
2628
Optional,
2729
Sequence,
2830
Tuple,
@@ -211,6 +213,52 @@ def on(self, *qubits: Qid) -> 'Operation':
211213

212214
return gate_operation.GateOperation(self, list(qubits))
213215

216+
def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation']:
217+
"""Returns a list of operations applying the gate to all targets.
218+
Args:
219+
*targets: The qubits to apply this gate to. For single-qubit gates
220+
this can be provided as varargs or a combination of nested
221+
iterables. For multi-qubit gates this must be provided as an
222+
`Iterable[Sequence[Qid]]`, where each sequence has `num_qubits`
223+
qubits.
224+
Returns:
225+
Operations applying this gate to the target qubits.
226+
Raises:
227+
ValueError if targets are not instances of Qid or Iterable[Qid].
228+
ValueError if the gate qubit number is incompatible.
229+
"""
230+
operations: List['cirq.Operation'] = []
231+
if self._num_qubits_() > 1:
232+
iterator: Iterable = targets
233+
if len(targets) == 1:
234+
if not isinstance(targets[0], Iterable):
235+
raise TypeError(f'{targets[0]} object is not iterable.')
236+
t0 = list(targets[0])
237+
iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0
238+
for target in iterator:
239+
if not isinstance(target, Sequence):
240+
raise ValueError(
241+
f'Inputs to multi-qubit gates must be Sequence[Qid].'
242+
f' Type: {type(target)}'
243+
)
244+
if not all(isinstance(x, Qid) for x in target):
245+
raise ValueError(f'All values in sequence should be Qids, but got {target}')
246+
if len(target) != self._num_qubits_():
247+
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
248+
operations.append(self.on(*target))
249+
return operations
250+
251+
for target in targets:
252+
if isinstance(target, Qid):
253+
operations.append(self.on(target))
254+
elif isinstance(target, Iterable) and not isinstance(target, str):
255+
operations.extend(self.on_each(*target))
256+
else:
257+
raise ValueError(
258+
f'Gate was called with type different than Qid. Type: {type(target)}'
259+
)
260+
return operations
261+
214262
def wrap_in_linear_combination(
215263
self, coefficient: Union[complex, float, int] = 1
216264
) -> 'cirq.LinearCombinationOfGates':

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ class CustomGate(cirq.SingleQubitGate):
801801

802802

803803
def test_on_each_two_qubits():
804-
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.TwoQubitGate):
804+
class CustomGate(cirq.TwoQubitGate):
805805
pass
806806

807807
a = cirq.NamedQubit('a')
@@ -855,7 +855,7 @@ class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.TwoQubitGate):
855855

856856

857857
def test_on_each_three_qubits():
858-
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.ThreeQubitGate):
858+
class CustomGate(cirq.ThreeQubitGate):
859859
pass
860860

861861
a = cirq.NamedQubit('a')

0 commit comments

Comments
 (0)