Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/transformers/models/janus/modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections.abc import Callable
from dataclasses import dataclass

Expand Down Expand Up @@ -1204,23 +1203,23 @@ def generate(
**kwargs,
):
# 1. Handle generation config and model kwargs
generation_config = kwargs.pop("generation_config", self.generation_config)
generation_config = copy.deepcopy(generation_config)
# Pop generation_mode first since it's specific to Janus
generation_mode = kwargs.pop("generation_mode", "text")
generation_config, model_kwargs = self._prepare_generation_config(
kwargs.pop("generation_config", None), **kwargs
)

# Default to "text" generation if mode isn't provided
generation_mode = kwargs.pop("generation_mode", "text")
if generation_mode == "text":
# Set guidance_scale=None to prevent running UnbatchedCFG processor.
return super().generate(
inputs=inputs,
attention_mask=attention_mask,
generation_config=generation_config,
guidance_scale=None,
**kwargs,
**model_kwargs,
)

model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs

# Validate generation mode
if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
raise ValueError(
Expand Down Expand Up @@ -1326,8 +1325,14 @@ def generate(
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None

for i in range(num_image_tokens):
# Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`.
# Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt)
# instead of our prepared `inputs_embeds` (1 new token).
# This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`.
# Set this to `True` is also necessary to match the expected output, see the more detailed comment
# https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374.
model_inputs = self.prepare_inputs_for_generation(
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image test for janus is broken for so long, with different errors introduced over several commits (and some of them are resolved).

This is_first_iteration=True not only fixes the crash issue (I didn't find the root commit for it yet) but also bring the actual outputs back to match the expected outputs (which should have been updated in Default auto (#42805) ).

This fix is thus valid .

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so short (or long) history

)
if "attention_mask" in model_inputs:
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
Expand Down
21 changes: 13 additions & 8 deletions src/transformers/models/janus/modular_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections.abc import Callable
from dataclasses import dataclass

Expand Down Expand Up @@ -968,23 +967,23 @@ def generate(
**kwargs,
):
# 1. Handle generation config and model kwargs
generation_config = kwargs.pop("generation_config", self.generation_config)
generation_config = copy.deepcopy(generation_config)
# Pop generation_mode first since it's specific to Janus
generation_mode = kwargs.pop("generation_mode", "text")
generation_config, model_kwargs = self._prepare_generation_config(
kwargs.pop("generation_config", None), **kwargs
)

# Default to "text" generation if mode isn't provided
generation_mode = kwargs.pop("generation_mode", "text")
if generation_mode == "text":
# Set guidance_scale=None to prevent running UnbatchedCFG processor.
return super().generate(
inputs=inputs,
attention_mask=attention_mask,
generation_config=generation_config,
guidance_scale=None,
**kwargs,
**model_kwargs,
)

model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs

# Validate generation mode
if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
raise ValueError(
Expand Down Expand Up @@ -1090,8 +1089,14 @@ def generate(
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None

for i in range(num_image_tokens):
# Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`.
# Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt)
# instead of our prepared `inputs_embeds` (1 new token).
# This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`.
# Set this to `True` is also necessary to match the expected output, see the more detailed comment
# https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374.
model_inputs = self.prepare_inputs_for_generation(
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs
)
if "attention_mask" in model_inputs:
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
Expand Down
19 changes: 9 additions & 10 deletions tests/models/janus/test_modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,22 +544,21 @@ def test_model_generate_images(self):
897, 4044, 1762, 4676
],
("cuda", None): [
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
897, 4044, 1762, 4676
2567, 6155, 6155, 250, 15131, 15797, 15453, 12190, 3351, 10803, 10673, 3096, 14485, 5335, 6677,
13743, 9574, 8228, 3679, 11495, 11495, 15342, 11209, 1389, 15628, 6841, 15490, 10301, 12841, 3930,
3396, 10037, 7779, 4517, 3824, 3673, 14408, 4791, 14109, 4929, 2342, 4817, 15531, 4320, 1923, 9530,
13086, 5212, 14575, 4212
],
("xpu", None): [
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
897, 4044, 1762, 4676
4484, 4015, 15750, 376, 2300, 13791, 3609, 2509, 2418, 6347, 7372, 1006, 14519, 6126, 11908, 14968,
9642, 9490, 14427, 196, 15131, 6155, 4015, 2047, 15628, 4656, 14055, 13908, 3077, 4377, 11641, 4835,
8854, 10351, 7339, 2815, 13634, 8134, 257, 3621, 7739, 9954, 5989, 11578, 8763, 12788, 7571, 13595,
1762, 12683
],
}
)
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
# fmt: on

expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)
# Compare the first 50 generated tokens.
self.assertTrue(torch.allclose(expected_tokens, out[0][:50]))

Expand Down
Loading