Skip to content

Commit e9e12ee

Browse files
remove partial CZs if allow_partial_czs=False (#6436)
1 parent 9f07ce8 commit e9e12ee

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py

+24
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,24 @@
3131
import cirq
3232

3333

34+
def _remove_partial_czs_or_fail(
35+
operations: Iterable['cirq.Operation'], atol: float
36+
) -> List['cirq.Operation']:
37+
result = []
38+
for op in operations:
39+
if isinstance(op.gate, ops.CZPowGate):
40+
t = op.gate.exponent % 2 # CZ^t is periodic with period 2.
41+
if t < atol:
42+
continue # Identity.
43+
elif abs(t - 1) < atol:
44+
result.append(ops.CZ(*op.qubits)) # Was either CZ or CZ**-1.
45+
else:
46+
raise ValueError(f'CZ^t is not allowed for t={t}')
47+
else:
48+
result.append(op)
49+
return result
50+
51+
3452
def two_qubit_matrix_to_cz_operations(
3553
q0: 'cirq.Qid',
3654
q1: 'cirq.Qid',
@@ -53,10 +71,16 @@ def two_qubit_matrix_to_cz_operations(
5371
5472
Returns:
5573
A list of operations implementing the matrix.
74+
75+
Raises:
76+
ValueError: If allow_partial_czs=False and the matrix requires partial CZs.
5677
"""
5778
kak = linalg.kak_decomposition(mat, atol=atol)
5879
operations = _kak_decomposition_to_operations(q0, q1, kak, allow_partial_czs, atol=atol)
5980
if clean_operations:
81+
if not allow_partial_czs:
82+
# CZ^t is not allowed for any $t$ except $t=1$.
83+
return _remove_partial_czs_or_fail(cleanup_operations(operations), atol=atol)
6084
return cleanup_operations(operations)
6185
return operations
6286

cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,20 @@ def test_decompose_to_diagonal_and_circuit(v):
257257
combined_circuit = cirq.Circuit(cirq.MatrixGate(diagonal)(b, c), ops)
258258
circuit_unitary = combined_circuit.unitary(qubits_that_should_be_present=[b, c])
259259
cirq.testing.assert_allclose_up_to_global_phase(circuit_unitary, v, atol=2e-6)
260+
261+
262+
def test_remove_partial_czs_or_fail():
263+
CZ = cirq.CZ(*cirq.LineQubit.range(2))
264+
assert (
265+
cirq.transformers.analytical_decompositions.two_qubit_to_cz._remove_partial_czs_or_fail(
266+
[CZ**1e-15], atol=1e-9
267+
)
268+
== []
269+
)
270+
assert cirq.transformers.analytical_decompositions.two_qubit_to_cz._remove_partial_czs_or_fail(
271+
[CZ**-1, CZ], atol=1e-9
272+
) == [CZ, CZ]
273+
with pytest.raises(ValueError):
274+
_ = cirq.transformers.analytical_decompositions.two_qubit_to_cz._remove_partial_czs_or_fail(
275+
[CZ**-0.5], atol=1e-9
276+
)

0 commit comments

Comments
 (0)