Skip to content

Commit 2eb698f

Browse files
Created ControlValues for controlled gates/operations, fix for quantumlib#4512
created control_values.py which contains the ControlValues class.\nFreeVars and ConstrainedVars classes are provided for ease of use.\nwhile the basic idea of ControlValues integrating it inside the code base was challening\nthe old way of using control_values assumed it's a tuple of tuples of ints and was used as thus (comparasion, hashing, slicing, fomatting, conditioning, and loops), the ControlValues class had to provide these functionalities\nthe trickiest part to get right was the support for formatting!\nI'll create a follow up PR for unit tests for control_values.py
1 parent 7259925 commit 2eb698f

File tree

3 files changed

+271
-44
lines changed

3 files changed

+271
-44
lines changed

cirq-core/cirq/ops/control_values.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2018 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Collection, Optional, Sequence, Union, Tuple, List, Type
15+
16+
import copy
17+
import itertools
18+
19+
import cirq
20+
21+
22+
def flatten(sequence):
23+
def _flatten_aux(sequence):
24+
if isinstance(sequence, int):
25+
yield sequence
26+
else:
27+
for item in sequence:
28+
yield from _flatten_aux(item)
29+
30+
return tuple(_flatten_aux(sequence))
31+
32+
33+
class ControlValues:
34+
def __init__(
35+
self, control_values: Sequence[Union[int, Collection[int], Type['ControlValues']]]
36+
):
37+
if len(control_values) == 0:
38+
self.nxt = None
39+
self.vals = None
40+
self.num_variables = 0
41+
self.itr = None
42+
return
43+
self.itr = None
44+
self.nxt = None
45+
self.vals = control_values[0]
46+
self.num_variables = 0
47+
48+
if len(control_values) > 1:
49+
self.nxt = ControlValues(control_values[1:])
50+
51+
if isinstance(self.vals, int):
52+
self.vals = ((self.vals,),)
53+
54+
if not isinstance(self.vals, ControlValues) and isinstance(self.vals[0], int):
55+
self.vals = tuple((val,) if isinstance(val, int) else tuple(val) for val in self.vals)
56+
57+
if isinstance(control_values[0], ControlValues):
58+
aux = control_values[0].copy().And(self.nxt)
59+
self.vals, self.num_variables, self.nxt = aux.vals, aux.num_variables, aux.nxt
60+
else:
61+
self.vals = (Tuple[Tuple[int, ...], ...], self.vals)
62+
self.num_variables = len(self.vals[0])
63+
64+
def And(self, other): # pylint: disable=invalid-name
65+
# Cartesian product of all combinations in self x other
66+
if other is None:
67+
return
68+
if not isinstance(other, ControlValues):
69+
raise ValueError
70+
other = other.copy()
71+
cur = self
72+
while cur.nxt:
73+
cur = cur.nxt
74+
cur.nxt = other
75+
76+
def __call__(self):
77+
return self.__iter__()
78+
79+
def __iter__(self):
80+
nxt = self.nxt if self.nxt else lambda: [()]
81+
if self.num_variables:
82+
self.itr = itertools.product(self.vals, nxt())
83+
else:
84+
self.itr = itertools.product(*())
85+
return self.itr
86+
87+
def __next__(self):
88+
for val in self.itr:
89+
yield val
90+
91+
def copy(self):
92+
new_copy = ControlValues(
93+
[
94+
copy.deepcopy(self.vals),
95+
]
96+
)
97+
new_copy.nxt = None
98+
if self.nxt:
99+
new_copy.nxt = self.nxt.copy()
100+
return new_copy
101+
102+
def __len__(self):
103+
cur = self
104+
num_variables = 0
105+
while cur is not None:
106+
num_variables += cur.num_variables
107+
cur = cur.nxt
108+
return num_variables
109+
110+
def __getitem__(self, key):
111+
if isinstance(key, slice):
112+
if key != slice(None, -1, None):
113+
raise TypeError('Unsupported slicing')
114+
return self.copy().pop()
115+
key = int(key)
116+
num_variables = len(self)
117+
if not 0 <= key < num_variables:
118+
key = key % num_variables
119+
if key < 0:
120+
key += num_variables
121+
cur = self
122+
while cur.num_variables <= key:
123+
key -= cur.num_variables
124+
cur = cur.nxt
125+
return cur
126+
127+
def __eq__(self, other):
128+
if other is None:
129+
return False
130+
if not isinstance(other, ControlValues):
131+
return self == ControlValues(other)
132+
self_values = set(flatten(A) for A in self)
133+
other_values = set(flatten(B) for B in other)
134+
return self_values == other_values
135+
136+
def identifier(self, companions: Sequence[Union[int, 'cirq.Qid']]):
137+
companions = tuple(companions)
138+
controls = []
139+
cur = self
140+
while cur is not None:
141+
controls.append((cur.vals, companions[: cur.num_variables]))
142+
companions = companions[cur.num_variables :]
143+
cur = cur.nxt
144+
return tuple(controls)
145+
146+
def check_dimentionality(
147+
self,
148+
qid_shape: Optional[Union[Tuple[int, ...], List[int]]] = None,
149+
controls: Optional[Union[Tuple['cirq.Qid', ...], List['cirq.Qid']]] = None,
150+
offset=0,
151+
):
152+
if self.num_variables == 0:
153+
return
154+
if qid_shape is None or len(qid_shape) == 0:
155+
qid_shape = [q.dimension for q in controls[: self.num_variables]]
156+
if self.vals is None:
157+
raise ValueError('vals can\'t be None.')
158+
for product in self.vals:
159+
product = flatten(product)
160+
for i in range(self.num_variables):
161+
if not 0 <= product[i] < qid_shape[i]:
162+
message = (
163+
'Control values <{!r}> outside of range ' 'for control qubit number <{!r}>.'
164+
).format(product[i], i + offset)
165+
if controls is not None:
166+
message = (
167+
'Control values <{product[i]!r}> outside of range'
168+
' for qubit <{controls[i]!r}>.'
169+
)
170+
raise ValueError(message)
171+
172+
if self.nxt is not None:
173+
self.nxt.check_dimentionality(
174+
qid_shape=qid_shape[self.num_variables :],
175+
controls=controls[self.num_variables :] if controls else None,
176+
offset=offset + self.num_variables,
177+
)
178+
179+
def are_same_value(self, value: int = 1):
180+
if self.vals is None:
181+
raise ValueError('vals can\'t be None.')
182+
for product in self.vals:
183+
product = flatten(product)
184+
if not all(v == value for v in product):
185+
return False
186+
if self.nxt is not None:
187+
return self.nxt.are_same_value(value)
188+
return True
189+
190+
def arrangements(self):
191+
_arrangements = []
192+
cur = self
193+
while cur is not None:
194+
if cur.num_variables == 1:
195+
_arrangements.append(flatten(cur.vals))
196+
else:
197+
_arrangements.append(flatten(product) for product in cur.vals)
198+
cur = cur.nxt
199+
return _arrangements
200+
201+
def pop(self):
202+
if self.nxt is None:
203+
return None
204+
cur = self
205+
while cur.nxt.nxt is not None:
206+
cur = cur.nxt
207+
cur.nxt = None
208+
return self
209+
210+
211+
class FreeVars(ControlValues):
212+
pass
213+
214+
215+
class ConstrainedVars(ControlValues):
216+
def __init__(self, control_values):
217+
sum_of_product = (tuple(zip(*control_values)),)
218+
super().__init__(sum_of_product)

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union
15+
from typing import AbstractSet, Any, Collection, Dict, Optional, Sequence, Tuple, Union
1616

1717
import numpy as np
1818

1919
import cirq
2020
from cirq import protocols, value
21-
from cirq.ops import raw_types, controlled_operation as cop
21+
from cirq.ops import raw_types, controlled_operation as cop, control_values as cv
2222
from cirq.type_workarounds import NotImplementedType
2323

2424

@@ -33,7 +33,9 @@ def __init__(
3333
self,
3434
sub_gate: 'cirq.Gate',
3535
num_controls: int = None,
36-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
36+
control_values: Optional[
37+
Union[cv.ControlValues, Sequence[Union[int, Collection[int]]]]
38+
] = None,
3739
control_qid_shape: Optional[Sequence[int]] = None,
3840
) -> None:
3941
"""Initializes the controlled gate. If no arguments are specified for
@@ -76,22 +78,22 @@ def __init__(
7678
self.control_qid_shape = tuple(control_qid_shape)
7779

7880
# Convert to sorted tuples
79-
self.control_values = cast(
80-
Tuple[Tuple[int, ...], ...],
81-
tuple((val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values),
82-
)
83-
# Verify control values not out of bounds
84-
for i, (val, dimension) in enumerate(zip(self.control_values, self.control_qid_shape)):
85-
if not all(0 <= v < dimension for v in val):
86-
raise ValueError(
87-
'Control values <{!r}> outside of range for control qubit '
88-
'number <{!r}>.'.format(val, i)
81+
if not isinstance(control_values, cv.ControlValues):
82+
self.control_values = cv.ControlValues(
83+
tuple(
84+
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
8985
)
86+
)
87+
else:
88+
self.control_values = control_values
89+
90+
# Verify control values not out of bounds
91+
self.control_values.check_dimentionality(self.control_qid_shape)
9092

9193
# Flatten nested ControlledGates.
9294
if isinstance(sub_gate, ControlledGate):
9395
self.sub_gate = sub_gate.sub_gate # type: ignore
94-
self.control_values += sub_gate.control_values
96+
self.control_values.And(sub_gate.control_values)
9597
self.control_qid_shape += sub_gate.control_qid_shape
9698
else:
9799
self.sub_gate = sub_gate
@@ -131,7 +133,7 @@ def _value_equality_values_(self):
131133
return (
132134
self.sub_gate,
133135
self.num_controls(),
134-
frozenset(zip(self.control_values, self.control_qid_shape)),
136+
frozenset(self.control_values.identifier(self.control_qid_shape)),
135137
)
136138

137139
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
@@ -223,14 +225,14 @@ def get_symbol(vals):
223225

224226
return protocols.CircuitDiagramInfo(
225227
wire_symbols=(
226-
*(get_symbol(vals) for vals in self.control_values),
228+
*(get_symbol(vals) for vals in self.control_values.arrangements()),
227229
*sub_info.wire_symbols,
228230
),
229231
exponent=sub_info.exponent,
230232
)
231233

232234
def __str__(self) -> str:
233-
if set(self.control_values) == {(1,)}:
235+
if self.control_values.are_same_value(1):
234236

235237
def get_prefix(control_vals):
236238
return 'C'
@@ -241,26 +243,26 @@ def get_prefix(control_vals):
241243
control_vals_str = ''.join(map(str, sorted(control_vals)))
242244
return f'C{control_vals_str}'
243245

244-
return ''.join(map(get_prefix, self.control_values)) + str(self.sub_gate)
246+
return ''.join(map(get_prefix, self.control_values.arrangements())) + str(self.sub_gate)
245247

246248
def __repr__(self) -> str:
247-
if self.num_controls() == 1 and self.control_values == ((1,),):
249+
if self.num_controls() == 1 and self.control_values.are_same_value(1):
248250
return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})'
249251

250-
if all(vals == (1,) for vals in self.control_values) and set(self.control_qid_shape) == {2}:
252+
if self.control_values.are_same_value(1) and set(self.control_qid_shape) == {2}:
251253
return (
252254
f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
253255
f'num_controls={self.num_controls()!r})'
254256
)
255257
return (
256258
f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
257-
f'control_values={self.control_values!r},'
259+
f'control_values={self.control_values.arrangements()!r},'
258260
f'control_qid_shape={self.control_qid_shape!r})'
259261
)
260262

261263
def _json_dict_(self) -> Dict[str, Any]:
262264
return {
263-
'control_values': self.control_values,
265+
'control_values': self.control_values.arrangements(),
264266
'control_qid_shape': self.control_qid_shape,
265267
'sub_gate': self.sub_gate,
266268
}

0 commit comments

Comments
 (0)