Skip to content

Commit ea06bfd

Browse files
Refactored and created an abstraction for control values (#5362)
Created a concrete implementation that replicates current behaviour. refactored the code for control_values and abstracted its calls this is a first step towards implementing more ways to represent control values in order to solve #4512
1 parent 7c9c779 commit ea06bfd

15 files changed

+397
-81
lines changed

Diff for: cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@
278278
PhasedXZGate,
279279
PhaseFlipChannel,
280280
StatePreparationChannel,
281+
ProductOfSums,
281282
ProjectorString,
282283
ProjectorSum,
283284
RandomGateChannel,

Diff for: cirq-core/cirq/json_resolver_cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _symmetricalqidpair(qids):
158158
'PhasedXPowGate': cirq.PhasedXPowGate,
159159
'PhasedXZGate': cirq.PhasedXZGate,
160160
'ProductState': cirq.ProductState,
161+
'ProductOfSums': cirq.ProductOfSums,
161162
'ProjectorString': cirq.ProjectorString,
162163
'ProjectorSum': cirq.ProjectorSum,
163164
'QasmUGate': cirq.circuits.qasm_output.QasmUGate,

Diff for: cirq-core/cirq/ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,5 @@
209209
from cirq.ops.wait_gate import wait, WaitGate
210210

211211
from cirq.ops.state_preparation_channel import StatePreparationChannel
212+
213+
from cirq.ops.control_values import AbstractControlValues, ProductOfSums

Diff for: cirq-core/cirq/ops/common_gates.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from cirq import protocols, value
4545
from cirq._compat import proper_repr
4646
from cirq._doc import document
47-
from cirq.ops import controlled_gate, eigen_gate, gate_features, raw_types
47+
from cirq.ops import controlled_gate, eigen_gate, gate_features, raw_types, control_values as cv
4848

4949
from cirq.type_workarounds import NotImplementedType
5050

@@ -159,7 +159,9 @@ def _trace_distance_bound_(self) -> Optional[float]:
159159
def controlled(
160160
self,
161161
num_controls: int = None,
162-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
162+
control_values: Optional[
163+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
164+
] = None,
163165
control_qid_shape: Optional[Tuple[int, ...]] = None,
164166
) -> raw_types.Gate:
165167
"""Returns a controlled `XPowGate`, using a `CXPowGate` where possible.
@@ -566,7 +568,9 @@ def with_canonical_global_phase(self) -> 'ZPowGate':
566568
def controlled(
567569
self,
568570
num_controls: int = None,
569-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
571+
control_values: Optional[
572+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
573+
] = None,
570574
control_qid_shape: Optional[Tuple[int, ...]] = None,
571575
) -> raw_types.Gate:
572576
"""Returns a controlled `ZPowGate`, using a `CZPowGate` where possible.
@@ -998,7 +1002,9 @@ def _phase_by_(self, phase_turns, qubit_index):
9981002
def controlled(
9991003
self,
10001004
num_controls: int = None,
1001-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
1005+
control_values: Optional[
1006+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
1007+
] = None,
10021008
control_qid_shape: Optional[Tuple[int, ...]] = None,
10031009
) -> raw_types.Gate:
10041010
"""Returns a controlled `CZPowGate`, using a `CCZPowGate` where possible.
@@ -1187,7 +1193,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
11871193
def controlled(
11881194
self,
11891195
num_controls: int = None,
1190-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
1196+
control_values: Optional[
1197+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
1198+
] = None,
11911199
control_qid_shape: Optional[Tuple[int, ...]] = None,
11921200
) -> raw_types.Gate:
11931201
"""Returns a controlled `CXPowGate`, using a `CCXPowGate` where possible.

Diff for: cirq-core/cirq/ops/control_values.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import abc
15+
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, cast, Iterator
16+
from dataclasses import dataclass
17+
18+
import itertools
19+
20+
if TYPE_CHECKING:
21+
import cirq
22+
23+
24+
class AbstractControlValues(abc.ABC):
25+
"""Abstract base class defining the API for control values.
26+
27+
`AbstractControlValues` is an abstract class that defines the API for control values
28+
and implements functions common to all implementations (e.g. comparison).
29+
30+
`cirq.ControlledGate` and `cirq.ControlledOperation` are useful to augment
31+
existing gates and operations to have one or more control qubits. For every
32+
control qubit, the set of integer values for which the control should be enabled
33+
is represented by one of the implementations of `cirq.AbstractControlValues`.
34+
35+
Implementations of `cirq.AbstractControlValues` can use different internal
36+
representations to store control values, but they must satisfy the public API
37+
defined here and be immutable.
38+
"""
39+
40+
@abc.abstractmethod
41+
def __and__(self, other: 'AbstractControlValues') -> 'AbstractControlValues':
42+
"""Sets self to be the cartesian product of all combinations in self x other.
43+
44+
Args:
45+
other: An object that implements AbstractControlValues.
46+
47+
Returns:
48+
An object that represents the cartesian product of the two inputs.
49+
"""
50+
51+
@abc.abstractmethod
52+
def _expand(self) -> Iterator[Tuple[int, ...]]:
53+
"""Expands the (possibly compressed) internal representation into a sum of products representation.""" # pylint: disable=line-too-long
54+
55+
@abc.abstractmethod
56+
def diagram_repr(self) -> str:
57+
"""Returns a string representation to be used in circuit diagrams."""
58+
59+
@abc.abstractmethod
60+
def _number_variables(self) -> int:
61+
"""Returns the number of variables controlled by the object."""
62+
63+
@abc.abstractmethod
64+
def __len__(self) -> int:
65+
pass
66+
67+
@abc.abstractmethod
68+
def _identifier(self) -> Any:
69+
"""Returns the internal representation of the object."""
70+
71+
@abc.abstractmethod
72+
def __hash__(self) -> int:
73+
pass
74+
75+
@abc.abstractmethod
76+
def __repr__(self) -> str:
77+
pass
78+
79+
@abc.abstractmethod
80+
def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
81+
"""Validates control values
82+
83+
Validate that control values are in the half closed interval
84+
[0, qid_shapes) for each qubit.
85+
"""
86+
87+
@abc.abstractmethod
88+
def _are_ones(self) -> bool:
89+
"""Checks whether all control values are equal to 1."""
90+
91+
@abc.abstractmethod
92+
def _json_dict_(self) -> Dict[str, Any]:
93+
pass
94+
95+
@abc.abstractmethod
96+
def __getitem__(
97+
self, key: Union[slice, int]
98+
) -> Union['AbstractControlValues', Tuple[int, ...]]:
99+
pass
100+
101+
def __iter__(self) -> Generator[Tuple[int, ...], None, None]:
102+
for assignment in self._expand():
103+
yield assignment
104+
105+
def __eq__(self, other) -> bool:
106+
"""Returns True iff self and other represent the same configurations.
107+
108+
Args:
109+
other: A AbstractControlValues object.
110+
111+
Returns:
112+
boolean whether the two objects are equivalent or not.
113+
"""
114+
if not isinstance(other, AbstractControlValues):
115+
other = ProductOfSums(other)
116+
return sorted(v for v in self) == sorted(v for v in other)
117+
118+
119+
@dataclass(frozen=True, eq=False)
120+
class ProductOfSums(AbstractControlValues):
121+
"""ProductOfSums represents control values in a form of a cartesian product of tuples."""
122+
123+
_internal_representation: Tuple[Tuple[int, ...], ...]
124+
125+
def _identifier(self) -> Tuple[Tuple[int, ...], ...]:
126+
return self._internal_representation
127+
128+
def _expand(self) -> Iterator[Tuple[int, ...]]:
129+
"""Returns the combinations tracked by the object."""
130+
self = cast('ProductOfSums', self)
131+
return itertools.product(*self._internal_representation)
132+
133+
def __repr__(self) -> str:
134+
return f'cirq.ProductOfSums({str(self._identifier())})'
135+
136+
def _number_variables(self) -> int:
137+
return len(self._internal_representation)
138+
139+
def __len__(self) -> int:
140+
return self._number_variables()
141+
142+
def __hash__(self) -> int:
143+
return hash(self._internal_representation)
144+
145+
def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
146+
for i, (vals, shape) in enumerate(zip(self._internal_representation, qid_shapes)):
147+
if not all(0 <= v < shape for v in vals):
148+
message = (
149+
f'Control values <{vals!r}> outside of range for control qubit '
150+
f'number <{i}>.'
151+
)
152+
raise ValueError(message)
153+
154+
def _are_ones(self) -> bool:
155+
return frozenset(self._internal_representation) == {(1,)}
156+
157+
def diagram_repr(self) -> str:
158+
if self._are_ones():
159+
return 'C' * self._number_variables()
160+
161+
def get_prefix(control_vals):
162+
control_vals_str = ''.join(map(str, sorted(control_vals)))
163+
return f'C{control_vals_str}'
164+
165+
return ''.join(map(get_prefix, self._internal_representation))
166+
167+
def __getitem__(
168+
self, key: Union[int, slice]
169+
) -> Union['AbstractControlValues', Tuple[int, ...]]:
170+
if isinstance(key, slice):
171+
return ProductOfSums(self._internal_representation[key])
172+
return self._internal_representation[key]
173+
174+
def _json_dict_(self) -> Dict[str, Any]:
175+
return {'_internal_representation': self._internal_representation}
176+
177+
def __and__(self, other: AbstractControlValues) -> 'ProductOfSums':
178+
if not isinstance(other, ProductOfSums):
179+
raise TypeError(
180+
f'And operation not supported between types ProductOfSums and {type(other)}'
181+
)
182+
return type(self)(self._internal_representation + other._internal_representation)

Diff for: cirq-core/cirq/ops/control_values_test.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cirq
16+
import pytest
17+
from cirq.ops import control_values as cv
18+
19+
20+
def test_init_productOfSum():
21+
eq = cirq.testing.EqualsTester()
22+
tests = [
23+
(((1,),), {(1,)}),
24+
(((0, 1), (1,)), {(0, 1), (1, 1)}),
25+
((((0, 1), (1, 0))), {(0, 0), (0, 1), (1, 0), (1, 1)}),
26+
]
27+
for control_values, want in tests:
28+
print(control_values)
29+
got = {c for c in cv.ProductOfSums(control_values)}
30+
eq.add_equality_group(got, want)
31+
32+
33+
def test_and_operation():
34+
eq = cirq.testing.EqualsTester()
35+
originals = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
36+
for control_values1 in originals:
37+
for control_values2 in originals:
38+
control_vals1 = cv.ProductOfSums(control_values1)
39+
control_vals2 = cv.ProductOfSums(control_values2)
40+
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
41+
got = [c for c in control_vals1 & control_vals2]
42+
eq.add_equality_group(got, want)
43+
44+
45+
def test_and_supported_types():
46+
CV = cv.ProductOfSums((1,))
47+
with pytest.raises(TypeError):
48+
_ = CV & 1
49+
50+
51+
def test_repr():
52+
product_of_sums_data = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
53+
for t in map(cv.ProductOfSums, product_of_sums_data):
54+
cirq.testing.assert_equivalent_repr(t)

0 commit comments

Comments
 (0)