Skip to content

Commit 4ec906d

Browse files
authored
Allow on_each for multi-qubit gates (#4281)
Adds on_each support to the SupportsOnEachGate mixin for multi-qubit gates. The handling here is not as flexible as for single-qubit gates, which allows any tree of gates and applies them depth-first. This allows the following two options for multi-qubit gates: ``` A: varargs form gate.on_each([q1, q2], [q3, q4]) B: explicit form gate.on_each([[q1, q2], [q3, q4]]) ``` Discussion here, #4034 (comment). Part of #4236.
1 parent e2b4477 commit 4ec906d

File tree

4 files changed

+237
-14
lines changed

4 files changed

+237
-14
lines changed

cirq-core/cirq/ops/common_channels_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,22 @@ def test_deprecated_on_each_for_depolarizing_channel_one_qubit():
241241

242242

243243
def test_deprecated_on_each_for_depolarizing_channel_two_qubits():
244-
q0, q1 = cirq.LineQubit.range(2)
244+
q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6)
245245
op = cirq.DepolarizingChannel(p=0.1, n_qubits=2)
246246

247-
with pytest.raises(ValueError, match="one qubit"):
247+
op.on_each([(q0, q1)])
248+
op.on_each([(q0, q1), (q2, q3)])
249+
op.on_each(zip([q0, q2, q4], [q1, q3, q5]))
250+
op.on_each((q0, q1))
251+
op.on_each([q0, q1])
252+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
248253
op.on_each(q0, q1)
254+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
255+
op.on_each([('bogus object 0', 'bogus object 1')])
256+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
257+
op.on_each(['01'])
258+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
259+
op.on_each([(False, None)])
249260

250261

251262
def test_depolarizing_channel_apply_two_qubits():

cirq-core/cirq/ops/gate_features.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
import abc
21-
from typing import Union, Iterable, Any, List
21+
from typing import Union, Iterable, Any, List, Sequence
2222

2323
from cirq.ops import raw_types
2424

@@ -36,20 +36,39 @@ class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta):
3636

3737
def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]:
3838
"""Returns a list of operations applying the gate to all targets.
39-
4039
Args:
41-
*targets: The qubits to apply this gate to.
42-
40+
*targets: The qubits to apply this gate to. For single-qubit gates
41+
this can be provided as varargs or a combination of nested
42+
iterables. For multi-qubit gates this must be provided as an
43+
`Iterable[Sequence[Qid]]`, where each sequence has `num_qubits`
44+
qubits.
4345
Returns:
4446
Operations applying this gate to the target qubits.
45-
4647
Raises:
47-
ValueError if targets are not instances of Qid or List[Qid].
48-
ValueError if the gate operates on two or more Qids.
48+
ValueError if targets are not instances of Qid or Iterable[Qid].
49+
ValueError if the gate qubit number is incompatible.
4950
"""
51+
operations: List[raw_types.Operation] = []
5052
if self._num_qubits_() > 1:
51-
raise ValueError('This gate only supports on_each when it is a one qubit gate.')
52-
operations = [] # type: List[raw_types.Operation]
53+
iterator: Iterable = targets
54+
if len(targets) == 1:
55+
if not isinstance(targets[0], Iterable):
56+
raise TypeError(f'{targets[0]} object is not iterable.')
57+
t0 = list(targets[0])
58+
iterator = [t0] if t0 and isinstance(t0[0], raw_types.Qid) else t0
59+
for target in iterator:
60+
if not isinstance(target, Sequence):
61+
raise ValueError(
62+
f'Inputs to multi-qubit gates must be Sequence[Qid].'
63+
f' Type: {type(target)}'
64+
)
65+
if not all(isinstance(x, raw_types.Qid) for x in target):
66+
raise ValueError(f'All values in sequence should be Qids, but got {target}')
67+
if len(target) != self._num_qubits_():
68+
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
69+
operations.append(self.on(*target))
70+
return operations
71+
5372
for target in targets:
5473
if isinstance(target, raw_types.Qid):
5574
operations.append(self.on(target))

cirq-core/cirq/ops/identity_test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,28 @@ def test_identity_on_each_only_single_qubit():
6666
cirq.IdentityGate(1, (3,)).on(q0_3),
6767
cirq.IdentityGate(1, (3,)).on(q1_3),
6868
]
69-
with pytest.raises(ValueError, match='one qubit'):
70-
cirq.IdentityGate(num_qubits=2).on_each(q0, q1)
69+
70+
71+
def test_identity_on_each_two_qubits():
72+
q0, q1, q2, q3 = cirq.LineQubit.range(4)
73+
q0_3, q1_3 = q0.with_dimension(3), q1.with_dimension(3)
74+
assert cirq.IdentityGate(2).on_each([(q0, q1)]) == [cirq.IdentityGate(2)(q0, q1)]
75+
assert cirq.IdentityGate(2).on_each([(q0, q1), (q2, q3)]) == [
76+
cirq.IdentityGate(2)(q0, q1),
77+
cirq.IdentityGate(2)(q2, q3),
78+
]
79+
assert cirq.IdentityGate(2, (3, 3)).on_each([(q0_3, q1_3)]) == [
80+
cirq.IdentityGate(2, (3, 3))(q0_3, q1_3),
81+
]
82+
assert cirq.IdentityGate(2).on_each((q0, q1)) == [cirq.IdentityGate(2)(q0, q1)]
83+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
84+
cirq.IdentityGate(2).on_each(q0, q1)
85+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
86+
cirq.IdentityGate(2).on_each([[(q0, q1)]])
87+
with pytest.raises(ValueError, match='Expected 2 qubits'):
88+
cirq.IdentityGate(2).on_each([(q0,)])
89+
with pytest.raises(ValueError, match='Expected 2 qubits'):
90+
cirq.IdentityGate(2).on_each([(q0, q1, q2)])
7191

7292

7393
@pytest.mark.parametrize('num_qubits', [1, 2, 4])

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import AbstractSet
15+
from typing import AbstractSet, Iterator, Any
1616

1717
import pytest
1818
import numpy as np
@@ -739,3 +739,176 @@ def qubits(self):
739739
cirq.act_on(NoActOn()(q).with_tags("test"), args)
740740
with pytest.raises(TypeError, match="Failed to act"):
741741
cirq.act_on(MissingActOn().with_tags("test"), args)
742+
743+
744+
def test_single_qubit_gate_validates_on_each():
745+
class Dummy(cirq.SingleQubitGate):
746+
def matrix(self):
747+
pass
748+
749+
g = Dummy()
750+
assert g.num_qubits() == 1
751+
752+
test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)]
753+
754+
_ = g.on_each(*test_qubits)
755+
_ = g.on_each(test_qubits)
756+
757+
test_non_qubits = [str(i) for i in range(3)]
758+
with pytest.raises(ValueError):
759+
_ = g.on_each(*test_non_qubits)
760+
with pytest.raises(ValueError):
761+
_ = g.on_each(*test_non_qubits)
762+
763+
764+
def test_on_each():
765+
class CustomGate(cirq.SingleQubitGate):
766+
pass
767+
768+
a = cirq.NamedQubit('a')
769+
b = cirq.NamedQubit('b')
770+
c = CustomGate()
771+
772+
assert c.on_each() == []
773+
assert c.on_each(a) == [c(a)]
774+
assert c.on_each(a, b) == [c(a), c(b)]
775+
assert c.on_each(b, a) == [c(b), c(a)]
776+
777+
assert c.on_each([]) == []
778+
assert c.on_each([a]) == [c(a)]
779+
assert c.on_each([a, b]) == [c(a), c(b)]
780+
assert c.on_each([b, a]) == [c(b), c(a)]
781+
assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)]
782+
783+
with pytest.raises(ValueError):
784+
c.on_each('abcd')
785+
with pytest.raises(ValueError):
786+
c.on_each(['abcd'])
787+
with pytest.raises(ValueError):
788+
c.on_each([a, 'abcd'])
789+
790+
qubit_iterator = (q for q in [a, b, a, b])
791+
assert isinstance(qubit_iterator, Iterator)
792+
assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)]
793+
794+
795+
def test_on_each_two_qubits():
796+
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.TwoQubitGate):
797+
pass
798+
799+
a = cirq.NamedQubit('a')
800+
b = cirq.NamedQubit('b')
801+
g = CustomGate()
802+
803+
assert g.on_each([]) == []
804+
assert g.on_each([(a, b)]) == [g(a, b)]
805+
assert g.on_each([[a, b]]) == [g(a, b)]
806+
assert g.on_each([(b, a)]) == [g(b, a)]
807+
assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)]
808+
assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)]
809+
assert g.on_each() == []
810+
assert g.on_each((b, a)) == [g(b, a)]
811+
assert g.on_each((a, b), (a, b)) == [g(a, b), g(a, b)]
812+
assert g.on_each(*zip([a, b], [b, a])) == [g(a, b), g(b, a)]
813+
with pytest.raises(TypeError, match='object is not iterable'):
814+
g.on_each(a)
815+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
816+
g.on_each(a, b)
817+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
818+
g.on_each([12])
819+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
820+
g.on_each([(a, b), 12])
821+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
822+
g.on_each([(a, b), [(a, b)]])
823+
with pytest.raises(ValueError, match='Expected 2 qubits'):
824+
g.on_each([()])
825+
with pytest.raises(ValueError, match='Expected 2 qubits'):
826+
g.on_each([(a,)])
827+
with pytest.raises(ValueError, match='Expected 2 qubits'):
828+
g.on_each([(a, b, a)])
829+
with pytest.raises(ValueError, match='Expected 2 qubits'):
830+
g.on_each(zip([a, a]))
831+
with pytest.raises(ValueError, match='Expected 2 qubits'):
832+
g.on_each(zip([a, a], [b, b], [a, a]))
833+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
834+
g.on_each('ab')
835+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
836+
g.on_each(('ab',))
837+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
838+
g.on_each([('ab',)])
839+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
840+
g.on_each([(a, 'ab')])
841+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
842+
g.on_each([(a, 'b')])
843+
844+
qubit_iterator = (qs for qs in [[a, b], [a, b]])
845+
assert isinstance(qubit_iterator, Iterator)
846+
assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)]
847+
848+
849+
def test_on_each_three_qubits():
850+
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.ThreeQubitGate):
851+
pass
852+
853+
a = cirq.NamedQubit('a')
854+
b = cirq.NamedQubit('b')
855+
c = cirq.NamedQubit('c')
856+
g = CustomGate()
857+
858+
assert g.on_each([]) == []
859+
assert g.on_each([(a, b, c)]) == [g(a, b, c)]
860+
assert g.on_each([[a, b, c]]) == [g(a, b, c)]
861+
assert g.on_each([(c, b, a)]) == [g(c, b, a)]
862+
assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)]
863+
assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)]
864+
assert g.on_each() == []
865+
assert g.on_each((c, b, a)) == [g(c, b, a)]
866+
assert g.on_each((a, b, c), (c, b, a)) == [g(a, b, c), g(c, b, a)]
867+
assert g.on_each(*zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)]
868+
with pytest.raises(TypeError, match='object is not iterable'):
869+
g.on_each(a)
870+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
871+
g.on_each(a, b, c)
872+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
873+
g.on_each([12])
874+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
875+
g.on_each([(a, b, c), 12])
876+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
877+
g.on_each([(a, b, c), [(a, b, c)]])
878+
with pytest.raises(ValueError, match='Expected 3 qubits'):
879+
g.on_each([(a,)])
880+
with pytest.raises(ValueError, match='Expected 3 qubits'):
881+
g.on_each([(a, b)])
882+
with pytest.raises(ValueError, match='Expected 3 qubits'):
883+
g.on_each([(a, b, c, a)])
884+
with pytest.raises(ValueError, match='Expected 3 qubits'):
885+
g.on_each(zip([a, a], [b, b]))
886+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
887+
g.on_each('abc')
888+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
889+
g.on_each(('abc',))
890+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
891+
g.on_each([('abc',)])
892+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
893+
g.on_each([(a, 'abc')])
894+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
895+
g.on_each([(a, 'bc')])
896+
897+
qubit_iterator = (qs for qs in [[a, b, c], [a, b, c]])
898+
assert isinstance(qubit_iterator, Iterator)
899+
assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)]
900+
901+
902+
def test_on_each_iterable_qid():
903+
class QidIter(cirq.Qid):
904+
@property
905+
def dimension(self) -> int:
906+
return 2
907+
908+
def _comparison_key(self) -> Any:
909+
return 1
910+
911+
def __iter__(self):
912+
raise NotImplementedError()
913+
914+
assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())

0 commit comments

Comments
 (0)