Skip to content

Commit cb37d42

Browse files
Cleanup composite analysis (#1397)
### Summary Thanks to #1342 we can cleanup internals of `CompositeCurveAnalysis`. Not API break and no feature upgrade with this PR. ### Details and comments Previously the curve data and fit summary data are internally created in `CurveAnalysis` but immediately discarded. The implementation in `CurveAnalysis._run_analysis` is manually copied to `CompositeCurveAnalysis._run_analysis` to access these artifact data to create composite artifact data from them. This makes code fragile since developers needed to manually update both base classes. With this PR, implementation of component analysis is encapsulated.
1 parent e9acd22 commit cb37d42

File tree

2 files changed

+46
-106
lines changed

2 files changed

+46
-106
lines changed

qiskit_experiments/curve_analysis/base_curve_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def _create_curve_data(
356356
"""
357357
samples = []
358358

359-
for model_name, sub_data in list(curve_data.groupby("model_name")):
359+
for model_name, sub_data in list(curve_data.dataframe.groupby("model_name")):
360360
raw_datum = AnalysisResultData(
361361
name=DATA_ENTRY_PREFIX + self.__class__.__name__,
362362
value={

qiskit_experiments/curve_analysis/composite_curve_analysis.py

Lines changed: 45 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
# pylint: disable=invalid-name
1717
import warnings
1818
from typing import Dict, List, Optional, Tuple, Union
19+
from collections import defaultdict
1920

2021
import lmfit
2122
import numpy as np
2223
import pandas as pd
23-
from uncertainties import unumpy as unp
2424

2525
from qiskit.utils.deprecation import deprecate_func
2626

@@ -39,10 +39,9 @@
3939
)
4040

4141
from qiskit_experiments.framework.containers import FigureType, ArtifactData
42-
from .base_curve_analysis import DATA_ENTRY_PREFIX, BaseCurveAnalysis, PARAMS_ENTRY_PREFIX
42+
from .base_curve_analysis import BaseCurveAnalysis
4343
from .curve_data import CurveFitResult
4444
from .scatter_table import ScatterTable
45-
from .utils import eval_with_uncertainties
4645

4746

4847
class CompositeCurveAnalysis(BaseAnalysis):
@@ -344,123 +343,64 @@ def _run_analysis(
344343
else:
345344
plot = getattr(self, "_generate_figures", "always")
346345

347-
fit_dataset = {}
348-
curve_data_set = []
349-
for analysis in self._analyses:
350-
analysis._initialize(experiment_data)
351-
analysis.set_options(plot=False)
352-
353-
metadata = analysis.options.extra.copy()
346+
sub_artifacts = defaultdict(list)
347+
for source_analysis in self._analyses:
348+
analysis = source_analysis.copy()
349+
metadata = analysis.options.extra
354350
metadata["group"] = analysis.name
351+
analysis.set_options(
352+
plot=False,
353+
extra=metadata,
354+
return_fit_parameters=self.options.return_fit_parameters,
355+
return_data_points=self.options.return_data_points,
356+
)
357+
results, _ = analysis._run_analysis(experiment_data)
358+
for res in results:
359+
if isinstance(res, ArtifactData):
360+
sub_artifacts[res.name].append((analysis.name, res.data))
361+
else:
362+
result_data.append(res)
363+
364+
if "curve_data" in sub_artifacts:
365+
combined_curve_data = ScatterTable.from_dataframe(
366+
data=pd.concat([d.dataframe for _, d in sub_artifacts["curve_data"]])
367+
)
368+
artifacts.append(ArtifactData(name="curve_data", data=combined_curve_data))
369+
else:
370+
combined_curve_data = None
355371

356-
table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
357-
formatted_subset = table.filter(category=analysis.options.fit_category)
358-
fit_data = analysis._run_curve_fit(formatted_subset)
359-
fit_dataset[analysis.name] = fit_data
360-
361-
if fit_data.success:
362-
quality = analysis._evaluate_quality(fit_data)
363-
else:
364-
quality = "bad"
365-
366-
if self.options.return_fit_parameters:
367-
# Store fit status overview entry regardless of success.
368-
# This is sometime useful when debugging the fitting code.
369-
overview = AnalysisResultData(
370-
name=PARAMS_ENTRY_PREFIX + analysis.name,
371-
value=fit_data,
372-
quality=quality,
373-
extra=metadata,
374-
)
375-
result_data.append(overview)
376-
377-
if fit_data.success:
378-
# Add fit data to curve data table
379-
model_names = analysis.model_names()
380-
for series_id, sub_data in formatted_subset.iter_by_series_id():
381-
xval = sub_data.x
382-
if len(xval) == 0:
383-
# If data is empty, skip drawing this model.
384-
# This is the case when fit model exist but no data to fit is provided.
385-
continue
386-
# Compute X, Y values with fit parameters.
387-
xval_arr_fit = np.linspace(np.min(xval), np.max(xval), num=100, dtype=float)
388-
uval_arr_fit = eval_with_uncertainties(
389-
x=xval_arr_fit,
390-
model=analysis.models[series_id],
391-
params=fit_data.ufloat_params,
392-
)
393-
yval_arr_fit = unp.nominal_values(uval_arr_fit)
394-
if fit_data.covar is not None:
395-
yerr_arr_fit = unp.std_devs(uval_arr_fit)
396-
else:
397-
yerr_arr_fit = np.zeros_like(xval_arr_fit)
398-
for xval, yval, yerr in zip(xval_arr_fit, yval_arr_fit, yerr_arr_fit):
399-
table.add_row(
400-
xval=xval,
401-
yval=yval,
402-
yerr=yerr,
403-
series_name=model_names[series_id],
404-
series_id=series_id,
405-
category="fitted",
406-
analysis=analysis.name,
407-
)
408-
result_data.extend(
409-
analysis._create_analysis_results(
410-
fit_data=fit_data,
411-
quality=quality,
412-
**metadata.copy(),
413-
)
414-
)
415-
416-
if self.options.return_data_points:
417-
# Add raw data points
418-
warnings.warn(
419-
f"{DATA_ENTRY_PREFIX + self.name} has been moved to experiment data artifacts. "
420-
"Saving this result with 'return_data_points'=True will be disabled in "
421-
"Qiskit Experiments 0.7.",
422-
DeprecationWarning,
423-
)
424-
result_data.extend(
425-
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
426-
)
427-
428-
curve_data_set.append(table)
429-
430-
combined_curve_data = ScatterTable.from_dataframe(
431-
pd.concat([d.dataframe for d in curve_data_set])
432-
)
433-
total_quality = self._evaluate_quality(fit_dataset)
372+
if "fit_summary" in sub_artifacts:
373+
combined_summary = dict(sub_artifacts["fit_summary"])
374+
artifacts.append(ArtifactData(name="fit_summary", data=combined_summary))
375+
total_quality = self._evaluate_quality(combined_summary)
376+
else:
377+
combined_summary = None
378+
total_quality = "No Information"
434379

435380
# After the quality is determined, plot can become a boolean flag for whether
436381
# to generate the figure
437382
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")
438383

439384
# Create analysis results by combining all fit data
440-
if all(fit_data.success for fit_data in fit_dataset.values()):
385+
if combined_summary and all(fit_data.success for fit_data in combined_summary.values()):
441386
composite_results = self._create_analysis_results(
442-
fit_data=fit_dataset, quality=total_quality, **self.options.extra.copy()
387+
fit_data=combined_summary,
388+
quality=total_quality,
389+
**self.options.extra.copy(),
443390
)
444391
result_data.extend(composite_results)
445392
else:
446393
composite_results = []
447394

448-
artifacts.append(
449-
ArtifactData(
450-
name="curve_data",
451-
data=combined_curve_data,
452-
)
453-
)
454-
artifacts.append(
455-
ArtifactData(
456-
name="fit_summary",
457-
data=fit_dataset,
458-
)
459-
)
460-
461-
if plot_bool:
395+
if plot_bool and combined_curve_data:
396+
if combined_summary:
397+
red_chi_dict = {
398+
k: v.reduced_chisq for k, v in combined_summary.items() if v.success
399+
}
400+
else:
401+
red_chi_dict = {}
462402
self.plotter.set_supplementary_data(
463-
fit_red_chi={k: v.reduced_chisq for k, v in fit_dataset.items() if v.success},
403+
fit_red_chi=red_chi_dict,
464404
primary_results=composite_results,
465405
)
466406
figures.extend(self._create_figures(curve_data=combined_curve_data))

0 commit comments

Comments
 (0)