Skip to content

Commit f9f0d66

Browse files
Introduce gauge compilation (#6526)
This PR introduces the abstraction for Gauge compilation as well as implementation for Sycamore gate, CZ gate, SqrtCZ gate, ZZ (a.k.a spin inversion), ISWAP gate, and SQRT_ISWAP gate
1 parent 45c5fa3 commit f9f0d66

20 files changed

+902
-0
lines changed

cirq-core/cirq/transformers/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,16 @@
119119
unroll_circuit_op_greedy_earliest,
120120
unroll_circuit_op_greedy_frontier,
121121
)
122+
123+
124+
from cirq.transformers.gauge_compiling import (
125+
CZGaugeTransformer,
126+
ConstantGauge,
127+
Gauge,
128+
GaugeSelector,
129+
GaugeTransformer,
130+
ISWAPGaugeTransformer,
131+
SpinInversionGaugeTransformer,
132+
SqrtCZGaugeTransformer,
133+
SqrtISWAPGaugeTransformer,
134+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from cirq.transformers.gauge_compiling.gauge_compiling import (
17+
ConstantGauge,
18+
Gauge,
19+
GaugeSelector,
20+
GaugeTransformer,
21+
)
22+
from cirq.transformers.gauge_compiling.sqrt_cz_gauge import SqrtCZGaugeTransformer
23+
from cirq.transformers.gauge_compiling.spin_inversion_gauge import SpinInversionGaugeTransformer
24+
from cirq.transformers.gauge_compiling.cz_gauge import CZGaugeTransformer
25+
from cirq.transformers.gauge_compiling.iswap_gauge import ISWAPGaugeTransformer
26+
from cirq.transformers.gauge_compiling.sqrt_iswap_gauge import SqrtISWAPGaugeTransformer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A Gauge Transformer for the CZ gate."""
16+
17+
from cirq.transformers.gauge_compiling.gauge_compiling import (
18+
GaugeTransformer,
19+
GaugeSelector,
20+
ConstantGauge,
21+
)
22+
from cirq.ops.common_gates import CZ
23+
from cirq import ops
24+
25+
CZGaugeSelector = GaugeSelector(
26+
gauges=[
27+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.I, post_q0=ops.I, post_q1=ops.I),
28+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.X, post_q0=ops.Z, post_q1=ops.X),
29+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.Y, post_q0=ops.Z, post_q1=ops.Y),
30+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.Z, post_q0=ops.I, post_q1=ops.Z),
31+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.I, post_q0=ops.X, post_q1=ops.Z),
32+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.X, post_q0=ops.Y, post_q1=ops.Y),
33+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.Y, post_q0=ops.Y, post_q1=ops.X),
34+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.Z, post_q0=ops.X, post_q1=ops.I),
35+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.I, post_q0=ops.Y, post_q1=ops.Z),
36+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.X, post_q0=ops.X, post_q1=ops.Y),
37+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.Y, post_q0=ops.X, post_q1=ops.X),
38+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.Z, post_q0=ops.Y, post_q1=ops.I),
39+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.I, post_q0=ops.Z, post_q1=ops.I),
40+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.X, post_q0=ops.I, post_q1=ops.X),
41+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.Y, post_q0=ops.I, post_q1=ops.Y),
42+
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.Z, post_q0=ops.Z, post_q1=ops.Z),
43+
]
44+
)
45+
46+
CZGaugeTransformer = GaugeTransformer(target=CZ, gauge_selector=CZGaugeSelector)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import cirq
17+
from cirq.transformers.gauge_compiling import CZGaugeTransformer
18+
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester
19+
20+
21+
class TestCZGauge(GaugeTester):
22+
two_qubit_gate = cirq.CZ
23+
gauge_transformer = CZGaugeTransformer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Creates the abstraction for gauge compiling as a cirq transformer."""
16+
17+
from typing import Callable, Tuple, Optional, Sequence, Union, List
18+
import abc
19+
import itertools
20+
import functools
21+
22+
from dataclasses import dataclass
23+
from attrs import frozen, field
24+
import numpy as np
25+
26+
from cirq.transformers import transformer_api
27+
from cirq import ops, circuits
28+
29+
30+
class Gauge(abc.ABC):
31+
"""A gauge replaces a two qubit gate with an equivalent subcircuit.
32+
0: pre_q0───────two_qubit_gate───────post_q0
33+
|
34+
1: pre_q1───────two_qubit_gate───────post_q1
35+
36+
The Gauge class in general represents a family of closely related gauges
37+
(e.g. random z-rotations); Use `sample` method to get a specific gauge.
38+
"""
39+
40+
def weight(self) -> float:
41+
"""Returns the relative frequency for selecting this gauge."""
42+
return 1.0
43+
44+
@abc.abstractmethod
45+
def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge":
46+
"""Returns a ConstantGauge sampled from a family of gauges.
47+
48+
Args:
49+
gate: The two qubit gate to replace.
50+
prng: A numpy random number generator.
51+
52+
Returns:
53+
A ConstantGauge.
54+
"""
55+
56+
57+
@frozen
58+
class ConstantGauge(Gauge):
59+
"""A gauge that replaces a two qubit gate with a constant gauge."""
60+
61+
two_qubit_gate: ops.Gate
62+
pre_q0: Tuple[ops.Gate, ...] = field(
63+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
64+
)
65+
pre_q1: Tuple[ops.Gate, ...] = field(
66+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
67+
)
68+
post_q0: Tuple[ops.Gate, ...] = field(
69+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
70+
)
71+
post_q1: Tuple[ops.Gate, ...] = field(
72+
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
73+
)
74+
75+
def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge":
76+
return self
77+
78+
@property
79+
def pre(self) -> Tuple[Tuple[ops.Gate, ...], Tuple[ops.Gate, ...]]:
80+
"""A tuple (ops to apply to q0, ops to apply to q1)."""
81+
return self.pre_q0, self.pre_q1
82+
83+
@property
84+
def post(self) -> Tuple[Tuple[ops.Gate, ...], Tuple[ops.Gate, ...]]:
85+
"""A tuple (ops to apply to q0, ops to apply to q1)."""
86+
return self.post_q0, self.post_q1
87+
88+
89+
def _select(choices: Sequence[Gauge], probabilites: np.ndarray, prng: np.random.Generator) -> Gauge:
90+
return choices[prng.choice(len(choices), p=probabilites)]
91+
92+
93+
@dataclass(frozen=True)
94+
class GaugeSelector:
95+
"""Samples a gauge from a list of gauges."""
96+
97+
gauges: Sequence[Gauge]
98+
99+
@functools.cached_property
100+
def _weights(self) -> np.ndarray:
101+
weights = np.array([g.weight() for g in self.gauges])
102+
return weights / np.sum(weights)
103+
104+
def __call__(self, prng: np.random.Generator) -> Gauge:
105+
"""Randomly selects a gauge with probability proportional to its weight."""
106+
return _select(self.gauges, self._weights, prng)
107+
108+
109+
@transformer_api.transformer
110+
class GaugeTransformer:
111+
def __init__(
112+
self,
113+
# target can be either a specific gate, gatefamily or gateset
114+
# which allows matching parametric gates.
115+
target: Union[ops.Gate, ops.Gateset, ops.GateFamily],
116+
gauge_selector: Callable[[np.random.Generator], Gauge],
117+
) -> None:
118+
"""Constructs a GaugeTransformer.
119+
120+
Args:
121+
target: Target two-qubit gate, a gate-family or a gate-set of two-qubit gates.
122+
gauge_selector: A callable that takes a numpy random number generator
123+
as an argument and returns a Gauge.
124+
"""
125+
self.target = ops.GateFamily(target) if isinstance(target, ops.Gate) else target
126+
self.gauge_selector = gauge_selector
127+
128+
def __call__(
129+
self,
130+
circuit: circuits.AbstractCircuit,
131+
*,
132+
context: Optional[transformer_api.TransformerContext] = None,
133+
prng: Optional[np.random.Generator] = None,
134+
) -> circuits.AbstractCircuit:
135+
rng = np.random.default_rng() if prng is None else prng
136+
if context is None:
137+
context = transformer_api.TransformerContext(deep=False)
138+
if context.deep:
139+
raise ValueError('GaugeTransformer cannot be used with deep=True')
140+
new_moments = []
141+
left: List[List[ops.Operation]] = []
142+
right: List[List[ops.Operation]] = []
143+
for moment in circuit:
144+
left.clear()
145+
right.clear()
146+
center: List[ops.Operation] = []
147+
for op in moment:
148+
if isinstance(op, ops.TaggedOperation) and set(op.tags).intersection(
149+
context.tags_to_ignore
150+
):
151+
center.append(op)
152+
continue
153+
if op.gate is not None and len(op.qubits) == 2 and op in self.target:
154+
gauge = self.gauge_selector(rng).sample(op.gate, rng)
155+
q0, q1 = op.qubits
156+
left.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.pre))
157+
center.append(gauge.two_qubit_gate(q0, q1))
158+
right.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.post))
159+
else:
160+
center.append(op)
161+
if left:
162+
new_moments.extend(_build_moments(left))
163+
new_moments.append(center)
164+
if right:
165+
new_moments.extend(_build_moments(right))
166+
return circuits.Circuit.from_moments(*new_moments)
167+
168+
169+
def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[ops.Operation]]:
170+
"""Builds moments from a list of operations grouped by qubits.
171+
172+
Returns a list of moments from a list whose ith element is a list of operations applied
173+
to qubit i.
174+
"""
175+
moments = []
176+
for moment in itertools.zip_longest(*operation_by_qubits):
177+
moments.append([op for op in moment if op is not None])
178+
return moments
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import numpy as np
17+
import cirq
18+
from cirq.transformers.gauge_compiling import GaugeTransformer, CZGaugeTransformer
19+
20+
21+
def test_deep_transformation_not_supported():
22+
23+
with pytest.raises(ValueError, match="cannot be used with deep=True"):
24+
_ = GaugeTransformer(target=cirq.CZ, gauge_selector=lambda _: None)(
25+
cirq.Circuit(), context=cirq.TransformerContext(deep=True)
26+
)
27+
28+
29+
def test_ignore_tags():
30+
c = cirq.Circuit(cirq.CZ(*cirq.LineQubit.range(2)).with_tags('foo'))
31+
assert c == CZGaugeTransformer(c, context=cirq.TransformerContext(tags_to_ignore={"foo"}))
32+
33+
34+
def test_target_can_be_gateset():
35+
qs = cirq.LineQubit.range(2)
36+
c = cirq.Circuit(cirq.CZ(*qs))
37+
transformer = GaugeTransformer(
38+
target=cirq.Gateset(cirq.CZ), gauge_selector=CZGaugeTransformer.gauge_selector
39+
)
40+
want = cirq.Circuit(cirq.Y.on_each(qs), cirq.CZ(*qs), cirq.X.on_each(qs))
41+
assert transformer(c, prng=np.random.default_rng(0)) == want

0 commit comments

Comments
 (0)