Skip to content

Commit 498200c

Browse files
StrilancCirqBot
authored andcommitted
Add Quantum Fourier Transform (#2135)
- Add PhaseGradientGate, QuantumFourierTransformGate, and QFT method - QFT does not define _apply_unitary_ or _unitary_ because the most efficient way to get them is actually via the decomposition - Add exponent_qubit_index to CircuitDiagramInfo - Update ControlledOperation _circuit_diagram_info_ to forward exponent_qubit_index - Add _circuit_diagram_info_ to _InverseCompositeGate - Update phase_estimator example to use built-in QFT - Update hhl example to use built-in QFT - Add `big_endian_bits` option to `ApplyUnitaryArgs.subspace_index` This one had snuck into a couple examples as redundant code. It's also one that I've seen people ask about.
1 parent 0f8518c commit 498200c

15 files changed

+485
-143
lines changed

cirq/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,12 @@
209209
phase_damp,
210210
phase_flip,
211211
PhaseDampingChannel,
212+
PhaseGradientGate,
212213
PhasedXPowGate,
213214
PhaseFlipChannel,
215+
QFT,
214216
Qid,
217+
QuantumFourierTransformGate,
215218
QubitOrder,
216219
QubitOrderOrList,
217220
reset,

cirq/circuits/circuit.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1777,7 +1777,11 @@ def _draw_moment_in_diagram(
17771777
if exponent is not None:
17781778
if info.connected:
17791779
# Add an exponent to the last label only.
1780-
out_diagram.write(x, y2, '^' + exponent)
1780+
if info.exponent_qubit_index is not None:
1781+
y3 = qubit_map[op.qubits[info.exponent_qubit_index]]
1782+
else:
1783+
y3 = y2
1784+
out_diagram.write(x, y3, '^' + exponent)
17811785
else:
17821786
# Add an exponent to every label
17831787
for index in indices:

cirq/ops/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@
8484
from cirq.ops.eigen_gate import (
8585
EigenGate,)
8686

87+
from cirq.ops.fourier_transform import (
88+
PhaseGradientGate,
89+
QFT,
90+
QuantumFourierTransformGate,
91+
)
92+
8793
from cirq.ops.fsim_gate import (
8894
FSimGate,)
8995

cirq/ops/controlled_operation.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,8 @@ def get_symbol(vals):
193193

194194
wire_symbols = (*(get_symbol(vals) for vals in self.control_values),
195195
*sub_info.wire_symbols)
196-
return protocols.CircuitDiagramInfo(wire_symbols=wire_symbols,
197-
exponent=sub_info.exponent)
196+
return protocols.CircuitDiagramInfo(
197+
wire_symbols=wire_symbols,
198+
exponent=sub_info.exponent,
199+
exponent_qubit_index=None if sub_info.exponent_qubit_index is None
200+
else sub_info.exponent_qubit_index + 1)

cirq/ops/fourier_transform.py

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright 2019 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Union
16+
17+
import numpy as np
18+
import sympy
19+
20+
import cirq
21+
from cirq import value, _compat
22+
from cirq.ops import raw_types
23+
24+
25+
@value.value_equality
26+
class QuantumFourierTransformGate(raw_types.Gate):
27+
"""Switches from the computational basis to the frequency basis."""
28+
29+
def __init__(self, num_qubits: int, *, without_reverse: bool = False):
30+
"""
31+
Args:
32+
num_qubits: The number of qubits the gate applies to.
33+
without_reverse: Whether or not to include the swaps at the end
34+
of the circuit decomposition that reverse the order of the
35+
qubits. These are technically necessary in order to perform the
36+
correct effect, but can almost always be optimized away by just
37+
performing later operations on different qubits.
38+
"""
39+
self._num_qubits = num_qubits
40+
self._without_reverse = without_reverse
41+
42+
def _json_dict_(self):
43+
return {
44+
'cirq_type': self.__class__.__name__,
45+
'num_qubits': self._num_qubits,
46+
'without_reverse': self._without_reverse
47+
}
48+
49+
def _value_equality_values_(self):
50+
return self._num_qubits, self._without_reverse
51+
52+
def num_qubits(self) -> int:
53+
return self._num_qubits
54+
55+
def _decompose_(self, qubits):
56+
if len(qubits) == 0:
57+
return
58+
yield cirq.H(qubits[0])
59+
for i in range(1, len(qubits)):
60+
yield PhaseGradientGate(
61+
num_qubits=i,
62+
exponent=0.5).on(*qubits[:i][::-1]).controlled_by(qubits[i])
63+
yield cirq.H(qubits[i])
64+
if not self._without_reverse:
65+
for i in range(len(qubits) // 2):
66+
yield cirq.SWAP(qubits[i], qubits[-i - 1])
67+
68+
def _has_unitary_(self):
69+
return True
70+
71+
def __str__(self):
72+
return 'QFT[norev]' if self._without_reverse else 'QFT'
73+
74+
def __repr__(self):
75+
return ('cirq.QuantumFourierTransformGate(num_qubits={!r}, '
76+
'without_reverse={!r})'.format(self._num_qubits,
77+
self._without_reverse))
78+
79+
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
80+
return cirq.CircuitDiagramInfo(
81+
wire_symbols=(str(self),) +
82+
tuple(f'#{k+1}' for k in range(1, self._num_qubits)),
83+
exponent_qubit_index=0)
84+
85+
86+
@value.value_equality
87+
class PhaseGradientGate(raw_types.Gate):
88+
"""Phases each state |k⟩ out of n by e^(2*pi*i*k/n*exponent).
89+
"""
90+
91+
def __init__(self, *, num_qubits: int, exponent: Union[float, sympy.Basic]):
92+
self._num_qubits = num_qubits
93+
self.exponent = exponent
94+
95+
def _json_dict_(self):
96+
return {
97+
'cirq_type': self.__class__.__name__,
98+
'num_qubits': self._num_qubits,
99+
'exponent': self.exponent
100+
}
101+
102+
def _value_equality_values_(self):
103+
return self._num_qubits, self.exponent
104+
105+
def num_qubits(self) -> int:
106+
return self._num_qubits
107+
108+
def _decompose_(self, qubits):
109+
for i, q in enumerate(qubits):
110+
yield cirq.Z(q)**(self.exponent / 2**i)
111+
112+
def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
113+
if isinstance(self.exponent, sympy.Basic):
114+
return NotImplemented
115+
116+
n = int(np.product([args.target_tensor.shape[k] for k in args.axes]))
117+
for i in range(n):
118+
p = 1j**(4 * i / n * self.exponent)
119+
args.target_tensor[args.subspace_index(big_endian_bits_int=i)] *= p
120+
121+
return args.target_tensor
122+
123+
def __pow__(self, power):
124+
new_exponent = cirq.mul(self.exponent, power, NotImplemented)
125+
if new_exponent is NotImplemented:
126+
# coverage: ignore
127+
return NotImplemented
128+
return PhaseGradientGate(num_qubits=self._num_qubits,
129+
exponent=new_exponent)
130+
131+
def _unitary_(self):
132+
if isinstance(self.exponent, sympy.Basic):
133+
return NotImplemented
134+
135+
size = 1 << self._num_qubits
136+
return np.diag(
137+
[1j**(4 * i / size * self.exponent) for i in range(size)])
138+
139+
def _has_unitary_(self):
140+
return not isinstance(self.exponent, sympy.Basic)
141+
142+
def _is_parameterized_(self):
143+
return isinstance(self.exponent, sympy.Basic)
144+
145+
def _resolve_parameters_(self, resolver):
146+
new_exponent = cirq.resolve_parameters(self.exponent, resolver)
147+
if new_exponent is self.exponent:
148+
return self
149+
return PhaseGradientGate(num_qubits=self._num_qubits,
150+
exponent=new_exponent)
151+
152+
def __str__(self):
153+
return f'Grad[{self._num_qubits}]' + (f'^{self.exponent}'
154+
if self.exponent != 1 else '')
155+
156+
def __repr__(self):
157+
return 'cirq.PhaseGradientGate(num_qubits={!r}, exponent={})'.format(
158+
self._num_qubits, _compat.proper_repr(self.exponent))
159+
160+
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
161+
return cirq.CircuitDiagramInfo(
162+
wire_symbols=('Grad',) +
163+
tuple(f'#{k+1}' for k in range(1, self._num_qubits)),
164+
exponent=self.exponent,
165+
exponent_qubit_index=0)
166+
167+
168+
def QFT(*qubits: 'cirq.Qid',
169+
without_reverse: bool = False,
170+
inverse: bool = False) -> 'cirq.Operation':
171+
"""The quantum Fourier transform.
172+
173+
Transforms a qubit register from the computational basis to the frequency
174+
basis.
175+
176+
The inverse quantum Fourier transform is `cirq.QFT(*qubits)**-1` or
177+
equivalently `cirq.inverse(cirq.QFT(*qubits))`.
178+
179+
Args:
180+
qubits: The qubits to apply the QFT to.
181+
without_reverse: When set, swap gates at the end of the QFT are omitted.
182+
This reverses the qubit order relative to the standard QFT effect,
183+
but makes the gate cheaper to apply.
184+
inverse: If set, the inverse QFT is performed instead of the QFT.
185+
Equivalent to calling `cirq.inverse` on the result, or raising it
186+
to the -1.
187+
188+
Returns:
189+
A `cirq.Operation` applying the QFT to the given qubits.
190+
"""
191+
result = QuantumFourierTransformGate(
192+
len(qubits), without_reverse=without_reverse).on(*qubits)
193+
if inverse:
194+
result = cirq.inverse(result)
195+
return result

cirq/ops/fourier_transform_test.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2019 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import sympy
17+
18+
import cirq
19+
20+
21+
def test_phase_gradient():
22+
np.testing.assert_allclose(
23+
cirq.unitary(cirq.PhaseGradientGate(num_qubits=2, exponent=1)),
24+
np.diag([1, 1j, -1, -1j]))
25+
26+
for k in range(4):
27+
cirq.testing.assert_implements_consistent_protocols(
28+
cirq.PhaseGradientGate(num_qubits=k, exponent=1))
29+
30+
31+
def test_phase_gradient_symbolic():
32+
a = cirq.PhaseGradientGate(num_qubits=2, exponent=0.5)
33+
b = cirq.PhaseGradientGate(num_qubits=2, exponent=sympy.Symbol('t'))
34+
assert not cirq.is_parameterized(a)
35+
assert cirq.is_parameterized(b)
36+
assert cirq.has_unitary(a)
37+
assert not cirq.has_unitary(b)
38+
assert cirq.resolve_parameters(a, {'t': 0.25}) is a
39+
assert cirq.resolve_parameters(b, {'t': 0.5}) == a
40+
assert cirq.resolve_parameters(b, {'t': 0.25}) == cirq.PhaseGradientGate(
41+
num_qubits=2, exponent=0.25)
42+
43+
44+
def test_str():
45+
assert str(cirq.PhaseGradientGate(num_qubits=2,
46+
exponent=0.5)) == 'Grad[2]^0.5'
47+
assert str(cirq.PhaseGradientGate(num_qubits=2, exponent=1)) == 'Grad[2]'
48+
49+
50+
def test_pow():
51+
a = cirq.PhaseGradientGate(num_qubits=2, exponent=0.5)
52+
assert a**0.5 == cirq.PhaseGradientGate(num_qubits=2, exponent=0.25)
53+
assert a**sympy.Symbol('t') == cirq.PhaseGradientGate(num_qubits=2,
54+
exponent=0.5 *
55+
sympy.Symbol('t'))
56+
57+
58+
def test_qft():
59+
np.testing.assert_allclose(cirq.unitary(cirq.QFT(*cirq.LineQubit.range(2))),
60+
np.array([
61+
[1, 1, 1, 1],
62+
[1, 1j, -1, -1j],
63+
[1, -1, 1, -1],
64+
[1, -1j, -1, 1j],
65+
]) / 2,
66+
atol=1e-8)
67+
68+
np.testing.assert_allclose(cirq.unitary(
69+
cirq.QFT(*cirq.LineQubit.range(2), without_reverse=True)),
70+
np.array([
71+
[1, 1, 1, 1],
72+
[1, -1, 1, -1],
73+
[1, 1j, -1, -1j],
74+
[1, -1j, -1, 1j],
75+
]) / 2,
76+
atol=1e-8)
77+
78+
np.testing.assert_allclose(
79+
cirq.unitary(cirq.QFT(*cirq.LineQubit.range(4))),
80+
np.array([[np.exp(2j * np.pi * i * j / 16)
81+
for i in range(16)]
82+
for j in range(16)]) / 4,
83+
atol=1e-8)
84+
85+
np.testing.assert_allclose(cirq.unitary(
86+
cirq.QFT(*cirq.LineQubit.range(2))**-1),
87+
np.array([
88+
[1, 1, 1, 1],
89+
[1, -1j, -1, 1j],
90+
[1, -1, 1, -1],
91+
[1, 1j, -1, -1j],
92+
]) / 2,
93+
atol=1e-8)
94+
95+
for k in range(4):
96+
for b in [False, True]:
97+
cirq.testing.assert_implements_consistent_protocols(
98+
cirq.QuantumFourierTransformGate(num_qubits=k,
99+
without_reverse=b))
100+
101+
102+
def test_inverse():
103+
a, b, c = cirq.LineQubit.range(3)
104+
assert cirq.QFT(a, b, c, inverse=True) == cirq.QFT(a, b, c)**-1
105+
assert cirq.QFT(a, b, c, inverse=True,
106+
without_reverse=True) == cirq.inverse(
107+
cirq.QFT(a, b, c, without_reverse=True))
108+
109+
110+
def test_circuit_diagram():
111+
cirq.testing.assert_has_diagram(
112+
cirq.Circuit.from_ops(
113+
cirq.decompose_once(cirq.QFT(*cirq.LineQubit.range(4)))), """
114+
0: ───H───Grad^0.5───────#2─────────────#3─────────────×───
115+
│ │ │ │
116+
1: ───────@──────────H───Grad^0.5───────#2─────────×───┼───
117+
│ │ │ │
118+
2: ──────────────────────@──────────H───Grad^0.5───×───┼───
119+
│ │
120+
3: ─────────────────────────────────────@──────────H───×───
121+
""")
122+
123+
cirq.testing.assert_has_diagram(
124+
cirq.Circuit.from_ops(
125+
cirq.decompose_once(
126+
cirq.QFT(*cirq.LineQubit.range(4), without_reverse=True))), """
127+
0: ───H───Grad^0.5───────#2─────────────#3─────────────
128+
│ │ │
129+
1: ───────@──────────H───Grad^0.5───────#2─────────────
130+
│ │
131+
2: ──────────────────────@──────────H───Grad^0.5───────
132+
133+
3: ─────────────────────────────────────@──────────H───
134+
""")
135+
136+
cirq.testing.assert_has_diagram(
137+
cirq.Circuit.from_ops(cirq.QFT(*cirq.LineQubit.range(4)),
138+
cirq.inverse(cirq.QFT(*cirq.LineQubit.range(4)))),
139+
"""
140+
0: ───QFT───QFT^-1───
141+
│ │
142+
1: ───#2────#2───────
143+
│ │
144+
2: ───#3────#3───────
145+
│ │
146+
3: ───#4────#4───────
147+
""")

0 commit comments

Comments
 (0)