@@ -125,6 +125,7 @@ def test_specialized_control(input_gate, specialized_output):
125
125
assert input_gate .controlled () == specialized_output
126
126
assert input_gate .controlled (num_controls = 1 ) == specialized_output
127
127
assert input_gate .controlled (control_values = ((1 ,),)) == specialized_output
128
+ assert input_gate .controlled (control_values = cirq .SumOfProducts ([[1 ]])) == specialized_output
128
129
assert input_gate .controlled (control_qid_shape = (2 ,)) == specialized_output
129
130
assert np .allclose (
130
131
cirq .unitary (specialized_output ),
@@ -166,6 +167,28 @@ def test_specialized_control(input_gate, specialized_output):
166
167
)
167
168
168
169
170
+ @pytest .mark .parametrize (
171
+ 'input_gate, specialized_output' ,
172
+ [
173
+ (cirq .Z , cirq .CCZ ),
174
+ (cirq .X , cirq .CCX ),
175
+ (cirq .ZPowGate (exponent = 0.5 ), cirq .CCZPowGate (exponent = 0.5 )),
176
+ (cirq .XPowGate (exponent = 0.5 ), cirq .CCXPowGate (exponent = 0.5 )),
177
+ ],
178
+ )
179
+ def test_specialized_control_two_step (input_gate , specialized_output ):
180
+ # Two-qubit control on the input gate gives the specialized output
181
+ assert input_gate .controlled ().controlled () == specialized_output
182
+ assert input_gate .controlled (num_controls = 2 ) == specialized_output
183
+ assert input_gate .controlled (control_values = [1 , 1 ]) == specialized_output
184
+ assert input_gate .controlled (control_values = cirq .SumOfProducts ([[1 , 1 ]])) == specialized_output
185
+ assert input_gate .controlled (control_qid_shape = (2 , 2 )) == specialized_output
186
+ assert np .allclose (
187
+ cirq .unitary (specialized_output ),
188
+ cirq .unitary (cirq .ControlledGate (input_gate , num_controls = 2 )),
189
+ )
190
+
191
+
169
192
@pytest .mark .parametrize (
170
193
'gate, specialized_type' ,
171
194
[
0 commit comments