Skip to content

Add Quantum Fourier Transform #2135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 24, 2019
3 changes: 3 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,12 @@
phase_damp,
phase_flip,
PhaseDampingChannel,
PhaseGradientGate,
PhasedXPowGate,
PhaseFlipChannel,
QFT,
Qid,
QuantumFourierTransformGate,
QubitOrder,
QubitOrderOrList,
reset,
Expand Down
6 changes: 5 additions & 1 deletion cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,11 @@ def _draw_moment_in_diagram(
if exponent is not None:
if info.connected:
# Add an exponent to the last label only.
out_diagram.write(x, y2, '^' + exponent)
if info.exponent_qubit_index is not None:
y3 = qubit_map[op.qubits[info.exponent_qubit_index]]
else:
y3 = y2
out_diagram.write(x, y3, '^' + exponent)
else:
# Add an exponent to every label
for index in indices:
Expand Down
6 changes: 6 additions & 0 deletions cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
from cirq.ops.eigen_gate import (
EigenGate,)

from cirq.ops.fourier_transform import (
PhaseGradientGate,
QFT,
QuantumFourierTransformGate,
)

from cirq.ops.fsim_gate import (
FSimGate,)

Expand Down
7 changes: 5 additions & 2 deletions cirq/ops/controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,8 @@ def get_symbol(vals):

wire_symbols = (*(get_symbol(vals) for vals in self.control_values),
*sub_info.wire_symbols)
return protocols.CircuitDiagramInfo(wire_symbols=wire_symbols,
exponent=sub_info.exponent)
return protocols.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=sub_info.exponent,
exponent_qubit_index=None if sub_info.exponent_qubit_index is None
else sub_info.exponent_qubit_index + 1)
195 changes: 195 additions & 0 deletions cirq/ops/fourier_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright 2019 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numpy as np
import sympy

import cirq
from cirq import value, _compat
from cirq.ops import raw_types


@value.value_equality
class QuantumFourierTransformGate(raw_types.Gate):
"""Switches from the computational basis to the frequency basis."""

def __init__(self, num_qubits: int, *, without_reverse: bool = False):
"""
Args:
num_qubits: The number of qubits the gate applies to.
without_reverse: Whether or not to include the swaps at the end
of the circuit decomposition that reverse the order of the
qubits. These are technically necessary in order to perform the
correct effect, but can almost always be optimized away by just
performing later operations on different qubits.
"""
self._num_qubits = num_qubits
self._without_reverse = without_reverse

def _json_dict_(self):
return {
'cirq_type': self.__class__.__name__,
'num_qubits': self._num_qubits,
'without_reverse': self._without_reverse
}

def _value_equality_values_(self):
return self._num_qubits, self._without_reverse

def num_qubits(self) -> int:
return self._num_qubits

def _decompose_(self, qubits):
if len(qubits) == 0:
return
yield cirq.H(qubits[0])
for i in range(1, len(qubits)):
yield PhaseGradientGate(
num_qubits=i,
exponent=0.5).on(*qubits[:i][::-1]).controlled_by(qubits[i])
yield cirq.H(qubits[i])
if not self._without_reverse:
for i in range(len(qubits) // 2):
yield cirq.SWAP(qubits[i], qubits[-i - 1])

def _has_unitary_(self):
return True

def __str__(self):
return 'QFT[norev]' if self._without_reverse else 'QFT'

def __repr__(self):
return ('cirq.QuantumFourierTransformGate(num_qubits={!r}, '
'without_reverse={!r})'.format(self._num_qubits,
self._without_reverse))

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
return cirq.CircuitDiagramInfo(
wire_symbols=(str(self),) +
tuple(f'#{k+1}' for k in range(1, self._num_qubits)),
exponent_qubit_index=0)


@value.value_equality
class PhaseGradientGate(raw_types.Gate):
"""Phases each state |k⟩ out of n by e^(2*pi*i*k/n*exponent).
"""

def __init__(self, *, num_qubits: int, exponent: Union[float, sympy.Basic]):
self._num_qubits = num_qubits
self.exponent = exponent

def _json_dict_(self):
return {
'cirq_type': self.__class__.__name__,
'num_qubits': self._num_qubits,
'exponent': self.exponent
}

def _value_equality_values_(self):
return self._num_qubits, self.exponent

def num_qubits(self) -> int:
return self._num_qubits

def _decompose_(self, qubits):
for i, q in enumerate(qubits):
yield cirq.Z(q)**(self.exponent / 2**i)

def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
if isinstance(self.exponent, sympy.Basic):
return NotImplemented

n = int(np.product([args.target_tensor.shape[k] for k in args.axes]))
for i in range(n):
p = 1j**(4 * i / n * self.exponent)
args.target_tensor[args.subspace_index(big_endian_bits_int=i)] *= p

return args.target_tensor

def __pow__(self, power):
new_exponent = cirq.mul(self.exponent, power, NotImplemented)
if new_exponent is NotImplemented:
# coverage: ignore
return NotImplemented
return PhaseGradientGate(num_qubits=self._num_qubits,
exponent=new_exponent)

def _unitary_(self):
if isinstance(self.exponent, sympy.Basic):
return NotImplemented

size = 1 << self._num_qubits
return np.diag(
[1j**(4 * i / size * self.exponent) for i in range(size)])

def _has_unitary_(self):
return not isinstance(self.exponent, sympy.Basic)

def _is_parameterized_(self):
return isinstance(self.exponent, sympy.Basic)

def _resolve_parameters_(self, resolver):
new_exponent = cirq.resolve_parameters(self.exponent, resolver)
if new_exponent is self.exponent:
return self
return PhaseGradientGate(num_qubits=self._num_qubits,
exponent=new_exponent)

def __str__(self):
return f'Grad[{self._num_qubits}]' + (f'^{self.exponent}'
if self.exponent != 1 else '')

def __repr__(self):
return 'cirq.PhaseGradientGate(num_qubits={!r}, exponent={})'.format(
self._num_qubits, _compat.proper_repr(self.exponent))

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
return cirq.CircuitDiagramInfo(
wire_symbols=('Grad',) +
tuple(f'#{k+1}' for k in range(1, self._num_qubits)),
exponent=self.exponent,
exponent_qubit_index=0)


def QFT(*qubits: 'cirq.Qid',
without_reverse: bool = False,
inverse: bool = False) -> 'cirq.Operation':
"""The quantum Fourier transform.

Transforms a qubit register from the computational basis to the frequency
basis.

The inverse quantum Fourier transform is `cirq.QFT(*qubits)**-1` or
equivalently `cirq.inverse(cirq.QFT(*qubits))`.

Args:
qubits: The qubits to apply the QFT to.
without_reverse: When set, swap gates at the end of the QFT are omitted.
This reverses the qubit order relative to the standard QFT effect,
but makes the gate cheaper to apply.
inverse: If set, the inverse QFT is performed instead of the QFT.
Equivalent to calling `cirq.inverse` on the result, or raising it
to the -1.

Returns:
A `cirq.Operation` applying the QFT to the given qubits.
"""
result = QuantumFourierTransformGate(
len(qubits), without_reverse=without_reverse).on(*qubits)
if inverse:
result = cirq.inverse(result)
return result
147 changes: 147 additions & 0 deletions cirq/ops/fourier_transform_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2019 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import sympy

import cirq


def test_phase_gradient():
np.testing.assert_allclose(
cirq.unitary(cirq.PhaseGradientGate(num_qubits=2, exponent=1)),
np.diag([1, 1j, -1, -1j]))

for k in range(4):
cirq.testing.assert_implements_consistent_protocols(
cirq.PhaseGradientGate(num_qubits=k, exponent=1))


def test_phase_gradient_symbolic():
a = cirq.PhaseGradientGate(num_qubits=2, exponent=0.5)
b = cirq.PhaseGradientGate(num_qubits=2, exponent=sympy.Symbol('t'))
assert not cirq.is_parameterized(a)
assert cirq.is_parameterized(b)
assert cirq.has_unitary(a)
assert not cirq.has_unitary(b)
assert cirq.resolve_parameters(a, {'t': 0.25}) is a
assert cirq.resolve_parameters(b, {'t': 0.5}) == a
assert cirq.resolve_parameters(b, {'t': 0.25}) == cirq.PhaseGradientGate(
num_qubits=2, exponent=0.25)


def test_str():
assert str(cirq.PhaseGradientGate(num_qubits=2,
exponent=0.5)) == 'Grad[2]^0.5'
assert str(cirq.PhaseGradientGate(num_qubits=2, exponent=1)) == 'Grad[2]'


def test_pow():
a = cirq.PhaseGradientGate(num_qubits=2, exponent=0.5)
assert a**0.5 == cirq.PhaseGradientGate(num_qubits=2, exponent=0.25)
assert a**sympy.Symbol('t') == cirq.PhaseGradientGate(num_qubits=2,
exponent=0.5 *
sympy.Symbol('t'))


def test_qft():
np.testing.assert_allclose(cirq.unitary(cirq.QFT(*cirq.LineQubit.range(2))),
np.array([
[1, 1, 1, 1],
[1, 1j, -1, -1j],
[1, -1, 1, -1],
[1, -1j, -1, 1j],
]) / 2,
atol=1e-8)

np.testing.assert_allclose(cirq.unitary(
cirq.QFT(*cirq.LineQubit.range(2), without_reverse=True)),
np.array([
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 1j, -1, -1j],
[1, -1j, -1, 1j],
]) / 2,
atol=1e-8)

np.testing.assert_allclose(
cirq.unitary(cirq.QFT(*cirq.LineQubit.range(4))),
np.array([[np.exp(2j * np.pi * i * j / 16)
for i in range(16)]
for j in range(16)]) / 4,
atol=1e-8)

np.testing.assert_allclose(cirq.unitary(
cirq.QFT(*cirq.LineQubit.range(2))**-1),
np.array([
[1, 1, 1, 1],
[1, -1j, -1, 1j],
[1, -1, 1, -1],
[1, 1j, -1, -1j],
]) / 2,
atol=1e-8)

for k in range(4):
for b in [False, True]:
cirq.testing.assert_implements_consistent_protocols(
cirq.QuantumFourierTransformGate(num_qubits=k,
without_reverse=b))


def test_inverse():
a, b, c = cirq.LineQubit.range(3)
assert cirq.QFT(a, b, c, inverse=True) == cirq.QFT(a, b, c)**-1
assert cirq.QFT(a, b, c, inverse=True,
without_reverse=True) == cirq.inverse(
cirq.QFT(a, b, c, without_reverse=True))


def test_circuit_diagram():
cirq.testing.assert_has_diagram(
cirq.Circuit.from_ops(
cirq.decompose_once(cirq.QFT(*cirq.LineQubit.range(4)))), """
0: ───H───Grad^0.5───────#2─────────────#3─────────────×───
│ │ │ │
1: ───────@──────────H───Grad^0.5───────#2─────────×───┼───
│ │ │ │
2: ──────────────────────@──────────H───Grad^0.5───×───┼───
│ │
3: ─────────────────────────────────────@──────────H───×───
""")

cirq.testing.assert_has_diagram(
cirq.Circuit.from_ops(
cirq.decompose_once(
cirq.QFT(*cirq.LineQubit.range(4), without_reverse=True))), """
0: ───H───Grad^0.5───────#2─────────────#3─────────────
│ │ │
1: ───────@──────────H───Grad^0.5───────#2─────────────
│ │
2: ──────────────────────@──────────H───Grad^0.5───────
3: ─────────────────────────────────────@──────────H───
""")

cirq.testing.assert_has_diagram(
cirq.Circuit.from_ops(cirq.QFT(*cirq.LineQubit.range(4)),
cirq.inverse(cirq.QFT(*cirq.LineQubit.range(4)))),
"""
0: ───QFT───QFT^-1───
│ │
1: ───#2────#2───────
│ │
2: ───#3────#3───────
│ │
3: ───#4────#4───────
""")
Loading