Skip to content

Commit 3080d93

Browse files
Add a new gauge for SqrtCZ and support SqrtCZ† and fix and improve spin inversion gauge (#6571)
1 parent 614c78a commit 3080d93

File tree

5 files changed

+104
-15
lines changed

5 files changed

+104
-15
lines changed

cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class ConstantGauge(Gauge):
7171
post_q1: Tuple[ops.Gate, ...] = field(
7272
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
7373
)
74+
swap_qubits: bool = False
7475

7576
def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge":
7677
return self
@@ -85,6 +86,41 @@ def post(self) -> Tuple[Tuple[ops.Gate, ...], Tuple[ops.Gate, ...]]:
8586
"""A tuple (ops to apply to q0, ops to apply to q1)."""
8687
return self.post_q0, self.post_q1
8788

89+
def on(self, q0: ops.Qid, q1: ops.Qid) -> ops.Operation:
90+
"""Returns the operation that replaces the two qubit gate."""
91+
if self.swap_qubits:
92+
return self.two_qubit_gate(q1, q0)
93+
return self.two_qubit_gate(q0, q1)
94+
95+
96+
@frozen
97+
class SameGateGauge(Gauge):
98+
"""Same as ConstantGauge but the new two-qubit gate equals the old gate."""
99+
100+
pre_q0: Tuple[ops.Gate, ...] = field(
101+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
102+
)
103+
pre_q1: Tuple[ops.Gate, ...] = field(
104+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
105+
)
106+
post_q0: Tuple[ops.Gate, ...] = field(
107+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
108+
)
109+
post_q1: Tuple[ops.Gate, ...] = field(
110+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
111+
)
112+
swap_qubits: bool = False
113+
114+
def sample(self, gate: ops.Gate, prng: np.random.Generator) -> ConstantGauge:
115+
return ConstantGauge(
116+
two_qubit_gate=gate,
117+
pre_q0=self.pre_q0,
118+
pre_q1=self.pre_q1,
119+
post_q0=self.post_q0,
120+
post_q1=self.post_q1,
121+
swap_qubits=self.swap_qubits,
122+
)
123+
88124

89125
def _select(choices: Sequence[Gauge], probabilites: np.ndarray, prng: np.random.Generator) -> Gauge:
90126
return choices[prng.choice(len(choices), p=probabilites)]
@@ -154,7 +190,7 @@ def __call__(
154190
gauge = self.gauge_selector(rng).sample(op.gate, rng)
155191
q0, q1 = op.qubits
156192
left.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.pre))
157-
center.append(gauge.two_qubit_gate(q0, q1))
193+
center.append(gauge.on(q0, q1))
158194
right.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.post))
159195
else:
160196
center.append(op)

cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717
from cirq.transformers.gauge_compiling.gauge_compiling import (
1818
GaugeTransformer,
1919
GaugeSelector,
20-
ConstantGauge,
20+
SameGateGauge,
2121
)
2222
from cirq import ops
2323

2424
SpinInversionGaugeSelector = GaugeSelector(
2525
gauges=[
26-
ConstantGauge(two_qubit_gate=ops.ZZ, pre_q0=ops.X, post_q0=ops.X),
27-
ConstantGauge(two_qubit_gate=ops.ZZ, pre_q1=ops.X, post_q1=ops.X),
26+
SameGateGauge(pre_q0=ops.X, post_q0=ops.X, pre_q1=ops.X, post_q1=ops.X),
27+
SameGateGauge(),
2828
]
2929
)
3030

3131
SpinInversionGaugeTransformer = GaugeTransformer(
32-
target=ops.ZZ, gauge_selector=SpinInversionGaugeSelector
32+
target=ops.GateFamily(ops.ZZPowGate), gauge_selector=SpinInversionGaugeSelector
3333
)

cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge_test.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import cirq
1716
from cirq.transformers.gauge_compiling import SpinInversionGaugeTransformer
1817
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester
1918

2019

21-
class TestSpinInversionGauge(GaugeTester):
20+
class TestSpinInversionGauge_0(GaugeTester):
2221
two_qubit_gate = cirq.ZZ
2322
gauge_transformer = SpinInversionGaugeTransformer
23+
24+
25+
class TestSpinInversionGauge_1(GaugeTester):
26+
two_qubit_gate = cirq.ZZ**0.1
27+
gauge_transformer = SpinInversionGaugeTransformer
28+
29+
30+
class TestSpinInversionGauge_2(GaugeTester):
31+
two_qubit_gate = cirq.ZZ**-1
32+
gauge_transformer = SpinInversionGaugeTransformer
33+
34+
35+
class TestSpinInversionGauge_3(GaugeTester):
36+
two_qubit_gate = cirq.ZZ**0.3
37+
gauge_transformer = SpinInversionGaugeTransformer

cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py

+42-7
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,53 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""A Gauge transformer for CZ**0.5 gate."""
15+
"""A Gauge transformer for CZ**0.5 and CZ**-0.5 gates."""
16+
17+
18+
from typing import TYPE_CHECKING
19+
import numpy as np
1620

1721
from cirq.transformers.gauge_compiling.gauge_compiling import (
1822
GaugeTransformer,
1923
GaugeSelector,
2024
ConstantGauge,
25+
Gauge,
2126
)
22-
from cirq.ops.common_gates import CZ
23-
from cirq import ops
27+
from cirq.ops import CZ, S, X, Gateset
28+
29+
if TYPE_CHECKING:
30+
import cirq
31+
32+
_SQRT_CZ = CZ**0.5
33+
_ADJ_S = S**-1
2434

25-
SqrtCZGaugeSelector = GaugeSelector(
26-
gauges=[ConstantGauge(pre_q0=ops.X, post_q0=ops.X, post_q1=ops.Z**0.5, two_qubit_gate=CZ**-0.5)]
27-
)
2835

29-
SqrtCZGaugeTransformer = GaugeTransformer(target=CZ**0.5, gauge_selector=SqrtCZGaugeSelector)
36+
class SqrtCZGauge(Gauge):
37+
38+
def weight(self) -> float:
39+
return 3.0
40+
41+
def sample(self, gate: 'cirq.Gate', prng: np.random.Generator) -> ConstantGauge:
42+
if prng.choice([True, False]):
43+
return ConstantGauge(two_qubit_gate=gate)
44+
swap_qubits = prng.choice([True, False])
45+
if swap_qubits:
46+
return ConstantGauge(
47+
pre_q1=X,
48+
post_q1=X,
49+
post_q0=S if gate == _SQRT_CZ else _ADJ_S,
50+
two_qubit_gate=gate**-1,
51+
swap_qubits=True,
52+
)
53+
else:
54+
return ConstantGauge(
55+
pre_q0=X,
56+
post_q0=X,
57+
post_q1=S if gate == _SQRT_CZ else _ADJ_S,
58+
two_qubit_gate=gate**-1,
59+
)
60+
61+
62+
SqrtCZGaugeTransformer = GaugeTransformer(
63+
target=Gateset(_SQRT_CZ, _SQRT_CZ**-1), gauge_selector=GaugeSelector(gauges=[SqrtCZGauge()])
64+
)

cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge_test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import cirq
1716
from cirq.transformers.gauge_compiling import SqrtCZGaugeTransformer
1817
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester
@@ -21,3 +20,8 @@
2120
class TestSqrtCZGauge(GaugeTester):
2221
two_qubit_gate = cirq.CZ**0.5
2322
gauge_transformer = SqrtCZGaugeTransformer
23+
24+
25+
class TestAdjointSqrtCZGauge(GaugeTester):
26+
two_qubit_gate = cirq.CZ**-0.5
27+
gauge_transformer = SqrtCZGaugeTransformer

0 commit comments

Comments
 (0)