diff --git a/.github/workflows/model_jobs.yml b/.github/workflows/model_jobs.yml index e96c7ef16a07..94f6dece6bc2 100644 --- a/.github/workflows/model_jobs.yml +++ b/.github/workflows/model_jobs.yml @@ -186,7 +186,18 @@ jobs: env: report_name_prefix: ${{ inputs.report_name_prefix }} run: | - cat "/transformers/reports/${machine_type}_${report_name_prefix}_${matrix_folders}_test_reports/captured_info.txt" + shopt -s nullglob + captured_info_files=("/transformers/reports/${machine_type}_${report_name_prefix}_${matrix_folders}_test_reports"/captured_info*.txt) + + if [ ${#captured_info_files[@]} -eq 0 ]; then + echo "No captured information files found." + exit 0 + fi + + for captured_info_file in "${captured_info_files[@]}"; do + echo "===== ${captured_info_file##*/} =====" + cat "$captured_info_file" + done - name: Copy test_outputs.txt if: ${{ always() }} diff --git a/all_requirements.txt b/all_requirements.txt new file mode 100644 index 000000000000..eacb47727a64 --- /dev/null +++ b/all_requirements.txt @@ -0,0 +1,98 @@ +gpustat==1.1.1 +psutil==6.0.0 +psycopg2==2.9.9 +pandas>=1.5.0 +numpy>=1.21.0 +psutil>=5.8.0 +nvidia-ml-py>=12.0.0 +torch>=2.0.0 +datasets>=2.10.0 +huggingface_hub>=0.16.0 +amdsmi>=7.0.2 +git+https://github.com/huggingface/transformers.git@main # install main or adjust it with vX.X.X for installing version specific transforms +datasets==1.8.0accelerate >= 0.12.0 +datasets >= 1.8.0 +torch >= 1.3.0 +evaluateaccelerate >= 0.21.0 +sentencepiece != 0.1.92 +protobuf +torch >= 1.3 +datasets[audio]>=1.14.0 +evaluate +librosa +torchaudio +torch>=1.6 +accelerate >= 0.12.0 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +sacrebleu >= 1.4.12 +py7zr +torch >= 1.3 +evaluatedatasets >= 2.0.0 +torch >= 1.3 +accelerate +evaluate +Pillow +albumentations >= 1.4.16 +accelerate >= 0.12.0 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +rouge-score +nltk +py7zr +torch >= 1.3 +evaluate +torch>=1.5.0 +torchvision>=0.6.0 +datasets>=1.8.0accelerate >= 0.12.0 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +scipy +scikit-learn +protobuf +torch >= 1.3 +evaluateaccelerate>=0.12.0 +torch>=1.5.0 +torchvision>=0.6.0 +datasets>=2.14.0 +evaluate +scikit-learnaccelerate >= 0.12.0 +torch >= 1.3 +datasets >= 2.14.0 +sentencepiece != 0.1.92 +protobuf +evaluate +scikit-learn +accelerate >= 0.12.0 +seqeval +datasets >= 1.8.0 +torch >= 1.3 +evaluatealbumentations >= 1.4.16 +timm +datasets>=4.0 +torchmetrics +pycocotools +datasets[audio] >= 1.18.0 +torch >= 1.5 +torchaudio +librosa +jiwer +evaluate +datasets[audio] >= 1.12.0 +torch >= 1.5 +torchaudio +accelerate >= 0.12.0 +librosatorch>=1.5.0 +torchvision>=0.6.0 +datasets>=1.8.0albumentations >= 1.4.16 +timm +datasets +torchmetrics +pycocotools +accelerate >= 0.12.0 +sentencepiece != 0.1.92 +protobuf +torch >= 1.3 +evaluate diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index bd3de7b27311..ee71c087dde2 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -18,9 +18,20 @@ ARG TORCHCODEC='0.11.0' ARG FLASH_ATTN='false' +# 'x86_64' or 'arm64' +ARG ARCHITECTURE='x86_64' + RUN apt update -RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs +RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs curl RUN git lfs install + +RUN set-e; \ +if [ "$ARCHITECTURE" = "arm64" ]; then \ + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y;\ + PATH="/root/.cargo/bin:${PATH}";\ + rustc --version;\ +fi; + RUN python3 -m pip install --no-cache-dir --upgrade pip ARG REF=main @@ -36,7 +47,11 @@ RUN set -e; \ # Determine torch version if [ ${#PYTORCH} -gt 0 ] && [ "$PYTORCH" != "pre" ]; then \ VERSION="torch==${PYTORCH}.*"; \ - TORCHCODEC_VERSION="torchcodec==${TORCHCODEC}.*"; \ + if [ "$ARCHITECTURE" = "arm64" ]; then \ + TORCHCODEC_VERSION="torchcodec"; \ + else \ + TORCHCODEC_VERSION="torchcodec==${TORCHCODEC}.*"; \ + fi; \ else \ VERSION="torch"; \ TORCHCODEC_VERSION="torchcodec"; \ diff --git a/docs/source/en/auto_docstring.md b/docs/source/en/auto_docstring.md index 5426af13fa31..1b55f0fcc5d1 100644 --- a/docs/source/en/auto_docstring.md +++ b/docs/source/en/auto_docstring.md @@ -134,11 +134,11 @@ class MyModelConfig(PreTrainedConfig): Description of another model-specific parameter. ```python - >>> from transformers import MyModelConfig, MyModel + from transformers import MyModelConfig, MyModel - >>> configuration = MyModelConfig() - >>> model = MyModel(configuration) - >>> configuration = model.config + configuration = MyModelConfig() + model = MyModel(configuration) + configuration = model.config ``` """ diff --git a/docs/source/en/internal/import_utils.md b/docs/source/en/internal/import_utils.md index 41ee64f1611c..abb85008d53e 100644 --- a/docs/source/en/internal/import_utils.md +++ b/docs/source/en/internal/import_utils.md @@ -29,18 +29,24 @@ This object is still importable: ```python >>> from transformers import DetrImageProcessor ->>> print(DetrImageProcessor) - +>>> print(DetrImageProcessor) # doctest: +ELLIPSIS + ``` However, no method can be called on that object: ```python +>>> from transformers.utils.import_utils import BACKENDS_MAPPING, DummyObject +>>> _torchvision_backend = BACKENDS_MAPPING["torchvision"] +>>> BACKENDS_MAPPING["torchvision"] = (lambda: False, _torchvision_backend[1].lstrip("\n")) +>>> DetrImageProcessor = DummyObject("DetrImageProcessor", (), {"_backends": ["torchvision"]}) >>> DetrImageProcessor.from_pretrained() -ImportError: -DetrImageProcessor requires the Torchvision library but it was not found in your environment. Check out the instructions on the +Traceback (most recent call last): +... +ImportError: DetrImageProcessor requires the Torchvision library but it was not found in your environment. Check out the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. +>>> BACKENDS_MAPPING["torchvision"] = _torchvision_backend ``` Let's see how to specify specific object dependencies. diff --git a/docs/source/en/main_classes/pipelines.md b/docs/source/en/main_classes/pipelines.md index faca097d1160..16f20999a954 100644 --- a/docs/source/en/main_classes/pipelines.md +++ b/docs/source/en/main_classes/pipelines.md @@ -34,6 +34,7 @@ pipeline but can provide additional quality of life. Simple call on one item: ```python +>>> from transformers import pipeline >>> pipe = pipeline("text-classification") >>> pipe("This restaurant is awesome") [{'label': 'POSITIVE', 'score': 0.9998743534088135}] diff --git a/docs/source/en/model_doc/pe_audio_video.md b/docs/source/en/model_doc/pe_audio_video.md index e116724d43f5..af0db76537f5 100644 --- a/docs/source/en/model_doc/pe_audio_video.md +++ b/docs/source/en/model_doc/pe_audio_video.md @@ -26,7 +26,47 @@ TODO ### Basic usage ```py -TODO + +model = PeAudioVideoModel.from_pretrained("facebook/pe-av-large", device_map="cuda", dtype=torch.bfloat16) +processor = PeAudioVideoProcessor.from_pretrained("facebook/pe-av-large") + +from huggingface_hub import hf_hub_download + +video_path = hf_hub_download( + repo_id="eustlb/dummy-video-dataset", filename="audiobox.mp4", repo_type="dataset" +) + +video_path2 = hf_hub_download( + repo_id="eustlb/dummy-video-dataset", filename="glass_breaking.mp4", repo_type="dataset" +) + +audio_path = hf_hub_download( + repo_id="eustlb/dummy-video-dataset", filename="audiobox.mp4", repo_type="dataset" +) + +audio_path2 = hf_hub_download( + repo_id="eustlb/dummy-video-dataset", filename="glass_breaking.mp4", repo_type="dataset" +) + +video_files = [video_path, video_path2] +descriptions = ["A woman and a man speaking", "A glass breaking"] +audio_files = [audio_path, audio_path2] + +inputs = processor( + videos=video_files, text=descriptions, audio=audio_files, return_tensors="pt", padding=True +) + +with torch.inference_mode(), torch.autocast(model.device.type, dtype=torch.bfloat16): + outputs = model(**inputs.to(model.device, dtype=model.dtype)) + +audio_embeds = outputs.audio_embeds # Audio-only embeddings +video_embeds = outputs.video_embeds # Video-only embeddings +audio_video_embeds = outputs.audio_video_embeds # Joint audio-video embeddings +text_audio_embeds = outputs.text_audio_embeds # Text embeddings aligned to audio +text_video_embeds = outputs.text_video_embeds # Text embeddings aligned to video +text_audio_video_embeds = outputs.text_audio_video_embeds # Text embeddings aligned to audio-video +audio_plus_text_embeds = outputs.audio_plus_text_embeds # Joint audio and text embedding +video_plus_text_embeds = outputs.video_plus_text_embeds # Joint video and text embedding ``` ## PeAudioVideoProcessor diff --git a/docs/source/en/model_doc/qwen3_5.md b/docs/source/en/model_doc/qwen3_5.md index 1d542dd918ce..aae67e8a8e7a 100644 --- a/docs/source/en/model_doc/qwen3_5.md +++ b/docs/source/en/model_doc/qwen3_5.md @@ -70,14 +70,19 @@ TODO [[autodoc]] Qwen3_5ForCausalLM - forward +## Qwen3_5ForConditionalGeneration + +[[autodoc]] Qwen3_5ForConditionalGeneration + - forward + ## Qwen3_5ForSequenceClassification [[autodoc]] Qwen3_5ForSequenceClassification - forward -## Qwen3_5ForConditionalGeneration +## Qwen3_5TextForSequenceClassification -[[autodoc]] Qwen3_5ForConditionalGeneration +[[autodoc]] Qwen3_5TextForSequenceClassification - forward ## Qwen3_5Tokenizer diff --git a/docs/source/en/tasks/zero_shot_object_detection.md b/docs/source/en/tasks/zero_shot_object_detection.md index 8a5506939898..aa15ff46f05d 100644 --- a/docs/source/en/tasks/zero_shot_object_detection.md +++ b/docs/source/en/tasks/zero_shot_object_detection.md @@ -168,8 +168,7 @@ boxes have the correct coordinates relative to the original image: ... outputs = model(**inputs) >>> results = processor.post_process_grounded_object_detection( -... outputs, threshold=0.50, target_sizes=[(image.height, image.width)], text_labels=text_labels, -... )[0] +... outputs, threshold=0.50, target_sizes=[(image.height, image.width)], text_labels=text_labels)[0] >>> draw = ImageDraw.Draw(image) diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index ff6e666a804e..a207b5d32e0f 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -160,10 +160,8 @@ def create_causal_mask_mapping( # from `forward` call. If users run a `forward` call, we have no option to infer `is_first_iteration` because users may be # running generation with custom loop. Thus we need to infer it in a `non-perfect` way # NOTE: Determining prefill in that case requires checking data values, which is not compile-compatible. - is_first_iteration = ( - is_first_iteration - if is_first_iteration - else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None) + is_first_iteration = is_first_iteration or ( + past_key_values is None or not past_key_values.is_initialized or pixel_values is not None ) if is_first_iteration or not kwargs.get("use_cache", True): @@ -256,9 +254,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/examples/pytorch/image-pretraining/run_mim_no_trainer.py b/examples/pytorch/image-pretraining/run_mim_no_trainer.py index f6d13078bbc6..5e539047a6b9 100644 --- a/examples/pytorch/image-pretraining/run_mim_no_trainer.py +++ b/examples/pytorch/image-pretraining/run_mim_no_trainer.py @@ -633,7 +633,7 @@ def preprocess_images(examples): ) # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: + if accelerator.distributed_type == DistributedType.XLA: model.tie_weights() # We need to recalculate our total training steps as the size of the training dataloader may have changed. diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 0f8d2cd0d6e3..d91340a3afe4 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -553,7 +553,7 @@ def group_texts(examples): ) # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: + if accelerator.distributed_type == DistributedType.XLA: model.tie_weights() # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -627,6 +627,7 @@ def group_texts(examples): model.train() if args.with_tracking: total_loss = 0 + total_samples = 0 if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) @@ -638,7 +639,9 @@ def group_texts(examples): loss = outputs.loss # We keep track of the loss at each epoch if args.with_tracking: - total_loss += loss.detach().float() + batch_size = batch["input_ids"].shape[0] + total_loss += loss.detach().float() * batch_size + total_samples += batch_size accelerator.backward(loss) optimizer.step() lr_scheduler.step() @@ -665,7 +668,8 @@ def group_texts(examples): outputs = model(**batch) loss = outputs.loss - losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + batch_size = batch["input_ids"].shape[0] + losses.append(accelerator.gather_for_metrics(loss.repeat(batch_size))) losses = torch.cat(losses) try: @@ -681,7 +685,7 @@ def group_texts(examples): { "perplexity": perplexity, "eval_loss": eval_loss, - "train_loss": total_loss.item() / len(train_dataloader), + "train_loss": total_loss.item() / total_samples, "epoch": epoch, "step": completed_steps, }, diff --git a/examples/pytorch/language-modeling/run_fim_no_trainer.py b/examples/pytorch/language-modeling/run_fim_no_trainer.py index 962e497b72e0..a0c0ff8b7da0 100644 --- a/examples/pytorch/language-modeling/run_fim_no_trainer.py +++ b/examples/pytorch/language-modeling/run_fim_no_trainer.py @@ -743,7 +743,7 @@ def apply_fim(examples): ) # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: + if accelerator.distributed_type == DistributedType.XLA: model.tie_weights() # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -817,6 +817,7 @@ def apply_fim(examples): model.train() if args.with_tracking: total_loss = 0 + total_samples = 0 if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) @@ -828,7 +829,9 @@ def apply_fim(examples): loss = outputs.loss # We keep track of the loss at each epoch if args.with_tracking: - total_loss += loss.detach().float() + batch_size = batch["input_ids"].shape[0] + total_loss += loss.detach().float() * batch_size + total_samples += batch_size accelerator.backward(loss) optimizer.step() lr_scheduler.step() @@ -855,7 +858,8 @@ def apply_fim(examples): outputs = model(**batch) loss = outputs.loss - losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + batch_size = batch["input_ids"].shape[0] + losses.append(accelerator.gather_for_metrics(loss.repeat(batch_size))) losses = torch.cat(losses) try: @@ -871,7 +875,7 @@ def apply_fim(examples): { "perplexity": perplexity, "eval_loss": eval_loss, - "train_loss": total_loss.item() / len(train_dataloader), + "train_loss": total_loss.item() / total_samples, "epoch": epoch, "step": completed_steps, }, diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 981a496badad..a4ed188c0fa1 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -582,7 +582,7 @@ def group_texts(examples): ) # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: + if accelerator.distributed_type == DistributedType.XLA: model.tie_weights() # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -656,6 +656,7 @@ def group_texts(examples): model.train() if args.with_tracking: total_loss = 0 + total_samples = 0 if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) @@ -667,7 +668,9 @@ def group_texts(examples): loss = outputs.loss # We keep track of the loss at each epoch if args.with_tracking: - total_loss += loss.detach().float() + batch_size = batch["input_ids"].shape[0] + total_loss += loss.detach().float() * batch_size + total_samples += batch_size accelerator.backward(loss) optimizer.step() lr_scheduler.step() @@ -695,7 +698,8 @@ def group_texts(examples): outputs = model(**batch) loss = outputs.loss - losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + batch_size = batch["input_ids"].shape[0] + losses.append(accelerator.gather_for_metrics(loss.repeat(batch_size))) losses = torch.cat(losses) try: @@ -711,7 +715,7 @@ def group_texts(examples): { "perplexity": perplexity, "eval_loss": eval_loss, - "train_loss": total_loss.item() / len(train_dataloader), + "train_loss": total_loss.item() / total_samples, "epoch": epoch, "step": completed_steps, }, diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 457ccc9001bf..573adbe46c81 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -412,8 +412,9 @@ def main(): # Trying to have good defaults here, don't hesitate to tweak to your needs. + label_feature = raw_datasets["train"].features["label"] is_regression = ( - raw_datasets["train"].features["label"].dtype in ["float32", "float64"] + getattr(label_feature, "dtype", None) in ["float32", "float64"] if data_args.do_regression is None else data_args.do_regression ) @@ -439,7 +440,7 @@ def main(): raise error else: # classification - if raw_datasets["train"].features["label"].dtype == "list": # multi-label classification + if isinstance(raw_datasets["train"].features["label"], datasets.Sequence): # multi-label classification is_multi_label = True logger.info("Label type is list, doing multi-label classification") # Trying to find the number of labels in a multi-label classification task diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index a705bc94a7f3..6fb8a786dc27 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -446,6 +446,9 @@ def main(): ) model.config.forced_bos_token_id = forced_bos_token_id + if hasattr(model, "generation_config") and model.generation_config is not None: + model.generation_config.forced_bos_token_id = forced_bos_token_id + # Get the language codes for input/target. source_lang = data_args.source_lang.split("_")[0] target_lang = data_args.target_lang.split("_")[0] diff --git a/scripts/check_tokenizers.py b/scripts/check_tokenizers.py index 93d7fb5afdc6..cd136a67124c 100644 --- a/scripts/check_tokenizers.py +++ b/scripts/check_tokenizers.py @@ -10,37 +10,27 @@ logging.set_verbosity_info() +# Mapping of slow -> fast tokenizer classes TOKENIZER_CLASSES = { name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS } -dataset = datasets.load_dataset("facebook/xnli", split="test+validation") # no-script +# Load a small subset of XNLI (English) for safe testing else all_languages and test+validation +dataset = datasets.load_dataset("facebook/xnli", "en", split="test+validation[:10]") -total = 0 -perfect = 0 -imperfect = 0 -wrong = 0 +total = perfect = imperfect = wrong = 0 def check_diff( spm_diff: list[int], tok_diff: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase ) -> bool: if spm_diff == list(reversed(tok_diff)): - # AAA -> AA+A vs A+AA case. return True elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff): - # Second order OK - # Barrich -> Barr + ich vs Bar + rich return True spm_reencoded = slow.encode(slow.decode(spm_diff)) tok_reencoded = fast.encode(fast.decode(spm_diff)) if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded: - # Type 3 error. - # Snehagatha -> - # Sne, h, aga, th, a - # Sne, ha, gat, ha - # Encoding the wrong with sp does not even recover what spm gave us - # It fits tokenizer however... return True return False @@ -59,8 +49,6 @@ def check_LTR_mark(line: str, idx: int, fast: PreTrainedTokenizerBase) -> bool: def check_details( line: str, spm_ids: list[int], tok_ids: list[int], slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase ) -> bool: - # Encoding can be the same with same result AAA -> A + AA vs AA + A - # We can check that we use at least exactly the same number of tokens. for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)): if spm_id != tok_id: break @@ -80,11 +68,9 @@ def check_details( return True if last - first > 5: - # We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems spms = Counter(spm_ids[first:last]) toks = Counter(tok_ids[first:last]) - - removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si} + removable_tokens = {spm_ for spm_, si in spms.items() if toks.get(spm_, 0) == si} min_width = 3 for i in range(last - first - min_width): if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)): @@ -105,25 +91,11 @@ def check_details( ): return True - print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}") - try: - print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}") - except Exception as e: - print(f"Could not decode tok_ids: {e}") - - fast.decode(spm_ids[:first]) - fast.decode(spm_ids[last:]) - wrong = fast.decode(spm_ids[first:last]) - print() - print(wrong) return False def test_string(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase, text: str) -> None: - global perfect - global imperfect - global wrong - global total + global perfect, imperfect, wrong, total slow_ids = slow.encode(text) fast_ids = fast.encode(text) @@ -140,9 +112,6 @@ def test_string(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase, te else: perfect += 1 - if total % 10000 == 0: - print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})") - if skip_assert: return @@ -151,29 +120,51 @@ def test_string(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase, te ) -def test_tokenizer(slow: PreTrainedTokenizerBase, fast: PreTrainedTokenizerBase) -> None: - global batch_total - for i in range(len(dataset)): - # premise, all languages - for text in dataset[i]["premise"].values(): - test_string(slow, fast, text) - - # hypothesis, all languages - for text in dataset[i]["hypothesis"]["translation"]: - test_string(slow, fast, text) +def test_tokenizer(slow, fast, dry_run=True): + global total, perfect, imperfect, wrong + total = perfect = imperfect = wrong = 0 + n_samples = 5 if dry_run else len(dataset) + for i in range(n_samples): + premise = dataset[i]["premise"] + hypothesis = dataset[i]["hypothesis"] + test_string(slow, fast, premise) + test_string(slow, fast, hypothesis) if __name__ == "__main__": + DEFAULT_CHECKPOINTS = { + "BertTokenizer": "bert-base-uncased", + "BertTokenizerFast": "bert-base-uncased", + "AlbertTokenizer": "albert-base-v2", + "AlbertTokenizerFast": "albert-base-v2", + "BartTokenizer": "facebook/bart-base", + "BartTokenizerFast": "facebook/bart-base", + "BarthezTokenizer": "facebook/barthez", + "DPRReaderTokenizer": "facebook/dpr-reader-single-nq-base", + "DPRReaderTokenizerFast": "facebook/dpr-reader-single-nq-base", + } + for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items(): - checkpoint_names = list(slow_class.max_model_input_sizes.keys()) - for checkpoint in checkpoint_names: - imperfect = 0 - perfect = 0 - wrong = 0 - total = 0 + checkpoint = DEFAULT_CHECKPOINTS.get(name) + if checkpoint is None: + print(f"Skipping {name}: no compatible checkpoint defined") + continue + try: print(f"========================== Checking {name}: {checkpoint} ==========================") slow = slow_class.from_pretrained(checkpoint, force_download=True) fast = fast_class.from_pretrained(checkpoint, force_download=True) - test_tokenizer(slow, fast) - print(f"Accuracy {perfect * 100 / total:.2f}") + + test_tokenizer(slow, fast, dry_run=True) + + if total > 0: + print(f"Accuracy {perfect * 100 / total:.2f}% ({perfect}/{total} perfect)") + else: + print("No samples tested.") + + except ImportError as e: + print(f"Skipping {name} due to missing dependency: {e}") + continue + except Exception as e: + print(f"Skipping {name} due to error: {e}") + continue diff --git a/setup.py b/setup.py index 42c865b1b9ba..a7c57f463852 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ "kenlm", "kernels>=0.12.0,<0.13", "librosa", - "mistral-common[image]>=1.10.0", + "mistral-common[image,audio]>=1.10.0", "nltk<=3.8.1", "num2words", "numpy>=1.17", @@ -165,6 +165,7 @@ "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", + "requests", ] # This is a lookup table with items like: {"tokenizers": "tokenizers==0.9.4", "packaging": "packaging"}, i.e. @@ -192,7 +193,7 @@ def deps_list(*pkgs): extras["kernels"] = deps_list("kernels") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["tiktoken"] = deps_list("tiktoken", "blobfile") -extras["mistral-common"] = deps_list("mistral-common[image]") +extras["mistral-common"] = deps_list("mistral-common[image,audio]") extras["chat_template"] = deps_list("jinja2", "jmespath") extras["sklearn"] = deps_list("scikit-learn") extras["accelerate"] = deps_list("accelerate") @@ -205,7 +206,9 @@ def deps_list(*pkgs): extras["ray"] = deps_list("ray[tune]") extras["integrations"] += extras["ray"] extras["codecarbon"] = deps_list("codecarbon") -extras["serving"] = deps_list("openai", "pydantic", "uvicorn", "fastapi", "starlette", "rich") + extras["torch"] +extras["serving"] = ( + deps_list("openai", "pydantic", "uvicorn", "fastapi", "starlette", "rich", "requests") + extras["torch"] +) extras["num2words"] = deps_list("num2words") extras["benchmark"] = deps_list("optimum-benchmark") extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "rhoknp") diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index c89618f2d9cb..9f02d5146326 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -88,6 +88,12 @@ def load_audio(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np # needed. Do not raise any errors if not installed or versions do not match if is_torchcodec_available() and version.parse("0.3.0") <= TORCHCODEC_VERSION: audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate, timeout=timeout) + elif audio.rsplit("?", 1)[0].lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv")): + raise RuntimeError( + f"The audio source appears to be a video file ('{audio.split('/')[-1]}'). " + "librosa cannot decode video containers. " + "Install torchcodec>=0.3.0 (`pip install torchcodec`) to load audio from video files." + ) else: audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout) elif not isinstance(audio, np.ndarray): diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 95a47ae39fdf..da144bd8897a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -353,6 +353,24 @@ def get_max_cache_shape(self) -> int: """Return the maximum cache shape of the cache""" return self.max_cache_len + def crop(self, max_length: int) -> None: + """Crop the cache to the given length.""" + if not self.is_initialized: + return + + current_length = self.cumulative_length.item() + + if max_length < 0: + raise ValueError(f"`max_length` passed to `StaticLayer.crop()` must be >= 0, got {max_length}.") + + if max_length >= current_length: + return + + self.keys[:, :, max_length:, :].zero_() + self.values[:, :, max_length:, :].zero_() + + self.cumulative_length.fill_(max_length) + class StaticSlidingWindowLayer(StaticLayer): """ @@ -531,6 +549,14 @@ def update( self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value) return key_states, value_states + # After reset, quantized data is cleared + if self._quantized_keys is None: + self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key) + self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value) + self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + return key_states, value_states + dequant_keys = self._dequantize(self._quantized_keys) dequant_values = self._dequantize(self._quantized_values) keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2) @@ -552,6 +578,11 @@ def _quantize(self, tensor, axis): ... @abstractmethod def _dequantize(self, q_tensor): ... + def reset(self) -> None: + super().reset() + self._quantized_keys = None + self._quantized_values = None + def get_seq_length(self) -> int: """Returns the sequence length of the cached states.""" return self.cumulative_length @@ -1337,6 +1368,17 @@ def __init__( offload_only_non_sliding: bool = True, **kwargs, ): + if kwargs: + raise TypeError(f"Unknown arguments passed to StaticCache: {list(kwargs.keys())}") + + if not isinstance(offloading, bool): + raise TypeError( + f"`offloading` must be a bool, got {type(offloading)}. " + "Did you accidentally pass `device` as a positional argument?" + ) + if not isinstance(offload_only_non_sliding, bool): + raise TypeError(f"`offload_only_non_sliding` must be a bool, got {type(offload_only_non_sliding)}.") + config = config.get_text_config(decoder=True) layer_types = getattr(config, "layer_types", None) # If `layer_types` is not explicitly provided, infer if the model is fully sliding diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 3d7c6a0c51ba..77fd7b134e01 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -150,6 +150,7 @@ def __init__( completion_handler=self._completion_handler, response_handler=self._response_handler, transcription_handler=self._transcription_handler, + generation_state=self._generation_state, enable_cors=enable_cors, ) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 13a9565db590..f3fc46e9ad1c 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -32,7 +32,7 @@ from .model_manager import ModelManager from .response import ResponseHandler from .transcription import TranscriptionHandler -from .utils import X_REQUEST_ID +from .utils import X_REQUEST_ID, CBWorkerDeadError, GenerationState logger = logging.get_logger(__name__) @@ -44,6 +44,7 @@ def build_server( completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler, + generation_state: GenerationState, enable_cors: bool = False, ) -> FastAPI: """Build and return a configured FastAPI application. @@ -52,6 +53,7 @@ def build_server( model_manager: Handles model loading, caching, and cleanup. chat_handler: Handles `/v1/chat/completions` requests. response_handler: Handles `/v1/responses` requests. + generation_state: Shared generation state, used by `/health` to report CB liveness. enable_cors: If `True`, adds permissive CORS middleware (allow all origins). Returns: @@ -65,6 +67,12 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + @app.exception_handler(CBWorkerDeadError) + async def _cb_dead_handler(_request: Request, exc: CBWorkerDeadError): + # CB worker died (e.g. CUDA illegal memory access); reject new requests with 503 + # carrying the cause, instead of letting them hang in the input queue forever. + return JSONResponse({"error": str(exc)}, status_code=503) + if enable_cors: app.add_middleware( CORSMiddleware, @@ -128,6 +136,8 @@ def list_models(): @app.get("/health") def health(): + if not generation_state.is_cb_alive(): + return JSONResponse({"status": "unhealthy", "reason": "cb_worker_dead"}, status_code=503) return JSONResponse({"status": "ok"}) return app diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d786a828fc28..165a56e8ddd7 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -73,6 +73,14 @@ class _GenerationCancelled(Exception): """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``.""" +class CBWorkerDeadError(RuntimeError): + """Raised when a request is submitted to a CB worker that has died. + + Surfaced as 503 by the FastAPI exception handler. Carries the original error message + that killed the worker so the client knows why the server is in this state. + """ + + # Fallback tool call configs for models that don't declare stc_token/etc_token/response_schema # on their tokenizer. # Keys are matched via substring against model_type (e.g. "qwen" matches "qwen2", "qwen3_vl", etc.). @@ -635,6 +643,21 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> N ) self._cb.start() + def is_alive(self) -> bool: + """Whether the CB worker is healthy and able to serve new requests.""" + return self._cb is not None and self._cb.fatal_error is None + + def _check_alive(self, request_id: str) -> None: + """Raise :class:`CBWorkerDeadError` if the CB worker has died. + + Called at request entry to fail fast — submitting to a dead worker would otherwise + enqueue the request into a void where it never gets processed. + """ + if self._cb is not None and self._cb.fatal_error is not None: + raise CBWorkerDeadError( + f"CB worker is dead and cannot accept request {request_id}: {self._cb.fatal_error}" + ) + def generate_streaming( self, model: "PreTrainedModel", @@ -648,6 +671,7 @@ def generate_streaming( cb = self._cb if cb is None: raise RuntimeError("CB manager not initialized. Call `init_cb()` first.") + self._check_alive(request_id) loop = asyncio.get_running_loop() text_queue: asyncio.Queue = asyncio.Queue() @@ -669,7 +693,13 @@ def generate_streaming( def _on_output(output): try: streamer.put(output) - if output.is_finished(): + # ``error`` is set together with ``status = FAILED`` in CB's _handle_request_error. + # Surface it as an end-of-stream error so the SSE handler can emit it and close, + # instead of leaving the client hanging on a stream that will never end. + if output.error is not None: + text_queue.put_nowait(_StreamError(output.error)) + streamer.end() + elif output.is_finished(): streamer.end() except Exception as e: text_queue.put_nowait(_StreamError(str(e))) @@ -689,6 +719,7 @@ async def generate_non_streaming( cb = self._cb if cb is None: raise RuntimeError("CB manager not initialized. Call `init_cb()` first.") + self._check_alive(request_id) input_ids = inputs["input_ids"] input_len = len(input_ids) @@ -711,8 +742,16 @@ def _on_result(result): eos_token_id=gen_config.eos_token_id, ) result = await future - if result is None: - raise RuntimeError(f"CB manager stopped before producing a result for {request_id}") + # CB signals a failed request by setting ``error`` (and ``status = FAILED``) on the + # delivered GenerationOutput, often with empty ``generated_tokens``. Surface it instead + # of returning an empty success that downstream parsing/decoding would silently mask. + # If the worker itself died, route to CBWorkerDeadError so the client gets the same 503 + # as requests submitted post-crash; otherwise it's a per-request failure (e.g. unsupported + # logit-processor kwarg) and a plain RuntimeError -> 500 is appropriate. + if result.error is not None: + if self._cb.fatal_error is not None: + raise CBWorkerDeadError(f"CB worker died during request {request_id}: {result.error}") + raise RuntimeError(f"CB generation failed for {request_id}: {result.error}") generated_ids = result.generated_tokens text = processor.decode(generated_ids, skip_special_tokens=True) return text, input_len, generated_ids @@ -805,6 +844,12 @@ def shutdown(self) -> None: self._cb_manager.stop() self._cb_manager = None + def is_cb_alive(self) -> bool: + """Whether the CB worker is healthy. ``True`` if CB is disabled or not yet initialized.""" + if self._cb_manager is None: + return True + return self._cb_manager.is_alive() + class BaseHandler: """Shared logic for chat completion and responses handlers. diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2dcdc5333f35..073b23172251 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -21,7 +21,7 @@ from collections.abc import Sequence from dataclasses import MISSING, dataclass, fields from functools import wraps -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, Union +from typing import Any, ClassVar, Literal, TypeVar from huggingface_hub import create_repo from huggingface_hub.dataclasses import strict @@ -43,10 +43,7 @@ logging, ) from .utils.generic import is_timm_config_dict - - -if TYPE_CHECKING: - import torch +from .utils.type_validators import dtype_validator logger = logging.get_logger(__name__) @@ -229,7 +226,7 @@ class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin): # Common attributes for all models output_hidden_states: bool | None = False return_dict: bool | None = True - dtype: Union[str, "torch.dtype"] | None = None + dtype: Any = dtype_validator(default=None) chunk_size_feed_forward: int = 0 is_encoder_decoder: bool = False @@ -1161,6 +1158,7 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None: "ignore_keys_at_rope_validation", "base_model_tp_plan", "base_model_pp_plan", + "distributed_config", ]: d.pop(key_to_remove, None) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index dadfeb4224ad..5a865164747c 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -62,19 +62,8 @@ "rt_detr_v2": "rt_detr", "pp_doclayout_v2": "rt_detr", "pp_doclayout_v3": "rt_detr", - "paligemma": "llava", - "aya_vision": "llava", - "got_ocr2": "llava", - "shieldgemma2": "llava", - "gemma3": "llava", - "internvl": "llava", - "llava_next_video": "llava_next", - "llava_onevision": "llava_next", - "vipllava": "llava", - "mistral3": "llava", "qwen2_5_vl": "qwen2_vl", "sam3_tracker_video": "sam3_tracker", - "pp_chart2table": "llava", "altclip_vision_model": "clip_vision_model", "chinese_clip_vision_model": "clip_vision_model", "clipseg_vision_model": "clip_vision_model", @@ -89,6 +78,32 @@ "siglip_text_model": "clip_text_model", "siglip2_text_model": "clip_text_model", "xclip_text_model": "clip_text_model", + "shield_gemma2": "llava", + "paligemma": "llava", + "aya_vision": "llava", + "got_ocr2": "llava", + "gemma3": "llava", + "internvl": "llava", + "vipllava": "llava", + "mistral3": "llava", + "pp_chart2table": "llava", + "llava_next_video": "llava_next", + "llava_onevision": "llava_next", + # class-based mappings + "PaliGemmaModel": "LlavaModel", + "AyaVisionModel": "LlavaModel", + "GotOcr2Model": "LlavaModel", + "Gemma3Model": "LlavaModel", + "InternVLModel": "LlavaModel", + "VipLlavaModel": "LlavaModel", + "Mistral3Model": "LlavaModel", + "PPChart2TableModel": "LlavaModel", + "LlavaNextModel": "LlavaModel", + "LlavaNextVideoModel": "LlavaModel", + "LlavaOnevisionModel": "LlavaModel", + "FuyuModel": "LlavaModel", + "MllamaModel": "LlavaModel", + "Qwen2_5_VLModel": "Qwen2VLModel", } @@ -97,42 +112,55 @@ def _build_checkpoint_conversion_mapping(): "altclip": [ WeightRenaming(source_patterns=r"layer\.", target_patterns="layers."), ], + "LlavaModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], "llava": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], "llava_next": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"), ], - "clip_vision_model": [PrefixChange(prefix_to_remove="vision_model")], + "clip_vision_model": [ + PrefixChange(prefix_to_remove="vision_model"), + # Keep old CLIP-like checkpoints loadable after fixing the historical typo in module names. + WeightRenaming(source_patterns=r"layrnorm", target_patterns="layernorm"), + ], "clip_text_model": [PrefixChange(prefix_to_remove="text_model")], + "VideoLlavaModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], "video_llava": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^image_tower", target_patterns="model.image_tower"), WeightRenaming(source_patterns=r"^video_tower", target_patterns="model.video_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], "fuyu": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^vision_embed_tokens", target_patterns="model.vision_embed_tokens"), ], "mllama": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^vision_model", target_patterns="model.vision_model"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "Emu3Model": [ + WeightRenaming(source_patterns=r"^text_model.model", target_patterns="text_model"), + ], "emu3": [ - WeightRenaming(source_patterns=r"^text_model.model", target_patterns="model.text_model"), WeightRenaming(source_patterns=r"^text_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^text_model", target_patterns="model.text_model"), WeightRenaming(source_patterns=r"^vqmodel", target_patterns="model.vqmodel"), ], "paddleocr_vl": [ @@ -143,15 +171,12 @@ def _build_checkpoint_conversion_mapping(): target_patterns="model.language_model", ), ], + "Qwen2VLModel": [WeightRenaming(source_patterns=r"^model.", target_patterns="")], "qwen2_vl": [ + WeightRenaming(source_patterns=r"^visual", target_patterns="model.visual"), WeightRenaming( source_patterns=r"(? None: + """ + Register a conversion mapping for a model type string or a class name. + + Class names take priority over ``model_type`` strings during lookup (see + :func:`extract_weight_conversions_for_model`), making it possible to define + task-head-specific or class-specific conversions that differ from the shared + ``model_type`` baseline. + """ global _checkpoint_conversion_mapping_cache if _checkpoint_conversion_mapping_cache is None: _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping() - if model_type in _checkpoint_conversion_mapping_cache and not overwrite: - raise ValueError(f"Model type {model_type} already exists in the checkpoint conversion mapping.") - _checkpoint_conversion_mapping_cache[model_type] = mapping + if model_type_or_class_name in _checkpoint_conversion_mapping_cache and not overwrite: + raise ValueError( + f"Conversion mapping for '{model_type_or_class_name}' already exists. Pass overwrite=True to replace it." + ) + _checkpoint_conversion_mapping_cache[model_type_or_class_name] = mapping -def extract_weight_conversions_for_model(model: PreTrainedModel, model_prefix: str) -> list[WeightTransform] | None: +def extract_weight_conversions_for_model( + model: PreTrainedModel, +) -> list[WeightTransform] | None: + """ + Return the registered conversion list for ``model``, or ``None`` if none exists. + + Looks up by class name first (enables task-head-specific overrides), then + falls back to ``model.config.model_type``. Transforms are returned + unmodified; the caller sets ``scope_prefix`` on each transform for sub-module isolation. + """ + class_name = type(model).__name__ model_type = getattr(model.config, "model_type", None) - if model_type is not None: - model_specific_conversions = get_checkpoint_conversion_mapping(model_type) - # In this case, add the prefix to `PrefixChange` instances, in order to know where to add/remove the prefix - if model_specific_conversions is not None and model_prefix != "": - for i, conversion in enumerate(model_specific_conversions): - # In this case, add the prefix, as otherwise we don't know where we need to re-add it exactly in the module name chain - if isinstance(conversion, PrefixChange): - model_specific_conversions[i] = conversion.with_submodel_prefix(model_prefix) - return model_specific_conversions - return None + + # Class name takes priority — allows ForXxx-specific overrides + conversions = get_checkpoint_conversion_mapping(class_name) + if conversions is None and model_type is not None: + conversions = get_checkpoint_conversion_mapping(model_type) + return conversions def get_model_conversion_mapping( @@ -660,11 +696,17 @@ def get_model_conversion_mapping( add_legacy: bool = True, ) -> list[WeightTransform]: """ - For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming - `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping. + Collect the ordered list of weight transforms for ``model`` (used during + loading and, when reversed, during saving). + + Each ``PreTrainedModel`` sub-module is looked up by class name then + ``model_type``. Root transforms are applied globally; sub-module transforms + have their ``scope_prefix`` set so they only match keys under that prefix. After any + sub-module is processed, both its class name and ``model_type`` are marked + seen to prevent ``XForY`` / ``XModel`` pairs from applying the same mapping + twice via different lookup paths. """ # Lazy import to avoid circular import issues - from .modeling_utils import PreTrainedModel # note: this function is used in PEFT, so changing the API requires coordination weight_conversions = [] @@ -673,16 +715,45 @@ def get_model_conversion_mapping( if key_mapping is not None: weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()] - # Model have several `PreTrainedModel` within with the same model type, for example: XForConditionalGeneration -> XModel - # We don't want to apply the same conversion pattern twice because of that - seen_model_types = set() - # Recurse over submodules and collect all conversions - for name, submodule in model.named_modules(): - if isinstance(submodule, PreTrainedModel) and submodule.config.model_type not in seen_model_types: - conversions = extract_weight_conversions_for_model(submodule, name) - if conversions is not None: - weight_conversions.extend(conversions) - seen_model_types.add(submodule.config.model_type) + seen_identifiers: set[str] = set() + + named_pretrained = getattr(model, "_named_pretrained_submodules", None) + if named_pretrained is None: + from .modeling_utils import PreTrainedModel + + named_pretrained = [(name, m) for name, m in model.named_modules() if isinstance(m, PreTrainedModel)] + for module_name, submodule in named_pretrained: + class_name = type(submodule).__name__ + model_type = getattr(submodule.config, "model_type", None) + + # Skip if this architecture was already processed via either lookup path. + if class_name in seen_identifiers or (model_type and model_type in seen_identifiers): + continue + + # Try class name first, then model_type. Track which path produced the hit so + # we know whether to block model_type for subsequent sub-modules (see below). + conversions = get_checkpoint_conversion_mapping(class_name) + found_via_class = conversions is not None + if not found_via_class and model_type is not None: + conversions = get_checkpoint_conversion_mapping(model_type) + + if conversions is None: + continue + + is_root_model = module_name == "" + if not is_root_model: + # Scope each transform so it only matches keys under this sub-module's prefix. + for transform in conversions: + transform.scope_prefix = module_name + weight_conversions.extend(conversions) + + seen_identifiers.add(class_name) + # Only block model_type when the hit was via model_type. When the hit was via + # class name, sub-modules that share the same model_type but have no class-specific + # mapping of their own (e.g. DetrModel under DetrForSegmentation) must still be + # reachable so their base transforms are picked up and scoped automatically. + if not found_via_class and model_type: + seen_identifiers.add(model_type) if add_legacy: weight_conversions.extend(get_checkpoint_conversion_mapping("legacy")) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index cd0710649c91..0e9e88dfbc83 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -591,6 +591,7 @@ class WeightTransform: "_original_source_patterns", "_original_target_patterns", "_was_used", + "scope_prefix", ) def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): @@ -608,6 +609,9 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list # Flag to notice if the Transform was used self._was_used = False + # Optional prefix scope: when set, this transform only applies to keys starting with + # ``scope_prefix + "."``, stripping / re-attaching the prefix around the pattern match. + self.scope_prefix: str | None = None # We need to process a few exceptions here when instantiating the reverse mapping (i.e. the targets become # sources, and sources become targets). The issues lie in the sources usually, so here we need to check the @@ -673,6 +677,27 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu self.collected_tensors[source_pattern].append(future) self.layer_targets[target_key].add(source_key) + def _scoped_match(self, source_key: str) -> tuple[str | None, str, re.Match[str]] | None: + """ + Apply ``scope_prefix`` stripping (if any), then match ``compiled_sources`` against the suffix. + + Returns ``(prefix_dot, key_to_match, match_object)`` when a branch matches, where ``prefix_dot`` is ``None`` + if ``scope_prefix`` is unset, else ``f"{scope_prefix}."``. Returns ``None`` when out of scope or unmatched. + Does not set ``_was_used``. + """ + prefix_dot = None + key_to_match = source_key + if self.scope_prefix is not None: + prefix_dot = self.scope_prefix + "." + if not source_key.startswith(prefix_dot): + return None + key_to_match = source_key[len(prefix_dot) :] + + match_object = self.compiled_sources.search(key_to_match) + if match_object is None: + return None + return (prefix_dot, key_to_match, match_object) + def rename_source_key(self, source_key: str) -> tuple[str, str | None]: """ Return a tuple (renamed_key, source_pattern_producing_the_match). @@ -680,11 +705,12 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]: In case of a one-to-many transform, i.e. we have several target patterns, the matching source pattern will be replaced by the first of all the target patterns (they are then correctly expanded in the Operations). """ - # Try matching one of the alternation branches - match_object = self.compiled_sources.search(source_key) - if match_object is None: + matched = self._scoped_match(source_key) + if matched is None: return source_key, None + prefix_dot, key_to_match, match_object = matched + # We have a match, so the Transform was used self._was_used = True @@ -699,7 +725,9 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]: # inside that matched named group replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1 replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx)) - renamed_key = source_key.replace(match_object.group(0), replacement, 1) + renamed_key = key_to_match.replace(match_object.group(0), replacement, 1) + if prefix_dot is not None: + renamed_key = prefix_dot + renamed_key return renamed_key, source_pattern_that_matched def reverse_transform(self) -> WeightTransform: @@ -717,7 +745,7 @@ def reverse_transform(self) -> WeightTransform: reverse_transform = self.__class__( source_patterns=self._original_target_patterns, target_patterns=self._original_source_patterns, **kwargs ) - + reverse_transform.scope_prefix = self.scope_prefix return reverse_transform def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: @@ -836,15 +864,11 @@ def reverse_transform(self) -> WeightTransform: raise ValueError("Cannot reverse the transform with TP or quantization") # Only one of the 2 can ever be used, so 1 is always None - return PrefixChange( + result = PrefixChange( prefix_to_add=self.prefix_to_remove, prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix ) - - def with_submodel_prefix(self, prefix: str) -> PrefixChange: - new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix - return PrefixChange( - prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=new_prefix - ) + result.scope_prefix = self.scope_prefix + return result # List of classes that are known to be able to use m:n @@ -1077,6 +1101,8 @@ def set_param_for_module( if ref is not None and param_value.shape != expected_shape and hf_quantizer is None: loading_info.mismatched_keys.add((target_name, param_value.shape, expected_shape)) else: + if distributed_operation is not None: + param_value = distributed_operation.post_shard_wrap(param_value) # super important otherwise _init_weight will re-init the param param_value._is_hf_initialized = True setattr(module_obj, param_name, param_value) @@ -1112,30 +1138,50 @@ class SkipParameters(Exception): def rename_source_key( source_key: str, - weight_renamings: list[WeightRenaming], - weight_converters: list[WeightConverter], + weight_transforms: list[WeightTransform], prefix: str | None = None, meta_state_dict: dict | None = None, ) -> tuple[str, str | None]: """ - Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing - the base model prefix during loading if necessary. + Rename a source key according to ``weight_transforms``, also handling the base model prefix. + + Transforms are applied in list order, interleaving ``WeightRenaming`` and ``WeightConverter`` + instances as they appear. The same list, reversed and with each transform individually + inverted, is used on the save path, so relative ordering is preserved in both directions. + + At most one ``WeightConverter`` fires per key; subsequent converters are skipped. + ``WeightRenaming`` always runs, even after a converter has already fired. + + Example (root rename followed by a scoped sub-model converter):: + + transforms = [ + WeightRenaming("^old_prefix", "model.vlm"), + WeightConverter("^q_proj", "qkv_proj", ...), # scope_prefix="model.vlm" + ] + # Load: "old_prefix.q_proj" + # → WeightRenaming → "model.vlm.q_proj" + # → WeightConverter → "model.vlm.qkv_proj" + # + # Save (inverted list, each transform reversed): + # "model.vlm.q_proj" + # → rev(WeightConverter) → "model.vlm.q_proj" + # → rev(WeightRenaming) → "old_prefix.q_proj" """ renamed_key = source_key - # 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they - # are coherent) - for renaming in weight_renamings: - renamed_key, _ = renaming.rename_source_key(renamed_key) - - # 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after - # the first match, as we assume only 1 converter can match any source key) source_pattern = None - for converter in weight_converters: - renamed_key, source_pattern = converter.rename_source_key(renamed_key) - if source_pattern is not None: - break - # 3. check if we need to add or remove prefix if necessary (only during loading, not saving) + for transform in weight_transforms: + if isinstance(transform, WeightConverter): + if source_pattern is not None: + # Already matched a converter; skip subsequent converters. + continue + renamed_key, sp = transform.rename_source_key(renamed_key) + if sp is not None: + source_pattern = sp + else: + renamed_key, _ = transform.rename_source_key(renamed_key) + + # check if we need to add or remove prefix if necessary (only during loading, not saving) if prefix is not None and meta_state_dict is not None: if ( renamed_key.startswith(prefix) @@ -1277,7 +1323,6 @@ def convert_and_load_state_dict_in_model( else: thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS) - renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)] param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {} @@ -1292,13 +1337,11 @@ def convert_and_load_state_dict_in_model( state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) for original_key, tensor in state_dict: - # 1. Rename the key according to all renaming pattern and optional weight converter patterns - renamed_key, source_pattern = rename_source_key( - original_key, renamings, converters, prefix, meta_model_state_dict - ) + # 1. Rename the key according to all renaming and weight conversion patterns. + renamed_key, source_pattern = rename_source_key(original_key, weight_mapping, prefix, meta_model_state_dict) if renamed_key not in meta_model_state_dict and original_key in meta_model_state_dict: - # Key should probably not have been renamed but we might need the `prefix` to be added.` - renamed_key, source_pattern = rename_source_key(original_key, [], [], prefix, meta_model_state_dict) + # Key should probably not have been renamed but we might need the `prefix` to be added. + renamed_key, source_pattern = rename_source_key(original_key, [], prefix, meta_model_state_dict) # 2. finally, collect the tensor into the proper converter if renamed_key in meta_model_state_dict: @@ -1460,15 +1503,14 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch # Reverse all Transform to correctly match keys reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions] # If we are still here, we need to create the (reverse) conversion mapping from scratch - renamings = [entry for entry in reverse_weight_conversion if isinstance(entry, WeightRenaming)] converters = [entry for entry in reverse_weight_conversion if isinstance(entry, WeightConverter)] pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns} conversion_mapping = {} state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) for original_key, tensor in state_dict: - # Rename the key according to all renaming pattern and optional weight converter patterns - renamed_key, source_pattern = rename_source_key(original_key, renamings, converters) + renamed_key, source_pattern = rename_source_key(original_key, reverse_weight_conversion) + if source_pattern is not None: new_converter = deepcopy(pattern_to_converter[source_pattern]) # each target key gets its own converter instance diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 8412ab5ae25a..aea74cee059b 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1369,6 +1369,7 @@ class DataCollatorWithFlattening(DefaultDataCollator): - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100 - no padding will be added, returns `input_ids`, `labels` and `position_ids` by default - optionally returns the kwargs contained in FlashAttentionKwargs + - optionally returns `cu_seqlens` for FLA-style kernels - optionally returns seq_idx indicating which sequence each token belongs to @@ -1385,6 +1386,7 @@ def __init__( return_position_ids=True, separator_id=-100, return_flash_attn_kwargs=False, + return_cu_seqlens=False, return_seq_idx=False, **kwargs, ): @@ -1392,6 +1394,7 @@ def __init__( self.return_position_ids = return_position_ids self.separator_id = separator_id self.return_flash_attn_kwargs = return_flash_attn_kwargs + self.return_cu_seqlens = return_cu_seqlens self.return_seq_idx = return_seq_idx self._int_64_keys = {"labels", "position_ids", "input_ids"} self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"} @@ -1408,7 +1411,7 @@ def __call__(self, features, return_tensors=None, separator_id=None): batch.update({"position_ids": []}) if self.return_seq_idx: batch.update({"seq_idx": []}) - if self.return_flash_attn_kwargs: + if self.return_flash_attn_kwargs or self.return_cu_seqlens: cu_seq_lens = [0] max_length = 0 for seq_idx, sample in enumerate(features): @@ -1423,20 +1426,25 @@ def __call__(self, features, return_tensors=None, separator_id=None): # Convert to list if tensor if hasattr(labels, "tolist"): labels = labels.tolist() - batch["labels"] += [separator_id] + labels[1:] + if isinstance(labels, (list, tuple)): + batch["labels"] += [separator_id] + labels[1:] + else: + batch["labels"] += [labels] * len(input_ids) else: batch["labels"] += [separator_id] + input_ids[1:] if self.return_position_ids: batch["position_ids"] += list(range(len(input_ids))) if self.return_seq_idx: batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))] - if self.return_flash_attn_kwargs: + if self.return_flash_attn_kwargs or self.return_cu_seqlens: cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids)) max_length = max(max_length, len(input_ids)) if self.return_flash_attn_kwargs: batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens batch["max_length_q"] = batch["max_length_k"] = max_length + if self.return_cu_seqlens: + batch["cu_seqlens"] = cu_seq_lens # FlashAttentionKwargs and seq_idx are expected to be int32s. if return_tensors == "pt": diff --git a/src/transformers/debug_utils.py b/src/transformers/debug_utils.py index 38ff0399641b..ae44ef1eb899 100644 --- a/src/transformers/debug_utils.py +++ b/src/transformers/debug_utils.py @@ -155,7 +155,7 @@ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_af self.batch_number = 0 self.total_calls = 0 self.detected_overflow = False - self.prefix = " " + self.prefix = " " * 17 self.analyse_model() diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 9c9e7b929f6f..4598a6760090 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -311,6 +311,42 @@ def get_class_in_module( return getattr(module, class_name) +def _compute_local_source_files_hash( + pretrained_model_name_or_path: str | os.PathLike, + module_file: str | os.PathLike, + resolved_module_file: str | os.PathLike, + modules_needed: list[str], +) -> str: + """ + Computes a stable hash from the bytes of the local source file and its direct relative-import source files. + """ + model_path = Path(pretrained_model_name_or_path).resolve() + module_parent = Path(module_file).parent + + resolved_module_file = Path(resolved_module_file).resolve() + + def _resolve_relative_source_path(source_file_path: Path) -> str: + try: + return source_file_path.relative_to(model_path).as_posix() + except ValueError: + # Fallback for edge cases where the source file is not under the local model directory. + return source_file_path.as_posix() + + files_to_hash = [ + (_resolve_relative_source_path(resolved_module_file), resolved_module_file), + ] + for module_needed in modules_needed: + module_needed_path = (model_path / module_parent / f"{module_needed}.py").resolve() + files_to_hash.append((_resolve_relative_source_path(module_needed_path), module_needed_path)) + + source_files_hash = hashlib.sha256() + for relative_path, file_path in sorted(files_to_hash, key=lambda entry: entry[0]): + source_files_hash.update(relative_path.encode("utf-8")) + source_files_hash.update(file_path.read_bytes()) + + return source_files_hash.hexdigest()[:16] + + def get_cached_module_file( pretrained_model_name_or_path: str | os.PathLike, module_file: str, @@ -374,11 +410,10 @@ def get_cached_module_file( local_files_only = True # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. - pretrained_model_name_or_path = str(pretrained_model_name_or_path) + pretrained_model_name_or_path = str(pretrained_model_name_or_path).rstrip(os.sep) is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)) - else: + cached_module = None + if not is_local: submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/"))) cached_module = try_to_load_from_cache( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type @@ -408,19 +443,28 @@ def get_cached_module_file( # Check we have all the requirements in our environment modules_needed = check_imports(resolved_module_file) + if is_local: + local_model_name = _sanitize_module_name(os.path.basename(os.path.normpath(pretrained_model_name_or_path))) + local_source_files_hash = _compute_local_source_files_hash( + pretrained_model_name_or_path, module_file, resolved_module_file, modules_needed + ) + if local_model_name: + submodule = os.path.sep.join([local_model_name, local_source_files_hash]) + else: + submodule = local_source_files_hash # Now we move the module inside our cached dynamic modules. full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule - if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)): + if is_local: # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or # has changed since last copy. if not (submodule_path / module_file).exists() or not filecmp.cmp( resolved_module_file, str(submodule_path / module_file) ): (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True) - shutil.copy(resolved_module_file, submodule_path / module_file) + shutil.copyfile(resolved_module_file, submodule_path / module_file) importlib.invalidate_caches() for module_needed in modules_needed: module_needed = Path(module_file).parent / f"{module_needed}.py" @@ -428,7 +472,7 @@ def get_cached_module_file( if not (submodule_path / module_needed).exists() or not filecmp.cmp( module_needed_file, str(submodule_path / module_needed) ): - shutil.copy(module_needed_file, submodule_path / module_needed) + shutil.copyfile(module_needed_file, submodule_path / module_needed) importlib.invalidate_caches() else: # Get the commit hash @@ -442,7 +486,7 @@ def get_cached_module_file( create_dynamic_module(Path(full_submodule_module_file_path).parent) if not (submodule_path / module_file).exists(): - shutil.copy(resolved_module_file, submodule_path / module_file) + shutil.copyfile(resolved_module_file, submodule_path / module_file) importlib.invalidate_caches() # Make sure we also have every file with relative for module_needed in modules_needed: @@ -647,13 +691,13 @@ def _set_auto_map_in_config(_config): # Copy module file to the output folder. object_file = sys.modules[obj.__module__].__file__ dest_file = Path(folder) / (Path(object_file).name) - shutil.copy(object_file, dest_file) + shutil.copyfile(object_file, dest_file) result.append(dest_file) # Gather all relative imports recursively and make sure they are copied as well. for needed_file in get_relative_import_files(object_file): dest_file = Path(folder) / (Path(needed_file).name) - shutil.copy(needed_file, dest_file) + shutil.copyfile(needed_file, dest_file) result.append(dest_file) return result diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index f69b3fdfd9b0..e9840a1fd3a1 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -32,6 +32,8 @@ TensorType, _is_tensor_or_array_like, copy_func, + is_mlx_array, + is_mlx_available, is_numpy_array, is_torch_available, is_torch_device, @@ -142,6 +144,26 @@ def as_tensor(value): return torch.tensor(value) is_tensor = torch.is_tensor + + elif tensor_type == TensorType.MLX: + if not is_mlx_available(): + raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.") + import mlx.core as mx + + def as_tensor(value): + if isinstance(value, (list, tuple)) and len(value) > 0: + if isinstance(value[0], np.ndarray): + value = np.array(value) + elif ( + isinstance(value[0], (list, tuple)) + and len(value[0]) > 0 + and isinstance(value[0][0], np.ndarray) + ): + value = np.array(value) + return mx.array(value) + + is_tensor = is_mlx_array + else: def as_tensor(value, dtype=None): diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 459dcfc1c2fa..6121b57909b8 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -943,7 +943,6 @@ def cancel_request(self, request_id: str) -> None: if self.batch_processor is not None: self.batch_processor.scheduler.set_request_cancellation(request_id) - # TODO:handle benchmarking properly when updating / fixing the requeue logic def get_result(self, request_id: str | None = None, timeout: float | None = None) -> GenerationOutput | None: """Retrieve one result from the output queue. @@ -956,14 +955,28 @@ def get_result(self, request_id: str | None = None, timeout: float | None = None """ if self._generation_thread is None and self.output_router.output_queue.empty(): return None + + deadline = None if timeout is None else perf_counter() + timeout + deferred: list[GenerationOutput] = [] + try: - result = self.output_router.output_queue.get(block=True, timeout=timeout) - if request_id is not None and result.request_id != request_id: - self.output_router.output_queue.put(result) - return None - return result - except queue.Empty: - return None + while True: + remaining = None if deadline is None else max(0.0, deadline - perf_counter()) + if remaining == 0.0: + return None + + try: + result = self.output_router.output_queue.get(timeout=remaining) + except queue.Empty: + return None + + if request_id is None or result.request_id == request_id: + return result + + deferred.append(result) + finally: + for item in deferred: + self.output_router.output_queue.put(item) def __iter__(self): """Iterate over results as they become available.""" @@ -980,11 +993,16 @@ def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]: """ while self._generation_thread is not None and self._generation_thread.is_alive(): result = self.get_result(request_id=request_id, timeout=0.1) + if result is not None: yield result if result.is_finished(): return + if self.batch_processor is not None: + if self.batch_processor.scheduler.request_is_cancelled(request_id): + return + def register_result_handler(self, request_id: str, callback: Callable) -> None: """Register a callback for result delivery (streaming or non-streaming). diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 9c47e551cee8..3d0c70dd7413 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1005,7 +1005,14 @@ def __init__( @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isneginf(scores).all(dim=-1).any(): + raise ValueError( + "EtaLogitsWarper received a row with all logits set to -inf. " + "This usually means previous logits processors masked every token." + ) + probabilities = scores.softmax(dim=-1) + entropy = torch.distributions.Categorical(logits=scores).entropy() eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] indices_to_remove = probabilities < eta @@ -1661,13 +1668,22 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class InfNanRemoveLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using - the logits processor should only be used if necessary since it can slow down the generation method. + [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. This version + has been extended to sanitize both logits and hidden state output tensors to handle instabilities in very wide + models or ones sharded across many devices. + + Note that using the logits processor should only be used if necessary since it can slow down the generation method. This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants - its use. + its use. However, when dealing with sharded models across many GPUs or models with very wide hidden dimensions that + can produce unstable values, setting `remove_invalid_values=True` in generation config will activate this processor + automatically. """ + def __init__(self, hidden_states_aware=True): + # Flag to control whether we also want to clean hidden states + self.hidden_states_aware = hidden_states_aware + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # set all nan values to 0.0 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 388cef73566a..16aa5a7ff17e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1086,9 +1086,31 @@ def _get_logits_processor( UserWarning, ) if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: - processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if self.config.is_encoder_decoder: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + else: + inputs_embeds = model_kwargs.get("inputs_embeds") if model_kwargs is not None else None + if inputs_embeds is not None and (input_ids_seq_length is None or input_ids_seq_length == 0): + warnings.warn( + "Passing `repetition_penalty` requires some form of `input_ids` to be passed to " + "`generate`, ignoring the argument.", + UserWarning, + ) + else: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: - processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if self.config.is_encoder_decoder: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + else: + inputs_embeds = model_kwargs.get("inputs_embeds") if model_kwargs is not None else None + if inputs_embeds is not None and (input_ids_seq_length is None or input_ids_seq_length == 0): + warnings.warn( + "Passing `no_repeat_ngram_size` requires some form of `input_ids` to be passed to " + "`generate`, ignoring the argument.", + UserWarning, + ) + else: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) if ( generation_config.encoder_no_repeat_ngram_size is not None and generation_config.encoder_no_repeat_ngram_size > 0 @@ -1720,12 +1742,50 @@ def _prepare_generation_config( "parameters explicitly, but not both.", ) + # Safety: if the model is sharded across multiple devices (hf_device_map/device_map) and we are + # doing sampling, enable `remove_invalid_values` by default to avoid NaN/Inf logits causing CUDA + # asserts during multinomial sampling. Users can still override this by passing the flag explicitly. + try: + is_sharded_map = False + hf_map = getattr(self, "hf_device_map", None) + if hf_map is not None and isinstance(hf_map, dict) and len(set(hf_map.values())) > 1: + # consider sharded if more than one device (excluding "cpu"/"disk") + devices = set(hf_map.values()) + gpu_devices = {d for d in devices if d not in {"cpu", "disk"}} + if len(gpu_devices) > 1: + is_sharded_map = True + + # also accept legacy `device_map` attribute or accelerate hooks + device_map_attr = getattr(self, "device_map", None) + if not is_sharded_map and device_map_attr is not None: + # device_map can be a dict mapping module->device or other structures; if it's a dict and maps + # to multiple cuda devices, consider it sharded + if isinstance(device_map_attr, dict) and len(set(device_map_attr.values())) > 1: + devices = set(device_map_attr.values()) + gpu_devices = {d for d in devices if d not in {"cpu", "disk"}} + if len(gpu_devices) > 1: + is_sharded_map = True + + if is_sharded_map and generation_config.do_sample and generation_config.remove_invalid_values is False: + generation_config.remove_invalid_values = True + logger.info( + "Enabling `remove_invalid_values=True` for sharded sampling to avoid NaN/Inf logits during sampling." + ) + except Exception as exception: + # never fail generation config preparation due to best-effort safety check + logger.debug("Failed to detect sharded generation setup: %s", exception) # Finally keep output_xxx args in `model_kwargs` so it can be passed to `forward` output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {}) model_kwargs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + # Enforce deterministic greedy decoding if do_sample=False and num_beams = 1 + if generation_config.do_sample is False and generation_config.num_beams == 1: + generation_config.temperature = 1.0 + generation_config.top_k = 0 + generation_config.top_p = 1.0 + return generation_config, model_kwargs def _prepare_static_cache( @@ -1993,6 +2053,10 @@ def _tensor_or_none(token, device=None): generation_config._pad_token_tensor = pad_token_tensor generation_config._decoder_start_token_tensor = decoder_start_token_tensor + def _is_dynamo_compilation_disabled(self) -> bool: + """Check standard environment variables that explicitly disable torch.dynamo compilation.""" + return os.getenv("TORCHDYNAMO_DISABLE", "").lower() in {"1", "true", "yes", "on"} + def _valid_auto_compile_criteria( self: "GenerativePreTrainedModel", model_kwargs: dict[str, Any], generation_config: GenerationConfig ) -> bool: @@ -2003,6 +2067,9 @@ def _valid_auto_compile_criteria( if generation_config.disable_compile: return False + if self._is_dynamo_compilation_disabled(): + return False + cache = model_kwargs.get("past_key_values", model_kwargs.get("cache_params")) # Base logic @@ -2969,9 +3036,17 @@ def _get_top_k_continuations( # Gather the top K scores from _all_ beams. if do_sample: - topk_indices = torch.multinomial( - nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep - ) + # Handle potential NaN values in accumulated_log_probs + probs = nn.functional.softmax(accumulated_log_probs, dim=-1) + # Replace NaN values with uniform distribution + if torch.isnan(probs).any(): + # Create a mask for NaN positions + nan_mask = torch.isnan(probs) + # Replace NaN with a small uniform probability + probs = torch.where(nan_mask, torch.ones_like(probs) / probs.shape[-1], probs) + # Renormalize to ensure probabilities sum to 1 + probs = probs / probs.sum(dim=-1, keepdim=True) + topk_indices = torch.multinomial(probs, num_samples=beams_to_keep) topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices) else: topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index b9e6f99b041d..32def64950b5 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -175,10 +175,19 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): " the argument parser only supports one type per argument." f" Problem encountered in field '{field.name}'." ) + # filter `dict` in Union because argparse does not support it + if dict in field.type.__args__: + remaining_types = tuple(arg for arg in field.type.__args__ if arg is not dict) + field.type = remaining_types[0] + for remaining_type in remaining_types[1:]: + field.type |= remaining_type if type(None) not in field.type.__args__: - # filter `str` in Union - field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1] - origin_type = getattr(field.type, "__origin__", field.type) + if len(field.type.__args__) > 2: + origin_type = str + else: + # filter `str` in Union + field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1] + origin_type = getattr(field.type, "__origin__", field.type) elif bool not in field.type.__args__: # filter `NoneType` in Union (except for `Union[bool, NoneType]`) field.type = ( diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index 704001c476a6..74069f93aff6 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -675,6 +675,10 @@ def get_patch_output_size(image, target_resolution, input_data_format): original_height, original_width = get_image_size(image, channel_dim=input_data_format) target_height, target_width = target_resolution + if original_width == 0: + raise ValueError("original_width can not be 0") + if original_height == 0: + raise ValueError("original_height can not be 0") scale_w = target_width / original_width scale_h = target_height / original_height diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 88160d1bced3..9ea4bfed897e 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -15,7 +15,7 @@ from collections import defaultdict from collections.abc import Collection, Iterable from math import ceil -from typing import Optional, Union +from typing import Any, Optional, Union, overload import numpy as np @@ -26,7 +26,7 @@ get_image_size, infer_channel_dimension_format, ) -from .utils import ExplicitEnum, TensorType, is_torch_tensor +from .utils import ExplicitEnum, is_torch_tensor from .utils.import_utils import ( is_torch_available, is_vision_available, @@ -547,7 +547,15 @@ def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: # 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py -def center_to_corners_format(bboxes_center: TensorType) -> TensorType: +@overload +def center_to_corners_format(bboxes_center: "torch.Tensor") -> "torch.Tensor": ... + + +@overload +def center_to_corners_format(bboxes_center: np.ndarray) -> np.ndarray: ... + + +def center_to_corners_format(bboxes_center: "torch.Tensor | np.ndarray") -> Any: """ Converts bounding boxes from center format to corners format. @@ -590,7 +598,15 @@ def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: return bboxes_center -def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: +@overload +def corners_to_center_format(bboxes_corners: "torch.Tensor") -> "torch.Tensor": ... + + +@overload +def corners_to_center_format(bboxes_corners: np.ndarray) -> np.ndarray: ... + + +def corners_to_center_format(bboxes_corners: "torch.Tensor | np.ndarray") -> Any: """ Converts bounding boxes from corners format to center format. diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 984d80964fad..8ed1ab73af1f 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -17,6 +17,7 @@ from collections.abc import Iterable from dataclasses import dataclass, fields from io import BytesIO +from pathlib import Path from typing import Any, Union import httpx @@ -463,14 +464,14 @@ def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, list | tuple def load_image( - image: Union[str, "PIL.Image.Image"], + image: Union[str, Path, "PIL.Image.Image"], timeout: float | None = None, ) -> "PIL.Image.Image": """ Loads `image` to a PIL Image. Args: - image (`str` or `PIL.Image.Image`): + image (`str`, `Path` or `PIL.Image.Image`): The image to convert to the PIL Image format. timeout (`float`, *optional*): The timeout value in seconds for the URL request. @@ -479,6 +480,9 @@ def load_image( `PIL.Image.Image`: A PIL Image. """ requires_backends(load_image, ["vision"]) + if isinstance(image, Path): + image = str(image) + if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file diff --git a/src/transformers/initialization.py b/src/transformers/initialization.py index b0ebb053086b..28072ba3b022 100644 --- a/src/transformers/initialization.py +++ b/src/transformers/initialization.py @@ -15,6 +15,7 @@ import sys from collections import defaultdict from contextlib import contextmanager +from contextvars import ContextVar import torch @@ -38,6 +39,19 @@ "sparse_": torch.nn.init.sparse_, } +# Track the current no-tie scope per execution context so concurrent model loads +# do not leak tie_weights suppression across threads. +_SKIP_TIE_WEIGHTS_SCOPE: ContextVar[object | None] = ContextVar("_SKIP_TIE_WEIGHTS_SCOPE", default=None) + + +def should_skip_tie_weights(model) -> bool: + scope = _SKIP_TIE_WEIGHTS_SCOPE.get() + if scope is None: + return False + + # Only skip tying for the model instance created inside the active scope. + return getattr(model, "_skip_tie_weights_scope", None) is scope + def uniform_( tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None @@ -287,19 +301,13 @@ def no_tie_weights(): weights in the state_dict during `from_pretrained`, and otherwise tying them would remove them from it, as it's called in `post_init` when instantiating. """ - from .modeling_utils import PreTrainedModel - - def empty_func(*args, **kwargs): - pass - + # Use an opaque scope token so nested or concurrent loads can identify only + # the models instantiated under this context manager. + state_token = _SKIP_TIE_WEIGHTS_SCOPE.set(object()) try: - original_tie_weights = PreTrainedModel.tie_weights - PreTrainedModel.tie_weights = empty_func - yield finally: - # Set back the original - PreTrainedModel.tie_weights = original_tie_weights + _SKIP_TIE_WEIGHTS_SCOPE.reset(state_token) @contextmanager diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index c2b7fa603570..d7a1e4808f30 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -399,7 +399,12 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload ): device_map_kwargs["offload_buffers"] = True - if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + is_quantized_bnb = ( + hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.BITS_AND_BYTES + ) + + if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled() and not is_quantized_bnb: dispatch_model(model, **device_map_kwargs) @@ -446,15 +451,13 @@ def accelerate_disk_offload( renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside `disk_offload_folder` during loading. """ - from ..core_model_loading import WeightRenaming, rename_source_key + from ..core_model_loading import rename_source_key if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") - renamings = [] - if weight_mapping is not None: - renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] + transforms = weight_mapping if weight_mapping is not None else [] # In this case, the offload index is simply the existing safetensors (except if using custom weight loading # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) @@ -470,7 +473,7 @@ def accelerate_disk_offload( # Update the weight names according to the `weight_mapping` weight_renaming_map = { - rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map + rename_source_key(k, transforms, model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map } # Prepare the index using existing safetensors files diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 9703f642f8bc..79f3896cb48c 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -347,7 +347,7 @@ def _apply_weight_conversions_to_state_dict(model, state_dict, weight_mapping): "in your DeepSpeed config or convert your checkpoint to the expected format first." ) - from ..core_model_loading import WeightConverter, WeightRenaming, dot_natural_key, rename_source_key + from ..core_model_loading import WeightConverter, dot_natural_key, rename_source_key # Preserve metadata from the original state dict metadata = getattr(state_dict, "_metadata", None) @@ -360,14 +360,13 @@ def _apply_weight_conversions_to_state_dict(model, state_dict, weight_mapping): for key, param in model.state_dict().items(): model_state_dict[key] = torch.empty(param.shape, dtype=param.dtype, device="meta") - renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)] # Fast path: if we only have simple renamings and no converters, we can skip the expensive collection logic if len(converters) == 0: new_state_dict = {} for original_key, tensor in state_dict.items(): - renamed_key, _ = rename_source_key(original_key, renamings, [], prefix, model_state_dict) + renamed_key, _ = rename_source_key(original_key, weight_mapping, prefix, model_state_dict) if renamed_key in model_state_dict: new_state_dict[renamed_key] = tensor # Attach metadata to the new state dict @@ -386,7 +385,7 @@ def _apply_weight_conversions_to_state_dict(model, state_dict, weight_mapping): sorted_keys = sorted(state_dict.keys(), key=lambda k: dot_natural_key(k)) for original_key in sorted_keys: tensor = state_dict.pop(original_key) - renamed_key, source_pattern = rename_source_key(original_key, renamings, converters, prefix, model_state_dict) + renamed_key, source_pattern = rename_source_key(original_key, weight_mapping, prefix, model_state_dict) # Only process if the renamed key is in the model's state dict if renamed_key in model_state_dict: diff --git a/src/transformers/integrations/dsa_kernels.py b/src/transformers/integrations/dsa_kernels.py new file mode 100644 index 000000000000..aa7498a387be --- /dev/null +++ b/src/transformers/integrations/dsa_kernels.py @@ -0,0 +1,479 @@ +import torch + +from ..utils import logging as transformers_logging + + +logger = transformers_logging.get_logger(__name__) + +# Try to import tilelang for accelerated kernels +_tilelang_available = False +try: + import tilelang + import tilelang.language as T + + tilelang.set_log_level("WARNING") + + pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, + } + _tilelang_available = True +except Exception: + T = None + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" + + +# ---- TileLang kernel definitions (only if tilelang is available) ---- +if _tilelang_available: + + def fast_log2_ceil(x): + bits_x = T.reinterpret("uint32", x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + def fast_pow2(x): + bits_x = (x + 127) << 23 + return T.reinterpret("float32", bits_x) + + def fast_round_scale(amax, fp8_max_inv): + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + @tilelang.jit(pass_configs=pass_configs) + def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False): + M = T.symbolic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): + with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m,), scale_dtype) + s_local = T.alloc_fragment((blk_m,), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + @tilelang.jit(pass_configs=pass_configs) + def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): + assert out_dtype in [BF16, "float32"] + + M = T.symbolic("M") + group_size = 128 + block_M = 32 + block_N = 128 + block_K = 128 + + @T.prim_func + def fp8_gemm_kernel_( + A: T.Tensor[(M, K), FP8], + B: T.Tensor[(N, K), FP8], + C: T.Tensor[(M, N), out_dtype], + scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], + scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32], + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), FP8) + B_shared = T.alloc_shared((block_N, block_K), FP8) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + Scale_C_shared = T.alloc_shared((block_M), FP32) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return fp8_gemm_kernel_ + + @tilelang.jit(out_idx=[4], pass_configs=pass_configs) + def fp8_index_kernel(h: int, d: int): + b = T.symbolic("b") + m = T.symbolic("m") + n = T.symbolic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +# ---- PyTorch fallback implementations ---- + + +def _act_quant_pytorch( + x: torch.Tensor, block_size: int = 128, scale_fmt: str | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure PyTorch implementation of block-wise FP8 activation quantization. + + Equivalent to the TileLang ``act_quant_kernel``: per-group absmax scaling, + optional power-of-2 rounded scales, clamp to FP8 range. + """ + N = x.size(-1) + assert N % block_size == 0, f"Last dimension size must be divisible by block_size (block_size={block_size})" + num_groups = N // block_size + orig_shape = x.shape + + # Flatten to 2D, then group — mirrors the TileLang kernel's (M, N) layout. + x_flat = x.reshape(-1, N) # [M, N] + x_grouped = x_flat.reshape(-1, num_groups, block_size) # [M, G, BS] + + # Per-group absmax + amax = x_grouped.abs().amax(dim=-1).clamp(min=1e-4) # [M, G] + + if scale_fmt is not None: + # Power-of-2 rounded scale: scale = 2^(ceil(log2(amax / 448))) + scale = torch.pow(2.0, torch.ceil(torch.log2(amax / 448.0))) + else: + scale = amax / 448.0 + + # Quantize: divide each group by its scale, clamp to FP8 range + x_q = (x_grouped / scale.unsqueeze(-1)).clamp(-448.0, 448.0).to(torch.float8_e4m3fn) # [M, G, BS] + x_q = x_q.reshape(orig_shape) + + # Scale shape: (*x.shape[:-1], num_groups) + scale = scale.reshape(*orig_shape[:-1], num_groups) + return x_q, scale + + +def _fp8_index_pytorch( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """Pure PyTorch implementation of FP8 index scoring. + + Equivalent to the TileLang ``fp8_index_kernel``: + logits = k @ q^T (FP8 -> FP32 matmul over D) + logits = relu(logits) * q_s (per-head scale) + result = logits.sum(H) * k_s (reduce heads, scale by k) + """ + q_bf16 = q.to(torch.bfloat16) + k_bf16 = k.to(torch.bfloat16) + # q: [B, M, H, D], k: [B, T, D] -> logits: [B, M, T, H] + # Matches TileLang kernel: logits[n, h] = k[n, :] @ q[h, :]^T + logits = torch.einsum("bmhd,btd->bmth", q_bf16, k_bf16) + logits = logits.clamp(min=0) * q_s.unsqueeze(-2) # q_s: [B,M,H] -> [B,M,1,H] + result = logits.sum(dim=-1) * k_s.unsqueeze(-2) # k_s: [B,T] -> [B,1,T] + return result + + +def _fp8_index_triton( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """Triton FP8 GEMM implementation of FP8 index scoring. + + Uses ``w8a8_fp8_matmul`` from the finegrained-fp8 integration (which dispatches + to Triton on Blackwell) for FP8→FP32 matmul, matching vLLM's computation + granularity. Post-processing (relu, scale, reduce) is done in FP32. + + Equivalent to the TileLang ``fp8_index_kernel``: + logits = dequant(q_fp8, q_scale) @ dequant(k_fp8, k_scale)^T (FP8 dequant + FP32 matmul) + logits = relu(logits) * q_s (per-head weights, already includes q_scale) + result = logits.sum(H) * k_s (reduce heads, scale by k_scale) + """ + from .finegrained_fp8 import w8a8_fp8_matmul + + B, M, H, D = q.shape + T = k.shape[1] + + if B == 1: + # Single batch: one matmul for all (M, H) query vectors against all T keys + q_flat = q.reshape(M * H, D).contiguous() + k_flat = k.reshape(T, D).contiguous() + # Create unit scales: fp8_gemm will compute raw FP8 dot products + # (dequant with scale=1 is equivalent to using FP8 values directly) + ones_q = q_flat.new_ones(M * H, D // 128, dtype=torch.float32) + ones_k = k_flat.new_ones(T, D // 128, dtype=torch.float32) + logits_flat = w8a8_fp8_matmul(q_flat, k_flat, ones_q, ones_k, [128, 128], torch.float32) + # logits_flat: [M*H, T] → reshape to [M, H, T] → transpose to [M, T, H] + logits = logits_flat.reshape(M, H, T).permute(0, 2, 1).unsqueeze(0) # [1, M, T, H] + else: + # Multi-batch: loop over batches + results = [] + for b in range(B): + q_b = q[b].reshape(M * H, D).contiguous() + k_b = k[b].reshape(T, D).contiguous() + ones_q_b = q_b.new_ones(M * H, D // 128, dtype=torch.float32) + ones_k_b = k_b.new_ones(T, D // 128, dtype=torch.float32) + logits_b = w8a8_fp8_matmul(q_b, k_b, ones_q_b, ones_k_b, [128, 128], torch.float32) + logits_b = logits_b.reshape(M, H, T).permute(0, 2, 1) # [M, T, H] + results.append(logits_b) + logits = torch.stack(results, dim=0) # [B, M, T, H] + + # Post-processing in FP32 — matches TileLang kernel + logits = logits.clamp(min=0) * q_s.unsqueeze(-2) # relu * weights + result = logits.sum(dim=-1) * k_s.unsqueeze(-2) # reduce heads * k_scale + return result + + +# ---- Public API: TileLang → Triton → PyTorch fallback ---- + +# One-time flags — once a backend fails, we stop retrying it. +_act_quant_use_tilelang = _tilelang_available +_fp8_index_use_tilelang = _tilelang_available +_fp8_gemm_use_tilelang = _tilelang_available + +# Lazily-loaded Triton kernels from the finegrained-fp8 hub package. +_triton_act_quant = None +_triton_fallbacks_loaded = False + + +def _load_triton_fallbacks(): + """Lazily load Triton FP8 kernels from the finegrained-fp8 hub package.""" + global _triton_fallbacks_loaded, _triton_act_quant + if _triton_fallbacks_loaded: + return + _triton_fallbacks_loaded = True + try: + from .finegrained_fp8 import triton_fp8_act_quant + + _triton_act_quant = triton_fp8_act_quant + except ImportError: + pass + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: str | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Fallback chain: TileLang → Triton (non-ue8m0 only) → PyTorch. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last + dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. + Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + When set (e.g. ``"ue8m0"``), scales are rounded to powers of 2 — handled by + the PyTorch fallback since the Triton kernel does not support power-of-2 rounding. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + + global _act_quant_use_tilelang + if _act_quant_use_tilelang: + try: + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + except Exception: + logger.warning_once("TileLang act_quant compilation failed, falling back to PyTorch implementation") + _act_quant_use_tilelang = False + + # Triton fallback — only for non-ue8m0 scales (Triton kernel lacks power-of-2 rounding) + if scale_fmt is None: + global _triton_act_quant + _load_triton_fallbacks() + if _triton_act_quant is not None: + try: + N = x.size(-1) + x_flat = x.reshape(-1, N).contiguous() + x_q_flat, scale_flat = _triton_act_quant(x_flat, block_size) + x_q = x_q_flat.reshape(x.shape) + scale = scale_flat.reshape(*x.shape[:-1], N // block_size) + return x_q, scale + except Exception: + logger.warning_once("Triton act_quant failed, falling back to PyTorch") + _triton_act_quant = None + + return _act_quant_pytorch(x, block_size, scale_fmt) + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor: + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling factor tensors must be contiguous" + + global _fp8_gemm_use_tilelang + if _fp8_gemm_use_tilelang: + try: + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + kernel = fp8_gemm_kernel(N, K) + kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) + return c + except Exception: + logger.warning_once("TileLang fp8_gemm compilation failed, falling back to PyTorch implementation") + _fp8_gemm_use_tilelang = False + + # PyTorch fallback: dequantize and matmul + group_size = a.shape[-1] // a_s.shape[-1] + a_deq = a.to(torch.bfloat16) * a_s.to(torch.bfloat16).repeat_interleave(group_size, dim=-1) + b_deq = b.to(torch.bfloat16) * b_s.to(torch.bfloat16).repeat_interleave(group_size, dim=-1).repeat_interleave( + group_size, dim=0 + ) + return torch.matmul(a_deq, b_deq.T) + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Fallback chain: TileLang → Triton fp8_gemm → PyTorch bf16 einsum. + + The Triton path uses the fp8_gemm kernel from the finegrained-fp8 hub package + to compute raw FP8 dot products with FP32 accumulation, matching vLLM's + DeepGEMM fp8_mqa_logits computation granularity. + + Args: + q (torch.Tensor): The Q tensor, must be contiguous. + q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. + k (torch.Tensor): The K tensor, must be contiguous. + k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + global _fp8_index_use_tilelang + if _fp8_index_use_tilelang: + try: + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) + except Exception: + logger.warning_once("TileLang fp8_index compilation failed, falling back to PyTorch implementation") + _fp8_index_use_tilelang = False + + # Triton fallback: FP8 matmul with FP32 accumulation (matches vLLM granularity) + try: + return _fp8_index_triton(q, q_s, k, k_s) + except Exception: + logger.warning_once("Triton fp8_index failed, falling back to PyTorch bf16 implementation") + + return _fp8_index_pytorch(q, q_s, k, k_s) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 675a0ea5783a..a835fc44cc71 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -13,8 +13,10 @@ import logging import torch +import torch.utils._pytree as pytree from ..cache_utils import ( + Cache, DynamicCache, DynamicLayer, DynamicSlidingWindowLayer, @@ -25,10 +27,7 @@ ) from ..generation.configuration_utils import GenerationConfig from ..modeling_utils import PreTrainedModel -from ..pytorch_utils import ( - is_torch_greater_or_equal, - is_torch_greater_or_equal_than_2_6, -) +from ..pytorch_utils import is_torch_greater_or_equal class TorchExportableModuleForVLM: @@ -881,7 +880,7 @@ def __init__(self, model, max_static_cache_length, batch_size): self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device) self.cache = EncoderDecoderCache(self.static_cache, DynamicCache(config=self.config)) - register_dynamic_cache_export_support() + register_pytree_cache() # Register cache buffers to make them exportable for i, layer in enumerate(self.static_cache.layers): @@ -889,7 +888,13 @@ def __init__(self, model, max_static_cache_length, batch_size): self.register_buffer(f"value_cache_{i}", layer.values, persistent=False) self.register_buffer(f"cumulative_length_{i}", layer.cumulative_length, persistent=False) - def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): + def forward( + self, + decoder_input_ids: torch.Tensor, + encoder_hidden_states: torch.Tensor, + cache_position: torch.Tensor, + encoder_attention_mask: torch.Tensor | None = None, + ): # Start by resetting static cache (it's needed to be able to run several generations with the same exported program, # as otherwise it's mutated in-place indefinitely - we cannot call reset in-between the `generate` as the program was # already exported) @@ -900,6 +905,7 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_values=self.cache, use_cache=True, ) @@ -947,7 +953,7 @@ def _export_encoder(self, encoder_input_ids): return exported_encoder - def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): + def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position, encoder_attention_mask=None): target_device = self.full_model.device wrapped_decoder = ( Seq2SeqLMDecoderExportableModuleWithStaticCache( @@ -963,27 +969,35 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi decoder_input_ids = decoder_input_ids.to(target_device) encoder_hidden_states = encoder_hidden_states.to(target_device) cache_position = cache_position.to(target_device) - - # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) - - # Export the decoder + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.to(target_device) + + # Export the decoder. + # encoder_hidden_states uses a static shape to avoid a symbolic-shape + # conflict with the static KV cache size during torch.export. Callers + # that pad encoder inputs to a fixed max length (e.g. max_hidden_seq_length) + # should pass encoder_hidden_states of that shape. with torch.no_grad(): exported_decoder = torch.export.export( wrapped_decoder, - (decoder_input_ids, encoder_hidden_states, cache_position), - dynamic_shapes={ - "decoder_input_ids": None, - "encoder_hidden_states": {1: encoder_seq_len_dim}, - "cache_position": None, - }, + (decoder_input_ids, encoder_hidden_states, cache_position, encoder_attention_mask), + dynamic_shapes=None, strict=True, ) return exported_decoder - def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None): + def export( + self, + encoder_input_ids=None, + decoder_input_ids=None, + encoder_hidden_states=None, + cache_position=None, + encoder_attention_mask=None, + ): device = self.full_model.device + max_cache_len = self.generation_config.cache_config.get("max_cache_len") + batch_size = self.generation_config.cache_config.get("batch_size") example_encoder_input_ids = ( encoder_input_ids if encoder_input_ids is not None @@ -1001,14 +1015,22 @@ def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_ encoder_hidden_states if encoder_hidden_states is not None else torch.zeros( - (self.generation_config.cache_config.get("batch_size"), 10, self.config.d_model), + (batch_size, max_cache_len, self.config.d_model), dtype=torch.float32, device=device, ) ) + example_encoder_attention_mask = ( + encoder_attention_mask + if encoder_attention_mask is not None + else torch.ones((batch_size, max_cache_len), dtype=torch.long, device=device) + ) self.exported_encoder = self._export_encoder(example_encoder_input_ids) self.exported_decoder = self._export_decoder( - example_decoder_input_ids, example_encoder_hidden_states, example_cache_position + example_decoder_input_ids, + example_encoder_hidden_states, + example_cache_position, + example_encoder_attention_mask, ) # Return self to allow chaining @@ -1025,6 +1047,22 @@ def generate(self, prompt_token_ids, max_new_tokens): # Run encoder encoder_output = self.exported_encoder.module()(prompt_token_ids) + # Build encoder attention mask: 1 at real token positions, 0 at padding. + # Assumes padding token id is 0 (standard for T5 and most seq2seq models). + max_cache_len = self.generation_config.cache_config.get("max_cache_len") + batch_size = prompt_token_ids.shape[0] + encoder_attention_mask = (prompt_token_ids != 0).long() + # Pad or trim to max_cache_len so shape matches the static export + if encoder_attention_mask.shape[1] < max_cache_len: + pad = torch.zeros( + (batch_size, max_cache_len - encoder_attention_mask.shape[1]), + dtype=torch.long, + device=model_device, + ) + encoder_attention_mask = torch.cat([encoder_attention_mask, pad], dim=1) + else: + encoder_attention_mask = encoder_attention_mask[:, :max_cache_len] + # Initialize with start token (0 for T5) on the correct device decoder_input_ids = torch.tensor([[0]], dtype=torch.long, device=model_device) generated_ids = [0] @@ -1033,7 +1071,10 @@ def generate(self, prompt_token_ids, max_new_tokens): for i in range(max_new_tokens - 1): # Run decoder for next token prediction logits = self.exported_decoder.module()( - decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long, device=model_device) + decoder_input_ids, + encoder_output, + torch.tensor([i], dtype=torch.long, device=model_device), + encoder_attention_mask, ) # Get next token @@ -1067,7 +1108,7 @@ def export_with_dynamic_cache( Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. """ - register_dynamic_cache_export_support() + register_pytree_cache() with torch.no_grad(): exported_program = torch.export.export( @@ -1084,54 +1125,97 @@ def export_with_dynamic_cache( return exported_program -def register_dynamic_cache_export_support(): - """ - Utilities for `DynamicCache` <> torch.export support - """ - +def _register_pytree_node(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn): try: - torch.utils._pytree.register_pytree_node( - DynamicCache, - lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)), - _unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys( - _get_cache_dict(dynamic_cache) - ), + pytree.register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=f"{cls.__module__}.{cls.__name__}", + flatten_with_keys_fn=flatten_with_keys_fn, ) - # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec( - DynamicCache, - lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec), - ) - # Catching this in case there are multiple runs for some test runs - except ValueError as e: - if "already registered as pytree node" not in str(e): + except ValueError as error: + if "already registered as pytree node" not in str(error): raise -def _get_cache_dict(cache: DynamicCache): - """Convert cache to dictionary format for pytree operations.""" - if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") +def _register_pytree_cache_layer(cache_layer_cls): + def _flatten_layer(layer): + attributes = { + "keys": layer.keys, + "values": layer.values, + "is_initialized": layer.is_initialized, + } + for name in ( + "max_cache_len", + "max_batch_size", + "num_heads", + "k_head_dim", + "v_head_dim", + "cumulative_length", + "cumulative_length_int", + "sliding_window", + ): + if hasattr(layer, name): + attributes[name] = getattr(layer, name) + return list(attributes.values()), list(attributes.keys()) + + def _unflatten_layer(values, context): + attributes = dict(zip(context, values)) + + if cache_layer_cls is StaticLayer: + layer = cache_layer_cls(max_cache_len=attributes["max_cache_len"]) + elif cache_layer_cls is StaticSlidingWindowLayer: + layer = cache_layer_cls( + max_cache_len=attributes["max_cache_len"], + sliding_window=attributes["max_cache_len"], + ) + elif cache_layer_cls is DynamicSlidingWindowLayer: + layer = cache_layer_cls(sliding_window=attributes["sliding_window"]) + else: + layer = cache_layer_cls() + + for name, value in attributes.items(): + setattr(layer, name, value) + return layer + + def _flatten_layer_with_keys(layer): + values, context = _flatten_layer(layer) + return [(pytree.MappingKey(key), value) for key, value in zip(context, values)], context + + _register_pytree_node(cache_layer_cls, _flatten_layer, _unflatten_layer, _flatten_layer_with_keys) + + +def _register_pytree_cache(cache_cls): + def _flatten_cache(cache): + attributes = { + "layers": cache.layers, + "offloading": cache.offloading, + "only_non_sliding": getattr(cache, "only_non_sliding", True), + } + return list(attributes.values()), list(attributes.keys()) + + def _flatten_cache_with_keys(cache): + values, context = _flatten_cache(cache) + return [(pytree.MappingKey(key), value) for key, value in zip(context, values)], context + + def _unflatten_cache(values, context): + attributes = dict(zip(context, values)) + cache = Cache( + layers=attributes["layers"], + offloading=attributes["offloading"], + offload_only_non_sliding=attributes["only_non_sliding"], + ) + cache.__class__ = cache_cls + return cache - if not is_torch_greater_or_equal_than_2_6: - logging.warning("DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions.") + _register_pytree_node(cache_cls, _flatten_cache, _unflatten_cache, _flatten_cache_with_keys) - return { - "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], - "value_cache": [layer.values for layer in cache.layers if layer.values is not None], - } +def register_pytree_cache(): + """Register cache classes as pytrees for torch.export.""" + for cache_layer_cls in (StaticLayer, StaticSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer): + _register_pytree_cache_layer(cache_layer_cls) -def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - # Reconstruct layers from keys and values lists - key_list = dictionary.get("key_cache", []) - value_list = dictionary.get("value_cache", []) - for idx in range(max(len(key_list), len(value_list))): - key = key_list[idx] if idx < len(key_list) else None - value = value_list[idx] if idx < len(value_list) else None - cache.update(key, value, idx) - return cache + for cache_cls in (StaticCache, DynamicCache): + _register_pytree_cache(cache_cls) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index c64f1ce23ec2..02577b4e85b1 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -128,6 +128,10 @@ def _load_deepgemm_kernel(): # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions major = torch.cuda.get_device_capability()[0] + if major >= 10: + raise ImportError( + "DeepGEMM is not yet supported on Blackwell (SM100+) GPUs. Falling back to Triton finegrained-fp8 kernel." + ) if major < 9: raise ImportError( f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index c9ba021c54db..8bf7bfdba648 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -259,6 +259,7 @@ "attention.head_count_kv": "num_key_value_heads", "attention.layer_norm_rms_epsilon": "rms_norm_eps", "attention.sliding_window": "sliding_window", + "attention.logit_softcapping": "attn_logit_softcapping", "vocab_size": "vocab_size", }, "gemma3": { @@ -275,6 +276,7 @@ "attention.head_count_kv": "num_key_value_heads", "attention.layer_norm_rms_epsilon": "rms_norm_eps", "attention.sliding_window": "sliding_window", + "attention.logit_softcapping": "attn_logit_softcapping", "vocab_size": "vocab_size", }, "umt5": { diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 083ec53a2fd3..f83007410f7d 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -127,3 +127,135 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve logger.warning("No linear modules were found in your model for quantization.") return model + + +class HqqQuantize: + """HQQ quantization operation for the new weight loading flow.""" + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict, + full_layer_name=None, + model=None, + **kwargs, + ): + from hqq.core.quantize import HQQLinear + + from ..quantizers.quantizers_utils import get_module_from_name + + # input_dict has {param_name: [tensor]} for the weight + value = list(input_dict.values())[0] + value = value[0] if isinstance(value, list) else value + + # full_layer_name is e.g. "model.layers.0.self_attn.q_proj.weight" + module_name = full_layer_name.rsplit(".", 1)[0] + module, _ = get_module_from_name(model, full_layer_name) + + # Load weight into the nn.Linear module + module.weight = torch.nn.Parameter(value, requires_grad=False) + + # Get the quant_config that was set in _process_model_before_weight_loading + quant_config = getattr(module, "quant_config", None) + if quant_config is None: + # Module is skipped from quantization, just return the weight as-is + return {full_layer_name: value} + + # Determine target device and compute dtype + target_device = value.device + compute_dtype = self.hf_quantizer.dtype + + # Create HQQLinear from the nn.Linear + hqq_layer = HQQLinear( + module, + quant_config=quant_config, + compute_dtype=compute_dtype, + device=target_device, + del_orig=True, + ) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.hf_quantizer.using_multi_gpu: + hqq_layer = self.hf_quantizer._patch_layer_for_multigpu(hqq_layer) + + # Replace the module in the model + parent_module_name, _, child_name = module_name.rpartition(".") + parent_module = model.get_submodule(parent_module_name) if parent_module_name else model + setattr(parent_module, child_name, hqq_layer) + + # Mark as loaded so it's not reported as missing + missing_keys = kwargs.get("missing_keys") + if missing_keys is not None: + missing_keys.discard(full_layer_name) + + # Return empty dict so the loading code doesn't try to set params + return {} + + +class HqqDeserialize: + """Deserialize HQQ pre-quantized weights into an HQQLinear module.""" + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict, + full_layer_name=None, + model=None, + **kwargs, + ): + from hqq.core.quantize import HQQLinear + + # Unwrap list values + state_dict = {} + for key, value in input_dict.items(): + state_dict[key] = value[0] if isinstance(value, list) else value + + # If W_q is not present, this is not an HQQ-quantized layer — pass through + if "W_q" not in state_dict: + return input_dict + + # full_layer_name is e.g. "model.layers.0.self_attn.v_proj.weight" + # (target pattern "weight" appended to module path) + module_name = full_layer_name.rsplit(".", 1)[0] + + parent_name, _, child_name = module_name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + + # Create empty HQQLinear + hqq_layer = HQQLinear( + None, + None, + compute_dtype=self.hf_quantizer.dtype or torch.float16, + device="cpu", + initialize=False, + ) + + # Make W_q an nn.Parameter as HQQ expects + if "W_q" in state_dict: + state_dict["W_q"] = torch.nn.Parameter(state_dict["W_q"], requires_grad=False) + + hqq_layer.load_state_dict(state_dict) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.hf_quantizer.using_multi_gpu: + hqq_layer = self.hf_quantizer._patch_layer_for_multigpu(hqq_layer) + + setattr(parent, child_name, hqq_layer) + + # Mark weight and bias as loaded + missing_keys = kwargs.get("missing_keys") + if missing_keys is not None: + missing_keys.discard(full_layer_name) + # Also discard bias since HQQLinear handles it internally + bias_key = module_name + ".bias" + missing_keys.discard(bias_key) + + return {} diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 70a343424aa8..c541e939b07b 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -359,7 +359,11 @@ def load_and_register_attn_kernel( # Register the kernel as a valid attention ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) - ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) + + # Allow the kernel module to declare its preferred mask function (e.g., MASK_FUNCTION = "sdpa"). + # Falls back to "flash_attention_2" for backward compatibility with existing kernels. + mask_type = getattr(kernel, "MASK_FUNCTION", "flash_attention_2") + ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS[mask_type]) return kernel diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 2656b7169c62..83a653261cbf 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -2261,7 +2261,7 @@ class SwanLabCallback(TrainerCallback): A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/). """ - def __init__(self): + def __init__(self, **kwargs): if not is_swanlab_available(): raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.") import swanlab @@ -2269,6 +2269,7 @@ def __init__(self): self._swanlab = swanlab self._initialized = False self._log_model = os.getenv("SWANLAB_LOG_MODEL", None) + self._init_kwargs = kwargs def setup(self, args, state, model, **kwargs): """ @@ -2352,6 +2353,7 @@ def setup(self, args, state, model, **kwargs): init_args["resume"] = "allow" if self._swanlab.get_run() is None: + init_args.update(self._init_kwargs) self._swanlab.init( **init_args, ) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index c8a8e87f3621..acc9b3575a14 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -15,6 +15,8 @@ from collections.abc import Callable from functools import wraps +from torch.distributed.tensor import DTensor + from ..utils import logging from ..utils.generic import GeneralInterface from ..utils.import_utils import ( @@ -354,12 +356,17 @@ def _grouped_linear( Returns: `torch.Tensor`: Output tensor of shape (S, output_dim). """ + # torch._grouped_mm is not registered for autocast, so we need to ensure + # input and weight have the same dtype (e.g. LayerNorm outputs float32 under + # autocast while weights may be bfloat16). + input = input.to(weight.dtype) + if is_transposed: # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) out = _grouped_mm(input, weight, offs=offs) else: # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) - out = _grouped_mm(input, weight.transpose(-2, -1), offs=offs) + out = _grouped_mm(input, weight.transpose(-2, -1).contiguous(), offs=offs) if bias is not None: # We should be able to pass bias to the grouped_mm call, but it's not yet supported. @@ -401,21 +408,29 @@ def grouped_mm_experts_forward( # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues # With deterministic algorithms, CPU only supports float input, CUDA only supports int input. - histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() + + # torch.histc() does not support integer dtypes on CPU and MPS. + # It works well and is more efficient on CUDA when using int. + # For all other backends (XPU, TPU/XLA, HPU, etc.), we conservatively + # use float32 as it has broader operator suppor + histc_input = expert_ids_g.int() if device.type == "cuda" else expert_ids_g.to(torch.float32) tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + def _local(p): + return p.to_local() if isinstance(p, DTensor) else p + # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. # I have already implemented a version that only passes the active experts, but # to do so I had to use torch.unique which breaks the graph capture (data-dependent). # Also there were no speedup gains from it in my experiments, even in eager mode. if self.has_gate: - selected_weights = self.gate_up_proj - selected_biases = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None + selected_weights = _local(self.gate_up_proj) + selected_biases = _local(self.gate_up_proj_bias)[expert_ids_g] if self.has_bias else None else: - selected_weights = self.up_proj - selected_biases = self.up_proj_bias[expert_ids_g] if self.has_bias else None + selected_weights = _local(self.up_proj) + selected_biases = _local(self.up_proj_bias)[expert_ids_g] if self.has_bias else None # --- Up projection per expert (grouped) --- proj_out = _grouped_linear( @@ -431,8 +446,8 @@ def grouped_mm_experts_forward( proj_out = self.act_fn(proj_out) # (S, intermediate_dim) # Select down projection weights and biases - selected_weights = self.down_proj - selected_biases = self.down_proj_bias[expert_ids_g] if self.has_bias else None + selected_weights = _local(self.down_proj) + selected_biases = _local(self.down_proj_bias)[expert_ids_g] if self.has_bias else None # --- Down projection per expert (grouped) --- proj_out = _grouped_linear( diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 67d9420659af..018507a5134b 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -498,15 +498,18 @@ def mlp_forward(self, hidden_states): else: routing = triton_kernels_hub.routing.routing - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + is_3d = hidden_states.ndim == 3 + if is_3d: + batch_size, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) with on_device(router_logits.device): routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx=scatter_idx) - routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) + if is_3d: + routed_out = routed_out.reshape(batch_size, seq_len, self.router.hidden_dim) return routed_out, router_logits diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 7b93e0a134b8..cad07bc2d3fc 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -34,6 +34,7 @@ Transpose, WeightConverter, WeightRenaming, + rename_source_key, ) from ..utils import ( CONFIG_NAME, @@ -47,7 +48,7 @@ logging, ) from ..utils.hub import DownloadKwargs -from ..utils.loading_report import log_state_dict_report +from ..utils.loading_report import LoadStateDictInfo, log_state_dict_report if is_torch_available(): @@ -506,6 +507,7 @@ def load_adapter( `find_adapter_config_file` method. """ from peft import PeftType + from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict @@ -618,45 +620,92 @@ def load_adapter( device_map = getattr(self, "hf_device_map", {"": self.device}) - # If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model` - # is not compatible with the way PEFT adapter should be sharded. - has_tp_adapters = False - for module in self.modules(): - tp_info = getattr(module, "_tp_info", None) - if tp_info is not None: - has_tp_adapters = True - break - - if has_tp_adapters: + def _resolve_adapter_state_dict(): + # Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths + # that bypass `self._load_pretrained_model` (which would otherwise read the files itself). all_pointer = set() if adapter_state_dict is not None: - merged_state_dict = adapter_state_dict - elif ( - checkpoint_files is not None - and checkpoint_files[0].endswith(".safetensors") - and adapter_state_dict is None - ): + return adapter_state_dict + if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"): merged_state_dict = {} for file in checkpoint_files: file_pointer = safe_open(file, framework="pt", device="cpu") all_pointer.add(file_pointer) for k in file_pointer.keys(): merged_state_dict[k] = file_pointer.get_tensor(k) + return merged_state_dict # Checkpoints are .bin - elif checkpoint_files is not None: + if checkpoint_files is not None: merged_state_dict = {} for ckpt_file in checkpoint_files: merged_state_dict.update(load_state_dict(ckpt_file)) - else: - raise ValueError("Neither a state dict nor checkpoint files were found.") + return merged_state_dict + raise ValueError("Neither a state dict nor checkpoint files were found.") - adapter_state_dict = merged_state_dict + def set_inference_mode(model): + model.eval() + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.requires_grad_(False) + + # If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model` + # is not compatible with the way PEFT adapter should be sharded. + has_tp_adapters = False + for module in self.modules(): + tp_info = getattr(module, "_tp_info", None) + if tp_info is not None: + has_tp_adapters = True + break + + if has_tp_adapters: + adapter_state_dict = _resolve_adapter_state_dict() if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()): raise ValueError("Expected all values in the adapter state dict to be tensors.") _maybe_shard_state_dict_for_tp(self, adapter_state_dict, adapter_name) + if hotswap: + # Bypass the standard loader and use PEFT's hotswap path so that LoRA weights + # whose rank differs from the existing adapter's are copied (and zero-padded) + # in place rather than triggering a "size mismatch" reinit, and so the LoRA + # scaling is updated alongside the weights. + from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + + adapter_state_dict = _resolve_adapter_state_dict() + + # need to apply conversions manually as we don't use _load_pretrained_model + renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)] + converters = [c for c in peft_weight_conversions if isinstance(c, WeightConverter)] + meta_state_dict = self.state_dict() + processed_state_dict = {} + for key, value in adapter_state_dict.items(): + renamed_key, _ = rename_source_key(key, renamings, converters, self.base_model_prefix, meta_state_dict) + processed_state_dict[renamed_key] = value + + check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=processed_state_dict, + adapter_name=adapter_name, + config=peft_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error:\n{e}") + raise + + if peft_config.inference_mode: + set_inference_mode(self) + + return LoadStateDictInfo( + missing_keys=set(), + unexpected_keys=set(), + mismatched_keys=set(), + error_msgs=[], + conversion_errors={}, + ) + load_config = replace( load_config, pretrained_model_name_or_path=peft_model_id, @@ -676,12 +725,7 @@ def load_adapter( ) if peft_config.inference_mode: - from peft.tuners.tuners_utils import BaseTunerLayer - - self.eval() - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - module.requires_grad_(False) + set_inference_mode(self) adapter_key_markers = {adapter_name} if peft_config is not None and getattr(peft_config, "peft_type", None) is not None: @@ -699,6 +743,16 @@ def is_adapter_key(key: str) -> bool: loading_info=loading_info, logger=logger, ) + + if self._prepare_peft_hotswap_kwargs is not None: + # Apply once, after the first adapter has been loaded but before the model is + # compiled, so the LoRA layers get padded up to target_rank and a later adapter + # with a different rank can be hot-swapped in without recompiling. + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs) + self._prepare_peft_hotswap_kwargs = None + return loading_info def enable_peft_hotswap( diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index bdf82e8490f0..21f0a833ef08 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -29,6 +29,7 @@ import torch import torch.distributed as dist from torch import nn + from torch.distributed.tensor import DTensor, Shard # Cache this result has it's a C FFI call which can be pretty time-consuming _torch_distributed_available = torch.distributed.is_available() @@ -46,8 +47,11 @@ def initialize_tensor_parallelism( """ if tp_size is not None and tp_plan is None: raise ValueError("tp_plan has to be set when tp_size is passed.") - if tp_plan is not None and device_map is not None: - raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.") + if tp_plan is not None and device_map is not None and device_map != "meta" and device_mesh is None: + raise ValueError( + "`tp_plan` and `device_map` are mutually exclusive. " + "Choose either one for parallelization or include a `device_mesh`." + ) if device_mesh is None: if not is_torch_greater_or_equal("2.5"): raise OSError("Tensor parallel is only supported for `torch>=2.5`.") @@ -97,7 +101,8 @@ def initialize_tensor_parallelism( ) device_mesh = device_mesh["tp"] tp_size = device_mesh.size() - device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}") + if device_map is None: + device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}") return device_map, device_mesh, tp_size @@ -130,6 +135,17 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig return None +def get_ep_sharded_param_names(model) -> list[str]: + """FQNs of parameters whose data is per-rank unique under EP sharding.""" + if not getattr(model, "has_ep", False): + return [] + return [ + name + for name, _ in model.named_parameters() + if _get_parameter_tp_plan(parameter_name=name, tp_plan=model.tp_plan, is_weight=True) == "grouped_gemm" + ] + + # ============================================================================= # Tensor Sharding Utilities # ============================================================================= @@ -685,6 +701,14 @@ def update_module_attributes(self, module: nn.Module): """ pass + def post_shard_wrap(self, param: nn.Parameter) -> nn.Parameter: + """ + Optional final wrap applied to a parameter after `shard_tensor` and before it is + attached to the module. Default is identity. Subclasses can override to e.g. wrap + the local shard as a DTensor. + """ + return param + class ColwiseParallel(TensorParallelLayer): """ @@ -966,8 +990,8 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"): input_mask = mod._input_mask # Use multiplication instead of in-place assignment to preserve gradients - mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs) - outputs = outputs * (~mask_expanded).to(outputs.dtype) + mask = input_mask.unsqueeze(-1) + outputs = outputs * (~mask).to(outputs.dtype) del mod._input_mask return all_reduce_forward(outputs, device_mesh) @@ -1078,6 +1102,15 @@ def update_module_attributes(self, module: nn.Module): if hasattr(module, "num_experts"): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] + def post_shard_wrap(self, param: nn.Parameter) -> nn.Parameter: + """ + Wrap the EP-sharded local tensor as a DTensor on the TP/EP mesh. Without this, the + optimizer's foreach ops error with "mixed Tensor and DTensor" against the + FSDP-wrapped DTensor params on the rest of the model. + """ + dt = DTensor.from_local(param.data, self.device_mesh, [Shard(0)], run_check=False) + return nn.Parameter(dt, requires_grad=param.requires_grad) + class RouterParallel(TensorParallelLayer): """ @@ -1488,6 +1521,8 @@ def shard_and_distribute_module( # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point()) + if current_shard_plan is not None: + param = tp_layer.post_shard_wrap(param) setattr(module_to_tp, param_type, param) if tp_layer is not None: tp_layer.update_module_attributes(module_to_tp) diff --git a/src/transformers/integrations/torchao.py b/src/transformers/integrations/torchao.py index 421a004dd6e9..2fa20a3982b9 100644 --- a/src/transformers/integrations/torchao.py +++ b/src/transformers/integrations/torchao.py @@ -35,19 +35,10 @@ logger = logging.get_logger(__name__) -def _quantization_type(weight): - from torchao.dtypes import AffineQuantizedTensor - from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor - - if isinstance(weight, AffineQuantizedTensor): - return f"{weight.__class__.__name__}({weight._quantization_type()})" - - if isinstance(weight, LinearActivationQuantizedTensor): - return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" - - def _linear_extra_repr(self): - weight = _quantization_type(self.weight) + from torchao.utils import TorchAOBaseTensor + + weight = self.weight.__class__.__name__ if isinstance(self.weight, TorchAOBaseTensor) else None if weight is None: return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" else: diff --git a/src/transformers/integrations/tpu.py b/src/transformers/integrations/tpu.py index a329a7fcdd84..e05776aab7fe 100644 --- a/src/transformers/integrations/tpu.py +++ b/src/transformers/integrations/tpu.py @@ -162,7 +162,9 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): return model -def save_tpu_checkpoint(model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, output_dir=None): +def save_tpu_checkpoint( + model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, is_fsdp_xla_v2_enabled, output_dir=None +): """ Saves a model checkpoint on TPU/XLA devices. @@ -175,10 +177,13 @@ def save_tpu_checkpoint(model, args, accelerator, processing_class, is_fsdp_xla_ accelerator (`Accelerator`): The accelerator instance. processing_class: The processing class (tokenizer/processor) to save alongside the model. is_fsdp_xla_v1_enabled (`bool`): Whether FSDP XLA v1 is enabled. + is_fsdp_xla_v2_enabled (`bool`): Whether FSDP XLA v2 is enabled. output_dir (`str`, *optional*): The directory to save to. Defaults to `args.output_dir`. """ import torch_xla.core.xla_model as xm + from ..modeling_utils import unwrap_model + output_dir = output_dir if output_dir is not None else args.output_dir logger.info(f"Saving model checkpoint to {output_dir}") @@ -219,15 +224,16 @@ def save_tpu_checkpoint(model, args, accelerator, processing_class, is_fsdp_xla_ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) elif not isinstance(model, supported_classes): - if isinstance(accelerator.unwrap_model(model), supported_classes): - accelerator.unwrap_model(model).save_pretrained( + unwrapped_model = unwrap_model(model, recursive=is_fsdp_xla_v2_enabled) + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( output_dir, is_main_process=args.should_save, - state_dict=xm._maybe_convert_to_cpu(model.state_dict()), + state_dict=xm._maybe_convert_to_cpu(unwrapped_model.state_dict()), ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = xm._maybe_convert_to_cpu(model.state_dict()) + state_dict = xm._maybe_convert_to_cpu(unwrapped_model.state_dict()) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: model.save_pretrained( diff --git a/src/transformers/loss/loss_for_object_detection.py b/src/transformers/loss/loss_for_object_detection.py index 52b43f779f35..79469785827d 100644 --- a/src/transformers/loss/loss_for_object_detection.py +++ b/src/transformers/loss/loss_for_object_detection.py @@ -31,7 +31,7 @@ from transformers.image_transforms import center_to_corners_format -def dice_loss(inputs, targets, num_boxes): +def dice_loss(inputs, targets, num_boxes, valid_mask=None): """ Compute the DICE loss, similar to generalized IOU for masks @@ -41,16 +41,25 @@ def dice_loss(inputs, targets, num_boxes): targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). + valid_mask: Optional boolean tensor with the same shape as inputs. + If provided, only valid (non-padding) areas are considered in the loss. + True means valid, False means padding. """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) + + if valid_mask is not None: + valid_mask = valid_mask.flatten(1).to(dtype=inputs.dtype) + inputs = inputs * valid_mask + targets = targets * valid_mask + numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_boxes -def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, valid_mask=None): """ Loss used in RetinaNet for dense detection: https://huggingface.co/papers/1708.02002. @@ -64,6 +73,9 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f Optional weighting factor in the range (0,1) to balance positive vs. negative examples. gamma (`int`, *optional*, defaults to `2`): Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. + valid_mask: Optional boolean tensor with the same shape as inputs. + If provided, only valid (non-padding) areas are considered in the loss. + True means valid, False means padding. Returns: Loss tensor @@ -78,6 +90,13 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss + if valid_mask is not None: + valid_mask = valid_mask.flatten(1).to(dtype=loss.dtype) + loss = loss * valid_mask + # Average only over valid pixels per sample + valid_count = valid_mask.sum(1).clamp(min=1) + return (loss.sum(1) / valid_count).sum() / num_boxes + return loss.mean(1).sum() / num_boxes @@ -193,11 +212,16 @@ def loss_masks(self, outputs, targets, indices, num_boxes): source_masks = outputs["pred_masks"] source_masks = source_masks[source_idx] masks = [t["masks"] for t in targets] - # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks = target_masks.to(source_masks) target_masks = target_masks[target_idx] + # Get valid mask for selected targets (invert: True = valid, False = padding) + # valid has shape (batch, h, w), we need to index by batch indices only + batch_idx = target_idx[0] + valid_mask = ~valid + valid_mask = valid_mask[batch_idx] + # upsample predictions to the target size source_masks = nn.functional.interpolate( source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False @@ -206,9 +230,12 @@ def loss_masks(self, outputs, targets, indices, num_boxes): target_masks = target_masks.flatten(1) target_masks = target_masks.view(source_masks.shape) + valid_mask = valid_mask.flatten(1) + valid_mask = valid_mask.view(source_masks.shape) + losses = { - "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), - "loss_dice": dice_loss(source_masks, target_masks, num_boxes), + "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes, valid_mask=valid_mask), + "loss_dice": dice_loss(source_masks, target_masks, num_boxes, valid_mask=valid_mask), } return losses diff --git a/src/transformers/loss/loss_rt_detr.py b/src/transformers/loss/loss_rt_detr.py index cf6d6ad05940..69dc1ff67600 100644 --- a/src/transformers/loss/loss_rt_detr.py +++ b/src/transformers/loss/loss_rt_detr.py @@ -270,6 +270,12 @@ def loss_masks(self, outputs, targets, indices, num_boxes): target_masks = target_masks.to(source_masks) target_masks = target_masks[target_idx] + # Get valid mask for selected targets (invert: True = valid, False = padding) + # valid has shape (batch, h, w), we need to index by batch indices only + batch_idx = target_idx[0] + valid_mask = ~valid + valid_mask = valid_mask[batch_idx] + # upsample predictions to the target size source_masks = nn.functional.interpolate( source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False @@ -278,9 +284,12 @@ def loss_masks(self, outputs, targets, indices, num_boxes): target_masks = target_masks.flatten(1) target_masks = target_masks.view(source_masks.shape) + valid_mask = valid_mask.flatten(1) + valid_mask = valid_mask.view(source_masks.shape) + losses = { - "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), - "loss_dice": dice_loss(source_masks, target_masks, num_boxes), + "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes, valid_mask=valid_mask), + "loss_dice": dice_loss(source_masks, target_masks, num_boxes, valid_mask=valid_mask), } return losses diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 51564d299e55..07a75fbb57b1 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -31,10 +31,14 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, + weight: torch.Tensor | None = None, + label_smoothing: float = 0.0, **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + loss = nn.functional.cross_entropy( + source, target, ignore_index=ignore_index, weight=weight, reduction=reduction, label_smoothing=label_smoothing + ) if reduction == "sum": # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch): @@ -52,9 +56,6 @@ def ForCausalLMLoss( shift_labels: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - if shift_labels is None: # Shift so that tokens < n predict n labels = nn.functional.pad(labels, (0, 1), value=ignore_index) @@ -63,6 +64,13 @@ def ForCausalLMLoss( # Flatten the tokens logits = logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) + # Filter out the ignore_index labels + mask = shift_labels != ignore_index + shift_labels = shift_labels[mask] + logits = logits[mask.to(logits.device)] + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Enable model parallelism shift_labels = shift_labels.to(logits.device) loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) return loss diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index e833a5a8a2ab..8a7195f13806 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -21,8 +21,7 @@ import httpx import yaml from huggingface_hub import is_offline_mode, model_info -from huggingface_hub.errors import OfflineModeIsEnabled -from huggingface_hub.utils import HFValidationError +from huggingface_hub.errors import HFValidationError, OfflineModeIsEnabled from . import __version__ from .models.auto.modeling_auto import ( diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 32642d71d2a3..690254513618 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -328,7 +328,7 @@ def _pad_input(hidden_states, indices, batch, seqlen): return output.view(batch, seqlen, *dim) -def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. @@ -337,19 +337,15 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). max_seqlen_in_batch (`int`): Maximum sequence length in batch. """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( - indices, cu_seqlens, max_seqlen_in_batch, ) @@ -396,7 +392,8 @@ def _upad_input( (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + flatten_mask = attention_mask.reshape(-1).bool() # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores @@ -405,13 +402,15 @@ def _upad_input( batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = _index_first_axis(key_layer, indices_k) - value_layer = _index_first_axis(value_layer, indices_k) + key_layer = _index_first_axis(key_layer, flatten_mask) + value_layer = _index_first_axis(value_layer, flatten_mask) if query_length == kv_seq_len: - query_layer = _index_first_axis(query_layer, indices_k) + query_layer = _index_first_axis(query_layer, flatten_mask) cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch_q = max_seqlen_in_batch_k.item() + indices_q = flatten_mask.nonzero(as_tuple=False).flatten() elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( @@ -517,7 +516,7 @@ def _is_packed_sequence(position_ids, batch_size): 2. Flattened sequences only are supported 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences """ - if position_ids is None: + if is_tracing(position_ids) or position_ids is None: return False increasing_position_sequences = ( @@ -616,6 +615,21 @@ def _process_flash_attention_kwargs( flash_kwargs (`dict`): A dict of kwargs that are requested and supported. """ + + user_kwargs = { + "dropout_p": dropout, + "window_size": sliding_window, + "deterministic": deterministic, + "softcap": softcap, + "s_aux": s_aux, + } + # Note 'window_size' in supports_mapping maps to our 'sliding_window' param + for k, v in user_kwargs.items(): + if not supports_mapping[k] and v is not None: + raise ValueError( + f"Parameter `{k}` is not supported by this Flash Attention implementation but was set, please use a different attentionimplementation." + ) + flash_kwargs = { "causal": is_causal and not (use_top_left_mask and query_length == 1), "softmax_scale": softmax_scale, diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index 1012606fcaaf..2aca6fda0aa3 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -102,7 +102,7 @@ def __init__(self, config): self.num_labels = config.num_labels # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class setattr(self, self.base_model_prefix, AutoModel.from_config(config)) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + self.score = nn.Linear(config.get_text_config().hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @@ -137,13 +137,13 @@ def forward( else: batch_size = inputs_embeds.shape[0] - if self.config.pad_token_id is None and batch_size != 1: + if self.config.get_text_config().pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: + if self.config.get_text_config().pad_token_id is None: last_non_pad_token = -1 elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + non_pad_mask = (input_ids != self.config.get_text_config().pad_token_id).to(logits.device, torch.int32) token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b041964bbdfc..b2ea39eeb850 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1316,6 +1316,12 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): ) self.config = config self.name_or_path = config.name_or_path + quant_config = getattr(config, "quantization_config", None) + if quant_config is not None: + raise NotImplementedError( + "Quantization via `from_config()` is not supported. " + "Quantized models must be created via `from_pretrained()` with an appropriate backend." + ) # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid # setting it recursively) @@ -1368,6 +1374,9 @@ def post_init(self): self._keep_in_fp32_modules_strict = set(self._keep_in_fp32_modules_strict or []) # Current submodel must register its `_no_split_modules` as well self._no_split_modules = set(self._no_split_modules or []) + # Current submodel must register the `_keys_to_ignore_on_load_unexpected/missing` + self._keys_to_ignore_on_load_unexpected = self._keys_to_ignore_on_load_unexpected or [] + self._keys_to_ignore_on_load_missing = self._keys_to_ignore_on_load_missing or [] # Iterate over children only: as the final model is created, this is enough to gather the properties from all submodels. # This works because the way the `__init__` and `post_init` are called on all submodules is depth-first in the graph @@ -1390,17 +1399,40 @@ def post_init(self): # Record `_no_split_modules` from the children if no_split := getattr(module, "_no_split_modules", None): self._no_split_modules.update(no_split) + # Record `_keys_to_ignore_on_load_unexpected/missing` from the children + if ignore_unexpected := getattr(module, "_keys_to_ignore_on_load_unexpected", None): + self._keys_to_ignore_on_load_unexpected.extend( + [f"{name}.{child_name}" for child_name in ignore_unexpected] + ) + if ignore_missing := getattr(module, "_keys_to_ignore_on_load_missing", None): + self._keys_to_ignore_on_load_missing.extend([f"{name}.{child_name}" for child_name in ignore_missing]) + + # Preserve the current no-tie scope on this instance so only the model + # being initialized in that scope skips tie_weights(). + self._skip_tie_weights_scope = init._SKIP_TIE_WEIGHTS_SCOPE.get() # Maybe initialize the weights and tie the keys self.init_weights() self._backward_compatibility_gradient_checkpointing() + # Cache the list of (name, submodule) pairs where the submodule is a PreTrainedModel. + # This pattern is used in several places across the codebase; computing it once avoids + # repeated traversal of the full module tree. + self._named_pretrained_submodules: list[tuple[str, PreTrainedModel]] = [ + (name, module) for name, module in self.named_modules() if isinstance(module, PreTrainedModel) + ] + + @property + def has_ep(self) -> bool: + """Whether expert parallelism is enabled for this model.""" + distributed_config = getattr(getattr(self, "config", None), "distributed_config", None) + return distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) @property def tp_plan(self) -> dict[str, str]: """ The full tp plan for the model's modules """ - if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel: + if self.has_ep: return self._ep_plan return self._tp_plan @@ -2371,14 +2403,25 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): if getattr(module, "weight", None) is not None: - init.normal_(module.weight, mean=0.0, std=std) - if module.bias is not None: + if module.weight.dtype in (torch.int8, torch.uint8): + logger.debug( + f"Skipping weight initialization for quantized module {module.__class__.__name__} with dtype " + f"{module.weight.dtype}" + ) + else: + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None and module.bias.dtype not in (torch.int8, torch.uint8): init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - init.normal_(module.weight, mean=0.0, std=std) - # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag - if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): - init.zeros_(module.weight[module.padding_idx]) + if module.weight.dtype in (torch.int8, torch.uint8): + logger.debug( + f"Skipping weight initialization for quantized embedding with dtype {module.weight.dtype}" + ) + else: + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) elif isinstance(module, nn.MultiheadAttention): # This uses torch's original init module._reset_parameters() @@ -2591,6 +2634,9 @@ def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: b `source` is missing in the checkpoint while `target` exists, we *swap* source and target so we can still tie everything to the parameter that actually exists. """ + if init.should_skip_tie_weights(self): + return + # In this case, the keys stored in `all_tied_weights_keys` are already correct if not recompute_mapping: tied_keys = self.all_tied_weights_keys @@ -3338,8 +3384,10 @@ def save_pretrained( files_timestamps = self._get_files_timestamps(save_directory) metadata = {} + quantizer_provided_state_dict = False if hf_quantizer is not None: state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self) + quantizer_provided_state_dict = state_dict is not None metadata["format"] = "pt" # Only save the model itself if we are using distributed training @@ -3428,7 +3476,8 @@ def save_pretrained( state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save) # Revert all renaming and/or weight operations - if save_original_format and not _hf_peft_config_loaded: + # Skip if the quantizer already provided the state_dict in the correct serialization format + if save_original_format and not _hf_peft_config_loaded and not quantizer_provided_state_dict: state_dict = revert_weight_conversion(model_to_save, state_dict) # Shard the model if it is too big. @@ -3671,14 +3720,27 @@ def float(self, *args): @classmethod def get_init_context( - cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool, allow_all_kernels: bool | None + cls, + dtype: torch.dtype, + is_quantized: bool, + _is_ds_init_called: bool, + allow_all_kernels: bool | None, + distributed_config=None, ): # Need to instantiate with correct dtype init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights(), apply_patches()] # Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__ if allow_all_kernels: init_contexts.append(allow_all_hub_kernels()) - if is_deepspeed_zero3_enabled(): + _has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) + if _has_ep and is_deepspeed_zero3_enabled(): + # EP + DeepSpeed: use meta device (same as the normal non-DS path). + # zero.Init is skipped because EP needs to shard experts via distribute_model() + # hooks, which are incompatible with ZeRO-3 lazy parameters. + # The standard weight loading path (not zero3) handles EP sharding via + # shard_and_distribute_module. deepspeed.initialize() wraps the result later. + init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()]) + elif is_deepspeed_zero3_enabled(): import deepspeed # We cannot initialize the model on meta device with deepspeed when not quantized @@ -4086,6 +4148,12 @@ def from_pretrained( download_kwargs_with_commit, **adapter_kwargs, ) + # EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model + # loads on CPU first. distribute_model() handles GPU placement during EP sharding. + # Without this, device_map triggers accelerate's dispatch path which breaks shard loading. + _has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) + if _has_ep and is_deepspeed_zero3_enabled(): + device_map = None device_map = check_and_set_device_map(device_map) # warn, error and fix the device map user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -4194,7 +4262,9 @@ def from_pretrained( register_fusion_patches(cls, config, fusion_config) - model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels) + model_init_context = cls.get_init_context( + dtype, is_quantized, _is_ds_init_called, allow_all_kernels, distributed_config + ) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. with ContextManagers(model_init_context): @@ -4327,7 +4397,11 @@ def _load_pretrained_model( error_msgs = [] - if is_deepspeed_zero3_enabled() and not is_quantized: + # EP + DeepSpeed: skip zero3 loading path. The model was created on meta device + # (not via zero.Init), so params are not zero3-partitioned. The standard loading + # path handles EP sharding via shard_and_distribute_module using the EP plan hooks + # registered by distribute_model(). deepspeed.initialize() wraps the result later. + if is_deepspeed_zero3_enabled() and not is_quantized and not model.has_ep: if state_dict is None: merged_state_dict = {} for ckpt_file in checkpoint_files: @@ -4646,14 +4720,12 @@ def _move_missing_keys_from_meta_to_device( """ is_quantized = hf_quantizer is not None # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here - if is_deepspeed_zero3_enabled() and not is_quantized: + # Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path. + if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep: return - # In this case we need to move everything back + # Leave parameters on meta on non-rank-0 FSDP ranks (rank-0 broadcast overwrites them); only buffers need real placeholders. if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: - for key, param in self.named_parameters(): - value = torch.zeros_like(param, device="cpu") - _load_parameter_into_model(self, key, value) for key, buffer in self.named_buffers(): value = torch.zeros_like(buffer, device="cpu") _load_parameter_into_model(self, key, value) @@ -4704,7 +4776,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None: self._is_hf_initialized = True # This will only initialize submodules that are not marked as initialized by the line above. - if is_deepspeed_zero3_enabled() and not is_quantized: + if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep: import deepspeed # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them @@ -4714,7 +4786,21 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None: with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): self.initialize_weights() else: - self.initialize_weights() + try: + all_params = [p for p in self.parameters() if p is not None] + if all_params and not any(p.dtype.is_floating_point for p in all_params): + logger.info("Skipping weight initialization for quantized model (non-floating-point dtype).") + skip_weight_initialization = True + else: + skip_weight_initialization = False + except Exception: + skip_weight_initialization = False + + if not skip_weight_initialization: + self.initialize_weights() + else: + logger.info("Weight initialization skipped.") + def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -> None: """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid @@ -4800,7 +4886,19 @@ def get_parameter_or_buffer(self, target: str): ): return module.get_extra_state() - raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.") + def __recursive_getattr(object, attribute, *args): + """Recurse through a parameter name that is '.' seperated to get the attribute""" + + def __getattr(object, attribute): + return getattr(object, attribute, *args) + + return functools.reduce(__getattr, [object] + attribute.split(".")) + + try: + # get the actual tensor parameter from a possible nested list + return __recursive_getattr(module, param_name) + except AttributeError: + raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.") def named_non_persistent_buffers( self, recurse: bool = True, remove_duplicate: bool = True diff --git a/src/transformers/models/afmoe/modeling_afmoe.py b/src/transformers/models/afmoe/modeling_afmoe.py index 421119b33deb..74fc8bc03b6a 100644 --- a/src/transformers/models/afmoe/modeling_afmoe.py +++ b/src/transformers/models/afmoe/modeling_afmoe.py @@ -103,7 +103,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index fa15fcce3de6..85b26d160058 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -17,9 +17,11 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin from ...utils import auto_docstring +from ..efficientnet.image_processing_efficientnet import EfficientNetImageProcessorKwargs class AlignProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: EfficientNetImageProcessorKwargs # see processing_utils.ProcessingKwargs documentation for usage. _defaults = { "text_kwargs": { diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 6162cb29559e..2ef1a1f30213 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -125,7 +125,7 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) - buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1) + buffered_token_type_ids = self.token_type_ids.to(position_ids.device).expand(position_ids.shape[0], -1) buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length) else: @@ -630,7 +630,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel): config: AltCLIPConfig base_model_prefix = "altclip" input_modalities = ("image", "text") - _no_split_modules = ["AltCLIPTextEmbeddings", "AltCLIPEncoderLayer", "AltCLIPVisionEmbeddings"] + _no_split_modules = ["AltRobertaEmbeddings", "AltRobertaLayer", "AltCLIPEncoderLayer", "AltCLIPVisionEmbeddings"] supports_gradient_checkpointing = True _supports_sdpa = True @@ -705,7 +705,7 @@ def __init__(self, config: AltCLIPVisionConfig): embed_dim = config.hidden_size self.embeddings = AltCLIPVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = AltCLIPEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @@ -742,7 +742,7 @@ def forward( >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.pre_layernorm(hidden_states) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/altclip/modular_altclip.py b/src/transformers/models/altclip/modular_altclip.py index fe9be6cac92f..ed36ac6e2a48 100644 --- a/src/transformers/models/altclip/modular_altclip.py +++ b/src/transformers/models/altclip/modular_altclip.py @@ -226,6 +226,7 @@ class AltCLIPVisionEmbeddings(CLIPVisionEmbeddings): class AltCLIPPreTrainedModel(CLIPPreTrainedModel): + _no_split_modules = ["AltRobertaEmbeddings", "AltRobertaLayer", "AltCLIPEncoderLayer", "AltCLIPVisionEmbeddings"] _can_record_outputs = { "hidden_states": AltCLIPEncoderLayer, "attentions": AltCLIPAttention, diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 7d14dd3d14c8..af1a03c7c900 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -134,7 +134,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 8d2d05bf2952..4e99339ca294 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -139,7 +139,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index e66b12438940..76d8459de528 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -673,7 +673,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -946,9 +946,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py b/src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py index 246e37edd729..000d786560bb 100644 --- a/src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py +++ b/src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py @@ -233,7 +233,7 @@ def merge_and_shard_weights(src_root: Path, dst_root: Path, processor: AudioFlam --dst_dir audio-flamingo-3-hf ``` -3) Convert and push directly to the Hub (requires `huggingface-cli login` or `HF_TOKEN`): +3) Convert and push directly to the Hub (requires `hf auth login` or `HF_TOKEN`): ``` python src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py \ diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 5ef09f8eb443..9f9266aadde6 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -323,7 +323,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], if kwargs.get("dtype") == "auto": _ = kwargs.pop("dtype") # to not overwrite the quantization_config if config has a quantization_config - if kwargs.get("quantization_config") is not None: + if "quantization_config" in kwargs: _ = kwargs.pop("quantization_config") config, kwargs = AutoConfig.from_pretrained( @@ -340,7 +340,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], kwargs["torch_dtype"] = "auto" if kwargs_orig.get("dtype", None) == "auto": kwargs["dtype"] = "auto" - if kwargs_orig.get("quantization_config", None) is not None: + if "quantization_config" in kwargs_orig: kwargs["quantization_config"] = kwargs_orig["quantization_config"] has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index c624f49083d2..98447b6d1724 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -583,8 +583,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"] image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor") - # If not in image processor config, try the model config - if image_processor_type is None and image_processor_auto_map is None: + # If not in image processor config, try the model config (override image_processor_auto_map if trust_remote_code is False) + if image_processor_type is None and (image_processor_auto_map is None or trust_remote_code is False): if not isinstance(config, PreTrainedConfig): config = AutoConfig.from_pretrained( pretrained_model_name_or_path, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a541f13499b7..0ddf436c42d2 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -704,6 +704,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral", "MinistralForCausalLM"), ("ministral3", "Ministral3ForCausalLM"), ("mistral", "MistralForCausalLM"), + ("mistral4", "Mistral4ForCausalLM"), ("mixtral", "MixtralForCausalLM"), ("mllama", "MllamaForCausalLM"), ("modernbert-decoder", "ModernBertDecoderForCausalLM"), @@ -1217,6 +1218,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): [ ("cohere_asr", "CohereAsrForConditionalGeneration"), ("dia", "DiaForConditionalGeneration"), + ("glmasr", "GlmAsrForConditionalGeneration"), ("granite_speech", "GraniteSpeechForConditionalGeneration"), ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"), ("moonshine", "MoonshineForConditionalGeneration"), @@ -1329,7 +1331,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen2_moe", "Qwen2MoeForSequenceClassification"), ("qwen3", "Qwen3ForSequenceClassification"), ("qwen3_5", "Qwen3_5ForSequenceClassification"), - ("qwen3_5_text", "Qwen3_5ForSequenceClassification"), + ("qwen3_5_text", "Qwen3_5TextForSequenceClassification"), ("qwen3_moe", "Qwen3MoeForSequenceClassification"), ("qwen3_next", "Qwen3NextForSequenceClassification"), ("reformer", "ReformerForSequenceClassification"), @@ -1688,6 +1690,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): # Model for Text-To-Waveform mapping ("bark", "BarkModel"), ("csm", "CsmForConditionalGeneration"), + ("dia", "DiaForConditionalGeneration"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), ("higgs_audio_v2", "HiggsAudioV2ForConditionalGeneration"), ("musicgen", "MusicgenForConditionalGeneration"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index bb0c13f7dbcc..169691e45a1c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -172,6 +172,7 @@ ("led", "LEDTokenizer" if is_tokenizers_available() else None), ("lighton_ocr", "Qwen2TokenizerFast" if is_tokenizers_available() else None), ("lilt", "RobertaTokenizer" if is_tokenizers_available() else None), + ("llama", "LlamaTokenizer" if is_tokenizers_available() else None), ("longformer", "RobertaTokenizer" if is_tokenizers_available() else None), ("luke", "LukeTokenizer"), ("lxmert", "LxmertTokenizer" if is_tokenizers_available() else None), @@ -822,6 +823,11 @@ def from_pretrained( model_type = config_class_to_model_type(type(config).__name__) or getattr(config, "model_type", None) if model_type is not None: + if model_type == "voxtral" and not is_mistral_common_available(): + raise ImportError( + "The Voxtral tokenizer requires the 'mistral-common' package. " + "Use `pip install mistral-common` to install the package." + ) tokenizer_class = TOKENIZER_MAPPING.get(type(config), TokenizersBackend) if tokenizer_class is not None: return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py index ae5b6f8b9ed3..49db6d3efc9a 100644 --- a/src/transformers/models/auto/video_processing_auto.py +++ b/src/transformers/models/auto/video_processing_auto.py @@ -80,6 +80,8 @@ def video_processor_class_from_name(class_name: str): for module_name, extractor in VIDEO_PROCESSOR_MAPPING_NAMES.items(): + if extractor is None: + continue if class_name == extractor: module_name = model_type_to_module_name(module_name) diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 38f434d405a5..f74f1a85677a 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -227,9 +227,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/aya_vision/processing_aya_vision.py b/src/transformers/models/aya_vision/processing_aya_vision.py index 90188519aba7..83876f2fc46c 100644 --- a/src/transformers/models/aya_vision/processing_aya_vision.py +++ b/src/transformers/models/aya_vision/processing_aya_vision.py @@ -18,9 +18,11 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring +from ..got_ocr2.image_processing_got_ocr2 import GotOcr2ImageProcessorKwargs class AyaVisionProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: GotOcr2ImageProcessorKwargs _defaults = { "text_kwargs": { "padding_side": "left", diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 90129fc998b1..3782da89ce24 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -132,7 +132,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index 53053f644539..a95c8e9752be 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -127,9 +127,10 @@ def reduce_label(self, labels: list["torch.Tensor"]) -> list["torch.Tensor"]: """Reduce label values by 1, replacing 0 with 255.""" for idx in range(len(labels)): label = labels[idx] - label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype, device=label.device), label) - label = label - 1 - label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype, device=label.device), label) + ignore_mask = (label == 0) | (label == 255) + label = label.clone() + label[ignore_mask] = 255 + label[~ignore_mask] = label[~ignore_mask] - 1 labels[idx] = label return labels diff --git a/src/transformers/models/beit/image_processing_pil_beit.py b/src/transformers/models/beit/image_processing_pil_beit.py index e3ccf12e909b..ff78dac96c40 100644 --- a/src/transformers/models/beit/image_processing_pil_beit.py +++ b/src/transformers/models/beit/image_processing_pil_beit.py @@ -120,10 +120,10 @@ def _preprocess_image_like_inputs( def reduce_label(self, image: np.ndarray) -> np.ndarray: """Reduce label values by 1, replacing 0 with 255.""" - # Avoid using underflow conversion - image[image == 0] = 255 - image = image - 1 - image[image == 254] = 255 + image = image.copy() + ignore_mask = (image == 0) | (image == 255) + image[ignore_mask] = 255 + image[~ignore_mask] = image[~ignore_mask] - 1 return image def _preprocess( diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 14c1581b250f..6f6f969e9b6b 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -318,7 +318,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index c5c022d39066..12cc101356d7 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1240,7 +1240,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple @@ -1686,7 +1686,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple @@ -1913,7 +1913,7 @@ def generate( else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) @@ -2054,13 +2054,7 @@ def forward( if use_image_text_matching_head: query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - if self.config.image_token_index is not None: - input_ids = input_ids[:, self.config.num_query_tokens :] - else: - query_attention_mask = torch.ones( - query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device - ) - attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) + input_ids = input_ids[:, self.config.num_query_tokens :] query_embeds = self.embeddings( input_ids=input_ids, @@ -2092,9 +2086,8 @@ def forward( image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state image_embeds = image_embeds.to(dtype=self.vision_projection.weight.dtype) - if self.config.image_token_index is not None: - input_ids = input_ids[:, self.config.num_query_tokens :] - attention_mask = attention_mask[:, self.config.num_query_tokens :] + input_ids = input_ids[:, self.config.num_query_tokens :] + attention_mask = attention_mask[:, self.config.num_query_tokens :] query_embeds = self.embeddings( input_ids=input_ids, diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index e339854a6736..c8feeed2b822 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -77,8 +77,16 @@ def __call__( return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) max_length = output_kwargs["text_kwargs"].pop("max_length", None) if max_length is not None: - output_kwargs["text_kwargs"]["max_length"] = max_length - self.num_query_tokens - + num_query_tokens = self.num_query_tokens + if num_query_tokens is None: + logger.warning( + "Blip2Processor.num_query_tokens is None. Treating it as 0 for max_length calculations. " + "Consider updating the processor to set num_query_tokens explicitly." + ) + num_query_tokens = 0 + adjusted_max_length = max_length - num_query_tokens + if adjusted_max_length > 0: + output_kwargs["text_kwargs"]["max_length"] = adjusted_max_length encoding = BatchFeature(tensor_type=return_tensors) if text is not None: if isinstance(text, str): diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 225289d8367e..d5d1b1f03a7e 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -820,7 +820,7 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) - buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1) + buffered_token_type_ids = self.token_type_ids.to(position_ids.device).expand(position_ids.shape[0], -1) buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length) else: diff --git a/src/transformers/models/bridgetower/processing_bridgetower.py b/src/transformers/models/bridgetower/processing_bridgetower.py index aa0ea7b4c4da..9424362e519c 100644 --- a/src/transformers/models/bridgetower/processing_bridgetower.py +++ b/src/transformers/models/bridgetower/processing_bridgetower.py @@ -17,9 +17,11 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin from ...utils import auto_docstring +from .image_processing_bridgetower import BridgeTowerImageProcessorKwargs class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: BridgeTowerImageProcessorKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 9d10a8aeaef1..c47245a0ae2b 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -106,7 +106,7 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) - buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1) + buffered_token_type_ids = self.token_type_ids.to(position_ids.device).expand(position_ids.shape[0], -1) buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length) else: diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index af69779959e4..fe3243f01dc8 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -143,7 +143,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -911,9 +911,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 3c2ddef2e7a4..99828afbda36 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -517,7 +517,12 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): config: ChineseCLIPConfig base_model_prefix = "chinese_clip" input_modalities = ("image", "text") - _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPTextEmbeddings", "ChineseCLIPVisionAttention"] + _no_split_modules = [ + "ChineseCLIPVisionEmbeddings", + "ChineseCLIPTextEmbeddings", + "ChineseCLIPTextLayer", + "ChineseCLIPVisionAttention", + ] supports_gradient_checkpointing = True _supports_sdpa = True @@ -653,7 +658,7 @@ def __init__(self, config: ChineseCLIPVisionConfig): embed_dim = config.hidden_size self.embeddings = ChineseCLIPVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = ChineseCLIPVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @@ -690,7 +695,7 @@ def forward( >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.pre_layernorm(hidden_states) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/chinese_clip/modular_chinese_clip.py b/src/transformers/models/chinese_clip/modular_chinese_clip.py index 280cb7bd54ae..bb6b05f9ac92 100644 --- a/src/transformers/models/chinese_clip/modular_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modular_chinese_clip.py @@ -197,7 +197,12 @@ class ChineseCLIPTextPooler(BertPooler): @auto_docstring class ChineseCLIPPreTrainedModel(CLIPPreTrainedModel): - _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPTextEmbeddings", "ChineseCLIPVisionAttention"] + _no_split_modules = [ + "ChineseCLIPVisionEmbeddings", + "ChineseCLIPTextEmbeddings", + "ChineseCLIPTextLayer", + "ChineseCLIPVisionAttention", + ] _can_record_outputs = { "hidden_states": ChineseCLIPVisionLayer, "attentions": ChineseCLIPVisionAttention, diff --git a/src/transformers/models/chmv2/image_processing_chmv2.py b/src/transformers/models/chmv2/image_processing_chmv2.py index 3bb82b2dea53..067ba5898734 100644 --- a/src/transformers/models/chmv2/image_processing_chmv2.py +++ b/src/transformers/models/chmv2/image_processing_chmv2.py @@ -182,9 +182,10 @@ def reduce_label(self, labels: list["torch.Tensor"]) -> list["torch.Tensor"]: """Reduce label values by 1, replacing 0 with 255.""" for idx in range(len(labels)): label = labels[idx] - label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype, device=label.device), label) - label = label - 1 - label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype, device=label.device), label) + ignore_mask = (label == 0) | (label == 255) + label = label.clone() + label[ignore_mask] = 255 + label[~ignore_mask] = label[~ignore_mask] - 1 labels[idx] = label return labels diff --git a/src/transformers/models/chmv2/modular_chmv2.py b/src/transformers/models/chmv2/modular_chmv2.py index f61c6687a351..5f44654876c6 100644 --- a/src/transformers/models/chmv2/modular_chmv2.py +++ b/src/transformers/models/chmv2/modular_chmv2.py @@ -150,6 +150,17 @@ class CHMv2ImageProcessor(DPTImageProcessor): image_std = [0.213, 0.156, 0.143] valid_kwargs = CHMv2ImageProcessorKwargs + def reduce_label(self, labels: list["torch.Tensor"]) -> list["torch.Tensor"]: + """Reduce label values by 1, replacing 0 with 255.""" + for idx in range(len(labels)): + label = labels[idx] + ignore_mask = (label == 0) | (label == 255) + label = label.clone() + label[ignore_mask] = 255 + label[~ignore_mask] = label[~ignore_mask] - 1 + labels[idx] = label + return labels + def post_process_depth_estimation( self, outputs: "DepthEstimatorOutput", diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 96c540a3424f..cf766d53a261 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -990,7 +990,7 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) - buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1) + buffered_token_type_ids = self.token_type_ids.to(position_ids.device).expand(position_ids.shape[0], -1) buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length) else: diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 2bca67e59a21..daeca0a502b1 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -481,15 +481,18 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: hidden_states = inputs_embeds + all_hidden_states = [hidden_states] if self.config.output_hidden_states else None for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, **kwargs, ) + if all_hidden_states: + all_hidden_states.append(hidden_states) return BaseModelOutput( - last_hidden_state=hidden_states, + last_hidden_state=hidden_states, hidden_states=tuple(all_hidden_states) if all_hidden_states else None ) @@ -609,7 +612,7 @@ def __init__(self, config: CLIPVisionConfig): embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = CLIPEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @@ -646,7 +649,7 @@ def forward( >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.pre_layernorm(hidden_states) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index cf17b44b00c2..a462bdc7ef40 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -708,7 +708,7 @@ def __init__(self, config: CLIPSegVisionConfig): embed_dim = config.hidden_size self.embeddings = CLIPSegVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = CLIPSegEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @@ -745,7 +745,7 @@ def forward( >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.pre_layernorm(hidden_states) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py index 048412b383e7..c1915dfcea46 100644 --- a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py @@ -188,9 +188,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/cohere2_vision/processing_cohere2_vision.py b/src/transformers/models/cohere2_vision/processing_cohere2_vision.py index 7d76f1187733..bd65e67aa1f9 100644 --- a/src/transformers/models/cohere2_vision/processing_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/processing_cohere2_vision.py @@ -18,9 +18,11 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring +from .image_processing_cohere2_vision_fast import Cohere2VisionFastImageProcessorKwargs class Cohere2VisionProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Cohere2VisionFastImageProcessorKwargs _defaults = { "text_kwargs": { "padding_side": "left", diff --git a/src/transformers/models/cohere_asr/feature_extraction_cohere_asr.py b/src/transformers/models/cohere_asr/feature_extraction_cohere_asr.py index 1192be10606d..42f4bf3117da 100644 --- a/src/transformers/models/cohere_asr/feature_extraction_cohere_asr.py +++ b/src/transformers/models/cohere_asr/feature_extraction_cohere_asr.py @@ -284,17 +284,17 @@ def __call__( f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " "We will take the mean of the channels to convert to mono." ) - raw_speech = raw_speech.mean(-1) + raw_speech = raw_speech.mean(1) is_batched_sequence = isinstance(raw_speech, (list, tuple)) if is_batched_sequence: - for speech in raw_speech: + for index, speech in enumerate(raw_speech): if len(speech.shape) > 1: logger.warning( f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " "We will take the mean of the channels to convert to mono." ) - speech = speech.mean(-1) + raw_speech[index] = speech.mean(0) if is_batched_torch or is_batched_sequence: raw_speech = [speech.to(torch.float32) for speech in raw_speech] diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 656ad6c758c5..795e6679bc7a 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -165,9 +165,7 @@ def forward( if pixel_values is not None: image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True).pooler_output - image_mask = ( - (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - ) + image_mask = (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index aa7a3f48ca6e..8394607a08de 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -304,9 +304,7 @@ def forward( if pixel_values is not None: image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True).pooler_output - image_mask = ( - (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - ) + image_mask = (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) diff --git a/src/transformers/models/colqwen2/processing_colqwen2.py b/src/transformers/models/colqwen2/processing_colqwen2.py index 48af99206afe..89b737bd5009 100644 --- a/src/transformers/models/colqwen2/processing_colqwen2.py +++ b/src/transformers/models/colqwen2/processing_colqwen2.py @@ -29,9 +29,11 @@ if is_torch_available(): import torch +from ..qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessorKwargs class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen2VLImageProcessorKwargs _defaults = { "text_kwargs": { "padding": "longest", diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index 539fe152f606..6dd8c22ff207 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -440,8 +440,13 @@ def __init__(self, **kwargs: Unpack[ConditionalDetrImageProcessorKwargs]) -> Non kwargs.setdefault("do_pad", kwargs.pop("pad_and_return_pixel_mask", self.do_pad)) size = kwargs.pop("size", None) - max_size = None if size is None else kwargs.pop("max_size", 1333) - size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + max_size = kwargs.pop("max_size", None) + + if size is None: + size = {"shortest_edge": 800, "longest_edge": max_size if max_size is not None else 1333} + elif isinstance(size, dict) and max_size is not None and "longest_edge" not in size: + size = {**size, "longest_edge": max_size} + # Convert size dict for backwards compat with max_size parameter kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) diff --git a/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py index 30740114d5f0..3f96def66064 100644 --- a/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_pil_conditional_detr.py @@ -443,13 +443,17 @@ def __init__(self, **kwargs: Unpack[ConditionalDetrImageProcessorKwargs]) -> Non kwargs.setdefault("do_pad", kwargs.pop("pad_and_return_pixel_mask", self.do_pad)) size = kwargs.pop("size", None) - max_size = None if size is None else kwargs.pop("max_size", 1333) - size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + max_size = kwargs.pop("max_size", None) + + if size is None: + size = {"shortest_edge": 800, "longest_edge": max_size if max_size is not None else 1333} + elif isinstance(size, dict) and max_size is not None and "longest_edge" not in size: + size = {**size, "longest_edge": max_size} + # Convert size dict for backwards compat with max_size parameter - if size is not None: - from ...image_processing_utils import get_size_dict + from ...image_processing_utils import get_size_dict - kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) + kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) # Backwards compatibility do_convert_annotations = kwargs.get("do_convert_annotations") diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index eb78dca8faf5..d6a38358f9d7 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -174,7 +174,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 3e0eb0504be0..afffc6e41449 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -99,7 +99,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/dac/convert_dac_checkpoint.py b/src/transformers/models/dac/convert_dac_checkpoint.py index b3360fc1706d..acfa4166414e 100644 --- a/src/transformers/models/dac/convert_dac_checkpoint.py +++ b/src/transformers/models/dac/convert_dac_checkpoint.py @@ -17,7 +17,6 @@ import numpy as np import torch -import torch.nn as nn from transformers import ( DacConfig, @@ -186,50 +185,21 @@ def recursively_load_weights(orig_dict, hf_model, model_name): logger.warning(f"Unused weights: {unused_weights}") -def apply_weight_norm(model): - weight_norm = nn.utils.weight_norm - - for layer in model.quantizer.quantizers: - weight_norm(layer.in_proj) - weight_norm(layer.out_proj) - - weight_norm(model.encoder.conv1) - weight_norm(model.encoder.conv2) - - for layer in model.encoder.block: - weight_norm(layer.conv1) - weight_norm(layer.res_unit1.conv1) - weight_norm(layer.res_unit1.conv2) - weight_norm(layer.res_unit2.conv1) - weight_norm(layer.res_unit2.conv2) - weight_norm(layer.res_unit3.conv1) - weight_norm(layer.res_unit3.conv2) - - weight_norm(model.decoder.conv1) - weight_norm(model.decoder.conv2) - - for layer in model.decoder.block: - weight_norm(layer.conv_t1) - weight_norm(layer.res_unit1.conv1) - weight_norm(layer.res_unit1.conv2) - weight_norm(layer.res_unit2.conv1) - weight_norm(layer.res_unit2.conv2) - weight_norm(layer.res_unit3.conv1) - weight_norm(layer.res_unit3.conv2) - - @torch.no_grad() def convert_checkpoint( model_name, checkpoint_path, pytorch_dump_folder_path, - sample_rate=16000, repo_id=None, + legacy_weight_norm=True, ): - model_dict = torch.load(checkpoint_path, "cpu", weights_only=True) + # NOTE: Models on Hub (https://huggingface.co/descript/models) did conversion on CPU. + # However, for equivalent weights after removing weight norm, conversion should be done on GPU. + # torch_device = "cuda" + torch_device = "cpu" + model_dict = torch.load(checkpoint_path, torch_device, weights_only=True) config = DacConfig() - metadata = model_dict["metadata"]["kwargs"] config.encoder_hidden_size = metadata["encoder_dim"] config.downsampling_ratios = metadata["encoder_rates"] @@ -239,18 +209,20 @@ def convert_checkpoint( config.decoder_hidden_size = metadata["decoder_dim"] config.upsampling_ratios = metadata["decoder_rates"] config.quantizer_dropout = float(metadata["quantizer_dropout"]) - config.sampling_rate = sample_rate + config.sampling_rate = int(metadata["sample_rate"]) config.hop_length = int(np.prod(config.downsampling_ratios)) - model = DacModel(config) + model = DacModel(config).to(torch_device) feature_extractor = DacFeatureExtractor() - feature_extractor.sampling_rate = sample_rate + feature_extractor.sampling_rate = config.sampling_rate + feature_extractor.hop_length = config.hop_length original_checkpoint = model_dict["state_dict"] - apply_weight_norm(model) + # original model uses old weight norm function + model.apply_weight_norm(legacy=legacy_weight_norm) recursively_load_weights(original_checkpoint, model, model_name) - model.remove_weight_norm() + model.remove_weight_norm(legacy=legacy_weight_norm) model.save_pretrained(pytorch_dump_folder_path) @@ -275,9 +247,14 @@ def convert_checkpoint( parser.add_argument( "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the Hugging Face hub." ) - parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor") + parser.add_argument( + "--legacy_weight_norm", + default=True, + type=bool, + help="Whether legacy weight normalization was used by original model.", + ) args = parser.parse_args() convert_checkpoint( - args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub + args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub, args.legacy_weight_norm ) diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 6ac46f78a4a6..d8f3f6ae5607 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -85,6 +85,9 @@ class DacDecoderOutput(ModelOutput): class Snake1d(nn.Module): """ A 1-dimensional Snake activation function module. + + Original version from DAC used JIT compilation: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py#L18-L33 + This leads to slight differences in output. """ def __init__(self, hidden_dim): @@ -490,9 +493,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): init.normal_(module.weight, mean=0.0, std=0.02) - def apply_weight_norm(self): + def apply_weight_norm(self, legacy=True): + # original version of DAC uses legacy weight norm weight_norm = nn.utils.weight_norm - if hasattr(nn.utils.parametrizations, "weight_norm"): + if hasattr(nn.utils.parametrizations, "weight_norm") and not legacy: weight_norm = nn.utils.parametrizations.weight_norm for layer in self.quantizer.quantizers: @@ -523,34 +527,38 @@ def apply_weight_norm(self): weight_norm(layer.res_unit3.conv1) weight_norm(layer.res_unit3.conv2) - def remove_weight_norm(self): + def remove_weight_norm(self, legacy=True): + remove_weight_norm = nn.utils.remove_weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm") and not legacy: + remove_weight_norm = torch.nn.utils.parametrize.remove_parametrizations + for layer in self.quantizer.quantizers: - nn.utils.remove_weight_norm(layer.in_proj) - nn.utils.remove_weight_norm(layer.out_proj) + remove_weight_norm(layer.in_proj, "weight") + remove_weight_norm(layer.out_proj, "weight") - nn.utils.remove_weight_norm(self.encoder.conv1) - nn.utils.remove_weight_norm(self.encoder.conv2) + remove_weight_norm(self.encoder.conv1, "weight") + remove_weight_norm(self.encoder.conv2, "weight") for layer in self.encoder.block: - nn.utils.remove_weight_norm(layer.conv1) - nn.utils.remove_weight_norm(layer.res_unit1.conv1) - nn.utils.remove_weight_norm(layer.res_unit1.conv2) - nn.utils.remove_weight_norm(layer.res_unit2.conv1) - nn.utils.remove_weight_norm(layer.res_unit2.conv2) - nn.utils.remove_weight_norm(layer.res_unit3.conv1) - nn.utils.remove_weight_norm(layer.res_unit3.conv2) + remove_weight_norm(layer.conv1, "weight") + remove_weight_norm(layer.res_unit1.conv1, "weight") + remove_weight_norm(layer.res_unit1.conv2, "weight") + remove_weight_norm(layer.res_unit2.conv1, "weight") + remove_weight_norm(layer.res_unit2.conv2, "weight") + remove_weight_norm(layer.res_unit3.conv1, "weight") + remove_weight_norm(layer.res_unit3.conv2, "weight") - nn.utils.remove_weight_norm(self.decoder.conv1) - nn.utils.remove_weight_norm(self.decoder.conv2) + remove_weight_norm(self.decoder.conv1, "weight") + remove_weight_norm(self.decoder.conv2, "weight") for layer in self.decoder.block: - nn.utils.remove_weight_norm(layer.conv_t1) - nn.utils.remove_weight_norm(layer.res_unit1.conv1) - nn.utils.remove_weight_norm(layer.res_unit1.conv2) - nn.utils.remove_weight_norm(layer.res_unit2.conv1) - nn.utils.remove_weight_norm(layer.res_unit2.conv2) - nn.utils.remove_weight_norm(layer.res_unit3.conv1) - nn.utils.remove_weight_norm(layer.res_unit3.conv2) + remove_weight_norm(layer.conv_t1, "weight") + remove_weight_norm(layer.res_unit1.conv1, "weight") + remove_weight_norm(layer.res_unit1.conv2, "weight") + remove_weight_norm(layer.res_unit2.conv1, "weight") + remove_weight_norm(layer.res_unit2.conv2, "weight") + remove_weight_norm(layer.res_unit3.conv1, "weight") + remove_weight_norm(layer.res_unit3.conv2, "weight") @auto_docstring( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index bb3af4ecf25a..dbdaec7efdbc 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -45,6 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_peft_available +from ...utils.generic import _conv_out_length from .configuration_data2vec_audio import Data2VecAudioConfig @@ -510,11 +511,6 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int add_adapter = self.config.add_adapter if add_adapter is None else add_adapter - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) @@ -1220,11 +1216,6 @@ def _get_tdnn_output_lengths(self, input_lengths: torch.LongTensor | int): Computes the output length of the TDNN layers """ - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - for kernel_size in self.config.tdnn_kernel: input_lengths = _conv_out_length(input_lengths, kernel_size, 1) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 512431cb3b0a..47f9866e9f4f 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -105,7 +105,7 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0]) - buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1) + buffered_token_type_ids = self.token_type_ids.to(position_ids.device).expand(position_ids.shape[0], -1) buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids) token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length) else: diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 58735fb55c0b..a820e61e1113 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -98,7 +98,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -594,7 +594,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -602,7 +602,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -619,8 +621,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index fe3acd9aeddd..b59f7dfcc75a 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -112,7 +112,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -384,7 +384,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) else: self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( @@ -392,7 +392,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 2bf7d347e85d..38041e0b9707 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -189,7 +189,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) else: self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( @@ -197,7 +197,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), diff --git a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py index c58f56ddfac0..ca2dbdb1ea8b 100644 --- a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py @@ -180,9 +180,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/deepseek_vl/processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/processing_deepseek_vl.py index 7057ff152a67..be55db718b82 100644 --- a/src/transformers/models/deepseek_vl/processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/processing_deepseek_vl.py @@ -24,9 +24,11 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring +from .image_processing_deepseek_vl import DeepseekVLImageProcessorKwargs class DeepseekVLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: DeepseekVLImageProcessorKwargs _defaults = { "text_kwargs": {"padding": False}, "common_kwargs": {"return_tensors": "pt"}, diff --git a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py index eb85a8d02a76..83e0c656e244 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py @@ -331,9 +331,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask @@ -373,7 +373,7 @@ def forward( else: image_attention_mask = input_ids == self.config.image_token_id - image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_attention_mask = image_attention_mask.unsqueeze(-1).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values, high_res_pixel_values, return_dict=True).pooler_output image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 99d24c163562..d1567dda59d9 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -332,7 +332,7 @@ def forward( else: image_attention_mask = input_ids == self.config.image_token_id - image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_attention_mask = image_attention_mask.unsqueeze(-1).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values, high_res_pixel_values, return_dict=True).pooler_output image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py index 7948b954b6d7..9c1f4f8c012d 100644 --- a/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py @@ -23,9 +23,11 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring +from .image_processing_deepseek_vl_hybrid import DeepseekVLHybridImageProcessorKwargs class DeepseekVLHybridProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: DeepseekVLHybridImageProcessorKwargs _defaults = { "text_kwargs": {"padding": False}, "common_kwargs": {"return_tensors": "pt"}, diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index 1bc4775255d6..246828a95756 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -284,8 +284,13 @@ def __init__(self, **kwargs: Unpack[DeformableDetrImageProcessorKwargs]) -> None kwargs.setdefault("do_pad", kwargs.pop("pad_and_return_pixel_mask", self.do_pad)) size = kwargs.pop("size", None) - max_size = None if size is None else kwargs.pop("max_size", 1333) - size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + max_size = kwargs.pop("max_size", None) + + if size is None: + size = {"shortest_edge": 800, "longest_edge": max_size if max_size is not None else 1333} + elif isinstance(size, dict) and max_size is not None and "longest_edge" not in size: + size = {**size, "longest_edge": max_size} + # Convert size dict for backwards compat with max_size parameter kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) diff --git a/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py index 9c7ccc213910..68ec02518d8b 100644 --- a/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_pil_deformable_detr.py @@ -278,13 +278,17 @@ def __init__(self, **kwargs: Unpack[DeformableDetrImageProcessorKwargs]) -> None kwargs.setdefault("do_pad", kwargs.pop("pad_and_return_pixel_mask", self.do_pad)) size = kwargs.pop("size", None) - max_size = None if size is None else kwargs.pop("max_size", 1333) - size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + max_size = kwargs.pop("max_size", None) + + if size is None: + size = {"shortest_edge": 800, "longest_edge": max_size if max_size is not None else 1333} + elif isinstance(size, dict) and max_size is not None and "longest_edge" not in size: + size = {**size, "longest_edge": max_size} + # Convert size dict for backwards compat with max_size parameter - if size is not None: - from ...image_processing_utils import get_size_dict + from ...image_processing_utils import get_size_dict - kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) + kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) # Backwards compatibility do_convert_annotations = kwargs.get("do_convert_annotations") diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index e5cfa7ce14fb..ccfc0ec10216 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -438,8 +438,13 @@ def __init__(self, **kwargs: Unpack[DetrImageProcessorKwargs]) -> None: kwargs.setdefault("do_pad", kwargs.pop("pad_and_return_pixel_mask", self.do_pad)) size = kwargs.pop("size", None) - max_size = None if size is None else kwargs.pop("max_size", 1333) - size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + max_size = kwargs.pop("max_size", None) + + if size is None: + size = {"shortest_edge": 800, "longest_edge": max_size if max_size is not None else 1333} + elif isinstance(size, dict) and max_size is not None and "longest_edge" not in size: + size = {**size, "longest_edge": max_size} + # Convert size dict for backwards compat with max_size parameter kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) diff --git a/src/transformers/models/detr/image_processing_pil_detr.py b/src/transformers/models/detr/image_processing_pil_detr.py index 14c5769549d8..e5995c70d157 100644 --- a/src/transformers/models/detr/image_processing_pil_detr.py +++ b/src/transformers/models/detr/image_processing_pil_detr.py @@ -442,13 +442,17 @@ def __init__(self, **kwargs: Unpack[DetrImageProcessorKwargs]) -> None: kwargs.setdefault("do_pad", kwargs.pop("pad_and_return_pixel_mask", self.do_pad)) size = kwargs.pop("size", None) - max_size = None if size is None else kwargs.pop("max_size", 1333) - size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + max_size = kwargs.pop("max_size", None) + + if size is None: + size = {"shortest_edge": 800, "longest_edge": max_size if max_size is not None else 1333} + elif isinstance(size, dict) and max_size is not None and "longest_edge" not in size: + size = {**size, "longest_edge": max_size} + # Convert size dict for backwards compat with max_size parameter - if size is not None: - from ...image_processing_utils import get_size_dict + from ...image_processing_utils import get_size_dict - kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) + kwargs["size"] = get_size_dict(size, max_size=max_size, default_to_square=False) # Backwards compatibility do_convert_annotations = kwargs.get("do_convert_annotations") diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 629dfd4cdb35..cc649f4459b4 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -193,7 +193,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index d80ccd572dc3..6a6703bcc50f 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -126,7 +126,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 4aad59b52a9a..de518cba9287 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -128,7 +128,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -266,6 +266,7 @@ def __init__(self, config: DogeConfig, layer_idx: int | None = None): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.keep_window_size = config.keep_window_size + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -477,7 +478,7 @@ def forward( # sequence transformation residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -493,6 +494,8 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) hidden_states = self.post_attention_residual * residual + hidden_states @@ -524,6 +527,9 @@ def _init_weights(self, module): if isinstance(module, DogeAttention): if hasattr(module, "A"): init.zeros_(module.A) + elif isinstance(module, DogeCDMoE): + if hasattr(module, "router_gate"): + init.zeros_(module.router_gate.weight) elif isinstance(module, DogeDecoderLayer): if hasattr(module, "input_residual"): init.ones_(module.input_residual) diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index 8b78126c0a00..840390ec51b8 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -204,6 +204,7 @@ def __init__(self, config: DogeConfig, layer_idx: int | None = None): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.keep_window_size = config.keep_window_size + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -403,7 +404,7 @@ def forward( # sequence transformation residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -419,6 +420,8 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) hidden_states = self.post_attention_residual * residual + hidden_states @@ -441,6 +444,9 @@ def _init_weights(self, module): if isinstance(module, DogeAttention): if hasattr(module, "A"): init.zeros_(module.A) + elif isinstance(module, DogeCDMoE): + if hasattr(module, "router_gate"): + init.zeros_(module.router_gate.weight) elif isinstance(module, DogeDecoderLayer): if hasattr(module, "input_residual"): init.ones_(module.input_residual) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 95b21258ffd5..7cfde8d1957c 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -125,7 +125,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index 6d157f6385c0..7969cead3f21 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -192,9 +192,10 @@ def reduce_label(self, labels: list["torch.Tensor"]) -> list["torch.Tensor"]: """Reduce label values by 1, replacing 0 with 255.""" for idx in range(len(labels)): label = labels[idx] - label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype, device=label.device), label) - label = label - 1 - label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype, device=label.device), label) + ignore_mask = (label == 0) | (label == 255) + label = label.clone() + label[ignore_mask] = 255 + label[~ignore_mask] = label[~ignore_mask] - 1 labels[idx] = label return labels diff --git a/src/transformers/models/dpt/image_processing_pil_dpt.py b/src/transformers/models/dpt/image_processing_pil_dpt.py index 6f770cac4e5f..07e711769829 100644 --- a/src/transformers/models/dpt/image_processing_pil_dpt.py +++ b/src/transformers/models/dpt/image_processing_pil_dpt.py @@ -180,9 +180,10 @@ def _preprocess_image_like_inputs( def reduce_label(self, image: np.ndarray) -> np.ndarray: """Reduce label values by 1, replacing 0 with 255.""" - image[image == 0] = 255 - image = image - 1 - image[image == 254] = 255 + image = image.copy() + ignore_mask = (image == 0) | (image == 255) + image[ignore_mask] = 255 + image[~ignore_mask] = image[~ignore_mask] - 1 return image def resize( diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 2481decd7aeb..8fa4fd9d9188 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1188,7 +1188,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -1447,9 +1447,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 598687892727..e37ce1eb337f 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -1016,9 +1016,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index f634f89ab89f..9718ec588100 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -35,11 +35,16 @@ logger = logging.get_logger(__name__) -DEPRECATION_WARNING = ( +# Warning about deprecated practice of passing decoder_input_ids when labels are provided +DEPRECATED_DECODER_INPUT_IDS_WARNING = ( + "The decoder_input_ids are created based on the labels, no need to pass them yourself anymore." +) + +# Warning about v4.12.0 loss computation change - always shown when training with labels +V4_12_LOSS_COMPUTATION_WARNING = ( "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the" " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" - " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the" - " labels, no need to pass them yourself anymore." + " fine-tuning a model trained with versions anterior to 4.12.0." ) @@ -423,6 +428,9 @@ def forward( ) if decoder_attention_mask is None: decoder_attention_mask = (decoder_input_ids != self.config.pad_token_id).to(decoder_input_ids.dtype) + elif (labels is not None) and (decoder_input_ids is not None): + # User provided both labels and decoder_input_ids - this is the deprecated path + warnings.warn(DEPRECATED_DECODER_INPUT_IDS_WARNING, FutureWarning) # Decode decoder_outputs = self.decoder( @@ -440,7 +448,8 @@ def forward( # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: - warnings.warn(DEPRECATION_WARNING, FutureWarning) + # Always warn about v4.12.0 loss computation change + warnings.warn(V4_12_LOSS_COMPUTATION_WARNING, FutureWarning) logits = decoder_outputs.logits loss_fct = CrossEntropyLoss() loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index 589b023d4db8..554b37da4b03 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -1109,6 +1109,14 @@ def forward( list of tuples indicating the image index and start and end positions of patches for semantic segmentation. """ + if mask_labels is not None: + target_device = pixel_values.device + mask_labels = [mask.to(target_device) for mask in mask_labels] + + if class_labels is not None: + target_device = pixel_values.device + class_labels = [label.to(target_device) for label in class_labels] + masks_queries_logits_per_layer, class_queries_logits_per_layer = (), () attention_mask = None diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index e4dafa024861..9cc2d228e24e 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -455,6 +455,14 @@ def forward( list of tuples indicating the image index and start and end positions of patches for semantic segmentation. """ + if mask_labels is not None: + target_device = pixel_values.device + mask_labels = [mask.to(target_device) for mask in mask_labels] + + if class_labels is not None: + target_device = pixel_values.device + class_labels = [label.to(target_device) for label in class_labels] + masks_queries_logits_per_layer, class_queries_logits_per_layer = (), () attention_mask = None diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 9c106a90010d..bfa8c79b6bab 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -605,7 +605,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -613,13 +613,17 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: - batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + num_hidden_layers = len(gate_logits) + batch_size = attention_mask.shape[0] + sequence_length = gate_logits[0].shape[0] // batch_size + attention_mask = attention_mask[:, -sequence_length:] # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( @@ -630,8 +634,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index f3d7bc590f5d..0a26d9796c04 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -600,7 +600,7 @@ def forward( if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, @@ -1334,18 +1334,18 @@ def get_placeholder_mask( special_video_mask = input_ids == self.config.video_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) n_video_tokens = special_video_mask.sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask = special_video_mask.unsqueeze(-1).to(inputs_embeds.device) if video_features is not None: torch_compilable_check( - inputs_embeds[special_video_mask].numel() == video_features.numel(), + n_video_tokens * inputs_embeds.shape[-1] == video_features.numel(), f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", ) return special_image_mask, special_video_mask @@ -1517,7 +1517,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -1525,7 +1525,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -1542,8 +1544,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py index 8413907ef3c2..5eab0158452f 100644 --- a/src/transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py @@ -22,9 +22,11 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...video_utils import VideoInput +from .image_processing_ernie4_5_vl_moe import Ernie4_5_VLMoeImageProcessorKwargs class Ernie4_5_VLMoeProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Ernie4_5_VLMoeImageProcessorKwargs _defaults = { "text_kwargs": { "padding": False, diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 74fd137882d2..4efa653779ea 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -1086,7 +1086,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/exaone4/configuration_exaone4.py b/src/transformers/models/exaone4/configuration_exaone4.py index f29cab8dd8ea..f53f2ce2d05d 100644 --- a/src/transformers/models/exaone4/configuration_exaone4.py +++ b/src/transformers/models/exaone4/configuration_exaone4.py @@ -98,15 +98,26 @@ class Exaone4Config(PreTrainedConfig): layer_types: list[str] | None = None def __post_init__(self, **kwargs): - if self.sliding_window is None: - self.sliding_window_pattern = 0 if self.layer_types is None: - self.layer_types = [ - "sliding_attention" - if ((i + 1) % (self.sliding_window_pattern) != 0 and i < self.num_hidden_layers) - else "full_attention" - for i in range(self.num_hidden_layers) - ] + if self.sliding_window in (None, 0): + self.layer_types = ["full_attention"] * self.num_hidden_layers + elif isinstance(self.sliding_window_pattern, str) and self.sliding_window_pattern: + layer_pattern = [ + "sliding_attention" if layer_type.upper() == "L" else "full_attention" + for layer_type in self.sliding_window_pattern + ] + self.layer_types = [ + layer_pattern[i % len(layer_pattern)] for i in range(self.num_hidden_layers - 1) + ] + ["full_attention"] + else: + repeat_period = self.sliding_window_pattern if isinstance(self.sliding_window_pattern, int) else 1 + repeat_period = max(repeat_period, 1) + self.layer_types = [ + "sliding_attention" + if ((i + 1) % repeat_period != 0 and i < self.num_hidden_layers - 1) + else "full_attention" + for i in range(self.num_hidden_layers) + ] super().__post_init__(**kwargs) diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index fab10b9b6937..2009ee162f7d 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -124,7 +124,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/exaone4/modular_exaone4.py b/src/transformers/models/exaone4/modular_exaone4.py index c6d9202170a0..89ccb28a90bd 100644 --- a/src/transformers/models/exaone4/modular_exaone4.py +++ b/src/transformers/models/exaone4/modular_exaone4.py @@ -127,15 +127,26 @@ class Exaone4Config(PreTrainedConfig): layer_types: list[str] | None = None def __post_init__(self, **kwargs): - if self.sliding_window is None: - self.sliding_window_pattern = 0 if self.layer_types is None: - self.layer_types = [ - "sliding_attention" - if ((i + 1) % (self.sliding_window_pattern) != 0 and i < self.num_hidden_layers) - else "full_attention" - for i in range(self.num_hidden_layers) - ] + if self.sliding_window in (None, 0): + self.layer_types = ["full_attention"] * self.num_hidden_layers + elif isinstance(self.sliding_window_pattern, str) and self.sliding_window_pattern: + layer_pattern = [ + "sliding_attention" if layer_type.upper() == "L" else "full_attention" + for layer_type in self.sliding_window_pattern + ] + self.layer_types = [ + layer_pattern[i % len(layer_pattern)] for i in range(self.num_hidden_layers - 1) + ] + ["full_attention"] + else: + repeat_period = self.sliding_window_pattern if isinstance(self.sliding_window_pattern, int) else 1 + repeat_period = max(repeat_period, 1) + self.layer_types = [ + "sliding_attention" + if ((i + 1) % repeat_period != 0 and i < self.num_hidden_layers - 1) + else "full_attention" + for i in range(self.num_hidden_layers) + ] super().__post_init__(**kwargs) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 016b3209b6b1..89141439e668 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -157,7 +157,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -280,15 +280,15 @@ def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Ten return query, key, value elif not self.multi_query: batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + fused_qkv = fused_qkv.view(batch_size, seq_length, -1, 3, self.head_dim) return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] else: batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) + fused_qkv = fused_qkv.view(batch_size, seq_length, -1, self.head_dim) return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads - def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + def _merge_heads(self, x: torch.Tensor, tp_aware_num_heads: int) -> torch.Tensor: """ Merge heads together over the last dimension @@ -301,17 +301,17 @@ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: # What we want to achieve is: # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim batch_size_and_num_heads, seq_length, _ = x.shape - batch_size = batch_size_and_num_heads // self.num_heads + batch_size = batch_size_and_num_heads // tp_aware_num_heads # First view to decompose the batch size # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim - x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) + x = x.view(batch_size, tp_aware_num_heads, seq_length, self.head_dim) # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim x = x.permute(0, 2, 1, 3) # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim - return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) + return x.reshape(batch_size, seq_length, tp_aware_num_heads * self.head_dim) def forward( self, @@ -326,15 +326,20 @@ def forward( **kwargs, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, query_length, _, _ = query_layer.shape - query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + tp_aware_num_heads = query_layer.shape[2] + tp_aware_key_heads = key_layer.shape[2] + tp_aware_value_heads = value_layer.shape[2] + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, tp_aware_num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, tp_aware_key_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape( + batch_size, tp_aware_value_heads, query_length, self.head_dim + ) if alibi is None: cos, sin = position_embeddings @@ -369,9 +374,9 @@ def forward( # It is unclear why dropout is not applied here (while it is with alibi). attn_output = attention_scores @ value_layer - attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.view(batch_size, tp_aware_num_heads, query_length, self.head_dim) attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = attn_output.reshape(batch_size, query_length, tp_aware_num_heads * self.head_dim) attn_output = self.dense(attn_output) @@ -392,14 +397,14 @@ def forward( ) attention_probs = None attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = attn_output.reshape(batch_size, query_length, tp_aware_num_heads * self.head_dim) attn_output = self.dense(attn_output) else: matmul_result = query_layer @ key_layer.transpose(-1, -2) # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + attention_scores = matmul_result.view(batch_size, tp_aware_num_heads, query_length, kv_length) # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype @@ -407,20 +412,22 @@ def forward( if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits = attention_scores + alibi.view(batch_size, tp_aware_num_heads, 1, -1) attention_logits *= self.inv_norm_factor attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + attention_probs_reshaped = attention_probs.view( + batch_size, tp_aware_num_heads, query_length, kv_length + ) # matmul: [batch_size * num_heads, q_length, head_dim] attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) + attn_output = self._merge_heads(attn_output, tp_aware_num_heads) attn_output = self.dense(attn_output) @@ -771,7 +778,7 @@ def forward( attention_mask=attention_mask, past_key_values=past_key_values, # Force mask creation for alibi - and_mask_function=lambda *args: torch.tensor(True, dtype=torch.bool), + and_mask_function=(lambda *args: torch.tensor(True, dtype=torch.bool)) if self.use_alibi else None, ) if alibi is not None and causal_mask is not None and causal_mask.ndim == 4: min_dtype = torch.finfo(inputs_embeds.dtype).min diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 37b5da9df4b3..9e281c1b1c0b 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -110,7 +110,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 85c2eeb82b64..53ff29d5b558 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -162,9 +162,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 100e6fa35554..5967e27b691a 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -548,7 +548,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -556,7 +556,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -573,8 +575,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/florence2/modeling_florence2.py b/src/transformers/models/florence2/modeling_florence2.py index fd941b85ce66..6abd07dddca2 100644 --- a/src/transformers/models/florence2/modeling_florence2.py +++ b/src/transformers/models/florence2/modeling_florence2.py @@ -716,9 +716,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index df57519032b9..e38b4a099ea8 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -141,9 +141,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index 76287ae3a5ea..b86a4b6cb4a8 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -30,6 +30,7 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, is_torch_available, logging, requires_backends from ...utils.import_utils import requires +from .image_processing_fuyu import FuyuImagesKwargs if is_torch_available(): @@ -56,6 +57,7 @@ class FuyuProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: FuyuImagesKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c6c5a55b8790..13a6451e112f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -154,7 +154,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 20673571b2d2..d7f347cd3a01 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -139,7 +139,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 2edd9ef5f101..50a176e4b287 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -170,7 +170,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 38f50e95bb6d..247c8788cbe8 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.deprecation import deprecate_kwarg from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs @@ -50,9 +50,6 @@ from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig -logger = logging.get_logger(__name__) - - @dataclass @auto_docstring( custom_intro=""" @@ -824,9 +821,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask @@ -1126,24 +1123,17 @@ def create_masks_for_generate( ) -class Gemma3ForSequenceClassification(Gemma3PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Gemma3Model(config) - self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() +@auto_docstring( + custom_intro=""" +Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig. +It uses the generic sequence classification implementation for efficiency and consistency.""" +) +class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel): + config: Gemma3TextConfig + input_modalities = ("text",) - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - @can_return_tuple - @auto_docstring +class Gemma3ForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel): def forward( self, input_ids: torch.LongTensor | None = None, @@ -1151,78 +1141,22 @@ def forward( attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, - use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - - transformer_outputs = self.model( - input_ids, + return super().forward( + input_ids=input_ids, attention_mask=attention_mask, - pixel_values=pixel_values, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + pixel_values=pixel_values, token_type_ids=token_type_ids, - use_cache=use_cache, - return_dict=True, + labels=labels, **kwargs, ) - hidden_states = transformer_outputs.last_hidden_state - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.text_config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.text_config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel): - """ - Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig. - It uses the generic sequence classification implementation for efficiency and consistency. - """ - - config: Gemma3TextConfig - input_modalities = ("text",) __all__ = [ diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 9de1d8172513..6d965e9f6890 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -889,24 +889,17 @@ def prepare_inputs_for_generation( return model_inputs -class Gemma3ForSequenceClassification(Gemma3PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Gemma3Model(config) - self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() +@auto_docstring( + custom_intro=""" +Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig. +It uses the generic sequence classification implementation for efficiency and consistency.""" +) +class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel): + config: Gemma3TextConfig + input_modalities = ("text",) - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - @can_return_tuple - @auto_docstring +class Gemma3ForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel): def forward( self, input_ids: torch.LongTensor | None = None, @@ -914,78 +907,22 @@ def forward( attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, - use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - - transformer_outputs = self.model( - input_ids, + return super().forward( + input_ids=input_ids, attention_mask=attention_mask, - pixel_values=pixel_values, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + pixel_values=pixel_values, token_type_ids=token_type_ids, - use_cache=use_cache, - return_dict=True, + labels=labels, **kwargs, ) - hidden_states = transformer_outputs.last_hidden_state - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.text_config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.text_config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel): - """ - Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig. - It uses the generic sequence classification implementation for efficiency and consistency. - """ - - config: Gemma3TextConfig - input_modalities = ("text",) __all__ = [ diff --git a/src/transformers/models/gemma3/processing_gemma3.py b/src/transformers/models/gemma3/processing_gemma3.py index 048fe1adfa66..f24b23db4f55 100644 --- a/src/transformers/models/gemma3/processing_gemma3.py +++ b/src/transformers/models/gemma3/processing_gemma3.py @@ -19,9 +19,11 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, to_py_obj +from .image_processing_gemma3 import Gemma3ImageProcessorKwargs class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Gemma3ImageProcessorKwargs _defaults = { "text_kwargs": { "padding": False, diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index e61c5f0038e7..60be54dacc18 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -394,7 +394,7 @@ def to_dict(self) -> dict[str, Any]: @strict class Gemma3nConfig(PreTrainedConfig): r""" - audio_soft_tokens_per_image (`int`, *optional*, defaults to 188): + audio_soft_tokens_per_audio (`int`, *optional*, defaults to 188): The number of soft tokens per audio clip. vision_soft_tokens_per_image (`int`, *optional*, defaults to 256): The number of soft tokens per image. @@ -441,7 +441,7 @@ class Gemma3nConfig(PreTrainedConfig): text_config: Gemma3nTextConfig | dict[str, Any] | None = None vision_config: Gemma3nVisionConfig | dict[str, Any] | None = None audio_config: Gemma3nAudioConfig | dict[str, Any] | None = None - audio_soft_tokens_per_image: int | None = 188 + audio_soft_tokens_per_audio: int | None = 188 vision_soft_tokens_per_image: int | None = 256 boi_token_id: int | None = 255_999 eoi_token_id: int | None = 262_144 diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 8d1c5348d378..039c8c4e84c9 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1485,7 +1485,7 @@ def forward( Returns: audio_encodings: a torch.Tensor of shape - `[batch_size, self.config.audio_soft_tokens_per_image, + `[batch_size, self.config.audio_soft_tokens_per_audio, self.config.audio_config.hidden_size]` audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. """ @@ -2040,18 +2040,18 @@ def get_placeholder_mask( special_audio_mask = input_ids == self.config.audio_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0] * image_features.shape[1]}", ) n_audio_tokens = special_audio_mask.sum() - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_audio_mask = special_audio_mask.unsqueeze(-1).to(inputs_embeds.device) if audio_features is not None: torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + n_audio_tokens * inputs_embeds.shape[-1] == audio_features.numel(), f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {audio_features.shape[0] * audio_features.shape[1]}", ) @@ -2124,7 +2124,7 @@ def forward( vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device) vision_embeds = self.embed_vision(input_ids=vision_input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds) + expanded_vision_mask = vision_mask.unsqueeze(-1) inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds) # Handle audio tokens (>= embed_audio.vocab_offset) @@ -2133,7 +2133,7 @@ def forward( audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device) audio_embeds = self.embed_audio(input_ids=audio_input_ids) audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) + expanded_audio_mask = audio_mask.unsqueeze(-1) inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) else: per_layer_inputs = None @@ -2163,7 +2163,7 @@ def forward( audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features) audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape - extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len + extra_padding_tokens = self.config.audio_soft_tokens_per_audio - audio_seq_len extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) audio_features = torch.cat((audio_features, extra_padding_features), dim=1) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index e97e1ef4c6d2..181605040330 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -357,7 +357,7 @@ class Gemma3nVisionConfig(TimmWrapperConfig): @strict class Gemma3nConfig(PreTrainedConfig): r""" - audio_soft_tokens_per_image (`int`, *optional*, defaults to 188): + audio_soft_tokens_per_audio (`int`, *optional*, defaults to 188): The number of soft tokens per audio clip. vision_soft_tokens_per_image (`int`, *optional*, defaults to 256): The number of soft tokens per image. @@ -404,7 +404,7 @@ class Gemma3nConfig(PreTrainedConfig): text_config: Gemma3nTextConfig | dict[str, Any] | None = None vision_config: Gemma3nVisionConfig | dict[str, Any] | None = None audio_config: Gemma3nAudioConfig | dict[str, Any] | None = None - audio_soft_tokens_per_image: int | None = 188 + audio_soft_tokens_per_audio: int | None = 188 vision_soft_tokens_per_image: int | None = 256 boi_token_id: int | None = 255_999 eoi_token_id: int | None = 262_144 @@ -1764,7 +1764,7 @@ def forward( Returns: audio_encodings: a torch.Tensor of shape - `[batch_size, self.config.audio_soft_tokens_per_image, + `[batch_size, self.config.audio_soft_tokens_per_audio, self.config.audio_config.hidden_size]` audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. """ @@ -2149,18 +2149,18 @@ def get_placeholder_mask( special_audio_mask = input_ids == self.config.audio_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0] * image_features.shape[1]}", ) n_audio_tokens = special_audio_mask.sum() - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_audio_mask = special_audio_mask.unsqueeze(-1).to(inputs_embeds.device) if audio_features is not None: torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + n_audio_tokens * inputs_embeds.shape[-1] == audio_features.numel(), f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {audio_features.shape[0] * audio_features.shape[1]}", ) @@ -2233,7 +2233,7 @@ def forward( vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device) vision_embeds = self.embed_vision(input_ids=vision_input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds) + expanded_vision_mask = vision_mask.unsqueeze(-1) inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds) # Handle audio tokens (>= embed_audio.vocab_offset) @@ -2242,7 +2242,7 @@ def forward( audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device) audio_embeds = self.embed_audio(input_ids=audio_input_ids) audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) + expanded_audio_mask = audio_mask.unsqueeze(-1) inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) else: per_layer_inputs = None @@ -2272,7 +2272,7 @@ def forward( audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features) audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape - extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len + extra_padding_tokens = self.config.audio_soft_tokens_per_audio - audio_seq_len extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) audio_features = torch.cat((audio_features, extra_padding_features), dim=1) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index cdc4a6daeafc..d2acb10afae5 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1442,7 +1442,7 @@ class Gemma4PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] - _supports_flash_attn = True + _supports_flash_attn = False # released checkpoints use head_dim=512, which is not supported yet by FA kernels _supports_sdpa = True _supports_flex_attn = True @@ -1941,7 +1941,8 @@ def forward( (self.config.attention_context_left - 1, self.config.attention_context_right) ), ) - attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) + if attention_mask is not None: + attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) for encoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = encoder_layer( diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index 739870f2a177..2bce7a9200c6 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -1159,6 +1159,7 @@ class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): class Gemma4PreTrainedModel(Gemma3nPreTrainedModel): _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] input_modalities = ("image", "text", "video", "audio") + _supports_flash_attn = False # released checkpoints use head_dim=512, which is not supported yet by FA kernels _can_record_outputs = None # override @torch.no_grad() @@ -1511,7 +1512,8 @@ def forward( (self.config.attention_context_left - 1, self.config.attention_context_right) ), ) - attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) + if attention_mask is not None: + attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) for encoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = encoder_layer( diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 9be97d01c425..8cfe34aaab49 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -679,7 +679,7 @@ def __init__(self, config: GitVisionConfig): embed_dim = config.hidden_size self.embeddings = GitVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = GitVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -694,7 +694,7 @@ def forward( raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.pre_layernorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 712202580943..186cbcc238e1 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -121,7 +121,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index e99930ae57f6..64c349c1a5bc 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -319,7 +319,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 81207e4c8608..9b7e01ec3d93 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -333,18 +333,18 @@ def get_placeholder_mask( special_video_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) n_video_tokens = special_video_mask.sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask = special_video_mask.unsqueeze(-1).to(inputs_embeds.device) if video_features is not None: torch_compilable_check( - inputs_embeds[special_video_mask].numel() == video_features.numel(), + n_video_tokens * inputs_embeds.shape[-1] == video_features.numel(), f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", ) return special_image_mask, special_video_mask diff --git a/src/transformers/models/glm46v/processing_glm46v.py b/src/transformers/models/glm46v/processing_glm46v.py index 9dcf7c4856e6..6c5b561a69b6 100644 --- a/src/transformers/models/glm46v/processing_glm46v.py +++ b/src/transformers/models/glm46v/processing_glm46v.py @@ -27,12 +27,14 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...video_utils import VideoInput +from .image_processing_glm46v import Glm46VImageProcessorKwargs logger = logging.get_logger(__name__) class Glm46VProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Glm46VImageProcessorKwargs _defaults = { "text_kwargs": { "padding": False, diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index cc5a564ab86f..3a61d135f417 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -102,7 +102,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py index 0b8ccc865775..153bad424033 100644 --- a/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py @@ -249,7 +249,7 @@ def __init__(self, config: Glm4MoeLiteConfig, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) else: self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = Glm4MoeLiteRMSNorm(config.q_lora_rank) + self.q_a_layernorm = Glm4MoeLiteRMSNorm(config.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( @@ -257,7 +257,7 @@ def __init__(self, config: Glm4MoeLiteConfig, layer_idx: int): self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = Glm4MoeLiteRMSNorm(self.kv_lora_rank) + self.kv_a_layernorm = Glm4MoeLiteRMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 6121dc8d3fe8..84d2be8810ac 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -305,7 +305,7 @@ def forward( if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, @@ -1176,18 +1176,18 @@ def get_placeholder_mask( special_video_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) n_video_tokens = special_video_mask.sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask = special_video_mask.unsqueeze(-1).to(inputs_embeds.device) if video_features is not None: torch_compilable_check( - inputs_embeds[special_video_mask].numel() == video_features.numel(), + n_video_tokens * inputs_embeds.shape[-1] == video_features.numel(), f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", ) return special_image_mask, special_video_mask diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index d4a34a1952ad..d1878f19644f 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -861,18 +861,18 @@ def get_placeholder_mask( special_video_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) n_video_tokens = special_video_mask.sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask = special_video_mask.unsqueeze(-1).to(inputs_embeds.device) if video_features is not None: torch_compilable_check( - inputs_embeds[special_video_mask].numel() == video_features.numel(), + n_video_tokens * inputs_embeds.shape[-1] == video_features.numel(), f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", ) return special_image_mask, special_video_mask diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 2d3e93aec9ed..cfd3b445d683 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -26,12 +26,14 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...video_utils import VideoInput +from .image_processing_glm4v import Glm4vImageProcessorKwargs logger = logging.get_logger(__name__) class Glm4vProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Glm4vImageProcessorKwargs _defaults = { "text_kwargs": { "padding": False, diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 3bf3dc157d3f..db8f7fdbb447 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -675,7 +675,7 @@ def forward( if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, @@ -1345,18 +1345,18 @@ def get_placeholder_mask( special_video_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) n_video_tokens = special_video_mask.sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask = special_video_mask.unsqueeze(-1).to(inputs_embeds.device) if video_features is not None: torch_compilable_check( - inputs_embeds[special_video_mask].numel() == video_features.numel(), + n_video_tokens * inputs_embeds.shape[-1] == video_features.numel(), f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", ) return special_image_mask, special_video_mask @@ -1515,7 +1515,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -1523,7 +1523,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -1540,8 +1542,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 012da8513453..967419aa21ad 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -129,7 +129,7 @@ def forward( if "flash" in self.config._attn_implementation: # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index e72aede3da66..4a1dd37b1b90 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -226,7 +226,7 @@ def forward( if "flash" in self.config._attn_implementation: # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, diff --git a/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py index 8f11f42794b3..9d7175de8583 100644 --- a/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from huggingface_hub.dataclasses import strict from ...configuration_utils import PreTrainedConfig @@ -116,6 +117,7 @@ class GlmMoeDsaConfig(PreTrainedConfig): mlp_layer_types: list[str] | None = None attention_bias: bool = False attention_dropout: float | int = 0.0 + num_experts: int = 256 index_topk: int = 2048 index_head_dim: int = 128 index_n_heads: int = 32 diff --git a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py index 736dcdce32c3..ccf67726d089 100644 --- a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from collections.abc import Callable from typing import Optional @@ -30,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernel_forward_from_hub +from ...integrations.dsa_kernels import act_quant, fp8_index from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -64,11 +66,12 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + + hidden_size = x.size(-1) + return hadamard_transform(x, scale=hidden_size**-0.5) def apply_rotary_pos_emb( @@ -93,13 +96,14 @@ def apply_rotary_pos_emb( Returns: `torch.Tensor`: Tensor with rotary embeddings applied, same shape as input. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - # Split-half (NeoX/Llama style): (x[:d/2], x[d/2:]) - # This matches llama's apply_rotary_pos_emb logic. - x_rotated = (x * cos) + (rotate_half(x) * sin) - return x_rotated + # Interleaved (GPT-J style): (x[0], x[1]), (x[2], x[3]), ... + # RotaryEmbedding outputs cos/sin with repeated halves for NeoX compatibility, + # while interleaved rotation expects [.., D/2] frequencies. + cos = cos[..., : x.shape[-1] // 2].unsqueeze(unsqueeze_dim) + sin = sin[..., : x.shape[-1] // 2].unsqueeze(unsqueeze_dim) + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).flatten(-2) class GlmMoeDsaIndexer(nn.Module): @@ -107,8 +111,7 @@ class GlmMoeDsaIndexer(nn.Module): DeepSeek Sparse Attention (DSA) indexer for selecting top-k tokens. The Indexer has its own lightweight projections (wq_b, wk) separate from the - main MLA attention. It uses non-interleaved (NeoX/Llama) RoPE, unlike the main attention - which uses interleaved RoPE. + main MLA attention. **Cache strategy**: The Indexer manages its own key cache (`_cached_keys`) separately from the DynamicCache used by MLA attention, since DynamicCache is sized for exactly @@ -137,9 +140,12 @@ def __init__(self, config: "GlmMoeDsaConfig", layer_idx: int): # Keeping it as a plain Linear prevents FP8 conversion (see `_keep_in_fp32_modules`). self.weights_proj = nn.Linear(self.hidden_size, self.n_heads, bias=False) self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # Indexer maintains its own key cache (not in DynamicCache, which is sized for attention layers only) self.register_buffer("_cached_keys", None, persistent=False) + self.register_buffer("_cached_keys_scales", None, persistent=False) @torch.no_grad() def forward( @@ -187,19 +193,29 @@ def forward( k_pe = apply_rotary_pos_emb(k_pe.unsqueeze(2), cos, sin, unsqueeze_dim=2).squeeze(2) # [B, S, rope_D] k = torch.cat([k_pe, k_nope], dim=-1) # [B, S, D] + q = rotate_activation(q) # [B, S, H, D] + k = rotate_activation(k) # [B, S, D] + q_fp8, q_scale = act_quant(q, self.quant_block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.quant_block_size, self.scale_fmt) + # === Key cache (managed by the indexer, not DynamicCache) === # Reset cache on prefill (new prompt) to avoid stale keys / batch-size mismatch if seq_len > 1: self._cached_keys = None + self._cached_keys_scales = None if use_cache: if self._cached_keys is not None: - k_cached = torch.cat([self._cached_keys, k], dim=1) # [B, T, D] + k_cached = torch.cat([self._cached_keys, k_fp8], dim=1) # [B, T, D] + k_scale_cached = torch.cat([self._cached_keys_scales, k_scale.squeeze(-1)], dim=1) # [B, T] else: - k_cached = k + k_cached = k_fp8 + k_scale_cached = k_scale.squeeze(-1) self._cached_keys = k_cached + self._cached_keys_scales = k_scale_cached else: - k_cached = k + k_cached = k_fp8 + k_scale_cached = k_scale.squeeze(-1) # === Scoring === # Reference: weights = weights_proj(x.float()) * n_heads^(-0.5) @@ -213,19 +229,17 @@ def forward( # Don't force fp32 inputs here: the checkpoint stores `weights_proj.weight` in bf16. # Use native dtype for matmul, then upcast the result for scoring stability. weights = self.weights_proj(hidden_states).float() * (self.n_heads**-0.5) # [B, S, H] + weights = weights * q_scale.squeeze(-1) * self.softmax_scale # [B, S, H] - # q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T] - scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale - scores = F.relu(scores) - # Weight per head and sum across heads → [B, S, T] - index_scores = torch.einsum("bsht,bsh->bst", scores, weights) + index_score = fp8_index( + q_fp8.contiguous(), weights.contiguous(), k_cached.contiguous(), k_scale_cached.contiguous() + ) # [B, S, T] if attention_mask is not None: - index_scores = index_scores + attention_mask + index_score = index_score + attention_mask - total_len = index_scores.shape[-1] - topk = min(self.index_topk, total_len) - topk_indices = index_scores.topk(topk, dim=-1).indices # [B, S, topk] + actual_topk = min(self.index_topk, index_score.shape[-1]) + topk_indices = index_score.topk(actual_topk, dim=-1)[1] # [B, S, actual_topk] return topk_indices diff --git a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py index 2e7e91200d8b..ab0026c52818 100644 --- a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from collections.abc import Callable import torch @@ -21,11 +22,11 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig +from ...integrations.dsa_kernels import act_quant, fp8_index from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...models.llama.modeling_llama import rotate_half from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import is_flash_attention_requested @@ -45,6 +46,14 @@ logger = logging.get_logger(__name__) +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + + hidden_size = x.size(-1) + return hadamard_transform(x, scale=hidden_size**-0.5) + + def apply_rotary_pos_emb( x: torch.Tensor, cos: torch.Tensor, @@ -67,13 +76,14 @@ def apply_rotary_pos_emb( Returns: `torch.Tensor`: Tensor with rotary embeddings applied, same shape as input. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - # Split-half (NeoX/Llama style): (x[:d/2], x[d/2:]) - # This matches llama's apply_rotary_pos_emb logic. - x_rotated = (x * cos) + (rotate_half(x) * sin) - return x_rotated + # Interleaved (GPT-J style): (x[0], x[1]), (x[2], x[3]), ... + # RotaryEmbedding outputs cos/sin with repeated halves for NeoX compatibility, + # while interleaved rotation expects [.., D/2] frequencies. + cos = cos[..., : x.shape[-1] // 2].unsqueeze(unsqueeze_dim) + sin = sin[..., : x.shape[-1] // 2].unsqueeze(unsqueeze_dim) + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).flatten(-2) @auto_docstring(checkpoint="zai-org/GLM-5") @@ -128,6 +138,7 @@ class GlmMoeDsaConfig(Glm4MoeLiteConfig): num_hidden_layers: int = 78 num_attention_heads: int = 64 num_key_value_heads: int = 64 + num_experts: int = 256 n_routed_experts: int = 256 routed_scaling_factor: float = 2.5 q_lora_rank: int = 2048 @@ -173,8 +184,7 @@ class GlmMoeDsaIndexer(nn.Module): DeepSeek Sparse Attention (DSA) indexer for selecting top-k tokens. The Indexer has its own lightweight projections (wq_b, wk) separate from the - main MLA attention. It uses non-interleaved (NeoX/Llama) RoPE, unlike the main attention - which uses interleaved RoPE. + main MLA attention. **Cache strategy**: The Indexer manages its own key cache (`_cached_keys`) separately from the DynamicCache used by MLA attention, since DynamicCache is sized for exactly @@ -203,9 +213,12 @@ def __init__(self, config: "GlmMoeDsaConfig", layer_idx: int): # Keeping it as a plain Linear prevents FP8 conversion (see `_keep_in_fp32_modules`). self.weights_proj = nn.Linear(self.hidden_size, self.n_heads, bias=False) self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # Indexer maintains its own key cache (not in DynamicCache, which is sized for attention layers only) self.register_buffer("_cached_keys", None, persistent=False) + self.register_buffer("_cached_keys_scales", None, persistent=False) @torch.no_grad() def forward( @@ -253,19 +266,29 @@ def forward( k_pe = apply_rotary_pos_emb(k_pe.unsqueeze(2), cos, sin, unsqueeze_dim=2).squeeze(2) # [B, S, rope_D] k = torch.cat([k_pe, k_nope], dim=-1) # [B, S, D] + q = rotate_activation(q) # [B, S, H, D] + k = rotate_activation(k) # [B, S, D] + q_fp8, q_scale = act_quant(q, self.quant_block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.quant_block_size, self.scale_fmt) + # === Key cache (managed by the indexer, not DynamicCache) === # Reset cache on prefill (new prompt) to avoid stale keys / batch-size mismatch if seq_len > 1: self._cached_keys = None + self._cached_keys_scales = None if use_cache: if self._cached_keys is not None: - k_cached = torch.cat([self._cached_keys, k], dim=1) # [B, T, D] + k_cached = torch.cat([self._cached_keys, k_fp8], dim=1) # [B, T, D] + k_scale_cached = torch.cat([self._cached_keys_scales, k_scale.squeeze(-1)], dim=1) # [B, T] else: - k_cached = k + k_cached = k_fp8 + k_scale_cached = k_scale.squeeze(-1) self._cached_keys = k_cached + self._cached_keys_scales = k_scale_cached else: - k_cached = k + k_cached = k_fp8 + k_scale_cached = k_scale.squeeze(-1) # === Scoring === # Reference: weights = weights_proj(x.float()) * n_heads^(-0.5) @@ -279,19 +302,17 @@ def forward( # Don't force fp32 inputs here: the checkpoint stores `weights_proj.weight` in bf16. # Use native dtype for matmul, then upcast the result for scoring stability. weights = self.weights_proj(hidden_states).float() * (self.n_heads**-0.5) # [B, S, H] + weights = weights * q_scale.squeeze(-1) * self.softmax_scale # [B, S, H] - # q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T] - scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale - scores = F.relu(scores) - # Weight per head and sum across heads → [B, S, T] - index_scores = torch.einsum("bsht,bsh->bst", scores, weights) + index_score = fp8_index( + q_fp8.contiguous(), weights.contiguous(), k_cached.contiguous(), k_scale_cached.contiguous() + ) # [B, S, T] if attention_mask is not None: - index_scores = index_scores + attention_mask + index_score = index_score + attention_mask - total_len = index_scores.shape[-1] - topk = min(self.index_topk, total_len) - topk_indices = index_scores.topk(topk, dim=-1).indices # [B, S, topk] + actual_topk = min(self.index_topk, index_score.shape[-1]) + topk_indices = index_score.topk(actual_topk, dim=-1)[1] # [B, S, actual_topk] return topk_indices diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 828a99a705b5..bec157674bad 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -429,7 +429,7 @@ def forward( if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, @@ -1092,18 +1092,18 @@ def get_placeholder_mask( special_video_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) if image_features is not None: torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) n_video_tokens = special_video_mask.sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_video_mask = special_video_mask.unsqueeze(-1).to(inputs_embeds.device) if video_features is not None: torch_compilable_check( - inputs_embeds[special_video_mask].numel() == video_features.numel(), + n_video_tokens * inputs_embeds.shape[-1] == video_features.numel(), f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", ) return special_image_mask, special_video_mask diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 2f71dded711d..cbd89201179a 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -182,7 +182,7 @@ def forward( if is_flash_attention_requested(self.config): # Flash Attention: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output, _ = attention_interface( self, query_states, diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index ab072a8b1f5f..2eaad185933c 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -579,9 +579,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index 709e6ca86a48..13323ab3d83c 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -81,7 +81,7 @@ class GPT2Config(PreTrainedConfig): n_layer: int = 12 n_head: int = 12 n_inner: int | None = None - activation_function: str = "gelu_new" + activation_function: str = "gelu" resid_pdrop: float | int = 0.1 embd_pdrop: float | int = 0.1 attn_pdrop: float | int = 0.1 diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 10e4b5922add..d227d71120a8 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -108,7 +108,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index e334ce023d67..d92020b0152b 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -111,7 +111,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/gpt_oss/configuration_gpt_oss.py b/src/transformers/models/gpt_oss/configuration_gpt_oss.py index 47c029a5bca9..66c993a94fdf 100644 --- a/src/transformers/models/gpt_oss/configuration_gpt_oss.py +++ b/src/transformers/models/gpt_oss/configuration_gpt_oss.py @@ -23,9 +23,6 @@ @strict class GptOssConfig(PreTrainedConfig): model_type = "gpt_oss" - attribute_map = { - "num_experts": "num_local_experts", - } default_theta = 150000.0 base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 55381a7e3c21..d0191d373238 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -475,6 +475,7 @@ def forward( "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "past_key_values": past_key_values, + "position_ids": position_ids, } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), @@ -537,7 +538,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -545,13 +546,17 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: - batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + num_hidden_layers = len(gate_logits) + batch_size = attention_mask.shape[0] + sequence_length = gate_logits[0].shape[0] // batch_size + attention_mask = attention_mask[:, -sequence_length:] # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( @@ -562,8 +567,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 934345fe6723..25927348cd9a 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -356,7 +356,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 5fb53d6afe49..40277ef1a64d 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -121,7 +121,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -577,7 +577,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -585,7 +585,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -602,8 +604,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 2e0926f3e5d4..fdabd685fe9a 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -833,7 +833,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -1258,7 +1258,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -1266,7 +1266,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -1283,8 +1285,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 741c58e005f8..36a87cc00ed4 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -23,7 +23,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..bamba.configuration_bamba import BambaConfig @@ -275,8 +275,9 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state()) or ( - attention_mask is not None and torch.all(attention_mask == 1) + if not is_torchdynamo_compiling() and ( + (past_key_values is not None and past_key_values.has_previous_state()) + or (attention_mask is not None and torch.all(attention_mask == 1)) ): mamba_mask = None return mamba_mask diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 71f8c6eaff7d..2f9533f92e45 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -524,7 +524,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -646,7 +646,7 @@ def load_balancing_loss_func( compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dtype=torch.float, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -654,7 +654,9 @@ def load_balancing_loss_func( if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) / top_k # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) @@ -671,8 +673,10 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 + # Normalize by top_k so that sum(f_i) = 1, matching the distribution of P_i + # See: https://github.com/huggingface/transformers/issues/43688 + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / ( + torch.sum(expert_attention_mask, dim=0) * top_k ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert diff --git a/src/transformers/models/grounding_dino/processing_grounding_dino.py b/src/transformers/models/grounding_dino/processing_grounding_dino.py index 7835885fd42d..4d6f0201cc7d 100644 --- a/src/transformers/models/grounding_dino/processing_grounding_dino.py +++ b/src/transformers/models/grounding_dino/processing_grounding_dino.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from .modeling_grounding_dino import GroundingDinoObjectDetectionOutput +from .image_processing_grounding_dino import GroundingDinoImageProcessorKwargs AnnotationType = dict[str, int | str | list[dict]] @@ -98,6 +99,7 @@ def get(self, key, *args, **kwargs): class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: GroundingDinoImageProcessorKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 8283fcb19e28..653867d7c5bd 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -119,7 +119,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py b/src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py index a0f106167721..eec49fca3f07 100644 --- a/src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py +++ b/src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py @@ -524,7 +524,7 @@ def forward( else audio_embeds ) inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask[..., None].expand_as(inputs_embeds), audio_embeds.to(inputs_embeds.device) + audio_token_mask[..., None], audio_embeds.to(inputs_embeds.device) ) elif audio_input_ids is not None: inputs_embeds = audio_embeds diff --git a/src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py b/src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py index 03d34b0e3444..8c48760a5a17 100644 --- a/src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py +++ b/src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py @@ -326,7 +326,7 @@ def forward( else audio_embeds ) inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask[..., None].expand_as(inputs_embeds), audio_embeds.to(inputs_embeds.device) + audio_token_mask[..., None], audio_embeds.to(inputs_embeds.device) ) elif audio_input_ids is not None: inputs_embeds = audio_embeds diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index e3934ba80f68..f0c561e93920 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -36,6 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, get_torch_context_manager_or_global_device from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import _conv_out_length from .configuration_hubert import HubertConfig @@ -676,11 +677,6 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int Computes the output length of the convolutional layers """ - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index 59a72d3269cb..dac73d85ccb2 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -22,6 +22,7 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring +from ...utils.generic import _conv_out_length from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, @@ -174,11 +175,6 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int Computes the output length of the convolutional layers """ - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index d1652d78cbbc..1812977963cf 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -376,7 +376,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 19779da0528c..970daefaa2f3 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -465,7 +465,7 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8f4578e1d0f2..687bc71f30a9 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -156,7 +156,9 @@ def expand_inputs_for_generation( return input_ids, model_kwargs -def freeze_model(model, module_exceptions=()): +def freeze_model(model, module_exceptions=None): + if module_exceptions is None: + module_exceptions = [] mapping = { "LayerNorm": nn.LayerNorm, "Linear": nn.Linear, @@ -927,11 +929,15 @@ def freeze_relevant_params(self, config=None): if config.freeze_vision_layers: freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) - def freeze_text_layers(self, module_exceptions=()): + def freeze_text_layers(self, module_exceptions=None): + if module_exceptions is None: + module_exceptions = [] for module in [self.layers, self.norm]: freeze_model(module, module_exceptions=module_exceptions) - def freeze_vision_layers(self, module_exceptions=()): + def freeze_vision_layers(self, module_exceptions=None): + if module_exceptions is None: + module_exceptions = [] freeze_model(self.vision_model, module_exceptions=module_exceptions) @merge_with_config_defaults diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index b774d10b35c7..8d099c3bbcdd 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -31,6 +31,7 @@ if is_torch_available(): import torch +from .image_processing_idefics import IdeficsImageProcessorKwargs IMAGE_TOKEN = "" @@ -52,6 +53,7 @@ class IdeficsTextKwargs(TextKwargs, total=False): class IdeficsProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: IdeficsImageProcessorKwargs text_kwargs: IdeficsTextKwargs _defaults = { "text_kwargs": { diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 5d81439e27b6..770b7d6c9fd5 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -815,7 +815,7 @@ def inputs_merger( else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states) return inputs_embeds diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index dd87290838ff..95a1c41fea03 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from ...tokenization_utils_base import PreTokenizedInput +from .image_processing_idefics2 import Idefics2ImageProcessorKwargs logger = logging.get_logger(__name__) @@ -46,6 +47,7 @@ def is_image_or_image_url(elem): class Idefics2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Idefics2ImageProcessorKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 2c58aba032cd..86a8ac50ce04 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -559,7 +559,7 @@ def inputs_merger( else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states) return inputs_embeds diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index f43ac76bf3ff..24d05f958c35 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: from ...tokenization_utils_base import PreTokenizedInput +from .image_processing_idefics3 import Idefics3ImageProcessorKwargs + logger = logging.get_logger(__name__) @@ -87,6 +89,7 @@ def get_image_prompt_string( class Idefics3ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Idefics3ImageProcessorKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 29f32f17d6c4..1faaa9f536ba 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -998,7 +998,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple @@ -1257,7 +1257,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 06d3d28b2c88..955794db2b0b 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -982,7 +982,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple @@ -1074,7 +1074,7 @@ def forward( ) special_image_mask = special_image_mask.all(-1) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) @@ -1205,7 +1205,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.video_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index d84f3fd13398..862a812fdeb5 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -209,7 +209,7 @@ def forward( ) special_image_mask = special_image_mask.all(-1) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) @@ -324,7 +324,7 @@ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch else: special_image_mask = input_ids == self.config.video_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) return special_image_mask @can_return_tuple diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 284d97406e65..7c61c4eee2b8 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -609,9 +609,9 @@ def get_placeholder_mask( n_image_tokens = special_image_mask.sum() n_image_features = image_features.shape[0] * image_features.shape[1] - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device) torch_compilable_check( - inputs_embeds[special_image_mask].numel() == image_features.numel(), + n_image_tokens * inputs_embeds.shape[-1] == image_features.numel(), f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", ) return special_image_mask diff --git a/src/transformers/models/internvl/processing_internvl.py b/src/transformers/models/internvl/processing_internvl.py index 84c611115dcf..36dc8082a4d0 100644 --- a/src/transformers/models/internvl/processing_internvl.py +++ b/src/transformers/models/internvl/processing_internvl.py @@ -21,9 +21,11 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring from ...video_utils import VideoInput +from ..got_ocr2.image_processing_got_ocr2 import GotOcr2ImageProcessorKwargs class InternVLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: GotOcr2ImageProcessorKwargs _defaults = { "text_kwargs": { "padding_side": "left", @@ -75,7 +77,7 @@ def _insert_media_placeholders( video_num_patches: list[int], image_num_patches_indices: np.ndarray, video_num_patches_indices: np.ndarray, - video_patch_indices: np.ndarray, + video_frame_indices: np.ndarray, ): """ Processes interleaved text with and