-
Notifications
You must be signed in to change notification settings - Fork 77
feat: add recoverer algorithms #491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@cursor review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment @cursor review or bugbot run to trigger another review on this PR
| The number of parameters found that match the given name but were not trainable. | ||
| """ | ||
| if len(target_modules) == 0: | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return value causes None unpacking error
High Severity
The unfreeze_parameters_by_name function returns None (via bare return) when target_modules is empty, but the function signature declares tuple[int, int] as the return type. Callers like BiasAdapter.activate unpack the result as num_activ_param, num_skip_param = utils.unfreeze_parameters_by_name(...), which would raise a TypeError if called with empty target_modules. The similar functions unfreeze_submodules_by_type and unfreeze_submodules_by_class_name correctly return (0, 0) for empty inputs.
| if len(model_heads) != 1: | ||
| # = 0: model with no head, e.g. diffusers | ||
| # > 1: model with multiple heads, e.g. for localization, not currently supported | ||
| model_head_names = [h[0] for h in model_heads] # type: ignore[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect indexing of Linear modules instead of names
Medium Severity
The code attempts [h[0] for h in model_heads] to extract head names, but model_heads is a list of torch.nn.Linear modules (not tuples). The list comprehension at lines 71-75 stores only component, discarding comp_name. When there are multiple heads (>1), indexing a Linear module with [0] will raise a TypeError. The fix requires storing (comp_name, component) tuples in model_heads or collecting names separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_head_names = [h[0] for h in model_heads] # type: ignore[index] | |
| model_head_names = [ | |
| comp_name | |
| for comp_name, component in inspect.getmembers(model) | |
| if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower() | |
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit hacky, but okay for me, or we also collect the name in model_heads in line 71, then we don't have to go through the model twice
| A text field is only provided if the dataset is a huggingface dataset not yet tokenized. | ||
| """ | ||
| column_names = dataset.column_names # type: ignore[union-attr] | ||
| if hasattr(dataset, "column_names") and "input_ids" in column_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attribute access before hasattr check causes crash
High Severity
The function _format_dataset_for_causal_lm accepts Dataset | torch.utils.data.Dataset per its type signature, but line 223 unconditionally accesses dataset.column_names before line 224 checks hasattr(dataset, "column_names"). When a standard PyTorch Dataset (which lacks column_names) is passed, line 223 raises an AttributeError before the hasattr check can run. The # type: ignore[union-attr] comment acknowledges the type issue but doesn't prevent the runtime crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a reasonable comment, here is my takedown:
- this problem is 100% due to trying to handle both torch datasets and HF datasets. The handling of torch datasets is a bit slow so one option is to drop that support but it feels very restrictive
- the other option is to reorder, and properly check for type with:
- isinstance(dataset, Dataset) on one side, only compute the column names in this case and that should help with typing as well.
- isinstance(dataset, torch.utils.data.Dataset) on the other side this handles only an (input, output) case for next token prediction and converts it to (extendended_inputs,), but we should also handle a dataset already in this format I think
| ) | ||
| # make directory for logs and checkpoints | ||
| model_path = pathlib.Path(smash_config.cache_dir) / "recovery" | ||
| model_path.mkdir(parents=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing exist_ok causes failure on repeated runs
Medium Severity
The call model_path.mkdir(parents=True) is missing exist_ok=True, unlike the equivalent code in text_to_image_finetuner.py line 218 which correctly uses mkdir(exist_ok=True, parents=True). This causes a FileExistsError if the recovery directory already exists from a previous run or if distillation is run multiple times with the same cache directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_path.mkdir(parents=True) | |
| model_path.mkdir(exist_ok=True, parents=True) |
| filenames: List[str] = [] | ||
|
|
||
| for batch in tqdm(dataloader, desc=desc): | ||
| captions = batch if isinstance(batch[0], str) else batch[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing None check for optional dataloader causes crash
Medium Severity
The _prepare_one_dataset method accepts dataloader: Optional[DataLoader] but doesn't handle the None case. At line 175, for batch in tqdm(dataloader, desc=desc) will raise TypeError: 'NoneType' object is not iterable when dataloader is None. This occurs when val_dataloader() or test_dataloader() return None, which is common when validation or test data isn't provided to the caption datamodule.
gsprochette
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some code suggestions fixing the bugbot slander.
There's one specific improvement that can be made on next-token-prediction dataset handling which is very crude. Let me know if you need me to update the function.
If you accept the small changes it's an instant approve on my side :)
| The number of parameters found that match the given name but were not trainable. | ||
| """ | ||
| if len(target_modules) == 0: | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return | |
| return (0, 0) |
| if len(model_heads) != 1: | ||
| # = 0: model with no head, e.g. diffusers | ||
| # > 1: model with multiple heads, e.g. for localization, not currently supported | ||
| model_head_names = [h[0] for h in model_heads] # type: ignore[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_head_names = [h[0] for h in model_heads] # type: ignore[index] | |
| model_head_names = [ | |
| comp_name | |
| for comp_name, component in inspect.getmembers(model) | |
| if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower() | |
| ] |
| A text field is only provided if the dataset is a huggingface dataset not yet tokenized. | ||
| """ | ||
| column_names = dataset.column_names # type: ignore[union-attr] | ||
| if hasattr(dataset, "column_names") and "input_ids" in column_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a reasonable comment, here is my takedown:
- this problem is 100% due to trying to handle both torch datasets and HF datasets. The handling of torch datasets is a bit slow so one option is to drop that support but it feels very restrictive
- the other option is to reorder, and properly check for type with:
- isinstance(dataset, Dataset) on one side, only compute the column names in this case and that should help with typing as well.
- isinstance(dataset, torch.utils.data.Dataset) on the other side this handles only an (input, output) case for next token prediction and converts it to (extendended_inputs,), but we should also handle a dataset already in this format I think
| ) | ||
| # make directory for logs and checkpoints | ||
| model_path = pathlib.Path(smash_config.cache_dir) / "recovery" | ||
| model_path.mkdir(parents=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_path.mkdir(parents=True) | |
| model_path.mkdir(exist_ok=True, parents=True) |
| List[str] | ||
| The filenames of the dataset. | ||
| """ | ||
| Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding this None case and extending the type hint should take care of the bugbot comment below. Maybe failing more strongly would be a better choice though.
| Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True) | |
| if dataloader is None: | |
| pruna_logger.warning(f"Missing dataloader for {subdir_name} data") | |
| return None | |
| Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True) |
|
|
||
| def _prepare_one_dataset( | ||
| self, | ||
| dataloader: Optional[DataLoader], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| dataloader: Optional[DataLoader], | |
| dataloader: Optional[DataLoader] | None, |
| ) | ||
| from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr | ||
| from pruna.data.pruna_datamodule import PrunaDataModule | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from pruna.logging.logger import pruna_logger |
simlang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing to add, thank you!
LGTM 🚀
| if len(model_heads) != 1: | ||
| # = 0: model with no head, e.g. diffusers | ||
| # > 1: model with multiple heads, e.g. for localization, not currently supported | ||
| model_head_names = [h[0] for h in model_heads] # type: ignore[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit hacky, but okay for me, or we also collect the name in model_heads in line 71, then we don't have to go through the model twice
Description
Added PERP (Parameter-Efficient Recovery Protocol) recovery algorithms to restore model quality after compression. The implementation includes multiple variants for both text-to-image and text-to-text models, supporting LoRA adapters and in-place finetuning mechanisms.
Algorithm Locations
Core PERP Algorithms
Main algorithm classes:
src/pruna/algorithms/perp.pyTextToImagePERP- General-purpose recovery for text-to-image modelsTextToImageInPlacePERP- In-place variant without LoRA (faster inference)TextToImageLoRA- LoRA-only variant for text-to-imageTextToTextPERP- Recovery for text-to-text models (LLMs)TextToTextInPlacePERP- In-place variant for text-to-textTextToTextLoRA- LoRA-only variant for text-to-textDistillation variants:
src/pruna/algorithms/distillation_perp.pyTextToImagePERPDistillation- PERP with distillation for text-to-imageTextToImageInPlacePERPDistillation- In-place distillation variantTextToImageLoraDistillation- LoRA distillation variantCore Recovery Infrastructure
Base recoverer:
src/pruna/algorithms/global_utils/recovery/perp_recoverer.pyPERPRecoverer- Base class implementing core recovery logic with adapter management, finetuning orchestration, and scheduler handlingAdapter implementations:
src/pruna/algorithms/global_utils/recovery/adapters/norm.py- NormAdapter for in-place norm finetuningbias.py- BiasAdapter for in-place bias finetuninghead.py- HeadAdapter for head finetuning (text-to-text models)lora.py- LoraAdapter using HuggingFace PEFTutils.py- Adapter utilities and parameter freezingFinetuner implementations:
src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py- TextToImageFinetuner with training loop, loss computation, and scheduler integrationtext_to_text_finetuner.py- TextToTextFinetuner for LLM recoverytext_to_image_distiller.py- TextToImageDistiller for distillation-based recoverydiffusers/scheduler_interface.py- Training scheduler utilities (get_training_scheduler, sample_timesteps, add_noise, get_target)diffusers/pack_and_predict.py- Pipeline-specific prediction functions for SD, SDXL, Sana, Fluxdiffusers/utils.py- Diffusers pipeline utilities (denoiser extraction, prompt encoding, device management)diffusers/distillation_arg_utils.py- Utilities for distillation argument handlingType of Change
How Has This Been Tested?
Text-to-image PERP tests:
tests/algorithms/testers/tti_perp.pyTextToImagePERPalgorithm with various model typesText-to-image in-place PERP tests:
tests/algorithms/testers/tti_inplace_perp.pyTextToImageInPlacePERP(no LoRA variant)Text-to-image LoRA tests:
tests/algorithms/testers/tti_lora.pyTextToImageLoRA(LoRA-only variant)Text-to-text PERP tests:
tests/algorithms/testers/ttt_perp.pyTextToTextPERPalgorithm with LLM modelsText-to-text in-place PERP tests:
tests/algorithms/testers/ttt_inplace_perp.pyTextToTextInPlacePERP(no LoRA variant)Text-to-text LoRA tests:
tests/algorithms/testers/ttt_lora.pyTextToTextLoRA(LoRA-only variant)Distillation PERP tests:
tests/algorithms/testers/tti_distillation_perp.pyTextToImagePERPDistillationalgorithmIn-place distillation PERP tests:
tests/algorithms/testers/tti_distillation_inplace_perp.pyTextToImageInPlacePERPDistillationalgorithmLoRA distillation PERP tests:
tests/algorithms/testers/tti_distillation_lora.pyTextToImageLoraDistillationalgorithmTest utilities:
tests/algorithms/testers/utils.pyrestrict_recovery_time()- Utility to limit training time for testingreplace_datamodule_with_distillation_datamodule()- Helper for distillation testsget_model_sparsity()- Utility to measure model sparsityChecklist