Skip to content

Commit def0928

Browse files
Created ControlValues for controlled gates/operations, fix for quantumlib#4512
created control_values.py which contains the ControlValues class. FreeVars and ConstrainedVars classes are provided for ease of use. while the basic idea of ControlValues integrating it inside the code base was challening the old way of using control_values assumed it's a tuple of tuples of ints and was used as thus (comparasion, hashing, slicing, fomatting, conditioning, and loops), the ControlValues class had to provide these functionalities the trickiest part to get right was the support for formatting!
1 parent 320511c commit def0928

File tree

5 files changed

+144
-78
lines changed

5 files changed

+144
-78
lines changed

cirq-core/cirq/ops/control_values.py

Lines changed: 98 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -48,77 +48,79 @@ def __init__(
4848
self, control_values: Sequence[Union[int, Collection[int], Type['ControlValues']]]
4949
):
5050
if len(control_values) == 0:
51-
self.vals = cast(Tuple[Tuple[int, ...], ...], (()))
52-
self.num_variables = 0
53-
self.nxt = None
54-
self.itr = None
51+
self._vals = cast(Tuple[Tuple[int, ...], ...], (()))
52+
self._num_variables = 0
53+
self._nxt = None
5554
return
56-
self.itr = None
57-
self.nxt = None
55+
self._nxt = None
5856

5957
if len(control_values) > 1:
60-
self.nxt = ControlValues(control_values[1:])
58+
self._nxt = ControlValues(control_values[1:])
6159

6260
if isinstance(control_values[0], ControlValues):
6361
aux = control_values[0].copy()
64-
aux.product(self.nxt)
65-
self.vals, self.num_variables, self.nxt = aux.vals, aux.num_variables, aux.nxt
66-
self.vals = cast(Tuple[Tuple[int, ...], ...], self.vals)
62+
aux.product(self._nxt)
63+
self._vals, self._num_variables, self._nxt = aux._vals, aux._num_variables, aux._nxt
64+
self._vals = cast(Tuple[Tuple[int, ...], ...], self._vals)
6765
return
6866

6967
val = control_values[0]
7068
if isinstance(val, int):
71-
self.vals = _from_int(val)
69+
self._vals = _from_int(val)
7270
elif isinstance(val, (list, tuple)):
7371
if isinstance(val[0], int):
74-
self.vals = _from_sequence_int(val)
72+
self._vals = _from_sequence_int(val)
7573
else:
76-
self.vals = _from_sequence_sequence(val)
74+
self._vals = _from_sequence_sequence(val)
7775
else:
7876
raise TypeError(f'{val} is of Unsupported type {type(val)}')
79-
self.num_variables = len(self.vals[0])
77+
self._num_variables = len(self._vals[0])
8078

8179
def product(self, other):
82-
# Cartesian product of all combinations in self x other
80+
"""Sets self to be the cartesian product of all combinations in self x other.
81+
82+
Args:
83+
other: A ControlValues object
84+
"""
8385
if other is None:
8486
return
8587
other = other.copy()
8688
cur = self
87-
while cur.nxt is not None:
88-
cur = cur.nxt
89-
cur.nxt = other
89+
while cur._nxt is not None:
90+
cur = cur._nxt
91+
cur._nxt = other
9092

9193
def __call__(self):
9294
return self.__iter__()
9395

9496
def __iter__(self):
95-
nxt = self.nxt if self.nxt else lambda: [()]
96-
if self.num_variables:
97-
self.itr = itertools.product(self.vals, nxt())
97+
nxt = self._nxt if self._nxt else lambda: [()]
98+
if self._num_variables:
99+
return itertools.product(self._vals, nxt())
98100
else:
99-
self.itr = itertools.product(*(), nxt())
100-
return self.itr
101+
return itertools.product(*(), nxt())
101102

102103
def copy(self):
103-
if self.num_variables == 0:
104+
"""Returns a deep copy of the object."""
105+
if self._num_variables == 0:
104106
new_copy = ControlValues([])
105107
else:
106108
new_copy = ControlValues(
107109
[
108-
copy.deepcopy(self.vals),
110+
copy.deepcopy(self._vals),
109111
]
110112
)
111-
new_copy.nxt = None
112-
if self.nxt:
113-
new_copy.nxt = self.nxt.copy()
113+
new_copy._nxt = None
114+
if self._nxt:
115+
new_copy._nxt = self._nxt.copy()
114116
return new_copy
115117

116118
def __len__(self):
117119
cur = self
118120
num_variables = 0
119121
while cur is not None:
120-
num_variables += cur.num_variables
121-
cur = cur.nxt
122+
num_variables += cur._num_variables
123+
cur = cur._nxt
122124
return num_variables
123125

124126
def __getitem__(self, key):
@@ -131,9 +133,9 @@ def __getitem__(self, key):
131133
if not 0 <= key < num_variables:
132134
key = key % num_variables
133135
cur = self
134-
while cur.num_variables <= key:
135-
key -= cur.num_variables
136-
cur = cur.nxt
136+
while cur._num_variables <= key:
137+
key -= cur._num_variables
138+
cur = cur._nxt
137139
return cur
138140

139141
def __eq__(self, other):
@@ -147,33 +149,55 @@ def __eq__(self, other):
147149
return self_values == other_values
148150

149151
def identifier(self, companions: Sequence[Union[int, 'cirq.Qid']]):
152+
"""Returns an identifier that pairs the control_values with their counterparts
153+
in the companion sequence (e.g. sequence of the control qubits ids)
154+
155+
Args:
156+
companions: sequence of the same length as the number of control qubits.
157+
158+
Returns:
159+
Tuple of pairs of the control values paired with its corresponding companion.
160+
"""
150161
companions = tuple(companions)
151162
controls = []
152163
cur = cast(Optional[ControlValues], self)
153164
while cur is not None:
154-
controls.append((cur.vals, companions[: cur.num_variables]))
155-
companions = companions[cur.num_variables :]
156-
cur = cur.nxt
165+
controls.append((cur._vals, companions[: cur._num_variables]))
166+
companions = companions[cur._num_variables :]
167+
cur = cur._nxt
157168
return tuple(controls)
158169

159-
def check_dimentionality(
170+
def check_dimensionality(
160171
self,
161172
qid_shape: Optional[Union[Tuple[int, ...], List[int]]] = None,
162173
controls: Optional[Union[Tuple['cirq.Qid', ...], List['cirq.Qid']]] = None,
163174
offset=0,
164175
):
165-
if self.num_variables == 0:
176+
"""Checks the dimentionality of control values with respect to qid_shape or controls.
177+
*At least one of qid_shape and controls must be provided.
178+
*if both are provided then controls is ignored*
179+
180+
Args:
181+
qid_shape: Sequence of shapes of the control qubits.
182+
controls: Sequence of Qids.
183+
offset: starting index.
184+
Raises:
185+
ValueError:
186+
- if none of qid_shape or controls are provided or both are empty.
187+
- if one of the control values violates the shape of its qubit.
188+
"""
189+
if self._num_variables == 0:
166190
return
167191
if qid_shape is None and controls is None:
168192
raise ValueError('At least one of qid_shape or controls has to be not given.')
169193
if controls is not None:
170194
controls = tuple(controls)
171195
if (qid_shape is None or len(qid_shape) == 0) and controls is not None:
172-
qid_shape = tuple(q.dimension for q in controls[: self.num_variables])
196+
qid_shape = tuple(q.dimension for q in controls[: self._num_variables])
173197
qid_shape = cast(Tuple[int], qid_shape)
174-
for product in self.vals:
198+
for product in self._vals:
175199
product = flatten(product)
176-
for i in range(self.num_variables):
200+
for i in range(self._num_variables):
177201
if not 0 <= product[i] < qid_shape[i]:
178202
message = (
179203
'Control values <{!r}> outside of range ' 'for control qubit number <{!r}>.'
@@ -185,40 +209,59 @@ def check_dimentionality(
185209
)
186210
raise ValueError(message)
187211

188-
if self.nxt is not None:
189-
self.nxt.check_dimentionality(
190-
qid_shape=qid_shape[self.num_variables :],
191-
controls=controls[self.num_variables :] if controls else None,
192-
offset=offset + self.num_variables,
212+
if self._nxt is not None:
213+
self._nxt.check_dimensionality(
214+
qid_shape=qid_shape[self._num_variables :],
215+
controls=controls[self._num_variables :] if controls else None,
216+
offset=offset + self._num_variables,
193217
)
194218

195219
def are_same_value(self, value: int = 1):
196-
for product in self.vals:
220+
for product in self._vals:
197221
product = flatten(product)
198222
if not all(v == value for v in product):
199223
return False
200-
if self.nxt is not None:
201-
return self.nxt.are_same_value(value)
224+
if self._nxt is not None:
225+
return self._nxt.are_same_value(value)
202226
return True
203227

204228
def arrangements(self):
229+
"""Returns a list of the control values.
230+
231+
Returns:
232+
lists containing the control values whose product is the list of all
233+
possible combinations of the control values.
234+
"""
205235
_arrangements = []
206236
cur = self
207237
while cur is not None:
208-
if cur.num_variables == 1:
209-
_arrangements.append(flatten(cur.vals))
238+
if cur._num_variables == 1:
239+
_arrangements.append(flatten(cur._vals))
210240
else:
211-
_arrangements.append(tuple(flatten(product) for product in cur.vals))
212-
cur = cur.nxt
241+
_arrangements.append(tuple(flatten(product) for product in cur._vals))
242+
cur = cur._nxt
213243
return _arrangements
214244

215245
def pop(self):
216-
if self.nxt is None:
246+
"""Removes the last control values combination."""
247+
if self._nxt is None:
217248
return None
218-
self.nxt = self.nxt.pop()
249+
self._nxt = self._nxt.pop()
219250
return self
220251

221252

253+
def to_control_values(
254+
values: Union[ControlValues, Sequence[Union[int, Collection[int]]]]
255+
) -> ControlValues:
256+
if not isinstance(values, ControlValues):
257+
# Convert to sorted tuples
258+
return ControlValues(
259+
tuple((val,) if isinstance(val, int) else tuple(sorted(val)) for val in values)
260+
)
261+
else:
262+
return values
263+
264+
222265
class FreeVars(ControlValues):
223266
pass
224267

cirq-core/cirq/ops/control_values_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ def test_slicing_not_supported():
7878
_ = control_vals[0:1]
7979

8080

81-
def test_check_dimentionality():
81+
def test_check_dimensionality():
8282
empty_control_vals = cv.ControlValues([])
83-
empty_control_vals.check_dimentionality()
83+
empty_control_vals.check_dimensionality()
8484

8585
control_values = cv.ControlValues([[0, 1], 1])
8686
with pytest.raises(ValueError):
87-
control_values.check_dimentionality()
87+
control_values.check_dimensionality()
8888

8989

9090
def test_pop():

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,10 @@ def __init__(
7979
raise ValueError('len(control_qid_shape) != num_controls')
8080
self.control_qid_shape = tuple(control_qid_shape)
8181

82-
# Convert to sorted tuples
83-
if not isinstance(control_values, cv.ControlValues):
84-
self.control_values = cv.ControlValues(
85-
tuple(
86-
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
87-
)
88-
)
89-
else:
90-
self.control_values = control_values
82+
self.control_values = cv.to_control_values(control_values)
9183

9284
# Verify control values not out of bounds
93-
self.control_values.check_dimentionality(self.control_qid_shape)
85+
self.control_values.check_dimensionality(self.control_qid_shape)
9486

9587
# Flatten nested ControlledGates.
9688
if isinstance(sub_gate, ControlledGate):
@@ -100,6 +92,16 @@ def __init__(
10092
else:
10193
self.sub_gate = sub_gate
10294

95+
@property
96+
def control_values(self) -> cv.ControlValues:
97+
return self._control_values
98+
99+
@control_values.setter
100+
def control_values(
101+
self, values: Union[cv.ControlValues, Sequence[Union[int, Collection[int]]]]
102+
) -> None:
103+
self._control_values = cv.to_control_values(values)
104+
103105
def num_controls(self) -> int:
104106
return len(self.control_qid_shape)
105107

cirq-core/cirq/ops/controlled_gate_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import cirq
2222
from cirq.type_workarounds import NotImplementedType
23+
from cirq.ops import control_values as cv
2324

2425

2526
class GateUsingWorkspaceForApplyUnitary(cirq.SingleQubitGate):
@@ -158,6 +159,22 @@ def test_init2():
158159
assert gate.num_qubits() == 3
159160
assert cirq.qid_shape(gate) == (3, 3, 2)
160161

162+
gate = cirq.ControlledGate(cirq.Z, control_values=cv.FreeVars([0, (0, 1)]))
163+
assert gate.sub_gate is cirq.Z
164+
assert gate.num_controls() == 2
165+
assert gate.control_values == ((0,), (0, 1))
166+
assert gate.control_qid_shape == (2, 2)
167+
assert gate.num_qubits() == 3
168+
assert cirq.qid_shape(gate) == (2, 2, 2)
169+
170+
gate = cirq.ControlledGate(cirq.Z, control_values=cv.ConstrainedVars([(1, 0), (0, 1)]))
171+
assert gate.sub_gate is cirq.Z
172+
assert gate.num_controls() == 2
173+
assert gate.control_values == [((1, 0), (0, 1))]
174+
assert gate.control_qid_shape == (2, 2)
175+
assert gate.num_qubits() == 3
176+
assert cirq.qid_shape(gate) == (2, 2, 2)
177+
161178

162179
def test_validate_args():
163180
a = cirq.NamedQubit('a')

cirq-core/cirq/ops/controlled_operation.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,11 @@ def __init__(
5353
control_values = ((1,),) * len(controls)
5454
if len(control_values) != len(controls):
5555
raise ValueError('len(control_values) != len(controls)')
56-
if not isinstance(control_values, cv.ControlValues):
57-
self.control_values = cv.ControlValues(
58-
# Convert to sorted tuples
59-
tuple(
60-
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
61-
)
62-
)
63-
else:
64-
self.control_values = control_values
56+
57+
self.control_values = cv.to_control_values(control_values)
58+
6559
# Verify control values not out of bounds
66-
self.control_values.check_dimentionality(controls=tuple(controls))
60+
self.control_values.check_dimensionality(controls=tuple(controls))
6761

6862
if not isinstance(sub_operation, ControlledOperation):
6963
self.controls = tuple(controls)
@@ -88,6 +82,16 @@ def gate(self) -> Optional['cirq.ControlledGate']:
8882
def qubits(self):
8983
return self.controls + self.sub_operation.qubits
9084

85+
@property
86+
def control_values(self) -> cv.ControlValues:
87+
return self._control_values
88+
89+
@control_values.setter
90+
def control_values(
91+
self, values: Union[cv.ControlValues, Sequence[Union[int, Collection[int]]]]
92+
) -> None:
93+
self._control_values = cv.to_control_values(values)
94+
9195
def with_qubits(self, *new_qubits):
9296
n = len(self.controls)
9397
return ControlledOperation(

0 commit comments

Comments
 (0)