Skip to content

Commit 2bbec11

Browse files
Speed up layer refinement by caching and skip exterior structures
1 parent 35edc0b commit 2bbec11

File tree

2 files changed

+93
-55
lines changed

2 files changed

+93
-55
lines changed

tidy3d/components/grid/grid_spec.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -1226,8 +1226,8 @@ def center_axis(self) -> float:
12261226

12271227
@cached_property
12281228
def _is_inplane_unbounded(self) -> bool:
1229-
"""Whether the layer is unbounded in inplane dimensions."""
1230-
return np.isinf(self.size[(self.axis + 1) % 3]) and np.isinf(self.size[(self.axis + 2) % 3])
1229+
"""Whether the layer is unbounded in any of the inplane dimensions."""
1230+
return np.isinf(self.size[(self.axis + 1) % 3]) or np.isinf(self.size[(self.axis + 2) % 3])
12311231

12321232
def _unpop_axis(self, ax_coord: float, plane_coord: Any) -> CoordinateOptional:
12331233
"""Combine coordinate along axis with identical coordinates on the plane tangential to the axis.
@@ -1318,7 +1318,14 @@ def _corners(self, structure_list: List[Structure]) -> List[CoordinateOptional]:
13181318
if self.corner_finder is None:
13191319
return []
13201320

1321-
inplane_points = self.corner_finder.corners(self.axis, self.center_axis, structure_list)
1321+
# filter structures outside the layer
1322+
structures_intersect = structure_list
1323+
if not self._is_inplane_unbounded:
1324+
structures_intersect = [s for s in structure_list if self.intersects(s.geometry)]
1325+
inplane_points = self.corner_finder.corners(
1326+
self.axis, self.center_axis, structures_intersect
1327+
)
1328+
13221329
# filter corners outside the inplane bounds
13231330
if not self._is_inplane_unbounded:
13241331
inplane_points = [point for point in inplane_points if self._inplane_inside(point)]
@@ -1625,22 +1632,30 @@ def internal_snapping_points(self, structures: List[Structure]) -> List[Coordina
16251632
snapping_points += layer_spec.generate_snapping_points(list(structures))
16261633
return snapping_points
16271634

1628-
def all_snapping_points(self, structures: List[Structure]) -> List[CoordinateOptional]:
1635+
def all_snapping_points(
1636+
self,
1637+
structures: List[Structure],
1638+
internal_snapping_points: List[CoordinateOptional] = None,
1639+
) -> List[CoordinateOptional]:
16291640
"""Internal and external snapping points. External snapping points take higher priority.
16301641
So far, internal snapping points are generated by `layer_refinement_specs`.
16311642
16321643
Parameters
16331644
----------
16341645
structures : List[Structure]
16351646
List of physical structures.
1647+
internal_snapping_points : List[CoordinateOptional]
1648+
If `None`, recomputes internal snapping points.
16361649
16371650
Returns
16381651
-------
16391652
List[CoordinateOptional]
16401653
List of snapping points coordinates.
16411654
"""
16421655

1643-
return self.internal_snapping_points(structures) + list(self.snapping_points)
1656+
if internal_snapping_points is None:
1657+
return self.internal_snapping_points(structures) + list(self.snapping_points)
1658+
return internal_snapping_points + list(self.snapping_points)
16441659

16451660
@property
16461661
def external_override_structures(self) -> List[StructureType]:
@@ -1677,7 +1692,11 @@ def internal_override_structures(
16771692
return override_structures
16781693

16791694
def all_override_structures(
1680-
self, structures: List[Structure], wavelength: pd.PositiveFloat, sim_size: Tuple[float, 3]
1695+
self,
1696+
structures: List[Structure],
1697+
wavelength: pd.PositiveFloat,
1698+
sim_size: Tuple[float, 3],
1699+
internal_override_structures: List[MeshOverrideStructure] = None,
16811700
) -> List[StructureType]:
16821701
"""Internal and external mesh override structures. External override structures take higher priority.
16831702
So far, internal override structures all come from `layer_refinement_specs`.
@@ -1688,17 +1707,22 @@ def all_override_structures(
16881707
List of structures, with the simulation structure being the first item.
16891708
wavelength : pd.PositiveFloat
16901709
Wavelength to use for minimal step size in vaccum.
1710+
internal_override_structures : List[MeshOverrideStructure]
1711+
If `None`, recomputes internal override structures.
16911712
16921713
Returns
16931714
-------
16941715
List[StructureType]
16951716
List of override structures.
16961717
"""
16971718

1698-
return (
1699-
self.internal_override_structures(structures, wavelength, sim_size)
1700-
+ self.external_override_structures
1701-
)
1719+
if internal_override_structures is None:
1720+
return (
1721+
self.internal_override_structures(structures, wavelength, sim_size)
1722+
+ self.external_override_structures
1723+
)
1724+
1725+
return internal_override_structures + self.external_override_structures
17021726

17031727
def _min_vacuum_dl_in_autogrid(self, wavelength: float, sim_size: Tuple[float, 3]) -> float:
17041728
"""Compute grid step size in vacuum for Autogrd. If AutoGrid is applied along more than 1 dimension,
@@ -1757,6 +1781,8 @@ def make_grid(
17571781
periodic: Tuple[bool, bool, bool],
17581782
sources: List[SourceType],
17591783
num_pml_layers: List[Tuple[pd.NonNegativeInt, pd.NonNegativeInt]],
1784+
internal_override_structures: List[MeshOverrideStructure] = None,
1785+
internal_snapping_points: List[CoordinateOptional] = None,
17601786
) -> Grid:
17611787
"""Make the entire simulation grid based on some simulation parameters.
17621788
@@ -1772,6 +1798,10 @@ def make_grid(
17721798
List of sources.
17731799
num_pml_layers : List[Tuple[float, float]]
17741800
List containing the number of absorber layers in - and + boundaries.
1801+
internal_override_structures : List[MeshOverrideStructure]
1802+
If `None`, recomputes internal override structures.
1803+
internal_snapping_points : List[CoordinateOptional]
1804+
If `None`, recomputes internal snapping points.
17751805
17761806
Returns
17771807
-------
@@ -1830,7 +1860,10 @@ def make_grid(
18301860

18311861
sim_size = list(structures[0].geometry.size)
18321862
all_structures = list(structures) + self.all_override_structures(
1833-
list(structures), wavelength, sim_size
1863+
list(structures),
1864+
wavelength,
1865+
sim_size,
1866+
internal_override_structures,
18341867
)
18351868

18361869
# apply internal `dl_min` if any AutoGrid has unset `dl_min`
@@ -1856,7 +1889,7 @@ def make_grid(
18561889
periodic=periodic[idim],
18571890
wavelength=wavelength,
18581891
num_pml_layers=num_pml_layers[idim],
1859-
snapping_points=self.all_snapping_points(structures),
1892+
snapping_points=self.all_snapping_points(structures, internal_snapping_points),
18601893
)
18611894

18621895
coords = Coords(**coords_dict)

tidy3d/components/simulation.py

+48-43
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,17 @@
9898
from .source.utils import SourceType
9999
from .structure import MeshOverrideStructure, Structure
100100
from .subpixel_spec import SubpixelSpec
101-
from .types import TYPE_TAG_STR, Ax, Axis, FreqBound, InterpMethod, Literal, Symmetry, annotate_type
101+
from .types import (
102+
TYPE_TAG_STR,
103+
Ax,
104+
Axis,
105+
CoordinateOptional,
106+
FreqBound,
107+
InterpMethod,
108+
Literal,
109+
Symmetry,
110+
annotate_type,
111+
)
102112
from .validators import (
103113
assert_objects_in_sim_bounds,
104114
validate_mode_objects_symmetry,
@@ -722,6 +732,31 @@ def pml_thicknesses(self) -> List[Tuple[float, float]]:
722732

723733
return pml_thicknesses
724734

735+
@cached_property
736+
def internal_override_structures(self) -> List[MeshOverrideStructure]:
737+
"""Internal mesh override structures. So far, internal override structures all come from `layer_refinement_specs`.
738+
739+
Returns
740+
-------
741+
List[MeshOverrideSructure]
742+
List of override structures.
743+
"""
744+
wavelength = self.grid_spec.get_wavelength(self.sources)
745+
return self.grid_spec.internal_override_structures(
746+
self.scene.all_structures, wavelength, self.geometry.size
747+
)
748+
749+
@cached_property
750+
def internal_snapping_points(self) -> List[CoordinateOptional]:
751+
"""Internal snapping points. So far, internal snapping points are generated by `layer_refinement_specs`.
752+
753+
Returns
754+
-------
755+
List[CoordinateOptional]
756+
List of snapping points coordinates.
757+
"""
758+
return self.grid_spec.internal_snapping_points(self.scene.all_structures)
759+
725760
@equal_aspect
726761
@add_ax_if_none
727762
def plot_lumped_elements(
@@ -842,12 +877,8 @@ def plot_grid(
842877
plot_params[0] = plot_params[0].include_kwargs(edgecolor=kwargs["colors_internal"])
843878

844879
if self.grid_spec.auto_grid_used:
845-
wavelength = self.grid_spec.get_wavelength(self.sources)
846-
internal_override_structures = self.grid_spec.internal_override_structures(
847-
self.scene.all_structures, wavelength, self.geometry.size
848-
)
849880
all_override_structures = [
850-
internal_override_structures,
881+
self.internal_override_structures,
851882
self.grid_spec.external_override_structures,
852883
]
853884
for structures, plot_param in zip(all_override_structures, plot_params):
@@ -867,11 +898,8 @@ def plot_grid(
867898
ax.add_patch(rect)
868899

869900
# Plot snapping points
870-
internal_snapping_points = self.grid_spec.internal_snapping_points(
871-
self.scene.all_structures
872-
)
873901
for points, plot_param in zip(
874-
[internal_snapping_points, self.grid_spec.snapping_points], plot_params
902+
[self.internal_snapping_points, self.grid_spec.snapping_points], plot_params
875903
):
876904
for point in points:
877905
_, (x_point, y_point) = Geometry.pop_axis(point, axis=axis)
@@ -1056,20 +1084,27 @@ def grid(self) -> Grid:
10561084

10571085
# Add a simulation Box as the first structure
10581086
structures = [Structure(geometry=self.geometry, medium=self.medium)]
1059-
structures += self.structures
1087+
structures += self.static_structures
10601088

10611089
grid = self.grid_spec.make_grid(
10621090
structures=structures,
10631091
symmetry=self.symmetry,
10641092
periodic=self._periodic,
10651093
sources=self.sources,
10661094
num_pml_layers=self.num_pml_layers,
1095+
internal_snapping_points=self.internal_snapping_points,
1096+
internal_override_structures=self.internal_override_structures,
10671097
)
10681098

10691099
# This would AutoGrid the in-plane directions of the 2D materials
10701100
# return self._grid_corrections_2dmaterials(grid)
10711101
return grid
10721102

1103+
@cached_property
1104+
def static_structures(self) -> list[Structure]:
1105+
"""Structures in simulation with all autograd tracers removed."""
1106+
return [structure.to_static() for structure in self.structures]
1107+
10731108
@cached_property
10741109
def num_cells(self) -> int:
10751110
"""Number of cells in the simulation.
@@ -3774,11 +3809,6 @@ def _validate_time_monitors_num_steps(self) -> None:
37743809
"at which the monitor stores data."
37753810
)
37763811

3777-
@cached_property
3778-
def static_structures(self) -> list[Structure]:
3779-
"""Structures in simulation with all autograd tracers removed."""
3780-
return [structure.to_static() for structure in self.structures]
3781-
37823812
@cached_property
37833813
def monitors_data_size(self) -> Dict[str, float]:
37843814
"""Dictionary mapping monitor names to their estimated storage size in bytes."""
@@ -4550,6 +4580,8 @@ def _grid_corrections_2dmaterials(self, grid: Grid) -> Grid:
45504580
symmetry=self.symmetry,
45514581
sources=self.sources,
45524582
num_pml_layers=self.num_pml_layers,
4583+
internal_snapping_points=self.internal_snapping_points,
4584+
internal_override_structures=self.internal_override_structures,
45534585
)
45544586

45554587
# Handle 2D materials if ``AutoGrid`` is used for in-plane directions
@@ -4587,33 +4619,6 @@ def _grid_corrections_2dmaterials(self, grid: Grid) -> Grid:
45874619

45884620
return Grid(boundaries=Coords(**dict(zip("xyz", coords_all))))
45894621

4590-
@cached_property
4591-
def grid(self) -> Grid:
4592-
"""FDTD grid spatial locations and information.
4593-
4594-
Returns
4595-
-------
4596-
:class:`.Grid`
4597-
:class:`.Grid` storing the spatial locations relevant to the simulation.
4598-
"""
4599-
4600-
# Add a simulation Box as the first structure
4601-
structures = [Structure(geometry=self.geometry, medium=self.medium)]
4602-
4603-
structures += self.static_structures
4604-
4605-
grid = self.grid_spec.make_grid(
4606-
structures=structures,
4607-
symmetry=self.symmetry,
4608-
periodic=self._periodic,
4609-
sources=self.sources,
4610-
num_pml_layers=self.num_pml_layers,
4611-
)
4612-
4613-
# This would AutoGrid the in-plane directions of the 2D materials
4614-
# return self._grid_corrections_2dmaterials(grid)
4615-
return grid
4616-
46174622
@cached_property
46184623
def num_cells(self) -> int:
46194624
"""Number of cells in the simulation grid.

0 commit comments

Comments
 (0)