Skip to content

Changes for transformers 5 weight conversion#3083

Merged
BenjaminBossan merged 15 commits intohuggingface:mainfrom
BenjaminBossan:transformers-weight-conversion-additions
Apr 1, 2026
Merged

Changes for transformers 5 weight conversion#3083
BenjaminBossan merged 15 commits intohuggingface:mainfrom
BenjaminBossan:transformers-weight-conversion-additions

Conversation

@BenjaminBossan
Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan commented Mar 5, 2026

See accompanying huggingface/transformers#44478.

  • better handling of swapped in and out features
  • move PEFT config update functions to PEFT
  • move PEFT-specific weight conversion logic to PEFT

Note that the newly added tests will fail until a new transformers release with the linked PR is out. This should be v5.4, so the corresponding tests only run with that transformers version. I locally tested with the current main branch and the tests pass.

- better handling of swapped in and out features
- move PEFT config update functions to PEFT
This allows the weight conversion to be correctly applied without going
through transformer_model.load_adapter.
@BenjaminBossan BenjaminBossan marked this pull request as ready for review March 6, 2026 17:03
Move weight conversion code to its own module.
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@githubnemo The PR should now be ready for review.

@BenjaminBossan BenjaminBossan changed the title [WIP] Changes for transformers 5 weight conversion Changes for transformers 5 weight conversion Mar 11, 2026
- always apply in/out feature swapping for MoE params
- add a test for this with Qwen3 MoE
- expose swapping argument to provide escape hatch
Comment thread src/peft/tuners/lora/config.py Outdated
Whether to tie weights or not after peft initialization. This will ensure that the adapters added to the
tied layers are also tied. This is only applicable for layers passed via `modules_to_save` and
`target_modules`.
param_wrapper_swap_in_out_features (`bool`, *optional)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this parameter used to resolve #3112? If so, maybe automatic detection would be better?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my latest commit, I changed the code to use module.is_transposed.

@jeejeelee
Copy link
Copy Markdown

jeejeelee commented Mar 23, 2026

I tested your branch, the saved LoRA weights for qwen35-moe still have the same issues,see: #3112

@BenjaminBossan
Copy link
Copy Markdown
Member Author

I tested your branch, the saved LoRA weights for qwen35-moe still have the same issues,see: #3112

Could you please show a small reproducer for that error?

@jeejeelee
Copy link
Copy Markdown

@BenjaminBossan You should be able to reproduce this issue easily by training Qwen3-5 MoE with LoRA.

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@jeejeelee Could you at least describe what goes wrong? Loading the trained LoRA weights or something else? What error do you get? What transformers version are you using?

@jeejeelee
Copy link
Copy Markdown

@BenjaminBossan Hmm, I think I've already described it clearly in #3112.

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@jeejeelee LMK if I overlooked something, but I didn't find the information that I would need to try to reproduce the error and fix it. You mentioned:

I found that the generated LoRA weights were incorrect

and

the saved LoRA weights for qwen35-moe still have the same issues

and linked to a couple of those weights. But I'm missing the information how these weights were created, what error you got (full stacktrace), what versions were used (especially of Transformers). I can understand if it's not possible for you to provide a full reproducer, but if you want me to take a look at your issue, you have to provide these missing pieces of information or point me to where you've shared them.

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@jeejeelee I wrote a test that loads a small Qwen 3.5 MoE model and applies LoRA to normal linear layers and also to MoE layeers. Then it saves the LoRA weights and loads them again, checking that the outputs remain the same. The test passes (I used transformers 5.4.0). So without further information, I cannot replicate any error with this model architecture.

Qwen 3.5 MoE test
import torch
from transformers import Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeConfig
from peft import LoraConfig, PeftModel, get_peft_model

def create_small_qwen():
    torch.manual_seed(0)
    config = Qwen3_5MoeConfig(
        image_token_id=248056,
        video_token_id=248057,
        vision_start_token_id=248053,
        vision_end_token_id=248054,
        tie_word_embeddings=False,
        text_config={
            "attention_bias": False,
            "attention_dropout": 0.0,
            "attn_output_gate": True,
            "eos_token_id": 248044,
            "full_attention_interval": 4,
            "head_dim": 16,
            "hidden_act": "silu",
            "hidden_size": 64,
            "initializer_range": 0.02,
            "layer_types": [
                "linear_attention",
                "linear_attention",
                "linear_attention",
                "full_attention",
                "linear_attention",
                "linear_attention",
                "linear_attention",
                "full_attention",
            ],
            "linear_conv_kernel_dim": 4,
            "linear_key_head_dim": 16,
            "linear_num_key_heads": 2,
            "linear_num_value_heads": 4,
            "linear_value_head_dim": 16,
            "max_position_embeddings": 1024,
            "mlp_only_layers": [],
            "model_type": "qwen3_5_moe_text",
            "moe_intermediate_size": 32,
            "mtp_num_hidden_layers": 1,
            "mtp_use_dedicated_embeddings": False,
            "num_attention_heads": 4,
            "num_experts": 8,
            "num_experts_per_tok": 2,
            "num_hidden_layers": 8,
            "num_key_value_heads": 2,
            "rms_norm_eps": 1e-06,
            "router_aux_loss_coef": 0.001,
            "shared_expert_intermediate_size": 32,
            "use_cache": True,
            "vocab_size": 248320,
            "mamba_ssm_dtype": "float32",
            "rope_parameters": {
                "mrope_interleaved": True,
                "mrope_section": [3, 3, 2],
                "rope_type": "default",
                "rope_theta": 10000000,
                "partial_rotary_factor": 0.25,
            },
        },
        vision_config={
            "deepstack_visual_indexes": [],
            "depth": 2,
            "hidden_act": "gelu_pytorch_tanh",
            "hidden_size": 64,
            "in_channels": 3,
            "initializer_range": 0.02,
            "intermediate_size": 128,
            "model_type": "qwen3_5_moe",
            "num_heads": 4,
            "num_position_embeddings": 2304,
            "out_hidden_size": 64,
            "patch_size": 16,
            "spatial_merge_size": 2,
            "temporal_patch_size": 2,
        },
    )

    model = Qwen3_5MoeForConditionalGeneration(config).to(0)
    return model

def main():
    inputs = torch.arange(10).view(1, -1).to(0)
    model = create_small_qwen()
    with torch.inference_mode():
        out_base = model(inputs).logits

    config = LoraConfig(
        target_modules=["q_proj", "v_proj", "in_proj_qkv"],
        target_parameters=["experts.gate_up_proj", "experts.down_proj"],
        init_lora_weights=False,
    )
    torch.manual_seed(0)
    model = get_peft_model(model, config)
    model.print_trainable_parameters()
    with torch.inference_mode():
        out_lora = model(inputs).logits

    path = "/tmp/peft/qwen3_5moe"
    model.save_pretrained(path)

    del model

    model = create_small_qwen()
    model = PeftModel.from_pretrained(model, path)
    with torch.inference_mode():
        out_loaded = model(inputs).logits

    assert not torch.allclose(out_base, out_lora)
    assert torch.allclose(out_lora, out_loaded)

if __name__ == "__main__":
    main()

@jeejeelee
Copy link
Copy Markdown

jeejeelee commented Apr 1, 2026

GPT-OSS

Let me descride the gpt-oss20b first.

the gpt-oss config is :rank=8, moe_number =32,moe_intermediate_size=2880 hidden_size=2880
and the shape of expert lora as folllowing, these shapes are correct

name shape note
prefix.experts.base_layer.lora_A.weight [256, 2880] [8x32,2880]
prefix.experts.base_layer.lora_B.weight [5760, 256] [2880x2,8x32]
prefix.experts.lora_A.weight [256, 2880] [8x32,2880]
prefix.experts.lora_B.weight [2880, 256] [2880,8x32]

Qwen35

When I tried to generate lora weights for qwen35-moe , I got the incorrect shape.

  • qwen35-35ba3b config: rank=8 moe_number=256 moe_intermediate_size=512,hidden_size=2048,
name shape note
prefix.experts.base_layer.lora_A.weight [2048, 1024] should be [8x256,2048]
prefix.experts.base_layer.lora_B.weight [2048, 2048] should be [512x2, 8x256]
prefix.experts.lora_A.weight [2048, 2048] should be [256x8,512]
prefix.experts.lora_B.weight [512, 2048] should be [2048,8x256]

Reproduce code:

import os
import torch
from transformers import Qwen3_5MoeForConditionalGeneration
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType


model_name_or_path="Qwen/Qwen3.5-35B-A3B"

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=True,
    r=8,
    lora_alpha=32,
    target_modules="all-linear",
    use_rslora=False,
    use_dora=False,
    target_parameters=[
        "mlp.experts.gate_up_proj",
        "mlp.experts.down_proj",
    ],
)

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).eval()
print(model)
model = get_peft_model(model, peft_config)

print(model)

model.save_pretrained(
    save_directory="qwen35-moe-lora-moe",
    safe_serialization=True,
    save_embedding_layers=False,
)
  • transformers version :5.4.0
  • peft: this branch

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@jeejeelee Thank you for providing further information. So IIUC, what you take issue with is that the shapes of the LoRA weights are not what you expect -- there is no actual error from running the code, just the shapes look incorrect.

Regarding these shapes, they can be different from what you expect for a few reasons. First of all, even for expert layers with 3-dim parameters, PEFT flattens out the parameters to 2-dim. This is for keeping with PEFT conventions of using an nn.Linear.

self.lora_A[adapter_name] = nn.Linear(self.in_features, r * self.num_experts, bias=False)
self.lora_B[adapter_name] = nn.Linear(r * self.num_experts, self.out_features, bias=lora_bias)

Furthermore, for some models, we have to deal with situations where the original weights were fused and we have to keep the checkpoint compatible. In that case, we employ different fusing strategies, which are defined in transformers_weight_conversion.py of this PR. The picture below illustrates that:

lora-fusing-2

This shouldn't be relevant to Qwen3.5, as it didn't have a change in weight structure, but it's good to be aware that it can factor in.

Let's look at one concrete example of the experts in Qwen3.5 MoE, model.model.language_model.layers[0].mlp.experts.gate_up_proj. We have:

  • gate_up_proj.shape: 256, 1024, 2048 (experts, in_features, moe_intermediate_size)
  • lora_A.default.weight.shape: 2048, 1024 (2048 = 8 * 256 = lora rank * num_experts, in_features)
  • lora_B.default.weight.shape: 2024, 2028 (moe_intermediate_size, 2048 = 8 * 256 = lora rank * num_experts)

These shapes thus look correct to me.

All of the above can lead to shapes that are unexpected at first. However, with the testing we do, we should hopefully ensure that these operations are correctly applied. If you have a concrete example where the model output is incorrect or the model doesn't train as expected, please let us know, just be aware that pure shapes can be misleading.

(Btw.: trust_remote_code=True is unnecessary here and it should only be used if absolutely required)

Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this offline often enough. LGTM.

@BenjaminBossan BenjaminBossan merged commit 5356277 into huggingface:main Apr 1, 2026
10 checks passed
@BenjaminBossan BenjaminBossan deleted the transformers-weight-conversion-additions branch April 1, 2026 13:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants