|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from typing import AbstractSet |
| 15 | +from typing import AbstractSet, Iterator, Any |
16 | 16 |
|
17 | 17 | import pytest
|
18 | 18 | import numpy as np
|
@@ -739,3 +739,176 @@ def qubits(self):
|
739 | 739 | cirq.act_on(NoActOn()(q).with_tags("test"), args)
|
740 | 740 | with pytest.raises(TypeError, match="Failed to act"):
|
741 | 741 | cirq.act_on(MissingActOn().with_tags("test"), args)
|
| 742 | + |
| 743 | + |
| 744 | +def test_single_qubit_gate_validates_on_each(): |
| 745 | + class Dummy(cirq.SingleQubitGate): |
| 746 | + def matrix(self): |
| 747 | + pass |
| 748 | + |
| 749 | + g = Dummy() |
| 750 | + assert g.num_qubits() == 1 |
| 751 | + |
| 752 | + test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)] |
| 753 | + |
| 754 | + _ = g.on_each(*test_qubits) |
| 755 | + _ = g.on_each(test_qubits) |
| 756 | + |
| 757 | + test_non_qubits = [str(i) for i in range(3)] |
| 758 | + with pytest.raises(ValueError): |
| 759 | + _ = g.on_each(*test_non_qubits) |
| 760 | + with pytest.raises(ValueError): |
| 761 | + _ = g.on_each(*test_non_qubits) |
| 762 | + |
| 763 | + |
| 764 | +def test_on_each(): |
| 765 | + class CustomGate(cirq.SingleQubitGate): |
| 766 | + pass |
| 767 | + |
| 768 | + a = cirq.NamedQubit('a') |
| 769 | + b = cirq.NamedQubit('b') |
| 770 | + c = CustomGate() |
| 771 | + |
| 772 | + assert c.on_each() == [] |
| 773 | + assert c.on_each(a) == [c(a)] |
| 774 | + assert c.on_each(a, b) == [c(a), c(b)] |
| 775 | + assert c.on_each(b, a) == [c(b), c(a)] |
| 776 | + |
| 777 | + assert c.on_each([]) == [] |
| 778 | + assert c.on_each([a]) == [c(a)] |
| 779 | + assert c.on_each([a, b]) == [c(a), c(b)] |
| 780 | + assert c.on_each([b, a]) == [c(b), c(a)] |
| 781 | + assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)] |
| 782 | + |
| 783 | + with pytest.raises(ValueError): |
| 784 | + c.on_each('abcd') |
| 785 | + with pytest.raises(ValueError): |
| 786 | + c.on_each(['abcd']) |
| 787 | + with pytest.raises(ValueError): |
| 788 | + c.on_each([a, 'abcd']) |
| 789 | + |
| 790 | + qubit_iterator = (q for q in [a, b, a, b]) |
| 791 | + assert isinstance(qubit_iterator, Iterator) |
| 792 | + assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)] |
| 793 | + |
| 794 | + |
| 795 | +def test_on_each_two_qubits(): |
| 796 | + class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.TwoQubitGate): |
| 797 | + pass |
| 798 | + |
| 799 | + a = cirq.NamedQubit('a') |
| 800 | + b = cirq.NamedQubit('b') |
| 801 | + g = CustomGate() |
| 802 | + |
| 803 | + assert g.on_each([]) == [] |
| 804 | + assert g.on_each([(a, b)]) == [g(a, b)] |
| 805 | + assert g.on_each([[a, b]]) == [g(a, b)] |
| 806 | + assert g.on_each([(b, a)]) == [g(b, a)] |
| 807 | + assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)] |
| 808 | + assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)] |
| 809 | + assert g.on_each() == [] |
| 810 | + assert g.on_each((b, a)) == [g(b, a)] |
| 811 | + assert g.on_each((a, b), (a, b)) == [g(a, b), g(a, b)] |
| 812 | + assert g.on_each(*zip([a, b], [b, a])) == [g(a, b), g(b, a)] |
| 813 | + with pytest.raises(TypeError, match='object is not iterable'): |
| 814 | + g.on_each(a) |
| 815 | + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): |
| 816 | + g.on_each(a, b) |
| 817 | + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): |
| 818 | + g.on_each([12]) |
| 819 | + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): |
| 820 | + g.on_each([(a, b), 12]) |
| 821 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 822 | + g.on_each([(a, b), [(a, b)]]) |
| 823 | + with pytest.raises(ValueError, match='Expected 2 qubits'): |
| 824 | + g.on_each([()]) |
| 825 | + with pytest.raises(ValueError, match='Expected 2 qubits'): |
| 826 | + g.on_each([(a,)]) |
| 827 | + with pytest.raises(ValueError, match='Expected 2 qubits'): |
| 828 | + g.on_each([(a, b, a)]) |
| 829 | + with pytest.raises(ValueError, match='Expected 2 qubits'): |
| 830 | + g.on_each(zip([a, a])) |
| 831 | + with pytest.raises(ValueError, match='Expected 2 qubits'): |
| 832 | + g.on_each(zip([a, a], [b, b], [a, a])) |
| 833 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 834 | + g.on_each('ab') |
| 835 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 836 | + g.on_each(('ab',)) |
| 837 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 838 | + g.on_each([('ab',)]) |
| 839 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 840 | + g.on_each([(a, 'ab')]) |
| 841 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 842 | + g.on_each([(a, 'b')]) |
| 843 | + |
| 844 | + qubit_iterator = (qs for qs in [[a, b], [a, b]]) |
| 845 | + assert isinstance(qubit_iterator, Iterator) |
| 846 | + assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)] |
| 847 | + |
| 848 | + |
| 849 | +def test_on_each_three_qubits(): |
| 850 | + class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.ThreeQubitGate): |
| 851 | + pass |
| 852 | + |
| 853 | + a = cirq.NamedQubit('a') |
| 854 | + b = cirq.NamedQubit('b') |
| 855 | + c = cirq.NamedQubit('c') |
| 856 | + g = CustomGate() |
| 857 | + |
| 858 | + assert g.on_each([]) == [] |
| 859 | + assert g.on_each([(a, b, c)]) == [g(a, b, c)] |
| 860 | + assert g.on_each([[a, b, c]]) == [g(a, b, c)] |
| 861 | + assert g.on_each([(c, b, a)]) == [g(c, b, a)] |
| 862 | + assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)] |
| 863 | + assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] |
| 864 | + assert g.on_each() == [] |
| 865 | + assert g.on_each((c, b, a)) == [g(c, b, a)] |
| 866 | + assert g.on_each((a, b, c), (c, b, a)) == [g(a, b, c), g(c, b, a)] |
| 867 | + assert g.on_each(*zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] |
| 868 | + with pytest.raises(TypeError, match='object is not iterable'): |
| 869 | + g.on_each(a) |
| 870 | + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): |
| 871 | + g.on_each(a, b, c) |
| 872 | + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): |
| 873 | + g.on_each([12]) |
| 874 | + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): |
| 875 | + g.on_each([(a, b, c), 12]) |
| 876 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 877 | + g.on_each([(a, b, c), [(a, b, c)]]) |
| 878 | + with pytest.raises(ValueError, match='Expected 3 qubits'): |
| 879 | + g.on_each([(a,)]) |
| 880 | + with pytest.raises(ValueError, match='Expected 3 qubits'): |
| 881 | + g.on_each([(a, b)]) |
| 882 | + with pytest.raises(ValueError, match='Expected 3 qubits'): |
| 883 | + g.on_each([(a, b, c, a)]) |
| 884 | + with pytest.raises(ValueError, match='Expected 3 qubits'): |
| 885 | + g.on_each(zip([a, a], [b, b])) |
| 886 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 887 | + g.on_each('abc') |
| 888 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 889 | + g.on_each(('abc',)) |
| 890 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 891 | + g.on_each([('abc',)]) |
| 892 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 893 | + g.on_each([(a, 'abc')]) |
| 894 | + with pytest.raises(ValueError, match='All values in sequence should be Qids'): |
| 895 | + g.on_each([(a, 'bc')]) |
| 896 | + |
| 897 | + qubit_iterator = (qs for qs in [[a, b, c], [a, b, c]]) |
| 898 | + assert isinstance(qubit_iterator, Iterator) |
| 899 | + assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)] |
| 900 | + |
| 901 | + |
| 902 | +def test_on_each_iterable_qid(): |
| 903 | + class QidIter(cirq.Qid): |
| 904 | + @property |
| 905 | + def dimension(self) -> int: |
| 906 | + return 2 |
| 907 | + |
| 908 | + def _comparison_key(self) -> Any: |
| 909 | + return 1 |
| 910 | + |
| 911 | + def __iter__(self): |
| 912 | + raise NotImplementedError() |
| 913 | + |
| 914 | + assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter()) |
0 commit comments