diff --git a/adaptive/learner/learner2D.py b/adaptive/learner/learner2D.py index a2aec2069..1ea381794 100644 --- a/adaptive/learner/learner2D.py +++ b/adaptive/learner/learner2D.py @@ -231,6 +231,65 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray: return losses +def thresholded_loss_function( + lower_threshold: float | None = None, + upper_threshold: float | None = None, + priority_factor: float = 0.1, +) -> Callable[[LinearNDInterpolator], np.ndarray]: + """ + Factory function to create a custom loss function that deprioritizes + values above an upper threshold and below a lower threshold. + + Parameters + ---------- + lower_threshold : float, optional + The lower threshold for deprioritizing values. If None (default), + there is no lower threshold. + upper_threshold : float, optional + The upper threshold for deprioritizing values. If None (default), + there is no upper threshold. + priority_factor : float, default: 0.1 + The factor by which the loss is multiplied for values outside + the specified thresholds. + + Returns + ------- + custom_loss : Callable[[LinearNDInterpolator], np.ndarray] + A custom loss function that can be used with Learner2D. + """ + + def custom_loss(ip: LinearNDInterpolator) -> np.ndarray: + """Loss function that deprioritizes values outside an upper and lower threshold. + + Parameters + ---------- + ip : `scipy.interpolate.LinearNDInterpolator` instance + + Returns + ------- + losses : numpy.ndarray + Loss per triangle in ``ip.tri``. + """ + losses = default_loss(ip) + + if lower_threshold is not None or upper_threshold is not None: + simplices = ip.tri.simplices + values = ip.values[simplices] + if lower_threshold is not None: + mask_lower = (values < lower_threshold).all(axis=(1, -1)) + if mask_lower.any(): + losses[mask_lower] *= priority_factor + + if upper_threshold is not None: + mask_upper = (values > upper_threshold).all(axis=(1, -1)) + if mask_upper.any(): + losses[mask_upper] *= priority_factor + + return losses + + return custom_loss + + def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarray: """Choose a new point in inside a triangle. diff --git a/adaptive/tests/test_learners.py b/adaptive/tests/test_learners.py index 17af7f9b8..e32aa75ef 100644 --- a/adaptive/tests/test_learners.py +++ b/adaptive/tests/test_learners.py @@ -53,6 +53,7 @@ adaptive.learner.learner2D.uniform_loss, adaptive.learner.learner2D.minimize_triangle_surface_loss, adaptive.learner.learner2D.resolution_loss_function(), + adaptive.learner.learner2D.thresholded_loss_function(upper_threshold=0.5), ), ), LearnerND: (