Skip to content

Commit ec76158

Browse files
Create an abstraction for control values and a concrete implementation that replicates current behaviour.
refactored the code for control_values and abstracted its calls this is a first step towards implementing more ways to represent control values in order to solve quantumlib#4512
1 parent 5162326 commit ec76158

14 files changed

+316
-83
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@
283283
PhasedXZGate,
284284
PhaseFlipChannel,
285285
StatePreparationChannel,
286+
ProductOfSums,
286287
ProjectorString,
287288
ProjectorSum,
288289
RandomGateChannel,

cirq-core/cirq/json_resolver_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _symmetricalqidpair(qids):
157157
'PhasedXPowGate': cirq.PhasedXPowGate,
158158
'PhasedXZGate': cirq.PhasedXZGate,
159159
'ProductState': cirq.ProductState,
160+
'ProductOfSums': cirq.ProductOfSums,
160161
'ProjectorString': cirq.ProjectorString,
161162
'ProjectorSum': cirq.ProjectorSum,
162163
'QasmUGate': cirq.circuits.qasm_output.QasmUGate,

cirq-core/cirq/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,5 @@
208208
from cirq.ops.wait_gate import wait, WaitGate
209209

210210
from cirq.ops.state_preparation_channel import StatePreparationChannel
211+
212+
from cirq.ops.control_values import AbstractControlValues, ProductOfSums

cirq-core/cirq/ops/common_gates.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from cirq import protocols, value
4545
from cirq._compat import proper_repr
4646
from cirq._doc import document
47-
from cirq.ops import controlled_gate, eigen_gate, gate_features, raw_types
47+
from cirq.ops import controlled_gate, eigen_gate, gate_features, raw_types, control_values as cv
4848

4949
from cirq.type_workarounds import NotImplementedType
5050

@@ -159,7 +159,9 @@ def _trace_distance_bound_(self) -> Optional[float]:
159159
def controlled(
160160
self,
161161
num_controls: int = None,
162-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
162+
control_values: Optional[
163+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
164+
] = None,
163165
control_qid_shape: Optional[Tuple[int, ...]] = None,
164166
) -> raw_types.Gate:
165167
"""Returns a controlled `XPowGate`, using a `CXPowGate` where possible.
@@ -557,7 +559,9 @@ def with_canonical_global_phase(self) -> 'ZPowGate':
557559
def controlled(
558560
self,
559561
num_controls: int = None,
560-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
562+
control_values: Optional[
563+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
564+
] = None,
561565
control_qid_shape: Optional[Tuple[int, ...]] = None,
562566
) -> raw_types.Gate:
563567
"""Returns a controlled `ZPowGate`, using a `CZPowGate` where possible.
@@ -978,7 +982,9 @@ def _phase_by_(self, phase_turns, qubit_index):
978982
def controlled(
979983
self,
980984
num_controls: int = None,
981-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
985+
control_values: Optional[
986+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
987+
] = None,
982988
control_qid_shape: Optional[Tuple[int, ...]] = None,
983989
) -> raw_types.Gate:
984990
"""Returns a controlled `CZPowGate`, using a `CCZPowGate` where possible.
@@ -1167,7 +1173,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
11671173
def controlled(
11681174
self,
11691175
num_controls: int = None,
1170-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
1176+
control_values: Optional[
1177+
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
1178+
] = None,
11711179
control_qid_shape: Optional[Tuple[int, ...]] = None,
11721180
) -> raw_types.Gate:
11731181
"""Returns a controlled `CXPowGate`, using a `CCXPowGate` where possible.

cirq-core/cirq/ops/control_values.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import abc
15+
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict
16+
from dataclasses import dataclass
17+
18+
import itertools
19+
20+
if TYPE_CHECKING:
21+
import cirq
22+
23+
24+
@dataclass(frozen=True, eq=False) # type: ignore
25+
class AbstractControlValues(abc.ABC):
26+
"""AbstractControlValues is an abstract immutable data class.
27+
28+
AbstractControlValues defines an API for control values and implements
29+
functions common to all implementations (e.g. comparison).
30+
"""
31+
32+
_internal_representation: Any
33+
34+
def __and__(self, other: 'AbstractControlValues') -> 'AbstractControlValues':
35+
"""Sets self to be the cartesian product of all combinations in self x other.
36+
37+
Args:
38+
other: An object that implements AbstractControlValues.
39+
40+
Returns:
41+
An object that represents the cartesian product of the two inputs.
42+
"""
43+
return type(self)(self._internal_representation + other._internal_representation)
44+
45+
def _iterator(self):
46+
return self._expand()
47+
48+
@abc.abstractmethod
49+
def _expand(self):
50+
"""Returns the control values tracked by the object."""
51+
52+
@abc.abstractmethod
53+
def diagram_repr(self) -> str:
54+
"""Returns a string representation to be used in circuit diagrams."""
55+
56+
@abc.abstractmethod
57+
def _number_variables(self):
58+
"""Returns the control values tracked by the object."""
59+
60+
@abc.abstractmethod
61+
def __len__(self):
62+
pass
63+
64+
@abc.abstractmethod
65+
def identifier(self) -> Tuple[Any]:
66+
"""Returns an identifier from which the object can be rebuilt."""
67+
68+
@abc.abstractmethod
69+
def __hash__(self):
70+
pass
71+
72+
@abc.abstractmethod
73+
def __repr__(self) -> str:
74+
pass
75+
76+
@abc.abstractmethod
77+
def _validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
78+
"""Validates control values
79+
80+
Validate that control values are in the half closed interval
81+
[0, qid_shapes) for each qubit.
82+
"""
83+
84+
@abc.abstractmethod
85+
def _are_ones(self) -> bool:
86+
"""Checks whether all control values are equal to 1."""
87+
88+
@abc.abstractmethod
89+
def _json_dict_(self) -> Dict[str, Any]:
90+
pass
91+
92+
@abc.abstractmethod
93+
def __getitem__(self, key):
94+
pass
95+
96+
def __iter__(self):
97+
for assignment in self._iterator():
98+
yield assignment
99+
100+
def __eq__(self, other):
101+
"""Returns True iff self and other represent the same configurations.
102+
103+
Args:
104+
other: A AbstractControlValues object.
105+
106+
Returns:
107+
boolean whether the two objects are equivalent or not.
108+
"""
109+
if not isinstance(other, AbstractControlValues):
110+
other = ProductOfSums(other)
111+
return sorted(v for v in self) == sorted(v for v in other)
112+
113+
114+
@AbstractControlValues.register
115+
class ProductOfSums(AbstractControlValues):
116+
"""ProductOfSums represents control values in a form of a cartesian product of tuples."""
117+
118+
_internal_representation: Tuple[Tuple[int, ...]]
119+
120+
def identifier(self):
121+
return self._internal_representation
122+
123+
def _expand(self):
124+
"""Returns the combinations tracked by the object."""
125+
return itertools.product(*self._internal_representation)
126+
127+
def __repr__(self):
128+
return f'cirq.ProductOfSums({str(self.identifier())})'
129+
130+
def _number_variables(self) -> int:
131+
return len(self._internal_representation)
132+
133+
def __len__(self):
134+
return self._number_variables()
135+
136+
def __hash__(self):
137+
return hash(self._internal_representation)
138+
139+
def _validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> None:
140+
for i, (vals, shape) in enumerate(zip(self._internal_representation, qid_shapes)):
141+
if not all(0 <= v < shape for v in vals):
142+
message = (
143+
f'Control values <{vals!r}> outside of range for control qubit '
144+
f'number <{i}>.'
145+
)
146+
raise ValueError(message)
147+
148+
def _are_ones(self) -> bool:
149+
return frozenset(self._internal_representation) == {(1,)}
150+
151+
def diagram_repr(self) -> str:
152+
if self._are_ones():
153+
return 'C' * self._number_variables()
154+
155+
def get_prefix(control_vals):
156+
control_vals_str = ''.join(map(str, sorted(control_vals)))
157+
return f'C{control_vals_str}'
158+
159+
return ''.join(map(get_prefix, self._internal_representation))
160+
161+
def __getitem__(self, key):
162+
if isinstance(key, slice):
163+
return ProductOfSums(self._internal_representation[key])
164+
return self._internal_representation[key]
165+
166+
def _json_dict_(self) -> Dict[str, Any]:
167+
return {
168+
'_internal_representation': self._internal_representation,
169+
'cirq_type': 'ProductOfSums',
170+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cirq
16+
from cirq.ops import control_values as cv
17+
18+
19+
def test_init_productOfSum():
20+
eq = cirq.testing.EqualsTester()
21+
tests = [
22+
(((1,),), {(1,)}),
23+
(((0, 1), (1,)), {(0, 1), (1, 1)}),
24+
((((0, 1), (1, 0))), {(0, 0), (0, 1), (1, 0), (1, 1)}),
25+
]
26+
for control_values, want in tests:
27+
print(control_values)
28+
got = {c for c in cv.ProductOfSums(control_values)}
29+
eq.add_equality_group(got, want)
30+
31+
32+
def test_and_operation():
33+
eq = cirq.testing.EqualsTester()
34+
originals = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
35+
for control_values1 in originals:
36+
for control_values2 in originals:
37+
control_vals1 = cv.ProductOfSums(control_values1)
38+
control_vals2 = cv.ProductOfSums(control_values2)
39+
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
40+
got = [c for c in control_vals1 & control_vals2]
41+
eq.add_equality_group(got, want)

0 commit comments

Comments
 (0)