Skip to content

Commit aa04196

Browse files
authored
Create a generalized uniform superposition state gate (#6506)
* Create generalized_uniform_superposition_gate.py Creates a generalized uniform superposition state, $\frac{1}{\sqrt{M}} \sum_{j=0}^{M-1} \ket{j} $ (where 1< M <= 2^n), using n qubits, according to the Shukla-Vedula algorithm [SV24]. Note: The Shukla-Vedula algorithm [SV24] offers an efficient approach for creation of a generalized uniform superposition state of the form, $\frac{1}{\sqrt{M}} \sum_{j=0}^{M-1} \ket{j} $, requiring only $O(log_2 (M))$ qubits and $O(log_2 (M))$ gates. This provides an exponential improvement (in the context of reduced resources and complexity) over other approaches in the literature. Reference: [SV24] A. Shukla and P. Vedula, “An efficient quantum algorithm for preparation of uniform quantum superposition states,” Quantum Information Processing, 23(38): pp. 1-32 (2024).
1 parent df07e94 commit aa04196

7 files changed

+227
-0
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@
332332
ZPowGate,
333333
ZZ,
334334
ZZPowGate,
335+
UniformSuperpositionGate,
335336
)
336337

337338
from cirq.transformers import (

cirq-core/cirq/json_resolver_cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def _symmetricalqidpair(qids):
247247
'ZipLongest': cirq.ZipLongest,
248248
'ZPowGate': cirq.ZPowGate,
249249
'ZZPowGate': cirq.ZZPowGate,
250+
'UniformSuperpositionGate': cirq.UniformSuperpositionGate,
250251
# Old types, only supported for backwards-compatibility
251252
'BooleanHamiltonian': _boolean_hamiltonian_gate_op, # Removed in v0.15
252253
'CrossEntropyResult': _cross_entropy_result, # Removed in v0.16

cirq-core/cirq/ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,5 @@
217217
from cirq.ops.state_preparation_channel import StatePreparationChannel
218218

219219
from cirq.ops.control_values import AbstractControlValues, ProductOfSums, SumOfProducts
220+
221+
from cirq.ops.uniform_superposition_gate import UniformSuperpositionGate
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence, Any, Dict, TYPE_CHECKING
16+
17+
import numpy as np
18+
from cirq.ops.common_gates import H, ry
19+
from cirq.ops.pauli_gates import X
20+
from cirq.ops import raw_types
21+
22+
23+
if TYPE_CHECKING:
24+
import cirq
25+
26+
27+
class UniformSuperpositionGate(raw_types.Gate):
28+
r"""Creates a uniform superposition state on the states $[0, M)$
29+
The gate creates the state $\frac{1}{\sqrt{M}}\sum_{j=0}^{M-1}\ket{j}$
30+
(where $1\leq M \leq 2^n$), using n qubits, according to the Shukla-Vedula algorithm [SV24].
31+
References:
32+
[SV24]
33+
[An efficient quantum algorithm for preparation of uniform quantum superposition
34+
states](https://arxiv.org/abs/2306.11747)
35+
"""
36+
37+
def __init__(self, m_value: int, num_qubits: int) -> None:
38+
"""Initializes UniformSuperpositionGate.
39+
40+
Args:
41+
m_value: The number of computational basis states.
42+
num_qubits: The number of qubits used.
43+
44+
Raises:
45+
ValueError: If `m_value` is not a positive integer, or
46+
if `num_qubits` is not an integer greater than or equal to log2(m_value).
47+
"""
48+
if not (isinstance(m_value, int) and (m_value > 0)):
49+
raise ValueError("m_value must be a positive integer.")
50+
log_two_m_value = m_value.bit_length()
51+
52+
if (m_value & (m_value - 1)) == 0:
53+
log_two_m_value = log_two_m_value - 1
54+
if not (isinstance(num_qubits, int) and (num_qubits >= log_two_m_value)):
55+
raise ValueError(
56+
"num_qubits must be an integer greater than or equal to log2(m_value)."
57+
)
58+
self._m_value = m_value
59+
self._num_qubits = num_qubits
60+
61+
def _decompose_(self, qubits: Sequence["cirq.Qid"]) -> "cirq.OP_TREE":
62+
"""Decomposes the gate into a sequence of standard gates.
63+
Implements the construction from https://arxiv.org/pdf/2306.11747.
64+
"""
65+
qreg = list(qubits)
66+
qreg.reverse()
67+
68+
if self._m_value == 1: # if m_value is 1, do nothing
69+
return
70+
if (self._m_value & (self._m_value - 1)) == 0: # if m_value is an integer power of 2
71+
m = self._m_value.bit_length() - 1
72+
yield H.on_each(qreg[:m])
73+
return
74+
k = self._m_value.bit_length()
75+
l_value = []
76+
for i in range(self._m_value.bit_length()):
77+
if (self._m_value >> i) & 1:
78+
l_value.append(i) # Locations of '1's
79+
80+
yield X.on_each(qreg[q_bit] for q_bit in l_value[1:k])
81+
m_current = 2 ** (l_value[0])
82+
theta = -2 * np.arccos(np.sqrt(m_current / self._m_value))
83+
if l_value[0] > 0: # if m_value is even
84+
yield H.on_each(qreg[: l_value[0]])
85+
86+
yield ry(theta).on(qreg[l_value[1]])
87+
88+
for i in range(l_value[0], l_value[1]):
89+
yield H(qreg[i]).controlled_by(qreg[l_value[1]], control_values=[False])
90+
91+
for m in range(1, len(l_value) - 1):
92+
theta = -2 * np.arccos(np.sqrt(2 ** l_value[m] / (self._m_value - m_current)))
93+
yield ry(theta).on(qreg[l_value[m + 1]]).controlled_by(
94+
qreg[l_value[m]], control_values=[0]
95+
)
96+
for i in range(l_value[m], l_value[m + 1]):
97+
yield H.on(qreg[i]).controlled_by(qreg[l_value[m + 1]], control_values=[0])
98+
99+
m_current = m_current + 2 ** (l_value[m])
100+
101+
def num_qubits(self) -> int:
102+
return self._num_qubits
103+
104+
@property
105+
def m_value(self) -> int:
106+
return self._m_value
107+
108+
def __eq__(self, other):
109+
if isinstance(other, UniformSuperpositionGate):
110+
return (self._m_value == other._m_value) and (self._num_qubits == other._num_qubits)
111+
return False
112+
113+
def __repr__(self) -> str:
114+
return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})'
115+
116+
def _json_dict_(self) -> Dict[str, Any]:
117+
d = {}
118+
d['m_value'] = self._m_value
119+
d['num_qubits'] = self._num_qubits
120+
return d
121+
122+
def __str__(self) -> str:
123+
return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
import cirq
18+
19+
20+
@pytest.mark.parametrize(
21+
["m", "n"],
22+
[[int(m), n] for n in range(3, 7) for m in np.random.randint(1, 1 << n, size=3)]
23+
+ [(1, 2), (4, 2), (6, 3), (7, 3)],
24+
)
25+
def test_generated_unitary_is_uniform(m: int, n: int) -> None:
26+
r"""The code checks that the unitary matrix corresponds to the generated uniform superposition
27+
states (see uniform_superposition_gate.py). It is enough to check that the
28+
first colum of the unitary matrix (which corresponds to the action of the gate on
29+
$\ket{0}^n$ is $\frac{1}{\sqrt{M}} [1 1 \cdots 1 0 \cdots 0]^T$, where the first $M$
30+
entries are all "1"s (excluding the normalization factor of $\frac{1}{\sqrt{M}}$ and the
31+
remaining $2^n-M$ entries are all "0"s.
32+
"""
33+
gate = cirq.UniformSuperpositionGate(m, n)
34+
matrix = np.array(cirq.unitary(gate))
35+
np.testing.assert_allclose(
36+
matrix[:, 0], (1 / np.sqrt(m)) * np.array([1] * m + [0] * (2**n - m)), atol=1e-8
37+
)
38+
39+
40+
@pytest.mark.parametrize(["m", "n"], [(1, 1), (-2, 1), (-3.1, 2), (6, -4), (5, 6.1)])
41+
def test_incompatible_m_value_and_qubit_args(m: int, n: int) -> None:
42+
r"""The code checks that test errors are raised if the arguments m (number of
43+
superposition states and n (number of qubits) are positive integers and are compatible
44+
(i.e., n >= log2(m)).
45+
"""
46+
47+
if not (isinstance(m, int)):
48+
with pytest.raises(ValueError, match="m_value must be a positive integer."):
49+
cirq.UniformSuperpositionGate(m, n)
50+
elif not (isinstance(n, int)):
51+
with pytest.raises(
52+
ValueError,
53+
match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).",
54+
):
55+
cirq.UniformSuperpositionGate(m, n)
56+
elif m < 1:
57+
with pytest.raises(ValueError, match="m_value must be a positive integer."):
58+
cirq.UniformSuperpositionGate(int(m), int(n))
59+
elif n < np.log2(m):
60+
with pytest.raises(
61+
ValueError,
62+
match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).",
63+
):
64+
cirq.UniformSuperpositionGate(m, n)
65+
66+
67+
def test_repr():
68+
assert (
69+
repr(cirq.UniformSuperpositionGate(7, 3))
70+
== 'UniformSuperpositionGate(m_value=7, num_qubits=3)'
71+
)
72+
73+
74+
def test_uniform_superposition_gate_json_dict():
75+
assert cirq.UniformSuperpositionGate(7, 3)._json_dict_() == {'m_value': 7, 'num_qubits': 3}
76+
77+
78+
def test_str():
79+
assert (
80+
str(cirq.UniformSuperpositionGate(7, 3))
81+
== 'UniformSuperpositionGate(m_value=7, num_qubits=3)'
82+
)
83+
84+
85+
@pytest.mark.parametrize(["m", "n"], [(5, 3), (10, 4)])
86+
def test_eq(m: int, n: int) -> None:
87+
a = cirq.UniformSuperpositionGate(m, n)
88+
b = cirq.UniformSuperpositionGate(m, n)
89+
c = cirq.UniformSuperpositionGate(m + 1, n)
90+
d = cirq.X
91+
assert a.m_value == b.m_value
92+
assert a.__eq__(b)
93+
assert not (a.__eq__(c))
94+
assert not (a.__eq__(d))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"cirq_type": "UniformSuperpositionGate",
3+
"m_value": 7,
4+
"num_qubits": 3
5+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.UniformSuperpositionGate(m_value=7, num_qubits=3)

0 commit comments

Comments
 (0)