Skip to content

Commit df92f59

Browse files
Add support for allocating qubits in decompose to cirq.unitary (quantumlib#6112)
* Add support for allocating qubits in decompose to cirq.unitary * fixed apply_unitaries * fix mypy * refactored tests * addressing comments * added sample_gates_test.py * Improved sample_gates.py implementation and unitary_protocol tests. Also added docstrings * fixed lint * retrigger checks --------- Co-authored-by: Tanuj Khattar <[email protected]>
1 parent 8694c04 commit df92f59

7 files changed

+303
-22
lines changed

cirq/protocols/apply_unitary_protocol.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,33 @@ def default(
133133
state = qis.one_hot(index=(0,) * num_qubits, shape=qid_shape, dtype=np.complex128)
134134
return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits))
135135

136+
@classmethod
137+
def for_unitary(
138+
cls, num_qubits: Optional[int] = None, *, qid_shape: Optional[Tuple[int, ...]] = None
139+
) -> 'ApplyUnitaryArgs':
140+
"""A default instance corresponding to an identity matrix.
141+
142+
Specify exactly one argument.
143+
144+
Args:
145+
num_qubits: The number of qubits to make space for in the state.
146+
qid_shape: A tuple representing the number of quantum levels of each
147+
qubit the identity matrix applies to. `qid_shape` is (2, 2, 2) for
148+
a three-qubit identity operation tensor.
149+
150+
Raises:
151+
TypeError: If exactly neither `num_qubits` or `qid_shape` is provided or
152+
both are provided.
153+
"""
154+
if (num_qubits is None) == (qid_shape is None):
155+
raise TypeError('Specify exactly one of num_qubits or qid_shape.')
156+
if num_qubits is not None:
157+
qid_shape = (2,) * num_qubits
158+
qid_shape = cast(Tuple[int, ...], qid_shape) # Satisfy mypy
159+
num_qubits = len(qid_shape)
160+
state = qis.eye_tensor(qid_shape, dtype=np.complex128)
161+
return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits))
162+
136163
def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs':
137164
"""Returns a transposed view of the same arguments.
138165
@@ -409,19 +436,7 @@ def _strat_apply_unitary_from_apply_unitary(
409436
return _incorporate_result_into_target(args, sub_args, sub_result)
410437

411438

412-
def _strat_apply_unitary_from_unitary(
413-
unitary_value: Any, args: ApplyUnitaryArgs
414-
) -> Optional[np.ndarray]:
415-
# Check for magic method.
416-
method = getattr(unitary_value, '_unitary_', None)
417-
if method is None:
418-
return NotImplemented
419-
420-
# Attempt to get the unitary matrix.
421-
matrix = method()
422-
if matrix is NotImplemented or matrix is None:
423-
return matrix
424-
439+
def _apply_unitary_from_matrix(matrix: np.ndarray, unitary_value: Any, args: ApplyUnitaryArgs):
425440
if args.slices is None:
426441
val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
427442
slices = tuple(slice(0, size) for size in val_qid_shape)
@@ -450,11 +465,42 @@ def _strat_apply_unitary_from_unitary(
450465
return _incorporate_result_into_target(args, sub_args, sub_result)
451466

452467

468+
def _strat_apply_unitary_from_unitary(
469+
unitary_value: Any, args: ApplyUnitaryArgs
470+
) -> Optional[np.ndarray]:
471+
# Check for magic method.
472+
method = getattr(unitary_value, '_unitary_', None)
473+
if method is None:
474+
return NotImplemented
475+
476+
# Attempt to get the unitary matrix.
477+
matrix = method()
478+
if matrix is NotImplemented or matrix is None:
479+
return matrix
480+
481+
return _apply_unitary_from_matrix(matrix, unitary_value, args)
482+
483+
453484
def _strat_apply_unitary_from_decompose(val: Any, args: ApplyUnitaryArgs) -> Optional[np.ndarray]:
454485
operations, qubits, _ = _try_decompose_into_operations_and_qubits(val)
455486
if operations is None:
456487
return NotImplemented
457-
return apply_unitaries(operations, qubits, args, None)
488+
all_qubits = frozenset([q for op in operations for q in op.qubits])
489+
ancilla = tuple(sorted(all_qubits.difference(qubits)))
490+
if not len(ancilla):
491+
return apply_unitaries(operations, qubits, args, None)
492+
ordered_qubits = ancilla + tuple(qubits)
493+
all_qid_shapes = qid_shape_protocol.qid_shape(ordered_qubits)
494+
result = apply_unitaries(
495+
operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=all_qid_shapes), None
496+
)
497+
if result is None or result is NotImplemented:
498+
return result
499+
result = result.reshape((np.prod(all_qid_shapes, dtype=np.int64), -1))
500+
val_qid_shape = qid_shape_protocol.qid_shape(qubits)
501+
state_vec_length = np.prod(val_qid_shape, dtype=np.int64)
502+
result = result[:state_vec_length, :state_vec_length]
503+
return _apply_unitary_from_matrix(result, val, args)
458504

459505

460506
def apply_unitaries(

cirq/protocols/apply_unitary_protocol_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,3 +717,53 @@ def test_cast_to_complex():
717717
np.ComplexWarning, match='Casting complex values to real discards the imaginary part'
718718
):
719719
cirq.apply_unitary(y0, args)
720+
721+
722+
class NotDecomposableGate(cirq.Gate):
723+
def num_qubits(self):
724+
return 1
725+
726+
727+
class DecomposableGate(cirq.Gate):
728+
def __init__(self, sub_gate: cirq.Gate, allocate_ancilla: bool) -> None:
729+
super().__init__()
730+
self._sub_gate = sub_gate
731+
self._allocate_ancilla = allocate_ancilla
732+
733+
def num_qubits(self):
734+
return 1
735+
736+
def _decompose_(self, qubits):
737+
if self._allocate_ancilla:
738+
yield cirq.Z(cirq.NamedQubit('DecomposableGateQubit'))
739+
yield self._sub_gate(qubits[0])
740+
741+
742+
def test_strat_apply_unitary_from_decompose():
743+
state = np.eye(2, dtype=np.complex128)
744+
args = cirq.ApplyUnitaryArgs(
745+
target_tensor=state, available_buffer=np.zeros_like(state), axes=(0,)
746+
)
747+
np.testing.assert_allclose(
748+
cirq.apply_unitaries(
749+
[DecomposableGate(cirq.X, False)(cirq.LineQubit(0))], [cirq.LineQubit(0)], args
750+
),
751+
[[0, 1], [1, 0]],
752+
)
753+
754+
with pytest.raises(TypeError):
755+
_ = cirq.apply_unitaries(
756+
[DecomposableGate(NotDecomposableGate(), True)(cirq.LineQubit(0))],
757+
[cirq.LineQubit(0)],
758+
args,
759+
)
760+
761+
762+
def test_unitary_construction():
763+
with pytest.raises(TypeError):
764+
_ = cirq.ApplyUnitaryArgs.for_unitary()
765+
766+
np.testing.assert_allclose(
767+
cirq.ApplyUnitaryArgs.for_unitary(num_qubits=3).target_tensor,
768+
cirq.eye_tensor((2,) * 3, dtype=np.complex128),
769+
)

cirq/protocols/unitary_protocol.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import numpy as np
1818
from typing_extensions import Protocol
1919

20-
from cirq import qis
2120
from cirq._doc import doc_private
2221
from cirq.protocols import qid_shape_protocol
2322
from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs, apply_unitaries
@@ -162,9 +161,7 @@ def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]:
162161
return NotImplemented
163162

164163
# Apply unitary effect to an identity matrix.
165-
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
166-
buffer = np.empty_like(state)
167-
result = method(ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))))
164+
result = method(ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape))
168165

169166
if result is NotImplemented or result is None:
170167
return result
@@ -179,15 +176,26 @@ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
179176
if operations is None:
180177
return NotImplemented
181178

179+
all_qubits = frozenset(q for op in operations for q in op.qubits)
180+
work_qubits = frozenset(qubits)
181+
ancillas = tuple(sorted(all_qubits.difference(work_qubits)))
182+
183+
ordered_qubits = ancillas + tuple(qubits)
184+
val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape
185+
182186
# Apply sub-operations' unitary effects to an identity matrix.
183-
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
184-
buffer = np.empty_like(state)
185187
result = apply_unitaries(
186-
operations, qubits, ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))), None
188+
operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape), None
187189
)
188190

189191
# Package result.
190192
if result is None:
191193
return None
194+
192195
state_len = np.prod(val_qid_shape, dtype=np.int64)
193-
return result.reshape((state_len, state_len))
196+
result = result.reshape((state_len, state_len))
197+
# Assuming borrowable qubits are restored to their original state and
198+
# clean qubits restord to the zero state then the desired unitary is
199+
# the upper left square.
200+
work_state_len = np.prod(val_qid_shape[len(ancillas) :], dtype=np.int64)
201+
return result[:work_state_len, :work_state_len]

cirq/protocols/unitary_protocol_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818

1919
import cirq
20+
from cirq import testing
2021

2122
m0: np.ndarray = np.array([])
2223
# yapf: disable
@@ -188,6 +189,42 @@ def test_has_unitary():
188189
assert not cirq.has_unitary(FullyImplemented(False))
189190

190191

192+
def _test_gate_that_allocates_qubits(gate):
193+
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose
194+
195+
op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate)))
196+
moment = cirq.Moment(op)
197+
circuit = cirq.FrozenCircuit(op)
198+
circuit_op = cirq.CircuitOperation(circuit)
199+
for val in [gate, op, moment, circuit, circuit_op]:
200+
unitary_from_strat = _strat_unitary_from_decompose(val)
201+
assert unitary_from_strat is not None
202+
np.testing.assert_allclose(unitary_from_strat, gate.narrow_unitary())
203+
204+
205+
@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
206+
@pytest.mark.parametrize('phase_state', [0, 1])
207+
@pytest.mark.parametrize('target_bitsize', [1, 2, 3])
208+
@pytest.mark.parametrize('ancilla_bitsize', [1, 4])
209+
def test_decompose_gate_that_allocates_clean_qubits(
210+
theta: float, phase_state: int, target_bitsize: int, ancilla_bitsize: int
211+
):
212+
213+
gate = testing.PhaseUsingCleanAncilla(theta, phase_state, target_bitsize, ancilla_bitsize)
214+
_test_gate_that_allocates_qubits(gate)
215+
216+
217+
@pytest.mark.parametrize('phase_state', [0, 1])
218+
@pytest.mark.parametrize('target_bitsize', [1, 2, 3])
219+
@pytest.mark.parametrize('ancilla_bitsize', [1, 4])
220+
def test_decompose_gate_that_allocates_dirty_qubits(
221+
phase_state: int, target_bitsize: int, ancilla_bitsize: int
222+
):
223+
224+
gate = testing.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize)
225+
_test_gate_that_allocates_qubits(gate)
226+
227+
191228
def test_decompose_and_get_unitary():
192229
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose
193230

cirq/testing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,5 @@
107107
)
108108

109109
from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit
110+
111+
from cirq.testing.sample_gates import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla

cirq/testing/sample_gates.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2023 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import dataclasses
15+
16+
import cirq
17+
import numpy as np
18+
from cirq import ops, qis
19+
20+
21+
def _matrix_for_phasing_state(num_qubits, phase_state, phase):
22+
matrix = qis.eye_tensor((2,) * num_qubits, dtype=np.complex128)
23+
matrix = matrix.reshape((2**num_qubits, 2**num_qubits))
24+
matrix[phase_state, phase_state] = phase
25+
print(num_qubits, phase_state, phase)
26+
print(matrix)
27+
return matrix
28+
29+
30+
@dataclasses.dataclass(frozen=True)
31+
class PhaseUsingCleanAncilla(ops.Gate):
32+
r"""Phases the state $|phase_state>$ by $\exp(1j * \pi * \theta)$ using one clean ancilla."""
33+
34+
theta: float
35+
phase_state: int = 1
36+
target_bitsize: int = 1
37+
ancilla_bitsize: int = 1
38+
39+
def _num_qubits_(self):
40+
return self.target_bitsize
41+
42+
def _decompose_(self, qubits):
43+
anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc")
44+
cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}']
45+
cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)]
46+
47+
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
48+
yield [cnot_ladder, ops.Z(anc[-1]) ** self.theta, reversed(cnot_ladder)]
49+
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
50+
51+
def narrow_unitary(self) -> np.ndarray:
52+
"""Narrowed unitary corresponding to the unitary effect applied on target qubits."""
53+
phase = np.exp(1j * np.pi * self.theta)
54+
return _matrix_for_phasing_state(self.target_bitsize, self.phase_state, phase)
55+
56+
57+
@dataclasses.dataclass(frozen=True)
58+
class PhaseUsingDirtyAncilla(ops.Gate):
59+
r"""Phases the state $|phase_state>$ by -1 using one dirty ancilla."""
60+
61+
phase_state: int = 1
62+
target_bitsize: int = 1
63+
ancilla_bitsize: int = 1
64+
65+
def _num_qubits_(self):
66+
return self.target_bitsize
67+
68+
def _decompose_(self, qubits):
69+
anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc")
70+
cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}']
71+
cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)]
72+
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
73+
yield [cnot_ladder, ops.Z(anc[-1]), reversed(cnot_ladder)]
74+
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
75+
yield [cnot_ladder, ops.Z(anc[-1]), reversed(cnot_ladder)]
76+
77+
def narrow_unitary(self) -> np.ndarray:
78+
"""Narrowed unitary corresponding to the unitary effect applied on target qubits."""
79+
return _matrix_for_phasing_state(self.target_bitsize, self.phase_state, -1)

cirq/testing/sample_gates_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2023 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
import numpy as np
17+
from cirq.testing import sample_gates
18+
import cirq
19+
20+
21+
@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 20))
22+
def test_phase_using_clean_ancilla(theta: float):
23+
g = sample_gates.PhaseUsingCleanAncilla(theta)
24+
q = cirq.LineQubit(0)
25+
qubit_order = cirq.QubitOrder.explicit([q], fallback=cirq.QubitOrder.DEFAULT)
26+
decomposed_unitary = cirq.Circuit(cirq.decompose_once(g.on(q))).unitary(qubit_order=qubit_order)
27+
phase = np.exp(1j * np.pi * theta)
28+
np.testing.assert_allclose(g.narrow_unitary(), np.array([[1, 0], [0, phase]]))
29+
np.testing.assert_allclose(
30+
decomposed_unitary,
31+
# fmt: off
32+
np.array(
33+
[
34+
[1 , 0 , 0 , 0],
35+
[0 , phase, 0 , 0],
36+
[0 , 0 , phase, 0],
37+
[0 , 0 , 0 , 1],
38+
]
39+
),
40+
# fmt: on
41+
)
42+
43+
44+
@pytest.mark.parametrize(
45+
'target_bitsize, phase_state', [(1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (2, 3)]
46+
)
47+
@pytest.mark.parametrize('ancilla_bitsize', [1, 4])
48+
def test_phase_using_dirty_ancilla(target_bitsize, phase_state, ancilla_bitsize):
49+
g = sample_gates.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize)
50+
q = cirq.LineQubit.range(target_bitsize)
51+
qubit_order = cirq.QubitOrder.explicit(q, fallback=cirq.QubitOrder.DEFAULT)
52+
decomposed_circuit = cirq.Circuit(cirq.decompose_once(g.on(*q)))
53+
decomposed_unitary = decomposed_circuit.unitary(qubit_order=qubit_order)
54+
phase_matrix = np.eye(2**target_bitsize)
55+
phase_matrix[phase_state, phase_state] = -1
56+
np.testing.assert_allclose(g.narrow_unitary(), phase_matrix)
57+
np.testing.assert_allclose(
58+
decomposed_unitary, np.kron(phase_matrix, np.eye(2**ancilla_bitsize)), atol=1e-5
59+
)

0 commit comments

Comments
 (0)