diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8c1a2c9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,37 @@ +# Pre-commit configuration for IM2Deep +# Install with: pip install pre-commit && pre-commit install + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: check-merge-conflict + - id: debug-statements + + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: [--max-line-length=99] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.3.0 + hooks: + - id: mypy + additional_dependencies: [types-requests] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..fa87517 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,31 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- This CHANGELOG file +- Comprehensive documentation with API reference, development guide, and tutorial +- Enhanced error handling with custom exceptions +- Input validation throughout the codebase +- Type hints for better code clarity +- Detailed logging with different verbosity levels + +### Changed +- Improved function signatures with better parameter validation +- Enhanced docstrings with NumPy style documentation +- Better error messages with more context +- Improved CLI with better help text and validation +- More robust file handling with proper encoding +- Enhanced calibration functions with edge case handling + +### Fixed +- Import error handling for optional dependencies +- File path validation in CLI +- Memory management improvements +- Edge cases in CCS/ion mobility conversions +- Calibration edge cases with insufficient data diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..56fc296 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,63 @@ +# Contributing to IM2Deep + +We welcome contributions to IM2Deep! This document provides guidelines for contributing to the project. + +## Getting Started + +1. Fork the repository on GitHub +2. Clone your fork locally +3. Set up the development environment +4. Create a feature branch +5. Make your changes +6. Run tests +7. Submit a pull request + +## Development Setup + +```bash +# Clone your fork +git clone https://github.com/yourusername/IM2Deep.git +cd IM2Deep + +# Create virtual environment +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate + +# Install in development mode +pip install -e .[dev,test] +``` + +## Code Standards + +### Style Guide +- Follow PEP 8 +- Use Black for code formatting: `black im2deep/` +- Use isort for imports: `isort im2deep/` +- Maximum line length: 99 characters + +### Documentation +- Use NumPy-style docstrings +- Include type hints +- Provide examples in docstrings +- Update documentation for new features + +## Pull Request Process + +1. Create a feature branch from `main` +2. Make your changes +3. Add tests for new functionality +4. Update documentation +5. Ensure all tests pass +6. Update CHANGELOG.md +7. Submit pull request + +## Code of Conduct + +- Be respectful and inclusive +- Provide constructive feedback +- Focus on the code, not the person +- Help others learn and grow + +## Questions? + +Feel free to open an issue for questions or discussion! diff --git a/im2deep/__init__.py b/im2deep/__init__.py index ec65ce9..35b5085 100644 --- a/im2deep/__init__.py +++ b/im2deep/__init__.py @@ -1,3 +1,54 @@ -"""IM2Deep: Deep learning framework for peptide collisional cross section prediction.""" +""" +IM2Deep: Deep learning framework for peptide collisional cross section prediction. -__version__ = "1.0.3" +IM2Deep is a Python package that provides accurate CCS (Collisional Cross Section) +prediction for peptides and modified peptides using deep learning models trained +specifically for TIMS (Trapped Ion Mobility Spectrometry) data. + +Key Features: + - Single-conformer CCS prediction using ensemble of neural networks + - Multi-conformer CCS prediction for peptides with multiple conformations + - Linear calibration using reference datasets + - Support for modified peptides + - Ion mobility conversion utilities + - Command-line interface for easy usage + +Example: + Basic usage for CCS prediction: + + >>> from im2deep.im2deep import predict_ccs + >>> from psm_utils.psm_list import PSMList + >>> predictions = predict_ccs(psm_list, calibration_data) + +Dependencies: + - deeplc: For deep learning model infrastructure + - psm_utils: For peptide and PSM handling + - pandas: For data manipulation + - numpy: For numerical computations + - click: For command-line interface + +Authors: + - Robbe Devreese + - Robbin Bouwmeester + - Ralf Gabriels + +License: + Apache License 2.0 +""" + +__version__ = "1.1.0" + +# Import main functionality for easier access +from im2deep.im2deep import predict_ccs +from im2deep.calibrate import linear_calibration +from im2deep.utils import ccs2im, im2ccs +from im2deep._exceptions import IM2DeepError, CalibrationError + +__all__ = [ + "predict_ccs", + "linear_calibration", + "ccs2im", + "im2ccs", + "IM2DeepError", + "CalibrationError", +] diff --git a/im2deep/__main__.py b/im2deep/__main__.py index 2d5ec10..e3b23d5 100644 --- a/im2deep/__main__.py +++ b/im2deep/__main__.py @@ -1,4 +1,42 @@ -"""Command line interface to IM2Deep.""" +""" +Command line interface for IM2Deep. + +This module provides a comprehensive command-line interface for the IM2Deep +CCS prediction package. It handles input file parsing, model configuration, +calibration setup, and output generation. + +The CLI supports: +- Multiple input file formats (CSV with seq/modifications or PSM formats) +- Optional calibration using reference datasets +- Single-conformer and multi-conformer predictions +- Ion mobility output conversion +- Ensemble or single model prediction +- Comprehensive logging and error reporting + +Usage: + Basic prediction: + im2deep input_peptides.csv + + With calibration (recommended): + im2deep input_peptides.csv -c calibration_data.csv + + Multi-conformer prediction: + im2deep input_peptides.csv -c calibration_data.csv -e + + Ion mobility output: + im2deep input_peptides.csv -c calibration_data.csv -i + +Dependencies: + - click: Command-line interface framework + - psm_utils: Peptide and PSM data handling + - rich: Enhanced logging and progress display + - pandas: Data manipulation + +Authors: + - Robbe Devreese + - Robbin Bouwmeester + - Ralf Gabriels +""" from __future__ import annotations @@ -10,7 +48,6 @@ import click import pandas as pd -# from deeplc import DeepLC from psm_utils.io import read_file from psm_utils.io.exceptions import PSMUtilsIOException from psm_utils.io.peptide_record import peprec_to_proforma @@ -18,15 +55,25 @@ from psm_utils.psm_list import PSMList from rich.logging import RichHandler - -# from im2deep.calibrate import linear_calibration - REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "reference_ccs.zip" LOGGER = logging.getLogger(__name__) -def setup_logging(passed_level): +def setup_logging(passed_level: str) -> None: + """ + Configure logging with Rich formatting. + + Parameters + ---------- + passed_level : str + Logging level name (debug, info, warning, error, critical) + + Raises + ------ + ValueError + If invalid logging level provided + """ log_mapping = { "debug": logging.DEBUG, "info": logging.INFO, @@ -37,8 +84,7 @@ def setup_logging(passed_level): if passed_level.lower() not in log_mapping: raise ValueError( - f"""Invalid log level: {passed_level}. - Should be one of {log_mapping.keys()}""" + f"Invalid log level: {passed_level}. " f"Should be one of {list(log_mapping.keys())}" ) logging.basicConfig( @@ -49,182 +95,334 @@ def setup_logging(passed_level): ) -def check_optional_dependencies(): +def check_optional_dependencies() -> None: + """ + Check if optional dependencies for multi-conformer prediction are available. + + Raises + ------ + SystemExit + If required dependencies are missing + """ try: import torch import im2deeptrainer + + LOGGER.debug("Optional dependencies for multi-conformer prediction found") except ImportError: LOGGER.error( - "In order to run multiconformational precursor CCS predictions, IM2Deep requires the installation of 'torch' and 'im2deeptrainer'.\nPlease re-install IM2Deep with the optional dependencies by running 'pip install 'im2deep[er]'." + "Multi-conformer prediction requires optional dependencies.\n" + "Please install IM2Deep with optional dependencies:\n" + "pip install 'im2deep[er]'" ) sys.exit(1) -# Command line arguments TODO: Make config_parser script +def _validate_file_format(file_path: str, file_type: str = "input") -> bool: + """ + Validate file format and accessibility. + + Parameters + ---------- + file_path : str + Path to file to validate + file_type : str + Type of file for error messages + + Returns + ------- + bool + True if file is valid + + Raises + ------ + click.ClickException + If file validation fails + """ + path = Path(file_path) + + if not path.exists(): + raise click.ClickException(f"{file_type.capitalize()} file not found: {file_path}") + + if not path.is_file(): + raise click.ClickException(f"{file_type.capitalize()} path is not a file: {file_path}") + + if path.suffix.lower() not in [".csv", ".txt", ".tsv"]: + LOGGER.warning(f"Unexpected file extension for {file_type} file: {path.suffix}") + + try: + with open(file_path, "r", encoding="utf-8") as f: + first_line = f.readline().strip() + if not first_line: + raise click.ClickException(f"{file_type.capitalize()} file appears to be empty") + except Exception as e: + raise click.ClickException(f"Error reading {file_type} file: {e}") + + return True + + +def _parse_csv_input(file_path: str, file_type: str = "prediction") -> PSMList: + """ + Parse CSV input file into PSMList. + + Parameters + ---------- + file_path : str + Path to CSV file + file_type : str + Type of file for error messages + + Returns + ------- + PSMList + Parsed PSM data + + Raises + ------ + click.ClickException + If parsing fails + """ + try: + df = pd.read_csv(file_path) + df = df.fillna("") + + required_cols = ["seq", "modifications", "charge"] + missing_cols = set(required_cols) - set(df.columns) + if missing_cols: + raise click.ClickException( + f"Missing required columns in {file_type} file: {missing_cols}\n" + f"Required columns: {required_cols}" + ) + + if file_type == "calibration" and "CCS" not in df.columns: + raise click.ClickException("Calibration file must contain 'CCS' column") + + list_of_psms = [] + for idx, row in df.iterrows(): + try: + peptidoform = peprec_to_proforma(row["seq"], row["modifications"], row["charge"]) + metadata = {} + if file_type == "calibration" and "CCS" in row: + metadata["CCS"] = float(row["CCS"]) + + psm = PSM(peptidoform=peptidoform, metadata=metadata, spectrum_id=idx) + list_of_psms.append(psm) + except Exception as e: + LOGGER.warning(f"Skipping row {idx} due to parsing error: {e}") + continue + + if not list_of_psms: + raise click.ClickException(f"No valid peptides found in {file_type} file") + + LOGGER.info(f"Parsed {len(list_of_psms)} peptides from {file_type} file") + return PSMList(psm_list=list_of_psms) + + except pd.errors.EmptyDataError: + raise click.ClickException(f"{file_type.capitalize()} file is empty") + except pd.errors.ParserError as e: + raise click.ClickException(f"Error parsing {file_type} file: {e}") + except Exception as e: + raise click.ClickException(f"Unexpected error reading {file_type} file: {e}") + + +# Command line interface with comprehensive options @click.command() -@click.argument("psm-file", type=click.Path(exists=True, dir_okay=False)) +@click.argument("psm-file", type=click.Path(exists=True, dir_okay=False), metavar="INPUT_FILE") @click.option( "-c", "--calibration-file", - type=click.Path(exists=False), + type=click.Path(exists=True, dir_okay=False), default=None, - help="Calibration file name.", + help="Path to calibration file with known CCS values. Highly recommended for accurate predictions.", ) @click.option( "-o", "--output-file", - type=click.Path(exists=False), + type=click.Path(dir_okay=False), default=None, - help="Output file name.", + help="Output file path. If not specified, creates file next to input with '_IM2Deep-predictions.csv' suffix.", ) @click.option( "-m", "--model-name", - type=click.Choice(["tims"]), + type=click.Choice(["tims"], case_sensitive=False), default="tims", - help="Model name.", + help="Neural network model to use for prediction.", ) @click.option( "-e", "--multi", - default=False, is_flag=True, - help="Use multi-conformer model in addition to classical model.", + default=False, + help="Enable multi-conformer prediction. Requires optional dependencies: pip install 'im2deep[er]'", ) @click.option( "-l", "--log-level", - type=click.Choice(["debug", "info", "warning", "error", "critical"]), + type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False), default="info", - help="Logging level.", + help="Set logging verbosity level.", ) @click.option( "-n", "--n-jobs", - type=click.INT, + type=click.IntRange(min=1), default=None, - help="Number of jobs to use for parallel processing.", + help="Number of parallel jobs for model inference. Default uses all available CPU cores.", ) @click.option( "--calibrate-per-charge", type=click.BOOL, default=True, - help="Calibrate CCS values per charge state. Default is True.", + help="Apply calibration per charge state for improved accuracy. Disable for global calibration.", ) @click.option( "--use-charge-state", - type=click.INT, + type=click.IntRange(min=1, max=6), default=2, - help="Charge state to use for calibration. Only used if calibrate_per_charge is set to False.", + help="Charge state for global calibration when --calibrate-per-charge is disabled.", ) @click.option( "--use-single-model", type=click.BOOL, default=True, - help="Use a single model for prediction. If False, an ensemble of models will be used, which may slightly improve prediction accuracy but increase runtimes. Default is True.", + help="Use single model (faster) vs ensemble of models (potentially slightly more accurate).", ) @click.option( "-i", "--ion-mobility", - type=click.BOOL, - default=False, - help="Output predictions in ion mobility (1/K0) instead of CCS. Default is False.", is_flag=True, + default=False, + help="Output ion mobility (1/K0) instead of CCS values.", ) def main( psm_file: str, calibration_file: Optional[str] = None, output_file: Optional[str] = None, - model_name: Optional[str] = "tims", - multi: Optional[bool] = False, - log_level: Optional[str] = "info", + model_name: str = "tims", + multi: bool = False, + log_level: str = "info", n_jobs: Optional[int] = None, - use_single_model: Optional[bool] = True, - calibrate_per_charge: Optional[bool] = True, - use_charge_state: Optional[int] = 2, - ion_mobility: Optional[bool] = False, -): - """Command line interface to IM2Deep.""" - setup_logging(log_level) - - if multi: - check_optional_dependencies() - - from im2deep._exceptions import IM2DeepError - from im2deep.im2deep import predict_ccs - - with open(psm_file) as f: - first_line_pred = f.readline().strip() - if calibration_file: - with open(calibration_file) as fc: - first_line_cal = fc.readline().strip() - - if "modifications" in first_line_pred.split(",") and "seq" in first_line_pred.split(","): - # Read input file - df_pred = pd.read_csv(psm_file) - df_pred.fillna("", inplace=True) + use_single_model: bool = True, + calibrate_per_charge: bool = True, + use_charge_state: int = 2, + ion_mobility: bool = False, +) -> None: + """ + IM2Deep: Predict CCS values for peptides using deep learning. - list_of_psms = [] - for seq, mod, charge, ident in zip( - df_pred["seq"], df_pred["modifications"], df_pred["charge"], df_pred.index - ): - list_of_psms.append( - PSM(peptidoform=peprec_to_proforma(seq, mod, charge), spectrum_id=ident) - ) - psm_list_pred = PSMList(psm_list=list_of_psms) - - else: - try: - psm_list_pred = read_file(psm_file) - except PSMUtilsIOException: - LOGGER.error("Invalid input file. Please check the format of the input file.") - sys.exit(1) - - psm_list_cal = [] - if ( - calibration_file - and "modifications" in first_line_cal.split(",") - and "seq" in first_line_cal.split(",") - ): - try: - df_cal = pd.read_csv(calibration_file) - df_cal.fillna("", inplace=True) - del calibration_file - - list_of_cal_psms = [] - for seq, mod, charge, ident, CCS in zip( - df_cal["seq"], - df_cal["modifications"], - df_cal["charge"], - df_cal.index, - df_cal["CCS"], - ): - list_of_cal_psms.append( - PSM(peptidoform=peprec_to_proforma(seq, mod, charge), spectrum_id=ident) - ) - psm_list_cal = PSMList(psm_list=list_of_cal_psms) - psm_list_cal_df = psm_list_cal.to_dataframe() - psm_list_cal_df["ccs_observed"] = df_cal["CCS"] + IM2Deep predicts Collisional Cross Section (CCS) values for peptides, + including those with post-translational modifications. The tool supports + both single-conformer and multi-conformer predictions with optional + calibration using reference datasets. - except IOError: - LOGGER.error( - "Invalid calibration file. Please check the format of the calibration file." - ) - sys.exit(1) + INPUT_FILE should be a CSV file with columns: + \b + - seq: Peptide sequence (required) + - modifications: Modifications in format "position|name" (required, can be empty) + - charge: Charge state (required) + + For calibration files, an additional 'CCS' column with observed values is required. + + Examples: + \b + # Basic prediction + im2deep peptides.csv + + # With calibration (recommended) + im2deep peptides.csv -c calibration.csv + + # Multi-conformer prediction + im2deep peptides.csv -c calibration.csv -e + + # Ion mobility output + im2deep peptides.csv -c calibration.csv -i + + # Ensemble prediction with custom output + im2deep peptides.csv -c calibration.csv -o results.csv --use-single-model False + """ + try: + # Setup logging first + setup_logging(log_level) - else: - LOGGER.warning( - "No calibration file found. Proceeding without calibration. Calibration is HIGHLY recommended for accurate CCS prediction." + LOGGER.info("IM2Deep command-line interface started") + LOGGER.debug( + f"Input arguments: psm_file={psm_file}, calibration_file={calibration_file}, " + f"multi={multi}, ion_mobility={ion_mobility}" ) - psm_list_cal_df = None + + # Import main functionality (after logging setup) + from im2deep._exceptions import IM2DeepError + from im2deep.im2deep import predict_ccs + + # Check optional dependencies if multi-conformer requested + if multi: + check_optional_dependencies() + + # Validate input files + _validate_file_format(psm_file, "input") + if calibration_file: + _validate_file_format(calibration_file, "calibration") + + # Parse input files + LOGGER.info("Parsing input files...") + + # Try to determine file format + with open(psm_file, "r", encoding="utf-8") as f: + first_line = f.readline().strip() + + # Check if it's the expected CSV format + if "modifications" in first_line and "seq" in first_line: + psm_list_pred = _parse_csv_input(psm_file, "prediction") + df_pred = pd.read_csv(psm_file).fillna("") + else: + # Try psm_utils for other formats + try: + psm_list_pred = read_file(psm_file) + df_pred = None + LOGGER.info(f"Loaded {len(psm_list_pred)} PSMs using psm_utils") + except PSMUtilsIOException as e: + raise click.ClickException( + f"Could not parse input file. Expected CSV with columns 'seq', 'modifications', 'charge' " + f"or a format supported by psm_utils. Error: {e}" + ) + + # Parse calibration file + psm_list_cal = None df_cal = None + if calibration_file: + with open(calibration_file, "r", encoding="utf-8") as f: + cal_first_line = f.readline().strip() - if not output_file: - output_file = Path(psm_file).parent / (Path(psm_file).stem + "_IM2Deep-predictions.csv") - try: + if ( + "modifications" in cal_first_line + and "seq" in cal_first_line + and "CCS" in cal_first_line + ): + psm_list_cal = _parse_csv_input(calibration_file, "calibration") + df_cal = pd.read_csv(calibration_file).fillna("") + else: + raise click.ClickException( + "Calibration file must be CSV with columns: 'seq', 'modifications', 'charge', 'CCS'" + ) + else: + LOGGER.warning( + "No calibration file provided. Predictions will be uncalibrated. " + "Calibration is HIGHLY recommended for accurate results." + ) + + # Set up output file + if not output_file: + input_path = Path(psm_file) + output_file = input_path.parent / f"{input_path.stem}_IM2Deep-predictions.csv" + + LOGGER.info(f"Output will be written to: {output_file}") + + # Run prediction + LOGGER.info("Starting CCS prediction...") predict_ccs( psm_list_pred, - psm_list_cal_df, + psm_list_cal, output_file=output_file, model_name=model_name, multi=multi, @@ -235,9 +433,21 @@ def main( ion_mobility=ion_mobility, pred_df=df_pred, cal_df=df_cal, + write_output=True, ) + + LOGGER.info("IM2Deep completed successfully!") + except IM2DeepError as e: - LOGGER.error(e) + LOGGER.error(f"IM2Deep error: {e}") + sys.exit(1) + except click.ClickException: + # Re-raise click exceptions to preserve formatting + raise + except Exception as e: + LOGGER.error(f"Unexpected error: {e}") + if log_level.lower() == "debug": + LOGGER.exception("Full traceback:") sys.exit(1) diff --git a/im2deep/_exceptions.py b/im2deep/_exceptions.py index caea6ac..525b5b0 100644 --- a/im2deep/_exceptions.py +++ b/im2deep/_exceptions.py @@ -1,8 +1,53 @@ -"""IM2Deep exceptions.""" +""" +Custom exceptions for IM2Deep package. + +This module defines custom exception classes used throughout the IM2Deep package +for better error handling and debugging. + +Classes: + IM2DeepError: Base exception class for all IM2Deep-related errors + CalibrationError: Exception raised when calibration-related errors occur +""" + class IM2DeepError(Exception): + """ + Base exception class for all IM2Deep-related errors. + + This exception serves as the base class for all custom exceptions + in the IM2Deep package, allowing users to catch all package-specific + errors with a single except clause. + + Attributes: + message (str): Human readable string describing the exception. + + Example: + >>> try: + ... predict_ccs(invalid_data) + ... except IM2DeepError as e: + ... print(f"IM2Deep error occurred: {e}") + """ pass class CalibrationError(IM2DeepError): + """ + Exception raised when calibration-related errors occur. + + This exception is raised when there are issues with calibration data, + reference datasets, or calibration procedures that prevent successful + CCS calibration. + + Common scenarios: + - Insufficient overlapping peptides between calibration and reference data + - Invalid calibration file format + - Missing required columns in calibration data + - Numerical issues during calibration calculation + + Example: + >>> try: + ... linear_calibration(pred_df, cal_df, ref_df) + ... except CalibrationError as e: + ... print(f"Calibration failed: {e}") + """ pass diff --git a/im2deep/calibrate.py b/im2deep/calibrate.py index 091abbd..d4dd8e1 100644 --- a/im2deep/calibrate.py +++ b/im2deep/calibrate.py @@ -1,36 +1,154 @@ +""" +Calibration functions for CCS predictions in IM2Deep. + +This module provides functions for calibrating CCS predictions using reference datasets. Calibration is performed by calculating +shift factors based on overlapping peptides between calibration and reference data. + +The calibration process involves: +1. Finding overlapping peptide-charge pairs between calibration and reference datasets +2. Calculating mean CCS differences (shift factors) +3. Applying shifts to predictions either globally or per charge state + +Functions: + get_ccs_shift: Calculate global CCS shift factor for a specific charge state + get_ccs_shift_per_charge: Calculate CCS shift factors per charge state + calculate_ccs_shift: Wrapper function for shift calculation with validation + linear_calibration: Apply linear calibration to CCS predictions + +Example: + >>> calibrated_df = linear_calibration( + ... predictions_df, + ... calibration_df, + ... reference_df, + ... per_charge=True + ... ) +""" + import logging +from typing import Dict, Union, Optional import numpy as np import pandas as pd from numpy import ndarray from psm_utils.peptidoform import Peptidoform +from im2deep._exceptions import CalibrationError + LOGGER = logging.getLogger(__name__) +def _validate_calibration_inputs( + cal_df: pd.DataFrame, + reference_dataset: pd.DataFrame, + required_cal_columns: Optional[list] = None, + required_ref_columns: Optional[list] = None, +) -> None: + """ + Validate input dataframes for calibration functions. + + Parameters + ---------- + cal_df : pd.DataFrame + Calibration dataset + reference_dataset : pd.DataFrame + Reference dataset + required_cal_columns : list, optional + Required columns for calibration dataset + required_ref_columns : list, optional + Required columns for reference dataset + + Raises + ------ + CalibrationError + If validation fails + """ + if cal_df.empty: + raise CalibrationError("Calibration dataset is empty") + if reference_dataset.empty: + raise CalibrationError("Reference dataset is empty") + + if required_cal_columns: + missing_cols = set(required_cal_columns) - set(cal_df.columns) + if missing_cols: + raise CalibrationError(f"Missing columns in calibration data: {missing_cols}") + + if required_ref_columns: + missing_cols = set(required_ref_columns) - set(reference_dataset.columns) + if missing_cols: + raise CalibrationError(f"Missing columns in reference data: {missing_cols}") + + def get_ccs_shift( cal_df: pd.DataFrame, reference_dataset: pd.DataFrame, use_charge_state: int = 2 ) -> float: """ - Calculate CCS shift factor, i.e. a constant offset, - based on identical precursors as in reference dataset. + Calculate CCS shift factor for a specific charge state. + + This function calculates a constant offset based on identical precursors + between calibration and reference datasets for a specific charge state. + The shift represents how much the calibration CCS values differ from + reference CCS values on average. Parameters ---------- - cal_df - PSMs with CCS values. - reference_dataset - Reference dataset with CCS values. - use_charge_state - Charge state to use for CCS shift calculation, needs to be [2,4], by default 2. - return_shift_factor - CCS shift factor. + cal_df : pd.DataFrame + PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' + reference_dataset : pd.DataFrame + Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' + use_charge_state : int, default 2 + Charge state to use for CCS shift calculation. Should be in range [2,4]. + + Returns + ------- + float + CCS shift factor. Positive values indicate calibration CCS is higher + than reference CCS on average. + + Raises + ------ + CalibrationError + If charge state is invalid or no overlapping peptides found + + Notes + ----- + The function: + 1. Filters both datasets to the specified charge state + 2. Merges on sequence and charge to find overlapping peptides + 3. Calculates mean difference: mean(ccs_observed - CCS_reference) + Examples + -------- + >>> shift = get_ccs_shift(calibration_df, reference_df, use_charge_state=2) + >>> print(f"CCS shift factor: {shift:.2f} Ų") """ + # Validate inputs + _validate_calibration_inputs( + cal_df, + reference_dataset, + required_cal_columns=["sequence", "charge", "ccs_observed"], + required_ref_columns=["peptidoform", "charge", "CCS"], + ) + + if not 1 <= use_charge_state <= 6: + raise CalibrationError( + f"Invalid charge state {use_charge_state}. Should be between 1 and 6." + ) + LOGGER.debug(f"Using charge state {use_charge_state} for CCS shift calculation.") + # Filter data by charge state reference_tmp = reference_dataset[reference_dataset["charge"] == use_charge_state] df_tmp = cal_df[cal_df["charge"] == use_charge_state] + + if reference_tmp.empty: + LOGGER.warning(f"No reference data found for charge state {use_charge_state}") + return 0.0 + + if df_tmp.empty: + LOGGER.warning(f"No calibration data found for charge state {use_charge_state}") + return 0.0 + + # Merge datasets to find overlapping peptides both = pd.merge( left=reference_tmp, right=df_tmp, @@ -39,36 +157,87 @@ def get_ccs_shift( how="inner", suffixes=("_ref", "_data"), ) + LOGGER.debug( - """Calculating CCS shift based on {} overlapping peptide-charge pairs - between PSMs and reference dataset""".format( - both.shape[0] - ) + f"Calculating CCS shift based on {both.shape[0]} overlapping peptide-charge pairs " + f"between PSMs and reference dataset" ) - # How much CCS in calibration data is larger than reference CCS, so predictions - # need to be increased by this amount - return 0 if both.empty else np.mean(both["ccs_observed"] - both["CCS"]) + if both.empty: + LOGGER.warning("No overlapping peptides found between calibration and reference data") + return 0.0 + + if both.shape[0] < 10: + LOGGER.warning( + f"Only {both.shape[0]} overlapping peptides found. " + "Consider using more calibration data for reliable results." + ) + + # Calculate shift: how much calibration CCS is larger than reference CCS + shift = np.mean(both["ccs_observed"] - both["CCS"]) + + if abs(shift) > 100: # Sanity check for unreasonably large shifts + LOGGER.warning( + f"Large CCS shift detected ({shift:.2f} Å^2). " + "Please verify calibration and reference data quality." + ) + + return float(shift) -def get_ccs_shift_per_charge(cal_df: pd.DataFrame, reference_dataset: pd.DataFrame) -> ndarray: +def get_ccs_shift_per_charge( + cal_df: pd.DataFrame, reference_dataset: pd.DataFrame +) -> Dict[int, float]: """ - Calculate CCS shift factor per charge state, - i.e. a constant offset based on identical precursors as in reference. + Calculate CCS shift factors per charge state. + + This function calculates individual shift factors for each charge state + present in both calibration and reference datasets. This allows for + charge-specific calibration which often improves accuracy. Parameters ---------- - cal_df - PSMs with CCS values. - reference_dataset - Reference dataset with CCS values. + cal_df : pd.DataFrame + PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' + reference_dataset : pd.DataFrame + Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' Returns ------- - ndarray - CCS shift factors per charge state. + Dict[int, float] + Dictionary mapping charge states to their shift factors. + Keys are charge states (int), values are shift factors (float). + + Raises + ------ + CalibrationError + If required columns are missing or no overlapping data found + + Notes + ----- + The function: + 1. Merges calibration and reference data on sequence and charge + 2. Groups by charge state + 3. Calculates mean difference for each charge state + Charge states with insufficient data (< 5 overlapping peptides) will be + logged as warnings but still included in results. + + Examples + -------- + >>> shifts = get_ccs_shift_per_charge(calibration_df, reference_df) + >>> print(shifts) + {2: 5.2, 3: 3.8, 4: 2.1} """ + # Validate inputs + _validate_calibration_inputs( + cal_df, + reference_dataset, + required_cal_columns=["sequence", "charge", "ccs_observed"], + required_ref_columns=["peptidoform", "charge", "CCS"], + ) + + # Merge datasets to find overlapping peptides both = pd.merge( left=reference_dataset, right=cal_df, @@ -77,46 +246,130 @@ def get_ccs_shift_per_charge(cal_df: pd.DataFrame, reference_dataset: pd.DataFra how="inner", suffixes=("_ref", "_data"), ) - return both.groupby("charge").apply(lambda x: np.mean(x["ccs_observed"] - x["CCS"])).to_dict() + + if both.empty: + raise CalibrationError( + "No overlapping peptides found between calibration and reference data" + ) + + LOGGER.debug(f"Found {both.shape[0]} total overlapping peptide-charge pairs") + + # Check data distribution across charge states + charge_counts = both.groupby("charge").size() + LOGGER.debug(f"Peptides per charge state: {charge_counts.to_dict()}") + + # Warn about charge states with low data + low_data_charges = charge_counts[charge_counts < 5].index.tolist() + if low_data_charges: + LOGGER.warning( + f"Charge states with <5 peptides: {low_data_charges}. " + "Consider using global calibration for these charges." + ) + + # Calculate shift per charge state + shift_dict = ( + both.groupby("charge").apply(lambda x: np.mean(x["ccs_observed"] - x["CCS"])).to_dict() + ) + + # Convert numpy types to native Python types for JSON serialization + shift_dict = {int(k): float(v) for k, v in shift_dict.items()} + + # Check for unreasonably large shifts + large_shifts = {k: v for k, v in shift_dict.items() if abs(v) > 100} + if large_shifts: + LOGGER.warning( + f"Large CCS shifts detected: {large_shifts}. " "Please verify data quality." + ) + + return shift_dict def calculate_ccs_shift( - cal_df: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge=True, use_charge_state=None -) -> float: + cal_df: pd.DataFrame, + reference_dataset: pd.DataFrame, + per_charge: bool = True, + use_charge_state: Optional[int] = None, +) -> Union[float, Dict[int, float]]: """ - Apply CCS shift to CCS values. + Calculate CCS shift factors with validation and filtering. + + This is the main interface for calculating CCS shift factors. It provides + input validation, charge filtering, and can return either global or + per-charge shift factors. Parameters ---------- - cal_df - PSMs with CCS values. - reference_dataset - Reference dataset with CCS values. - per_charge - Whether to calculate shift factor per charge state, default True. - use_charge_state - Charge state to use for CCS shift calculation, needs to be [2,4], by default None. + cal_df : pd.DataFrame + PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed' + reference_dataset : pd.DataFrame + Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS' + per_charge : bool, default True + Whether to calculate shift factors per charge state. If False, calculates + a single global shift factor using the specified charge state. + use_charge_state : int, optional + Charge state to use for global shift calculation when per_charge=False. + Should be in range [2,4]. Default is 2 if not specified. Returns ------- - float - CCS shift factor. + Union[float, Dict[int, float]] + If per_charge=True: Dictionary mapping charge states to shift factors + If per_charge=False: Single shift factor (float) + + Raises + ------ + CalibrationError + If validation fails or invalid parameters provided + Notes + ----- + The function automatically filters out charges >6 as IM2Deep predictions + are not reliable for very high charge states. A warning is logged if + any peptides are filtered out. + + Examples + -------- + >>> # Per-charge calibration + >>> shifts = calculate_ccs_shift(cal_df, ref_df, per_charge=True) + >>> + >>> # Global calibration using charge 2 + >>> shift = calculate_ccs_shift(cal_df, ref_df, per_charge=False, use_charge_state=2) """ - cal_df = cal_df[cal_df["charge"] < 7] # predictions do not go higher for IM2Deep + # Validate inputs + _validate_calibration_inputs(cal_df, reference_dataset) - if not per_charge: - shift_factor = get_ccs_shift( - cal_df, - reference_dataset, - use_charge_state=use_charge_state, + if use_charge_state is not None and not 1 <= use_charge_state <= 6: + raise CalibrationError( + f"Invalid charge state {use_charge_state}. Should be between 1 and 6." ) - LOGGER.debug(f"CCS shift factor: {shift_factor}") - return shift_factor + # Filter high charge states (IM2Deep predictions are unreliable >6) + original_size = len(cal_df) + cal_df = cal_df[cal_df["charge"] < 7].copy() + + if len(cal_df) < original_size: + filtered_count = original_size - len(cal_df) + LOGGER.info( + f"Filtered out {filtered_count} peptides with charge >6 " + "(predictions not reliable for z>6)" + ) + + if cal_df.empty: + raise CalibrationError("No valid calibration data remaining after filtering") + + if not per_charge: + # Global calibration using specified charge state + if use_charge_state is None: + use_charge_state = 2 + LOGGER.debug("No charge state specified for global calibration, using charge 2") + + shift_factor = get_ccs_shift(cal_df, reference_dataset, use_charge_state) + LOGGER.debug(f"Global CCS shift factor: {shift_factor:.3f}") + return shift_factor else: + # Per-charge calibration shift_factor_dict = get_ccs_shift_per_charge(cal_df, reference_dataset) - LOGGER.debug(f"CCS shift factor dict: {shift_factor_dict}") + LOGGER.debug(f"CCS shift factors per charge: {shift_factor_dict}") return shift_factor_dict @@ -125,66 +378,156 @@ def linear_calibration( calibration_dataset: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: int = None, + use_charge_state: Optional[int] = None, ) -> pd.DataFrame: """ - Calibrate PSM df using linear calibration. + Calibrate CCS predictions using linear calibration. + + This function performs linear calibration of CCS predictions by applying + shift factors calculated from overlapping peptides between calibration + and reference datasets. Calibration can be applied globally or per charge state. Parameters ---------- - preds_df - PSMs with CCS values. - calibration_dataset - Calibration dataset with CCS values. - reference_dataset - Reference dataset with CCS values. - per_charge - Whether to calculate shift factor per charge state, default True. - use_charge_state - Charge state to use for CCS shift calculation, needs to be [2,4], by default None. + preds_df : pd.DataFrame + PSMs with CCS predictions. Must contain 'predicted_ccs' column. + Will be modified to include 'charge' and 'shift' columns. + calibration_dataset : pd.DataFrame + Calibration dataset with observed CCS values. Must contain columns: + 'peptidoform', 'ccs_observed' + reference_dataset : pd.DataFrame + Reference dataset with CCS values. Must contain columns: + 'peptidoform', 'CCS' + per_charge : bool, default True + Whether to calculate and apply shift factors per charge state. + If True, uses charge-specific calibration with fallback to global shift. + If False, applies single global shift factor. + use_charge_state : int, optional + Charge state to use for global shift calculation when per_charge=False. + Default is 2 if not specified. Returns ------- pd.DataFrame - PSMs with calibrated CCS values. + Calibrated PSMs with updated 'predicted_ccs' values and added 'shift' column. + + Raises + ------ + CalibrationError + If calibration fails due to data issues or missing columns + + Notes + ----- + The calibration process: + 1. Extracts sequence and charge information from peptidoforms + 2. Calculates shift factors from calibration vs reference data + 3. Applies shifts to predictions + 4. For per-charge calibration: uses charge-specific shifts with global fallback + + Per-charge calibration is recommended as it typically provides better accuracy + by accounting for charge-dependent systematic biases. + Examples + -------- + >>> # Per-charge calibration (recommended) + >>> calibrated_df = linear_calibration( + ... predictions_df, + ... calibration_df, + ... reference_df, + ... per_charge=True + ... ) + >>> + >>> # Global calibration using charge 2 + >>> calibrated_df = linear_calibration( + ... predictions_df, + ... calibration_df, + ... reference_df, + ... per_charge=False, + ... use_charge_state=2 + ... ) """ LOGGER.info("Calibrating CCS values using linear calibration...") - calibration_dataset["sequence"] = calibration_dataset["peptidoform"].apply( - lambda x: x.proforma.split("\\")[0] - ) - calibration_dataset["charge"] = calibration_dataset["peptidoform"].apply( - lambda x: x.precursor_charge - ) - # reference_dataset['sequence'] = reference_dataset['peptidoform'].apply(lambda x: x.split('/')[0]) - reference_dataset["charge"] = reference_dataset["peptidoform"].apply( - lambda x: int(x.split("/")[1]) - ) - if per_charge: - LOGGER.info("Getting general shift factor") - general_shift = calculate_ccs_shift( - calibration_dataset, - reference_dataset, - per_charge=False, - use_charge_state=use_charge_state, + # Validate input dataframes + if preds_df.empty: + raise CalibrationError("Predictions dataframe is empty") + if "predicted_ccs" not in preds_df.columns: + raise CalibrationError("Predictions dataframe missing 'predicted_ccs' column") + + # Create working copy to avoid modifying original + preds_df = preds_df.copy() + calibration_dataset = calibration_dataset.copy() + reference_dataset = reference_dataset.copy() + + try: + # Extract sequence and charge from calibration peptidoforms + LOGGER.debug("Extracting sequence and charge from calibration peptidoforms...") + calibration_dataset["sequence"] = calibration_dataset["peptidoform"].apply( + lambda x: x.proforma.split("\\")[0] if hasattr(x, "proforma") else str(x).split("/")[0] + ) + calibration_dataset["charge"] = calibration_dataset["peptidoform"].apply( + lambda x: ( + x.precursor_charge if hasattr(x, "precursor_charge") else int(str(x).split("/")[1]) + ) + ) + + # Extract charge from reference peptidoforms + LOGGER.debug("Extracting charge from reference peptidoforms...") + reference_dataset["charge"] = reference_dataset["peptidoform"].apply( + lambda x: int(x.split("/")[1]) if isinstance(x, str) else x.precursor_charge ) - LOGGER.info("Getting shift factors per charge state") + + except (AttributeError, ValueError, IndexError) as e: + raise CalibrationError(f"Error parsing peptidoform data: {e}") + + if per_charge: + LOGGER.info("Calculating general shift factor for fallback...") + try: + general_shift = calculate_ccs_shift( + calibration_dataset, + reference_dataset, + per_charge=False, + use_charge_state=use_charge_state or 2, + ) + except CalibrationError as e: + LOGGER.warning( + f"Could not calculate general shift factor: {e}. Using 0.0 as fallback." + ) + general_shift = 0.0 + + LOGGER.info("Calculating shift factors per charge state...") shift_factor_dict = calculate_ccs_shift( calibration_dataset, reference_dataset, per_charge=True ) + # Add charge information to predictions if not present + if "charge" not in preds_df.columns: + preds_df["charge"] = preds_df["peptidoform"].apply( + lambda x: x.precursor_charge if hasattr(x, "precursor_charge") else 2 + ) + + # Apply charge-specific shifts with fallback to general shift preds_df["shift"] = preds_df["charge"].map(shift_factor_dict).fillna(general_shift) preds_df["predicted_ccs"] = preds_df["predicted_ccs"] + preds_df["shift"] + # Log calibration statistics + used_charges = set(shift_factor_dict.keys()) + fallback_charges = set(preds_df[preds_df["shift"] == general_shift]["charge"].unique()) + if fallback_charges: + LOGGER.info(f"Used charge-specific calibration for charges: {sorted(used_charges)}") + LOGGER.info(f"Used fallback calibration for charges: {sorted(fallback_charges)}") + else: + # Global calibration shift_factor = calculate_ccs_shift( calibration_dataset, reference_dataset, per_charge=False, - use_charge_state=use_charge_state, + use_charge_state=use_charge_state or 2, ) preds_df["predicted_ccs"] += shift_factor + preds_df["shift"] = shift_factor + LOGGER.info(f"Applied global shift factor: {shift_factor:.3f}") - LOGGER.info("CCS values calibrated.") + LOGGER.info("CCS values calibrated successfully.") return preds_df diff --git a/im2deep/im2deep.py b/im2deep/im2deep.py index 63b836a..0ed7520 100644 --- a/im2deep/im2deep.py +++ b/im2deep/im2deep.py @@ -1,5 +1,37 @@ +""" +Main CCS prediction module for IM2Deep. + +This module provides the core functionality for predicting Collisional Cross Section (CCS) +values for peptides using deep learning models. It supports both single-conformer and +multi-conformer predictions with optional calibration. + +The module handles: +- Loading and running neural network models for CCS prediction +- Calibrating predictions using reference datasets +- Converting between CCS and ion mobility +- Outputting results in various formats + +Functions: + predict_ccs: Main function for CCS prediction with optional calibration + +Dependencies: + - deeplc: For neural network model infrastructure + - psm_utils: For peptide data handling + - pandas/numpy: For data manipulation + +Example: + Basic CCS prediction: + >>> from im2deep.im2deep import predict_ccs + >>> predictions = predict_ccs(psm_list, calibration_data) + + Multi-conformer prediction: + >>> predictions = predict_ccs(psm_list, calibration_data, multi=True) +""" + import logging from pathlib import Path +import sys +from typing import Optional, Union, List import pandas as pd from deeplc import DeepLC @@ -7,108 +39,113 @@ from im2deep.calibrate import linear_calibration from im2deep.utils import ccs2im +from im2deep._exceptions import IM2DeepError LOGGER = logging.getLogger(__name__) REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "reference_ccs.zip" -# TODO: get file reading out of the function -def predict_ccs( - psm_list_pred: PSMList, - psm_list_cal_df=None, - file_reference=REFERENCE_DATASET_PATH, - output_file=None, - model_name="tims", - multi=False, - calibrate_per_charge=True, - use_charge_state=2, - use_single_model=True, - n_jobs=None, - write_output=True, - ion_mobility=False, - pred_df=None, - cal_df=None, -): - """Run IM2Deep.""" - LOGGER.info("IM2Deep started.") - reference_dataset = pd.read_csv(file_reference) +def _validate_inputs(psm_list_pred: PSMList, output_file: Optional[str] = None) -> None: + """ + Validate input parameters for prediction. + + Parameters + ---------- + psm_list_pred : PSMList + PSM list for prediction + output_file : str, optional + Output file path + + Raises + ------ + IM2DeepError + If validation fails + """ + if not isinstance(psm_list_pred, PSMList): + raise IM2DeepError("psm_list_pred must be a PSMList instance") + + if len(psm_list_pred) == 0: + raise IM2DeepError("PSM list for prediction is empty") + + if output_file and not isinstance(output_file, (str, Path)): + raise IM2DeepError("output_file must be a string or Path object") + +def _get_model_paths(model_name: str, use_single_model: bool) -> List[Path]: + """ + Get model file paths based on model name and configuration. + + Parameters + ---------- + model_name : str + Name of the model ('tims') + use_single_model : bool + Whether to use single model or ensemble + + Returns + ------- + List[Path] + List of model file paths + + Raises + ------ + IM2DeepError + If model files not found + """ if model_name == "tims": path_model = Path(__file__).parent / "models" / "TIMS" + else: + raise IM2DeepError(f"Unsupported model name: {model_name}") + + if not path_model.exists(): + raise IM2DeepError(f"Model directory not found: {path_model}") path_model_list = list(path_model.glob("*.keras")) - if use_single_model: - LOGGER.debug("Using model {}".format(path_model_list[2])) - path_model_list = [path_model_list[2]] - - dlc = DeepLC(path_model=path_model_list, n_jobs=n_jobs, predict_ccs=True) - LOGGER.info("Predicting CCS values...") - preds = dlc.make_preds(psm_list=psm_list_pred, calibrate=False) - LOGGER.info("CCS values predicted.") - psm_list_pred_df = psm_list_pred.to_dataframe() - psm_list_pred_df["predicted_ccs"] = preds - psm_list_pred_df["charge"] = psm_list_pred_df["peptidoform"].apply( - lambda x: x.precursor_charge - ) - - if psm_list_cal_df is not None: - psm_list_pred_df = linear_calibration( - psm_list_pred_df, - calibration_dataset=psm_list_cal_df, - reference_dataset=reference_dataset, - per_charge=calibrate_per_charge, - use_charge_state=use_charge_state, - ) - if multi: - from im2deep.predict_multi import predict_multi - - LOGGER.info("Predicting multiconformer CCS values...") - pred_df = predict_multi( - psm_list_pred, - cal_df, - calibrate_per_charge, - use_charge_state, - ) + if not path_model_list: + raise IM2DeepError(f"No model files found in {path_model}") - if ion_mobility: - if not multi: - psm_list_pred_df["predicted_im"] = ccs2im( - psm_list_pred_df["predicted_ccs"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], - ) - if write_output: - LOGGER.info("Writing output file for ion mobility prediction...") - with open(output_file, "w") as f: - f.write("modified_seq,charge,predicted IM\n") - for peptidoform, charge, IM in zip( - psm_list_pred_df["peptidoform"], - psm_list_pred_df["charge"], - psm_list_pred_df["predicted_im"], - ): - f.write(f"{peptidoform},{charge},{IM}\n") - LOGGER.info("IM2Deep finished!") - return psm_list_pred_df["predicted_im"] + if use_single_model: + # Use the third model by default (index 2) for consistency + if len(path_model_list) > 2: + selected_model = path_model_list[2] + LOGGER.debug(f"Using single model: {selected_model}") + return [selected_model] else: - psm_list_pred_df["predicted_im"] = ccs2im( - psm_list_pred_df["predicted_ccs"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], - ) - psm_list_pred_df["predicted_im_multi_1"] = ccs2im( - pred_df["predicted_ccs_multi_1"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], - ) - psm_list_pred_df["predicted_im_multi_2"] = ccs2im( - pred_df["predicted_ccs_multi_2"], - psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), - psm_list_pred_df["charge"], - ) - if write_output: - LOGGER.info("Writing output file for multi-ion mobility prediction...") - with open(output_file, "w") as f: + LOGGER.warning("Less than 3 models available, using first model") + return [path_model_list[0]] + else: + LOGGER.debug(f"Using ensemble of {len(path_model_list)} models") + return path_model_list + + +def _write_output_file( + output_file: str, + psm_list_pred_df: pd.DataFrame, + pred_df: Optional[pd.DataFrame] = None, + ion_mobility: bool = False, + multi: bool = False, +) -> None: + """ + Write predictions to output file. + + Parameters + ---------- + output_file : str + Path to output file + psm_list_pred_df : pd.DataFrame + DataFrame with predictions + pred_df : pd.DataFrame, optional + Multi-conformer predictions + ion_mobility : bool + Whether to output ion mobility instead of CCS + multi : bool + Whether multi-conformer predictions are included + """ + try: + with open(output_file, "w", encoding="utf-8") as f: + if ion_mobility: + if multi: f.write( "modified_seq,charge,predicted IM single,predicted IM multi 1,predicted IM multi 2\n" ) @@ -120,26 +157,16 @@ def predict_ccs( psm_list_pred_df["predicted_im_multi_2"], ): f.write(f"{peptidoform},{charge},{IM_single},{IM_multi_1},{IM_multi_2}\n") - LOGGER.info("IM2Deep finished!") - return psm_list_pred_df["predicted_im"] - else: - if not multi: - if write_output: - LOGGER.info("Writing output file for CCS prediction...") - with open(output_file, "w") as f: - f.write("modified_seq,charge,predicted CCS\n") - for peptidoform, charge, CCS in zip( + else: + f.write("modified_seq,charge,predicted IM\n") + for peptidoform, charge, IM in zip( psm_list_pred_df["peptidoform"], psm_list_pred_df["charge"], - psm_list_pred_df["predicted_ccs"], + psm_list_pred_df["predicted_im"], ): - f.write(f"{peptidoform},{charge},{CCS}\n") - LOGGER.info("IM2Deep finished!") - return psm_list_pred_df["predicted_ccs"] - else: - if write_output: - LOGGER.info("Writing output file for multi-CCS prediction...") - with open(output_file, "w") as f: + f.write(f"{peptidoform},{charge},{IM}\n") + else: + if multi: f.write( "modified_seq,charge,predicted CCS single,predicted CCS multi 1,predicted CCS multi 2\n" ) @@ -153,5 +180,264 @@ def predict_ccs( f.write( f"{peptidoform},{charge},{CCS_single},{CCS_multi_1},{CCS_multi_2}\n" ) - LOGGER.info("IM2Deep finished!") - return psm_list_pred_df["predicted_ccs"] + else: + f.write("modified_seq,charge,predicted CCS\n") + for peptidoform, charge, CCS in zip( + psm_list_pred_df["peptidoform"], + psm_list_pred_df["charge"], + psm_list_pred_df["predicted_ccs"], + ): + f.write(f"{peptidoform},{charge},{CCS}\n") + + LOGGER.info(f"Results written to: {output_file}") + + except IOError as e: + raise IM2DeepError(f"Failed to write output file {output_file}: {e}") + + +def predict_ccs( + psm_list_pred: PSMList, + psm_list_cal: Optional[PSMList] = None, + file_reference: Optional[Union[str, Path]] = None, + output_file: Optional[Union[str, Path]] = None, + model_name: str = "tims", + multi: bool = False, + calibrate_per_charge: bool = True, + use_charge_state: int = 2, + use_single_model: bool = True, + n_jobs: Optional[int] = None, + write_output: bool = False, + ion_mobility: bool = False, + pred_df: Optional[pd.DataFrame] = None, + cal_df: Optional[pd.DataFrame] = None, +) -> Union[pd.Series, pd.DataFrame]: + """ + Predict CCS values for peptides using IM2Deep models. + + This is the main function for CCS prediction. It can perform single-conformer + or multi-conformer predictions with optional calibration using reference datasets. + + Parameters + ---------- + psm_list_pred : PSMList + PSM list containing peptides for CCS prediction. Each PSM should contain + a valid peptidoform with sequence and modifications. + psm_list_cal : PSMList, optional + PSM list for calibration with observed CCS values in metadata. + Required for calibration. Default is None (no calibration). + file_reference : str or Path, optional + Path to reference dataset file for calibration. Default uses built-in + reference dataset. + output_file : str or Path, optional + Path to write output predictions. If None, no file is written. + model_name : str, default "tims" + Name of the model to use. Currently only "tims" is supported. + multi : bool, default False + Whether to include multi-conformer predictions. Requires optional + dependencies (torch, im2deeptrainer). + calibrate_per_charge : bool, default True + Whether to perform calibration per charge state. If False, uses + global calibration with specified charge state. + use_charge_state : int, default 2 + Charge state to use for global calibration when calibrate_per_charge=False. + Should be in range [2,4] for best results. + use_single_model : bool, default True + Whether to use a single model (faster) or ensemble of models (potentially + more accurate). Single model recommended for most applications. + n_jobs : int, optional + Number of parallel jobs for model prediction. If None, uses all available CPUs. + write_output : bool, default False + Whether to write predictions to output file. + ion_mobility : bool, default False + Whether to output ion mobility (1/K0) instead of CCS values. + pred_df : pd.DataFrame, optional + Pre-computed prediction DataFrame (used internally). + cal_df : pd.DataFrame, optional + Pre-computed calibration DataFrame (used internally). + + Returns + ------- + pd.Series or pd.DataFrame + If ion_mobility=True: Series with predicted ion mobility values + If ion_mobility=False: Series with predicted CCS values + For multi-conformer predictions, additional columns are included. + + Raises + ------ + IM2DeepError + If prediction fails due to invalid inputs, missing models, or other errors. + + Notes + ----- + The prediction workflow: + 1. Validate inputs and load appropriate models + 2. Generate CCS predictions using neural networks + 3. Apply calibration if calibration data provided + 4. Optionally run multi-conformer predictions + 5. Convert to ion mobility if requested + 6. Write output file if requested + + Calibration is highly recommended for accurate predictions and requires + a set of peptides with known CCS values that overlap with the reference dataset. + + Examples + -------- + Basic CCS prediction without calibration: + >>> predictions = predict_ccs(psm_list) + + CCS prediction with calibration: + >>> predictions = predict_ccs(psm_list, psm_list_calibration) + + Multi-conformer prediction with ion mobility output: + >>> predictions = predict_ccs( + ... psm_list, + ... psm_list_calibration, + ... multi=True, + ... ion_mobility=True + ... ) + + Ensemble prediction with file output: + >>> predictions = predict_ccs( + ... psm_list, + ... psm_list_calibration, + ... use_single_model=False, + ... output_file="predictions.csv", + ... write_output=True + ... ) + """ + LOGGER.info("IM2Deep started.") + + # Validate inputs + _validate_inputs(psm_list_pred, output_file) + + # Load reference dataset + if file_reference is None: + file_reference = REFERENCE_DATASET_PATH + + try: + reference_dataset = pd.read_csv(file_reference) + LOGGER.debug(f"Loaded reference dataset with {len(reference_dataset)} entries") + except Exception as e: + raise IM2DeepError(f"Failed to load reference dataset from {file_reference}: {e}") + + if reference_dataset.empty: + raise IM2DeepError("Reference dataset is empty") + + # Get model paths + try: + path_model_list = _get_model_paths(model_name, use_single_model) + except Exception as e: + raise IM2DeepError(f"Failed to load models: {e}") + + # Initialize DeepLC for CCS prediction + try: + dlc = DeepLC(path_model=path_model_list, n_jobs=n_jobs, predict_ccs=True) + LOGGER.info("Predicting CCS values...") + preds = dlc.make_preds(psm_list=psm_list_pred, calibrate=False) + LOGGER.info(f"CCS values predicted for {len(preds)} peptides.") + except Exception as e: + raise IM2DeepError(f"CCS prediction failed: {e}") + + if len(preds) == 0: + raise IM2DeepError("No predictions generated") + + # Convert PSM list to DataFrame and add predictions + try: + psm_list_pred_df = psm_list_pred.to_dataframe() + psm_list_pred_df["predicted_ccs"] = preds + psm_list_pred_df["charge"] = psm_list_pred_df["peptidoform"].apply( + lambda x: x.precursor_charge + ) + except Exception as e: + raise IM2DeepError(f"Failed to process predictions: {e}") + + # Apply calibration if calibration data provided + pred_df = None + if psm_list_cal is not None: + try: + LOGGER.info("Applying calibration...") + psm_list_cal_df = psm_list_cal.to_dataframe() + psm_list_cal_df["ccs_observed"] = psm_list_cal_df["metadata"].apply( + lambda x: float(x.get("CCS")) if x and "CCS" in x else None + ) + + # Filter out entries without CCS values + psm_list_cal_df = psm_list_cal_df[psm_list_cal_df["ccs_observed"].notnull()] + + if psm_list_cal_df.empty: + LOGGER.warning("No valid calibration data found (missing CCS values)") + else: + psm_list_pred_df = linear_calibration( + psm_list_pred_df, + calibration_dataset=psm_list_cal_df, + reference_dataset=reference_dataset, + per_charge=calibrate_per_charge, + use_charge_state=use_charge_state, + ) + LOGGER.info("Calibration applied successfully.") + + except Exception as e: + LOGGER.error(f"Calibration failed: {e}") + # Continue without calibration rather than failing completely + LOGGER.warning("Continuing without calibration") + + # Multi-conformer prediction + if multi: + try: + from im2deep.predict_multi import predict_multi + + LOGGER.info("Predicting multiconformer CCS values...") + pred_df = predict_multi( + psm_list_pred, + cal_df, + calibrate_per_charge, + use_charge_state, + ) + LOGGER.info("Multiconformational predictions completed.") + except ImportError: + raise IM2DeepError( + "Multi-conformer prediction requires optional dependencies. " + "Please install with: pip install 'im2deep[er]'" + ) + except Exception as e: + raise IM2DeepError(f"Multi-conformer prediction failed: {e}") + + # Convert to ion mobility if requested + if ion_mobility: + try: + psm_list_pred_df["predicted_im"] = ccs2im( + psm_list_pred_df["predicted_ccs"], + psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), + psm_list_pred_df["charge"], + ) + + if multi and pred_df is not None: + psm_list_pred_df["predicted_im_multi_1"] = ccs2im( + pred_df["predicted_ccs_multi_1"], + psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), + psm_list_pred_df["charge"], + ) + psm_list_pred_df["predicted_im_multi_2"] = ccs2im( + pred_df["predicted_ccs_multi_2"], + psm_list_pred_df["peptidoform"].apply(lambda x: x.theoretical_mz), + psm_list_pred_df["charge"], + ) + + except Exception as e: + raise IM2DeepError(f"Ion mobility conversion failed: {e}") + + # Write output file if requested + if write_output and output_file: + try: + _write_output_file(output_file, psm_list_pred_df, pred_df, ion_mobility, multi) + except Exception as e: + LOGGER.error(f"Failed to write output: {e}") + # Don't fail the entire prediction because of output issues + + LOGGER.info("IM2Deep finished!") + + # Return appropriate predictions + if ion_mobility: + return psm_list_pred_df["predicted_im"] + else: + return psm_list_pred_df["predicted_ccs"] diff --git a/im2deep/models/TIMS_multi/multi_output.ckpt b/im2deep/models/TIMS_multi/multi_output.ckpt index b14673e..e1c7068 100644 Binary files a/im2deep/models/TIMS_multi/multi_output.ckpt and b/im2deep/models/TIMS_multi/multi_output.ckpt differ diff --git a/im2deep/models/TIMS_multi/multi_output_pre_revision.ckpt b/im2deep/models/TIMS_multi/multi_output_pre_revision.ckpt new file mode 100644 index 0000000..b14673e Binary files /dev/null and b/im2deep/models/TIMS_multi/multi_output_pre_revision.ckpt differ diff --git a/im2deep/predict_multi.py b/im2deep/predict_multi.py index 1f63ab1..2b8fdfe 100644 --- a/im2deep/predict_multi.py +++ b/im2deep/predict_multi.py @@ -1,44 +1,163 @@ +""" +Multi-conformer CCS prediction module for IM2Deep. + +This module provides functionality for predicting CCS values for peptides that +can exist in multiple conformational states. It uses specialized neural network +models trained to predict multiple CCS values per peptide. + +The multi-conformer prediction pipeline: +1. Extract molecular features from peptide sequences +2. Run multi-output neural network models +3. Apply calibration using multi-conformer reference data +4. Return multiple CCS predictions per peptide + +Functions: + get_ccs_shift_multi: Calculate CCS shift for multi-conformer predictions + get_ccs_shift_per_charge_multi: Calculate per-charge shifts for multi predictions + calculate_ccs_shift_multi: Main shift calculation with validation + linear_calibration_multi: Apply calibration to multi-conformer predictions + predict_multi: Main function for multi-conformer CCS prediction + +Dependencies: + - torch: For neural network inference + - im2deeptrainer: For handling specialized im2deep models + - pandas/numpy: For data manipulation + +Note: + This module requires optional dependencies that can be installed with: + pip install 'im2deep[er]' +""" + from pathlib import Path import logging +from typing import Dict, Optional, Union + +try: + import torch + import pandas as pd + import numpy as np + + from im2deeptrainer.model import IM2DeepMultiTransfer + from im2deeptrainer.extract_data import _get_matrices + from im2deeptrainer.utils import FlexibleLossSorted -import torch -import pandas as pd -import numpy as np + TORCH_AVAILABLE = True +except ImportError: + # Optional dependencies not available + torch = None + IM2DeepMultiTransfer = None + _get_matrices = None + FlexibleLossSorted = None + TORCH_AVAILABLE = False -from im2deeptrainer.model import IM2DeepMultiTransfer -from im2deeptrainer.extract_data import _get_matrices -from im2deeptrainer.utils import FlexibleLossSorted + import pandas as pd + import numpy as np from im2deep.utils import multi_config +from im2deep._exceptions import IM2DeepError, CalibrationError LOGGER = logging.getLogger(__name__) MULTI_CKPT_PATH = Path(__file__).parent / "models" / "TIMS_multi" / "multi_output.ckpt" REFERENCE_DATASET_PATH = Path(__file__).parent / "reference_data" / "multi_reference_ccs.gz" +def _validate_multi_inputs(df_cal: pd.DataFrame, reference_dataset: pd.DataFrame) -> None: + """ + Validate inputs for multi-conformer calibration. + + Parameters + ---------- + df_cal : pd.DataFrame + Calibration dataset + reference_dataset : pd.DataFrame + Reference dataset + + Raises + ------ + CalibrationError + If validation fails + """ + required_cal_cols = ["seq", "modifications", "charge", "CCS"] + required_ref_cols = ["seq", "modifications", "charge", "ccs_observed"] + + if df_cal.empty: + raise CalibrationError("Calibration dataset is empty") + if reference_dataset.empty: + raise CalibrationError("Reference dataset is empty") + + missing_cal = set(required_cal_cols) - set(df_cal.columns) + if missing_cal: + raise CalibrationError(f"Missing columns in calibration data: {missing_cal}") + + missing_ref = set(required_ref_cols) - set(reference_dataset.columns) + if missing_ref: + raise CalibrationError(f"Missing columns in reference data: {missing_ref}") + + def get_ccs_shift_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame, use_charge_state: int = 2 ) -> float: - """Calculate CCS shift factor for multi predictions. + """ + Calculate CCS shift factor for multi-conformer predictions. + + This function calculates a shift factor specifically for multi-conformer + predictions by comparing calibration data with reference data for a + specific charge state. Parameters ---------- - df_cal - Peptides with CCS values. - reference_dataset - Reference dataset with CCS values. - use_charge_state - Charge state to use for CCS shift calculation, by default 2. - - Returns CCS shift factor + df_cal : pd.DataFrame + Calibration peptides with observed CCS values. Must contain columns: + 'seq', 'modifications', 'charge', 'ccs_observed' + reference_dataset : pd.DataFrame + Reference dataset with known CCS values. Must contain columns: + 'seq', 'modifications', 'charge', 'CCS' + use_charge_state : int, default 2 + Charge state to use for CCS shift calculation. Recommended range [2,4]. + + Returns + ------- + float + CCS shift factor for multi-conformer predictions. Positive values + indicate calibration CCS is higher than reference on average. + + Raises + ------ + CalibrationError + If inputs are invalid or no overlapping data found + + Notes + ----- + Multi-conformer shift calculation differs from single-conformer by: + - Using sequence + modifications for matching instead of peptidoform + - Typically having fewer overlapping peptides due to stricter matching + - Requiring specific reference data trained for multi-conformer models + + Examples + -------- + >>> shift = get_ccs_shift_multi(cal_df, ref_df, use_charge_state=2) + >>> print(f"Multi-conformer shift: {shift:.2f} Ų") """ + _validate_multi_inputs(df_cal, reference_dataset) + + if not use_charge_state <= 6: + raise CalibrationError(f"Invalid charge state {use_charge_state}") + LOGGER.debug( f"Using charge state {use_charge_state} for calibration of multiconformer predictions." ) + # Filter by charge state reference_tmp = reference_dataset[reference_dataset["charge"] == use_charge_state] df_tmp = df_cal[df_cal["charge"] == use_charge_state] + if reference_tmp.empty or df_tmp.empty: + LOGGER.warning( + f"No data found for charge state {use_charge_state} in multi-conformer calibration" + ) + return 0.0 + + # Merge on sequence and modifications for multi-conformer matching both = pd.merge( left=reference_tmp, right=df_tmp, @@ -47,30 +166,78 @@ def get_ccs_shift_multi( suffixes=("_ref", "_data"), ) + LOGGER.debug(f"Head of overlapping peptides:\n{both.head()}") + LOGGER.debug( - f"Calculating CCS shift based on {both.shape[0]} overlapping peptide-charge pairs between PSMs and reference dataset." + "" + f"Calculating CCS shift based on {both.shape[0]} overlapping peptide-charge pairs " + f"between PSMs and reference dataset." ) - return 0 if both.empty else np.mean(both["ccs_observed"] - both["CCS"]) + if both.empty: + LOGGER.warning("No overlapping peptides found for multi-conformer calibration") + return 0.0 + + if both.shape[0] < 10: + LOGGER.warning( + f"Only {both.shape[0]} overlapping peptides found for multi-conformer calibration. " + "Results may be unreliable." + ) + + # Calculate mean shift + shift = np.mean(both["ccs_observed"] - both["CCS"]) + + if abs(shift) > 50: + LOGGER.warning(f"Large multi-conformer CCS shift detected ({shift:.2f} Å^2)") + + return float(shift) def get_ccs_shift_per_charge_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame -) -> np.ndarray: - """Calculate CCS shift factor per charge state for multiconformational predictions +) -> Dict[int, float]: + """ + Calculate CCS shift factors per charge state for multi-conformer predictions. + + This function calculates charge-specific shift factors for multi-conformer + predictions, allowing for more accurate calibration across different + charge states. Parameters ---------- - df_cal - Peptides with CCS values. - reference_dataset - Reference dataset with CCS values. + df_cal : pd.DataFrame + Calibration peptides with observed CCS values. Must contain columns: + 'seq', 'modifications', 'charge', 'ccs_observed' + reference_dataset : pd.DataFrame + Reference dataset with known CCS values. Must contain columns: + 'seq', 'modifications', 'charge', 'CCS' Returns ------- - np.ndarray - CCS shift factor per charge state. + Dict[int, float] + Dictionary mapping charge states to their shift factors. + + Raises + ------ + CalibrationError + If inputs are invalid or no overlapping data found + + Notes + ----- + Multi-conformer per-charge calibration: + - Matches peptides exactly on sequence, modifications, and charge + - Typically yields fewer matches than single-conformer calibration + - Provides charge-specific corrections for systematic biases + + Examples + -------- + >>> shifts = get_ccs_shift_per_charge_multi(cal_df, ref_df) + >>> print("Multi-conformer shifts:", shifts) + {2: 4.1, 3: 2.8, 4: 1.5} """ + _validate_multi_inputs(df_cal, reference_dataset) + + # Merge datasets for exact matching both = pd.merge( left=reference_dataset, right=df_cal, @@ -78,44 +245,120 @@ def get_ccs_shift_per_charge_multi( how="inner", suffixes=("_ref", "_data"), ) - return both.groupby("charge").apply(lambda x: np.mean(x["ccs_observed"] - x["CCS"])).to_dict() + + if both.empty: + raise CalibrationError( + "No overlapping peptides found for multi-conformer per-charge calibration" + ) + + LOGGER.debug( + f"Found {both.shape[0]} overlapping peptides for multi-conformer per-charge calibration" + ) + + # Check data distribution + charge_counts = both.groupby("charge").size() + LOGGER.debug(f"Multi-conformer peptides per charge: {charge_counts.to_dict()}") + + # Warn about insufficient data + low_data_charges = charge_counts[charge_counts < 5].index.tolist() + if low_data_charges: + LOGGER.warning( + f"Charge states with <5 peptides in multi-conformer calibration: {low_data_charges}" + ) + + # Calculate shifts per charge + shift_dict = ( + both.groupby("charge").apply(lambda x: np.mean(x["ccs_observed"] - x["CCS"])).to_dict() + ) + + # Convert to native Python types + shift_dict = {int(k): float(v) for k, v in shift_dict.items()} + + return shift_dict def calculate_ccs_shift_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: int = None, -) -> float: + use_charge_state: Optional[int] = None, +) -> Union[float, Dict[int, float]]: """ - Apply CCS shift to CCS values for multiconformational predictions. + Calculate CCS shift factors for multi-conformer predictions with validation. + + This is the main interface for calculating shift factors for multi-conformer + predictions. It provides input validation, charge filtering, and supports + both global and per-charge calibration modes. Parameters ---------- - df_cal - Peptides with CCS values. - reference_dataset - Reference dataset with CCS values. - per_charge - Apply CCS shift per charge state, by default True. - use_charge_state - Charge state to use for CCS shift calculation, needs to be [2,4], by default None. + df_cal : pd.DataFrame + Calibration peptides with observed CCS values. + reference_dataset : pd.DataFrame + Reference dataset with known CCS values. + per_charge : bool, default True + Whether to calculate shift factors per charge state. + use_charge_state : int, optional + Charge state for global calibration when per_charge=False. + Default is 2 if not specified. Returns ------- - float - CCS shift factor. + Union[float, Dict[int, float]] + If per_charge=True: Dictionary of shift factors per charge + If per_charge=False: Single global shift factor + + Raises + ------ + CalibrationError + If validation fails or invalid parameters + + Notes + ----- + Multi-conformer models are typically trained for charges 2-4, so higher + charges are filtered out automatically. The function logs filtering actions + for transparency. + + Examples + -------- + >>> # Per-charge calibration (recommended) + >>> shifts = calculate_ccs_shift_multi(cal_df, ref_df, per_charge=True) + >>> + >>> # Global calibration + >>> shift = calculate_ccs_shift_multi(cal_df, ref_df, per_charge=False, use_charge_state=2) """ - df_cal = df_cal[(df_cal["charge"] < 5) & (df_cal["charge"] > 1)] + _validate_multi_inputs(df_cal, reference_dataset) + + if use_charge_state is not None and not use_charge_state <= 6: + raise CalibrationError(f"Invalid charge state {use_charge_state}") + + # Filter charge states (multi-conformer models typically work best for 2-4) + original_size = len(df_cal) + df_cal = df_cal[(df_cal["charge"] < 5)].copy() + + if len(df_cal) < original_size: + filtered_count = original_size - len(df_cal) + LOGGER.info( + f"Filtered {filtered_count} peptides outside charge range 2-4 " + "for multi-conformer calibration" + ) + + if df_cal.empty: + raise CalibrationError( + "No valid calibration data for multi-conformer prediction after filtering" + ) if not per_charge: + if use_charge_state is None: + use_charge_state = 2 + LOGGER.debug("Using charge 2 for global multi-conformer calibration") + shift_factor = get_ccs_shift_multi(df_cal, reference_dataset, use_charge_state) - LOGGER.debug(f"CCS shift factor: {shift_factor}") + LOGGER.debug(f"Multi-conformer global shift factor: {shift_factor:.3f}") return shift_factor - else: shift_factor_dict = get_ccs_shift_per_charge_multi(df_cal, reference_dataset) - LOGGER.debug(f"CCS shift factors: {shift_factor_dict}") + LOGGER.debug(f"Multi-conformer shift factors: {shift_factor_dict}") return shift_factor_dict @@ -124,99 +367,234 @@ def linear_calibration_multi( df_cal: pd.DataFrame, reference_dataset: pd.DataFrame, per_charge: bool = True, - use_charge_state: int = None, + use_charge_state: Optional[int] = None, ) -> pd.DataFrame: """ - Calibrate multiconformer predictions using linear calibration. + Calibrate multi-conformer CCS predictions using linear calibration. + + This function applies linear calibration specifically designed for + multi-conformer CCS predictions. It calculates and applies shift factors + to both conformer predictions. Parameters ---------- - df_pred - Peptides with CCS predictions. - df_cal - Peptides with CCS values. - reference_dataset - Reference dataset with CCS values. - per_charge - Apply calibration per charge state, by default True. - use_charge_state - Charge state to use for calibration, needs to be [2,4], by default None. + df_pred : pd.DataFrame + DataFrame with multi-conformer CCS predictions. Must contain columns: + 'predicted_ccs_multi_1', 'predicted_ccs_multi_2', 'peptidoform' + df_cal : pd.DataFrame + Calibration dataset with observed CCS values. + reference_dataset : pd.DataFrame + Reference dataset for multi-conformer calibration. + per_charge : bool, default True + Whether to apply calibration per charge state. + use_charge_state : int, optional + Charge state for global calibration when per_charge=False. Returns ------- pd.DataFrame - Calibrated PSMs. + DataFrame with calibrated multi-conformer predictions. + + Raises + ------ + CalibrationError + If calibration fails + + Notes + ----- + Multi-conformer calibration: + - Applies the same shift to both conformer predictions + - Uses specialized reference data for multi-conformer models + - Supports both global and per-charge calibration strategies + + The calibration preserves the relative differences between conformers + while correcting systematic biases. + + Examples + -------- + >>> calibrated_df = linear_calibration_multi( + ... pred_df, cal_df, ref_df, per_charge=True + ... ) """ LOGGER.info("Calibrating multiconformer predictions using linear calibration...") - if per_charge: - LOGGER.info("Generating general shift factor for multiconformer predictions...") - general_shift = calculate_ccs_shift_multi( - df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state - ) - LOGGER.info("Getting shift factors per charge state...") - df_pred["charge"] = df_pred["peptidoform"].apply(lambda x: x.precursor_charge) - shift_factor_dict = calculate_ccs_shift_multi(df_cal, reference_dataset, per_charge=True) - df_pred["shift_multi"] = df_pred["charge"].map(shift_factor_dict).fillna(general_shift) - df_pred["predicted_ccs_multi_1"] = ( - df_pred["predicted_ccs_multi_1"] + df_pred["shift_multi"] - ) - df_pred["predicted_ccs_multi_2"] = ( - df_pred["predicted_ccs_multi_2"] + df_pred["shift_multi"] - ) + if df_pred.empty: + raise CalibrationError("Predictions dataframe is empty") + + required_cols = ["predicted_ccs_multi_1", "predicted_ccs_multi_2", "peptidoform"] + missing_cols = set(required_cols) - set(df_pred.columns) + if missing_cols: + raise CalibrationError(f"Missing columns in predictions: {missing_cols}") + + # Create working copy + df_pred = df_pred.copy() + + try: + if per_charge: + LOGGER.info("Generating general shift factor for multiconformer predictions...") + general_shift = calculate_ccs_shift_multi( + df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state or 2 + ) + + LOGGER.info("Getting shift factors per charge state for multiconformer...") + df_pred["charge"] = df_pred["peptidoform"].apply(lambda x: x.precursor_charge) + shift_factor_dict = calculate_ccs_shift_multi( + df_cal, reference_dataset, per_charge=True + ) + + # Apply charge-specific shifts with fallback + df_pred["shift_multi"] = df_pred["charge"].map(shift_factor_dict).fillna(general_shift) + df_pred["predicted_ccs_multi_1"] = ( + df_pred["predicted_ccs_multi_1"] + df_pred["shift_multi"] + ) + df_pred["predicted_ccs_multi_2"] = ( + df_pred["predicted_ccs_multi_2"] + df_pred["shift_multi"] + ) + + else: + shift_factor = calculate_ccs_shift_multi( + df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state or 2 + ) + df_pred["predicted_ccs_multi_1"] = df_pred["predicted_ccs_multi_1"] + shift_factor + df_pred["predicted_ccs_multi_2"] = df_pred["predicted_ccs_multi_2"] + shift_factor + df_pred["shift_multi"] = shift_factor + + LOGGER.info("Multiconformer predictions calibrated successfully.") + return df_pred + + except Exception as e: + raise CalibrationError(f"Multi-conformer calibration failed: {e}") + + +def predict_multi( + df_pred_psm_list, + df_cal: Optional[pd.DataFrame], + calibrate_per_charge: bool, + use_charge_state: int, +) -> pd.DataFrame: + """ + Generate multi-conformer CCS predictions for peptides. - else: - shift_factor = calculate_ccs_shift_multi( - df_cal, reference_dataset, per_charge=False, use_charge_state=use_charge_state + This is the main function for multi-conformer CCS prediction. It loads + the specialized multi-output neural network model and generates predictions + for multiple conformational states of each peptide. + + Parameters + ---------- + df_pred_psm_list : PSMList + PSM list containing peptides for prediction. + df_cal : pd.DataFrame, optional + Calibration dataset. If provided, predictions will be calibrated. + calibrate_per_charge : bool + Whether to perform per-charge calibration. + use_charge_state : int + Charge state for global calibration. + + Returns + ------- + pd.DataFrame + DataFrame with columns 'predicted_ccs_multi_1' and 'predicted_ccs_multi_2' + containing CCS predictions for two conformational states. + + Raises + ------ + IM2DeepError + If multi-conformer prediction fails + + Notes + ----- + Multi-conformer prediction workflow: + 1. Extract molecular features using im2deeptrainer + 2. Load pre-trained multi-output model + 3. Generate predictions for two conformational states + 4. Apply calibration if calibration data provided + 5. Return predictions as DataFrame + + The model predicts two CCS values per peptide, representing the most + probable conformational states based on the training data. + + Examples + -------- + >>> multi_preds = predict_multi(psm_list, cal_df, True, 2) + >>> print(multi_preds.columns) + ['predicted_ccs_multi_1', 'predicted_ccs_multi_2'] + """ + # Check if optional dependencies are available + if not TORCH_AVAILABLE: + raise IM2DeepError( + "Multi-conformer prediction requires optional dependencies. " + "Please install with: pip install 'im2deep[er]'" ) - df_pred["predicted_ccs_multi_1"] = df_pred["predicted_ccs_multi_1"] + shift_factor - df_pred["predicted_ccs_multi_2"] = df_pred["predicted_ccs_multi_2"] + shift_factor - LOGGER.info("Multiconformer predictions calibrated.") - return df_pred + try: + # Initialize model components + criterion = FlexibleLossSorted() + # Check if model file exists + if not MULTI_CKPT_PATH.exists(): + raise IM2DeepError(f"Multi-conformer model not found: {MULTI_CKPT_PATH}") -def predict_multi(df_pred_psm_list, df_cal, calibrate_per_charge, use_charge_state): - criterion = FlexibleLossSorted() - model = IM2DeepMultiTransfer.load_from_checkpoint( - MULTI_CKPT_PATH, config=multi_config, criterion=criterion - ) + model = IM2DeepMultiTransfer.load_from_checkpoint( + MULTI_CKPT_PATH, config=multi_config, criterion=criterion + ) - # df_pred["tr"] = 0 # Placeholder for DeepLC compatibility - # print(df_pred) - matrices = _get_matrices(df_pred_psm_list, inference=True) + LOGGER.debug("Multi-conformer model loaded successfully") - tensors = {} - for key in matrices: - tensors[key] = torch.tensor(matrices[key]).type(torch.FloatTensor) + # Extract molecular features + LOGGER.debug("Extracting molecular features for multi-conformer prediction...") + matrices = _get_matrices(df_pred_psm_list, inference=True) - dataset = torch.utils.data.TensorDataset(*[tensors[key] for key in tensors]) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=multi_config["batch_size"], shuffle=False - ) + # Convert to tensors + tensors = {} + for key in matrices: + tensors[key] = torch.tensor(matrices[key]).type(torch.FloatTensor) - model.eval() - with torch.no_grad(): - preds = [] - for index, batch in enumerate(dataloader): - prediction = model.predict_step(batch, inference=True) - preds.append(prediction) - predictions = torch.cat(preds).numpy() - - df_pred = df_pred_psm_list.to_dataframe() - - df_pred["predicted_ccs_multi_1"] = predictions[:, 0] - df_pred["predicted_ccs_multi_2"] = predictions[:, 1] - - if df_cal is not None: - df_pred = linear_calibration_multi( - df_pred, - df_cal, - reference_dataset=pd.read_csv( - REFERENCE_DATASET_PATH, compression="gzip", keep_default_na=False - ), - per_charge=calibrate_per_charge, - use_charge_state=use_charge_state, + # Create data loader + dataset = torch.utils.data.TensorDataset(*[tensors[key] for key in tensors]) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=multi_config["batch_size"], shuffle=False ) - return df_pred[["predicted_ccs_multi_1", "predicted_ccs_multi_2"]] + # Generate predictions + model.eval() + with torch.no_grad(): + preds = [] + for batch in dataloader: + prediction = model.predict_step(batch, inference=True) + preds.append(prediction) + predictions = torch.cat(preds).numpy() + + LOGGER.debug(f"Generated multi-conformer predictions for {len(predictions)} peptides") + + # Convert PSM list to DataFrame and add predictions + df_pred = df_pred_psm_list.to_dataframe() + + if len(predictions) != len(df_pred): + raise IM2DeepError(f"Prediction count mismatch: {len(predictions)} vs {len(df_pred)}") + + df_pred["predicted_ccs_multi_1"] = predictions[:, 0] + df_pred["predicted_ccs_multi_2"] = predictions[:, 1] + + # Apply calibration if calibration data provided + if df_cal is not None: + try: + LOGGER.debug("Loading multi-conformer reference dataset...") + reference_dataset = pd.read_csv( + REFERENCE_DATASET_PATH, compression="gzip", keep_default_na=False + ) + + df_pred = linear_calibration_multi( + df_pred, + df_cal, + reference_dataset=reference_dataset, + per_charge=calibrate_per_charge, + use_charge_state=use_charge_state, + ) + except Exception as e: + LOGGER.warning(f"Multi-conformer calibration failed: {e}") + LOGGER.warning("Returning uncalibrated multi-conformer predictions") + + return df_pred[["predicted_ccs_multi_1", "predicted_ccs_multi_2"]] + + except Exception as e: + raise IM2DeepError(f"Multi-conformer prediction failed: {e}") diff --git a/im2deep/utils.py b/im2deep/utils.py index 751dfaa..c5e937b 100644 --- a/im2deep/utils.py +++ b/im2deep/utils.py @@ -1,4 +1,20 @@ +""" +Utility functions for IM2Deep package. + +This module provides utility functions for converting between different +mobility measurements and configuration settings for multi-conformer models. + +Functions: + im2ccs: Convert ion mobility to collisional cross section + ccs2im: Convert collisional cross section to ion mobility + +Constants: + multi_config: Configuration dictionary for multi-conformer model + MULTI_BACKBONE_PATH: Path to the multi-conformer model backbone +""" + from pathlib import Path +from typing import Union, Any, Dict import numpy as np MULTI_BACKBONE_PATH = ( @@ -6,67 +22,166 @@ ) -def im2ccs(reverse_im, mz, charge, mass_gas=28.013, temp=31.85, t_diff=273.15): +def im2ccs( + reverse_im: Union[float, np.ndarray], + mz: Union[float, np.ndarray], + charge: Union[int, np.ndarray], + mass_gas: float = 28.013, + temp: float = 31.85, + t_diff: float = 273.15, +) -> Union[float, np.ndarray]: """ - Convert ion mobility to collisional cross section. + Convert reduced ion mobility to collisional cross section. + + This function converts reduced ion mobility (1/K0) values to collisional + cross section (CCS) using the Mason-Schamp equation. The conversion is + temperature and gas-dependent. Parameters ---------- - reverse_im - Reduced ion mobility. - mz - Precursor m/z. - charge - Precursor charge. - mass_gas - Mass of gas, default 28.013 - temp - Temperature in Celsius, default 31.85 - t_diff - Factor to convert Celsius to Kelvin, default 273.15 + reverse_im : float or array-like + Reduced ion mobility (1/K0) in V⋅s/cm². + mz : float or array-like + Precursor m/z ratio. + charge : int or array-like + Precursor charge state. + mass_gas : float, optional + Mass of drift gas in atomic mass units. Default is 28.013 (N₂). + temp : float, optional + Temperature in Celsius. Default is 31.85°C + t_diff : float, optional + Temperature conversion factor (°C to K). Default is 273.15. + + Returns + ------- + float or np.ndarray + Collisional cross section in Ų (square Angstroms). Notes ----- - Adapted from theGreatHerrLebert/ionmob (https://doi.org/10.1093/bioinformatics/btad486) + The conversion uses the Mason-Schamp equation: + CCS = (18509.8632163405 * z) / (sqrt(μ * T) * K0) + + Where: + - z is the charge + - μ is the reduced mass + - T is temperature in Kelvin + - K0 is the ion mobility + + References + ---------- + Adapted from theGreatHerrLebert/ionmob + (https://doi.org/10.1093/bioinformatics/btad486) + + Examples + -------- + >>> im2ccs(0.7, 500.0, 2) + 425.3 + >>> # For arrays + >>> import numpy as np + >>> ims = np.array([0.7, 0.8, 0.9]) + >>> mzs = np.array([500.0, 600.0, 700.0]) + >>> charges = np.array([2, 2, 3]) + >>> ccs_values = im2ccs(ims, mzs, charges) """ + # Validate inputs + if np.any(reverse_im <= 0): + raise ValueError("Reduced ion mobility must be positive") + if np.any(mz <= 0): + raise ValueError("m/z must be positive") + if np.any(charge <= 0): + raise ValueError("Charge must be positive") + if mass_gas <= 0: + raise ValueError("Gas mass must be positive") + if temp <= -t_diff: + raise ValueError("Temperature must be above absolute zero") SUMMARY_CONSTANT = 18509.8632163405 reduced_mass = (mz * charge * mass_gas) / (mz * charge + mass_gas) return (SUMMARY_CONSTANT * charge) / (np.sqrt(reduced_mass * (temp + t_diff)) * 1 / reverse_im) -def ccs2im(ccs, mz, charge, mass_gas=28.013, temp=31.85, t_diff=273.15): +def ccs2im( + ccs: Union[float, np.ndarray], + mz: Union[float, np.ndarray], + charge: Union[int, np.ndarray], + mass_gas: float = 28.013, + temp: float = 31.85, + t_diff: float = 273.15, +) -> Union[float, np.ndarray]: """ - Convert collisional cross section to ion mobility. + Convert collisional cross section to reduced ion mobility. + + This function converts collisional cross section (CCS) values to reduced + ion mobility (1/K0) using the inverse of the Mason-Schamp equation. Parameters ---------- - ccs - Collisional cross section. - mz - Precursor m/z. - charge - Precursor charge. - mass_gas - Mass of gas, default 28.013 - temp - Temperature in Celsius, default 31.85 - t_diff - Factor to convert Celsius to Kelvin, default 273.15 + ccs : float or array-like + Collisional cross section in Ų (square Angstroms). + mz : float or array-like + Precursor m/z ratio. + charge : int or array-like + Precursor charge state. + mass_gas : float, optional + Mass of drift gas in atomic mass units. Default is 28.013 (N₂). + temp : float, optional + Temperature in Celsius. Default is 31.85°C (typical for TIMS). + t_diff : float, optional + Temperature conversion factor (°C to K). Default is 273.15. + + Returns + ------- + float or np.ndarray + Reduced ion mobility (1/K0) in V⋅s/cm². Notes ----- - Adapted from theGreatHerrLebert/ionmob (https://doi.org/10.1093/bioinformatics/btad486) + The conversion uses the inverse Mason-Schamp equation: + 1/K0 = (sqrt(μ * T) * CCS) / (18509.8632163405 * z) + + Where: + - μ is the reduced mass + - T is temperature in Kelvin + - z is the charge + + References + ---------- + Adapted from theGreatHerrLebert/ionmob + (https://doi.org/10.1093/bioinformatics/btad486) + + Examples + -------- + >>> ccs2im(425.3, 500.0, 2) + 0.7 + >>> # For arrays + >>> import numpy as np + >>> ccs_values = np.array([425.3, 510.2, 680.5]) + >>> mzs = np.array([500.0, 600.0, 700.0]) + >>> charges = np.array([2, 2, 3]) + >>> ims = ccs2im(ccs_values, mzs, charges) """ + # Validate inputs + if np.any(ccs <= 0): + raise ValueError("CCS must be positive") + if np.any(mz <= 0): + raise ValueError("m/z must be positive") + if np.any(charge <= 0): + raise ValueError("Charge must be positive") + if mass_gas <= 0: + raise ValueError("Gas mass must be positive") + if temp <= -t_diff: + raise ValueError("Temperature must be above absolute zero") SUMMARY_CONSTANT = 18509.8632163405 reduced_mass = (mz * charge * mass_gas) / (mz * charge + mass_gas) return ((np.sqrt(reduced_mass * (temp + t_diff))) * ccs) / (SUMMARY_CONSTANT * charge) -multi_config = { +# Configuration for multi-conformer model +multi_config: Dict[str, Any] = { "model_name": "IM2DeepMulti", "batch_size": 16, "learning_rate": 0.0001, diff --git a/pyproject.toml b/pyproject.toml index 0b90821..dd1fc60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,16 +28,24 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["black", "isort>5", "pytest", "pytest-cov"] +dev = [ + "black>=23.0", + "isort>=5.12", + "pytest>=7.0", + "pytest-cov>=4.0", + "pytest-mock>=3.10", + "mypy>=1.0", + "pre-commit>=3.0" +] docs = [ - "sphinx", - "numpydoc>=1,<2", - "recommonmark", - "sphinx-mdinclude", - "toml", - "semver>=2", - "sphinx_rtd_theme", - "sphinx-autobuild", + "sphinx>=6.0", + "numpydoc>=1.5", + "recommonmark>=0.7", + "sphinx-mdinclude>=0.5", + "toml>=0.10", + "semver>=2.13", + "sphinx_rtd_theme>=1.2", + "sphinx-autobuild>=2021.3" ] er = [ "im2deeptrainer", @@ -59,11 +67,107 @@ version = {attr = "im2deep.__version__"} [tool.isort] profile = "black" +line_length = 99 +known_first_party = ["im2deep"] +known_third_party = ["click", "deeplc", "psm_utils", "pandas", "numpy", "rich"] [tool.black] line-length = 99 -target-version = ['py38'] +target-version = ['py310'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[[tool.mypy.overrides]] +module = [ + "deeplc.*", + "torch.*", + "im2deeptrainer.*", + "psm_utils.*" +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = [ + "-ra", + "--strict-markers", + "--strict-config" +] +testpaths = ["tests"] +filterwarnings = [ + "error", + "ignore::UserWarning", + "ignore::DeprecationWarning" +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", + "unit: marks tests as unit tests" +] + +[tool.coverage.run] +source = ["im2deep"] +omit = [ + "*/tests/*", + "*/test_*", + "*/__main__.py" +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if self.debug:", + "if settings.DEBUG", + "raise AssertionError", + "raise NotImplementedError", + "if 0:", + "if __name__ == .__main__.:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod" +] [tool.ruff] line-length = 99 -target-version = "py38" +target-version = "py310" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long, handled by black + "B008", # do not perform function calls in argument defaults + "C901", # too complex +]