Skip to content

Commit 03c0c94

Browse files
committed
Add option to use dispersion fitter without rich.progress
1 parent 2bbec11 commit 03c0c94

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

Diff for: tests/test_plugins/test_dispersion_fitter.py

+16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import responses
66
import tidy3d as td
7+
from tidy3d.components.dispersion_fitter import fit
78
from tidy3d.exceptions import SetupError, ValidationError
89
from tidy3d.plugins.dispersion import (
910
AdvancedFastFitterParam,
@@ -285,3 +286,18 @@ def test_dispersion_loss_samples():
285286
ep = nAlGaN_mat.eps_model(freq_list)
286287
for e in ep:
287288
assert e.imag >= 0
289+
290+
291+
@responses.activate
292+
def test_fit_no_progress(random_data):
293+
wvl_um, n_data, k_data = random_data
294+
eps_complex = (n_data + 1j * k_data) ** 2
295+
omega = 2 * np.pi * td.C_0 / wvl_um
296+
297+
medium, rms = fit(
298+
omega_data=omega,
299+
resp_data=eps_complex,
300+
scale_factor=td.HBAR,
301+
advanced_param=advanced_param,
302+
show_progress=False,
303+
)

Diff for: tidy3d/components/dispersion_fitter.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import scipy
99
from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator
10-
from rich.progress import Progress
1110

1211
from ..constants import fp_eps
1312
from ..exceptions import ValidationError
@@ -759,6 +758,7 @@ def fit(
759758
tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS,
760759
advanced_param: AdvancedFastFitterParam = None,
761760
scale_factor: PositiveFloat = 1,
761+
show_progress: bool = True,
762762
) -> Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float]:
763763
"""Fit data using a fast fitting algorithm.
764764
@@ -815,6 +815,8 @@ def fit(
815815
Advanced parameters for fitting.
816816
scale_factor : PositiveFloat, optional
817817
Factor to rescale frequency by before fitting.
818+
show_progress : bool = True
819+
Whether to show a progress bar for the fitting.
818820
819821
Returns
820822
-------
@@ -823,6 +825,8 @@ def fit(
823825
The dispersive medium parameters have the form (resp_inf, poles, residues)
824826
and are in the original unscaled units.
825827
"""
828+
if show_progress:
829+
from rich.progress import Progress
826830

827831
if max_num_poles < min_num_poles:
828832
raise ValidationError(
@@ -862,6 +866,82 @@ def make_configs():
862866

863867
configs = make_configs()
864868

869+
if not show_progress:
870+
for num_poles, relaxed, smooth, logspacing, optimize_eps_inf in configs:
871+
model = init_model.updated_copy(
872+
num_poles=num_poles,
873+
relaxed=relaxed,
874+
smooth=smooth,
875+
logspacing=logspacing,
876+
optimize_eps_inf=optimize_eps_inf,
877+
)
878+
model = _fit_fixed_parameters((min_num_poles, max_num_poles), model)
879+
880+
if model.rms_error < best_model.rms_error:
881+
log.debug(
882+
f"Fitter: possible improved fit with "
883+
f"rms_error={model.rms_error:.3g} found using "
884+
f"relaxed={model.relaxed}, "
885+
f"smooth={model.smooth}, "
886+
f"logspacing={model.logspacing}, "
887+
f"optimize_eps_inf={model.optimize_eps_inf}, "
888+
f"loss_in_bounds={model.loss_in_bounds}, "
889+
f"passivity_optimized={model.passivity_optimized}, "
890+
f"sellmeier_passivity={model.sellmeier_passivity}."
891+
)
892+
if model.loss_in_bounds and model.sellmeier_passivity:
893+
best_model = model
894+
else:
895+
if not warned_about_passivity_num_iters and model.passivity_num_iters_too_small:
896+
warned_about_passivity_num_iters = True
897+
log.warning(
898+
"Did not finish enforcing passivity in dispersion fitter. "
899+
"If the fit is not good enough, consider increasing "
900+
"'AdvancedFastFitterParam.passivity_num_iters'."
901+
)
902+
if (
903+
not warned_about_slsqp_constraint_scale
904+
and model.slsqp_constraint_scale_too_small
905+
):
906+
warned_about_slsqp_constraint_scale = True
907+
log.warning(
908+
"SLSQP constraint scale may be too small. "
909+
"If the fit is not good enough, consider increasing "
910+
"'AdvancedFastFitterParam.slsqp_constraint_scale'."
911+
)
912+
913+
# if below tolerance, return
914+
if best_model.rms_error < tolerance_rms:
915+
log.info(
916+
"Found optimal fit with weighted RMS error %.3g",
917+
best_model.rms_error,
918+
)
919+
if best_model.show_unweighted_rms:
920+
log.info(
921+
"Unweighted RMS error %.3g",
922+
best_model.unweighted_rms_error,
923+
)
924+
return (
925+
best_model.pole_residue,
926+
best_model.rms_error,
927+
)
928+
929+
# if exited loop, did not reach tolerance (warn)
930+
log.warning(
931+
"Unable to fit with weighted RMS error under 'tolerance_rms' of %.3g", tolerance_rms
932+
)
933+
log.info("Returning best fit with weighted RMS error %.3g", best_model.rms_error)
934+
if best_model.show_unweighted_rms:
935+
log.info(
936+
"Unweighted RMS error %.3g",
937+
best_model.unweighted_rms_error,
938+
)
939+
940+
return (
941+
best_model.pole_residue,
942+
best_model.rms_error,
943+
)
944+
865945
with Progress(console=get_logging_console()) as progress:
866946
task = progress.add_task(
867947
f"Fitting to weighted RMS of {tolerance_rms}...",

0 commit comments

Comments
 (0)