From 26cc6a1b9bb50a3352f1bb0bad62bab6c63c36a5 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 28 Mar 2024 17:46:38 +0100 Subject: [PATCH 01/15] ENH: rework setters and validation --- rocketpy/mathutils/function.py | 632 +++++++++++++++------------------ 1 file changed, 286 insertions(+), 346 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index c887a904a..86135e180 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -20,6 +20,14 @@ from ..tools import cached_property NUMERICAL_TYPES = (float, int, complex, np.ndarray, np.integer, np.floating) +INTERPOLATION_TYPES = { + "linear": 0, + "polynomial": 1, + "akima": 2, + "spline": 3, + "shepard": 4, +} +EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2} class Function: @@ -110,27 +118,25 @@ def __init__( (II) Fields in CSV files may be enclosed in double quotes. If fields are not quoted, double quotes should not appear inside them. """ - - inputs, outputs = Function._validate_inputs_outputs(inputs, outputs) - # initialize variables to avoid errors when being called by other methods - self.get_value_opt = None self.__polynomial_coefficients__ = None self.__akima_coefficients__ = None self.__spline_coefficients__ = None - self.EXTRAPOLATION_TYPES = { - "zero": 0, - "natural": 1, - "constant": 2, - } - - # store variables - self.set_inputs(inputs) - self.set_outputs(outputs) + + # initialize parameters + self.source = source + self.__inputs__ = inputs + self.__outputs__ = outputs self.__interpolation__ = interpolation self.__extrapolation__ = extrapolation - self.set_source(source) - self.set_title(title) + self.title = title + self.__img_dim__ = 1 # always 1, here for backwards compatibility + + # args must be passed from self. + self.set_source(self.source) + self.set_inputs(self.__inputs__) + self.set_outputs(self.__outputs__) + self.set_title(self.title) # Define all set methods def set_inputs(self, inputs): @@ -145,8 +151,7 @@ def set_inputs(self, inputs): ------- self : Function """ - self.__inputs__ = [inputs] if isinstance(inputs, str) else list(inputs) - self.__dom_dim__ = len(self.__inputs__) + self.__inputs__ = self._validate_inputs(inputs) return self def set_outputs(self, outputs): @@ -161,8 +166,7 @@ def set_outputs(self, outputs): ------- self : Function """ - self.__outputs__ = [outputs] if isinstance(outputs, str) else list(outputs) - self.__img_dim__ = len(self.__outputs__) + self.__outputs__ = self._validate_outputs(outputs) return self def set_source(self, source): @@ -211,20 +215,10 @@ def set_source(self, source): self : Function Returns the Function instance. """ - source, inputs, outputs, interpolation, extrapolation = self._check_user_input( - source, - self.__inputs__, - self.__outputs__, - self.__interpolation__, - self.__extrapolation__, - ) - # updates inputs and outputs (could be modified due to csv headers) - self.set_inputs(inputs) - self.set_outputs(outputs) + source = self._validate_source(source) # Handle callable source or number source if callable(source): - self.source = source self.get_value_opt = source self.__interpolation__ = None self.__extrapolation__ = None @@ -232,13 +226,16 @@ def set_source(self, source): # Set arguments name and domain dimensions parameters = signature(source).parameters self.__dom_dim__ = len(parameters) - if self.__inputs__ == ["Scalar"]: + if self.__inputs__ is None: self.__inputs__ = list(parameters) # Handle ndarray source else: + # Evaluate dimension + self.__dom_dim__ = source.shape[1] - 1 + # Check to see if dimensions match incoming data set - new_total_dim = len(source[0, :]) + new_total_dim = source.shape[1] old_total_dim = self.__dom_dim__ + self.__img_dim__ # If they don't, update default values or throw error @@ -250,20 +247,25 @@ def set_source(self, source): # if Function is 1D, sort source by x. If 2D, set z if self.__dom_dim__ == 1: source = source[source[:, 0].argsort()] + self.x_array = source[:, 0] + self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] + self.y_array = source[:, 1] + self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] + self.get_value_opt = self._get_value_opt_1d elif self.__dom_dim__ == 2: + self.x_array = source[:, 0] + self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] + self.y_array = source[:, 1] + self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] self.z_array = source[:, 2] self.z_initial, self.z_final = self.z_array[0], self.z_array[-1] + self.get_value_opt = self._get_value_opt_nd + else: + self.get_value_opt = self._get_value_opt_nd - # Set x and y arrays (common for 1D or multivariate) - self.x_array = source[:, 0] - self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] - self.y_array = source[:, 1] - self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] - - # Finally set source, update extrapolation and interpolation - self.source = source - self.__extrapolation__ = extrapolation # to avoid calling set_get_value_opt - self.set_interpolation(interpolation) + self.source = source + self.set_interpolation(self.__interpolation__) + self.set_extrapolation(self.__extrapolation__) return self @cached_property @@ -303,18 +305,27 @@ def set_interpolation(self, method="spline"): ------- self : Function """ - self.__interpolation__ = method + if not callable(self.source): + self.__interpolation__ = self._validate_interpolation(method) + self._update_interpolation_coefficients(self.__interpolation__) + self._set_interpolation_function() + return self + + def _update_interpolation_coefficients(self, method): + """Update interpolation coefficients for the given method.""" # Spline, akima and polynomial need data processing # Shepard, and linear do not - if method == "spline": - self.__interpolate_spline__() - elif method == "polynomial": + if method == "polynomial": self.__interpolate_polynomial__() + self._coeffs = self.__polynomial_coefficients__ elif method == "akima": self.__interpolate_akima__() - - self.set_get_value_opt() - return self + self._coeffs = self.__akima_coefficients__ + elif method == "spline" or method is None: + self.__interpolate_spline__() + self._coeffs = self.__spline_coefficients__ + else: + self._coeffs = [] def set_extrapolation(self, method="constant"): """Set extrapolation behavior of data set. @@ -333,142 +344,159 @@ def set_extrapolation(self, method="constant"): self : Function The Function object. """ - self.__extrapolation__ = method - self.set_get_value_opt() + if not callable(self.source): + self.__extrapolation__ = self._validate_extrapolation(method) + self._set_extrapolation_function() return self - def set_get_value_opt(self): - """Crates a method that evaluates interpolations rather quickly - when compared to other options available, such as just calling - the object instance or calling ``Function.get_value`` directly. See - ``Function.get_value_opt`` for documentation. + def _set_interpolation_function(self): + """Return a function that interpolates the Function.""" + interpolation = INTERPOLATION_TYPES[self.__interpolation__] + if interpolation == 0: # linear + + def linear_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = np.searchsorted(x_data, x) + x_left = x_data[x_interval - 1] + y_left = y_data[x_interval - 1] + dx = float(x_data[x_interval] - x_left) + dy = float(y_data[x_interval] - y_left) + y = (x - x_left) * (dy / dx) + y_left + return y - Returns - ------- - self : Function - """ - # Retrieve general info - x_data = self.x_array - y_data = self.y_array - x_min, x_max = self.x_initial, self.x_final - try: - extrapolation = self.EXTRAPOLATION_TYPES[self.__extrapolation__] - except KeyError as err: - raise ValueError( - f"Invalid extrapolation type '{self.__extrapolation__}'" - ) from err + self._interpolation_function = linear_interpolation - # Crete method to interpolate this info for each interpolation type - if self.__interpolation__ == "spline": - coeffs = self.__spline_coefficients__ + elif interpolation == 1: # polynomial - def get_value_opt(x): - if x_min <= x <= x_max: - # Interpolate - x_interval = np.searchsorted(x_data, x) - x_interval = max(x_interval, 1) - a = coeffs[:, x_interval - 1] - x = x - x_data[x_interval - 1] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - if x < x_min: - a = coeffs[:, 0] - x = x - x_data[0] - else: - a = coeffs[:, -1] - x = x - x_data[-2] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] + def polynomial_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + y = np.sum(coeffs * x ** np.arange(len(coeffs))) + return y + + self._interpolation_function = polynomial_interpolation + + elif interpolation == 2: # akima + + def akima_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = np.searchsorted(x_data, x) + x_interval = x_interval if x_interval != 0 else 1 + a = coeffs[4 * x_interval - 4 : 4 * x_interval] + y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] + return y + + self._interpolation_function = akima_interpolation + + elif interpolation == 3: # spline + + def spline_interpolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = np.searchsorted(x_data, x) + x_interval = max(x_interval, 1) + a = coeffs[:, x_interval - 1] + x = x - x_data[x_interval - 1] + y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] return y - elif self.__interpolation__ == "linear": + self._interpolation_function = spline_interpolation + + elif interpolation == 4: # shepard does not use interpolation function + self._interpolation_function = None + + def _set_extrapolation_function(self): + """Return a function that extrapolates the Function.""" + interpolation = INTERPOLATION_TYPES[self.__interpolation__] + extrapolation = EXTRAPOLATION_TYPES[self.__extrapolation__] + + if interpolation == 4: # shepard does not use extrapolation function + self._extrapolation_function = None + + elif extrapolation == 0: # zero + + def zero_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + return 0 - def get_value_opt(x): - if x_min <= x <= x_max: - # Interpolate - x_interval = np.searchsorted(x_data, x) + self._extrapolation_function = zero_extrapolation + elif extrapolation == 1: # natural + if interpolation == 0: # linear + + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + x_interval = 1 if x < x_min else -1 x_left = x_data[x_interval - 1] y_left = y_data[x_interval - 1] dx = float(x_data[x_interval] - x_left) dy = float(y_data[x_interval] - y_left) y = (x - x_left) * (dy / dx) + y_left - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - x_interval = 1 if x < x_min else -1 - x_left = x_data[x_interval - 1] - y_left = y_data[x_interval - 1] - dx = float(x_data[x_interval] - x_left) - dy = float(y_data[x_interval] - y_left) - y = (x - x_left) * (dy / dx) + y_left - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y + return y + + elif interpolation == 1: # polynomial + + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + y = np.sum(coeffs * x ** np.arange(len(coeffs))) + return y - elif self.__interpolation__ == "akima": - coeffs = np.array(self.__akima_coefficients__) + elif interpolation == 2: # akima - def get_value_opt(x): - if x_min <= x <= x_max: - # Interpolate - x_interval = np.searchsorted(x_data, x) - x_interval = x_interval if x_interval != 0 else 1 - a = coeffs[4 * x_interval - 4 : 4 * x_interval] + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + a = coeffs[:4] if x < x_min else coeffs[-4:] y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - a = coeffs[:4] if x < x_min else coeffs[-4:] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y + return y - elif self.__interpolation__ == "polynomial": - coeffs = self.__polynomial_coefficients__ + elif interpolation == 3: # spline - def get_value_opt(x): - # Interpolate... or extrapolate - if x_min <= x <= x_max: - # Interpolate - y = np.sum(coeffs * x ** np.arange(len(coeffs))) - else: - # Extrapolate - if extrapolation == 0: # Extrapolation == zero - y = 0 - elif extrapolation == 1: # Extrapolation == natural - y = 0 - for i, coef in enumerate(coeffs): - y += coef * (x**i) - else: # Extrapolation is set to constant - y = y_data[0] if x < x_min else y_data[-1] - return y + def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + if x < x_min: + a = coeffs[:, 0] + x = x - x_data[0] + else: + a = coeffs[:, -1] + x = x - x_data[-2] + y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] + return y - elif self.__interpolation__ == "shepard": - # change the function's name to avoid mypy's error - def get_value_opt_multiple(*args): - return self.__interpolate_shepard__(args) + self._extrapolation_function = natural_extrapolation + elif extrapolation == 2: # constant - get_value_opt = get_value_opt_multiple + def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): + return y_data[0] if x < x_min else y_data[-1] - try: - self.get_value_opt = get_value_opt - except UnboundLocalError: - warnings.warn( - "Cannot set the get_value_opt method when interpolation is " - f"{self.__interpolation}. Try using the set_interpolation method first." - ) + self._extrapolation_function = constant_extrapolation + + def set_get_value_opt(self): + """Crates a method that evaluates interpolations rather quickly + when compared to other options available, such as just calling + the object instance or calling ``Function.get_value`` directly. See + ``Function.get_value_opt`` for documentation. + + Returns + ------- + self : Function + """ + if callable(self.source): + self.get_value_opt = self.source + elif self.__dom_dim__ == 1: + self.get_value_opt = self._get_value_opt_1d + elif self.__dom_dim__ > 1: + self.get_value_opt = self._get_value_opt_nd return self + def _get_value_opt_1d(self, x, *args): + # Retrieve general info + x_data = self.x_array + y_data = self.y_array + x_min, x_max = self.x_initial, self.x_final + coeffs = self._coeffs + if x_min <= x <= x_max: + y = self._interpolation_function( + x, x_min, x_max, x_data, y_data, coeffs, *args + ) + else: + y = self._extrapolation_function( + x, x_min, x_max, x_data, y_data, coeffs, *args + ) + return y + + def _get_value_opt_nd(self, *args): + # always use shepard for N-D functions + y = self.__interpolate_shepard__(list(args)) + return y + def set_discrete( self, lower=0, @@ -534,7 +562,7 @@ def set_discrete( ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs) func.set_source(np.concatenate(([xs], [ys])).transpose()) func.set_interpolation(interpolation) - func.__extrapolation__ = extrapolation # avoid calling set_get_value_opt + func.set_extrapolation(extrapolation) elif func.__dom_dim__ == 2: lower = 2 * [lower] if isinstance(lower, (int, float)) else lower upper = 2 * [upper] if isinstance(upper, (int, float)) else upper @@ -673,7 +701,7 @@ def set_discrete_based_on_model( ) func.set_interpolation(interp) - func.__extrapolation__ = extrap # avoid calling set_get_value_opt + func.set_extrapolation(extrap) return func @@ -732,7 +760,7 @@ def reset( if interpolation is not None and interpolation != self.__interpolation__: self.set_interpolation(interpolation) if extrapolation is not None and extrapolation != self.__extrapolation__: - self.__extrapolation__ = extrapolation + self.set_extrapolation(extrapolation) self.set_title(title) @@ -862,9 +890,8 @@ def get_value(self, *args): if all(isinstance(arg, Iterable) for arg in args): return [self.source(*arg) for arg in zip(*args)] - # Returns value for shepard interpolation - elif self.__interpolation__ == "shepard": - return self.__interpolate_shepard__(args) + elif self.__dom_dim__ > 1: # deals with nd functions and shepard interp + return self.get_value_opt(*args) # Returns value for other interpolation type else: # interpolation is "polynomial", "spline", "akima" or "linear" @@ -2857,79 +2884,31 @@ def savetxt( np.savetxt(file, data_points, fmt=fmt, delimiter=delimiter, newline=newline) # Input validators - - @staticmethod - def _check_user_input( - source, - inputs=None, - outputs=None, - interpolation=None, - extrapolation=None, - ): - """ - Validates and processes the user input parameters for creating or - modifying a Function object. This function ensures the inputs, outputs, - interpolation, and extrapolation parameters are compatible with the - given source. It converts the source to a numpy array if necessary, sets - default values and raises warnings or errors for incompatible or - ill-defined parameters. + def _validate_source(self, source): + """Used to validate the source parameter for creating a Function object. Parameters ---------- - source : list, np.ndarray, or callable - The source data or Function object. If a list or ndarray, it should - contain numeric data. If a Function, its inputs and outputs are - checked against the provided inputs and outputs. - inputs : list of str or None - The names of the input variables. If None, defaults are generated - based on the dimensionality of the source. - outputs : str or list of str - The name(s) of the output variable(s). If a list is provided, it - must have a single element. - interpolation : str or None - The method of interpolation to be used. For multidimensional sources - it defaults to 'shepard' if not provided. - extrapolation : str or None - The method of extrapolation to be used. For multidimensional sources - it defaults to 'natural' if not provided. + source : np.ndarray, callable, str, Path, Function, list + The source data of the Function object. This can be a numpy array, + a callable function, a string or Path object to a csv or txt file, + a Function object, or a list of numbers. Returns ------- - tuple - A tuple containing the processed source, inputs, outputs, - interpolation, and extrapolation parameters. + np.ndarray, callable + The validated source parameter. Raises ------ ValueError - If the dimensionality of the source does not match the combined - dimensions of inputs and outputs. If the outputs list has more than - one element. - - Examples - -------- - >>> from rocketpy import Function - >>> source = np.array([(1, 1), (2, 4), (3, 9)]) - >>> inputs = "x" - >>> outputs = ["y"] - >>> interpolation = 'linear' - >>> extrapolation = 'zero' - >>> inputs, outputs, interpolation, extrapolation = Function._check_user_input( - ... source, inputs, outputs, interpolation, extrapolation - ... ) - >>> inputs - ['x'] - >>> outputs - ['y'] - >>> interpolation - 'linear' - >>> extrapolation - 'zero' + If the source is not a valid type or if the source is not a 2D array + or a callable function. """ if isinstance(source, Function): - source = source.get_source() + return source.get_source() - elif isinstance(source, (str, Path)): + if isinstance(source, (str, Path)): # Read csv or txt files and create a numpy array try: source = np.loadtxt(source, delimiter=",", dtype=np.float64) @@ -2941,26 +2920,25 @@ def _check_user_input( source = np.loadtxt(data, delimiter=",", dtype=np.float64) if len(source[0]) == len(header): - if inputs == ["Scalar"]: - inputs = header[:-1] - if outputs == ["Scalar"]: - outputs = [header[-1]] + if self.__inputs__ is None: + self.__inputs__ = header[:-1] + if self.__outputs__ is None: + self.__outputs__ = [header[-1]] except Exception as e: raise ValueError( "Could not read the csv or txt file to create Function source." ) from e - if isinstance(source, list): + if isinstance(source, list) or isinstance(source, np.ndarray): # Triggers an error if source is not a list of numbers source = np.array(source, dtype=np.float64) - if isinstance(source, np.ndarray): - inputs, interpolation, extrapolation = ( - Function._validate_interpolation_and_extrapolation( - inputs, interpolation, extrapolation, source + # Checks if 2D array + if len(source.shape) != 2: + raise ValueError( + "Source must be a 2D array in the form [[x1, x2 ..., xn, y], ...]." ) - ) - Function._validate_source_dimensions(inputs, outputs, source) + return source if isinstance(source, (int, float)): # Convert number source into vectorized lambda function @@ -2969,86 +2947,78 @@ def _check_user_input( def source_function(_): return temp - source = source_function - return source, inputs, outputs, interpolation, extrapolation + return source_function - @staticmethod - def _validate_inputs_outputs(inputs, outputs): - """Used to validate the inputs and outputs parameters for creating a - Function object. It sets default values if they are not provided. + # If source is a callable function + return source + + def _validate_inputs(self, inputs): + """Used to validate the inputs parameter for creating a Function object. + It sets a default value if it is not provided. Parameters ---------- - inputs : str, list of str, None + inputs : list of str, None The name(s) of the input variable(s). If None, defaults to "Scalar". - outputs : - The name of the output variables. If None, defaults to "Scalar". Returns ------- - tuple - A tuple containing the validated inputs and outputs parameters. - - Raises - ------ - ValueError - If the output has more than one element. - """ - if inputs is None: - inputs = ["Scalar"] - if outputs is None: - outputs = ["Scalar"] - # check output type and dimensions - if isinstance(inputs, str): - inputs = [inputs] - if isinstance(outputs, str): - outputs = [outputs] - elif len(outputs) > 1: + list + The validated inputs parameter. + """ + if self.__dom_dim__ == 1: + if inputs is None: + return ["Scalar"] + if isinstance(inputs, str): + return [inputs] + if isinstance(inputs, list): + if len(inputs) == 1 and isinstance(inputs[0], str): + return inputs + raise ValueError( + "Inputs must be a string or a list of strings with " + "the length of the domain dimension." + ) + if self.__dom_dim__ > 1: + if inputs is None: + return [f"Input {i+1}" for i in range(self.__dom_dim__)] + if isinstance(inputs, list): + if len(inputs) == self.__dom_dim__ and all( + isinstance(i, str) for i in inputs + ): + return inputs raise ValueError( - "Output must either be a string or have dimension 1, " - + f"it currently has dimension ({len(outputs)})." + "Inputs must be a list of strings with " + "the length of the domain dimension." ) - return inputs, outputs - @staticmethod - def _validate_interpolation_and_extrapolation( - inputs, interpolation, extrapolation, source - ): - """Used to validate the interpolation and extrapolation methods for - creating a Function object. It sets default values for interpolation - and extrapolation if they are not provided or if they are not supported - for the given source. The inputs and outputs may be modified if the - source is multidimensional. + def _validate_outputs(self, outputs): + """Used to validate the outputs parameter for creating a Function object. + It sets a default value if it is not provided. Parameters ---------- - inputs : list of strings - List of inputs, each input is a string. Example: ['x', 'y'] - interpolation : str, None - The type of interpolation to use. The default method is 'spline'. - Currently supported values are 'spline', 'linear', 'polynomial', - 'akima', and 'shepard'. - extrapolation : str, None - The type of extrapolation to use. Currently supported values are - 'constant', 'natural', and 'zero'. The default method is 'constant'. - source : np.ndarray - The source data of the Function object. This has to be a numpy - array. + outputs : str, list of str, None + The name of the output variables. If None, defaults to "Scalar". Returns ------- - tuple - A tuple with the validated inputs, interpolation, and extrapolation - parameters (inputs, interpolation, extrapolation). - - Raises - ------ - ValueError - If the source has less than 2 dimensions. + list + The validated outputs parameter. """ - source_dim = source.shape[1] - ## single dimension (1D Functions) - if source_dim == 2: + if outputs is None: + return ["Scalar"] + if isinstance(outputs, str): + return [outputs] + if isinstance(outputs, list): + if len(outputs) > 1 or not isinstance(outputs[0], str): + raise ValueError( + "Output must either be a string or a list of strings with " + + f"one item. It currently has dimension ({len(outputs)})." + ) + return outputs + + def _validate_interpolation(self, interpolation): + if self.__dom_dim__ == 1: # possible interpolation values: linear, polynomial, akima and spline if interpolation is None: interpolation = "spline" @@ -3063,8 +3033,20 @@ def _validate_interpolation_and_extrapolation( f"{interpolation} method is not supported." ) interpolation = "spline" + ## multiple dimensions + elif self.__dom_dim__ > 1: + if interpolation not in [None, "shepard"]: + warnings.warn( + ( + "Interpolation method set to 'shepard'. Only 'shepard' " + "interpolation is supported for multiple dimensions." + ), + ) + interpolation = "shepard" + return interpolation - # possible extrapolation values: constant, natural, zero + def _validate_extrapolation(self, extrapolation): + if self.__dom_dim__ == 1: if extrapolation is None: extrapolation = "constant" elif extrapolation.lower() not in ["constant", "natural", "zero"]: @@ -3075,56 +3057,14 @@ def _validate_interpolation_and_extrapolation( extrapolation = "constant" ## multiple dimensions - elif source_dim > 2: - if inputs == ["Scalar"]: - inputs = [f"Input {i+1}" for i in range(source_dim - 1)] - - if interpolation not in [None, "shepard"]: - warnings.warn( - ( - "Interpolation method set to 'shepard'. Other methods " - "are not supported yet." - ), - ) - interpolation = "shepard" - + elif self.__dom_dim__ > 1: if extrapolation not in [None, "natural"]: warnings.warn( "Extrapolation method set to 'natural'. Other methods " "are not supported yet." ) extrapolation = "natural" - else: - raise ValueError("Source must have at least 2 dimensions.") - return inputs, interpolation, extrapolation - - @staticmethod - def _validate_source_dimensions(inputs, outputs, source): - """Used to check whether the source dimensions match the inputs and - outputs. - - Parameters - ---------- - inputs : list of strings - List of inputs, each input is a string. Example: ['x', 'y'] - outputs : list of strings - List of outputs, each output is a string. Example: ['z'] - source : np.ndarray - The source data of the Function object. This has to be a numpy - array. - - Raises - ------ - ValueError - In case the source dimensions do not match the inputs and outputs. - """ - source_dim = source.shape[1] - in_out_dim = len(inputs) + len(outputs) - if source_dim != in_out_dim: - raise ValueError( - f"Source dimension ({source_dim}) does not match input " - f"and output dimension ({in_out_dim})." - ) + return extrapolation class PiecewiseFunction(Function): From 387e5326da2bf10f2f7966f8d210fd83efa09f73 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 28 Mar 2024 17:47:08 +0100 Subject: [PATCH 02/15] TST: fix function test for correct behaviour --- tests/test_function.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_function.py b/tests/test_function.py index 2ce94f691..70613671e 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -102,7 +102,8 @@ def test_setters(func_from_csv, func_2d_from_csv): func_2d_from_csv.set_interpolation("shepard") assert func_2d_from_csv.get_interpolation_method() == "shepard" func_2d_from_csv.set_extrapolation("zero") - assert func_2d_from_csv.get_extrapolation_method() == "zero" + # 2d functions do not support zero extrapolation, must change to natural + assert func_2d_from_csv.get_extrapolation_method() == "natural" @patch("matplotlib.pyplot.show") From c175e76d3dad30534ba18a1c7d0e07d635ac8c9d Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 28 Mar 2024 17:47:47 +0100 Subject: [PATCH 03/15] BUG: revert comit https://github.com/RocketPy-Team/RocketPy/pull/581/commits/9a637a834e443736e7fa988830d1586aaf1f9201 --- tests/test_flight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flight.py b/tests/test_flight.py index a4b03d101..db882e185 100644 --- a/tests/test_flight.py +++ b/tests/test_flight.py @@ -729,7 +729,7 @@ def test_velocities(flight_calisto_custom_wind, flight_time, expected_values): ("t_initial", (1.6542528, 0.65918, -0.067107)), ("out_of_rail_time", (5.05334, 2.01364, -1.7541)), ("apogee_time", (2.35291, -1.8275, -0.87851)), - ("t_final", (0, 0, 159.3292416824044)), + ("t_final", (0, 0, 159.2212)), ], ) def test_aerodynamic_forces(flight_calisto_custom_wind, flight_time, expected_values): From 81a79e3c34b98804645cc10ffd3c6b5895e9cb29 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 28 Mar 2024 18:51:31 +0100 Subject: [PATCH 04/15] ENH: change np.searchsorted to bisect_left --- rocketpy/mathutils/function.py | 43 ++++++++++++++++------------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 86135e180..107c39c61 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt import numpy as np from scipy import integrate, linalg, optimize +from bisect import bisect_left try: from functools import cached_property @@ -308,7 +309,7 @@ def set_interpolation(self, method="spline"): if not callable(self.source): self.__interpolation__ = self._validate_interpolation(method) self._update_interpolation_coefficients(self.__interpolation__) - self._set_interpolation_function() + self._set_interpolation_func() return self def _update_interpolation_coefficients(self, method): @@ -346,16 +347,16 @@ def set_extrapolation(self, method="constant"): """ if not callable(self.source): self.__extrapolation__ = self._validate_extrapolation(method) - self._set_extrapolation_function() + self._set_extrapolation_func() return self - def _set_interpolation_function(self): + def _set_interpolation_func(self): """Return a function that interpolates the Function.""" interpolation = INTERPOLATION_TYPES[self.__interpolation__] if interpolation == 0: # linear def linear_interpolation(x, x_min, x_max, x_data, y_data, coeffs): - x_interval = np.searchsorted(x_data, x) + x_interval = bisect_left(x_data, x) x_left = x_data[x_interval - 1] y_left = y_data[x_interval - 1] dx = float(x_data[x_interval] - x_left) @@ -363,7 +364,7 @@ def linear_interpolation(x, x_min, x_max, x_data, y_data, coeffs): y = (x - x_left) * (dy / dx) + y_left return y - self._interpolation_function = linear_interpolation + self._interpolation_func = linear_interpolation elif interpolation == 1: # polynomial @@ -371,48 +372,48 @@ def polynomial_interpolation(x, x_min, x_max, x_data, y_data, coeffs): y = np.sum(coeffs * x ** np.arange(len(coeffs))) return y - self._interpolation_function = polynomial_interpolation + self._interpolation_func = polynomial_interpolation elif interpolation == 2: # akima def akima_interpolation(x, x_min, x_max, x_data, y_data, coeffs): - x_interval = np.searchsorted(x_data, x) + x_interval = bisect_left(x_data, x) x_interval = x_interval if x_interval != 0 else 1 a = coeffs[4 * x_interval - 4 : 4 * x_interval] y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] return y - self._interpolation_function = akima_interpolation + self._interpolation_func = akima_interpolation elif interpolation == 3: # spline def spline_interpolation(x, x_min, x_max, x_data, y_data, coeffs): - x_interval = np.searchsorted(x_data, x) + x_interval = bisect_left(x_data, x) x_interval = max(x_interval, 1) a = coeffs[:, x_interval - 1] x = x - x_data[x_interval - 1] y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] return y - self._interpolation_function = spline_interpolation + self._interpolation_func = spline_interpolation elif interpolation == 4: # shepard does not use interpolation function - self._interpolation_function = None + self._interpolation_func = None - def _set_extrapolation_function(self): + def _set_extrapolation_func(self): """Return a function that extrapolates the Function.""" interpolation = INTERPOLATION_TYPES[self.__interpolation__] extrapolation = EXTRAPOLATION_TYPES[self.__extrapolation__] if interpolation == 4: # shepard does not use extrapolation function - self._extrapolation_function = None + self._extrapolation_func = None elif extrapolation == 0: # zero def zero_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): return 0 - self._extrapolation_function = zero_extrapolation + self._extrapolation_func = zero_extrapolation elif extrapolation == 1: # natural if interpolation == 0: # linear @@ -450,13 +451,13 @@ def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] return y - self._extrapolation_function = natural_extrapolation + self._extrapolation_func = natural_extrapolation elif extrapolation == 2: # constant def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): return y_data[0] if x < x_min else y_data[-1] - self._extrapolation_function = constant_extrapolation + self._extrapolation_func = constant_extrapolation def set_get_value_opt(self): """Crates a method that evaluates interpolations rather quickly @@ -476,20 +477,16 @@ def set_get_value_opt(self): self.get_value_opt = self._get_value_opt_nd return self - def _get_value_opt_1d(self, x, *args): + def _get_value_opt_1d(self, x): # Retrieve general info x_data = self.x_array y_data = self.y_array x_min, x_max = self.x_initial, self.x_final coeffs = self._coeffs if x_min <= x <= x_max: - y = self._interpolation_function( - x, x_min, x_max, x_data, y_data, coeffs, *args - ) + y = self._interpolation_func(x, x_min, x_max, x_data, y_data, coeffs) else: - y = self._extrapolation_function( - x, x_min, x_max, x_data, y_data, coeffs, *args - ) + y = self._extrapolation_func(x, x_min, x_max, x_data, y_data, coeffs) return y def _get_value_opt_nd(self, *args): From c1d13247133be3402c68216b3009156861f9df20 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 28 Mar 2024 19:08:12 +0100 Subject: [PATCH 05/15] DOC: adds docs to new methods --- rocketpy/mathutils/function.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 107c39c61..ef9b13287 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -351,7 +351,11 @@ def set_extrapolation(self, method="constant"): return self def _set_interpolation_func(self): - """Return a function that interpolates the Function.""" + """Defines interpolation function used by the Function. Each + interpolation method has its own function with exception of shepard, + which has its interpolation/extrapolation function defined in + ``Function.__interpolate_shepard__``. The function is stored in + the attribute _interpolation_func.""" interpolation = INTERPOLATION_TYPES[self.__interpolation__] if interpolation == 0: # linear @@ -401,7 +405,9 @@ def spline_interpolation(x, x_min, x_max, x_data, y_data, coeffs): self._interpolation_func = None def _set_extrapolation_func(self): - """Return a function that extrapolates the Function.""" + """Defines extrapolation function used by the Function. Each + extrapolation method has its own function. The function is stored in + the attribute _extrapolation_func.""" interpolation = INTERPOLATION_TYPES[self.__interpolation__] extrapolation = EXTRAPOLATION_TYPES[self.__extrapolation__] @@ -460,10 +466,7 @@ def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): self._extrapolation_func = constant_extrapolation def set_get_value_opt(self): - """Crates a method that evaluates interpolations rather quickly - when compared to other options available, such as just calling - the object instance or calling ``Function.get_value`` directly. See - ``Function.get_value_opt`` for documentation. + """Defines a method that evaluates interpolations. Returns ------- @@ -478,6 +481,19 @@ def set_get_value_opt(self): return self def _get_value_opt_1d(self, x): + """Evaluate the Function at a single point x. This method is used + when the Function is 1-D. + + Parameters + ---------- + x : scalar + Value where the Function is to be evaluated. + + Returns + ------- + y : scalar + Value of the Function at the specified point. + """ # Retrieve general info x_data = self.x_array y_data = self.y_array @@ -490,6 +506,8 @@ def _get_value_opt_1d(self, x): return y def _get_value_opt_nd(self, *args): + """Evaluate the Function at a single point (x, y, z). This method is + used when the Function is N-D.""" # always use shepard for N-D functions y = self.__interpolate_shepard__(list(args)) return y From d5d9c31e67eff9564f8382d424d5ff0e7ad206b8 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Sat, 30 Mar 2024 22:11:47 +0100 Subject: [PATCH 06/15] MNT: run isort --- rocketpy/mathutils/function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index ef9b13287..0cdb139cc 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -5,6 +5,7 @@ """ import warnings +from bisect import bisect_left from collections.abc import Iterable from copy import deepcopy from inspect import signature From 7cd270b956fa020305a6a6bdc45d52cab7283ce7 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Sat, 30 Mar 2024 22:12:02 +0100 Subject: [PATCH 07/15] MNT: remove unecessary varibale intialization --- rocketpy/mathutils/function.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 0cdb139cc..20c9e41d3 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -14,7 +14,6 @@ import matplotlib.pyplot as plt import numpy as np from scipy import integrate, linalg, optimize -from bisect import bisect_left try: from functools import cached_property @@ -120,11 +119,6 @@ def __init__( (II) Fields in CSV files may be enclosed in double quotes. If fields are not quoted, double quotes should not appear inside them. """ - # initialize variables to avoid errors when being called by other methods - self.__polynomial_coefficients__ = None - self.__akima_coefficients__ = None - self.__spline_coefficients__ = None - # initialize parameters self.source = source self.__inputs__ = inputs From 0bd6f6396e8716fb0c76025dd2e29380528cfa0d Mon Sep 17 00:00:00 2001 From: MateusStano Date: Sat, 30 Mar 2024 22:37:48 +0100 Subject: [PATCH 08/15] ENH: x, y and z array for all ndarray functions --- rocketpy/mathutils/function.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 20c9e41d3..731fc7810 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -240,7 +240,7 @@ def set_source(self, source): self.__dom_dim__ = new_total_dim - 1 self.__inputs__ = self.__dom_dim__ * self.__inputs__ - # if Function is 1D, sort source by x. If 2D, set z + # set x and y. If Function is 2D, also set z if self.__dom_dim__ == 1: source = source[source[:, 0].argsort()] self.x_array = source[:, 0] @@ -248,7 +248,7 @@ def set_source(self, source): self.y_array = source[:, 1] self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] self.get_value_opt = self._get_value_opt_1d - elif self.__dom_dim__ == 2: + elif self.__dom_dim__ > 1: self.x_array = source[:, 0] self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] self.y_array = source[:, 1] @@ -256,8 +256,6 @@ def set_source(self, source): self.z_array = source[:, 2] self.z_initial, self.z_final = self.z_array[0], self.z_array[-1] self.get_value_opt = self._get_value_opt_nd - else: - self.get_value_opt = self._get_value_opt_nd self.source = source self.set_interpolation(self.__interpolation__) From ca1ce44c912223c2906341e16077d947027cf7ed Mon Sep 17 00:00:00 2001 From: MateusStano <69485049+MateusStano@users.noreply.github.com> Date: Sat, 30 Mar 2024 19:28:37 -0300 Subject: [PATCH 09/15] Update rocketpy/mathutils/function.py Co-authored-by: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com> --- rocketpy/mathutils/function.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 731fc7810..9dd9ebdf3 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -366,8 +366,7 @@ def linear_interpolation(x, x_min, x_max, x_data, y_data, coeffs): elif interpolation == 1: # polynomial def polynomial_interpolation(x, x_min, x_max, x_data, y_data, coeffs): - y = np.sum(coeffs * x ** np.arange(len(coeffs))) - return y + return np.sum(coeffs * x ** np.arange(len(coeffs))) self._interpolation_func = polynomial_interpolation From 8cfaedcae674f317e5b40d5829e5be6e66a1b46b Mon Sep 17 00:00:00 2001 From: MateusStano <69485049+MateusStano@users.noreply.github.com> Date: Sun, 31 Mar 2024 07:45:26 -0300 Subject: [PATCH 10/15] Update rocketpy/mathutils/function.py Co-authored-by: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com> --- rocketpy/mathutils/function.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 9dd9ebdf3..94c9d70ff 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -501,8 +501,7 @@ def _get_value_opt_nd(self, *args): """Evaluate the Function at a single point (x, y, z). This method is used when the Function is N-D.""" # always use shepard for N-D functions - y = self.__interpolate_shepard__(list(args)) - return y + return self.__interpolate_shepard__(list(args)) def set_discrete( self, From 07f9ed50dccf78eda4095042d8868185ecd152dd Mon Sep 17 00:00:00 2001 From: MateusStano Date: Sun, 31 Mar 2024 12:50:25 +0200 Subject: [PATCH 11/15] MNT: remove unecessary list casting --- rocketpy/mathutils/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 94c9d70ff..2e3ea0e63 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -501,7 +501,7 @@ def _get_value_opt_nd(self, *args): """Evaluate the Function at a single point (x, y, z). This method is used when the Function is N-D.""" # always use shepard for N-D functions - return self.__interpolate_shepard__(list(args)) + return self.__interpolate_shepard__(args) def set_discrete( self, From fa8a89208fcfb0234df25ed88d7c4d457c307643 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Sun, 31 Mar 2024 12:50:44 +0200 Subject: [PATCH 12/15] MNT: return interp and extrap results directly --- rocketpy/mathutils/function.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 2e3ea0e63..02277ce1c 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -358,8 +358,7 @@ def linear_interpolation(x, x_min, x_max, x_data, y_data, coeffs): y_left = y_data[x_interval - 1] dx = float(x_data[x_interval] - x_left) dy = float(y_data[x_interval] - y_left) - y = (x - x_left) * (dy / dx) + y_left - return y + return (x - x_left) * (dy / dx) + y_left self._interpolation_func = linear_interpolation @@ -376,8 +375,7 @@ def akima_interpolation(x, x_min, x_max, x_data, y_data, coeffs): x_interval = bisect_left(x_data, x) x_interval = x_interval if x_interval != 0 else 1 a = coeffs[4 * x_interval - 4 : 4 * x_interval] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - return y + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] self._interpolation_func = akima_interpolation @@ -388,8 +386,7 @@ def spline_interpolation(x, x_min, x_max, x_data, y_data, coeffs): x_interval = max(x_interval, 1) a = coeffs[:, x_interval - 1] x = x - x_data[x_interval - 1] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - return y + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] self._interpolation_func = spline_interpolation @@ -421,21 +418,18 @@ def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): y_left = y_data[x_interval - 1] dx = float(x_data[x_interval] - x_left) dy = float(y_data[x_interval] - y_left) - y = (x - x_left) * (dy / dx) + y_left - return y + return (x - x_left) * (dy / dx) + y_left elif interpolation == 1: # polynomial def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): - y = np.sum(coeffs * x ** np.arange(len(coeffs))) - return y + return np.sum(coeffs * x ** np.arange(len(coeffs))) elif interpolation == 2: # akima def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): a = coeffs[:4] if x < x_min else coeffs[-4:] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - return y + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] elif interpolation == 3: # spline @@ -446,8 +440,7 @@ def natural_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): else: a = coeffs[:, -1] x = x - x_data[-2] - y = a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] - return y + return a[3] * x**3 + a[2] * x**2 + a[1] * x + a[0] self._extrapolation_func = natural_extrapolation elif extrapolation == 2: # constant From 29bb5fae988226ad2e119171a78e124e95275f87 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Sun, 31 Mar 2024 13:04:23 +0200 Subject: [PATCH 13/15] ENH: completely private methods --- rocketpy/mathutils/function.py | 44 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 02277ce1c..4a7c81154 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -147,7 +147,7 @@ def set_inputs(self, inputs): ------- self : Function """ - self.__inputs__ = self._validate_inputs(inputs) + self.__inputs__ = self.__validate_inputs(inputs) return self def set_outputs(self, outputs): @@ -162,7 +162,7 @@ def set_outputs(self, outputs): ------- self : Function """ - self.__outputs__ = self._validate_outputs(outputs) + self.__outputs__ = self.__validate_outputs(outputs) return self def set_source(self, source): @@ -211,7 +211,7 @@ def set_source(self, source): self : Function Returns the Function instance. """ - source = self._validate_source(source) + source = self.__validate_source(source) # Handle callable source or number source if callable(source): @@ -247,7 +247,7 @@ def set_source(self, source): self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] self.y_array = source[:, 1] self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] - self.get_value_opt = self._get_value_opt_1d + self.get_value_opt = self.__get_value_opt_1d elif self.__dom_dim__ > 1: self.x_array = source[:, 0] self.x_initial, self.x_final = self.x_array[0], self.x_array[-1] @@ -255,7 +255,7 @@ def set_source(self, source): self.y_initial, self.y_final = self.y_array[0], self.y_array[-1] self.z_array = source[:, 2] self.z_initial, self.z_final = self.z_array[0], self.z_array[-1] - self.get_value_opt = self._get_value_opt_nd + self.get_value_opt = self.__get_value_opt_nd self.source = source self.set_interpolation(self.__interpolation__) @@ -300,12 +300,12 @@ def set_interpolation(self, method="spline"): self : Function """ if not callable(self.source): - self.__interpolation__ = self._validate_interpolation(method) - self._update_interpolation_coefficients(self.__interpolation__) - self._set_interpolation_func() + self.__interpolation__ = self.__validate_interpolation(method) + self.__update_interpolation_coefficients(self.__interpolation__) + self.__set_interpolation_func() return self - def _update_interpolation_coefficients(self, method): + def __update_interpolation_coefficients(self, method): """Update interpolation coefficients for the given method.""" # Spline, akima and polynomial need data processing # Shepard, and linear do not @@ -339,11 +339,11 @@ def set_extrapolation(self, method="constant"): The Function object. """ if not callable(self.source): - self.__extrapolation__ = self._validate_extrapolation(method) - self._set_extrapolation_func() + self.__extrapolation__ = self.__validate_extrapolation(method) + self.__set_extrapolation_func() return self - def _set_interpolation_func(self): + def __set_interpolation_func(self): """Defines interpolation function used by the Function. Each interpolation method has its own function with exception of shepard, which has its interpolation/extrapolation function defined in @@ -393,7 +393,7 @@ def spline_interpolation(x, x_min, x_max, x_data, y_data, coeffs): elif interpolation == 4: # shepard does not use interpolation function self._interpolation_func = None - def _set_extrapolation_func(self): + def __set_extrapolation_func(self): """Defines extrapolation function used by the Function. Each extrapolation method has its own function. The function is stored in the attribute _extrapolation_func.""" @@ -460,12 +460,12 @@ def set_get_value_opt(self): if callable(self.source): self.get_value_opt = self.source elif self.__dom_dim__ == 1: - self.get_value_opt = self._get_value_opt_1d + self.get_value_opt = self.__get_value_opt_1d elif self.__dom_dim__ > 1: - self.get_value_opt = self._get_value_opt_nd + self.get_value_opt = self.__get_value_opt_nd return self - def _get_value_opt_1d(self, x): + def __get_value_opt_1d(self, x): """Evaluate the Function at a single point x. This method is used when the Function is 1-D. @@ -490,7 +490,7 @@ def _get_value_opt_1d(self, x): y = self._extrapolation_func(x, x_min, x_max, x_data, y_data, coeffs) return y - def _get_value_opt_nd(self, *args): + def __get_value_opt_nd(self, *args): """Evaluate the Function at a single point (x, y, z). This method is used when the Function is N-D.""" # always use shepard for N-D functions @@ -2883,7 +2883,7 @@ def savetxt( np.savetxt(file, data_points, fmt=fmt, delimiter=delimiter, newline=newline) # Input validators - def _validate_source(self, source): + def __validate_source(self, source): """Used to validate the source parameter for creating a Function object. Parameters @@ -2951,7 +2951,7 @@ def source_function(_): # If source is a callable function return source - def _validate_inputs(self, inputs): + def __validate_inputs(self, inputs): """Used to validate the inputs parameter for creating a Function object. It sets a default value if it is not provided. @@ -2990,7 +2990,7 @@ def _validate_inputs(self, inputs): "the length of the domain dimension." ) - def _validate_outputs(self, outputs): + def __validate_outputs(self, outputs): """Used to validate the outputs parameter for creating a Function object. It sets a default value if it is not provided. @@ -3016,7 +3016,7 @@ def _validate_outputs(self, outputs): ) return outputs - def _validate_interpolation(self, interpolation): + def __validate_interpolation(self, interpolation): if self.__dom_dim__ == 1: # possible interpolation values: linear, polynomial, akima and spline if interpolation is None: @@ -3044,7 +3044,7 @@ def _validate_interpolation(self, interpolation): interpolation = "shepard" return interpolation - def _validate_extrapolation(self, extrapolation): + def __validate_extrapolation(self, extrapolation): if self.__dom_dim__ == 1: if extrapolation is None: extrapolation = "constant" From 3bee49b3c723dcdc33380254513e985ffec0cf97 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 4 Apr 2024 13:17:37 +0200 Subject: [PATCH 14/15] TST: improve tests coverage --- rocketpy/mathutils/function.py | 18 ++++-------------- tests/test_function.py | 27 ++++++++++++++++++++++++++- tests/unit/test_function.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 4a7c81154..ea98bff78 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -230,16 +230,6 @@ def set_source(self, source): # Evaluate dimension self.__dom_dim__ = source.shape[1] - 1 - # Check to see if dimensions match incoming data set - new_total_dim = source.shape[1] - old_total_dim = self.__dom_dim__ + self.__img_dim__ - - # If they don't, update default values or throw error - if new_total_dim != old_total_dim: - # Update dimensions and inputs - self.__dom_dim__ = new_total_dim - 1 - self.__inputs__ = self.__dom_dim__ * self.__inputs__ - # set x and y. If Function is 2D, also set z if self.__dom_dim__ == 1: source = source[source[:, 0].argsort()] @@ -2970,8 +2960,8 @@ def __validate_inputs(self, inputs): return ["Scalar"] if isinstance(inputs, str): return [inputs] - if isinstance(inputs, list): - if len(inputs) == 1 and isinstance(inputs[0], str): + if isinstance(inputs, (list, tuple)): + if len(inputs) == 1: return inputs raise ValueError( "Inputs must be a string or a list of strings with " @@ -3008,8 +2998,8 @@ def __validate_outputs(self, outputs): return ["Scalar"] if isinstance(outputs, str): return [outputs] - if isinstance(outputs, list): - if len(outputs) > 1 or not isinstance(outputs[0], str): + if isinstance(outputs, (list, tuple)): + if len(outputs) > 1: raise ValueError( "Output must either be a string or a list of strings with " + f"one item. It currently has dimension ({len(outputs)})." diff --git a/tests/test_function.py b/tests/test_function.py index 70613671e..6f4122e47 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -182,7 +182,32 @@ def test_extrapolation_methods(linear_func): assert linear_func.get_extrapolation_method() == "constant" assert np.isclose(linear_func.get_value(-1), 0, atol=1e-6) - # Test natural + # Test natural for linear interpolation + linear_func.set_interpolation("linear") + assert isinstance(linear_func.set_extrapolation("natural"), Function) + linear_func.set_extrapolation("natural") + assert isinstance(linear_func.get_extrapolation_method(), str) + assert linear_func.get_extrapolation_method() == "natural" + assert np.isclose(linear_func.get_value(-1), -1, atol=1e-6) + + # Test natural for spline interpolation + linear_func.set_interpolation("spline") + assert isinstance(linear_func.set_extrapolation("natural"), Function) + linear_func.set_extrapolation("natural") + assert isinstance(linear_func.get_extrapolation_method(), str) + assert linear_func.get_extrapolation_method() == "natural" + assert np.isclose(linear_func.get_value(-1), -1, atol=1e-6) + + # Test natural for akima interpolation + linear_func.set_interpolation("akima") + assert isinstance(linear_func.set_extrapolation("natural"), Function) + linear_func.set_extrapolation("natural") + assert isinstance(linear_func.get_extrapolation_method(), str) + assert linear_func.get_extrapolation_method() == "natural" + assert np.isclose(linear_func.get_value(-1), -1, atol=1e-6) + + # Test natural for polynomial interpolation + linear_func.set_interpolation("polynomial") assert isinstance(linear_func.set_extrapolation("natural"), Function) linear_func.set_extrapolation("natural") assert isinstance(linear_func.get_extrapolation_method(), str) diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 8bcefb818..17da59498 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -306,3 +306,31 @@ def test_remove_outliers_iqr(x, y, expected_x, expected_y): assert filtered_func.__interpolation__ == func.__interpolation__ assert filtered_func.__extrapolation__ == func.__extrapolation__ assert filtered_func.title == func.title + + +def test_set_get_value_opt(): + """Test the set_value_opt and get_value_opt methods of the Function class.""" + func = Function(lambda x: x**2) + func.source = np.array([[1, 1], [2, 4], [3, 9], [4, 16], [5, 25]]) + func.x_array = np.array([1, 2, 3, 4, 5]) + func.y_array = np.array([1, 4, 9, 16, 25]) + func.x_initial = 1 + func.x_final = 5 + func.set_interpolation("linear") + func.set_get_value_opt() + assert func.get_value_opt(2.5) == 6.5 + + +def test_get_image_dim(linear_func): + """Test the get_img_dim method of the Function class.""" + assert linear_func.get_image_dim() == 1 + + +def test_get_domain_dim(linear_func): + """Test the get_domain_dim method of the Function class.""" + assert linear_func.get_domain_dim() == 1 + + +def test_bool(linear_func): + """Test the __bool__ method of the Function class.""" + assert bool(linear_func) == True From 00a3c1f2971d23933cb86e0ed852dddfe72461c2 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Thu, 4 Apr 2024 13:19:14 +0200 Subject: [PATCH 15/15] DEV: changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 90aa086ca..b96efbcc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- ENH: Function Validation Rework & Swap `np.searchsorted` to `bisect_left` [#582](https://github.com/RocketPy-Team/RocketPy/pull/582) - ENH: Add new stability margin properties to Flight class [#572](https://github.com/RocketPy-Team/RocketPy/pull/572) - ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554)