Skip to content

Commit 1fd4efa

Browse files
Created ControlValues for controlled gates/operations, fix for quantumlib#4512
created control_values.py which contains the ControlValues class.\nFreeVars and ConstrainedVars classes are provided for ease of use.\nwhile the basic idea of ControlValues integrating it inside the code base was challening\nthe old way of using control_values assumed it's a tuple of tuples of ints and was used as thus (comparasion, hashing, slicing, fomatting, conditioning, and loops), the ControlValues class had to provide these functionalities\nthe trickiest part to get right was the support for formatting!\nI'll create a follow up PR for unit tests for control_values.py
1 parent 7259925 commit 1fd4efa

File tree

3 files changed

+263
-44
lines changed

3 files changed

+263
-44
lines changed

cirq-core/cirq/ops/control_values.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2018 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Collection, cast, Optional, Sequence, Union
15+
16+
import copy
17+
import itertools
18+
19+
20+
def flatten(sequence):
21+
def _flatten_aux(sequence):
22+
if isinstance(sequence, int):
23+
yield sequence
24+
else:
25+
for item in sequence:
26+
yield from _flatten_aux(item)
27+
28+
return tuple(_flatten_aux(sequence))
29+
30+
31+
class ControlValues:
32+
def __init__(self, control_values: Sequence[Union[int, Collection[int], 'ControlValues']]):
33+
if len(control_values) == 0:
34+
self.nxt = None
35+
self.vals = None
36+
self.num_variables = 0
37+
self.itr = None
38+
return
39+
self.itr = None
40+
self.nxt = None
41+
self.vals = control_values[0]
42+
self.num_variables = 0
43+
44+
if len(control_values) > 1:
45+
self.nxt = ControlValues(control_values[1:])
46+
47+
if isinstance(self.vals, int):
48+
self.vals = ((self.vals,),)
49+
50+
if isinstance(self.vals[0], int):
51+
self.vals = tuple(tuple(val) for val in self.vals)
52+
53+
if isinstance(control_values[0], ControlValues):
54+
aux = control_values[0].copy().And(self.nxt)
55+
self.vals, self.num_variables, self.nxt = aux.vals, aux.num_variables, aux.nxt
56+
else:
57+
self.vals = cast(Tuple[Tuple[int, ...], ...], self.vals)
58+
self.num_variables = len(self.vals[0])
59+
60+
def And(self, other): # pylint: disable=invalid-name
61+
# Cartesian product of all combinations in self x other
62+
if other is None:
63+
return
64+
if not isinstance(other, ControlValues):
65+
raise ValueError
66+
other = other.copy()
67+
cur = self
68+
while cur.nxt:
69+
cur = cur.nxt
70+
cur.nxt = other
71+
72+
def __call__(self):
73+
return self.__iter__()
74+
75+
def __iter__(self):
76+
nxt = self.nxt if self.nxt else lambda: [()]
77+
if self.num_variables:
78+
self.itr = itertools.product(self.vals, nxt())
79+
else:
80+
self.itr = itertools.product(*())
81+
return self.itr
82+
83+
def __next__(self):
84+
for val in self.itr:
85+
yield val
86+
87+
def copy(self):
88+
new_copy = ControlValues(
89+
[
90+
copy.deepcopy(self.vals),
91+
]
92+
)
93+
new_copy.nxt = None
94+
if self.nxt:
95+
new_copy.nxt = self.nxt.copy()
96+
return new_copy
97+
98+
def __len__(self):
99+
cur = self
100+
num_variables = 0
101+
while cur is not None:
102+
num_variables += cur.num_variables
103+
cur = cur.nxt
104+
return num_variables
105+
106+
def __getitem__(self, key):
107+
if isinstance(key, slice):
108+
if key != slice(None, -1, None):
109+
raise TypeError('Unsupported slicing')
110+
return self.copy().pop()
111+
key = int(key)
112+
num_variables = len(self)
113+
if not 0 <= key < num_variables:
114+
key = key % num_variables
115+
if key < 0:
116+
key += num_variables
117+
cur = self
118+
while cur.num_variables <= key:
119+
key -= cur.num_variables
120+
cur = cur.nxt
121+
return cur
122+
123+
def __eq__(self, other):
124+
if other is None:
125+
return False
126+
if not isinstance(other, ControlValues):
127+
return self == ControlValues(other)
128+
self_values = set(flatten(A) for A in self)
129+
other_values = set(flatten(B) for B in other)
130+
return self_values == other_values
131+
132+
def identifier(self, companions: Sequence[Union[int, 'cirq.Qid']]):
133+
companions = tuple(companions)
134+
controls = []
135+
cur = self
136+
while cur is not None:
137+
controls.append((cur.vals, companions[: cur.num_variables]))
138+
companions = companions[cur.num_variables :]
139+
cur = cur.nxt
140+
return tuple(controls)
141+
142+
def check_dimentionality(
143+
self,
144+
qid_shape: Optional[Sequence[int]] = None,
145+
controls: Optional[Sequence['cirq.Qid']] = None,
146+
offset=0,
147+
):
148+
if self.num_variables == 0:
149+
return
150+
if qid_shape is None or len(qid_shape) == 0:
151+
qid_shape = [q.dimension for q in controls[: self.num_variables]]
152+
for product in self.vals:
153+
product = flatten(product)
154+
for i in range(self.num_variables):
155+
if not 0 <= product[i] < qid_shape[i]:
156+
message = (
157+
'Control values <{!r}> outside of range ' 'for control qubit number <{!r}>.'
158+
).format(product[i], i + offset)
159+
if controls is not None:
160+
message = (
161+
'Control values <{product[i]!r}> outside of range'
162+
' for qubit <{controls[i]!r}>.'
163+
)
164+
raise ValueError(message)
165+
166+
if self.nxt is not None:
167+
self.nxt.check_dimentionality(
168+
qid_shape=qid_shape[self.num_variables :],
169+
controls=controls[self.num_variables :] if controls else None,
170+
offset=offset + self.num_variables,
171+
)
172+
173+
def are_same_value(self, value: int = 1):
174+
for product in self.vals:
175+
product = flatten(product)
176+
if not all(v == value for v in product):
177+
return False
178+
if self.nxt is not None:
179+
return self.nxt.are_same_value(value)
180+
return True
181+
182+
def arrangements(self):
183+
_arrangements = []
184+
cur = self
185+
while cur is not None:
186+
if cur.num_variables == 1:
187+
_arrangements.append(flatten(cur.vals))
188+
else:
189+
_arrangements.append(flatten(product) for product in cur.vals)
190+
cur = cur.nxt
191+
return _arrangements
192+
193+
def pop(self):
194+
if self.nxt is None:
195+
return None
196+
cur = self
197+
while cur.nxt.nxt is not None:
198+
cur = cur.nxt
199+
cur.nxt = None
200+
return self
201+
202+
203+
class FreeVars(ControlValues):
204+
pass
205+
206+
207+
class ConstrainedVars(ControlValues):
208+
def __init__(self, control_values):
209+
sum_of_product = (tuple(zip(*control_values)),)
210+
super().__init__(sum_of_product)

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union
15+
from typing import AbstractSet, Any, Collection, Dict, Optional, Sequence, Tuple, Union
1616

1717
import numpy as np
1818

1919
import cirq
2020
from cirq import protocols, value
21-
from cirq.ops import raw_types, controlled_operation as cop
21+
from cirq.ops import raw_types, controlled_operation as cop, control_values as cv
2222
from cirq.type_workarounds import NotImplementedType
2323

2424

@@ -33,7 +33,9 @@ def __init__(
3333
self,
3434
sub_gate: 'cirq.Gate',
3535
num_controls: int = None,
36-
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
36+
control_values: Optional[
37+
Union[cv.ControlValues, Sequence[Union[int, Collection[int]]]]
38+
] = None,
3739
control_qid_shape: Optional[Sequence[int]] = None,
3840
) -> None:
3941
"""Initializes the controlled gate. If no arguments are specified for
@@ -76,22 +78,22 @@ def __init__(
7678
self.control_qid_shape = tuple(control_qid_shape)
7779

7880
# Convert to sorted tuples
79-
self.control_values = cast(
80-
Tuple[Tuple[int, ...], ...],
81-
tuple((val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values),
82-
)
83-
# Verify control values not out of bounds
84-
for i, (val, dimension) in enumerate(zip(self.control_values, self.control_qid_shape)):
85-
if not all(0 <= v < dimension for v in val):
86-
raise ValueError(
87-
'Control values <{!r}> outside of range for control qubit '
88-
'number <{!r}>.'.format(val, i)
81+
if not isinstance(control_values, cv.ControlValues):
82+
self.control_values = cv.ControlValues(
83+
tuple(
84+
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
8985
)
86+
)
87+
else:
88+
self.control_values = control_values
89+
90+
# Verify control values not out of bounds
91+
self.control_values.check_dimentionality(self.control_qid_shape)
9092

9193
# Flatten nested ControlledGates.
9294
if isinstance(sub_gate, ControlledGate):
9395
self.sub_gate = sub_gate.sub_gate # type: ignore
94-
self.control_values += sub_gate.control_values
96+
self.control_values.And(sub_gate.control_values)
9597
self.control_qid_shape += sub_gate.control_qid_shape
9698
else:
9799
self.sub_gate = sub_gate
@@ -131,7 +133,7 @@ def _value_equality_values_(self):
131133
return (
132134
self.sub_gate,
133135
self.num_controls(),
134-
frozenset(zip(self.control_values, self.control_qid_shape)),
136+
frozenset(self.control_values.identifier(self.control_qid_shape)),
135137
)
136138

137139
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
@@ -223,14 +225,14 @@ def get_symbol(vals):
223225

224226
return protocols.CircuitDiagramInfo(
225227
wire_symbols=(
226-
*(get_symbol(vals) for vals in self.control_values),
228+
*(get_symbol(vals) for vals in self.control_values.arrangements()),
227229
*sub_info.wire_symbols,
228230
),
229231
exponent=sub_info.exponent,
230232
)
231233

232234
def __str__(self) -> str:
233-
if set(self.control_values) == {(1,)}:
235+
if self.control_values.are_same_value(1):
234236

235237
def get_prefix(control_vals):
236238
return 'C'
@@ -241,26 +243,26 @@ def get_prefix(control_vals):
241243
control_vals_str = ''.join(map(str, sorted(control_vals)))
242244
return f'C{control_vals_str}'
243245

244-
return ''.join(map(get_prefix, self.control_values)) + str(self.sub_gate)
246+
return ''.join(map(get_prefix, self.control_values.arrangements())) + str(self.sub_gate)
245247

246248
def __repr__(self) -> str:
247-
if self.num_controls() == 1 and self.control_values == ((1,),):
249+
if self.num_controls() == 1 and self.control_values.are_same_value(1):
248250
return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})'
249251

250-
if all(vals == (1,) for vals in self.control_values) and set(self.control_qid_shape) == {2}:
252+
if self.control_values.are_same_value(1) and set(self.control_qid_shape) == {2}:
251253
return (
252254
f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
253255
f'num_controls={self.num_controls()!r})'
254256
)
255257
return (
256258
f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
257-
f'control_values={self.control_values!r},'
259+
f'control_values={self.control_values.arrangements()!r},'
258260
f'control_qid_shape={self.control_qid_shape!r})'
259261
)
260262

261263
def _json_dict_(self) -> Dict[str, Any]:
262264
return {
263-
'control_values': self.control_values,
265+
'control_values': self.control_values.arrangements(),
264266
'control_qid_shape': self.control_qid_shape,
265267
'sub_gate': self.sub_gate,
266268
}

0 commit comments

Comments
 (0)