diff --git a/im2deep/calibrate.py b/im2deep/calibrate.py index d4dd8e1..fd8c551 100644 --- a/im2deep/calibrate.py +++ b/im2deep/calibrate.py @@ -24,13 +24,13 @@ ... ) """ +from __future__ import annotations + import logging -from typing import Dict, Union, Optional +from typing import cast import numpy as np import pandas as pd -from numpy import ndarray -from psm_utils.peptidoform import Peptidoform from im2deep._exceptions import CalibrationError @@ -40,27 +40,28 @@ def _validate_calibration_inputs( cal_df: pd.DataFrame, reference_dataset: pd.DataFrame, - required_cal_columns: Optional[list] = None, - required_ref_columns: Optional[list] = None, + required_cal_columns: list | None = None, + required_ref_columns: list | None = None, ) -> None: """ Validate input dataframes for calibration functions. Parameters ---------- - cal_df : pd.DataFrame + cal_df Calibration dataset - reference_dataset : pd.DataFrame + reference_dataset Reference dataset - required_cal_columns : list, optional + required_cal_columns Required columns for calibration dataset - required_ref_columns : list, optional + required_ref_columns Required columns for reference dataset Raises ------ CalibrationError If validation fails + """ if cal_df.empty: raise CalibrationError("Calibration dataset is empty") @@ -91,11 +92,11 @@ def get_ccs_shift( Parameters ---------- - cal_df : pd.DataFrame + cal_df PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' - use_charge_state : int, default 2 + use_charge_state Charge state to use for CCS shift calculation. Should be in range [2,4]. Returns @@ -120,6 +121,7 @@ def get_ccs_shift( -------- >>> shift = get_ccs_shift(calibration_df, reference_df, use_charge_state=2) >>> print(f"CCS shift factor: {shift:.2f} Ų") + """ # Validate inputs _validate_calibration_inputs( @@ -187,7 +189,7 @@ def get_ccs_shift( def get_ccs_shift_per_charge( cal_df: pd.DataFrame, reference_dataset: pd.DataFrame -) -> Dict[int, float]: +) -> dict[int, float]: """ Calculate CCS shift factors per charge state. @@ -197,9 +199,9 @@ def get_ccs_shift_per_charge( Parameters ---------- - cal_df : pd.DataFrame + cal_df PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' Returns @@ -228,6 +230,7 @@ def get_ccs_shift_per_charge( >>> shifts = get_ccs_shift_per_charge(calibration_df, reference_df) >>> print(shifts) {2: 5.2, 3: 3.8, 4: 2.1} + """ # Validate inputs _validate_calibration_inputs( @@ -277,9 +280,7 @@ def get_ccs_shift_per_charge( # Check for unreasonably large shifts large_shifts = {k: v for k, v in shift_dict.items() if abs(v) > 100} if large_shifts: - LOGGER.warning( - f"Large CCS shifts detected: {large_shifts}. " "Please verify data quality." - ) + LOGGER.warning(f"Large CCS shifts detected: {large_shifts}. Please verify data quality.") return shift_dict @@ -288,8 +289,8 @@ def calculate_ccs_shift( cal_df: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: Optional[int] = None, -) -> Union[float, Dict[int, float]]: + use_charge_state: int | None = None, +) -> float | dict[int, float]: """ Calculate CCS shift factors with validation and filtering. @@ -299,14 +300,14 @@ def calculate_ccs_shift( Parameters ---------- - cal_df : pd.DataFrame + cal_df PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' - per_charge : bool, default True + per_charge Whether to calculate shift factors per charge state. If False, calculates a single global shift factor using the specified charge state. - use_charge_state : int, optional + use_charge_state Charge state to use for global shift calculation when per_charge=False. Should be in range [2,4]. Default is 2 if not specified. @@ -334,6 +335,7 @@ def calculate_ccs_shift( >>> >>> # Global calibration using charge 2 >>> shift = calculate_ccs_shift(cal_df, ref_df, per_charge=False, use_charge_state=2) + """ # Validate inputs _validate_calibration_inputs(cal_df, reference_dataset) @@ -378,7 +380,7 @@ def linear_calibration( calibration_dataset: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: Optional[int] = None, + use_charge_state: int | None = None, ) -> pd.DataFrame: """ Calibrate CCS predictions using linear calibration. @@ -389,20 +391,20 @@ def linear_calibration( Parameters ---------- - preds_df : pd.DataFrame + preds_df PSMs with CCS predictions. Must contain 'predicted_ccs' column. Will be modified to include 'charge' and 'shift' columns. - calibration_dataset : pd.DataFrame + calibration_dataset Calibration dataset with observed CCS values. Must contain columns: 'peptidoform', 'ccs_observed' - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with CCS values. Must contain columns: 'peptidoform', 'CCS' - per_charge : bool, default True + per_charge Whether to calculate and apply shift factors per charge state. If True, uses charge-specific calibration with fallback to global shift. If False, applies single global shift factor. - use_charge_state : int, optional + use_charge_state Charge state to use for global shift calculation when per_charge=False. Default is 2 if not specified. @@ -446,6 +448,7 @@ def linear_calibration( ... use_charge_state=2 ... ) """ + LOGGER.info("Calibrating CCS values using linear calibration...") # Validate input dataframes @@ -478,7 +481,7 @@ def linear_calibration( ) except (AttributeError, ValueError, IndexError) as e: - raise CalibrationError(f"Error parsing peptidoform data: {e}") + raise CalibrationError(f"Error parsing peptidoform data: {e}") from e if per_charge: LOGGER.info("Calculating general shift factor for fallback...") @@ -489,6 +492,8 @@ def linear_calibration( per_charge=False, use_charge_state=use_charge_state or 2, ) + # per_charge=False returns float + general_shift = cast(float, general_shift) except CalibrationError as e: LOGGER.warning( f"Could not calculate general shift factor: {e}. Using 0.0 as fallback." @@ -499,6 +504,8 @@ def linear_calibration( shift_factor_dict = calculate_ccs_shift( calibration_dataset, reference_dataset, per_charge=True ) + # per_charge=True returns dict[int, float] + shift_factor_dict = cast(dict[int, float], shift_factor_dict) # Add charge information to predictions if not present if "charge" not in preds_df.columns: @@ -525,6 +532,8 @@ def linear_calibration( per_charge=False, use_charge_state=use_charge_state or 2, ) + # per_charge=False returns floats + shift_factor = cast(float, shift_factor) preds_df["predicted_ccs"] += shift_factor preds_df["shift"] = shift_factor LOGGER.info(f"Applied global shift factor: {shift_factor:.3f}") diff --git a/im2deep/im2deep.py b/im2deep/im2deep.py index 4711505..cb3215d 100644 --- a/im2deep/im2deep.py +++ b/im2deep/im2deep.py @@ -28,32 +28,34 @@ >>> predictions = predict_ccs(psm_list, calibration_data, multi=True) """ +from __future__ import annotations + import logging +from os import PathLike from pathlib import Path -import sys -from typing import Optional, Union, List +from typing import cast import pandas as pd from deeplc import DeepLC from psm_utils.psm_list import PSMList +from im2deep._exceptions import IM2DeepError from im2deep.calibrate import linear_calibration from im2deep.utils import ccs2im -from im2deep._exceptions import IM2DeepError LOGGER = logging.getLogger(__name__) REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "reference_ccs.zip" -def _validate_inputs(psm_list_pred: PSMList, output_file: Optional[str] = None) -> None: +def _validate_inputs(psm_list_pred: PSMList, output_file: str | PathLike | None = None) -> None: """ Validate input parameters for prediction. Parameters ---------- - psm_list_pred : PSMList + psm_list_pred PSM list for prediction - output_file : str, optional + output_file Output file path Raises @@ -67,24 +69,24 @@ def _validate_inputs(psm_list_pred: PSMList, output_file: Optional[str] = None) if len(psm_list_pred) == 0: raise IM2DeepError("PSM list for prediction is empty") - if output_file and not isinstance(output_file, (str, Path)): - raise IM2DeepError("output_file must be a string or Path object") + if output_file and not isinstance(output_file, (str, PathLike)): + raise IM2DeepError("output_file must be a string or PathLike object") -def _get_model_paths(model_name: str, use_single_model: bool) -> List[Path]: +def _get_model_paths(model_name: str, use_single_model: bool) -> list[Path]: """ Get model file paths based on model name and configuration. Parameters ---------- - model_name : str + model_name Name of the model ('tims') - use_single_model : bool + use_single_model Whether to use single model or ensemble Returns ------- - List[Path] + list[Path] List of model file paths Raises @@ -120,9 +122,9 @@ def _get_model_paths(model_name: str, use_single_model: bool) -> List[Path]: def _write_output_file( - output_file: str, + output_file: str | PathLike, psm_list_pred_df: pd.DataFrame, - pred_df: Optional[pd.DataFrame] = None, + pred_df: pd.DataFrame | None = None, ion_mobility: bool = False, multi: bool = False, ) -> None: @@ -131,19 +133,24 @@ def _write_output_file( Parameters ---------- - output_file : str + output_file Path to output file - psm_list_pred_df : pd.DataFrame + psm_list_pred_df DataFrame with predictions - pred_df : pd.DataFrame, optional + pred_df Multi-conformer predictions - ion_mobility : bool + ion_mobility Whether to output ion mobility instead of CCS - multi : bool + multi Whether multi-conformer predictions are included """ + if multi and pred_df is None: + raise IM2DeepError("Multi-conformer predictions requested but pred_df is None") + else: + pred_df = cast(pd.DataFrame, pred_df) try: with open(output_file, "w", encoding="utf-8") as f: + # TODO: Consider using dictwriter or Pandas to_csv if ion_mobility: if multi: f.write( @@ -155,6 +162,7 @@ def _write_output_file( psm_list_pred_df["predicted_im"], psm_list_pred_df["predicted_im_multi_1"], psm_list_pred_df["predicted_im_multi_2"], + strict=True, ): f.write(f"{peptidoform},{charge},{IM_single},{IM_multi_1},{IM_multi_2}\n") else: @@ -163,6 +171,7 @@ def _write_output_file( psm_list_pred_df["peptidoform"], psm_list_pred_df["charge"], psm_list_pred_df["predicted_im"], + strict=True, ): f.write(f"{peptidoform},{charge},{IM}\n") else: @@ -176,6 +185,7 @@ def _write_output_file( psm_list_pred_df["predicted_ccs"], pred_df["predicted_ccs_multi_1"], pred_df["predicted_ccs_multi_2"], + strict=True, ): f.write( f"{peptidoform},{charge},{CCS_single},{CCS_multi_1},{CCS_multi_2}\n" @@ -186,31 +196,32 @@ def _write_output_file( psm_list_pred_df["peptidoform"], psm_list_pred_df["charge"], psm_list_pred_df["predicted_ccs"], + strict=True, ): f.write(f"{peptidoform},{charge},{CCS}\n") LOGGER.info(f"Results written to: {output_file}") - except IOError as e: - raise IM2DeepError(f"Failed to write output file {output_file}: {e}") + except OSError as e: + raise IM2DeepError(f"Failed to write output file {output_file}: {e}") from e def predict_ccs( psm_list_pred: PSMList, - psm_list_cal: Optional[Union[PSMList, pd.DataFrame]] = None, - file_reference: Optional[Union[str, Path]] = None, - output_file: Optional[Union[str, Path]] = None, + psm_list_cal: PSMList | pd.DataFrame | None = None, + file_reference: PathLike | None = None, + output_file: PathLike | None = None, model_name: str = "tims", multi: bool = False, calibrate_per_charge: bool = True, use_charge_state: int = 2, use_single_model: bool = True, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, write_output: bool = False, ion_mobility: bool = False, - pred_df: Optional[pd.DataFrame] = None, - cal_df: Optional[pd.DataFrame] = None, -) -> Union[pd.Series, pd.DataFrame]: + pred_df: pd.DataFrame | None = None, + cal_df: pd.DataFrame | None = None, +) -> pd.Series | pd.DataFrame: """ Predict CCS values for peptides using IM2Deep models. @@ -219,42 +230,42 @@ def predict_ccs( Parameters ---------- - psm_list_pred : PSMList + psm_list_pred PSM list containing peptides for CCS prediction. Each PSM should contain a valid peptidoform with sequence and modifications. - psm_list_cal : PSMList or pd.DataFrame, optional + psm_list_cal PSM list or DataFrame for calibration with observed CCS values. If PSMList: CCS values should be in metadata with key "CCS". If DataFrame: should have "ccs_observed" column. Required for calibration. Default is None (no calibration). - file_reference : str or Path, optional + file_reference Path to reference dataset file for calibration. Default uses built-in reference dataset. - output_file : str or Path, optional + output_file Path to write output predictions. If None, no file is written. - model_name : str, default "tims" + model_name Name of the model to use. Currently only "tims" is supported. - multi : bool, default False + multi Whether to include multi-conformer predictions. Requires optional dependencies (torch, im2deeptrainer). - calibrate_per_charge : bool, default True + calibrate_per_charge Whether to perform calibration per charge state. If False, uses global calibration with specified charge state. - use_charge_state : int, default 2 + use_charge_state Charge state to use for global calibration when calibrate_per_charge=False. Should be in range [2,4] for best results. - use_single_model : bool, default True + use_single_model Whether to use a single model (faster) or ensemble of models (potentially more accurate). Single model recommended for most applications. - n_jobs : int, optional + n_jobs Number of parallel jobs for model prediction. If None, uses all available CPUs. - write_output : bool, default False + write_output Whether to write predictions to output file. - ion_mobility : bool, default False + ion_mobility Whether to output ion mobility (1/K0) instead of CCS values. - pred_df : pd.DataFrame, optional + pred_df Pre-computed prediction DataFrame (used internally). - cal_df : pd.DataFrame, optional + cal_df Pre-computed calibration DataFrame (used internally). Returns @@ -279,8 +290,8 @@ def predict_ccs( 5. Convert to ion mobility if requested 6. Write output file if requested - Calibration is highly recommended for accurate predictions and requires - a set of peptides with known CCS values that overlap with the reference dataset. + Calibration is highly recommended for accurate predictions and requires a set of peptides with + known CCS values that overlap with the reference dataset. Examples -------- @@ -320,7 +331,7 @@ def predict_ccs( reference_dataset = pd.read_csv(file_reference) LOGGER.debug(f"Loaded reference dataset with {len(reference_dataset)} entries") except Exception as e: - raise IM2DeepError(f"Failed to load reference dataset from {file_reference}: {e}") + raise IM2DeepError(f"Failed to load reference dataset from {file_reference}: {e}") from e if reference_dataset.empty: raise IM2DeepError("Reference dataset is empty") @@ -329,7 +340,7 @@ def predict_ccs( try: path_model_list = _get_model_paths(model_name, use_single_model) except Exception as e: - raise IM2DeepError(f"Failed to load models: {e}") + raise IM2DeepError(f"Failed to load models: {e}") from e # Initialize DeepLC for CCS prediction try: @@ -338,7 +349,7 @@ def predict_ccs( preds = dlc.make_preds(psm_list=psm_list_pred, calibrate=False) LOGGER.info(f"CCS values predicted for {len(preds)} peptides.") except Exception as e: - raise IM2DeepError(f"CCS prediction failed: {e}") + raise IM2DeepError(f"CCS prediction failed: {e}") from e if len(preds) == 0: raise IM2DeepError("No predictions generated") @@ -351,14 +362,14 @@ def predict_ccs( lambda x: x.precursor_charge ) except Exception as e: - raise IM2DeepError(f"Failed to process predictions: {e}") + raise IM2DeepError(f"Failed to process predictions: {e}") from e # Apply calibration if calibration data provided pred_df = None if psm_list_cal is not None: try: LOGGER.info("Applying calibration...") - + # Handle both PSMList and DataFrame input if isinstance(psm_list_cal, pd.DataFrame): # Input is already a DataFrame with ccs_observed column @@ -375,7 +386,7 @@ def predict_ccs( ccs_values.append(float(psm.metadata["CCS"])) else: ccs_values.append(None) - + # Convert to DataFrame and add CCS values psm_list_cal_df = psm_list_cal.to_dataframe() psm_list_cal_df["ccs_observed"] = ccs_values @@ -413,37 +424,40 @@ def predict_ccs( use_charge_state, ) LOGGER.info("Multiconformational predictions completed.") - except ImportError: + except ImportError as e: raise IM2DeepError( "Multi-conformer prediction requires optional dependencies. " "Please install with: pip install 'im2deep[er]'" - ) + ) from e except Exception as e: - raise IM2DeepError(f"Multi-conformer prediction failed: {e}") + raise IM2DeepError(f"Multi-conformer prediction failed: {e}") from e # Convert to ion mobility if requested if ion_mobility: try: + mz_array = psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz).to_numpy() + charge_array = psm_list_pred_df["charge"].to_numpy() + psm_list_pred_df["predicted_im"] = ccs2im( - psm_list_pred_df["predicted_ccs"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], + psm_list_pred_df["predicted_ccs"].to_numpy(), + mz_array, + charge_array, ) if multi and pred_df is not None: psm_list_pred_df["predicted_im_multi_1"] = ccs2im( - pred_df["predicted_ccs_multi_1"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], + pred_df["predicted_ccs_multi_1"].to_numpy(), + mz_array, + charge_array, ) psm_list_pred_df["predicted_im_multi_2"] = ccs2im( - pred_df["predicted_ccs_multi_2"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], + pred_df["predicted_ccs_multi_2"].to_numpy(), + mz_array, + charge_array, ) except Exception as e: - raise IM2DeepError(f"Ion mobility conversion failed: {e}") + raise IM2DeepError(f"Ion mobility conversion failed: {e}") from e # Write output file if requested if write_output and output_file: diff --git a/im2deep/predict_multi.py b/im2deep/predict_multi.py index 2b8fdfe..3e346a4 100644 --- a/im2deep/predict_multi.py +++ b/im2deep/predict_multi.py @@ -28,17 +28,19 @@ pip install 'im2deep[er]' """ -from pathlib import Path +from __future__ import annotations + import logging -from typing import Dict, Optional, Union +from pathlib import Path +from typing import cast + +import numpy as np +import pandas as pd try: import torch - import pandas as pd - import numpy as np - + from im2deeptrainer.extract_data import _get_matrices # TODO: Should be public function? from im2deeptrainer.model import IM2DeepMultiTransfer - from im2deeptrainer.extract_data import _get_matrices from im2deeptrainer.utils import FlexibleLossSorted TORCH_AVAILABLE = True @@ -50,15 +52,12 @@ FlexibleLossSorted = None TORCH_AVAILABLE = False - import pandas as pd - import numpy as np - +from im2deep._exceptions import CalibrationError, IM2DeepError from im2deep.utils import multi_config -from im2deep._exceptions import IM2DeepError, CalibrationError LOGGER = logging.getLogger(__name__) -MULTI_CKPT_PATH = Path(__file__).parent / "models" / "TIMS_multi" / "multi_output.ckpt" -REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "multi_reference_ccs.gz" +MULTI_CKPT_PATH: Path = Path(__file__).parent / "models" / "TIMS_multi" / "multi_output.ckpt" +REFERENCE_DATASET_PATH: Path = Path(__file__).parent / "reference_data" / "multi_reference_ccs.gz" def _validate_multi_inputs(df_cal: pd.DataFrame, reference_dataset: pd.DataFrame) -> None: @@ -67,9 +66,9 @@ def _validate_multi_inputs(df_cal: pd.DataFrame, reference_dataset: pd.DataFrame Parameters ---------- - df_cal : pd.DataFrame + df_cal Calibration dataset - reference_dataset : pd.DataFrame + reference_dataset Reference dataset Raises @@ -106,13 +105,13 @@ def get_ccs_shift_multi( Parameters ---------- - df_cal : pd.DataFrame + df_cal Calibration peptides with observed CCS values. Must contain columns: 'seq', 'modifications', 'charge', 'ccs_observed' - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with known CCS values. Must contain columns: 'seq', 'modifications', 'charge', 'CCS' - use_charge_state : int, default 2 + use_charge_state Charge state to use for CCS shift calculation. Recommended range [2,4]. Returns @@ -144,7 +143,7 @@ def get_ccs_shift_multi( raise CalibrationError(f"Invalid charge state {use_charge_state}") LOGGER.debug( - f"Using charge state {use_charge_state} for calibration of multiconformer predictions." + f"Using charge state {use_charge_state} for calibration of multi-conformer predictions." ) # Filter by charge state @@ -195,7 +194,7 @@ def get_ccs_shift_multi( def get_ccs_shift_per_charge_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame -) -> Dict[int, float]: +) -> dict[int, float]: """ Calculate CCS shift factors per charge state for multi-conformer predictions. @@ -205,10 +204,10 @@ def get_ccs_shift_per_charge_multi( Parameters ---------- - df_cal : pd.DataFrame + df_cal Calibration peptides with observed CCS values. Must contain columns: 'seq', 'modifications', 'charge', 'ccs_observed' - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with known CCS values. Must contain columns: 'seq', 'modifications', 'charge', 'CCS' @@ -281,8 +280,8 @@ def calculate_ccs_shift_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: Optional[int] = None, -) -> Union[float, Dict[int, float]]: + use_charge_state: int | None = None, +) -> float | dict[int, float]: """ Calculate CCS shift factors for multi-conformer predictions with validation. @@ -292,19 +291,19 @@ def calculate_ccs_shift_multi( Parameters ---------- - df_cal : pd.DataFrame + df_cal Calibration peptides with observed CCS values. - reference_dataset : pd.DataFrame + reference_dataset Reference dataset with known CCS values. - per_charge : bool, default True + per_charge Whether to calculate shift factors per charge state. - use_charge_state : int, optional + use_charge_state Charge state for global calibration when per_charge=False. Default is 2 if not specified. Returns ------- - Union[float, Dict[int, float]] + float | dict[int, float] If per_charge=True: Dictionary of shift factors per charge If per_charge=False: Single global shift factor @@ -367,7 +366,7 @@ def linear_calibration_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: Optional[int] = None, + use_charge_state: int | None = None, ) -> pd.DataFrame: """ Calibrate multi-conformer CCS predictions using linear calibration. @@ -378,16 +377,16 @@ def linear_calibration_multi( Parameters ---------- - df_pred : pd.DataFrame + df_pred DataFrame with multi-conformer CCS predictions. Must contain columns: 'predicted_ccs_multi_1', 'predicted_ccs_multi_2', 'peptidoform' - df_cal : pd.DataFrame + df_cal Calibration dataset with observed CCS values. - reference_dataset : pd.DataFrame + reference_dataset Reference dataset for multi-conformer calibration. - per_charge : bool, default True + per_charge Whether to apply calibration per charge state. - use_charge_state : int, optional + use_charge_state Charge state for global calibration when per_charge=False. Returns @@ -416,7 +415,7 @@ def linear_calibration_multi( ... pred_df, cal_df, ref_df, per_charge=True ... ) """ - LOGGER.info("Calibrating multiconformer predictions using linear calibration...") + LOGGER.info("Calibrating multi-conformer predictions using linear calibration...") if df_pred.empty: raise CalibrationError("Predictions dataframe is empty") @@ -431,16 +430,19 @@ def linear_calibration_multi( try: if per_charge: - LOGGER.info("Generating general shift factor for multiconformer predictions...") + LOGGER.info("Generating general shift factor for multi-conformer predictions...") general_shift = calculate_ccs_shift_multi( df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state or 2 ) + general_shift = cast(float, general_shift) # per_charge=False returns float - LOGGER.info("Getting shift factors per charge state for multiconformer...") + LOGGER.info("Getting shift factors per charge state for multi-conformer...") df_pred["charge"] = df_pred["peptidoform"].apply(lambda x: x.precursor_charge) shift_factor_dict = calculate_ccs_shift_multi( df_cal, reference_dataset, per_charge=True ) + # per_charge=True returns dict[int, float] + shift_factor_dict = cast(dict[int, float], shift_factor_dict) # Apply charge-specific shifts with fallback df_pred["shift_multi"] = df_pred["charge"].map(shift_factor_dict).fillna(general_shift) @@ -455,20 +457,21 @@ def linear_calibration_multi( shift_factor = calculate_ccs_shift_multi( df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state or 2 ) + shift_factor = cast(float, shift_factor) # per_charge=False returns float df_pred["predicted_ccs_multi_1"] = df_pred["predicted_ccs_multi_1"] + shift_factor df_pred["predicted_ccs_multi_2"] = df_pred["predicted_ccs_multi_2"] + shift_factor df_pred["shift_multi"] = shift_factor - LOGGER.info("Multiconformer predictions calibrated successfully.") + LOGGER.info("Multi-conformer predictions calibrated successfully.") return df_pred except Exception as e: - raise CalibrationError(f"Multi-conformer calibration failed: {e}") + raise CalibrationError(f"Multi-conformer calibration failed: {e}") from e def predict_multi( df_pred_psm_list, - df_cal: Optional[pd.DataFrame], + df_cal: pd.DataFrame | None, calibrate_per_charge: bool, use_charge_state: int, ) -> pd.DataFrame: @@ -481,13 +484,13 @@ def predict_multi( Parameters ---------- - df_pred_psm_list : PSMList + df_pred_psm_list PSM list containing peptides for prediction. - df_cal : pd.DataFrame, optional + df_cal Calibration dataset. If provided, predictions will be calibrated. - calibrate_per_charge : bool + calibrate_per_charge Whether to perform per-charge calibration. - use_charge_state : int + use_charge_state Charge state for global calibration. Returns @@ -597,4 +600,4 @@ def predict_multi( return df_pred[["predicted_ccs_multi_1", "predicted_ccs_multi_2"]] except Exception as e: - raise IM2DeepError(f"Multi-conformer prediction failed: {e}") + raise IM2DeepError(f"Multi-conformer prediction failed: {e}") from e diff --git a/im2deep/utils.py b/im2deep/utils.py index c5e937b..4485cd7 100644 --- a/im2deep/utils.py +++ b/im2deep/utils.py @@ -13,8 +13,11 @@ MULTI_BACKBONE_PATH: Path to the multi-conformer model backbone """ +from __future__ import annotations + from pathlib import Path -from typing import Union, Any, Dict +from typing import Any + import numpy as np MULTI_BACKBONE_PATH = ( @@ -23,13 +26,13 @@ def im2ccs( - reverse_im: Union[float, np.ndarray], - mz: Union[float, np.ndarray], - charge: Union[int, np.ndarray], + reverse_im: float | np.ndarray, + mz: float | np.ndarray, + charge: int | np.ndarray, mass_gas: float = 28.013, temp: float = 31.85, t_diff: float = 273.15, -) -> Union[float, np.ndarray]: +) -> float | np.ndarray: """ Convert reduced ion mobility to collisional cross section. @@ -103,13 +106,13 @@ def im2ccs( def ccs2im( - ccs: Union[float, np.ndarray], - mz: Union[float, np.ndarray], - charge: Union[int, np.ndarray], + ccs: float | np.ndarray, + mz: float | np.ndarray, + charge: int | np.ndarray, mass_gas: float = 28.013, temp: float = 31.85, t_diff: float = 273.15, -) -> Union[float, np.ndarray]: +) -> float | np.ndarray: """ Convert collisional cross section to reduced ion mobility. @@ -181,7 +184,7 @@ def ccs2im( # Configuration for multi-conformer model -multi_config: Dict[str, Any] = { +multi_config: dict[str, Any] = { "model_name": "IM2DeepMulti", "batch_size": 16, "learning_rate": 0.0001, diff --git a/pyproject.toml b/pyproject.toml index dd1fc60..b6e0977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,7 @@ classifiers = [ ] dynamic = ["version"] requires-python = ">=3.10" -dependencies = [ - "click", - "deeplc<4", - "psm_utils", - "pandas", - "numpy==1.26.0", - "rich", -] +dependencies = ["click", "deeplc<4", "psm_utils", "pandas", "numpy", "rich"] [project.optional-dependencies] dev = [ @@ -35,7 +28,7 @@ dev = [ "pytest-cov>=4.0", "pytest-mock>=3.10", "mypy>=1.0", - "pre-commit>=3.0" + "pre-commit>=3.0", ] docs = [ "sphinx>=6.0", @@ -45,11 +38,9 @@ docs = [ "toml>=0.10", "semver>=2.13", "sphinx_rtd_theme>=1.2", - "sphinx-autobuild>=2021.3" + "sphinx-autobuild>=2021.3", ] -er = [ - "im2deeptrainer", - "torch==2.3.0"] +er = ["im2deeptrainer", "torch==2.3.0"] [project.urls] GitHub = "https://github.com/CompOmics/IM2Deep" @@ -63,7 +54,7 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.setuptools.dynamic] -version = {attr = "im2deep.__version__"} +version = { attr = "im2deep.__version__" } [tool.isort] profile = "black" @@ -105,40 +96,23 @@ warn_unreachable = true strict_equality = true [[tool.mypy.overrides]] -module = [ - "deeplc.*", - "torch.*", - "im2deeptrainer.*", - "psm_utils.*" -] +module = ["deeplc.*", "torch.*", "im2deeptrainer.*", "psm_utils.*"] ignore_missing_imports = true [tool.pytest.ini_options] minversion = "7.0" -addopts = [ - "-ra", - "--strict-markers", - "--strict-config" -] +addopts = ["-ra", "--strict-markers", "--strict-config"] testpaths = ["tests"] -filterwarnings = [ - "error", - "ignore::UserWarning", - "ignore::DeprecationWarning" -] +filterwarnings = ["error", "ignore::UserWarning", "ignore::DeprecationWarning"] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "integration: marks tests as integration tests", - "unit: marks tests as unit tests" + "unit: marks tests as unit tests", ] [tool.coverage.run] source = ["im2deep"] -omit = [ - "*/tests/*", - "*/test_*", - "*/__main__.py" -] +omit = ["*/tests/*", "*/test_*", "*/__main__.py"] [tool.coverage.report] exclude_lines = [ @@ -151,7 +125,7 @@ exclude_lines = [ "if 0:", "if __name__ == .__main__.:", "class .*\\bProtocol\\):", - "@(abc\\.)?abstractmethod" + "@(abc\\.)?abstractmethod", ] [tool.ruff] @@ -167,7 +141,7 @@ select = [ "UP", # pyupgrade ] ignore = [ - "E501", # line too long, handled by black - "B008", # do not perform function calls in argument defaults - "C901", # too complex + "E501", # line too long, handled by black + "B008", # do not perform function calls in argument defaults + "C901", # too complex ]