Skip to content

Conversation

@minettekaum
Copy link
Contributor

Description

Added PERP (Parameter-Efficient Recovery Protocol) recovery algorithms to restore model quality after compression. The implementation includes multiple variants for both text-to-image and text-to-text models, supporting LoRA adapters and in-place finetuning mechanisms.

Algorithm Locations

Core PERP Algorithms

  • Main algorithm classes: src/pruna/algorithms/perp.py

    • TextToImagePERP - General-purpose recovery for text-to-image models
    • TextToImageInPlacePERP - In-place variant without LoRA (faster inference)
    • TextToImageLoRA - LoRA-only variant for text-to-image
    • TextToTextPERP - Recovery for text-to-text models (LLMs)
    • TextToTextInPlacePERP - In-place variant for text-to-text
    • TextToTextLoRA - LoRA-only variant for text-to-text
  • Distillation variants: src/pruna/algorithms/distillation_perp.py

    • TextToImagePERPDistillation - PERP with distillation for text-to-image
    • TextToImageInPlacePERPDistillation - In-place distillation variant
    • TextToImageLoraDistillation - LoRA distillation variant

Core Recovery Infrastructure

  • Base recoverer: src/pruna/algorithms/global_utils/recovery/perp_recoverer.py

    • PERPRecoverer - Base class implementing core recovery logic with adapter management, finetuning orchestration, and scheduler handling
  • Adapter implementations: src/pruna/algorithms/global_utils/recovery/adapters/

    • norm.py - NormAdapter for in-place norm finetuning
    • bias.py - BiasAdapter for in-place bias finetuning
    • head.py - HeadAdapter for head finetuning (text-to-text models)
    • lora.py - LoraAdapter using HuggingFace PEFT
    • utils.py - Adapter utilities and parameter freezing
  • Finetuner implementations: src/pruna/algorithms/global_utils/recovery/finetuners/

    • text_to_image_finetuner.py - TextToImageFinetuner with training loop, loss computation, and scheduler integration
    • text_to_text_finetuner.py - TextToTextFinetuner for LLM recovery
    • text_to_image_distiller.py - TextToImageDistiller for distillation-based recovery
    • diffusers/scheduler_interface.py - Training scheduler utilities (get_training_scheduler, sample_timesteps, add_noise, get_target)
    • diffusers/pack_and_predict.py - Pipeline-specific prediction functions for SD, SDXL, Sana, Flux
    • diffusers/utils.py - Diffusers pipeline utilities (denoiser extraction, prompt encoding, device management)
    • diffusers/distillation_arg_utils.py - Utilities for distillation argument handling

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

  • Text-to-image PERP tests: tests/algorithms/testers/tti_perp.py

    • Tests for TextToImagePERP algorithm with various model types
  • Text-to-image in-place PERP tests: tests/algorithms/testers/tti_inplace_perp.py

    • Tests for TextToImageInPlacePERP (no LoRA variant)
  • Text-to-image LoRA tests: tests/algorithms/testers/tti_lora.py

    • Tests for TextToImageLoRA (LoRA-only variant)
  • Text-to-text PERP tests: tests/algorithms/testers/ttt_perp.py

    • Tests for TextToTextPERP algorithm with LLM models
  • Text-to-text in-place PERP tests: tests/algorithms/testers/ttt_inplace_perp.py

    • Tests for TextToTextInPlacePERP (no LoRA variant)
  • Text-to-text LoRA tests: tests/algorithms/testers/ttt_lora.py

    • Tests for TextToTextLoRA (LoRA-only variant)
  • Distillation PERP tests: tests/algorithms/testers/tti_distillation_perp.py

    • Tests for TextToImagePERPDistillation algorithm
  • In-place distillation PERP tests: tests/algorithms/testers/tti_distillation_inplace_perp.py

    • Tests for TextToImageInPlacePERPDistillation algorithm
  • LoRA distillation PERP tests: tests/algorithms/testers/tti_distillation_lora.py

    • Tests for TextToImageLoraDistillation algorithm
  • Test utilities: tests/algorithms/testers/utils.py

    • restrict_recovery_time() - Utility to limit training time for testing
    • replace_datamodule_with_distillation_datamodule() - Helper for distillation tests
    • get_model_sparsity() - Utility to measure model sparsity

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@gsprochette
Copy link
Collaborator

@cursor review

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

The number of parameters found that match the given name but were not trainable.
"""
if len(target_modules) == 0:
return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return value causes None unpacking error

High Severity

The unfreeze_parameters_by_name function returns None (via bare return) when target_modules is empty, but the function signature declares tuple[int, int] as the return type. Callers like BiasAdapter.activate unpack the result as num_activ_param, num_skip_param = utils.unfreeze_parameters_by_name(...), which would raise a TypeError if called with empty target_modules. The similar functions unfreeze_submodules_by_type and unfreeze_submodules_by_class_name correctly return (0, 0) for empty inputs.

Fix in Cursor Fix in Web

if len(model_heads) != 1:
# = 0: model with no head, e.g. diffusers
# > 1: model with multiple heads, e.g. for localization, not currently supported
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect indexing of Linear modules instead of names

Medium Severity

The code attempts [h[0] for h in model_heads] to extract head names, but model_heads is a list of torch.nn.Linear modules (not tuples). The list comprehension at lines 71-75 stores only component, discarding comp_name. When there are multiple heads (>1), indexing a Linear module with [0] will raise a TypeError. The fix requires storing (comp_name, component) tuples in model_heads or collecting names separately.

Fix in Cursor Fix in Web

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
model_head_names = [
comp_name
for comp_name, component in inspect.getmembers(model)
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit hacky, but okay for me, or we also collect the name in model_heads in line 71, then we don't have to go through the model twice

A text field is only provided if the dataset is a huggingface dataset not yet tokenized.
"""
column_names = dataset.column_names # type: ignore[union-attr]
if hasattr(dataset, "column_names") and "input_ids" in column_names:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute access before hasattr check causes crash

High Severity

The function _format_dataset_for_causal_lm accepts Dataset | torch.utils.data.Dataset per its type signature, but line 223 unconditionally accesses dataset.column_names before line 224 checks hasattr(dataset, "column_names"). When a standard PyTorch Dataset (which lacks column_names) is passed, line 223 raises an AttributeError before the hasattr check can run. The # type: ignore[union-attr] comment acknowledges the type issue but doesn't prevent the runtime crash.

Fix in Cursor Fix in Web

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reasonable comment, here is my takedown:

  • this problem is 100% due to trying to handle both torch datasets and HF datasets. The handling of torch datasets is a bit slow so one option is to drop that support but it feels very restrictive
  • the other option is to reorder, and properly check for type with:
    • isinstance(dataset, Dataset) on one side, only compute the column names in this case and that should help with typing as well.
    • isinstance(dataset, torch.utils.data.Dataset) on the other side this handles only an (input, output) case for next token prediction and converts it to (extendended_inputs,), but we should also handle a dataset already in this format I think

)
# make directory for logs and checkpoints
model_path = pathlib.Path(smash_config.cache_dir) / "recovery"
model_path.mkdir(parents=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing exist_ok causes failure on repeated runs

Medium Severity

The call model_path.mkdir(parents=True) is missing exist_ok=True, unlike the equivalent code in text_to_image_finetuner.py line 218 which correctly uses mkdir(exist_ok=True, parents=True). This causes a FileExistsError if the recovery directory already exists from a previous run or if distillation is run multiple times with the same cache directory.

Fix in Cursor Fix in Web

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_path.mkdir(parents=True)
model_path.mkdir(exist_ok=True, parents=True)

filenames: List[str] = []

for batch in tqdm(dataloader, desc=desc):
captions = batch if isinstance(batch[0], str) else batch[0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing None check for optional dataloader causes crash

Medium Severity

The _prepare_one_dataset method accepts dataloader: Optional[DataLoader] but doesn't handle the None case. At line 175, for batch in tqdm(dataloader, desc=desc) will raise TypeError: 'NoneType' object is not iterable when dataloader is None. This occurs when val_dataloader() or test_dataloader() return None, which is common when validation or test data isn't provided to the caption datamodule.

Fix in Cursor Fix in Web

Copy link
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some code suggestions fixing the bugbot slander.
There's one specific improvement that can be made on next-token-prediction dataset handling which is very crude. Let me know if you need me to update the function.
If you accept the small changes it's an instant approve on my side :)

The number of parameters found that match the given name but were not trainable.
"""
if len(target_modules) == 0:
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return
return (0, 0)

if len(model_heads) != 1:
# = 0: model with no head, e.g. diffusers
# > 1: model with multiple heads, e.g. for localization, not currently supported
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
model_head_names = [
comp_name
for comp_name, component in inspect.getmembers(model)
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
]

A text field is only provided if the dataset is a huggingface dataset not yet tokenized.
"""
column_names = dataset.column_names # type: ignore[union-attr]
if hasattr(dataset, "column_names") and "input_ids" in column_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reasonable comment, here is my takedown:

  • this problem is 100% due to trying to handle both torch datasets and HF datasets. The handling of torch datasets is a bit slow so one option is to drop that support but it feels very restrictive
  • the other option is to reorder, and properly check for type with:
    • isinstance(dataset, Dataset) on one side, only compute the column names in this case and that should help with typing as well.
    • isinstance(dataset, torch.utils.data.Dataset) on the other side this handles only an (input, output) case for next token prediction and converts it to (extendended_inputs,), but we should also handle a dataset already in this format I think

)
# make directory for logs and checkpoints
model_path = pathlib.Path(smash_config.cache_dir) / "recovery"
model_path.mkdir(parents=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_path.mkdir(parents=True)
model_path.mkdir(exist_ok=True, parents=True)

List[str]
The filenames of the dataset.
"""
Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this None case and extending the type hint should take care of the bugbot comment below. Maybe failing more strongly would be a better choice though.

Suggested change
Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True)
if dataloader is None:
pruna_logger.warning(f"Missing dataloader for {subdir_name} data")
return None
Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True)


def _prepare_one_dataset(
self,
dataloader: Optional[DataLoader],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dataloader: Optional[DataLoader],
dataloader: Optional[DataLoader] | None,

)
from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr
from pruna.data.pruna_datamodule import PrunaDataModule

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from pruna.logging.logger import pruna_logger

Copy link
Member

@simlang simlang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to add, thank you!
LGTM 🚀

if len(model_heads) != 1:
# = 0: model with no head, e.g. diffusers
# > 1: model with multiple heads, e.g. for localization, not currently supported
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit hacky, but okay for me, or we also collect the name in model_heads in line 71, then we don't have to go through the model twice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants