|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from typing import Union |
| 16 | + |
15 | 17 | import cirq
|
16 | 18 | from cirq.protocols.decompose_protocol import DecomposeResult
|
17 | 19 | from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset
|
@@ -243,3 +245,151 @@ def test_optimize_for_target_gateset_deep():
|
243 | 245 | 1: ───#2───────────────────────────────────────────────────────────────────────────
|
244 | 246 | ''',
|
245 | 247 | )
|
| 248 | + |
| 249 | + |
| 250 | +@pytest.mark.parametrize('max_num_passes', [2, None]) |
| 251 | +def test_optimize_for_target_gateset_multiple_passes(max_num_passes: Union[int, None]): |
| 252 | + gateset = cirq.CZTargetGateset() |
| 253 | + |
| 254 | + input_circuit = cirq.Circuit( |
| 255 | + [ |
| 256 | + cirq.Moment( |
| 257 | + cirq.X(cirq.LineQubit(1)), |
| 258 | + cirq.X(cirq.LineQubit(2)), |
| 259 | + cirq.X(cirq.LineQubit(3)), |
| 260 | + cirq.X(cirq.LineQubit(6)), |
| 261 | + ), |
| 262 | + cirq.Moment( |
| 263 | + cirq.H(cirq.LineQubit(0)), |
| 264 | + cirq.H(cirq.LineQubit(1)), |
| 265 | + cirq.H(cirq.LineQubit(2)), |
| 266 | + cirq.H(cirq.LineQubit(3)), |
| 267 | + cirq.H(cirq.LineQubit(4)), |
| 268 | + cirq.H(cirq.LineQubit(5)), |
| 269 | + cirq.H(cirq.LineQubit(6)), |
| 270 | + ), |
| 271 | + cirq.Moment( |
| 272 | + cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5)) |
| 273 | + ), |
| 274 | + cirq.Moment( |
| 275 | + cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)), |
| 276 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)), |
| 277 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)), |
| 278 | + ), |
| 279 | + cirq.Moment( |
| 280 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)), |
| 281 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)), |
| 282 | + cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)), |
| 283 | + ), |
| 284 | + ] |
| 285 | + ) |
| 286 | + desired_circuit = cirq.Circuit.from_moments( |
| 287 | + cirq.Moment( |
| 288 | + cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on( |
| 289 | + cirq.LineQubit(4) |
| 290 | + ) |
| 291 | + ), |
| 292 | + cirq.Moment(cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5))), |
| 293 | + cirq.Moment( |
| 294 | + cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on( |
| 295 | + cirq.LineQubit(1) |
| 296 | + ), |
| 297 | + cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on( |
| 298 | + cirq.LineQubit(0) |
| 299 | + ), |
| 300 | + cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on( |
| 301 | + cirq.LineQubit(3) |
| 302 | + ), |
| 303 | + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on( |
| 304 | + cirq.LineQubit(2) |
| 305 | + ), |
| 306 | + ), |
| 307 | + cirq.Moment( |
| 308 | + cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)), |
| 309 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)), |
| 310 | + ), |
| 311 | + cirq.Moment( |
| 312 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)), |
| 313 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)), |
| 314 | + ), |
| 315 | + cirq.Moment( |
| 316 | + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on( |
| 317 | + cirq.LineQubit(6) |
| 318 | + ) |
| 319 | + ), |
| 320 | + cirq.Moment(cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5))), |
| 321 | + ) |
| 322 | + got = cirq.optimize_for_target_gateset( |
| 323 | + input_circuit, gateset=gateset, max_num_passes=max_num_passes |
| 324 | + ) |
| 325 | + cirq.testing.assert_same_circuits(got, desired_circuit) |
| 326 | + |
| 327 | + |
| 328 | +@pytest.mark.parametrize('max_num_passes', [2, None]) |
| 329 | +def test_optimize_for_target_gateset_multiple_passes_dont_preserve_moment_structure( |
| 330 | + max_num_passes: Union[int, None] |
| 331 | +): |
| 332 | + gateset = cirq.CZTargetGateset(preserve_moment_structure=False) |
| 333 | + |
| 334 | + input_circuit = cirq.Circuit( |
| 335 | + [ |
| 336 | + cirq.Moment( |
| 337 | + cirq.X(cirq.LineQubit(1)), |
| 338 | + cirq.X(cirq.LineQubit(2)), |
| 339 | + cirq.X(cirq.LineQubit(3)), |
| 340 | + cirq.X(cirq.LineQubit(6)), |
| 341 | + ), |
| 342 | + cirq.Moment( |
| 343 | + cirq.H(cirq.LineQubit(0)), |
| 344 | + cirq.H(cirq.LineQubit(1)), |
| 345 | + cirq.H(cirq.LineQubit(2)), |
| 346 | + cirq.H(cirq.LineQubit(3)), |
| 347 | + cirq.H(cirq.LineQubit(4)), |
| 348 | + cirq.H(cirq.LineQubit(5)), |
| 349 | + cirq.H(cirq.LineQubit(6)), |
| 350 | + ), |
| 351 | + cirq.Moment( |
| 352 | + cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5)) |
| 353 | + ), |
| 354 | + cirq.Moment( |
| 355 | + cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)), |
| 356 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)), |
| 357 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)), |
| 358 | + ), |
| 359 | + cirq.Moment( |
| 360 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)), |
| 361 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)), |
| 362 | + cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)), |
| 363 | + ), |
| 364 | + ] |
| 365 | + ) |
| 366 | + desired_circuit = cirq.Circuit( |
| 367 | + cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on( |
| 368 | + cirq.LineQubit(4) |
| 369 | + ), |
| 370 | + cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on( |
| 371 | + cirq.LineQubit(1) |
| 372 | + ), |
| 373 | + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on( |
| 374 | + cirq.LineQubit(2) |
| 375 | + ), |
| 376 | + cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on( |
| 377 | + cirq.LineQubit(0) |
| 378 | + ), |
| 379 | + cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on( |
| 380 | + cirq.LineQubit(3) |
| 381 | + ), |
| 382 | + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on( |
| 383 | + cirq.LineQubit(6) |
| 384 | + ), |
| 385 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)), |
| 386 | + cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)), |
| 387 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)), |
| 388 | + cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)), |
| 389 | + cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)), |
| 390 | + cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)), |
| 391 | + ) |
| 392 | + got = cirq.optimize_for_target_gateset( |
| 393 | + input_circuit, gateset=gateset, max_num_passes=max_num_passes |
| 394 | + ) |
| 395 | + cirq.testing.assert_same_circuits(got, desired_circuit) |
0 commit comments