7
7
import numpy as np
8
8
import scipy
9
9
from pydantic .v1 import Field , NonNegativeFloat , PositiveFloat , PositiveInt , validator
10
- from rich .progress import Progress
11
10
12
11
from ..constants import fp_eps
13
12
from ..exceptions import ValidationError
@@ -759,6 +758,7 @@ def fit(
759
758
tolerance_rms : NonNegativeFloat = DEFAULT_TOLERANCE_RMS ,
760
759
advanced_param : AdvancedFastFitterParam = None ,
761
760
scale_factor : PositiveFloat = 1 ,
761
+ show_progress : bool = True ,
762
762
) -> Tuple [Tuple [float , ArrayComplex1D , ArrayComplex1D ], float ]:
763
763
"""Fit data using a fast fitting algorithm.
764
764
@@ -815,6 +815,8 @@ def fit(
815
815
Advanced parameters for fitting.
816
816
scale_factor : PositiveFloat, optional
817
817
Factor to rescale frequency by before fitting.
818
+ show_progress : bool = True
819
+ Whether to show a progress bar for the fitting.
818
820
819
821
Returns
820
822
-------
@@ -823,6 +825,8 @@ def fit(
823
825
The dispersive medium parameters have the form (resp_inf, poles, residues)
824
826
and are in the original unscaled units.
825
827
"""
828
+ if show_progress :
829
+ from rich .progress import Progress
826
830
827
831
if max_num_poles < min_num_poles :
828
832
raise ValidationError (
@@ -862,6 +866,82 @@ def make_configs():
862
866
863
867
configs = make_configs ()
864
868
869
+ if not show_progress :
870
+ for num_poles , relaxed , smooth , logspacing , optimize_eps_inf in configs :
871
+ model = init_model .updated_copy (
872
+ num_poles = num_poles ,
873
+ relaxed = relaxed ,
874
+ smooth = smooth ,
875
+ logspacing = logspacing ,
876
+ optimize_eps_inf = optimize_eps_inf ,
877
+ )
878
+ model = _fit_fixed_parameters ((min_num_poles , max_num_poles ), model )
879
+
880
+ if model .rms_error < best_model .rms_error :
881
+ log .debug (
882
+ f"Fitter: possible improved fit with "
883
+ f"rms_error={ model .rms_error :.3g} found using "
884
+ f"relaxed={ model .relaxed } , "
885
+ f"smooth={ model .smooth } , "
886
+ f"logspacing={ model .logspacing } , "
887
+ f"optimize_eps_inf={ model .optimize_eps_inf } , "
888
+ f"loss_in_bounds={ model .loss_in_bounds } , "
889
+ f"passivity_optimized={ model .passivity_optimized } , "
890
+ f"sellmeier_passivity={ model .sellmeier_passivity } ."
891
+ )
892
+ if model .loss_in_bounds and model .sellmeier_passivity :
893
+ best_model = model
894
+ else :
895
+ if not warned_about_passivity_num_iters and model .passivity_num_iters_too_small :
896
+ warned_about_passivity_num_iters = True
897
+ log .warning (
898
+ "Did not finish enforcing passivity in dispersion fitter. "
899
+ "If the fit is not good enough, consider increasing "
900
+ "'AdvancedFastFitterParam.passivity_num_iters'."
901
+ )
902
+ if (
903
+ not warned_about_slsqp_constraint_scale
904
+ and model .slsqp_constraint_scale_too_small
905
+ ):
906
+ warned_about_slsqp_constraint_scale = True
907
+ log .warning (
908
+ "SLSQP constraint scale may be too small. "
909
+ "If the fit is not good enough, consider increasing "
910
+ "'AdvancedFastFitterParam.slsqp_constraint_scale'."
911
+ )
912
+
913
+ # if below tolerance, return
914
+ if best_model .rms_error < tolerance_rms :
915
+ log .info (
916
+ "Found optimal fit with weighted RMS error %.3g" ,
917
+ best_model .rms_error ,
918
+ )
919
+ if best_model .show_unweighted_rms :
920
+ log .info (
921
+ "Unweighted RMS error %.3g" ,
922
+ best_model .unweighted_rms_error ,
923
+ )
924
+ return (
925
+ best_model .pole_residue ,
926
+ best_model .rms_error ,
927
+ )
928
+
929
+ # if exited loop, did not reach tolerance (warn)
930
+ log .warning (
931
+ "Unable to fit with weighted RMS error under 'tolerance_rms' of %.3g" , tolerance_rms
932
+ )
933
+ log .info ("Returning best fit with weighted RMS error %.3g" , best_model .rms_error )
934
+ if best_model .show_unweighted_rms :
935
+ log .info (
936
+ "Unweighted RMS error %.3g" ,
937
+ best_model .unweighted_rms_error ,
938
+ )
939
+
940
+ return (
941
+ best_model .pole_residue ,
942
+ best_model .rms_error ,
943
+ )
944
+
865
945
with Progress (console = get_logging_console ()) as progress :
866
946
task = progress .add_task (
867
947
f"Fitting to weighted RMS of { tolerance_rms } ..." ,
0 commit comments