Skip to content

Add Quantum Fourier Transform #2135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 24, 2019
3 changes: 3 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,12 @@
phase_damp,
phase_flip,
PhaseDampingChannel,
PhaseGradientGate,
PhasedXPowGate,
PhaseFlipChannel,
QFT,
Qid,
QuantumFourierTransformGate,
QubitOrder,
QubitOrderOrList,
reset,
Expand Down
6 changes: 5 additions & 1 deletion cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,11 @@ def _draw_moment_in_diagram(
if exponent is not None:
if info.connected:
# Add an exponent to the last label only.
out_diagram.write(x, y2, '^' + exponent)
if info.exponent_qubit_index is not None:
y3 = qubit_map[op.qubits[info.exponent_qubit_index]]
else:
y3 = y2
out_diagram.write(x, y3, '^' + exponent)
else:
# Add an exponent to every label
for index in indices:
Expand Down
6 changes: 6 additions & 0 deletions cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
from cirq.ops.eigen_gate import (
EigenGate,)

from cirq.ops.fourier_transform import (
QuantumFourierTransformGate,
PhaseGradientGate,
QFT,
)

from cirq.ops.fsim_gate import (
FSimGate,)

Expand Down
7 changes: 5 additions & 2 deletions cirq/ops/controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,8 @@ def get_symbol(vals):

wire_symbols = (*(get_symbol(vals) for vals in self.control_values),
*sub_info.wire_symbols)
return protocols.CircuitDiagramInfo(wire_symbols=wire_symbols,
exponent=sub_info.exponent)
return protocols.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=sub_info.exponent,
exponent_qubit_index=None if sub_info.exponent_qubit_index is None
else sub_info.exponent_qubit_index + 1)
187 changes: 187 additions & 0 deletions cirq/ops/fourier_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numpy as np
import sympy

import cirq
from cirq import value, _compat
from cirq.ops import raw_types


@value.value_equality
class QuantumFourierTransformGate(raw_types.Gate):
"""Switches from the computational basis to the frequency basis."""

def __init__(self, num_qubits: int, *, without_reverse: bool = False):
"""
Args:
num_qubits: The number of qubits the gate applies to.
without_reverse: Whether or not to include the swaps at the end
of the circuit decomposition that reverse the order of the
qubits. These are technically necessary in order to perform the
correct effect, but can almost always be optimized away by just
performing later operations on different qubits.
"""
self._num_qubits = num_qubits
self._without_reverse = without_reverse

def _json_dict_(self):
return {
'cirq_type': self.__class__.__name__,
'num_qubits': self._num_qubits,
'without_reverse': self._without_reverse
}

def _value_equality_values_(self):
return self._num_qubits, self._without_reverse

def num_qubits(self) -> int:
return self._num_qubits

def _decompose_(self, qubits):
if len(qubits) == 0:
return
yield cirq.H(qubits[0])
for i in range(1, len(qubits)):
yield PhaseGradientGate(
num_qubits=i,
exponent=0.5).on(*qubits[:i][::-1]).controlled_by(qubits[i])
yield cirq.H(qubits[i])
if not self._without_reverse:
for i in range(len(qubits) // 2):
yield cirq.SWAP(qubits[i], qubits[-i - 1])

def _has_unitary_(self):
return True

def __str__(self):
return 'QFT[norev]' if self._without_reverse else 'QFT'

def __repr__(self):
return ('cirq.QuantumFourierTransformGate(num_qubits={!r}, '
'without_reverse=False)'.format(self._num_qubits,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will self._without_reverse .format if you hardcode it to False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whooooops. Added the true case to the test battery inputs and fixed this.

self._without_reverse))

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
return cirq.CircuitDiagramInfo(
wire_symbols=(str(self),) +
tuple(f'#{k+1}' for k in range(1, self._num_qubits)),
exponent_qubit_index=0)


@value.value_equality
class PhaseGradientGate(raw_types.Gate):
"""Phases each computational basis state |k⟩ out of n by e^(i*k/n*exponent).
"""

def __init__(self, *, num_qubits: int, exponent: Union[float, sympy.Basic]):
self._num_qubits = num_qubits
self.exponent = exponent

def _json_dict_(self):
return {
'cirq_type': self.__class__.__name__,
'num_qubits': self._num_qubits,
'exponent': self.exponent
}

def _value_equality_values_(self):
return self._num_qubits, self.exponent

def num_qubits(self) -> int:
return self._num_qubits

def _decompose_(self, qubits):
for i, q in enumerate(qubits):
yield cirq.Z(q)**(self.exponent / 2**i)

def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
if isinstance(self.exponent, sympy.Basic):
return NotImplemented

n = int(np.product([args.target_tensor.shape[k] for k in args.axes]))
for i in range(n):
p = 1j**(4 * i / n * self.exponent)
args.target_tensor[args.subspace_index(big_endian_bits_int=i)] *= p

return args.target_tensor

def __pow__(self, power):
new_exponent = cirq.mul(self.exponent, power, NotImplemented)
if new_exponent is NotImplemented:
# coverage: ignore
return NotImplemented
return PhaseGradientGate(num_qubits=self._num_qubits,
exponent=new_exponent)

def _unitary_(self):
if isinstance(self.exponent, sympy.Basic):
return NotImplemented

size = 1 << self._num_qubits
return np.diag(
[1j**(4 * i / size * self.exponent) for i in range(size)])

def _has_unitary_(self):
return not isinstance(self.exponent, sympy.Basic)

def _is_parameterized_(self):
return isinstance(self.exponent, sympy.Basic)

def _resolve_parameters_(self, resolver):
new_exponent = cirq.resolve_parameters(self.exponent, resolver)
if new_exponent is self.exponent:
return self
return PhaseGradientGate(num_qubits=self._num_qubits,
exponent=new_exponent)

def __str__(self):
return f'Grad[{self._num_qubits}]' + (f'^{self.exponent}'
if self.exponent != 1 else '')

def __repr__(self):
return 'cirq.PhaseGradientGate(num_qubits={!r}, exponent={})'.format(
self._num_qubits, _compat.proper_repr(self.exponent))

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
return cirq.CircuitDiagramInfo(
wire_symbols=('Grad',) +
tuple(f'#{k+1}' for k in range(1, self._num_qubits)),
exponent=self.exponent,
exponent_qubit_index=0)


def QFT(*qubits: 'cirq.Qid', without_reverse: bool = False) -> 'cirq.Operation':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be lowercase?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to make it appear the same as e.g. cirq.H and cirq.SWAP. A user learning the API won't be aware that it's a function instead of a global constant (and isn't a function technically a global constant anyways?).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mpharrigan Any other issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

"""The quantum Fourier transform.

Transforms a qubit register from the computational basis to the frequency
basis.

The inverse quantum Fourier transform is `cirq.QFT(*qubits)**-1` or
equivalently `cirq.inverse(cirq.QFT(*qubits))`.

Args:
qubits: The qubits to apply the QFT to.
without_reverse: When set, swap gates at the end of the QFT are omitted.
This reverses the qubit order relative to the standard QFT effect,
but makes the gate cheaper to apply.

Returns:
A `cirq.Operation` applying the QFT to the given qubits.
"""
return QuantumFourierTransformGate(
len(qubits), without_reverse=without_reverse).on(*qubits)
137 changes: 137 additions & 0 deletions cirq/ops/fourier_transform_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import sympy

import cirq


def test_phase_gradient():
np.testing.assert_allclose(
cirq.unitary(cirq.PhaseGradientGate(num_qubits=2, exponent=1)),
np.diag([1, 1j, -1, -1j]))

for k in range(4):
cirq.testing.assert_implements_consistent_protocols(
cirq.PhaseGradientGate(num_qubits=k, exponent=1))


def test_phase_gradient_symbolic():
a = cirq.PhaseGradientGate(num_qubits=2, exponent=0.5)
b = cirq.PhaseGradientGate(num_qubits=2, exponent=sympy.Symbol('t'))
assert not cirq.is_parameterized(a)
assert cirq.is_parameterized(b)
assert cirq.has_unitary(a)
assert not cirq.has_unitary(b)
assert cirq.resolve_parameters(a, {'t': 0.25}) is a
assert cirq.resolve_parameters(b, {'t': 0.5}) == a
assert cirq.resolve_parameters(b, {'t': 0.25}) == cirq.PhaseGradientGate(
num_qubits=2, exponent=0.25)


def test_str():
assert str(cirq.PhaseGradientGate(num_qubits=2,
exponent=0.5)) == 'Grad[2]^0.5'
assert str(cirq.PhaseGradientGate(num_qubits=2, exponent=1)) == 'Grad[2]'


def test_pow():
a = cirq.PhaseGradientGate(num_qubits=2, exponent=0.5)
assert a**0.5 == cirq.PhaseGradientGate(num_qubits=2, exponent=0.25)
assert a**sympy.Symbol('t') == cirq.PhaseGradientGate(num_qubits=2,
exponent=0.5 *
sympy.Symbol('t'))


def test_qft():
np.testing.assert_allclose(cirq.unitary(cirq.QFT(*cirq.LineQubit.range(2))),
np.array([
[1, 1, 1, 1],
[1, 1j, -1, -1j],
[1, -1, 1, -1],
[1, -1j, -1, 1j],
]) / 2,
atol=1e-8)

np.testing.assert_allclose(cirq.unitary(
cirq.QFT(*cirq.LineQubit.range(2), without_reverse=True)),
np.array([
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 1j, -1, -1j],
[1, -1j, -1, 1j],
]) / 2,
atol=1e-8)

np.testing.assert_allclose(
cirq.unitary(cirq.QFT(*cirq.LineQubit.range(4))),
np.array([[np.exp(2j * np.pi * i * j / 16)
for i in range(16)]
for j in range(16)]) / 4,
atol=1e-8)

np.testing.assert_allclose(cirq.unitary(
cirq.QFT(*cirq.LineQubit.range(2))**-1),
np.array([
[1, 1, 1, 1],
[1, -1j, -1, 1j],
[1, -1, 1, -1],
[1, 1j, -1, -1j],
]) / 2,
atol=1e-8)

for k in range(4):
cirq.testing.assert_implements_consistent_protocols(
cirq.QuantumFourierTransformGate(num_qubits=k))


def test_circuit_diagram():
cirq.testing.assert_has_diagram(
cirq.Circuit.from_ops(
cirq.decompose_once(cirq.QFT(*cirq.LineQubit.range(4)))), """
0: ───H───Grad^0.5───────#2─────────────#3─────────────×───
│ │ │ │
1: ───────@──────────H───Grad^0.5───────#2─────────×───┼───
│ │ │ │
2: ──────────────────────@──────────H───Grad^0.5───×───┼───
│ │
3: ─────────────────────────────────────@──────────H───×───
""")

cirq.testing.assert_has_diagram(
cirq.Circuit.from_ops(
cirq.decompose_once(
cirq.QFT(*cirq.LineQubit.range(4), without_reverse=True))), """
0: ───H───Grad^0.5───────#2─────────────#3─────────────
│ │ │
1: ───────@──────────H───Grad^0.5───────#2─────────────
│ │
2: ──────────────────────@──────────H───Grad^0.5───────
3: ─────────────────────────────────────@──────────H───
""")

cirq.testing.assert_has_diagram(
cirq.Circuit.from_ops(cirq.QFT(*cirq.LineQubit.range(4)),
cirq.inverse(cirq.QFT(*cirq.LineQubit.range(4)))),
"""
0: ───QFT───QFT^-1───
│ │
1: ───#2────#2───────
│ │
2: ───#3────#3───────
│ │
3: ───#4────#4───────
""")
9 changes: 9 additions & 0 deletions cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cirq import value, protocols

if TYPE_CHECKING:
import cirq
from cirq.ops import gate_operation, linear_combinations


Expand Down Expand Up @@ -416,6 +417,14 @@ def _decompose_(self, qubits):
def _value_equality_values_(self):
return self._original

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
sub_info = protocols.circuit_diagram_info(self._original,
default=NotImplemented)
if sub_info is NotImplemented:
return NotImplemented
sub_info.exponent *= -1
return sub_info

def __repr__(self):
return '({!r}**-1)'.format(self._original)

Expand Down
Loading