Skip to content

fix bug for janus model image generation#45044

Merged
ydshieh merged 19 commits intohuggingface:mainfrom
kaixuanliu:janus-image-generation
Apr 1, 2026
Merged

fix bug for janus model image generation#45044
ydshieh merged 19 commits intohuggingface:mainfrom
kaixuanliu:janus-image-generation

Conversation

@kaixuanliu
Copy link
Copy Markdown
Contributor

Fix issue in #44792. @zucchini-nlp @ydshieh pls help review, thx!

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu changed the title Janus image generation fix bug for janus model image generation Mar 27, 2026
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a couple q

Comment thread src/transformers/models/janus/modeling_janus.py Outdated
Comment thread src/transformers/models/janus/modeling_janus.py Outdated
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as draft March 27, 2026 08:53
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as ready for review March 27, 2026 14:33
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as draft March 30, 2026 05:15
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as ready for review March 30, 2026 05:42
@kaixuanliu
Copy link
Copy Markdown
Contributor Author

@zucchini-nlp ,Hi, can you help review it again? Thx!!

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Mar 31, 2026

Investigation notes [Just for the record, no need to read]

9a6df2ce is the last commit where the test passes completely (no fixes needed).

Observations from bisecting the regression, in commit order:

bdaddb6f — Generation works (after removing 2 unrelated asserts at the top of the test, or equivalently with 9daee2e8), but produced output differs from expected values → test assertion mismatch only.

a81e04a9 — After removing the 2 asserts, generation now crashes:

TypeError: '>' not supported between instances of 'int' and 'NoneType'

at max(generation_config.max_length, num_image_tokens + seq_len)generation_config.max_length became None.

9daee2e8 / 2877e4e2 — Same max_length=None crash, no longer need to remove the 2 asserts to reproduce it.

93d7affd ("Generation config boolean defaults #43000") — New crash:

TypeError: repeat_interleave() received an invalid combination of arguments - got (NoneType, dim=int)

generation_config.num_return_sequences became None, passed as expand_size=None into _expand_inputs_for_generation.

Current main — Same expand_size=None / repeat_interleave(None) crash.


Experimenting with the PR fix:

  • PR fix applied, but without is_first_iteration=True: the expand_size crash is gone, but generation now fails deep in RoPE:

    CUDA error: device-side assert triggered
    

    at apply_rotary_pos_emb — position_ids are computed incorrectly, causing an out-of-bounds RoPE index error.

  • Full PR fix (with is_first_iteration=True): generation completes successfully. Only remaining issue is the expected output values in the test needing to be updated, which the PR handles.

So is_first_iteration=True is necessary to work around a behavioral change introduced somewhere between bdaddb6f and current main in prepare_inputs_for_generation (likely 3c52b78 — "Always pass full input_ids in prepare_inputs_for_generation"), which changed how input_ids are sliced when use_cache=True. Without is_first_iteration=True, the base class slices input_tokens (the full prompt) to 1 token, producing wrong position_ids.

ydshieh added a commit that referenced this pull request Mar 31, 2026
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Mar 31, 2026

ok, the failure, if we don't include the change is_first_iteration, comes from

421c7f6 [core] 🚨 Completely remove cache positions (#44181)

I ping him internally

@ydshieh

This comment was marked as resolved.

@ydshieh

This comment was marked as resolved.

@ydshieh

This comment was marked as resolved.

Comment on lines +547 to +550
4484, 4015, 15750, 15131, 7551, 7326, 3485, 4845, 376, 9925, 1082, 1457, 15550, 7029, 1482, 11522,
14695, 8587, 6807, 8221, 6807, 6140, 15079, 11766, 705, 11799, 405, 4228, 13153, 3910, 8631, 10037,
12758, 6321, 12249, 1787, 15982, 366, 8811, 6910, 1957, 10597, 8889, 8500, 7068, 2037, 897, 4044,
1762, 4080
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, on which hardware you get this value for cuda? On A10, I get

([ 2567, 6155, 6155, 250, 15131, 15797, 15453, 12190, 3351, 10803,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use A100 with torch version 2.11.0+cu128. You can adjust the expected token to adapt to your CI env.

# computed incorrectly based on cache length, leading to RoPE index out of bounds errors.
model_inputs = self.prepare_inputs_for_generation(
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image test for janus is broken for so long, with different errors introduced over several commits (and some of them are resolved).

This is_first_iteration=True not only fixes the crash issue (I didn't find the root commit for it yet) but also bring the actual outputs back to match the expected outputs (which should have been updated in Default auto (#42805) ).

This fix is thus valid .

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so short (or long) history

Comment on lines 1328 to 1332
# Set is_first_iteration=True to force using inputs_embeds instead of input_ids.
# Without this, prepare_inputs_for_generation would use input_ids (the full prompt)
# instead of our prepared inputs_embeds (1 new token). This causes position_ids to be
# computed incorrectly based on cache length, leading to RoPE index out of bounds errors.
model_inputs = self.prepare_inputs_for_generation(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are forced to use inputs_embeds, then this fix is correct - otherwise it would indeed use input_ids without is_first_iteration. Is this expected to create inputs_embeds like that @zucchini-nlp ? Can't we let the model do it in forward from input_ids?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will put a comment along the code, but would like to move on by merge for now.

Use inputs_embeds doesn't seem a bad thing here (no need to recompute stuff in the for loop).
I do agree that, it's strange using input_ids won't work (giving wrong value part, the crash part I have no idea).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not doable for this model, because embeddings in image-generation mode are obtained via embed+pooling, later the lm head is also a bit different. In text gen generation is simple lm-style tho, which is why we have early exit a few lines above

https://github.com/kaixuanliu/transformers/blob/e634aa1bcb43e81bc12e4977bf2a673838ef7836/src/transformers/models/janus/modeling_janus.py#L1367

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: janus

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Apr 1, 2026

run-slow: janus

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/janus"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b6c6bbad workflow commit (merge commit)
PR 96779dcf branch commit (from PR)
main 9914a364 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh ydshieh merged commit 6abd972 into huggingface:main Apr 1, 2026
18 checks passed
@kaixuanliu kaixuanliu deleted the janus-image-generation branch April 2, 2026 02:46
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Apr 4, 2026
* fix bug for janus model image generation

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update expected tokens

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update comment

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use `_preapre_generation_config`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update expected token

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update comments

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

* update

* update

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
* fix bug for janus model image generation

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update expected tokens

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update comment

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use `_preapre_generation_config`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update expected token

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update comments

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

* update

* update

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants