Skip to content

Fix PEFT x MoEs#43261

Merged
ArthurZucker merged 67 commits intomainfrom
peft-x-moes
Jan 24, 2026
Merged

Fix PEFT x MoEs#43261
ArthurZucker merged 67 commits intomainfrom
peft-x-moes

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker commented Jan 13, 2026

What does this PR do?

Fixes #42491

This should serve as an example of how the weight loader can be re-used in other project.
The content is probably gonna be upstreamed to peft!
Current status:
image

What to expect:

  1. You already loaded a transformers model
  2. You want to load only the Peft adapters:
    a. we check the weight conversion mapping
    b. if there are ops, we replace them with mapped peft ops
    c. we just collect lora_A and lora_B together, process them like so:
class PeftConcatenate(Concatenate):
    """Convert per-expert LoRA weights to merged MoE weights using SVD."""
    @torch.no_grad
    def convert(
        self, input_dict: dict[str, list[torch.Tensor]], source_patterns: list[str], target_patterns: list[str], **kwargs
    ) -> dict[str, list[torch.Tensor]]:
        lora_a_out = []
        lora_b_out = []
        for k,v in input_dict.items():
            if "lora_A" in k:
                lora_a_out.append(v)
            elif "lora_B" in k:
                lora_b_out.append(v)
        lora_a_out = torch.cat(lora_a_out, dim=0)
        for i in range(len(lora_b_out)):
            lora_b_out.append(torch.block_diag(lora_b_out[0][i], lora_b_out[1][i]))
        lora_b_out = torch.stack(lora_b_out[2:], dim=0)
        return {
            target_patterns[0]+".lora_A.weight": [lora_a_out],
            target_patterns[0]+".lora_B.weight": [lora_b_out],
        }
    d. The output fused gate_up.lora_A/B are loaded in the model

CF: (Credits to @BenjaminBossan for the pic)
image

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker marked this pull request as ready for review January 14, 2026 15:56
@BenjaminBossan
Copy link
Copy Markdown
Member

Just leaving some general comments on this PR:

  1. I wouldn't override __setattr__ in WeightTransform, as it's error prone. Instead, change source_patterns and target_patterns to @property with an appropriate @setter and move the logic there.
  2. The hotswap code path has been removed from load_adapter, it needs to be added back in.
  3. Similarly, there was a if peft_config.inference_mode: self.eval() call there that's now missing.
  4. As discussed, the currently hard-coded Mixtral conversion ops need to be moved to a mapping that is only called for Mixtral.
  5. We should add some sanity checks, e.g. if the expert layer is targeted and the PEFT adapter is not LoRA, raise a helpful error message.

nemo and others added 7 commits January 23, 2026 16:12
* rank needed to be set to 2*r for concatenated gate up projection
  parameter so that PEFT allocates 2*r and matches the converted
  weights (using rank_pattern)

* the weights needed to be transposed to match the counter parts

* MoE in PEFT assumes (experts, in, out) but Mixtral MoE is transposed
  so we need to patch this assumption in PEFT for now
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43261&sha=92d0fa

Copy link
Copy Markdown
Collaborator Author

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ready!

@ArthurZucker ArthurZucker merged commit 3af2eb7 into main Jan 24, 2026
22 of 26 checks passed
@ArthurZucker ArthurZucker deleted the peft-x-moes branch January 24, 2026 10:06
githubnemo pushed a commit to githubnemo/peft that referenced this pull request Feb 27, 2026
Continuation of PR huggingface#2995.
Background: huggingface/transformers#42491 and huggingface/transformers#43261.

This change implements conversion operations for converting some existing
PEFT checkpoints, mainly dealing with the fusing of MoE layers in transformers v5.

The code added here is currently a copy from the code that exists in transformers
which is supposed to be gated as soon PEFT v0.19 is released and use the code
in this PR.

The copying makes testing a bit difficult since there's currently no routing
depending on the PEFT version in transformers. Older transformers versions, therefore,
need patching to forcefully use the PEFT implementation of the conversion.
As soon as the routing is implemented in transformers we can conditionally
disable the patching.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The LoRA model trained with qwen3_moe on hf4.x cannot be used on the current main branch (hf5.x).

3 participants