Skip to content

Commit adf5155

Browse files
Combine RB and XEB to compute inferred errors (#6455)
1 parent c3de706 commit adf5155

File tree

3 files changed

+356
-22
lines changed

3 files changed

+356
-22
lines changed

cirq-core/cirq/experiments/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,8 @@
6565

6666
from cirq.experiments.xeb_fitting import XEBPhasedFSimCharacterizationOptions
6767

68-
from cirq.experiments.two_qubit_xeb import TwoQubitXEBResult, parallel_two_qubit_xeb
68+
from cirq.experiments.two_qubit_xeb import (
69+
InferredXEBResult,
70+
TwoQubitXEBResult,
71+
parallel_two_qubit_xeb,
72+
)

cirq-core/cirq/experiments/two_qubit_xeb.py

+177-13
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Sequence, TYPE_CHECKING, Optional, Tuple, Dict
14+
from typing import Sequence, TYPE_CHECKING, Optional, Tuple, Dict, cast, Mapping
1515

1616
from dataclasses import dataclass
17+
from types import MappingProxyType
1718
import itertools
1819
import functools
1920

@@ -22,25 +23,24 @@
2223
import numpy as np
2324
import pandas as pd
2425

25-
from cirq import ops, devices, value, vis
26+
from cirq import ops, value, vis
2627
from cirq.experiments.xeb_sampling import sample_2q_xeb_circuits
2728
from cirq.experiments.xeb_fitting import benchmark_2q_xeb_fidelities
2829
from cirq.experiments.xeb_fitting import fit_exponential_decays, exponential_decay
2930
from cirq.experiments import random_quantum_circuit_generation as rqcg
31+
from cirq.experiments.qubit_characterizations import ParallelRandomizedBenchmarkingResult
32+
from cirq.qis import noise_utils
33+
from cirq._compat import cached_method
3034

3135
if TYPE_CHECKING:
3236
import cirq
3337

3438

35-
def _grid_qubits_for_sampler(sampler: 'cirq.Sampler'):
39+
def _grid_qubits_for_sampler(sampler: 'cirq.Sampler') -> Optional[Sequence['cirq.GridQubit']]:
3640
if hasattr(sampler, 'processor'):
3741
device = sampler.processor.get_device()
3842
return sorted(device.metadata.qubit_set)
39-
else:
40-
qubits = devices.GridQubit.rect(3, 2, 4, 3)
41-
# Delete one qubit from the rectangular arangement to
42-
# 1) make it irregular 2) simplify simulation.
43-
return qubits[:-1]
43+
return None
4444

4545

4646
def _manhattan_distance(qubit1: 'cirq.GridQubit', qubit2: 'cirq.GridQubit') -> int:
@@ -65,7 +65,7 @@ def all_qubit_pairs(self) -> Tuple[Tuple['cirq.GridQubit', 'cirq.GridQubit'], ..
6565
return tuple(sorted(self._qubit_pair_map.keys()))
6666

6767
def plot_heatmap(self, ax: Optional[plt.Axes] = None, **plot_kwargs) -> plt.Axes:
68-
"""plot the heatmap for xeb error.
68+
"""plot the heatmap of XEB errors.
6969
7070
Args:
7171
ax: the plt.Axes to plot on. If not given, a new figure is created,
@@ -75,7 +75,6 @@ def plot_heatmap(self, ax: Optional[plt.Axes] = None, **plot_kwargs) -> plt.Axes
7575
show_plot = not ax
7676
if not isinstance(ax, plt.Axes):
7777
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
78-
7978
heatmap_data: Dict[Tuple['cirq.GridQubit', ...], float] = {
8079
pair: self.xeb_error(*pair) for pair in self.all_qubit_pairs
8180
}
@@ -131,10 +130,13 @@ def _record(self, q0, q1) -> pd.Series:
131130
q0, q1 = q1, q0
132131
return self.fidelities.iloc[self._qubit_pair_map[(q0, q1)]]
133132

133+
def xeb_fidelity(self, q0: 'cirq.GridQubit', q1: 'cirq.GridQubit') -> float:
134+
"""Return the XEB fidelity of a qubit pair."""
135+
return self._record(q0, q1).layer_fid
136+
134137
def xeb_error(self, q0: 'cirq.GridQubit', q1: 'cirq.GridQubit') -> float:
135138
"""Return the XEB error of a qubit pair."""
136-
p = self._record(q0, q1).layer_fid
137-
return 1 - p
139+
return 1 - self.xeb_fidelity(q0, q1)
138140

139141
def all_errors(self) -> Dict[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
140142
"""Return the XEB error of all qubit pairs."""
@@ -156,9 +158,163 @@ def plot_histogram(self, ax: Optional[plt.Axes] = None, **plot_kwargs) -> plt.Ax
156158
fig.show(**plot_kwargs)
157159
return ax
158160

161+
@cached_method
162+
def pauli_error(self) -> Dict[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
163+
"""Return the Pauli error of all qubit pairs."""
164+
return {
165+
pair: noise_utils.decay_constant_to_pauli_error(
166+
noise_utils.xeb_fidelity_to_decay_constant(self.xeb_fidelity(*pair), num_qubits=2),
167+
num_qubits=2,
168+
)
169+
for pair in self.all_qubit_pairs
170+
}
171+
172+
173+
@dataclass(frozen=True)
174+
class InferredXEBResult:
175+
"""Uses the results from XEB and RB to compute inferred two-qubit Pauli errors."""
176+
177+
rb_result: ParallelRandomizedBenchmarkingResult
178+
xeb_result: TwoQubitXEBResult
179+
180+
@property
181+
def all_qubit_pairs(self) -> Sequence[Tuple['cirq.GridQubit', 'cirq.GridQubit']]:
182+
return self.xeb_result.all_qubit_pairs
183+
184+
@cached_method
185+
def single_qubit_pauli_error(self) -> Mapping['cirq.Qid', float]:
186+
"""Return the single-qubit Pauli error for all qubits (RB results)."""
187+
return self.rb_result.pauli_error()
188+
189+
@cached_method
190+
def two_qubit_pauli_error(self) -> Mapping[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
191+
"""Return the two-qubit Pauli error for all pairs."""
192+
return MappingProxyType(self.xeb_result.pauli_error())
193+
194+
@cached_method
195+
def inferred_pauli_error(self) -> Mapping[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
196+
"""Return the inferred Pauli error for all pairs."""
197+
single_q_paulis = self.rb_result.pauli_error()
198+
xeb = self.xeb_result.pauli_error()
199+
200+
def _pauli_error(q0: 'cirq.GridQubit', q1: 'cirq.GridQubit') -> float:
201+
q0, q1 = sorted([q0, q1])
202+
return xeb[(q0, q1)] - single_q_paulis[q0] - single_q_paulis[q1]
203+
204+
return MappingProxyType({pair: _pauli_error(*pair) for pair in self.all_qubit_pairs})
205+
206+
@cached_method
207+
def inferred_decay_constant(self) -> Mapping[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
208+
"""Return the inferred decay constant for all pairs."""
209+
return MappingProxyType(
210+
{
211+
pair: noise_utils.pauli_error_to_decay_constant(pauli, 2)
212+
for pair, pauli in self.inferred_pauli_error().items()
213+
}
214+
)
215+
216+
@cached_method
217+
def inferred_xeb_error(self) -> Mapping[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
218+
"""Return the inferred XEB error for all pairs."""
219+
return MappingProxyType(
220+
{
221+
pair: 1 - noise_utils.decay_constant_to_xeb_fidelity(decay, 2)
222+
for pair, decay in self.inferred_decay_constant().items()
223+
}
224+
)
225+
226+
def _target_errors(
227+
self, target_error: str
228+
) -> Mapping[Tuple['cirq.GridQubit', 'cirq.GridQubit'], float]:
229+
error_funcs = {
230+
'pauli': self.inferred_pauli_error,
231+
'decay_constant': self.inferred_decay_constant,
232+
'xeb': self.inferred_xeb_error,
233+
}
234+
return error_funcs[target_error]()
235+
236+
def plot_heatmap(
237+
self, target_error: str = 'pauli', ax: Optional[plt.Axes] = None, **plot_kwargs
238+
) -> plt.Axes:
239+
"""plot the heatmap of the target errors.
240+
241+
Args:
242+
target_error: The error to draw. Must be one of 'xeb', 'pauli', or 'decay_constant'
243+
ax: the plt.Axes to plot on. If not given, a new figure is created,
244+
plotted on, and shown.
245+
**plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.
246+
"""
247+
show_plot = not ax
248+
if not isinstance(ax, plt.Axes):
249+
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
250+
heatmap_data = cast(
251+
Mapping[Tuple['cirq.GridQubit', ...], float], self._target_errors(target_error)
252+
)
253+
254+
name = f'{target_error} error' if target_error != 'decay_constant' else 'decay constant'
255+
ax.set_title(f'device {name} heatmap')
256+
257+
vis.TwoQubitInteractionHeatmap(heatmap_data).plot(ax=ax, **plot_kwargs)
258+
if show_plot:
259+
fig.show()
260+
return ax
261+
262+
def plot_histogram(
263+
self,
264+
target_error: str = 'pauli',
265+
ax: Optional[plt.Axes] = None,
266+
kind: str = 'two_qubit',
267+
**plot_kwargs,
268+
) -> plt.Axes:
269+
"""plot a histogram of target error.
270+
271+
Args:
272+
target_error: The error to draw. Must be one of 'xeb', 'pauli', or 'decay_constant'
273+
kind: Whether to plot the single-qubit RB errors ('single_qubit') or the
274+
two-qubit inferred errors ('two_qubit') or both ('both').
275+
ax: the plt.Axes to plot on. If not given, a new figure is created,
276+
plotted on, and shown.
277+
**plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.
278+
279+
Raises:
280+
ValueError: If
281+
- `kind` is not one of 'single_qubit', 'two_qubit', or 'both'.
282+
- `target_error` is not one of 'pauli', 'xeb', or 'decay_constant'
283+
- single qubit error is requested and `target_error` is not 'pauli'.
284+
"""
285+
if kind not in ('single_qubit', 'two_qubit', 'both'):
286+
raise ValueError(
287+
f"kind must be one of 'single_qubit', 'two_qubit', or 'both', not {kind}"
288+
)
289+
if kind != 'two_qubit' and target_error != 'pauli':
290+
raise ValueError(f'{target_error} is not supported for single qubits')
291+
fig = None
292+
if ax is None:
293+
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
294+
295+
alpha = 0.5 if kind == 'both' else 1.0
296+
if kind == 'single_qubit' or kind == 'both':
297+
self.rb_result.plot_integrated_histogram(
298+
ax=ax, alpha=alpha, label='single qubit', color='green', **plot_kwargs
299+
)
300+
if kind == 'two_qubit' or kind == 'both':
301+
vis.integrated_histogram(
302+
data=self._target_errors(target_error),
303+
ax=ax,
304+
alpha=alpha,
305+
label='two qubit',
306+
color='blue',
307+
**plot_kwargs,
308+
)
309+
310+
if fig is not None:
311+
fig.show(**plot_kwargs)
312+
return ax
313+
159314

160315
def parallel_two_qubit_xeb(
161316
sampler: 'cirq.Sampler',
317+
qubits: Optional[Sequence['cirq.GridQubit']] = None,
162318
entangling_gate: 'cirq.Gate' = ops.CZ,
163319
n_repetitions: int = 10**4,
164320
n_combinations: int = 10,
@@ -172,6 +328,7 @@ def parallel_two_qubit_xeb(
172328
173329
Args:
174330
sampler: The quantum engine or simulator to run the circuits.
331+
qubits: Qubits under test. If none, uses all qubits on the sampler's device.
175332
entangling_gate: The entangling gate to use.
176333
n_repetitions: The number of repetitions to use.
177334
n_combinations: The number of combinations to generate.
@@ -184,10 +341,17 @@ def parallel_two_qubit_xeb(
184341
185342
Returns:
186343
A TwoQubitXEBResult object representing the results of the experiment.
344+
345+
Raises:
346+
ValueError: If qubits are not specified and the sampler has no device.
187347
"""
188348
rs = value.parse_random_state(random_state)
189349

190-
qubits = _grid_qubits_for_sampler(sampler)
350+
if qubits is None:
351+
qubits = _grid_qubits_for_sampler(sampler)
352+
if qubits is None:
353+
raise ValueError("Couldn't determine qubits from sampler. Please specify them.")
354+
191355
graph = nx.Graph(
192356
pair for pair in itertools.combinations(qubits, 2) if _manhattan_distance(*pair) == 1
193357
)

0 commit comments

Comments
 (0)