Skip to content

Commit 7da7f64

Browse files
Add ParallelRandomizedBenchmarkingResult class (#6412)
1 parent 9c451a2 commit 7da7f64

File tree

2 files changed

+136
-4
lines changed

2 files changed

+136
-4
lines changed

cirq-core/cirq/experiments/qubit_characterizations.py

+130-3
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@
3232
from scipy.optimize import curve_fit
3333

3434
from matplotlib import pyplot as plt
35+
import cirq.vis.heatmap as cirq_heatmap
36+
import cirq.vis.histogram as cirq_histogram
3537

3638
# this is for older systems with matplotlib <3.2 otherwise 3d projections fail
3739
from mpl_toolkits import mplot3d
3840
from cirq import circuits, ops, protocols
41+
from cirq.devices import grid_qubit
42+
3943

4044
if TYPE_CHECKING:
4145
import cirq
@@ -144,6 +148,127 @@ def _fit_exponential(self) -> Tuple[np.ndarray, np.ndarray]:
144148
)
145149

146150

151+
@dataclasses.dataclass(frozen=True)
152+
class ParallelRandomizedBenchmarkingResult:
153+
"""Results from a parallel randomized benchmarking experiment."""
154+
155+
results_dictionary: Mapping['cirq.Qid', 'RandomizedBenchMarkResult']
156+
157+
def plot_single_qubit(
158+
self, qubit: 'cirq.Qid', ax: Optional[plt.Axes] = None, **plot_kwargs: Any
159+
) -> plt.Axes:
160+
"""Plot the raw data for the specified qubit.
161+
162+
Args:
163+
qubit: Plot data for this qubit.
164+
ax: the plt.Axes to plot on. If not given, a new figure is created,
165+
plotted on, and shown.
166+
**plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.
167+
Returns:
168+
The plt.Axes containing the plot.
169+
"""
170+
171+
return self.results_dictionary[qubit].plot(ax, **plot_kwargs)
172+
173+
def pauli_error(self) -> Mapping['cirq.Qid', float]:
174+
"""Return a dictionary of Pauli errors.
175+
Returns:
176+
A dictionary containing the Pauli errors for all qubits.
177+
"""
178+
179+
return {
180+
qubit: self.results_dictionary[qubit].pauli_error() for qubit in self.results_dictionary
181+
}
182+
183+
def plot_heatmap(
184+
self,
185+
ax: Optional[plt.Axes] = None,
186+
annotation_format: str = '0.1%',
187+
title: str = 'Single-qubit Pauli error',
188+
**plot_kwargs: Any,
189+
) -> plt.Axes:
190+
"""Plot a heatmap of the Pauli errors. If qubits are not cirq.GridQubits, throws an error.
191+
192+
Args:
193+
ax: the plt.Axes to plot on. If not given, a new figure is created,
194+
plotted on, and shown.
195+
annotation_format: The format string for the numbers in the heatmap.
196+
title: The title printed above the heatmap.
197+
**plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'.
198+
Returns:
199+
The plt.Axes containing the plot.
200+
"""
201+
202+
pauli_errors = self.pauli_error()
203+
pauli_errors_with_grid_qubit_keys = {}
204+
for qubit in pauli_errors:
205+
assert type(qubit) == grid_qubit.GridQubit, "qubits must be cirq.GridQubits"
206+
pauli_errors_with_grid_qubit_keys[qubit] = pauli_errors[qubit] # just for typecheck
207+
208+
if ax is None:
209+
_, ax = plt.subplots(dpi=200, facecolor='white')
210+
211+
ax, _ = cirq_heatmap.Heatmap(pauli_errors_with_grid_qubit_keys).plot(
212+
ax, annotation_format=annotation_format, title=title, **plot_kwargs
213+
)
214+
return ax
215+
216+
def plot_integrated_histogram(
217+
self,
218+
ax: Optional[plt.Axes] = None,
219+
cdf_on_x: bool = False,
220+
axis_label: str = 'Pauli error',
221+
semilog: bool = True,
222+
median_line: bool = True,
223+
median_label: Optional[str] = 'median',
224+
mean_line: bool = False,
225+
mean_label: Optional[str] = 'mean',
226+
show_zero: bool = False,
227+
title: Optional[str] = None,
228+
**kwargs,
229+
) -> plt.Axes:
230+
"""Plot the Pauli errors using cirq.integrated_histogram().
231+
232+
Args:
233+
ax: The axis to plot on. If None, we generate one.
234+
cdf_on_x: If True, flip the axes compared the above example.
235+
axis_label: Label for x axis (y-axis if cdf_on_x is True).
236+
semilog: If True, force the x-axis to be logarithmic.
237+
median_line: If True, draw a vertical line on the median value.
238+
median_label: If drawing median line, optional label for it.
239+
mean_line: If True, draw a vertical line on the mean value.
240+
mean_label: If drawing mean line, optional label for it.
241+
title: Title of the plot. If None, we assign "N={len(data)}".
242+
show_zero: If True, moves the step plot up by one unit by prepending 0
243+
to the data.
244+
**kwargs: Kwargs to forward to `ax.step()`. Some examples are
245+
color: Color of the line.
246+
linestyle: Linestyle to use for the plot.
247+
lw: linewidth for integrated histogram.
248+
ms: marker size for a histogram trace.
249+
label: An optional label which can be used in a legend.
250+
Returns:
251+
The axis that was plotted on.
252+
"""
253+
254+
ax = cirq_histogram.integrated_histogram(
255+
data=self.pauli_error(),
256+
ax=ax,
257+
cdf_on_x=cdf_on_x,
258+
axis_label=axis_label,
259+
semilog=semilog,
260+
median_line=median_line,
261+
median_label=median_label,
262+
mean_line=mean_line,
263+
mean_label=mean_label,
264+
show_zero=show_zero,
265+
title=title,
266+
**kwargs,
267+
)
268+
ax.set_ylabel('Percentile')
269+
return ax
270+
271+
147272
class TomographyResult:
148273
"""Results from a state tomography experiment."""
149274

@@ -265,7 +390,7 @@ def single_qubit_randomized_benchmarking(
265390
num_circuits=num_circuits,
266391
repetitions=repetitions,
267392
)
268-
return result[qubit]
393+
return result.results_dictionary[qubit]
269394

270395

271396
def parallel_single_qubit_randomized_benchmarking(
@@ -278,7 +403,7 @@ def parallel_single_qubit_randomized_benchmarking(
278403
),
279404
num_circuits: int = 10,
280405
repetitions: int = 1000,
281-
) -> Mapping['cirq.Qid', 'RandomizedBenchMarkResult']:
406+
) -> 'ParallelRandomizedBenchmarkingResult':
282407
"""Clifford-based randomized benchmarking (RB) single qubits in parallel.
283408
284409
This is the same as `single_qubit_randomized_benchmarking` except on all
@@ -321,7 +446,9 @@ def parallel_single_qubit_randomized_benchmarking(
321446
idx += 1
322447
for qubit in qubits:
323448
gnd_probs[qubit].append(1.0 - np.mean(excited_probs[qubit]))
324-
return {q: RandomizedBenchMarkResult(num_clifford_range, gnd_probs[q]) for q in qubits}
449+
return ParallelRandomizedBenchmarkingResult(
450+
{q: RandomizedBenchMarkResult(num_clifford_range, gnd_probs[q]) for q in qubits}
451+
)
325452

326453

327454
def two_qubit_randomized_benchmarking(

cirq-core/cirq/experiments/qubit_characterizations_test.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,13 @@ def test_parallel_single_qubit_randomized_benchmarking():
126126
simulator, num_clifford_range=num_cfds, repetitions=100, qubits=qubits
127127
)
128128
for qubit in qubits:
129-
g_pops = np.asarray(results[qubit].data)[:, 1]
129+
g_pops = np.asarray(results.results_dictionary[qubit].data)[:, 1]
130130
assert np.isclose(np.mean(g_pops), 1.0)
131+
_ = results.plot_single_qubit(qubit)
132+
pauli_errors = results.pauli_error()
133+
assert len(pauli_errors) == len(qubits)
134+
_ = results.plot_heatmap()
135+
_ = results.plot_integrated_histogram()
131136

132137

133138
def test_two_qubit_randomized_benchmarking():

0 commit comments

Comments
 (0)