@@ -231,6 +231,64 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
231
231
return losses
232
232
233
233
234
+ def thresholded_loss_factory (
235
+ lower_threshold : float | None = None ,
236
+ upper_threshold : float | None = None ,
237
+ priority_factor : float = 0.1 ,
238
+ ) -> Callable [[LinearNDInterpolator ], np .ndarray ]:
239
+ """
240
+ Factory function to create a custom loss function that deprioritizes
241
+ values above an upper threshold and below a lower threshold.
242
+
243
+ Parameters
244
+ ----------
245
+ lower_threshold : float, optional
246
+ The lower threshold for deprioritizing values. If None (default),
247
+ there is no lower threshold.
248
+ upper_threshold : float, optional
249
+ The upper threshold for deprioritizing values. If None (default),
250
+ there is no upper threshold.
251
+ priority_factor : float, default: 0.1
252
+ The factor by which the loss is multiplied for values outside
253
+ the specified thresholds.
254
+
255
+ Returns
256
+ -------
257
+ custom_loss : Callable[[LinearNDInterpolator], np.ndarray]
258
+ A custom loss function that can be used with Learner2D.
259
+ """
260
+
261
+ def custom_loss (ip : LinearNDInterpolator ) -> np .ndarray :
262
+ """Loss function that deprioritizes values above an upper and lower threshold.
263
+
264
+ Parameters
265
+ ----------
266
+ ip : `scipy.interpolate.LinearNDInterpolator` instance
267
+
268
+ Returns
269
+ -------
270
+ losses : numpy.ndarray
271
+ Loss per triangle in ``ip.tri``.
272
+ """
273
+ losses = default_loss (ip )
274
+
275
+ if lower_threshold is not None or upper_threshold is not None :
276
+ values = ip .values
277
+ if lower_threshold is not None :
278
+ mask_lower = values < lower_threshold
279
+ if mask_lower .any ():
280
+ losses [mask_lower ] *= priority_factor
281
+
282
+ if upper_threshold is not None :
283
+ mask_upper = values > upper_threshold
284
+ if mask_upper .any ():
285
+ losses [mask_upper ] *= priority_factor
286
+
287
+ return losses
288
+
289
+ return custom_loss
290
+
291
+
234
292
def choose_point_in_triangle (triangle : np .ndarray , max_badness : int ) -> np .ndarray :
235
293
"""Choose a new point in inside a triangle.
236
294
0 commit comments