Skip to content

Fix gemma4 has flash-attention incompatbile head-dim=512#45202

Open
Qubitium wants to merge 10 commits intohuggingface:mainfrom
Qubitium:gemma4-fa-fix
Open

Fix gemma4 has flash-attention incompatbile head-dim=512#45202
Qubitium wants to merge 10 commits intohuggingface:mainfrom
Qubitium:gemma4-fa-fix

Conversation

@Qubitium
Copy link
Copy Markdown
Contributor

@Qubitium Qubitium commented Apr 2, 2026

What does this PR do?

Disable FlashAttention support for Gemm4 which FA cannot suport due to global.head-dim=512.

I am very confused at the current code/test for Gemma4. I ran real inference using transformer main and fa throws head-dim errors yet I see there are tests in transformers with comments noting that FA2 generating gibberish and forcing changing head-dim value so the tests may pass. Is the Gemma4 models uploaded by Google have broken config.json when it comes to head-dim or did Transformers was testing on an unreleased/alpha version of Gemma4 or the test code was just bad? The tests should have never passed.

Gemma 4 sliding-attention layers: head_dim = 256, OK
Gemma 4 full-attention layers: head_dim = 512, not OK for FA2

Bad testing code force the the global_head_dim to match self.head_dim which allowed FA to run. Again, even the notes mentioned garbage output with FA2 enabled.

/root/transformers/tests/models/gemma4/test_modeling_gemma4.py:77: self.global_head_dim = self.head_dim
# TODO: raushan FA2 generates gibberish for no reason, check later
  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez @vasqu
@douglas-reid

"sliding_attention",
"full_attention",
] # similarly we want to test sharing on both types
self.global_head_dim = self.head_dim # gemma4 use a different head_dim for full and sliding layers
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this overriden? This caused tests to pass but does not reflect reality.

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Apr 3, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45202&sha=3919a9

My fix broke more gemma4 tests which assumed single head-dim for all the layers when all the gemm4 models switch between head-dim sizes in different layers.

Fixing these failed tests now.

Comment on lines +143 to +145
@unittest.skip("Float8 quantization + TP numerical noise exceeds match threshold")
def test_tp_generation_quantized(self):
pass
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 quantized is failing ci due to very low accuracy of output post-quantization. The fix is outside the scope of this minimal (disable FA2 support in Gemma4) pr.

Comment on lines +95 to +96
tensor_parallel_atol = 2e-4
tensor_parallel_rtol = 2e-4
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes ci test failing due to too strict of tolerance. Verified fix on A100.

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Apr 3, 2026

All relevant Gemm4 ci tests passing. The failed tests is some tokenizer issue appears to be unrelated. Ready for full review and merge. The actual fix is 2 lines. The rest are fixing bad unit test assumptions.

HuiyingLi added a commit to NVIDIA-NeMo/Automodel that referenced this pull request Apr 5, 2026
Gemma4 cannot use flash_attention_2 (global layers have head_dim=512,
FA2 max is 256 — huggingface/transformers#45202). SDPA has a known
NaN bug with sliding window + padding (huggingface/transformers#32390)
that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention
handles all padding correctly.

Also add padding_side: right to the 4 original configs that were
missing it (the mock configs already had it).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
HuiyingLi added a commit to NVIDIA-NeMo/Automodel that referenced this pull request Apr 5, 2026
Gemma4 cannot use flash_attention_2 (global layers have head_dim=512,
FA2 max is 256 — huggingface/transformers#45202). SDPA has a known
NaN bug with sliding window + padding (huggingface/transformers#32390)
that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention
handles all padding correctly.

Also add padding_side: right to the 4 original configs that were
missing it (the mock configs already had it).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@Qubitium Qubitium changed the title fix gemma4 has flash-attention incompatbile head-dim=512 Fix gemma4 has flash-attention incompatbile head-dim=512 Apr 7, 2026
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

A lot of has changed with the recent changes on gemma4 @Qubitium, I think the best way would be to maybe really have this layer-wise solution where we set granularly what attention type we want. It is non-trivial tho because it needs to touch some core code

Dao-AILab/flash-attention#2427 suggests that FA3 and FA4 will at least soon ish have support for these head dims. So it might be smarter to wait for these instead and then only allow those FA versions to be used

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

Tl;dr: Not sure about the timeline, if FA3/4 is faster I'd rather change the scope of allowed FAs then have the layer wise solution (because it has a few nontrivial things imo)

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Apr 9, 2026

A lot of has changed with the recent changes on gemma4 @Qubitium, I think the best way would be to maybe really have this layer-wise solution where we set granularly what attention type we want. It is non-trivial tho because it needs to touch some core code

Dao-AILab/flash-attention#2427 suggests that FA3 and FA4 will at least soon ish have support for these head dims. So it might be smarter to wait for these instead and then only allow those FA versions to be used

Agree on the per layer control now that gemma has dual attention type layers. Another issue is that the current on/off FA toggle has no FA version scope so it is no longer whether it supports FA but if what versions are allowed for this model/layer config.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 10, 2026

We have

# Model's compatible flash kernels (e.g., "kernels-community/flash-mla") defaulting to the first in the list
_compatible_flash_implementations: list[str] | None = None

which allows to specify the compatible FAs to be used (in the future).

@Qubitium
Copy link
Copy Markdown
Contributor Author

We have

# Model's compatible flash kernels (e.g., "kernels-community/flash-mla") defaulting to the first in the list
_compatible_flash_implementations: list[str] | None = None

which allows to specify the compatible FAs to be used (in the future).

Nice! I see you are the one who added the new toggle 2 months. =) In that case, we can just remove the boolean state and just only use the full control map?

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 10, 2026

It's more of an accompanying flag: If it is set to None, we assume general FA support for all flavors, otherwise only for those listed there (and we raise an error if it is not compatible).

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Apr 10, 2026

It's more of an accompanying flag: If it is set to None, we assume general FA support for all flavors, otherwise only for those listed there (and we raise an error if it is not compatible).

Forward thinking wise, perhaps we should just make it generic and use enum mask/list to store all the compatible attention? FA is only one of many that will exist so instead of _supports_flash_attn, do _supports_attention_imp = FA | FA2 | FA3 | F4 ... other attention. This way we dont' need to add multiple class states whenever new attn_impl is invented and added.

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Qubitium! Setting both head_dim to the same value in tests is simply an easy way to have usual tests pass as they expect the same shapes all the time. There's basically no way that the code can break or see a regression for different head_dim per layer, so that's why. Those tests run on a small version of the model, and FA technically works perfectly for gemma4 (only on released checkpoints, the head_dim is too large on full layers), so even if we updated the tests, they would still pass as we don't want a head_dim of 512 on the small models. It's a bit of an awkward situation, where the model supports FA, but FA itself does not support the values they used in real checkpoints.
Of course I thought about setting the flag to False, but I decided not to in the end (at least for now), as it's not a Transformers problem, it's a FA/checkpoints problem.

Also, I'm not in favor of adding a complicated per-layer attention system, as this will likely only ever be used for gemma4, and FA kernels will soon be updated to reflect gemma4's 512 head_dim, making the effort useless. It will only bloat attention dispatch IMO. Sdpa should not be really slower than FA in most cases.

Comment on lines +440 to +442
@unittest.skip("Gemma4 multimodal tiny test config exceeds the 1M common-test size cap")
def test_model_is_small(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See this, we really cannot afford it on our CI, that's why we use smaller head dims

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Apr 12, 2026

Hey @Qubitium! Setting both head_dim to the same value in tests is simply an easy way to have usual tests pass as they expect the same shapes all the time. There's basically no way that the code can break or see a regression for different head_dim per layer, so that's why. Those tests run on a small version of the model, and FA technically works perfectly for gemma4 (only on released checkpoints, the head_dim is too large on full layers), so even if we updated the tests, they would still pass as we don't want a head_dim of 512 on the small models. It's a bit of an awkward situation, where the model supports FA, but FA itself does not support the values they used in real checkpoints. Of course I thought about setting the flag to False, but I decided not to in the end (at least for now), as it's not a Transformers problem, it's a FA/checkpoints problem.

Also, I'm not in favor of adding a complicated per-layer attention system, as this will likely only ever be used for gemma4, and FA kernels will soon be updated to reflect gemma4's 512 head_dim, making the effort useless. It will only bloat attention dispatch IMO. Sdpa should not be really slower than FA in most cases.

I have reverted all test related changes as you have shown the tests are valid in their ctx.

My remaining point is only about the runtime contract exposed to users today. Released Gemma4 checkpoints use full-attention layers with global_head_dim=512, and currently released fa kernels do not support this. In this present tense situation, advertising Gemma4 as fa compatible is misconfiguration, even if future fa releases will fix it.

I agree a blanket global disable is not ideal and FA may release a fix tomorrow for all we know. But until a released flash kernel supports Gemma4’s published checkpoint, the current True value is misleading for users. Once FA release the appropriate kernels, we can simply return this value to True.

akoumpa added a commit to NVIDIA-NeMo/Automodel that referenced this pull request Apr 16, 2026
* feat: add mock VLM dataset and Gemma4 pretokenize support

Add build_mock_vlm_dataset for VLM benchmarking and testing without
real data downloads. Generates random PIL images + dummy text in the
standard Automodel conversation format.

Key changes:
- New mock VLM dataset with max_length-driven response generation
- PreTokenizedDatasetWrapper: add truncate mode (labels built before
  truncation) and Gemma4 tensor support (image_position_ids,
  mm_token_type_ids)
- pad_collate_fn: handle mm_token_type_ids padding and
  image_position_ids concatenation for Gemma4
- Recipe: auto-enable pretokenize/truncate when max_length is set
  on the dataset config
- Gemma4 4B mock config with right-padding (fixes SDPA + left-padding
  NaN gradient bug)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: rename test_mock.py to test_mock_vlm.py to avoid pytest collection conflict

pytest uses flat module names when __init__.py is absent, causing a
name collision between tests/unit_tests/datasets/llm/test_mock.py
and tests/unit_tests/datasets/vlm/test_mock.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: handle 1D tensor truncation and add truncate mode test

Address Claude review comments:
- Truncation dict comprehension now also slices 1D tensors whose
  length matches seq_len (e.g. mm_token_type_ids returned as (seq_len,))
- Add test_pretokenized_wrapper_truncate_mode to verify truncation
  clips input_ids, labels, attention_mask, and 1D mm_token_type_ids

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: correct mock patch paths for local imports in truncate test

Local imports inside __getitem__ bind from source modules at call time,
so patches must target the source (collate_fns, fake_image) not the
datasets module.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: remove unreliable label-content assertion from truncate test

The build_labels_from_template mock cannot reliably override the local
import inside __getitem__. Drop the label-content check and keep the
shape-based assertions which verify truncation works correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: remove PreTokenizedDatasetWrapper truncate test

The test requires mocking too many internal imports (local imports
inside __getitem__) to work reliably in CI. The 1D tensor truncation
fix is a straightforward change covered by GPU integration tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: switch Gemma4 configs to eager attention and right-padding

Gemma4 cannot use flash_attention_2 (global layers have head_dim=512,
FA2 max is 256 — huggingface/transformers#45202). SDPA has a known
NaN bug with sliding window + padding (huggingface/transformers#32390)
that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention
handles all padding correctly.

Also add padding_side: right to the 4 original configs that were
missing it (the mock configs already had it).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: remove activation_checkpointing from mock configs to match originals

The original gemma4_4b.yaml and gemma4_26b_a4b_moe.yaml configs do not
use activation checkpointing, so the mock variants should not either.
Also bump 4B mock nproc-per-node from 2 to 8 (needed without AC at
seq_len=2048).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add unit tests for PreTokenizedDatasetWrapper truncation and pad_collate_fn mm_token_type_ids

- Test that truncate=True produces exact max_length shapes for input_ids,
  attention_mask, and labels
- Test that labels are not all -100 after truncation
- Test that mm_token_type_ids (1D) is truncated correctly
- Test that pad_collate_fn pads and trims mm_token_type_ids with
  autoregressive shift for both 1D and 2D inputs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: assert mm_token_type_ids presence instead of conditional guard

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: mock internal helpers in PreTokenizedDatasetWrapper tests

The wrapper's __getitem__ calls _preload_media, _conversation_has_media,
build_labels_from_template, etc. which need a full processor. Mock these
internals so the tests can run without a real HF processor/model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: place mock labels in first quarter so they survive truncation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: compare inputs_embeds generate against input_ids generate, not uncached decode

The Mamba mixer uses different CUDA kernels for cached (chunk_scan +
selective_state_update) vs uncached (split_conv1d_scan_combined) paths.
These are mathematically equivalent but not bit-identical in bf16,
causing token divergence after the first step. Compare against
generate(input_ids=...) instead, which uses the same cached kernel path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* revert

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, we can set to False for now since kernels are not yet ready.
cc @vasqu @Qubitium let's make a mental note to change this back asap when fa kernels are released for head_dim 512!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

@Cyrilvallez Cyrilvallez enabled auto-merge April 23, 2026 04:03
linnanwang pushed a commit to NVIDIA-NeMo/Automodel that referenced this pull request Apr 24, 2026
* feat: add mock VLM dataset and Gemma4 pretokenize support

Add build_mock_vlm_dataset for VLM benchmarking and testing without
real data downloads. Generates random PIL images + dummy text in the
standard Automodel conversation format.

Key changes:
- New mock VLM dataset with max_length-driven response generation
- PreTokenizedDatasetWrapper: add truncate mode (labels built before
  truncation) and Gemma4 tensor support (image_position_ids,
  mm_token_type_ids)
- pad_collate_fn: handle mm_token_type_ids padding and
  image_position_ids concatenation for Gemma4
- Recipe: auto-enable pretokenize/truncate when max_length is set
  on the dataset config
- Gemma4 4B mock config with right-padding (fixes SDPA + left-padding
  NaN gradient bug)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: rename test_mock.py to test_mock_vlm.py to avoid pytest collection conflict

pytest uses flat module names when __init__.py is absent, causing a
name collision between tests/unit_tests/datasets/llm/test_mock.py
and tests/unit_tests/datasets/vlm/test_mock.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: handle 1D tensor truncation and add truncate mode test

Address Claude review comments:
- Truncation dict comprehension now also slices 1D tensors whose
  length matches seq_len (e.g. mm_token_type_ids returned as (seq_len,))
- Add test_pretokenized_wrapper_truncate_mode to verify truncation
  clips input_ids, labels, attention_mask, and 1D mm_token_type_ids

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: correct mock patch paths for local imports in truncate test

Local imports inside __getitem__ bind from source modules at call time,
so patches must target the source (collate_fns, fake_image) not the
datasets module.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: remove unreliable label-content assertion from truncate test

The build_labels_from_template mock cannot reliably override the local
import inside __getitem__. Drop the label-content check and keep the
shape-based assertions which verify truncation works correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: remove PreTokenizedDatasetWrapper truncate test

The test requires mocking too many internal imports (local imports
inside __getitem__) to work reliably in CI. The 1D tensor truncation
fix is a straightforward change covered by GPU integration tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: switch Gemma4 configs to eager attention and right-padding

Gemma4 cannot use flash_attention_2 (global layers have head_dim=512,
FA2 max is 256 — huggingface/transformers#45202). SDPA has a known
NaN bug with sliding window + padding (huggingface/transformers#32390)
that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention
handles all padding correctly.

Also add padding_side: right to the 4 original configs that were
missing it (the mock configs already had it).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: remove activation_checkpointing from mock configs to match originals

The original gemma4_4b.yaml and gemma4_26b_a4b_moe.yaml configs do not
use activation checkpointing, so the mock variants should not either.
Also bump 4B mock nproc-per-node from 2 to 8 (needed without AC at
seq_len=2048).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add unit tests for PreTokenizedDatasetWrapper truncation and pad_collate_fn mm_token_type_ids

- Test that truncate=True produces exact max_length shapes for input_ids,
  attention_mask, and labels
- Test that labels are not all -100 after truncation
- Test that mm_token_type_ids (1D) is truncated correctly
- Test that pad_collate_fn pads and trims mm_token_type_ids with
  autoregressive shift for both 1D and 2D inputs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: assert mm_token_type_ids presence instead of conditional guard

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: mock internal helpers in PreTokenizedDatasetWrapper tests

The wrapper's __getitem__ calls _preload_media, _conversation_has_media,
build_labels_from_template, etc. which need a full processor. Mock these
internals so the tests can run without a real HF processor/model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: place mock labels in first quarter so they survive truncation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: compare inputs_embeds generate against input_ids generate, not uncached decode

The Mamba mixer uses different CUDA kernels for cached (chunk_scan +
selective_state_update) vs uncached (split_conv1d_scan_combined) paths.
These are mathematically equivalent but not bit-identical in bf16,
causing token divergence after the first step. Compare against
generate(input_ids=...) instead, which uses the same cached kernel path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* revert

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants