diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 358765259be1..ff2292a9153e 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from collections.abc import Callable from dataclasses import dataclass @@ -1204,11 +1203,13 @@ def generate( **kwargs, ): # 1. Handle generation config and model kwargs - generation_config = kwargs.pop("generation_config", self.generation_config) - generation_config = copy.deepcopy(generation_config) + # Pop generation_mode first since it's specific to Janus + generation_mode = kwargs.pop("generation_mode", "text") + generation_config, model_kwargs = self._prepare_generation_config( + kwargs.pop("generation_config", None), **kwargs + ) # Default to "text" generation if mode isn't provided - generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": # Set guidance_scale=None to prevent running UnbatchedCFG processor. return super().generate( @@ -1216,11 +1217,9 @@ def generate( attention_mask=attention_mask, generation_config=generation_config, guidance_scale=None, - **kwargs, + **model_kwargs, ) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - # Validate generation mode if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): raise ValueError( @@ -1326,8 +1325,14 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): + # Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`. + # Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt) + # instead of our prepared `inputs_embeds` (1 new token). + # This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`. + # Set this to `True` is also necessary to match the expected output, see the more detailed comment + # https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374. model_inputs = self.prepare_inputs_for_generation( - inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs + inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) if "attention_mask" in model_inputs: model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index debb4fb25954..5e08abbfffc1 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from collections.abc import Callable from dataclasses import dataclass @@ -968,11 +967,13 @@ def generate( **kwargs, ): # 1. Handle generation config and model kwargs - generation_config = kwargs.pop("generation_config", self.generation_config) - generation_config = copy.deepcopy(generation_config) + # Pop generation_mode first since it's specific to Janus + generation_mode = kwargs.pop("generation_mode", "text") + generation_config, model_kwargs = self._prepare_generation_config( + kwargs.pop("generation_config", None), **kwargs + ) # Default to "text" generation if mode isn't provided - generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": # Set guidance_scale=None to prevent running UnbatchedCFG processor. return super().generate( @@ -980,11 +981,9 @@ def generate( attention_mask=attention_mask, generation_config=generation_config, guidance_scale=None, - **kwargs, + **model_kwargs, ) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - # Validate generation mode if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): raise ValueError( @@ -1090,8 +1089,14 @@ def generate( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): + # Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`. + # Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt) + # instead of our prepared `inputs_embeds` (1 new token). + # This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`. + # Set this to `True` is also necessary to match the expected output, see the more detailed comment + # https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374. model_inputs = self.prepare_inputs_for_generation( - inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs + inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs ) if "attention_mask" in model_inputs: model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index e4988348f5b8..13eabc162b1c 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -544,22 +544,21 @@ def test_model_generate_images(self): 897, 4044, 1762, 4676 ], ("cuda", None): [ - 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, - 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, - 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, - 897, 4044, 1762, 4676 + 2567, 6155, 6155, 250, 15131, 15797, 15453, 12190, 3351, 10803, 10673, 3096, 14485, 5335, 6677, + 13743, 9574, 8228, 3679, 11495, 11495, 15342, 11209, 1389, 15628, 6841, 15490, 10301, 12841, 3930, + 3396, 10037, 7779, 4517, 3824, 3673, 14408, 4791, 14109, 4929, 2342, 4817, 15531, 4320, 1923, 9530, + 13086, 5212, 14575, 4212 ], ("xpu", None): [ - 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, - 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, - 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, - 897, 4044, 1762, 4676 + 4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, 14968, + 9642, 9490, 14427, 196, 15131, 6155, 4015, 2047, 15628, 4656, 14055, 13908, 3077, 4377, 11641, 4835, + 8854, 10351, 7339, 2815, 13634, 8134, 257, 3621, 7739, 9954, 5989, 11578, 8763, 12788, 7571, 13595, + 1762, 12683 ], } ) - expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device) # fmt: on - + expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device) # Compare the first 50 generated tokens. self.assertTrue(torch.allclose(expected_tokens, out[0][:50]))