Skip to content

Commit dbabe42

Browse files
committed
Add option to use dispersion fitter without rich.progress
1 parent 4a25807 commit dbabe42

File tree

4 files changed

+169
-86
lines changed

4 files changed

+169
-86
lines changed

tests/test_plugins/test_dispersion_fitter.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import responses
66
import tidy3d as td
7+
from tidy3d.components.dispersion_fitter import fit
78
from tidy3d.exceptions import SetupError, ValidationError
89
from tidy3d.plugins.dispersion import (
910
AdvancedFastFitterParam,

tidy3d/components/dispersion_fitter.py

+120-72
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
import numpy as np
88
import scipy
99
from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator
10-
from rich.progress import Progress
1110

1211
from ..constants import fp_eps
1312
from ..exceptions import ValidationError
14-
from ..log import get_logging_console, log
13+
from ..log import Progress, get_logging_console, log
1514
from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
1615
from .types import ArrayComplex1D, ArrayComplex2D, ArrayFloat1D, ArrayFloat2D
1716

@@ -823,7 +822,6 @@ def fit(
823822
The dispersive medium parameters have the form (resp_inf, poles, residues)
824823
and are in the original unscaled units.
825824
"""
826-
827825
if max_num_poles < min_num_poles:
828826
raise ValidationError(
829827
"Dispersion fitter cannot have 'max_num_poles' less than 'min_num_poles'."
@@ -864,86 +862,82 @@ def make_configs():
864862

865863
with Progress(console=get_logging_console()) as progress:
866864
task = progress.add_task(
867-
f"Fitting to weighted RMS of {tolerance_rms}...",
865+
description=f"Fitting to weighted RMS of {tolerance_rms}...",
868866
total=len(configs),
869867
visible=init_model.show_progress,
870868
)
871869

872-
while not progress.finished:
873-
# try different initial pole configurations
874-
for num_poles, relaxed, smooth, logspacing, optimize_eps_inf in configs:
875-
model = init_model.updated_copy(
876-
num_poles=num_poles,
877-
relaxed=relaxed,
878-
smooth=smooth,
879-
logspacing=logspacing,
880-
optimize_eps_inf=optimize_eps_inf,
870+
# try different initial pole configurations
871+
for num_poles, relaxed, smooth, logspacing, optimize_eps_inf in configs:
872+
model = init_model.updated_copy(
873+
num_poles=num_poles,
874+
relaxed=relaxed,
875+
smooth=smooth,
876+
logspacing=logspacing,
877+
optimize_eps_inf=optimize_eps_inf,
878+
)
879+
model = _fit_fixed_parameters((min_num_poles, max_num_poles), model)
880+
881+
if model.rms_error < best_model.rms_error:
882+
log.debug(
883+
f"Fitter: possible improved fit with "
884+
f"rms_error={model.rms_error:.3g} found using "
885+
f"relaxed={model.relaxed}, "
886+
f"smooth={model.smooth}, "
887+
f"logspacing={model.logspacing}, "
888+
f"optimize_eps_inf={model.optimize_eps_inf}, "
889+
f"loss_in_bounds={model.loss_in_bounds}, "
890+
f"passivity_optimized={model.passivity_optimized}, "
891+
f"sellmeier_passivity={model.sellmeier_passivity}."
881892
)
882-
model = _fit_fixed_parameters((min_num_poles, max_num_poles), model)
883-
884-
if model.rms_error < best_model.rms_error:
885-
log.debug(
886-
f"Fitter: possible improved fit with "
887-
f"rms_error={model.rms_error:.3g} found using "
888-
f"relaxed={model.relaxed}, "
889-
f"smooth={model.smooth}, "
890-
f"logspacing={model.logspacing}, "
891-
f"optimize_eps_inf={model.optimize_eps_inf}, "
892-
f"loss_in_bounds={model.loss_in_bounds}, "
893-
f"passivity_optimized={model.passivity_optimized}, "
894-
f"sellmeier_passivity={model.sellmeier_passivity}."
895-
)
896-
if model.loss_in_bounds and model.sellmeier_passivity:
897-
best_model = model
898-
else:
899-
if (
900-
not warned_about_passivity_num_iters
901-
and model.passivity_num_iters_too_small
902-
):
903-
warned_about_passivity_num_iters = True
904-
log.warning(
905-
"Did not finish enforcing passivity in dispersion fitter. "
906-
"If the fit is not good enough, consider increasing "
907-
"'AdvancedFastFitterParam.passivity_num_iters'."
908-
)
909-
if (
910-
not warned_about_slsqp_constraint_scale
911-
and model.slsqp_constraint_scale_too_small
912-
):
913-
warned_about_slsqp_constraint_scale = True
914-
log.warning(
915-
"SLSQP constraint scale may be too small. "
916-
"If the fit is not good enough, consider increasing "
917-
"'AdvancedFastFitterParam.slsqp_constraint_scale'."
918-
)
893+
if model.loss_in_bounds and model.sellmeier_passivity:
894+
best_model = model
895+
else:
896+
if not warned_about_passivity_num_iters and model.passivity_num_iters_too_small:
897+
warned_about_passivity_num_iters = True
898+
log.warning(
899+
"Did not finish enforcing passivity in dispersion fitter. "
900+
"If the fit is not good enough, consider increasing "
901+
"'AdvancedFastFitterParam.passivity_num_iters'."
902+
)
903+
if (
904+
not warned_about_slsqp_constraint_scale
905+
and model.slsqp_constraint_scale_too_small
906+
):
907+
warned_about_slsqp_constraint_scale = True
908+
log.warning(
909+
"SLSQP constraint scale may be too small. "
910+
"If the fit is not good enough, consider increasing "
911+
"'AdvancedFastFitterParam.slsqp_constraint_scale'."
912+
)
913+
progress.update(
914+
task,
915+
advance=1,
916+
description=f"Best weighted RMS error so far: {best_model.rms_error:.3g}",
917+
refresh=True,
918+
)
919+
920+
# if below tolerance, return
921+
if best_model.rms_error < tolerance_rms:
919922
progress.update(
920923
task,
921-
advance=1,
922-
description=f"Best weighted RMS error so far: {best_model.rms_error:.3g}",
924+
completed=len(configs),
925+
description=f"Best weighted RMS error: {best_model.rms_error:.3g}",
923926
refresh=True,
924927
)
925-
926-
# if below tolerance, return
927-
if best_model.rms_error < tolerance_rms:
928-
progress.update(
929-
task,
930-
completed=len(configs),
931-
description=f"Best weighted RMS error: {best_model.rms_error:.3g}",
932-
refresh=True,
933-
)
928+
log.info(
929+
"Found optimal fit with weighted RMS error %.3g",
930+
best_model.rms_error,
931+
)
932+
if best_model.show_unweighted_rms:
934933
log.info(
935-
"Found optimal fit with weighted RMS error %.3g",
936-
best_model.rms_error,
937-
)
938-
if best_model.show_unweighted_rms:
939-
log.info(
940-
"Unweighted RMS error %.3g",
941-
best_model.unweighted_rms_error,
942-
)
943-
return (
944-
best_model.pole_residue,
945-
best_model.rms_error,
934+
"Unweighted RMS error %.3g",
935+
best_model.unweighted_rms_error,
946936
)
937+
return (
938+
best_model.pole_residue,
939+
best_model.rms_error,
940+
)
947941

948942
# if exited loop, did not reach tolerance (warn)
949943
progress.update(
@@ -967,3 +961,57 @@ def make_configs():
967961
best_model.pole_residue,
968962
best_model.rms_error,
969963
)
964+
965+
966+
def constant_loss_tangent_model(
967+
eps_real: float,
968+
loss_tangent: float,
969+
frequency_range: Tuple[float, float],
970+
max_num_poles: PositiveInt = DEFAULT_MAX_POLES,
971+
number_sampling_frequency: PositiveInt = 10,
972+
tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS,
973+
scale_factor: float = 1,
974+
) -> Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float]:
975+
"""Fit a constant loss tangent material model.
976+
977+
Parameters
978+
----------
979+
eps_real : float
980+
Real part of permittivity
981+
loss_tangent : float
982+
Loss tangent.
983+
frequency_range : Tuple[float, float]
984+
Freqquency range for the material to exhibit constant loss tangent response.
985+
max_num_poles : PositiveInt, optional
986+
Maximum number of poles in the model.
987+
number_sampling_frequency : PositiveInt, optional
988+
Number of sampling frequencies to compute RMS error for fitting.
989+
tolerance_rms : float, optional
990+
Weighted RMS error below which the fit is successful and the result is returned.
991+
scale_factor : PositiveFloat, optional
992+
Factor to rescale frequency by before fitting.
993+
994+
Returns
995+
-------
996+
Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float]
997+
Best fitting result: (dispersive medium parameters, weighted RMS error).
998+
The dispersive medium parameters have the form (resp_inf, poles, residues)
999+
and are in the original unscaled units.
1000+
"""
1001+
if number_sampling_frequency < 2:
1002+
frequencies = np.array([np.mean(frequency_range)])
1003+
else:
1004+
frequencies = np.linspace(frequency_range[0], frequency_range[1], number_sampling_frequency)
1005+
eps_real_array = np.ones_like(frequencies) * eps_real
1006+
loss_tangent_array = np.ones_like(frequencies) * loss_tangent
1007+
1008+
omega_data = frequencies * 2 * np.pi
1009+
eps_complex = eps_real_array * (1 + 1j * loss_tangent_array)
1010+
1011+
return fit(
1012+
omega_data=omega_data,
1013+
resp_data=eps_complex,
1014+
max_num_poles=max_num_poles,
1015+
tolerance_rms=tolerance_rms,
1016+
scale_factor=scale_factor,
1017+
)

tidy3d/log.py

+27
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Logging for Tidy3d."""
22

33
import inspect
4+
from contextlib import contextmanager
45
from datetime import datetime
56
from typing import Callable, List, Tuple, Union
67

@@ -442,3 +443,29 @@ def get_logging_console() -> Console:
442443
if "console" not in log.handlers:
443444
set_logging_console()
444445
return log.handlers["console"].console
446+
447+
448+
class NoOpProgress:
449+
def __enter__(self):
450+
return self
451+
452+
def __exit__(self, *args, **kwargs):
453+
pass
454+
455+
def add_task(self, *args, **kwargs):
456+
pass
457+
458+
def update(self, *args, **kwargs):
459+
pass
460+
461+
462+
@contextmanager
463+
def Progress(console):
464+
try:
465+
from rich.progress import Progress
466+
467+
with Progress(console=console) as progress:
468+
yield progress
469+
except ImportError:
470+
with NoOpProgress() as progress:
471+
yield progress

tidy3d/plugins/dispersion/fit_fast.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
import numpy as np
88
from pydantic.v1 import NonNegativeFloat, PositiveInt
99

10-
from ...components.dispersion_fitter import AdvancedFastFitterParam, fit
10+
from ...components.dispersion_fitter import (
11+
AdvancedFastFitterParam,
12+
constant_loss_tangent_model,
13+
fit,
14+
)
1115
from ...components.medium import PoleResidue
12-
from ...constants import C_0, HBAR
16+
from ...constants import HBAR
1317
from .fit import DispersionFitter
1418

1519
# numerical tolerance for pole relocation for fast fitter
@@ -144,15 +148,18 @@ def constant_loss_tangent_model(
144148
:class:`.PoleResidue
145149
Best results of multiple fits.
146150
"""
147-
if number_sampling_frequency < 2:
148-
frequencies = np.array([np.mean(frequency_range)])
149-
else:
150-
frequencies = np.linspace(
151-
frequency_range[0], frequency_range[1], number_sampling_frequency
152-
)
153-
wvl_um = C_0 / frequencies
154-
eps_real_array = np.ones_like(frequencies) * eps_real
155-
loss_tangent_array = np.ones_like(frequencies) * loss_tangent
156-
fitter = cls.from_loss_tangent(wvl_um, eps_real_array, loss_tangent_array)
157-
material, _ = fitter.fit(max_num_poles=max_num_poles, tolerance_rms=tolerance_rms)
158-
return material
151+
params, _ = constant_loss_tangent_model(
152+
eps_real=eps_real,
153+
loss_tangent=loss_tangent,
154+
frequency_range=frequency_range,
155+
max_num_poles=max_num_poles,
156+
number_sampling_frequency=number_sampling_frequency,
157+
tolerance_rms=tolerance_rms,
158+
scale_factor=HBAR,
159+
)
160+
161+
eps_inf, poles, residues = params
162+
163+
medium = PoleResidue(eps_inf=eps_inf, poles=list(zip(poles, residues)))
164+
165+
return medium

0 commit comments

Comments
 (0)