Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions im2deep/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
... )
"""

from __future__ import annotations

import logging
from typing import Dict, Union, Optional
from typing import cast

import numpy as np
import pandas as pd
from numpy import ndarray
from psm_utils.peptidoform import Peptidoform

from im2deep._exceptions import CalibrationError

Expand All @@ -40,27 +40,28 @@
def _validate_calibration_inputs(
cal_df: pd.DataFrame,
reference_dataset: pd.DataFrame,
required_cal_columns: Optional[list] = None,
required_ref_columns: Optional[list] = None,
required_cal_columns: list | None = None,
required_ref_columns: list | None = None,
) -> None:
"""
Validate input dataframes for calibration functions.

Parameters
----------
cal_df : pd.DataFrame
cal_df
Calibration dataset
reference_dataset : pd.DataFrame
reference_dataset
Reference dataset
required_cal_columns : list, optional
required_cal_columns
Required columns for calibration dataset
required_ref_columns : list, optional
required_ref_columns
Required columns for reference dataset

Raises
------
CalibrationError
If validation fails

"""
if cal_df.empty:
raise CalibrationError("Calibration dataset is empty")
Expand Down Expand Up @@ -91,11 +92,11 @@ def get_ccs_shift(

Parameters
----------
cal_df : pd.DataFrame
cal_df
PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed'
reference_dataset : pd.DataFrame
reference_dataset
Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS'
use_charge_state : int, default 2
use_charge_state
Charge state to use for CCS shift calculation. Should be in range [2,4].

Returns
Expand All @@ -120,6 +121,7 @@ def get_ccs_shift(
--------
>>> shift = get_ccs_shift(calibration_df, reference_df, use_charge_state=2)
>>> print(f"CCS shift factor: {shift:.2f} Ų")

"""
# Validate inputs
_validate_calibration_inputs(
Expand Down Expand Up @@ -187,7 +189,7 @@ def get_ccs_shift(

def get_ccs_shift_per_charge(
cal_df: pd.DataFrame, reference_dataset: pd.DataFrame
) -> Dict[int, float]:
) -> dict[int, float]:
"""
Calculate CCS shift factors per charge state.

Expand All @@ -197,9 +199,9 @@ def get_ccs_shift_per_charge(

Parameters
----------
cal_df : pd.DataFrame
cal_df
PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed'
reference_dataset : pd.DataFrame
reference_dataset
Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS'

Returns
Expand Down Expand Up @@ -228,6 +230,7 @@ def get_ccs_shift_per_charge(
>>> shifts = get_ccs_shift_per_charge(calibration_df, reference_df)
>>> print(shifts)
{2: 5.2, 3: 3.8, 4: 2.1}

"""
# Validate inputs
_validate_calibration_inputs(
Expand Down Expand Up @@ -277,9 +280,7 @@ def get_ccs_shift_per_charge(
# Check for unreasonably large shifts
large_shifts = {k: v for k, v in shift_dict.items() if abs(v) > 100}
if large_shifts:
LOGGER.warning(
f"Large CCS shifts detected: {large_shifts}. " "Please verify data quality."
)
LOGGER.warning(f"Large CCS shifts detected: {large_shifts}. Please verify data quality.")

return shift_dict

Expand All @@ -288,8 +289,8 @@ def calculate_ccs_shift(
cal_df: pd.DataFrame,
reference_dataset: pd.DataFrame,
per_charge: bool = True,
use_charge_state: Optional[int] = None,
) -> Union[float, Dict[int, float]]:
use_charge_state: int | None = None,
) -> float | dict[int, float]:
"""
Calculate CCS shift factors with validation and filtering.

Expand All @@ -299,14 +300,14 @@ def calculate_ccs_shift(

Parameters
----------
cal_df : pd.DataFrame
cal_df
PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed'
reference_dataset : pd.DataFrame
reference_dataset
Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS'
per_charge : bool, default True
per_charge
Whether to calculate shift factors per charge state. If False, calculates
a single global shift factor using the specified charge state.
use_charge_state : int, optional
use_charge_state
Charge state to use for global shift calculation when per_charge=False.
Should be in range [2,4]. Default is 2 if not specified.

Expand Down Expand Up @@ -334,6 +335,7 @@ def calculate_ccs_shift(
>>>
>>> # Global calibration using charge 2
>>> shift = calculate_ccs_shift(cal_df, ref_df, per_charge=False, use_charge_state=2)

"""
# Validate inputs
_validate_calibration_inputs(cal_df, reference_dataset)
Expand Down Expand Up @@ -378,7 +380,7 @@ def linear_calibration(
calibration_dataset: pd.DataFrame,
reference_dataset: pd.DataFrame,
per_charge: bool = True,
use_charge_state: Optional[int] = None,
use_charge_state: int | None = None,
) -> pd.DataFrame:
"""
Calibrate CCS predictions using linear calibration.
Expand All @@ -389,20 +391,20 @@ def linear_calibration(

Parameters
----------
preds_df : pd.DataFrame
preds_df
PSMs with CCS predictions. Must contain 'predicted_ccs' column.
Will be modified to include 'charge' and 'shift' columns.
calibration_dataset : pd.DataFrame
calibration_dataset
Calibration dataset with observed CCS values. Must contain columns:
'peptidoform', 'ccs_observed'
reference_dataset : pd.DataFrame
reference_dataset
Reference dataset with CCS values. Must contain columns:
'peptidoform', 'CCS'
per_charge : bool, default True
per_charge
Whether to calculate and apply shift factors per charge state.
If True, uses charge-specific calibration with fallback to global shift.
If False, applies single global shift factor.
use_charge_state : int, optional
use_charge_state
Charge state to use for global shift calculation when per_charge=False.
Default is 2 if not specified.

Expand Down Expand Up @@ -446,6 +448,7 @@ def linear_calibration(
... use_charge_state=2
... )
"""

LOGGER.info("Calibrating CCS values using linear calibration...")

# Validate input dataframes
Expand Down Expand Up @@ -478,7 +481,7 @@ def linear_calibration(
)

except (AttributeError, ValueError, IndexError) as e:
raise CalibrationError(f"Error parsing peptidoform data: {e}")
raise CalibrationError(f"Error parsing peptidoform data: {e}") from e

if per_charge:
LOGGER.info("Calculating general shift factor for fallback...")
Expand All @@ -489,6 +492,8 @@ def linear_calibration(
per_charge=False,
use_charge_state=use_charge_state or 2,
)
# per_charge=False returns float
general_shift = cast(float, general_shift)
except CalibrationError as e:
LOGGER.warning(
f"Could not calculate general shift factor: {e}. Using 0.0 as fallback."
Expand All @@ -499,6 +504,8 @@ def linear_calibration(
shift_factor_dict = calculate_ccs_shift(
calibration_dataset, reference_dataset, per_charge=True
)
# per_charge=True returns dict[int, float]
shift_factor_dict = cast(dict[int, float], shift_factor_dict)

# Add charge information to predictions if not present
if "charge" not in preds_df.columns:
Expand All @@ -525,6 +532,8 @@ def linear_calibration(
per_charge=False,
use_charge_state=use_charge_state or 2,
)
# per_charge=False returns floats
shift_factor = cast(float, shift_factor)
preds_df["predicted_ccs"] += shift_factor
preds_df["shift"] = shift_factor
LOGGER.info(f"Applied global shift factor: {shift_factor:.3f}")
Expand Down
Loading