Skip to content

Commit a69db92

Browse files
committed
Add Learner2D loss function 'thresholded_loss_factory'
1 parent 067444f commit a69db92

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

adaptive/learner/learner2D.py

+58
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,64 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
231231
return losses
232232

233233

234+
def thresholded_loss_factory(
235+
lower_threshold: float | None = None,
236+
upper_threshold: float | None = None,
237+
priority_factor: float = 0.1,
238+
) -> Callable[[LinearNDInterpolator], np.ndarray]:
239+
"""
240+
Factory function to create a custom loss function that deprioritizes
241+
values above an upper threshold and below a lower threshold.
242+
243+
Parameters
244+
----------
245+
lower_threshold : float, optional
246+
The lower threshold for deprioritizing values. If None (default),
247+
there is no lower threshold.
248+
upper_threshold : float, optional
249+
The upper threshold for deprioritizing values. If None (default),
250+
there is no upper threshold.
251+
priority_factor : float, default: 0.1
252+
The factor by which the loss is multiplied for values outside
253+
the specified thresholds.
254+
255+
Returns
256+
-------
257+
custom_loss : Callable[[LinearNDInterpolator], np.ndarray]
258+
A custom loss function that can be used with Learner2D.
259+
"""
260+
261+
def custom_loss(ip: LinearNDInterpolator) -> np.ndarray:
262+
"""Loss function that deprioritizes values above an upper and lower threshold.
263+
264+
Parameters
265+
----------
266+
ip : `scipy.interpolate.LinearNDInterpolator` instance
267+
268+
Returns
269+
-------
270+
losses : numpy.ndarray
271+
Loss per triangle in ``ip.tri``.
272+
"""
273+
losses = default_loss(ip)
274+
275+
if lower_threshold is not None or upper_threshold is not None:
276+
values = ip.values
277+
if lower_threshold is not None:
278+
mask_lower = values < lower_threshold
279+
if mask_lower.any():
280+
losses[mask_lower] *= priority_factor
281+
282+
if upper_threshold is not None:
283+
mask_upper = values > upper_threshold
284+
if mask_upper.any():
285+
losses[mask_upper] *= priority_factor
286+
287+
return losses
288+
289+
return custom_loss
290+
291+
234292
def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarray:
235293
"""Choose a new point in inside a triangle.
236294

0 commit comments

Comments
 (0)