diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c3aa2373..a1e06c193 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,8 +31,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). Attention: The newest changes should be on top --> ### Added - -- ENH: Add save functionality to `_MonteCarloPlots.all` method [#848](https://github.com/RocketPy-Team/RocketPy/pull/848) +- ENH: Built-in flight comparison tool (`FlightComparator`) to validate simulations against external data [#888](https://github.com/RocketPy-Team/RocketPy/pull/888) - ENH: Add persistent caching for ThrustCurve API [#881](https://github.com/RocketPy-Team/RocketPy/pull/881) - ENH: Compatibility with MERRA-2 atmosphere reanalysis files [#825](https://github.com/RocketPy-Team/RocketPy/pull/825) - ENH: Enable only radial burning [#815](https://github.com/RocketPy-Team/RocketPy/pull/815) diff --git a/docs/user/flight_comparator.rst b/docs/user/flight_comparator.rst new file mode 100644 index 000000000..162d4c173 --- /dev/null +++ b/docs/user/flight_comparator.rst @@ -0,0 +1,183 @@ +Flight Comparator +================= + +This example demonstrates how to use the RocketPy ``FlightComparator`` class to +compare a Flight simulation against external data sources. + +Users must explicitly create a `FlightComparator` instance. + + +This class is designed to compare a RocketPy Flight simulation against external +data sources, such as: + +- Real flight data (avionics logs, altimeter CSVs) +- Simulations from other software (OpenRocket, RASAero) +- Theoretical models or manual calculations + +Unlike ``CompareFlights`` (which compares multiple RocketPy simulations), +``FlightComparator`` specifically handles the challenge of aligning different +time steps and calculating error metrics (RMSE, MAE, etc.) between a +"Reference" simulation and "External" data. + +Importing classes +----------------- + +We will start by importing the necessary classes and modules: + +.. jupyter-execute:: + + import numpy as np + + from rocketpy import Environment, Rocket, SolidMotor, Flight + from rocketpy.simulation import FlightComparator, FlightDataImporter + +Create Simulation (Reference) +----------------------------- + +First, let's create the standard RocketPy simulation that will serve as our +"Ground Truth" or reference model. This follows the standard setup. + +.. jupyter-execute:: + + # 1. Setup Environment + env = Environment( + date=(2022, 10, 1, 12), + latitude=32.990254, + longitude=-106.974998, + elevation=1400, + ) + env.set_atmospheric_model(type="standard_atmosphere") + + # 2. Setup Motor + Pro75M1670 = SolidMotor( + thrust_source="../data/motors/cesaroni/Cesaroni_M1670.eng", + burn_time=3.9, + grain_number=5, + grain_density=1815, + grain_outer_radius=33 / 1000, + grain_initial_inner_radius=15 / 1000, + grain_initial_height=120 / 1000, + grain_separation=5 / 1000, + nozzle_radius=33 / 1000, + throat_radius=11 / 1000, + interpolation_method="linear", + coordinate_system_orientation="nozzle_to_combustion_chamber", + dry_mass=1.815, + dry_inertia=(0.125, 0.125, 0.002), + grains_center_of_mass_position=0.33, + center_of_dry_mass_position=0.317, + nozzle_position=0, + ) + + # 3. Setup Rocket + calisto = Rocket( + radius=127 / 2000, + mass=19.197 - 2.956, + inertia=(6.321, 6.321, 0.034), + power_off_drag="../data/calisto/powerOffDragCurve.csv", + power_on_drag="../data/calisto/powerOnDragCurve.csv", + center_of_mass_without_motor=0, + coordinate_system_orientation="tail_to_nose", + ) + + calisto.set_rail_buttons(0.0818, -0.618, 45) + calisto.add_motor(Pro75M1670, position=-1.255) + + # Add aerodynamic surfaces + nosecone = calisto.add_nose(length=0.55829, kind="vonKarman", position=0.71971) + fin_set = calisto.add_trapezoidal_fins( + n=4, + root_chord=0.120, + tip_chord=0.040, + span=0.100, + position=-1.04956, + cant_angle=0.5, + airfoil=("../data/calisto/fins/NACA0012-radians.txt", "radians"), + ) + tail = calisto.add_tail( + top_radius=0.0635, + bottom_radius=0.0435, + length=0.060, + position=-1.194656, + ) + + # 4. Simulate + flight = Flight( + rocket=calisto, + environment=env, + rail_length=5.2, + inclination=85, + heading=0, + ) + + # 5. Create FlightComparator instance + comparator = FlightComparator(flight) + +Adding Another Flight Object +---------------------------- + +You can compare against another RocketPy Flight simulation directly: + +.. jupyter-execute:: + + # Create a second simulation with slightly different parameters + flight2 = Flight( + rocket=calisto, + environment=env, + rail_length=5.0, # Slightly shorter rail + inclination=85, + heading=0, + ) + + # Add the second flight directly + comparator.add_data("Alternative Sim", flight2) + + print(f"Added variables from flight2: {list(comparator.data_sources['Alternative Sim'].keys())}") + +Importing External Data (dict) +------------------------------ + +The primary data format expected by ``FlightComparator.add_data`` is a dictionary +where keys are variable names (e.g. ``"z"``, ``"vz"``, ``"altitude"``) and values +are either: + +- A RocketPy ``Function`` object, or +- A tuple of ``(time_array, data_array)``. + +Let's create some synthetic external data to compare against our reference +simulation: + +.. jupyter-execute:: + + # Generate synthetic external data with realistic noise + time_external = np.linspace(0, flight.t_final, 80) # Different time steps + external_altitude = flight.z(time_external) + np.random.normal(0, 3, 80) # 3m noise + external_velocity = flight.vz(time_external) + np.random.normal(0, 0.5, 80) # 0.5 m/s noise + + # Add the external data to our comparator + comparator.add_data( + "External Simulator", + { + "altitude": (time_external, external_altitude), + "vz": (time_external, external_velocity), + } + ) + +Running Comparisons +------------------- + +Now we can run the various comparison methods: + +.. jupyter-execute:: + + # Generate summary with key events + comparator.summary() + + # Compare specific variable + comparator.compare("altitude") + + # Compare all available variables + comparator.all() + + # Plot 2D trajectory + comparator.trajectories_2d(plane="xz") diff --git a/rocketpy/plots/compare/compare_flights.py b/rocketpy/plots/compare/compare_flights.py index 521b2cf6b..4ff064858 100644 --- a/rocketpy/plots/compare/compare_flights.py +++ b/rocketpy/plots/compare/compare_flights.py @@ -139,7 +139,7 @@ def positions( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -196,7 +196,7 @@ def velocities( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -253,7 +253,7 @@ def stream_velocities( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -319,7 +319,7 @@ def accelerations( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -375,7 +375,7 @@ def euler_angles( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -435,7 +435,7 @@ def quaternions( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -491,7 +491,7 @@ def attitude_angles( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -546,7 +546,7 @@ def angular_velocities( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -601,7 +601,7 @@ def angular_accelerations( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -661,7 +661,7 @@ def aerodynamic_forces( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. Image options are: png, pdf, ps, eps and svg. @@ -720,7 +720,7 @@ def aerodynamic_moments( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -774,7 +774,7 @@ def energies( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -834,7 +834,7 @@ def powers( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -889,7 +889,7 @@ def rail_buttons_forces( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -955,7 +955,7 @@ def angles_of_attack( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -1011,7 +1011,7 @@ def fluid_mechanics( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -1074,7 +1074,7 @@ def stability_margin( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. @@ -1113,7 +1113,7 @@ def attitude_frequency( limit and second item, the y axis upper limit. If set to None, will be calculated automatically by matplotlib. legend : bool, optional - Weather or not to show the legend, by default True + Whether or not to show the legend, by default True filename : str, optional If a filename is provided, the plot will be saved to a file, by default None. diff --git a/rocketpy/simulation/flight_comparator.py b/rocketpy/simulation/flight_comparator.py new file mode 100644 index 000000000..7f1b34286 --- /dev/null +++ b/rocketpy/simulation/flight_comparator.py @@ -0,0 +1,904 @@ +import warnings + +import matplotlib.pyplot as plt +import numpy as np + +from rocketpy.mathutils import Function +from rocketpy.simulation.flight import Flight +from rocketpy.simulation.flight_data_importer import FlightDataImporter + +from ..plots.plot_helpers import show_or_save_fig + + +class FlightComparator: + """ + A class to compare a RocketPy Flight simulation against external data sources + (such as flight logs, OpenRocket simulations, RASAero). + + This class handles the time-interpolation required to compare datasets + recorded at different frequencies, and computes error metrics (RMSE, MAE, etc.) + between your RocketPy simulation and external or reference data. + + + Parameters + ---------- + flight : Flight + The reference RocketPy Flight object to compare against. + + Attributes + ---------- + flight : Flight + The reference RocketPy Flight object to compare against. + data_sources : dict + Dictionary storing external data sources in the format + {'Source Name': {'variable': Function}}. + + Examples + -------- + + .. code-block:: python + + from rocketpy.simulation import FlightComparator + + # Suppose you have a Flight object named 'my_flight' + comparator = FlightComparator(my_flight) + + # Add external data (e.g., from OpenRocket or logs) + comparator.add_data('OpenRocket', { + 'altitude': (time_array, altitude_array), + 'vz': (time_array, velocity_array) + }) + + # You can also add another RocketPy Flight directly: + comparator.add_data('OtherSimulation', other_flight) + + # Run comparisons + comparator.compare('altitude') + comparator.summary() + events_table = comparator.compare_key_events() + """ + + DEFAULT_GRID_POINTS = 1000 # number of points for interpolation grids + + def __init__(self, flight: Flight): + """ + Initialize the comparator with a reference RocketPy Flight. + + Parameters + ---------- + flight : rocketpy.Flight + The reference RocketPy Flight object to compare against. + + Returns + ------- + None + """ + # Duck-typed validation gives clear errors for Flight-like objects, more flexible than an isinstance check + required_attrs = ("t_final", "apogee", "apogee_time", "impact_velocity") + missing = [attr for attr in required_attrs if not hasattr(flight, attr)] + if missing: + raise TypeError( + "flight must be a rocketpy.Flight or Flight-like object with attributes " + f"{required_attrs}. Missing: {', '.join(missing)}" + ) + + self.flight = flight + self.data_sources = {} # The format is {'Source Name': {'variable': Function}} + + def add_data(self, label, data_dict): # pylint: disable=too-many-statements + """ + Add an external dataset to the comparator. + + Parameters + ---------- + label : str + Name of the data source (e.g., "Avionics", "OpenRocket", "RASAero"). + data_dict : dict, Flight, or FlightDataImporter + External data to be compared. + + If a dict, keys must be variable names (e.g., 'z', 'vz', 'az', 'altitude') + and values can be: + - A RocketPy Function object + - A tuple/list of (time_array, data_array) + + If a Flight object is provided, standard Flight attributes such as + 'z', 'vz', 'x', 'y', 'speed', 'vx', 'vy', 'ax', 'ay', 'az', 'acceleration' + will be registered automatically when available. + + If a FlightDataImporter object is provided, all flight attributes will be + registered automatically. In both cases, 'altitude' will be aliased to 'z' + if present. + """ + + if isinstance(data_dict, dict) and not data_dict: + raise ValueError("data_dict cannot be empty") + + processed_data = {} + + # Case 1: dict + if isinstance(data_dict, dict): + for key, value in data_dict.items(): + if isinstance(value, Function): + processed_data[key] = value + elif isinstance(value, (tuple, list)) and len(value) == 2: + time_arr, data_arr = value + processed_data[key] = Function( + np.column_stack((time_arr, data_arr)), + inputs="Time (s)", + outputs=key, + interpolation="linear", + ) + else: + warnings.warn( + f"Skipping '{key}' in '{label}'. Format not recognized. " + "Expected RocketPy Function or (time, data) tuple." + ) + + # Case 2: Flight instance + elif isinstance(data_dict, Flight): + external_flight = data_dict + candidate_vars = [ + "z", + "vz", + "x", + "y", + "speed", + "vx", + "vy", + "ax", + "ay", + "az", + "acceleration", + ] + for var in candidate_vars: + if hasattr(external_flight, var): + value = getattr(external_flight, var) + if isinstance(value, Function): + processed_data[var] = value + + # Provide 'altitude' alias for convenience if 'z' exists + if "z" in processed_data and "altitude" not in processed_data: + processed_data["altitude"] = processed_data["z"] + + if not processed_data: + warnings.warn( + f"No comparable variables found when using Flight " + f"object for data source '{label}'." + ) + + # Case 3: FlightDataImporter instance + elif isinstance(data_dict, FlightDataImporter): + importer = data_dict + candidate_vars = [ + "z", + "vz", + "x", + "y", + "speed", + "vx", + "vy", + "ax", + "ay", + "az", + "acceleration", + ] + for var in candidate_vars: + if hasattr(importer, var): + value = getattr(importer, var) + if isinstance(value, Function): + processed_data[var] = value + + # Provide 'altitude' alias for convenience if 'z' exists + if "z" in processed_data and "altitude" not in processed_data: + processed_data["altitude"] = processed_data["z"] + + if not processed_data: + warnings.warn( + f"No comparable variables found when using FlightDataImporter " + f"for data source '{label}'." + ) + + else: + warnings.warn( + f"Data source '{label}' not recognized. Expected a dict, Flight, " + "or FlightDataImporter object." + ) + + if label in self.data_sources: + warnings.warn(f"Data source '{label}' already exists. Overwriting.") + + self.data_sources[label] = processed_data + print( + f"Added data source '{label}' with variables: {list(processed_data.keys())}" + ) + + def _process_time_range(self, time_range): + """ + Validate and normalize the time_range argument. + + Parameters + ---------- + time_range : tuple of (float, float) or list of (float, float) or None + Tuple or list specifying the start and end times (in seconds) for the comparison. + If None, the full flight duration [0, flight.t_final] is used. + + Returns + ------- + tuple of (float, float) + The validated (t_min, t_max) time range in seconds, where + 0.0 <= t_min < t_max <= flight.t_final. + + Raises + ------ + TypeError + If time_range is not a tuple or list of two numeric values. + ValueError + If time_range values are invalid or out of bounds. + """ + if time_range is None: + return 0.0, self.flight.t_final + + if not isinstance(time_range, (tuple, list)) or len(time_range) != 2: + raise TypeError( + "time_range must be a (start_time, end_time) tuple or list." + ) + + t_min, t_max = time_range + if not isinstance(t_min, (int, float)) or not isinstance(t_max, (int, float)): + raise TypeError("time_range values must be numeric.") + + if t_min >= t_max: + raise ValueError("time_range[0] must be strictly less than time_range[1].") + + if t_min < 0 or t_max > self.flight.t_final: + raise ValueError( + "time_range must lie within [0, flight.t_final]. " + f"Got [{t_min}, {t_max}], flight.t_final={self.flight.t_final}." + ) + + return float(t_min), float(t_max) + + def _build_time_grid(self, t_min, t_max): + """ + Build a time grid for interpolation between t_min and t_max. + + Parameters + ---------- + t_min : float + Start time of the grid, in seconds. + t_max : float + End time of the grid, in seconds. + + Returns + ------- + numpy.ndarray + Array of time points (in seconds) linearly spaced between t_min and t_max, + with length equal to DEFAULT_GRID_POINTS. + """ + return np.linspace(t_min, t_max, self.DEFAULT_GRID_POINTS) + + def _setup_compare_figure(self, figsize, attribute): + """ + Create a matplotlib figure and axes for the compare() method. + + Parameters + ---------- + figsize : tuple of float + Size of the figure in inches as (width, height). + attribute : str + Name of the attribute being compared, used for the plot title. + + Returns + ------- + tuple + A tuple containing: + - fig : matplotlib.figure.Figure + The created figure object. + - ax1 : matplotlib.axes.Axes + The axes object for the main comparison plot. + - ax2 : matplotlib.axes.Axes + The axes object for the residuals (error) plot. + """ + fig, (ax1, ax2) = plt.subplots( + 2, + 1, + figsize=figsize, + sharex=True, + gridspec_kw={"height_ratios": [2, 1]}, + ) + ax1.set_title(f"Flight Comparison: {attribute}") + ax2.set_title("Residuals (Simulation - External)") + ax2.set_xlabel("Time (s)") + return fig, ax1, ax2 + + def _plot_reference_series(self, ax, t_grid, y_sim): + """ + Plot RocketPy reference curve on the given axes. + + Parameters + ---------- + ax : matplotlib.axes.Axes + he axes object on which to plot the reference curve. + t_grid : numpy.ndarray + Array of time points, in seconds. + y_sim : numpy.ndarray + Array of simulated values corresponding to t_grid. + + Returns + ------- + None + + """ + ax.plot( + t_grid, + y_sim, + label="RocketPy Simulation", + linewidth=2, + color="black", + alpha=0.8, + ) + + def _plot_external_sources( + self, + attribute, + t_grid, + y_sim, + ax_values, + ax_errors, + ): + """ + Plot external data sources and print error metrics. + + Parameters + ---------- + attribute : str + Name of the attribute to compare (e.g., 'altitude', 'vz'). + t_grid : np.ndarray + 1D array of time points (in seconds) at which to evaluate and plot the data. + y_sim : np.ndarray + 1D array of simulated values corresponding to t_grid. + ax_values : matplotlib.axes.Axes + Axes object to plot the simulation and external data values. + ax_errors : matplotlib.axes.Axes + Axes object to plot the error (residuals) between simulation and external data. + + Returns + ------- + bool + True if at least one external source had the specified attribute and data was plotted. + """ + has_plots = False + + print(f"\n{'-' * 20}") + print(f"COMPARISON REPORT: {attribute}") + print(f"{'-' * 20}") + + for label, dataset in self.data_sources.items(): + if attribute not in dataset: + continue + + has_plots = True + ext_func = dataset[attribute] + + # Interpolate External Data onto the same grid + y_ext = ext_func(t_grid) + + # Calculate Error (Residuals) + error = y_sim - y_ext + + # Calculate Metrics + mae = np.mean(np.abs(error)) # Mean Absolute Error + rmse = np.sqrt(np.mean(error**2)) # Root Mean Square Error + max_dev = np.max(np.abs(error)) # Max Deviation + + mean_abs_y_sim = np.mean(np.abs(y_sim)) + relative_error_pct = ( + (rmse / mean_abs_y_sim) * 100 if mean_abs_y_sim != 0 else np.inf + ) + + # Print Metrics + print(f"Source: {label}") + print(f" - MAE: {mae:.4f}") + print(f" - RMSE: {rmse:.4f}") + print(f" - Max Deviation: {max_dev:.4f}") + print(f" - Relative Error: {relative_error_pct:.2f}%") + + # Plot Data + ax_values.plot(t_grid, y_ext, label=label, linestyle="--") + + # Plot Error + ax_errors.plot(t_grid, error, label=f"Error ({label})") + + return has_plots + + def _finalize_compare_figure( + self, fig, ax_values, ax_errors, attribute, legend, filename + ): + """ + Apply formatting, legends, and show/save the comparison figure. + + Parameters + ---------- + fig : matplotlib.figure.Figure + The figure object containing the comparison plots. + ax_values : matplotlib.axes.Axes + The axes object displaying the compared values. + ax_errors : matplotlib.axes.Axes + The axes object displaying the residuals/errors. + attribute : str + Name of the attribute being compared (used for labels). + legend : bool + Whether to display legends on both axes. + filename : str or None + If provided, save the figure to this file path. If None, display the figure. + + Returns + ------- + None + """ + ax_values.set_ylabel(attribute) + ax_values.grid(True, linestyle=":", alpha=0.6) + + ax_errors.set_ylabel("Difference") + ax_errors.grid(True, linestyle=":", alpha=0.6) + + if legend: + ax_values.legend() + ax_errors.legend() + + fig.tight_layout() + show_or_save_fig(fig, filename) + if filename: + print(f"Plot saved to file: {filename}") + + def compare( # pylint: disable=too-many-statements + self, + attribute, + time_range=None, + figsize=(10, 8), + legend=True, + filename=None, + ): + """ + Compare a specific attribute (e.g., altitude, velocity) across all added data sources. + + This method generates a plot comparing the specified attribute from the reference + RocketPy Flight object and all added external data sources (e.g., OpenRocket, flight logs). + It interpolates all data onto a common time grid, computes error metrics (RMSE, MAE, + relative error), and displays or saves the resulting plot. + + Parameters + ---------- + attribute : str + Name of the attribute to compare (e.g., "altitude", "vz", "ax"). + The attribute must be present as a callable (function or property) in the + reference Flight object and in each external data source. + time_range : tuple of float, optional + Tuple specifying the time range (t_min, t_max) in seconds for the comparison. + If None (default), uses the full time range of the reference Flight. + figsize : tuple of float, optional + Size of the figure in inches, as (width, height). Default is (10, 8). + legend : bool, optional + Whether to display a legend on the plot. Default is True. + filename : str or None, optional + If provided, saves the plot to the specified file path. If None (default), + the plot is shown interactively. + + Returns + ------- + None + + """ + # 1. Get RocketPy Simulation Data + if not hasattr(self.flight, attribute): + warnings.warn( + f"Attribute '{attribute}' not found in the RocketPy Flight object." + ) + return + + sim_func = getattr(self.flight, attribute) + + # 2. Process time range and build grid + t_min, t_max = self._process_time_range(time_range) + t_grid = self._build_time_grid(t_min, t_max) + + # Interpolate Simulation onto the grid + y_sim = sim_func(t_grid) + + # 3. Set up figure and plot reference + fig, ax_values, ax_errors = self._setup_compare_figure(figsize, attribute) + self._plot_reference_series(ax_values, t_grid, y_sim) + + # 4. Plot external sources and metrics + has_plots = self._plot_external_sources( + attribute=attribute, + t_grid=t_grid, + y_sim=y_sim, + ax_values=ax_values, + ax_errors=ax_errors, + ) + + if not has_plots: + warnings.warn(f"No external sources have data for variable '{attribute}'.") + plt.close(fig) + return + + # 5. Final formatting and save/show + self._finalize_compare_figure( + fig=fig, + ax_values=ax_values, + ax_errors=ax_errors, + attribute=attribute, + legend=legend, + filename=filename, + ) + + def compare_key_events(self): # pylint: disable=too-many-statements + """ + Compare critical flight events across all data sources. + + Returns + ------- + dict + Comparison dictionary with metrics as keys, containing RocketPy values + and errors for each external data source. + """ + # Initialize results dictionary + results = {} + + # Create time grid for interpolation + t_grid = np.linspace(0, self.flight.t_final, self.DEFAULT_GRID_POINTS) + altitude_cache = {} + for label, dataset in self.data_sources.items(): + if "altitude" in dataset or "z" in dataset: + alt_func = dataset.get("altitude", dataset.get("z")) + altitude_cache[label] = alt_func(t_grid) + # 1. Compare Apogee Altitude + rocketpy_apogee = self.flight.apogee + apogee_results = {"RocketPy": rocketpy_apogee} + + for label, dataset in self.data_sources.items(): + if label in altitude_cache: + altitudes = altitude_cache[label] + ext_apogee = np.max(altitudes) + error = ext_apogee - rocketpy_apogee + rel_error = ( + (error / rocketpy_apogee) * 100 if rocketpy_apogee != 0 else np.inf + ) + + apogee_results[label] = { + "value": ext_apogee, + "error": error, + "error_percentage": rel_error, + } + + results["Apogee Altitude (m)"] = apogee_results + + # 2. Compare Apogee Time + rocketpy_apogee_time = self.flight.apogee_time + apogee_time_results = {"RocketPy": rocketpy_apogee_time} + + for label, dataset in self.data_sources.items(): + if label in altitude_cache: + altitudes = altitude_cache[label] + ext_apogee_time = t_grid[np.argmax(altitudes)] + error = ext_apogee_time - rocketpy_apogee_time + rel_error = ( + (error / rocketpy_apogee_time) * 100 + if rocketpy_apogee_time != 0 + else np.inf + ) + + apogee_time_results[label] = { + "value": ext_apogee_time, + "error": error, + "error_percentage": rel_error, + } + + results["Apogee Time (s)"] = apogee_time_results + + # 3. Compare Maximum Velocity + rocketpy_max_vel = self.flight.max_speed + max_vel_results = {"RocketPy": rocketpy_max_vel} + + for label, dataset in self.data_sources.items(): + if "speed" in dataset: + speed_func = dataset["speed"] + speeds = speed_func(t_grid) + ext_max_vel = np.max(speeds) + error = ext_max_vel - rocketpy_max_vel + rel_error = ( + (error / rocketpy_max_vel) * 100 + if rocketpy_max_vel != 0 + else np.inf + ) + + max_vel_results[label] = { + "value": ext_max_vel, + "error": error, + "error_percentage": rel_error, + "approximation": False, + } + elif "vz" in dataset: + vz_func = dataset["vz"] + vz_vals = vz_func(t_grid) + ext_max_vel = np.max(np.abs(vz_vals)) + error = ext_max_vel - rocketpy_max_vel + rel_error = ( + (error / rocketpy_max_vel) * 100 + if rocketpy_max_vel != 0 + else np.inf + ) + + max_vel_results[label] = { + "value": ext_max_vel, + "error": error, + "error_percentage": rel_error, + "approximation": True, + } + + results["Max Velocity (m/s)"] = max_vel_results + + # 4. Compare Impact Velocity + rocketpy_impact_vel = self.flight.impact_velocity + impact_vel_results = {"RocketPy": rocketpy_impact_vel} + + for label, dataset in self.data_sources.items(): + if "speed" in dataset: + speed_func = dataset["speed"] + ext_impact_vel = abs(speed_func(t_grid[-1])) + error = ext_impact_vel - rocketpy_impact_vel + rel_error = ( + (error / rocketpy_impact_vel) * 100 + if rocketpy_impact_vel != 0 + else np.inf + ) + + impact_vel_results[label] = { + "value": ext_impact_vel, + "error": error, + "error_percentage": rel_error, + "approximation": False, + } + elif "vz" in dataset: + vz_func = dataset["vz"] + ext_impact_vel = abs(vz_func(t_grid[-1])) + error = ext_impact_vel - rocketpy_impact_vel + rel_error = ( + (error / rocketpy_impact_vel) * 100 + if rocketpy_impact_vel != 0 + else np.inf + ) + + impact_vel_results[label] = { + "value": ext_impact_vel, + "error": error, + "error_percentage": rel_error, + "approximation": True, + } + + results["Impact Velocity (m/s)"] = impact_vel_results + + return results + + def _format_key_events_table(self, results): + """ + Format key events results as a string table. + + Parameters + ---------- + results : dict + Results from compare_key_events() + + Returns + ------- + str + Formatted table string + """ + lines = [] + + # Get all source names + sources = [] + for metric_data in results.values(): + for key in metric_data.keys(): + if key != "RocketPy" and key not in sources: + sources.append(key) + + # Header + header = f"{'Metric':<25} {'RocketPy':<15}" + for source in sources: + header += ( + f" {source:<15} {source + ' Error':<15} {source + ' Error (%)':<15}" + ) + lines.append(header) + lines.append("-" * len(header)) + + # Rows + for metric, data in results.items(): + row = f"{metric:<25} {data['RocketPy']:<15.2f}" + + for source in sources: + if source in data: + value = data[source]["value"] + error = data[source]["error"] + error_pct = data[source]["error_percentage"] + approx = "*" if data[source].get("approximation", False) else "" + row += f" {value:<15.2f}{approx} {error:<15.2f} {error_pct:<15.2f}" + else: + row += f" {'N/A':<15} {'N/A':<15} {'N/A':<15}" + + lines.append(row) + + return "\n".join(lines) + + def summary(self): # pylint: disable=too-many-statements + """ + Print comprehensive comparison summary including key events and metrics. + + Returns + ------- + None + """ + print("\n" + "=" * 60) + print("FLIGHT COMPARISON SUMMARY") + print("=" * 60) + + print("\nRocketPy Simulation:") + print( + f" - Apogee: {self.flight.apogee:.2f} m at t={self.flight.apogee_time:.2f} s" + ) + print(f" - Max velocity: {self.flight.max_speed:.2f} m/s") + print(f" - Impact velocity: {self.flight.impact_velocity:.2f} m/s") + print(f" - Flight duration: {self.flight.t_final:.2f} s") + + print(f"\nExternal Data Sources: {list(self.data_sources.keys())}") + + try: + events_results = self.compare_key_events() + print("\n" + self._format_key_events_table(events_results)) + print( + "\nNote: Values marked with * are approximations " + "(e.g., speed from vz only)" + ) + except (KeyError, AttributeError, ValueError) as exc: + print( + "Could not generate key events table. " + "Ensure external data sources contain compatible variables " + "such as 'altitude' or 'z' for altitude and 'speed' or 'vz' " + "for velocity. Details: " + f"{exc}" + ) + + print("\n" + "=" * 60) + + def all(self, time_range=None, figsize=(10, 8), legend=True): + """ + Generate comparison plots for all common variables found in both + the RocketPy simulation and external data sources. + + Parameters + ---------- + time_range : tuple, optional + (start_time, end_time) to restrict comparisons. + If None, uses the full flight duration. + figsize : tuple, optional + standard matplotlib figsize to be used in the plots, by default + (10, 8), where the tuple means (width, height). + legend : bool, optional + Whether or not to show the legend, by default True + + Returns + ------- + None + """ + # Common variables to check for + common_vars = [ + "z", + "vz", + "ax", + "ay", + "az", + "altitude", + "speed", + "vx", + "vy", + "acceleration", + ] + + # Find which variables are available in both simulation and at least one source + available_vars = [] + for var in common_vars: + if hasattr(self.flight, var): + # Check if at least one source has this variable + for dataset in self.data_sources.values(): + if var in dataset: + available_vars.append(var) + break + + if not available_vars: + print("No common variables found for comparison.") + return + + print(f"\nGenerating comparison plots for: {', '.join(available_vars)}\n") + + # Generate a plot for each available variable + for var in available_vars: + self.compare(var, time_range=time_range, figsize=figsize, legend=legend) + + def trajectories_2d(self, plane="xz", figsize=(7, 7), legend=True, filename=None): # pylint: disable=too-many-statements + """ + Compare 2D flight trajectories between RocketPy simulation and external sources. + Coordinates are plotted in the inertial NED-like frame used by Flight: + x is East, y is North and z is Up. + + Parameters + ---------- + plane : str, optional + Plane to plot: 'xy', 'xz', or 'yz'. Default is 'xz'. + figsize : tuple, optional + standard matplotlib figsize to be used in the plots, by default + (7, 7), where the tuple means (width, height). + legend : bool, optional + Whether or not to show the legend, by default True + filename : str, optional + If a filename is provided, the plot will be saved to a file, by + default None. Image options are: png, pdf, ps, eps and svg. + + Returns + ------- + None + """ + if plane not in ["xy", "xz", "yz"]: + raise ValueError("plane must be 'xy', 'xz', or 'yz'") + + axis1, axis2 = plane[0], plane[1] + + # Check if Flight object has the required attributes + if not hasattr(self.flight, axis1) or not hasattr(self.flight, axis2): + warnings.warn(f"Flight object missing {axis1} or {axis2} attributes") + return + + # Create figure + fig = plt.figure(figsize=figsize) + fig.suptitle("Flight Trajectories Comparison", fontsize=16, y=0.95, x=0.5) + ax = plt.subplot(111) + + # Create time grid for evaluation + t_grid = np.linspace(0, self.flight.t_final, self.DEFAULT_GRID_POINTS) + + # Plot RocketPy trajectory + x_sim = getattr(self.flight, axis1)(t_grid) + y_sim = getattr(self.flight, axis2)(t_grid) + + ax.plot(x_sim, y_sim, label="RocketPy", linewidth=2, color="black", alpha=0.8) + + # Plot external sources + has_plots = False + for label, dataset in self.data_sources.items(): + if axis1 in dataset and axis2 in dataset: + has_plots = True + x_ext = dataset[axis1](t_grid) + y_ext = dataset[axis2](t_grid) + ax.plot(x_ext, y_ext, label=label, linestyle="--", linewidth=1.5) + + if not has_plots: + warnings.warn(f"No external sources have both {axis1} and {axis2} data.") + plt.close(fig) + return + + # Formatting + axis_labels = {"x": "X - East (m)", "y": "Y - North (m)", "z": "Z - Up (m)"} + ax.set_xlabel(axis_labels.get(axis1, f"{axis1} (m)")) + ax.set_ylabel(axis_labels.get(axis2, f"{axis2} (m)")) + ax.scatter(0, 0, color="black", s=10, marker="o") + ax.grid(True) + + # Add legend + if legend: + fig.legend() + + fig.tight_layout() + + show_or_save_fig(fig, filename) + if filename: + print(f"Plot saved to file: {filename}") diff --git a/tests/conftest.py b/tests/conftest.py index 12d07c334..456de43ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ +import matplotlib import netCDF4 import numpy as np -import matplotlib import pytest # Configure matplotlib to use non-interactive backend for tests diff --git a/tests/integration/simulation/test_flight_comparator_workflow.py b/tests/integration/simulation/test_flight_comparator_workflow.py new file mode 100644 index 000000000..18c6daeab --- /dev/null +++ b/tests/integration/simulation/test_flight_comparator_workflow.py @@ -0,0 +1,92 @@ +"""Integration tests for FlightComparator. + +These tests exercise the full workflow of: +- running a Flight simulation, +- adding external data, +- generating comparison plots, +- summarizing key events. +""" + +import numpy as np + +from rocketpy.simulation.flight_comparator import FlightComparator +from rocketpy.simulation.flight_data_importer import FlightDataImporter + + +def test_full_workflow(flight_calisto): + """Test complete workflow: add data, compare, summary, plots. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Simulate external data with realistic errors + time_data = np.linspace(0, flight_calisto.t_final, 100) + + comparator.add_data( + "OpenRocket", + { + "altitude": ( + time_data, + flight_calisto.z(time_data) + np.random.normal(0, 5, 100), + ), + "vz": ( + time_data, + flight_calisto.vz(time_data) + np.random.normal(0, 1, 100), + ), + "x": (time_data, flight_calisto.x(time_data)), + "z": (time_data, flight_calisto.z(time_data)), + }, + ) + + # Test all methods - should run without error + comparator.summary() + comparator.compare("altitude") + results = comparator.compare_key_events() + comparator.trajectories_2d(plane="xz") + + # Verify results - compare_key_events now returns a dict + assert isinstance(results, dict) + assert len(results) >= 4 # At least 4 metrics + assert "Apogee Altitude (m)" in results + assert "Apogee Time (s)" in results + assert "Max Velocity (m/s)" in results + assert "Impact Velocity (m/s)" in results + + +def test_full_workflow_with_importer(flight_calisto, tmp_path): + """Full workflow using FlightDataImporter as external source.""" + comparator = FlightComparator(flight_calisto) + + # Create a tiny CSV with time,z,vz columns + csv_path = tmp_path / "flight_log.csv" + time_data = np.linspace(0, flight_calisto.t_final, 50) + z_data = flight_calisto.z(time_data) * 0.97 + vz_data = flight_calisto.vz(time_data) + + lines = ["time,z,vz\n"] + for t, z, vz in zip(time_data, z_data, vz_data): + lines.append(f"{t},{z},{vz}\n") + csv_path.write_text("".join(lines), encoding="utf-8") + + # Build importer + importer = FlightDataImporter( + paths=str(csv_path), + columns_map={"time": "time", "z": "z", "vz": "vz"}, + units=None, + ) + + # Use importer directly + comparator.add_data("Imported Log", importer) + + comparator.summary() + comparator.compare("z") + results = comparator.compare_key_events() + comparator.trajectories_2d(plane="xz") + + assert isinstance(results, dict) + assert "Apogee Altitude (m)" in results + assert "Impact Velocity (m/s)" in results diff --git a/tests/unit/simulation/test_flight_comparator.py b/tests/unit/simulation/test_flight_comparator.py new file mode 100644 index 000000000..cb708803c --- /dev/null +++ b/tests/unit/simulation/test_flight_comparator.py @@ -0,0 +1,561 @@ +"""Tests for the FlightComparator class. + +This module tests the FlightComparator class which compares RocketPy Flight +simulations against external data sources such as flight logs, OpenRocket +simulations, and RASAero simulations. +""" + +import os + +import numpy as np +import pytest + +from rocketpy import Function +from rocketpy.simulation.flight_comparator import FlightComparator +from rocketpy.simulation.flight_data_importer import FlightDataImporter + + +# Test FlightComparator initialization +def test_flight_comparator_init(flight_calisto): + """Test FlightComparator initialization. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + assert comparator.flight == flight_calisto + assert isinstance(comparator.data_sources, dict) + assert len(comparator.data_sources) == 0 + + +# Test add_data method with different input formats +def test_add_data_with_function(flight_calisto): + """Test adding external data using RocketPy Function objects. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Create mock Function object + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + np.random.normal(0, 5, 100) + + alt_function = Function( + np.column_stack((time_data, altitude_data)), + inputs="Time (s)", + outputs="Altitude (m)", + interpolation="linear", + ) + + comparator.add_data("Mock Data", {"z": alt_function}) + + assert "Mock Data" in comparator.data_sources + assert "z" in comparator.data_sources["Mock Data"] + assert isinstance(comparator.data_sources["Mock Data"]["z"], Function) + + +def test_add_data_with_tuple(flight_calisto): + """Test adding external data using (time, data) tuples. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Create mock data as tuples + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + np.random.normal(0, 5, 100) + velocity_data = flight_calisto.vz(time_data) + np.random.normal(0, 1, 100) + + comparator.add_data( + "External Simulator", + {"z": (time_data, altitude_data), "vz": (time_data, velocity_data)}, + ) + + assert "External Simulator" in comparator.data_sources + assert "z" in comparator.data_sources["External Simulator"] + assert "vz" in comparator.data_sources["External Simulator"] + assert isinstance(comparator.data_sources["External Simulator"]["z"], Function) + + +def test_add_data_overwrite_warning(flight_calisto): + """Test that adding data with same label raises warning. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + + comparator.add_data("Test", {"z": (time_data, altitude_data)}) + + with pytest.warns(UserWarning, match="already exists"): + comparator.add_data("Test", {"z": (time_data, altitude_data)}) + + +def test_add_data_empty_dict_raises_error(flight_calisto): + """Test that empty data_dict raises ValueError. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + with pytest.raises(ValueError, match="cannot be empty"): + comparator.add_data("Empty", {}) + + +def test_add_data_invalid_format_warning(flight_calisto): + """Test that invalid data format raises warning and skips variable. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + with pytest.warns(UserWarning, match="Format not recognized"): + comparator.add_data("Invalid", {"z": "invalid_string", "vz": 12345}) + + # Should have added the source but with no valid variables + assert "Invalid" in comparator.data_sources + assert len(comparator.data_sources["Invalid"]) == 0 + + +# Test compare method +def test_compare_basic(flight_calisto): + """Test basic comparison plot generation. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Add mock external data + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + 10 # 10m offset + + comparator.add_data("OpenRocket", {"z": (time_data, altitude_data)}) + + # This should generate plots and print metrics without error + comparator.compare("z") + + +def test_compare_with_time_range(flight_calisto): + """Test comparison with time_range parameter. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + + comparator.add_data("Test", {"z": (time_data, altitude_data)}) + + # Compare only first 5 seconds + comparator.compare("z", time_range=(0, 5)) + + +def test_compare_missing_attribute_warning(flight_calisto): + """Test that comparing non-existent attribute raises warning. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + with pytest.warns(UserWarning, match="not found in the RocketPy Flight"): + comparator.compare("nonexistent_attribute") + + +def test_compare_no_external_data_warning(flight_calisto): + """Test warning when no external sources have the requested variable. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Add data but without 'z' variable + time_data = np.linspace(0, flight_calisto.t_final, 100) + velocity_data = flight_calisto.vz(time_data) + + comparator.add_data("Test", {"vz": (time_data, velocity_data)}) + + with pytest.warns(UserWarning, match="No external sources have data"): + comparator.compare("z") + + +@pytest.mark.parametrize("filename", [None, "test_comparison.png"]) +def test_compare_save_file(flight_calisto, filename): + """Test comparison plot saving functionality. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + filename : str or None + Filename to save plot to, or None to show plot. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + + comparator.add_data("Test", {"z": (time_data, altitude_data)}) + comparator.compare("z", filename=filename) + + if filename: + assert os.path.exists(filename) + os.remove(filename) + + +# Test compare_key_events method +def test_compare_key_events_basic(flight_calisto): + """Test compare_key_events returns proper dict. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Add mock data with slight offset + time_data = np.linspace(0, flight_calisto.t_final, 100) + altitude_data = flight_calisto.z(time_data) + 5 + velocity_data = flight_calisto.vz(time_data) + + comparator.add_data( + "Simulator", + {"altitude": (time_data, altitude_data), "vz": (time_data, velocity_data)}, + ) + + results = comparator.compare_key_events() + + # Check it's a dict + assert isinstance(results, dict) + + # Check metrics exist + assert "Apogee Altitude (m)" in results + assert "Apogee Time (s)" in results + assert "Max Velocity (m/s)" in results + assert "Impact Velocity (m/s)" in results + apogee_alt = results["Apogee Altitude (m)"] + # Check RocketPy values exist + assert "RocketPy" in apogee_alt + + # Check external source data exists + assert "Simulator" in apogee_alt + assert set(apogee_alt["Simulator"].keys()) == { + "value", + "error", + "error_percentage", + } + + +def test_compare_key_events_multiple_sources(flight_calisto): + """Test compare_key_events with multiple data sources. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + + # Add two different simulators + comparator.add_data( + "OpenRocket", {"z": (time_data, flight_calisto.z(time_data) + 10)} + ) + + comparator.add_data("RASAero", {"z": (time_data, flight_calisto.z(time_data) - 5)}) + + results = comparator.compare_key_events() + apogee_alt = results["Apogee Altitude (m)"] + # Check both sources are in the results + assert "OpenRocket" in apogee_alt + assert "RASAero" in apogee_alt + + # Check data structure for each source + for src in ("OpenRocket", "RASAero"): + assert "value" in apogee_alt[src] + assert "error" in apogee_alt[src] + assert "error_percentage" in apogee_alt[src] + + +# Test summary method +def test_summary(flight_calisto, capsys): + """Test summary method prints correct information. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + capsys : + Pytest fixture to capture stdout/stderr. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + comparator.add_data("Test", {"z": (time_data, flight_calisto.z(time_data))}) + + comparator.summary() + + captured = capsys.readouterr() + assert "FLIGHT COMPARISON SUMMARY" in captured.out + assert "RocketPy Simulation:" in captured.out + assert "External Data Sources:" in captured.out + assert "Test" in captured.out + + +# Test all method +def test_all_plots(flight_calisto): + """Test that all() method generates plots for common variables. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + + # Add multiple variables + comparator.add_data( + "Simulator", + { + "z": (time_data, flight_calisto.z(time_data)), + "vz": (time_data, flight_calisto.vz(time_data)), + }, + ) + + # This should run without error + comparator.all() + + +def test_all_no_common_variables(flight_calisto, capsys): + """Test all() when no common variables exist. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + capsys : + Pytest fixture to capture stdout/stderr. + """ + comparator = FlightComparator(flight_calisto) + + # Don't add any data + comparator.all() + + captured = capsys.readouterr() + assert "No common variables found" in captured.out + + +# Test trajectories_2d method +@pytest.mark.parametrize("plane", ["xy", "xz", "yz"]) +def test_trajectories_2d_planes(flight_calisto, plane): + """Test 2D trajectory plotting in different planes. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + plane : str + Plane to plot trajectory in ('xy', 'xz', or 'yz'). + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + + # Add trajectory data + comparator.add_data( + "External", + { + "x": (time_data, flight_calisto.x(time_data)), + "y": (time_data, flight_calisto.y(time_data)), + "z": (time_data, flight_calisto.z(time_data)), + }, + ) + + # Should run without error + comparator.trajectories_2d(plane=plane) + + +def test_trajectories_2d_invalid_plane(flight_calisto): + """Test that invalid plane raises ValueError. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + with pytest.raises(ValueError, match="plane must be"): + comparator.trajectories_2d(plane="invalid") + + +def test_trajectories_2d_missing_data_warning(flight_calisto): + """Test warning when external data missing required axes. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + + # Add only 'z', missing 'x', should give a warning + comparator.add_data("Incomplete", {"z": (time_data, flight_calisto.z(time_data))}) + + with pytest.warns(UserWarning, match="No external sources have both"): + comparator.trajectories_2d(plane="xz") + + +@pytest.mark.parametrize("filename", [None, "test_trajectory.png"]) +def test_trajectories_2d_save(flight_calisto, filename): + """Test trajectory plot saving. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + filename : str or None + Filename to save plot to, or None to show plot. + """ + comparator = FlightComparator(flight_calisto) + + time_data = np.linspace(0, flight_calisto.t_final, 100) + + comparator.add_data( + "Test", + { + "x": (time_data, flight_calisto.x(time_data)), + "z": (time_data, flight_calisto.z(time_data)), + }, + ) + + comparator.trajectories_2d(plane="xz", filename=filename) + + if filename: + assert os.path.exists(filename) + os.remove(filename) + + +def test_add_data_with_flight_object(flight_calisto): + """Test adding external data by passing a Flight instance directly. + + Parameters + ---------- + flight_calisto : rocketpy.Flight + Flight object to be tested. See conftest.py for more info. + """ + comparator = FlightComparator(flight_calisto) + + # Use the same Flight as an external "simulator" just to exercise the path + comparator.add_data("Baseline Flight", flight_calisto) + + assert "Baseline Flight" in comparator.data_sources + source = comparator.data_sources["Baseline Flight"] + + # Standard variables should be registered when present + assert "z" in source + assert "vz" in source + assert "altitude" in source # alias to z + assert isinstance(source["z"], Function) + + +def test_compare_with_time_range_valid(flight_calisto): + """Test compare() with a valid time_range.""" + comparator = FlightComparator(flight_calisto) + + t = np.linspace(0, flight_calisto.t_final, 100) + comparator.add_data("Sim", {"z": (t, flight_calisto.z(t))}) + + # Should run without error for a proper sub-range + comparator.compare("z", time_range=(0.1, flight_calisto.t_final - 0.1)) + + +@pytest.mark.parametrize( + "time_range,exc_type", + [ + (("a", "b"), TypeError), + ((1.0, 1.0), ValueError), + ((-0.1, 1.0), ValueError), + ((0.0, 1e9), ValueError), + ("not_a_tuple", TypeError), + ], +) +def test_compare_with_invalid_time_range(flight_calisto, time_range, exc_type): + """Test that invalid time_range raises appropriate errors.""" + comparator = FlightComparator(flight_calisto) + + t = np.linspace(0, flight_calisto.t_final, 100) + comparator.add_data("Sim", {"z": (t, flight_calisto.z(t))}) + + with pytest.raises(exc_type): + comparator.compare("z", time_range=time_range) + + +def test_add_data_with_flight_data_importer(flight_calisto, tmp_path): + """Test adding external data by passing a FlightDataImporter instance.""" + comparator = FlightComparator(flight_calisto) + + # Minimal CSV with time and z + csv_path = tmp_path / "importer_log.csv" + time_data = np.linspace(0, flight_calisto.t_final, 20) + z_data = flight_calisto.z(time_data) + 3.0 + + lines = ["time,z\n"] + for t, z in zip(time_data, z_data): + lines.append(f"{t},{z}\n") + csv_path.write_text("".join(lines), encoding="utf-8") + + importer = FlightDataImporter( + paths=str(csv_path), + columns_map={"time": "time", "z": "z"}, + units=None, + ) + + comparator.add_data("Imported", importer) + + assert "Imported" in comparator.data_sources + source = comparator.data_sources["Imported"] + + # z should be registered, altitude alias should exist + assert "z" in source + assert isinstance(source["z"], Function) + assert "altitude" in source + assert source["altitude"] is source["z"]