From 9522674ccc6754a4cf8fc1ba001abbf84e120506 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 06:57:33 +0000 Subject: [PATCH 01/15] fix bug for janus model image generation Signed-off-by: Liu, Kaixuan --- .../models/janus/modeling_janus.py | 24 ++++++++++++++++--- .../models/janus/modular_janus.py | 24 ++++++++++++++++--- tests/models/janus/test_modeling_janus.py | 8 +++---- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 358765259be1..f2272ad1962b 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1207,6 +1207,11 @@ def generate( generation_config = kwargs.pop("generation_config", self.generation_config) generation_config = copy.deepcopy(generation_config) + # Ensure generation_config has defaults from model's config (e.g., num_return_sequences, max_length) + global_defaults = self.generation_config._get_default_generation_params() + generation_config.update(**self.generation_config.to_dict(), defaults_only=True, allow_custom_entries=True) + generation_config.update(**global_defaults, defaults_only=True) + # Default to "text" generation if mode isn't provided generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": @@ -1299,7 +1304,13 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_tokens) - if model_kwargs.get("past_key_values", None) is None: + # Get the device of language model's embed_tokens for proper tensor placement + language_model_device = self.get_input_embeddings().weight.device + + # Only prepare static cache if model is not distributed across devices + # (static cache doesn't work well with device_map="auto") + is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 + if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: # Prepare cache if not provided. model_kwargs["past_key_values"] = self._prepare_static_cache( cache_implementation=generation_config.cache_implementation or "static", @@ -1326,11 +1337,18 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): + # Ensure inputs_embeds is on the language model's device (important for device_map="auto") + if inputs_embeds.device != language_model_device: + inputs_embeds = inputs_embeds.to(language_model_device) + # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. + # Without this, prepare_inputs_for_generation would use input_ids (full prompt) in every + # iteration, causing the cache to overflow. We need it to use our prepared inputs_embeds + # which contains exactly 1 new token embedding per iteration. model_inputs = self.prepare_inputs_for_generation( - inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs + inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) if "attention_mask" in model_inputs: - model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(language_model_device) outputs = self.model.language_model( **model_inputs, diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index debb4fb25954..213f7f76a2a0 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -971,6 +971,11 @@ def generate( generation_config = kwargs.pop("generation_config", self.generation_config) generation_config = copy.deepcopy(generation_config) + # Ensure generation_config has defaults from model's config (e.g., num_return_sequences, max_length) + global_defaults = self.generation_config._get_default_generation_params() + generation_config.update(**self.generation_config.to_dict(), defaults_only=True, allow_custom_entries=True) + generation_config.update(**global_defaults, defaults_only=True) + # Default to "text" generation if mode isn't provided generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": @@ -1063,7 +1068,13 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_tokens) - if model_kwargs.get("past_key_values", None) is None: + # Get the device of language model's embed_tokens for proper tensor placement + language_model_device = self.get_input_embeddings().weight.device + + # Only prepare static cache if model is not distributed across devices + # (static cache doesn't work well with device_map="auto") + is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 + if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: # Prepare cache if not provided. model_kwargs["past_key_values"] = self._prepare_static_cache( cache_implementation=generation_config.cache_implementation or "static", @@ -1090,11 +1101,18 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): + # Ensure inputs_embeds is on the language model's device (important for device_map="auto") + if inputs_embeds.device != language_model_device: + inputs_embeds = inputs_embeds.to(language_model_device) + # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. + # Without this, prepare_inputs_for_generation would use input_ids (full prompt) in every + # iteration, causing the cache to overflow. We need it to use our prepared inputs_embeds + # which contains exactly 1 new token embedding per iteration. model_inputs = self.prepare_inputs_for_generation( - inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs + inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) if "attention_mask" in model_inputs: - model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(language_model_device) outputs = self.model.language_model( **model_inputs, diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index e5c713920980..06c509ca7365 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -516,10 +516,10 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, - 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, - 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, - 897, 4044, 1762, 4676 + 4484, 4015, 15750, 376, 15131, 4538, 8081, 7910, 5461, 3109, 8272, 11294, 7618, 14526, 9074, 7138, + 5396, 14184, 3300, 9106, 7762, 1760, 2231, 15059, 8282, 14776, 7332, 5481, 1071, 14292, 6348, 934, + 94, 10136, 8463, 8540, 7724, 4982, 8682, 9911, 7901, 1644, 4867, 8786, 5857, 7677, 14448, 6594, + 1762, 722 ], ("xpu", None): [ 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, From 5283acbcba4185bfca096e0beabaf7e92dedd37b Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 07:28:03 +0000 Subject: [PATCH 02/15] update expected tokens Signed-off-by: Liu, Kaixuan --- tests/models/janus/test_modeling_janus.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 06c509ca7365..0a8c1659b897 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -516,16 +516,16 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 376, 15131, 4538, 8081, 7910, 5461, 3109, 8272, 11294, 7618, 14526, 9074, 7138, - 5396, 14184, 3300, 9106, 7762, 1760, 2231, 15059, 8282, 14776, 7332, 5481, 1071, 14292, 6348, 934, - 94, 10136, 8463, 8540, 7724, 4982, 8682, 9911, 7901, 1644, 4867, 8786, 5857, 7677, 14448, 6594, - 1762, 722 + 4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522, + 14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037, + 12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044, + 1762, 4080 ], ("xpu", None): [ - 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, - 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, - 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, - 897, 4044, 1762, 4676 + 2567, 4015, 6155, 250, 14695, 3880, 11648, 14407, 8900, 13489, 6305, 2546, 9714, 9599, 6882, 9808, + 8618, 12636, 1151, 250, 250, 7311, 5176, 2273, 15628, 4462, 14906, 4221, 4320, 8389, 6415, 13006, + 12259, 13430, 6528, 13060, 9178, 12352, 11990, 1552, 11091, 16037, 12295, 3641, 8348, 8348, 3738, + 11697, 1762, 12683 ], } ) From 234f3d80cd1205a9ce550fdf0d1bb1fb48ba0972 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 07:48:59 +0000 Subject: [PATCH 03/15] update Signed-off-by: Liu, Kaixuan --- tests/models/janus/test_modeling_janus.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 0a8c1659b897..55086ba6b260 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -35,6 +35,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES from transformers.testing_utils import ( Expectations, + require_deterministic_for_xpu, require_torch, slow, torch_device, @@ -478,6 +479,7 @@ def test_model_text_generation_with_multi_image(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow + @require_deterministic_for_xpu def test_model_generate_images(self): model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto") processor = AutoProcessor.from_pretrained(self.model_id) @@ -522,16 +524,15 @@ def test_model_generate_images(self): 1762, 4080 ], ("xpu", None): [ - 2567, 4015, 6155, 250, 14695, 3880, 11648, 14407, 8900, 13489, 6305, 2546, 9714, 9599, 6882, 9808, - 8618, 12636, 1151, 250, 250, 7311, 5176, 2273, 15628, 4462, 14906, 4221, 4320, 8389, 6415, 13006, - 12259, 13430, 6528, 13060, 9178, 12352, 11990, 1552, 11091, 16037, 12295, 3641, 8348, 8348, 3738, - 11697, 1762, 12683 + 4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, + 14968, 9642, 9490, 14427, 196, 15131, 6155, 4015, 2047, 15628, 4656, 14055, 13908, 3077, 4377, + 11641, 4835, 8854, 10351, 7339, 2815, 13634, 8134, 257, 3621, 7739, 9954, 5989, 11578, 8763, + 12788, 7571, 13595, 1762, 12683 ], } ) - expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device) # fmt: on - + expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device) # Compare the first 50 generated tokens. self.assertTrue(torch.allclose(expected_tokens, out[0][:50])) From bb5bbddd807cd7029a159a3ac5938eae4d23d217 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 07:58:34 +0000 Subject: [PATCH 04/15] update comment Signed-off-by: Liu, Kaixuan --- src/transformers/models/janus/modeling_janus.py | 3 +-- src/transformers/models/janus/modular_janus.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index f2272ad1962b..0680320e7a05 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1308,7 +1308,6 @@ def generate( language_model_device = self.get_input_embeddings().weight.device # Only prepare static cache if model is not distributed across devices - # (static cache doesn't work well with device_map="auto") is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: # Prepare cache if not provided. @@ -1337,7 +1336,7 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): - # Ensure inputs_embeds is on the language model's device (important for device_map="auto") + # Ensure inputs_embeds is on the language model's device (important for multi devices setting) if inputs_embeds.device != language_model_device: inputs_embeds = inputs_embeds.to(language_model_device) # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 213f7f76a2a0..f5d2bad6b1b7 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1072,7 +1072,6 @@ def generate( language_model_device = self.get_input_embeddings().weight.device # Only prepare static cache if model is not distributed across devices - # (static cache doesn't work well with device_map="auto") is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: # Prepare cache if not provided. @@ -1101,7 +1100,7 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): - # Ensure inputs_embeds is on the language model's device (important for device_map="auto") + # Ensure inputs_embeds is on the language model's device (important for multi devices setting) if inputs_embeds.device != language_model_device: inputs_embeds = inputs_embeds.to(language_model_device) # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. From a53ec209e4b0a85b90d85781c92b922d2adb04dd Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 08:53:08 +0000 Subject: [PATCH 05/15] use `_preapre_generation_config` Signed-off-by: Liu, Kaixuan --- .../models/janus/modeling_janus.py | 18 ++++++------------ src/transformers/models/janus/modular_janus.py | 18 ++++++------------ src/transformers/utils/network_logging.py | 2 ++ tests/models/janus/test_modeling_janus.py | 16 ++++++++-------- 4 files changed, 22 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 0680320e7a05..3fff676fe2d7 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from collections.abc import Callable from dataclasses import dataclass @@ -1204,16 +1203,13 @@ def generate( **kwargs, ): # 1. Handle generation config and model kwargs - generation_config = kwargs.pop("generation_config", self.generation_config) - generation_config = copy.deepcopy(generation_config) - - # Ensure generation_config has defaults from model's config (e.g., num_return_sequences, max_length) - global_defaults = self.generation_config._get_default_generation_params() - generation_config.update(**self.generation_config.to_dict(), defaults_only=True, allow_custom_entries=True) - generation_config.update(**global_defaults, defaults_only=True) + # Pop generation_mode first since it's specific to Janus + generation_mode = kwargs.pop("generation_mode", "text") + generation_config, model_kwargs = self._prepare_generation_config( + kwargs.pop("generation_config", None), **kwargs + ) # Default to "text" generation if mode isn't provided - generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": # Set guidance_scale=None to prevent running UnbatchedCFG processor. return super().generate( @@ -1221,11 +1217,9 @@ def generate( attention_mask=attention_mask, generation_config=generation_config, guidance_scale=None, - **kwargs, + **model_kwargs, ) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - # Validate generation mode if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): raise ValueError( diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index f5d2bad6b1b7..61a5d18f87e5 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from collections.abc import Callable from dataclasses import dataclass @@ -968,16 +967,13 @@ def generate( **kwargs, ): # 1. Handle generation config and model kwargs - generation_config = kwargs.pop("generation_config", self.generation_config) - generation_config = copy.deepcopy(generation_config) - - # Ensure generation_config has defaults from model's config (e.g., num_return_sequences, max_length) - global_defaults = self.generation_config._get_default_generation_params() - generation_config.update(**self.generation_config.to_dict(), defaults_only=True, allow_custom_entries=True) - generation_config.update(**global_defaults, defaults_only=True) + # Pop generation_mode first since it's specific to Janus + generation_mode = kwargs.pop("generation_mode", "text") + generation_config, model_kwargs = self._prepare_generation_config( + kwargs.pop("generation_config", None), **kwargs + ) # Default to "text" generation if mode isn't provided - generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": # Set guidance_scale=None to prevent running UnbatchedCFG processor. return super().generate( @@ -985,11 +981,9 @@ def generate( attention_mask=attention_mask, generation_config=generation_config, guidance_scale=None, - **kwargs, + **model_kwargs, ) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - # Validate generation mode if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): raise ValueError( diff --git a/src/transformers/utils/network_logging.py b/src/transformers/utils/network_logging.py index 92f74ccd6d18..ac3b6bb02115 100644 --- a/src/transformers/utils/network_logging.py +++ b/src/transformers/utils/network_logging.py @@ -25,6 +25,7 @@ from typing import Any import httpx +import pytest from .generic import strtobool @@ -437,6 +438,7 @@ def pytest_configure(self, config): if shared_dir: _NETWORK_DEBUG_PROFILER.set_shared_dir(shared_dir) + @pytest.hookimpl(optionalhook=True) def pytest_configure_node(self, node): """xdist hook: called on the controller to configure each worker node.""" shared_dir = getattr(node.config, "_network_debug_shared_dir", None) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 55086ba6b260..0e1d0c106c07 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -518,16 +518,16 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522, - 14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037, - 12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044, - 1762, 4080 + 4484, 4015, 15750, 376, 15131, 4538, 8081, 7910, 5461, 3109, 8272, 11294, 7618, 14526, 9074, 7138, + 5396, 14184, 3300, 9106, 7762, 1760, 2231, 15059, 8282, 14776, 7332, 5481, 1071, 14292, 6348, 934, + 94, 10136, 8463, 8540, 7724, 4982, 8682, 9911, 7901, 1644, 4867, 8786, 5857, 7677, 14448, 6594, + 1762, 722 ], ("xpu", None): [ - 4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, - 14968, 9642, 9490, 14427, 196, 15131, 6155, 4015, 2047, 15628, 4656, 14055, 13908, 3077, 4377, - 11641, 4835, 8854, 10351, 7339, 2815, 13634, 8134, 257, 3621, 7739, 9954, 5989, 11578, 8763, - 12788, 7571, 13595, 1762, 12683 + 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, + 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, + 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, + 897, 4044, 1762, 4676 ], } ) From 68e71adb39300876ecad7c8ca90764624466e49a Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 08:56:37 +0000 Subject: [PATCH 06/15] update Signed-off-by: Liu, Kaixuan --- src/transformers/utils/network_logging.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/utils/network_logging.py b/src/transformers/utils/network_logging.py index ac3b6bb02115..92f74ccd6d18 100644 --- a/src/transformers/utils/network_logging.py +++ b/src/transformers/utils/network_logging.py @@ -25,7 +25,6 @@ from typing import Any import httpx -import pytest from .generic import strtobool @@ -438,7 +437,6 @@ def pytest_configure(self, config): if shared_dir: _NETWORK_DEBUG_PROFILER.set_shared_dir(shared_dir) - @pytest.hookimpl(optionalhook=True) def pytest_configure_node(self, node): """xdist hook: called on the controller to configure each worker node.""" shared_dir = getattr(node.config, "_network_debug_shared_dir", None) From 1c1c25dad0d9b94f1a53ba28761511fb2cda18ef Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 09:02:39 +0000 Subject: [PATCH 07/15] update expected token Signed-off-by: Liu, Kaixuan --- tests/models/janus/test_modeling_janus.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 0e1d0c106c07..378022421535 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -518,10 +518,10 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 376, 15131, 4538, 8081, 7910, 5461, 3109, 8272, 11294, 7618, 14526, 9074, 7138, - 5396, 14184, 3300, 9106, 7762, 1760, 2231, 15059, 8282, 14776, 7332, 5481, 1071, 14292, 6348, 934, - 94, 10136, 8463, 8540, 7724, 4982, 8682, 9911, 7901, 1644, 4867, 8786, 5857, 7677, 14448, 6594, - 1762, 722 + 4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522, + 14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037, + 12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044, + 1762, 4080 ], ("xpu", None): [ 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, From ff3b9f821a7b66308b73bc722f4d6ff3b3369989 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 14:10:32 +0000 Subject: [PATCH 08/15] update code Signed-off-by: Liu, Kaixuan --- src/transformers/models/janus/modeling_janus.py | 8 -------- src/transformers/models/janus/modular_janus.py | 8 -------- tests/models/janus/test_modeling_janus.py | 8 ++++---- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 3fff676fe2d7..86b7ebdb7e51 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1298,9 +1298,6 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_tokens) - # Get the device of language model's embed_tokens for proper tensor placement - language_model_device = self.get_input_embeddings().weight.device - # Only prepare static cache if model is not distributed across devices is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: @@ -1330,9 +1327,6 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): - # Ensure inputs_embeds is on the language model's device (important for multi devices setting) - if inputs_embeds.device != language_model_device: - inputs_embeds = inputs_embeds.to(language_model_device) # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. # Without this, prepare_inputs_for_generation would use input_ids (full prompt) in every # iteration, causing the cache to overflow. We need it to use our prepared inputs_embeds @@ -1340,8 +1334,6 @@ def generate( model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) - if "attention_mask" in model_inputs: - model_inputs["attention_mask"] = model_inputs["attention_mask"].to(language_model_device) outputs = self.model.language_model( **model_inputs, diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 61a5d18f87e5..1428d4ef8a6f 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1062,9 +1062,6 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_tokens) - # Get the device of language model's embed_tokens for proper tensor placement - language_model_device = self.get_input_embeddings().weight.device - # Only prepare static cache if model is not distributed across devices is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: @@ -1094,9 +1091,6 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): - # Ensure inputs_embeds is on the language model's device (important for multi devices setting) - if inputs_embeds.device != language_model_device: - inputs_embeds = inputs_embeds.to(language_model_device) # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. # Without this, prepare_inputs_for_generation would use input_ids (full prompt) in every # iteration, causing the cache to overflow. We need it to use our prepared inputs_embeds @@ -1104,8 +1098,6 @@ def generate( model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) - if "attention_mask" in model_inputs: - model_inputs["attention_mask"] = model_inputs["attention_mask"].to(language_model_device) outputs = self.model.language_model( **model_inputs, diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 378022421535..0e1d0c106c07 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -518,10 +518,10 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522, - 14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037, - 12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044, - 1762, 4080 + 4484, 4015, 15750, 376, 15131, 4538, 8081, 7910, 5461, 3109, 8272, 11294, 7618, 14526, 9074, 7138, + 5396, 14184, 3300, 9106, 7762, 1760, 2231, 15059, 8282, 14776, 7332, 5481, 1071, 14292, 6348, 934, + 94, 10136, 8463, 8540, 7724, 4982, 8682, 9911, 7901, 1644, 4867, 8786, 5857, 7677, 14448, 6594, + 1762, 722 ], ("xpu", None): [ 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, From f99704b36643b5d852e5ff8efdc10d061dc40e27 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 14:17:03 +0000 Subject: [PATCH 09/15] update Signed-off-by: Liu, Kaixuan --- tests/models/janus/test_modeling_janus.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 0e1d0c106c07..773268a067af 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -524,10 +524,10 @@ def test_model_generate_images(self): 1762, 722 ], ("xpu", None): [ - 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, - 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, - 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, - 897, 4044, 1762, 4676 + 4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, 14968, + 9642, 9490, 14427, 196, 15131, 6155, 4015, 2047, 15628, 4656, 14055, 13908, 3077, 4377, 11641, 4835, + 8854, 10351, 7339, 2815, 13634, 8134, 257, 3621, 7739, 9954, 5989, 11578, 8763, 12788, 7571, 13595, + 1762, 12683 ], } ) From 7a123d1f50f380022571ec39d8866c38b36091cd Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 27 Mar 2026 14:31:40 +0000 Subject: [PATCH 10/15] update Signed-off-by: Liu, Kaixuan --- src/transformers/models/janus/modeling_janus.py | 2 ++ src/transformers/models/janus/modular_janus.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 86b7ebdb7e51..6151065d1da7 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1334,6 +1334,8 @@ def generate( model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) + if "attention_mask" in model_inputs: + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) outputs = self.model.language_model( **model_inputs, diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 1428d4ef8a6f..c1e48c75237c 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1098,6 +1098,8 @@ def generate( model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) + if "attention_mask" in model_inputs: + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) outputs = self.model.language_model( **model_inputs, From 1b032fe022f9c87a46aa3a83b250b48f39fbb3b0 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Mon, 30 Mar 2026 03:24:32 +0000 Subject: [PATCH 11/15] update comments Signed-off-by: Liu, Kaixuan --- src/transformers/models/janus/modeling_janus.py | 6 +++--- src/transformers/models/janus/modular_janus.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 6151065d1da7..258e235bb93d 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1328,9 +1328,9 @@ def generate( for i in range(num_image_tokens): # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. - # Without this, prepare_inputs_for_generation would use input_ids (full prompt) in every - # iteration, causing the cache to overflow. We need it to use our prepared inputs_embeds - # which contains exactly 1 new token embedding per iteration. + # Without this, prepare_inputs_for_generation would use input_ids (the full prompt) + # instead of our prepared inputs_embeds (1 new token). This causes position_ids to be + # computed incorrectly based on cache length, leading to RoPE index out of bounds errors. model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index c1e48c75237c..7a889acec771 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1092,9 +1092,9 @@ def generate( for i in range(num_image_tokens): # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. - # Without this, prepare_inputs_for_generation would use input_ids (full prompt) in every - # iteration, causing the cache to overflow. We need it to use our prepared inputs_embeds - # which contains exactly 1 new token embedding per iteration. + # Without this, prepare_inputs_for_generation would use input_ids (the full prompt) + # instead of our prepared inputs_embeds (1 new token). This causes position_ids to be + # computed incorrectly based on cache length, leading to RoPE index out of bounds errors. model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) From 7b7518693ae6cf692af541310ce11b49001b97c4 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Mon, 30 Mar 2026 05:25:49 +0000 Subject: [PATCH 12/15] update Signed-off-by: Liu, Kaixuan --- src/transformers/models/janus/modeling_janus.py | 4 +--- src/transformers/models/janus/modular_janus.py | 4 +--- tests/models/janus/test_modeling_janus.py | 8 ++++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 258e235bb93d..edd47bc1dcbc 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1298,9 +1298,7 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_tokens) - # Only prepare static cache if model is not distributed across devices - is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 - if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: + if model_kwargs.get("past_key_values", None) is None: # Prepare cache if not provided. model_kwargs["past_key_values"] = self._prepare_static_cache( cache_implementation=generation_config.cache_implementation or "static", diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 7a889acec771..dcc710c2fa54 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1062,9 +1062,7 @@ def generate( inputs_embeds = self.get_input_embeddings()(input_tokens) - # Only prepare static cache if model is not distributed across devices - is_model_distributed = hasattr(self, "hf_device_map") and len(set(self.hf_device_map.values())) > 1 - if model_kwargs.get("past_key_values", None) is None and not is_model_distributed: + if model_kwargs.get("past_key_values", None) is None: # Prepare cache if not provided. model_kwargs["past_key_values"] = self._prepare_static_cache( cache_implementation=generation_config.cache_implementation or "static", diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 773268a067af..b7e80e11b525 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -518,10 +518,10 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 376, 15131, 4538, 8081, 7910, 5461, 3109, 8272, 11294, 7618, 14526, 9074, 7138, - 5396, 14184, 3300, 9106, 7762, 1760, 2231, 15059, 8282, 14776, 7332, 5481, 1071, 14292, 6348, 934, - 94, 10136, 8463, 8540, 7724, 4982, 8682, 9911, 7901, 1644, 4867, 8786, 5857, 7677, 14448, 6594, - 1762, 722 + 4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522, + 14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037, + 12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044, + 1762, 4080 ], ("xpu", None): [ 4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, 14968, From 96779dcff177dfa3d06a575d2dd2a96b845c2fb7 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 1 Apr 2026 11:20:41 +0200 Subject: [PATCH 13/15] update --- tests/models/janus/test_modeling_janus.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 39b4b9253e4a..13eabc162b1c 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -544,10 +544,10 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522, - 14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037, - 12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044, - 1762, 4080 + 2567, 6155, 6155, 250, 15131, 15797, 15453, 12190, 3351, 10803, 10673, 3096, 14485, 5335, 6677, + 13743, 9574, 8228, 3679, 11495, 11495, 15342, 11209, 1389, 15628, 6841, 15490, 10301, 12841, 3930, + 3396, 10037, 7779, 4517, 3824, 3673, 14408, 4791, 14109, 4929, 2342, 4817, 15531, 4320, 1923, 9530, + 13086, 5212, 14575, 4212 ], ("xpu", None): [ 4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, 14968, From 439b0cc52b5aeea13386b7a63a40118503ac1b2a Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 1 Apr 2026 11:45:29 +0200 Subject: [PATCH 14/15] update --- src/transformers/models/janus/modeling_janus.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index edd47bc1dcbc..ff2292a9153e 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1325,10 +1325,12 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): - # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. - # Without this, prepare_inputs_for_generation would use input_ids (the full prompt) - # instead of our prepared inputs_embeds (1 new token). This causes position_ids to be - # computed incorrectly based on cache length, leading to RoPE index out of bounds errors. + # Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`. + # Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt) + # instead of our prepared `inputs_embeds` (1 new token). + # This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`. + # Set this to `True` is also necessary to match the expected output, see the more detailed comment + # https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374. model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) From e634aa1bcb43e81bc12e4977bf2a673838ef7836 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 1 Apr 2026 11:50:17 +0200 Subject: [PATCH 15/15] update --- src/transformers/models/janus/modular_janus.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index dcc710c2fa54..5e08abbfffc1 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1089,10 +1089,12 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): - # Set is_first_iteration=True to force using inputs_embeds instead of input_ids. - # Without this, prepare_inputs_for_generation would use input_ids (the full prompt) - # instead of our prepared inputs_embeds (1 new token). This causes position_ids to be - # computed incorrectly based on cache length, leading to RoPE index out of bounds errors. + # Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`. + # Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt) + # instead of our prepared `inputs_embeds` (1 new token). + # This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`. + # Set this to `True` is also necessary to match the expected output, see the more detailed comment + # https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374. model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs )