|
13 | 13 | # limitations under the License.
|
14 | 14 | from typing import Iterable, cast
|
15 | 15 |
|
| 16 | +import dataclasses |
16 | 17 | import numpy as np
|
17 | 18 | import pytest
|
18 | 19 | import sympy
|
@@ -53,29 +54,42 @@ def assert_optimizes(
|
53 | 54 | )
|
54 | 55 |
|
55 | 56 | # And match the expected circuit.
|
56 |
| - assert circuit == expected, ( |
57 |
| - "Circuit wasn't optimized as expected.\n" |
58 |
| - "INPUT:\n" |
59 |
| - "{}\n" |
60 |
| - "\n" |
61 |
| - "EXPECTED OUTPUT:\n" |
62 |
| - "{}\n" |
63 |
| - "\n" |
64 |
| - "ACTUAL OUTPUT:\n" |
65 |
| - "{}\n" |
66 |
| - "\n" |
67 |
| - "EXPECTED OUTPUT (detailed):\n" |
68 |
| - "{!r}\n" |
69 |
| - "\n" |
70 |
| - "ACTUAL OUTPUT (detailed):\n" |
71 |
| - "{!r}" |
72 |
| - ).format(before, expected, circuit, expected, circuit) |
| 57 | + cirq.testing.assert_same_circuits(circuit, expected) |
73 | 58 |
|
74 | 59 | # And it should be idempotent.
|
75 | 60 | circuit = cirq.eject_phased_paulis(
|
76 | 61 | circuit, eject_parameterized=eject_parameterized, context=context
|
77 | 62 | )
|
78 |
| - assert circuit == expected |
| 63 | + cirq.testing.assert_same_circuits(circuit, expected) |
| 64 | + |
| 65 | + # Nested sub-circuits should also get optimized. |
| 66 | + q = before.all_qubits() |
| 67 | + c_nested = cirq.Circuit( |
| 68 | + [cirq.PhasedXPowGate(phase_exponent=0.5).on_each(*q), (cirq.Z ** 0.5).on_each(*q)], |
| 69 | + cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore"), |
| 70 | + [cirq.Y.on_each(*q), cirq.X.on_each(*q)], |
| 71 | + cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag"), |
| 72 | + ) |
| 73 | + c_expected = cirq.Circuit( |
| 74 | + cirq.PhasedXPowGate(phase_exponent=0.75).on_each(*q), |
| 75 | + cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")), |
| 76 | + cirq.Z.on_each(*q), |
| 77 | + cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")), |
| 78 | + ) |
| 79 | + if context is None: |
| 80 | + context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True) |
| 81 | + else: |
| 82 | + context = dataclasses.replace( |
| 83 | + context, tags_to_ignore=context.tags_to_ignore + ("ignore",), deep=True |
| 84 | + ) |
| 85 | + c_nested = cirq.eject_phased_paulis( |
| 86 | + c_nested, context=context, eject_parameterized=eject_parameterized |
| 87 | + ) |
| 88 | + cirq.testing.assert_same_circuits(c_nested, c_expected) |
| 89 | + c_nested = cirq.eject_phased_paulis( |
| 90 | + c_nested, context=context, eject_parameterized=eject_parameterized |
| 91 | + ) |
| 92 | + cirq.testing.assert_same_circuits(c_nested, c_expected) |
79 | 93 |
|
80 | 94 |
|
81 | 95 | def quick_circuit(*moments: Iterable[cirq.OP_TREE]) -> cirq.Circuit:
|
|
0 commit comments