Skip to content

Commit 8f6af09

Browse files
daxfohlCirqBot
andauthored
Simplify controlled gate for SumOfProducts (quantumlib#5873)
Co-authored-by: Cirq Bot <[email protected]>
1 parent ae149fb commit 8f6af09

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

cirq/ops/common_gates_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_specialized_control(input_gate, specialized_output):
125125
assert input_gate.controlled() == specialized_output
126126
assert input_gate.controlled(num_controls=1) == specialized_output
127127
assert input_gate.controlled(control_values=((1,),)) == specialized_output
128+
assert input_gate.controlled(control_values=cirq.SumOfProducts([[1]])) == specialized_output
128129
assert input_gate.controlled(control_qid_shape=(2,)) == specialized_output
129130
assert np.allclose(
130131
cirq.unitary(specialized_output),
@@ -166,6 +167,28 @@ def test_specialized_control(input_gate, specialized_output):
166167
)
167168

168169

170+
@pytest.mark.parametrize(
171+
'input_gate, specialized_output',
172+
[
173+
(cirq.Z, cirq.CCZ),
174+
(cirq.X, cirq.CCX),
175+
(cirq.ZPowGate(exponent=0.5), cirq.CCZPowGate(exponent=0.5)),
176+
(cirq.XPowGate(exponent=0.5), cirq.CCXPowGate(exponent=0.5)),
177+
],
178+
)
179+
def test_specialized_control_two_step(input_gate, specialized_output):
180+
# Two-qubit control on the input gate gives the specialized output
181+
assert input_gate.controlled().controlled() == specialized_output
182+
assert input_gate.controlled(num_controls=2) == specialized_output
183+
assert input_gate.controlled(control_values=[1, 1]) == specialized_output
184+
assert input_gate.controlled(control_values=cirq.SumOfProducts([[1, 1]])) == specialized_output
185+
assert input_gate.controlled(control_qid_shape=(2, 2)) == specialized_output
186+
assert np.allclose(
187+
cirq.unitary(specialized_output),
188+
cirq.unitary(cirq.ControlledGate(input_gate, num_controls=2)),
189+
)
190+
191+
169192
@pytest.mark.parametrize(
170193
'gate, specialized_type',
171194
[

cirq/ops/controlled_gate.py

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def __init__(
8080
bounds, or if the sub_gate is not a unitary or mixture.
8181
"""
8282
_validate_sub_object(sub_gate)
83+
84+
# Simplify a single SumOfProducts
85+
if isinstance(control_values, cv.SumOfProducts) and len(control_values._conjunctions) == 1:
86+
control_values = control_values._conjunctions[0]
87+
8388
if num_controls is None:
8489
if control_values is not None:
8590
num_controls = (

cirq/transformers/measurement_transformers.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,9 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
121121
if c.key not in measurement_qubits:
122122
raise ValueError(f'Deferred measurement for key={c.key} not found.')
123123
qs = measurement_qubits[c.key]
124-
if len(qs) == 1:
125-
control_values: Any = [range(1, qs[0].dimension)]
126-
else:
127-
all_values = itertools.product(*[range(q.dimension) for q in qs])
128-
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
129-
control_values = ops.SumOfProducts(anything_but_all_zeros)
124+
all_values = itertools.product(*[range(q.dimension) for q in qs])
125+
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
126+
control_values = ops.SumOfProducts(anything_but_all_zeros)
130127
new_op = new_op.controlled_by(*qs, control_values=control_values)
131128
else:
132129
raise ValueError('Only KeyConditions are allowed.')

0 commit comments

Comments
 (0)