|
16 | 16 | # pylint: disable=invalid-name
|
17 | 17 | import warnings
|
18 | 18 | from typing import Dict, List, Optional, Tuple, Union
|
| 19 | +from collections import defaultdict |
19 | 20 |
|
20 | 21 | import lmfit
|
21 | 22 | import numpy as np
|
22 | 23 | import pandas as pd
|
23 |
| -from uncertainties import unumpy as unp |
24 | 24 |
|
25 | 25 | from qiskit.utils.deprecation import deprecate_func
|
26 | 26 |
|
|
39 | 39 | )
|
40 | 40 |
|
41 | 41 | from qiskit_experiments.framework.containers import FigureType, ArtifactData
|
42 |
| -from .base_curve_analysis import DATA_ENTRY_PREFIX, BaseCurveAnalysis, PARAMS_ENTRY_PREFIX |
| 42 | +from .base_curve_analysis import BaseCurveAnalysis |
43 | 43 | from .curve_data import CurveFitResult
|
44 | 44 | from .scatter_table import ScatterTable
|
45 |
| -from .utils import eval_with_uncertainties |
46 | 45 |
|
47 | 46 |
|
48 | 47 | class CompositeCurveAnalysis(BaseAnalysis):
|
@@ -344,123 +343,64 @@ def _run_analysis(
|
344 | 343 | else:
|
345 | 344 | plot = getattr(self, "_generate_figures", "always")
|
346 | 345 |
|
347 |
| - fit_dataset = {} |
348 |
| - curve_data_set = [] |
349 |
| - for analysis in self._analyses: |
350 |
| - analysis._initialize(experiment_data) |
351 |
| - analysis.set_options(plot=False) |
352 |
| - |
353 |
| - metadata = analysis.options.extra.copy() |
| 346 | + sub_artifacts = defaultdict(list) |
| 347 | + for source_analysis in self._analyses: |
| 348 | + analysis = source_analysis.copy() |
| 349 | + metadata = analysis.options.extra |
354 | 350 | metadata["group"] = analysis.name
|
| 351 | + analysis.set_options( |
| 352 | + plot=False, |
| 353 | + extra=metadata, |
| 354 | + return_fit_parameters=self.options.return_fit_parameters, |
| 355 | + return_data_points=self.options.return_data_points, |
| 356 | + ) |
| 357 | + results, _ = analysis._run_analysis(experiment_data) |
| 358 | + for res in results: |
| 359 | + if isinstance(res, ArtifactData): |
| 360 | + sub_artifacts[res.name].append((analysis.name, res.data)) |
| 361 | + else: |
| 362 | + result_data.append(res) |
| 363 | + |
| 364 | + if "curve_data" in sub_artifacts: |
| 365 | + combined_curve_data = ScatterTable.from_dataframe( |
| 366 | + data=pd.concat([d.dataframe for _, d in sub_artifacts["curve_data"]]) |
| 367 | + ) |
| 368 | + artifacts.append(ArtifactData(name="curve_data", data=combined_curve_data)) |
| 369 | + else: |
| 370 | + combined_curve_data = None |
355 | 371 |
|
356 |
| - table = analysis._format_data(analysis._run_data_processing(experiment_data.data())) |
357 |
| - formatted_subset = table.filter(category=analysis.options.fit_category) |
358 |
| - fit_data = analysis._run_curve_fit(formatted_subset) |
359 |
| - fit_dataset[analysis.name] = fit_data |
360 |
| - |
361 |
| - if fit_data.success: |
362 |
| - quality = analysis._evaluate_quality(fit_data) |
363 |
| - else: |
364 |
| - quality = "bad" |
365 |
| - |
366 |
| - if self.options.return_fit_parameters: |
367 |
| - # Store fit status overview entry regardless of success. |
368 |
| - # This is sometime useful when debugging the fitting code. |
369 |
| - overview = AnalysisResultData( |
370 |
| - name=PARAMS_ENTRY_PREFIX + analysis.name, |
371 |
| - value=fit_data, |
372 |
| - quality=quality, |
373 |
| - extra=metadata, |
374 |
| - ) |
375 |
| - result_data.append(overview) |
376 |
| - |
377 |
| - if fit_data.success: |
378 |
| - # Add fit data to curve data table |
379 |
| - model_names = analysis.model_names() |
380 |
| - for series_id, sub_data in formatted_subset.iter_by_series_id(): |
381 |
| - xval = sub_data.x |
382 |
| - if len(xval) == 0: |
383 |
| - # If data is empty, skip drawing this model. |
384 |
| - # This is the case when fit model exist but no data to fit is provided. |
385 |
| - continue |
386 |
| - # Compute X, Y values with fit parameters. |
387 |
| - xval_arr_fit = np.linspace(np.min(xval), np.max(xval), num=100, dtype=float) |
388 |
| - uval_arr_fit = eval_with_uncertainties( |
389 |
| - x=xval_arr_fit, |
390 |
| - model=analysis.models[series_id], |
391 |
| - params=fit_data.ufloat_params, |
392 |
| - ) |
393 |
| - yval_arr_fit = unp.nominal_values(uval_arr_fit) |
394 |
| - if fit_data.covar is not None: |
395 |
| - yerr_arr_fit = unp.std_devs(uval_arr_fit) |
396 |
| - else: |
397 |
| - yerr_arr_fit = np.zeros_like(xval_arr_fit) |
398 |
| - for xval, yval, yerr in zip(xval_arr_fit, yval_arr_fit, yerr_arr_fit): |
399 |
| - table.add_row( |
400 |
| - xval=xval, |
401 |
| - yval=yval, |
402 |
| - yerr=yerr, |
403 |
| - series_name=model_names[series_id], |
404 |
| - series_id=series_id, |
405 |
| - category="fitted", |
406 |
| - analysis=analysis.name, |
407 |
| - ) |
408 |
| - result_data.extend( |
409 |
| - analysis._create_analysis_results( |
410 |
| - fit_data=fit_data, |
411 |
| - quality=quality, |
412 |
| - **metadata.copy(), |
413 |
| - ) |
414 |
| - ) |
415 |
| - |
416 |
| - if self.options.return_data_points: |
417 |
| - # Add raw data points |
418 |
| - warnings.warn( |
419 |
| - f"{DATA_ENTRY_PREFIX + self.name} has been moved to experiment data artifacts. " |
420 |
| - "Saving this result with 'return_data_points'=True will be disabled in " |
421 |
| - "Qiskit Experiments 0.7.", |
422 |
| - DeprecationWarning, |
423 |
| - ) |
424 |
| - result_data.extend( |
425 |
| - analysis._create_curve_data(curve_data=formatted_subset, **metadata) |
426 |
| - ) |
427 |
| - |
428 |
| - curve_data_set.append(table) |
429 |
| - |
430 |
| - combined_curve_data = ScatterTable.from_dataframe( |
431 |
| - pd.concat([d.dataframe for d in curve_data_set]) |
432 |
| - ) |
433 |
| - total_quality = self._evaluate_quality(fit_dataset) |
| 372 | + if "fit_summary" in sub_artifacts: |
| 373 | + combined_summary = dict(sub_artifacts["fit_summary"]) |
| 374 | + artifacts.append(ArtifactData(name="fit_summary", data=combined_summary)) |
| 375 | + total_quality = self._evaluate_quality(combined_summary) |
| 376 | + else: |
| 377 | + combined_summary = None |
| 378 | + total_quality = "No Information" |
434 | 379 |
|
435 | 380 | # After the quality is determined, plot can become a boolean flag for whether
|
436 | 381 | # to generate the figure
|
437 | 382 | plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")
|
438 | 383 |
|
439 | 384 | # Create analysis results by combining all fit data
|
440 |
| - if all(fit_data.success for fit_data in fit_dataset.values()): |
| 385 | + if combined_summary and all(fit_data.success for fit_data in combined_summary.values()): |
441 | 386 | composite_results = self._create_analysis_results(
|
442 |
| - fit_data=fit_dataset, quality=total_quality, **self.options.extra.copy() |
| 387 | + fit_data=combined_summary, |
| 388 | + quality=total_quality, |
| 389 | + **self.options.extra.copy(), |
443 | 390 | )
|
444 | 391 | result_data.extend(composite_results)
|
445 | 392 | else:
|
446 | 393 | composite_results = []
|
447 | 394 |
|
448 |
| - artifacts.append( |
449 |
| - ArtifactData( |
450 |
| - name="curve_data", |
451 |
| - data=combined_curve_data, |
452 |
| - ) |
453 |
| - ) |
454 |
| - artifacts.append( |
455 |
| - ArtifactData( |
456 |
| - name="fit_summary", |
457 |
| - data=fit_dataset, |
458 |
| - ) |
459 |
| - ) |
460 |
| - |
461 |
| - if plot_bool: |
| 395 | + if plot_bool and combined_curve_data: |
| 396 | + if combined_summary: |
| 397 | + red_chi_dict = { |
| 398 | + k: v.reduced_chisq for k, v in combined_summary.items() if v.success |
| 399 | + } |
| 400 | + else: |
| 401 | + red_chi_dict = {} |
462 | 402 | self.plotter.set_supplementary_data(
|
463 |
| - fit_red_chi={k: v.reduced_chisq for k, v in fit_dataset.items() if v.success}, |
| 403 | + fit_red_chi=red_chi_dict, |
464 | 404 | primary_results=composite_results,
|
465 | 405 | )
|
466 | 406 | figures.extend(self._create_figures(curve_data=combined_curve_data))
|
|
0 commit comments