diff --git a/CHANGELOG.md b/CHANGELOG.md index 90aa086ca..b96efbcc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- ENH: Function Validation Rework & Swap `np.searchsorted` to `bisect_left` [#582](https://github.com/RocketPy-Team/RocketPy/pull/582) - ENH: Add new stability margin properties to Flight class [#572](https://github.com/RocketPy-Team/RocketPy/pull/572) - ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index c887a904a..ea98bff78 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -5,6 +5,7 @@ """ import warnings +from bisect import bisect_left from collections.abc import Iterable from copy import deepcopy from inspect import signature @@ -20,6 +21,14 @@ from ..tools import cached_property NUMERICAL_TYPES = (float, int, complex, np.ndarray, np.integer, np.floating) +INTERPOLATION_TYPES = { + "linear": 0, + "polynomial": 1, + "akima": 2, + "spline": 3, + "shepard": 4, +} +EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2} class Function: @@ -110,27 +119,20 @@ def __init__( (II) Fields in CSV files may be enclosed in double quotes. If fields are not quoted, double quotes should not appear inside them. """ - - inputs, outputs = Function._validate_inputs_outputs(inputs, outputs) - - # initialize variables to avoid errors when being called by other methods - self.get_value_opt = None - self.__polynomial_coefficients__ = None - self.__akima_coefficients__ = None - self.__spline_coefficients__ = None - self.EXTRAPOLATION_TYPES = { - "zero": 0, - "natural": 1, - "constant": 2, - } - - # store variables - self.set_inputs(inputs) - self.set_outputs(outputs) + # initialize parameters + self.source = source + self.__inputs__ = inputs + self.__outputs__ = outputs self.__interpolation__ = interpolation self.__extrapolation__ = extrapolation - self.set_source(source) - self.set_title(title) + self.title = title + self.__img_dim__ = 1 # always 1, here for backwards compatibility + + # args must be passed from self. + self.set_source(self.source) + self.set_inputs(self.__inputs__) + self.set_outputs(self.__outputs__) + self.set_title(self.title) # Define all set methods def set_inputs(self, inputs): @@ -145,8 +147,7 @@ def set_inputs(self, inputs): ------- self : Function """ - self.__inputs__ = [inputs] if isinstance(inputs, str) else list(inputs) - self.__dom_dim__ = len(self.__inputs__) + self.__inputs__ = self.__validate_inputs(inputs) return self def set_outputs(self, outputs): @@ -161,8 +162,7 @@ def set_outputs(self, outputs): ------- self : Function """ - self.__outputs__ = [outputs] if isinstance(outputs, str) else list(outputs) - self.__img_dim__ = len(self.__outputs__) + self.__outputs__ = self.__validate_outputs(outputs) return self def set_source(self, source): @@ -211,20 +211,10 @@ def set_source(self, source): self : Function Returns the Function instance. """ - source, inputs, outputs, interpolation, extrapolation = self._check_user_input( - source, - self.__inputs__, - self.__outputs__, - self.__interpolation__, - self.__extrapolation__, - ) - # updates inputs and outputs (could be modified due to csv headers) - self.set_inputs(inputs) - self.set_outputs(outputs) + source = self.__validate_source(source) # Handle callable source or number source if callable(source): - self.source = source self.get_value_opt = source self.__interpolation__ = None self.__extrapolation__ = None @@ -232,38 +222,34 @@ def set_source(self, source): # Set arguments name and domain dimensions parameters = signature(source).parameters self.__dom_dim__ = len(parameters) - if self.__inputs__ == ["Scalar"]: + if self.__inputs__ is None: self.__inputs__ = list(parameters) # Handle ndarray source else: - # Check to see if dimensions match incoming data set - new_total_dim = len(source[0, :]) - old_total_dim = self.__dom_dim__ + self.__img_dim__ - - # If they don't, update default values or throw error - if new_total_dim != old_total_dim: - # Update dimensions and inputs - self.__dom_dim__ = new_total_dim - 1 - self.__inputs__ = self.__dom_dim__ * self.__inputs__ + # Evaluate dimension + self.__dom_dim__ = source.shape[1] - 1 - # if Function is 1D, sort source by x. If 2D, set z + # set x and y. If Function is 2D, also set z if self.__dom_dim__ == 1: source = source[source[:, 0].argsort()] - elif self.__dom_dim__ == 2: + self.x_array = source[:, 0] + self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] + self.y_array = source[:, 1] + self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] + self.get_value_opt = self.__get_value_opt_1d + elif self.__dom_dim__ > 1: + self.x_array = source[:, 0] + self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] + self.y_array = source[:, 1] + self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] self.z_array = source[:, 2] self.z_initial, self.z_final = self.z_array[0], self.z_array[-1] + self.get_value_opt = self.__get_value_opt_nd - # Set x and y arrays (common for 1D or multivariate) - self.x_array = source[:, 0] - self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] - self.y_array = source[:, 1] - self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] - - # Finally set source, update extrapolation and interpolation - self.source = source - self.__extrapolation__ = extrapolation # to avoid calling set_get_value_opt - self.set_interpolation(interpolation) + self.source = source + self.set_interpolation(self.__interpolation__) + self.set_extrapolation(self.__extrapolation__) return self @cached_property @@ -303,18 +289,27 @@ def set_interpolation(self, method="spline"): ------- self : Function """ - self.__interpolation__ = method + if not callable(self.source): + self.__interpolation__ = self.__validate_interpolation(method) + self.__update_interpolation_coefficients(self.__interpolation__) + self.__set_interpolation_func() + return self + + def __update_interpolation_coefficients(self, method): + """Update interpolation coefficients for the given method.""" # Spline, akima and polynomial need data processing # Shepard, and linear do not - if method == "spline": - self.__interpolate_spline__() - elif method == "polynomial": + if method == "polynomial": self.__interpolate_polynomial__() + self._coeffs = self.__polynomial_coefficients__ elif method == "akima": self.__interpolate_akima__() - - self.set_get_value_opt() - return self + self._coeffs = self.__akima_coefficients__ + elif method == "spline" or method is None: + self.__interpolate_spline__() + self._coeffs = self.__spline_coefficients__ + else: + self._coeffs = [] def set_extrapolation(self, method="constant"): """Set extrapolation behavior of data set. @@ -333,141 +328,163 @@ def set_extrapolation(self, method="constant"): self : Function The Function object. """ - self.__extrapolation__ = method - self.set_get_value_opt() + if not callable(self.source): + self.__extrapolation__ = self.__validate_extrapolation(method) + self.__set_extrapolation_func() return self + def __set_interpolation_func(self): + """Defines interpolation function used by the Function. Each + interpolation method has its own function with exception of shepard, + which has its interpolation/extrapolation function defined in + ``Function.__interpolate_shepard__``. The function is stored in + the attribute _interpolation_func.""" + interpolation = INTERPOLATION_TYPES[self.__interpolation__] + if interpolation == 0: # linear + + def linear_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = bisect_left(x_data, x) + x_left = x_data[x_interval - 1] + y_left = y_data[x_interval - 1] + dx = float(x_data[x_interval] - x_left) + dy = float(y_data[x_interval] - y_left) + return (x - x_left) * (dy / dx) + y_left + + self._interpolation_func = linear_interpolation + + elif interpolation == 1: # polynomial + + def polynomial_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + return np.sum(coeffs * x ** np.arange(len(coeffs))) + + self._interpolation_func = polynomial_interpolation + + elif interpolation == 2: # akima + + def akima_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = bisect_left(x_data, x) + x_interval = x_interval if x_interval != 0 else 1 + a = coeffs[4 * x_interval - 4 : 4 * x_interval] + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] + + self._interpolation_func = akima_interpolation + + elif interpolation == 3: # spline + + def spline_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = bisect_left(x_data, x) + x_interval = max(x_interval, 1) + a = coeffs[:, x_interval - 1] + x = x - x_data[x_interval - 1] + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] + + self._interpolation_func = spline_interpolation + + elif interpolation == 4: # shepard does not use interpolation function + self._interpolation_func = None + + def __set_extrapolation_func(self): + """Defines extrapolation function used by the Function. Each + extrapolation method has its own function. The function is stored in + the attribute _extrapolation_func.""" + interpolation = INTERPOLATION_TYPES[self.__interpolation__] + extrapolation = EXTRAPOLATION_TYPES[self.__extrapolation__] + + if interpolation == 4: # shepard does not use extrapolation function + self._extrapolation_func = None + + elif extrapolation == 0: # zero + + def zero_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + return 0 + + self._extrapolation_func = zero_extrapolation + elif extrapolation == 1: # natural + if interpolation == 0: # linear + + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = 1 if x < x_min else -1 + x_left = x_data[x_interval - 1] + y_left = y_data[x_interval - 1] + dx = float(x_data[x_interval] - x_left) + dy = float(y_data[x_interval] - y_left) + return (x - x_left) * (dy / dx) + y_left + + elif interpolation == 1: # polynomial + + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + return np.sum(coeffs * x ** np.arange(len(coeffs))) + + elif interpolation == 2: # akima + + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + a = coeffs[:4] if x < x_min else coeffs[-4:] + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] + + elif interpolation == 3: # spline + + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + if x < x_min: + a = coeffs[:, 0] + x = x - x_data[0] + else: + a = coeffs[:, -1] + x = x - x_data[-2] + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] + + self._extrapolation_func = natural_extrapolation + elif extrapolation == 2: # constant + + def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + return y_data[0] if x < x_min else y_data[-1] + + self._extrapolation_func = constant_extrapolation + def set_get_value_opt(self): - """Crates a method that evaluates interpolations rather quickly - when compared to other options available, such as just calling - the object instance or calling ``Function.get_value`` directly. See - ``Function.get_value_opt`` for documentation. + """Defines a method that evaluates interpolations. Returns ------- self : Function """ + if callable(self.source): + self.get_value_opt = self.source + elif self.__dom_dim__ == 1: + self.get_value_opt = self.__get_value_opt_1d + elif self.__dom_dim__ > 1: + self.get_value_opt = self.__get_value_opt_nd + return self + + def __get_value_opt_1d(self, x): + """Evaluate the Function at a single point x. This method is used + when the Function is 1-D. + + Parameters + ---------- + x : scalar + Value where the Function is to be evaluated. + + Returns + ------- + y : scalar + Value of the Function at the specified point. + """ # Retrieve general info x_data = self.x_array y_data = self.y_array x_min, x_max = self.x_initial, self.x_final - try: - extrapolation = self.EXTRAPOLATION_TYPES[self.__extrapolation__] - except KeyError as err: - raise ValueError( - f"Invalid extrapolation type '{self.__extrapolation__}'" - ) from err - - # Crete method to interpolate this info for each interpolation type - if self.__interpolation__ == "spline": - coeffs = self.__spline_coefficients__ - - def get_value_opt(x): - if x_min <= x <= x_max: - # Interpolate - x_interval = np.searchsorted(x_data, x) - x_interval = max(x_interval, 1) - a = coeffs[:, x_interval - 1] - x = x - x_data[x_interval - 1] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - if x < x_min: - a = coeffs[:, 0] - x = x - x_data[0] - else: - a = coeffs[:, -1] - x = x - x_data[-2] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y - - elif self.__interpolation__ == "linear": - - def get_value_opt(x): - if x_min <= x <= x_max: - # Interpolate - x_interval = np.searchsorted(x_data, x) - x_left = x_data[x_interval - 1] - y_left = y_data[x_interval - 1] - dx = float(x_data[x_interval] - x_left) - dy = float(y_data[x_interval] - y_left) - y = (x - x_left) * (dy / dx) + y_left - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - x_interval = 1 if x < x_min else -1 - x_left = x_data[x_interval - 1] - y_left = y_data[x_interval - 1] - dx = float(x_data[x_interval] - x_left) - dy = float(y_data[x_interval] - y_left) - y = (x - x_left) * (dy / dx) + y_left - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y - - elif self.__interpolation__ == "akima": - coeffs = np.array(self.__akima_coefficients__) - - def get_value_opt(x): - if x_min <= x <= x_max: - # Interpolate - x_interval = np.searchsorted(x_data, x) - x_interval = x_interval if x_interval != 0 else 1 - a = coeffs[4 * x_interval - 4 : 4 * x_interval] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - a = coeffs[:4] if x < x_min else coeffs[-4:] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y - - elif self.__interpolation__ == "polynomial": - coeffs = self.__polynomial_coefficients__ - - def get_value_opt(x): - # Interpolate... or extrapolate - if x_min <= x <= x_max: - # Interpolate - y = np.sum(coeffs * x ** np.arange(len(coeffs))) - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - y = 0 - for i, coef in enumerate(coeffs): - y += coef * (x**i) - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y - - elif self.__interpolation__ == "shepard": - # change the function's name to avoid mypy's error - def get_value_opt_multiple(*args): - return self.__interpolate_shepard__(args) - - get_value_opt = get_value_opt_multiple + coeffs = self._coeffs + if x_min <= x <= x_max: + y = self._interpolation_func(x, x_min, x_max, x_data, y_data, coeffs) + else: + y = self._extrapolation_func(x, x_min, x_max, x_data, y_data, coeffs) + return y - try: - self.get_value_opt = get_value_opt - except UnboundLocalError: - warnings.warn( - "Cannot set the get_value_opt method when interpolation is " - f"{self.__interpolation}. Try using the set_interpolation method first." - ) - return self + def __get_value_opt_nd(self, *args): + """Evaluate the Function at a single point (x, y, z). This method is + used when the Function is N-D.""" + # always use shepard for N-D functions + return self.__interpolate_shepard__(args) def set_discrete( self, @@ -534,7 +551,7 @@ def set_discrete( ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs) func.set_source(np.concatenate(([xs], [ys])).transpose()) func.set_interpolation(interpolation) - func.__extrapolation__ = extrapolation # avoid calling set_get_value_opt + func.set_extrapolation(extrapolation) elif func.__dom_dim__ == 2: lower = 2 * [lower] if isinstance(lower, (int, float)) else lower upper = 2 * [upper] if isinstance(upper, (int, float)) else upper @@ -673,7 +690,7 @@ def set_discrete_based_on_model( ) func.set_interpolation(interp) - func.__extrapolation__ = extrap # avoid calling set_get_value_opt + func.set_extrapolation(extrap) return func @@ -732,7 +749,7 @@ def reset( if interpolation is not None and interpolation != self.__interpolation__: self.set_interpolation(interpolation) if extrapolation is not None and extrapolation != self.__extrapolation__: - self.__extrapolation__ = extrapolation + self.set_extrapolation(extrapolation) self.set_title(title) @@ -862,9 +879,8 @@ def get_value(self, *args): if all(isinstance(arg, Iterable) for arg in args): return [self.source(*arg) for arg in zip(*args)] - # Returns value for shepard interpolation - elif self.__interpolation__ == "shepard": - return self.__interpolate_shepard__(args) + elif self.__dom_dim__ > 1: # deals with nd functions and shepard interp + return self.get_value_opt(*args) # Returns value for other interpolation type else: # interpolation is "polynomial", "spline", "akima" or "linear" @@ -2857,79 +2873,31 @@ def savetxt( np.savetxt(file, data_points, fmt=fmt, delimiter=delimiter, newline=newline) # Input validators - - @staticmethod - def _check_user_input( - source, - inputs=None, - outputs=None, - interpolation=None, - extrapolation=None, - ): - """ - Validates and processes the user input parameters for creating or - modifying a Function object. This function ensures the inputs, outputs, - interpolation, and extrapolation parameters are compatible with the - given source. It converts the source to a numpy array if necessary, sets - default values and raises warnings or errors for incompatible or - ill-defined parameters. + def __validate_source(self, source): + """Used to validate the source parameter for creating a Function object. Parameters ---------- - source : list, np.ndarray, or callable - The source data or Function object. If a list or ndarray, it should - contain numeric data. If a Function, its inputs and outputs are - checked against the provided inputs and outputs. - inputs : list of str or None - The names of the input variables. If None, defaults are generated - based on the dimensionality of the source. - outputs : str or list of str - The name(s) of the output variable(s). If a list is provided, it - must have a single element. - interpolation : str or None - The method of interpolation to be used. For multidimensional sources - it defaults to 'shepard' if not provided. - extrapolation : str or None - The method of extrapolation to be used. For multidimensional sources - it defaults to 'natural' if not provided. + source : np.ndarray, callable, str, Path, Function, list + The source data of the Function object. This can be a numpy array, + a callable function, a string or Path object to a csv or txt file, + a Function object, or a list of numbers. Returns ------- - tuple - A tuple containing the processed source, inputs, outputs, - interpolation, and extrapolation parameters. + np.ndarray, callable + The validated source parameter. Raises ------ ValueError - If the dimensionality of the source does not match the combined - dimensions of inputs and outputs. If the outputs list has more than - one element. - - Examples - -------- - >>> from rocketpy import Function - >>> source = np.array([(1, 1), (2, 4), (3, 9)]) - >>> inputs = "x" - >>> outputs = ["y"] - >>> interpolation = 'linear' - >>> extrapolation = 'zero' - >>> inputs, outputs, interpolation, extrapolation = Function._check_user_input( - ... source, inputs, outputs, interpolation, extrapolation - ... ) - >>> inputs - ['x'] - >>> outputs - ['y'] - >>> interpolation - 'linear' - >>> extrapolation - 'zero' + If the source is not a valid type or if the source is not a 2D array + or a callable function. """ if isinstance(source, Function): - source = source.get_source() + return source.get_source() - elif isinstance(source, (str, Path)): + if isinstance(source, (str, Path)): # Read csv or txt files and create a numpy array try: source = np.loadtxt(source, delimiter=",", dtype=np.float64) @@ -2941,26 +2909,25 @@ def _check_user_input( source = np.loadtxt(data, delimiter=",", dtype=np.float64) if len(source[0]) == len(header): - if inputs == ["Scalar"]: - inputs = header[:-1] - if outputs == ["Scalar"]: - outputs = [header[-1]] + if self.__inputs__ is None: + self.__inputs__ = header[:-1] + if self.__outputs__ is None: + self.__outputs__ = [header[-1]] except Exception as e: raise ValueError( "Could not read the csv or txt file to create Function source." ) from e - if isinstance(source, list): + if isinstance(source, list) or isinstance(source, np.ndarray): # Triggers an error if source is not a list of numbers source = np.array(source, dtype=np.float64) - if isinstance(source, np.ndarray): - inputs, interpolation, extrapolation = ( - Function._validate_interpolation_and_extrapolation( - inputs, interpolation, extrapolation, source + # Checks if 2D array + if len(source.shape) != 2: + raise ValueError( + "Source must be a 2D array in the form [[x1, x2 ..., xn, y], ...]." ) - ) - Function._validate_source_dimensions(inputs, outputs, source) + return source if isinstance(source, (int, float)): # Convert number source into vectorized lambda function @@ -2969,86 +2936,78 @@ def _check_user_input( def source_function(_): return temp - source = source_function - return source, inputs, outputs, interpolation, extrapolation + return source_function - @staticmethod - def _validate_inputs_outputs(inputs, outputs): - """Used to validate the inputs and outputs parameters for creating a - Function object. It sets default values if they are not provided. + # If source is a callable function + return source + + def __validate_inputs(self, inputs): + """Used to validate the inputs parameter for creating a Function object. + It sets a default value if it is not provided. Parameters ---------- - inputs : str, list of str, None + inputs : list of str, None The name(s) of the input variable(s). If None, defaults to "Scalar". - outputs : - The name of the output variables. If None, defaults to "Scalar". Returns ------- - tuple - A tuple containing the validated inputs and outputs parameters. - - Raises - ------ - ValueError - If the output has more than one element. - """ - if inputs is None: - inputs = ["Scalar"] - if outputs is None: - outputs = ["Scalar"] - # check output type and dimensions - if isinstance(inputs, str): - inputs = [inputs] - if isinstance(outputs, str): - outputs = [outputs] - elif len(outputs) > 1: + list + The validated inputs parameter. + """ + if self.__dom_dim__ == 1: + if inputs is None: + return ["Scalar"] + if isinstance(inputs, str): + return [inputs] + if isinstance(inputs, (list, tuple)): + if len(inputs) == 1: + return inputs + raise ValueError( + "Inputs must be a string or a list of strings with " + "the length of the domain dimension." + ) + if self.__dom_dim__ > 1: + if inputs is None: + return [f"Input {i+1}" for i in range(self.__dom_dim__)] + if isinstance(inputs, list): + if len(inputs) == self.__dom_dim__ and all( + isinstance(i, str) for i in inputs + ): + return inputs raise ValueError( - "Output must either be a string or have dimension 1, " - + f"it currently has dimension ({len(outputs)})." + "Inputs must be a list of strings with " + "the length of the domain dimension." ) - return inputs, outputs - @staticmethod - def _validate_interpolation_and_extrapolation( - inputs, interpolation, extrapolation, source - ): - """Used to validate the interpolation and extrapolation methods for - creating a Function object. It sets default values for interpolation - and extrapolation if they are not provided or if they are not supported - for the given source. The inputs and outputs may be modified if the - source is multidimensional. + def __validate_outputs(self, outputs): + """Used to validate the outputs parameter for creating a Function object. + It sets a default value if it is not provided. Parameters ---------- - inputs : list of strings - List of inputs, each input is a string. Example: ['x', 'y'] - interpolation : str, None - The type of interpolation to use. The default method is 'spline'. - Currently supported values are 'spline', 'linear', 'polynomial', - 'akima', and 'shepard'. - extrapolation : str, None - The type of extrapolation to use. Currently supported values are - 'constant', 'natural', and 'zero'. The default method is 'constant'. - source : np.ndarray - The source data of the Function object. This has to be a numpy - array. + outputs : str, list of str, None + The name of the output variables. If None, defaults to "Scalar". Returns ------- - tuple - A tuple with the validated inputs, interpolation, and extrapolation - parameters (inputs, interpolation, extrapolation). - - Raises - ------ - ValueError - If the source has less than 2 dimensions. + list + The validated outputs parameter. """ - source_dim = source.shape[1] - ## single dimension (1D Functions) - if source_dim == 2: + if outputs is None: + return ["Scalar"] + if isinstance(outputs, str): + return [outputs] + if isinstance(outputs, (list, tuple)): + if len(outputs) > 1: + raise ValueError( + "Output must either be a string or a list of strings with " + + f"one item. It currently has dimension ({len(outputs)})." + ) + return outputs + + def __validate_interpolation(self, interpolation): + if self.__dom_dim__ == 1: # possible interpolation values: linear, polynomial, akima and spline if interpolation is None: interpolation = "spline" @@ -3063,8 +3022,20 @@ def _validate_interpolation_and_extrapolation( f"{interpolation} method is not supported." ) interpolation = "spline" + ## multiple dimensions + elif self.__dom_dim__ > 1: + if interpolation not in [None, "shepard"]: + warnings.warn( + ( + "Interpolation method set to 'shepard'. Only 'shepard' " + "interpolation is supported for multiple dimensions." + ), + ) + interpolation = "shepard" + return interpolation - # possible extrapolation values: constant, natural, zero + def __validate_extrapolation(self, extrapolation): + if self.__dom_dim__ == 1: if extrapolation is None: extrapolation = "constant" elif extrapolation.lower() not in ["constant", "natural", "zero"]: @@ -3075,56 +3046,14 @@ def _validate_interpolation_and_extrapolation( extrapolation = "constant" ## multiple dimensions - elif source_dim > 2: - if inputs == ["Scalar"]: - inputs = [f"Input {i+1}" for i in range(source_dim - 1)] - - if interpolation not in [None, "shepard"]: - warnings.warn( - ( - "Interpolation method set to 'shepard'. Other methods " - "are not supported yet." - ), - ) - interpolation = "shepard" - + elif self.__dom_dim__ > 1: if extrapolation not in [None, "natural"]: warnings.warn( "Extrapolation method set to 'natural'. Other methods " "are not supported yet." ) extrapolation = "natural" - else: - raise ValueError("Source must have at least 2 dimensions.") - return inputs, interpolation, extrapolation - - @staticmethod - def _validate_source_dimensions(inputs, outputs, source): - """Used to check whether the source dimensions match the inputs and - outputs. - - Parameters - ---------- - inputs : list of strings - List of inputs, each input is a string. Example: ['x', 'y'] - outputs : list of strings - List of outputs, each output is a string. Example: ['z'] - source : np.ndarray - The source data of the Function object. This has to be a numpy - array. - - Raises - ------ - ValueError - In case the source dimensions do not match the inputs and outputs. - """ - source_dim = source.shape[1] - in_out_dim = len(inputs) + len(outputs) - if source_dim != in_out_dim: - raise ValueError( - f"Source dimension ({source_dim}) does not match input " - f"and output dimension ({in_out_dim})." - ) + return extrapolation class PiecewiseFunction(Function): diff --git a/tests/test_flight.py b/tests/test_flight.py index a4b03d101..db882e185 100644 --- a/tests/test_flight.py +++ b/tests/test_flight.py @@ -729,7 +729,7 @@ def test_velocities(flight_calisto_custom_wind, flight_time, expected_values): ("t_initial", (1.6542528, 0.65918, -0.067107)), ("out_of_rail_time", (5.05334, 2.01364, -1.7541)), ("apogee_time", (2.35291, -1.8275, -0.87851)), - ("t_final", (0, 0, 159.3292416824044)), + ("t_final", (0, 0, 159.2212)), ], ) def test_aerodynamic_forces(flight_calisto_custom_wind, flight_time, expected_values): diff --git a/tests/test_function.py b/tests/test_function.py index 2ce94f691..6f4122e47 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -102,7 +102,8 @@ def test_setters(func_from_csv, func_2d_from_csv): func_2d_from_csv.set_interpolation("shepard") assert func_2d_from_csv.get_interpolation_method() == "shepard" func_2d_from_csv.set_extrapolation("zero") - assert func_2d_from_csv.get_extrapolation_method() == "zero" + # 2d functions do not support zero extrapolation, must change to natural + assert func_2d_from_csv.get_extrapolation_method() == "natural" @patch("matplotlib.pyplot.show") @@ -181,7 +182,32 @@ def test_extrapolation_methods(linear_func): assert linear_func.get_extrapolation_method() == "constant" assert np.isclose(linear_func.get_value(-1), 0, atol=1e-6) - # Test natural + # Test natural for linear interpolation + linear_func.set_interpolation("linear") + assert isinstance(linear_func.set_extrapolation("natural"), Function) + linear_func.set_extrapolation("natural") + assert isinstance(linear_func.get_extrapolation_method(), str) + assert linear_func.get_extrapolation_method() == "natural" + assert np.isclose(linear_func.get_value(-1), -1, atol=1e-6) + + # Test natural for spline interpolation + linear_func.set_interpolation("spline") + assert isinstance(linear_func.set_extrapolation("natural"), Function) + linear_func.set_extrapolation("natural") + assert isinstance(linear_func.get_extrapolation_method(), str) + assert linear_func.get_extrapolation_method() == "natural" + assert np.isclose(linear_func.get_value(-1), -1, atol=1e-6) + + # Test natural for akima interpolation + linear_func.set_interpolation("akima") + assert isinstance(linear_func.set_extrapolation("natural"), Function) + linear_func.set_extrapolation("natural") + assert isinstance(linear_func.get_extrapolation_method(), str) + assert linear_func.get_extrapolation_method() == "natural" + assert np.isclose(linear_func.get_value(-1), -1, atol=1e-6) + + # Test natural for polynomial interpolation + linear_func.set_interpolation("polynomial") assert isinstance(linear_func.set_extrapolation("natural"), Function) linear_func.set_extrapolation("natural") assert isinstance(linear_func.get_extrapolation_method(), str) diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 8bcefb818..17da59498 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -306,3 +306,31 @@ def test_remove_outliers_iqr(x, y, expected_x, expected_y): assert filtered_func.__interpolation__ == func.__interpolation__ assert filtered_func.__extrapolation__ == func.__extrapolation__ assert filtered_func.title == func.title + + +def test_set_get_value_opt(): + """Test the set_value_opt and get_value_opt methods of the Function class.""" + func = Function(lambda x: x**2) + func.source = np.array([[1, 1], [2, 4], [3, 9], [4, 16], [5, 25]]) + func.x_array = np.array([1, 2, 3, 4, 5]) + func.y_array = np.array([1, 4, 9, 16, 25]) + func.x_initial = 1 + func.x_final = 5 + func.set_interpolation("linear") + func.set_get_value_opt() + assert func.get_value_opt(2.5) == 6.5 + + +def test_get_image_dim(linear_func): + """Test the get_img_dim method of the Function class.""" + assert linear_func.get_image_dim() == 1 + + +def test_get_domain_dim(linear_func): + """Test the get_domain_dim method of the Function class.""" + assert linear_func.get_domain_dim() == 1 + + +def test_bool(linear_func): + """Test the __bool__ method of the Function class.""" + assert bool(linear_func) == True