Skip to content

Move on_each to cirq.Gate #4236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def asymmetric_depolarize(


@value.value_equality
class DepolarizingChannel(gate_features.SupportsOnEachGate, raw_types.Gate):
class DepolarizingChannel(raw_types.Gate):
"""A channel that depolarizes one or several qubits."""

def __init__(self, p: float, n_qubits: int = 1) -> None:
Expand Down
17 changes: 15 additions & 2 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,24 @@ def test_deprecated_on_each_for_depolarizing_channel_one_qubit():


def test_deprecated_on_each_for_depolarizing_channel_two_qubits():
q0, q1 = cirq.LineQubit.range(2)
q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6)
op = cirq.DepolarizingChannel(p=0.1, n_qubits=2)

with pytest.raises(ValueError, match="one qubit"):
op.on_each([(q0, q1)])
op.on_each([(q0, q1), (q2, q3)])
op.on_each(zip([q0, q2, q4], [q1, q3, q5]))
with pytest.raises(ValueError, match='cannot be in varargs form'):
op.on_each(q0, q1)
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
op.on_each((q0, q1))
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
op.on_each([q0, q1])
with pytest.raises(ValueError, match='Qid'):
op.on_each([('bogus object 0', 'bogus object 1')])
with pytest.raises(ValueError, match='Qid'):
op.on_each(['01'])
with pytest.raises(ValueError, match='Qid'):
op.on_each([(False, None)])


def test_depolarizing_channel_apply_two_qubits():
Expand Down
43 changes: 13 additions & 30 deletions cirq-core/cirq/ops/gate_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
"""

import abc
from typing import Union, Iterable, Any, List
from typing import Union, Iterable, Any, List, Sequence

from cirq import value
from cirq._compat import deprecated_class
from cirq.ops import raw_types


Expand All @@ -31,38 +33,19 @@ def qubit_index_to_equivalence_group_key(self, index: int) -> int:
return 0


class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta):
"""A gate that can be applied to exactly one qubit."""
class _SupportsOnEachGateMeta(value.ABCMetaImplementAnyOneOf):
def __instancecheck__(cls, instance):
return isinstance(instance, SingleQubitGate) or issubclass(
type(instance), SupportsOnEachGate
)

def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]:
"""Returns a list of operations applying the gate to all targets.

@deprecated_class(deadline='v0.14', fix='None, this feature is in `Gate` now.')
class SupportsOnEachGate(raw_types.Gate, metaclass=_SupportsOnEachGateMeta):
pass

Args:
*targets: The qubits to apply this gate to.

Returns:
Operations applying this gate to the target qubits.

Raises:
ValueError if targets are not instances of Qid or List[Qid].
ValueError if the gate operates on two or more Qids.
"""
if self._num_qubits_() > 1:
raise ValueError('This gate only supports on_each when it is a one qubit gate.')
operations = [] # type: List[raw_types.Operation]
for target in targets:
if isinstance(target, raw_types.Qid):
operations.append(self.on(target))
elif isinstance(target, Iterable) and not isinstance(target, str):
operations.extend(self.on_each(*target))
else:
raise ValueError(
f'Gate was called with type different than Qid. Type: {type(target)}'
)
return operations


class SingleQubitGate(SupportsOnEachGate, metaclass=abc.ABCMeta):
class SingleQubitGate(raw_types.Gate, metaclass=abc.ABCMeta):
"""A gate that must be applied to exactly one qubit."""

def _num_qubits_(self) -> int:
Expand Down
102 changes: 34 additions & 68 deletions cirq-core/cirq/ops/gate_features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterator
from typing import Any

import pytest

import cirq
from cirq.testing import assert_deprecated


def test_single_qubit_gate_validate_args():
Expand All @@ -38,26 +36,6 @@ def matrix(self):
g.validate_args([q1, q2])


def test_single_qubit_gate_validates_on_each():
class Dummy(cirq.SingleQubitGate):
def matrix(self):
pass

g = Dummy()
assert g.num_qubits() == 1

test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)]

_ = g.on_each(*test_qubits)
_ = g.on_each(test_qubits)

test_non_qubits = [str(i) for i in range(3)]
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)


def test_single_qubit_validates_on():
class Dummy(cirq.SingleQubitGate):
def matrix(self):
Expand Down Expand Up @@ -137,41 +115,6 @@ def matrix(self):
g.validate_args([a, b, c, d])


def test_on_each():
class CustomGate(cirq.SingleQubitGate):
pass

a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
c = CustomGate()

assert c.on_each() == []
assert c.on_each(a) == [c(a)]
assert c.on_each(a, b) == [c(a), c(b)]
assert c.on_each(b, a) == [c(b), c(a)]

assert c.on_each([]) == []
assert c.on_each([a]) == [c(a)]
assert c.on_each([a, b]) == [c(a), c(b)]
assert c.on_each([b, a]) == [c(b), c(a)]
assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)]

with pytest.raises(ValueError):
c.on_each('abcd')
with pytest.raises(ValueError):
c.on_each(['abcd'])
with pytest.raises(ValueError):
c.on_each([a, 'abcd'])

def iterator(qubits):
for i in range(len(qubits)):
yield qubits[i]

qubit_iterator = iterator([a, b, a, b])
assert isinstance(qubit_iterator, Iterator)
assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)]


def test_qasm_output_args_validate():
args = cirq.QasmArgs(version='2.0')
args.validate_version('2.0')
Expand Down Expand Up @@ -231,16 +174,39 @@ def __init__(self, num_qubits):
g.validate_args([a, b, c, d])


def test_on_each_iterable_qid():
class QidIter(cirq.Qid):
@property
def dimension(self) -> int:
return 2
def test_supports_on_each_inheritance_shim():
class NotOnEach(cirq.Gate):
def num_qubits(self):
return 1 # coverage: ignore

class OnEach(cirq.ops.gate_features.SupportsOnEachGate):
def num_qubits(self):
return 1 # coverage: ignore

class SingleQ(cirq.SingleQubitGate):
pass

class TwoQ(cirq.TwoQubitGate):
pass

not_on_each = NotOnEach()
single_q = SingleQ()
two_q = TwoQ()
with assert_deprecated(deadline="v0.14"):
on_each = OnEach()

assert not isinstance(not_on_each, cirq.ops.gate_features.SupportsOnEachGate)
assert isinstance(on_each, cirq.ops.gate_features.SupportsOnEachGate)
assert isinstance(single_q, cirq.ops.gate_features.SupportsOnEachGate)
assert not isinstance(two_q, cirq.ops.gate_features.SupportsOnEachGate)
assert isinstance(cirq.X, cirq.ops.gate_features.SupportsOnEachGate)
assert not isinstance(cirq.CX, cirq.ops.gate_features.SupportsOnEachGate)

def _comparison_key(self) -> Any:
return 1

def __iter__(self):
raise NotImplementedError()
def test_supports_on_each_deprecation():
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate):
def num_qubits(self):
return 1 # coverage: ignore

assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())
with assert_deprecated(deadline="v0.14"):
assert isinstance(CustomGate(), cirq.ops.gate_features.SupportsOnEachGate)
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@

from cirq import protocols, value
from cirq._doc import document
from cirq.ops import gate_features, raw_types
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


@value.value_equality
class IdentityGate(gate_features.SupportsOnEachGate, raw_types.Gate):
class IdentityGate(raw_types.Gate):
"""A Gate that perform no operation on qubits.

The unitary matrix of this gate is a diagonal matrix with all 1s on the
Expand Down
25 changes: 23 additions & 2 deletions cirq-core/cirq/ops/identity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,29 @@ def test_identity_on_each_only_single_qubit():
cirq.IdentityGate(1, (3,)).on(q0_3),
cirq.IdentityGate(1, (3,)).on(q1_3),
]
with pytest.raises(ValueError, match='one qubit'):
cirq.IdentityGate(num_qubits=2).on_each(q0, q1)


def test_identity_on_each_two_qubits():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
q0_3, q1_3 = q0.with_dimension(3), q1.with_dimension(3)
assert cirq.IdentityGate(2).on_each([(q0, q1)]) == [cirq.IdentityGate(2)(q0, q1)]
assert cirq.IdentityGate(2).on_each([(q0, q1), (q2, q3)]) == [
cirq.IdentityGate(2)(q0, q1),
cirq.IdentityGate(2)(q2, q3),
]
assert cirq.IdentityGate(2, (3, 3)).on_each([(q0_3, q1_3)]) == [
cirq.IdentityGate(2, (3, 3))(q0_3, q1_3),
]
with pytest.raises(ValueError, match='cannot be in varargs form'):
cirq.IdentityGate(2).on_each(q0, q1)
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
cirq.IdentityGate(2).on_each((q0, q1))
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
cirq.IdentityGate(2).on_each([[(q0, q1)]])
with pytest.raises(ValueError, match='Expected 2 qubits'):
cirq.IdentityGate(2).on_each([(q0,)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
cirq.IdentityGate(2).on_each([(q0, q1, q2)])


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
Expand Down
51 changes: 51 additions & 0 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
TypeVar,
TYPE_CHECKING,
Union,
Iterable,
List,
)

import numpy as np
Expand Down Expand Up @@ -375,6 +377,55 @@ def _rmul_with_qubits(self, qubits: Tuple['cirq.Qid', ...], other):
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, attribute_names=[])

def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['Operation']:
"""Returns a list of operations applying the gate to all targets.

Args:
*targets: The qubits to apply this gate to. For single-qubit gates
this can be provided as varargs or a combination of nested
iterables. For multi-qubit gates this must be provided as an
`Iterable[Sequence[Qid]]`, where each sequence has `num_qubits`
qubits.

Returns:
Operations applying this gate to the target qubits.

Raises:
ValueError if targets are not instances of Qid or Iterable[Qid].
ValueError if the gate qubit number is incompatible.
"""
operations: List['Operation'] = []
if self._num_qubits_() > 1:
if len(targets) != 1 or not isinstance(targets[0], Iterable):
raise ValueError(f'The inputs for multi-qubit gates cannot be in varargs form.')
for target in targets[0]:
if not isinstance(target, Sequence):
if isinstance(target, Qid):
raise ValueError(
f'The inputs for multi-qubit gates cannot be in varargs form.'
)
else:
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
else:
for target in targets:
if isinstance(target, Qid):
operations.append(self.on(target))
elif isinstance(target, Iterable) and not isinstance(target, str):
operations.extend(self.on_each(*target))
else:
raise ValueError(
f'Gate was called with type different than Qid. Type: {type(target)}'
)
return operations


TSelf = TypeVar('TSelf', bound='Operation')

Expand Down
Loading