Skip to content

Add EXAONE-MoE implementations#43080

Merged
vasqu merged 25 commits intohuggingface:mainfrom
nuxlear:add-exaone-moe
Feb 3, 2026
Merged

Add EXAONE-MoE implementations#43080
vasqu merged 25 commits intohuggingface:mainfrom
nuxlear:add-exaone-moe

Conversation

@nuxlear
Copy link
Copy Markdown
Contributor

@nuxlear nuxlear commented Jan 2, 2026

What does this PR do?

Add EXAONE-MoE architecture for the K-EXAONE model released by LG AI Research.

This PR adds the modeling code of EXAONE-MoE (K-EXAONE), which is available at the fork of the LG AI Research:
https://github.com/Aim-Highest/transformers
Test code and documentation will be updated.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@nuxlear nuxlear marked this pull request as ready for review January 3, 2026 12:46
@nuxlear nuxlear force-pushed the add-exaone-moe branch 2 times, most recently from 930a3b7 to cf89e66 Compare January 10, 2026 11:42
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some initial comments

  • Missing tests but you already say that they will be added
  • Our Moe implementation has changed for v5 <-- this is the biggest thing to change IMO but it comes with nice benefits (fullgraph compile, boosted moe performance, fp8 support OOB etc)

Comment thread docs/source/en/model_doc/exaone_moe.md Outdated
@@ -0,0 +1,200 @@
<!--Copyright 2025 The LG AI Research and The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<!--Copyright 2025 The LG AI Research and The HuggingFace Team. All rights reserved.
<!--Copyright 2026 The LG AI Research and The HuggingFace Team. All rights reserved.

probably elsewhere as well then, happy new year :D

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To replace :p

("ernie4_5_vl_moe", "TokenizersBackend" if is_tokenizers_available() else None),
("esm", "EsmTokenizer"),
("exaone4", "GPT2Tokenizer" if is_tokenizers_available() else None),
("exaone_moe", "GPT2Tokenizer" if is_tokenizers_available() else None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("exaone_moe", "GPT2Tokenizer" if is_tokenizers_available() else None),

I suspect that you need the tokenizers backend, please see #42894 for more details. Can you double-check?

As a side note, this does not require any changes on the hub repo (we autodetect this). Only if you notice that you indeed need the gpt2 tokenizer, then we will need to add this to the mapping here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will check whether the tokenizer backend works well with EXAONE MoE (and EXAONE 4 as well).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update here? Can this be removed?

for i in range(self.num_hidden_layers)
]
if "sliding_window" in self.layer_types:
if "sliding_attention" in self.layer_types:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, that's a good catch 😅

Comment on lines +111 to +115
self.is_moe_layer = is_moe_layer
if self.is_moe_layer is None:
self.is_moe_layer = [0] * self.first_k_dense_replace + [1] * (
self.num_hidden_layers - self.first_k_dense_replace
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to attention layers (sliding window, full etc), we also introduced it similarly for moe layers, see

# Default to MoE from the second layer and on
self.mlp_layer_types = mlp_layer_types
if self.mlp_layer_types is None:
self.mlp_layer_types = ["dense"] + ["sparse"] * (self.num_hidden_layers - 1)
layer_type_validation(self.mlp_layer_types, self.num_hidden_layers, attention=False)

Can you change it to that logic?

Comment on lines +127 to +128
if "sliding_attention" in self.layer_types:
self.cache_implementation = "hybrid"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure if we still need this

class ExaoneMoEDecoderLayer(OlmoeDecoderLayer):
def __init__(self, config: ExaoneMoEConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = ExaoneMoEAttention(config=config, layer_idx=layer_idx)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we need this, should also be inheritable with modular, no?

def __init__(self, config: ExaoneMoEConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = ExaoneMoEAttention(config=config, layer_idx=layer_idx)
self.mlp = ExaoneMoESparseMoEBlock(config) if config.is_moe_layer[layer_idx] else ExaoneMoEMLP(config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment about mlp_layer_types (in the config)

"attentions": ExaoneMoEAttention,
"router_logits": ExaoneMoESparseMoEBlock,
}
_can_compile_fullgraph = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See

_can_compile_fullgraph = (
is_grouped_mm_available()
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile

If we get the conversion working, we can compile fullgraph

Comment thread src/transformers/models/exaone_moe/modular_exaone_moe.py
Comment on lines +300 to +309
class ExaoneMoEForSequenceClassification(Exaone4ForSequenceClassification):
pass


class ExaoneMoEForTokenClassification(Exaone4ForTokenClassification):
pass


class ExaoneMoEForQuestionAnswering(Exaone4ForQuestionAnswering):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do we really need this? If we can, I'd like to avoid these

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43080&sha=e7d79e

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 26, 2026

@nuxlear just ping me again when it's ready for review

@nuxlear
Copy link
Copy Markdown
Contributor Author

nuxlear commented Jan 26, 2026

@vasqu I think it's ready for review, but make fix-repo does not seem to be consistent. (it adds extra tabs after overwriting the configuration_exaone_moe.py, which makes a CI failure)

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks already super clean, just a few small nit + a dummy model for our CI

Comment thread docs/source/en/model_doc/exaone_moe.md Outdated
@@ -0,0 +1,200 @@
<!--Copyright 2025 The LG AI Research and The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To replace :p

Comment thread docs/source/en/model_doc/exaone_moe.md Outdated
The K-EXAONE model is compatible with both OpenAI and HuggingFace tool calling specifications.
The example below demonstrates tool calling using HuggingFace’s docstring-to-tool-schema utility.

Please check the [example file](examples/example_output_search.txt) for an example of a search agent conversation using K-EXAONE.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong link?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'll fix it :)

("ernie4_5_vl_moe", "TokenizersBackend" if is_tokenizers_available() else None),
("esm", "EsmTokenizer"),
("exaone4", "GPT2Tokenizer" if is_tokenizers_available() else None),
("exaone_moe", "GPT2Tokenizer" if is_tokenizers_available() else None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping

Comment thread src/transformers/models/exaone_moe/modular_exaone_moe.py
Comment thread src/transformers/models/exaone_moe/modular_exaone_moe.py
Comment thread src/transformers/models/exaone_moe/modular_exaone_moe.py

@require_torch
class ExaoneMoeIntegrationTest(unittest.TestCase):
TEST_MODEL_ID = "LGAI-EXAONE/K-EXAONE-236B-A23B"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be too big for our CI, can we create a dummy model instead? (up to 24GB Vram as it's an A10 GPU)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are nice tho!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to upload a dummy model to the HF hub?
We don't have a proper model for this, and it feels a bit awkward to upload dummy weights under our official organization.
Would it be okay if I uploaded it under my personal account instead?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure, I can also move it to our internal testing repo afterwards

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://huggingface.co/nuxlear/EXAONE-MoE-Dummy-7B-A1B
just uploaded, but I need to do more test with it.

Comment on lines +119 to +129
@slow
@require_torch_large_accelerator
def test_model_generation_beyond_sliding_window_flash(self):
EXPECTED_OUTPUT_TOKEN_IDS = [21605, 2711]
input_ids = [72861, 2711] * 2048
model = self.get_model()
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)

with torch.no_grad():
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need to change get_model to pass the implementation? It should load with sdpa currently this way - we can also just rename the test

input_ids = input_ids.to(model.model.embed_tokens.weight.device)

with torch.no_grad():
generated_ids = model.generate(**input_ids, max_new_tokens=20, temperature=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
generated_ids = model.generate(**input_ids, max_new_tokens=20, temperature=0)
generated_ids = model.generate(**input_ids, max_new_tokens=20, do_sample=False)

nit: just our preferred way to do it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below as well

sliding_window_pattern=4,
layer_types=None,
mlp_layer_types=None,
first_k_dense_replace=1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah missed this: this should be mlp layer types with a list of the types. (Similar to layer types for attention)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean one of 'dense' and 'sparse', right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 28, 2026

You can ping me when it's ready for review

@nuxlear
Copy link
Copy Markdown
Contributor Author

nuxlear commented Jan 28, 2026

Should I update the test code with a dummy model? I think everything else is ready.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 28, 2026

Should I update the test code with a dummy model? I think everything else is ready.

Yes, please 🙏 taking a look in a second then

@nuxlear
Copy link
Copy Markdown
Contributor Author

nuxlear commented Jan 28, 2026

It seems the current dummy model needs to be updated, so I’ll notify you when it’s ready.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some small last comments, imo it looks very much ready! Let's cleanup the config a tad more and wrap up the integration tests then we are good to go

Just ping me again when ready, great work

("ernie4_5_vl_moe", "TokenizersBackend" if is_tokenizers_available() else None),
("esm", "EsmTokenizer"),
("exaone4", "GPT2Tokenizer" if is_tokenizers_available() else None),
("exaone_moe", "GPT2Tokenizer" if is_tokenizers_available() else None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update here? Can this be removed?

@@ -0,0 +1,27 @@
# Copyright 2025 The LG AI Research and The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2025 The LG AI Research and The HuggingFace Team. All rights reserved.
# Copyright 2026 The LG AI Research and The HuggingFace Team. All rights reserved.

Comment on lines +79 to +89
sliding_window_pattern (`str`, *optional*, defaults to 4):
The pattern to use for sliding window attention. Can be one of:
- `None`: No sliding window attention is used
- `int`: Every `sliding_window` layers, use global attention, else use local attention.
- `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the
attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The
final layer always uses global attention regardless of the pattern.
For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means:
- Layer 0, 1, 2: local attention,
- Layer 3: global attention,
...(repeated)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to avoid this if possible, and just use layertypes directly. We also start to do the same for mlp layers (moe) and it gives more flexibility with other attention flavors (e.g. linear attention (gated delta net))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, and it would be better to remove them.

However, since these configs (including those below) are often used by other libraries such as llama.cpp, they should remain in the model’s config.json.

If that is acceptable, we have no reason to keep them in the config implementation. 😃

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, no worries not super important 👍 would be just the ideal case

Comment on lines +94 to +96
first_k_dense_replace (`int`, *optional*, defaults to 1):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same spirit to my comment before, let's remove this and only use mlp layer types directly

from ...configuration_utils import PreTrainedConfig, layer_type_validation


class ExaoneMoeConfig(PreTrainedConfig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably needs to sync with main, I recently made the Rope mixin explicit for models that use it - can you check

E.g.

class Exaone4Config(PreTrainedConfig, RotaryEmbeddingConfigMixin):
(modular should do it automatically for you, just need to merge with main and apply modular again)

Comment on lines +226 to +228
PreTrainedConfig.__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I commented directly on the config file but should be done here ofc

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 29, 2026

Also sorry about the CI, it's still flaky here and there but it should be more stable on main

@nuxlear
Copy link
Copy Markdown
Contributor Author

nuxlear commented Jan 30, 2026

I’ve updated the dummy test model and the docstrings.
Could you please do a final check? @vasqu

@nuxlear
Copy link
Copy Markdown
Contributor Author

nuxlear commented Feb 2, 2026

@ArthurZucker @Rocketknight1 could you kindly review this PR?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Yes!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 2, 2026

run-slow: exaone4, exaone_moe

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 2, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/exaone4", "models/exaone_moe"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 2, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b1f6aff2 merge commit
PR 0e1e5bc6 branch commit
main 751cff7c base commit

Model CI Report

7 new failed tests from this PR 😭

  • exaone4:
    tests/models/exaone4/test_modeling_exaone4.py::Exaone4ModelTest::test_generate_compilation_all_outputs
    tests/models/exaone4/test_modeling_exaone4.py::Exaone4ModelTest::test_generate_compile_model_forward_fullgraph

  • exaone_moe:
    tests/models/exaone_moe/test_modeling_exaone_moe.py::ExaoneMoeModelTest::test_cpu_offload
    tests/models/exaone_moe/test_modeling_exaone_moe.py::ExaoneMoeModelTest::test_disk_offload_bin
    tests/models/exaone_moe/test_modeling_exaone_moe.py::ExaoneMoeModelTest::test_disk_offload_safetensors
    tests/models/exaone_moe/test_modeling_exaone_moe.py::ExaoneMoeIntegrationTest::test_model_generation_beyond_sliding_window_flash
    tests/models/exaone_moe/test_modeling_exaone_moe.py::ExaoneMoeIntegrationTest::test_model_logits

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some last comments from my side, fixed a few smaller issues (checking with run slow again in a second)

Comment on lines +89 to +94
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 53):
End of stream token id.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took this from https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B/blob/main/generation_config.json

A bit confused since the values were different would be nice if you could confirm these or if it should be the previous values, see 0e1e5bc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use 53 as the end-of-turn token, while 2 is used as EOS.
Either can be used as the default value, so you can set it to 53.


return cls.model

def test_model_logits(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logits don't match on our CI, I think it's a GPU diff so let me know if I should update them myself

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that. It looks like you’ll need to update them in your CI environment.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, let me update them tomorrow then 👍 (and also copy the repo to our internal testing)

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 2, 2026

run-slow: exaone4, exaone_moe

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 2, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/exaone4", "models/exaone_moe"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 2, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 24f1d500 merge commit
PR 1b3b159f branch commit
main 78e4f885 base commit

Model CI Report

1 new failed tests from this PR 😭

  • exaone_moe:
    tests/models/exaone_moe/test_modeling_exaone_moe.py::ExaoneMoeIntegrationTest::test_model_logits

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! 🤗

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 3, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, exaone4, exaone_moe

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 3, 2026

run-slow: exaone_moe

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 3, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/exaone_moe"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 3, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 7910da44 merge commit
PR 87f6ca4f branch commit
main 01e860eb base commit

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the values and made a copy of the repo to our internal testing repos so feel free to remove your private one @nuxlear

Merging in a second, thanks a lot for iterating 🤗

@vasqu vasqu enabled auto-merge (squash) February 3, 2026 16:32
@vasqu vasqu disabled auto-merge February 3, 2026 16:34
@vasqu vasqu merged commit 379ec6b into huggingface:main Feb 3, 2026
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants