Skip to content

Conversation

@bukejiyu
Copy link
Collaborator

@bukejiyu bukejiyu commented Aug 14, 2025

依赖 paddle develop 2025.8.17及以后版本

支持动态量化loader
可以观察到 moe系列模型 load性能有 30%左右的提升
模型支持[qwen3/qwen3moe/deepseekv3/ernie text/ernie vl]

模型 模型类型 精度 tpsize loader 内存峰值 load耗时 精度
qwen3 32B qwen3 wint8 4 default 4×22.7G 19.78s 逐token对齐
inflight_quant 4×7.4G 18.96s
Qwen3-235B-A22B qwen3moe wint4 4 default 116×4G 368.767s 逐token对齐
inflight_quant 4×9.2G 261.65s
Qwen3-30B-A3B qwen3moe wint4 4 default 75.764s 逐token对齐
inflight_quant 53.599s
# offline_demo.py
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
#支持 deepseek/qwen2/qwen3/qwen3moe/ernie/ernievl 等模型
model_name_or_path = "model_path"
sampling_params = SamplingParams(temperature=0.1, max_tokens=30, top_p=0)
#quantization="wint4"/quantization="wint8"
llm = LLM(model=model_name_or_path,num_gpu_blocks_override=1024, tensor_parallel_size=4, quantization="wint4"load_choices="default_v1")
output = llm.generate(prompts="who are you",
                      use_tqdm=True,
                      sampling_params=sampling_params)
print(output)

@paddle-bot
Copy link

paddle-bot bot commented Aug 14, 2025

Thanks for your contribution!

@bukejiyu bukejiyu force-pushed the test_dyn_quant branch 3 times, most recently from 9d2f94f to 742318c Compare August 18, 2025 14:00
@bukejiyu bukejiyu changed the title Test dyn quant [feat]support inflight quant Aug 18, 2025
@bukejiyu bukejiyu changed the title [feat]support inflight quant [feat]support dyn quant Aug 18, 2025
YuanRisheng
YuanRisheng previously approved these changes Aug 20, 2025
elif args.quantization != "None":
quantization_config = {}
if load_config.load_choices == LoadChoices.DEFAULT_V1:
quantization_config["is_dyn_quant"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个一直设为True合理吗,因为离线量化也有quantization_config,这个区分可能只能在quant_config内部来区分?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

离线量化走不进 args.quantization != "None" 这个分支,而且量化相关的权重创建也都在 quantmethod,几乎只能拿到quantconfig了吧,不好在别的地方来区分

@bukejiyu bukejiyu requested review from YuanRisheng and removed request for risemeup1 August 20, 2025 04:41
Comment on lines -96 to +99
layer.up_gate_proj_weight,
layer.down_proj_weight,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里希望保持原样,更清晰

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但是量化前后名字不同

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有可能会不一样因为 w4a8也继承这个方法 w4a8目前还没确定咋弄

Comment on lines 654 to 655
self.ffn1_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.ffn2_scale_shape = [layer.num_local_experts, layer.hidden_size]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有地方不要出现ffn1/ffn2字样,就用up_gate/down

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

is_channel_wise: bool = False,
has_zero_point: bool = False,
is_permuted: bool = True,
is_dyn_quant: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这个pr中去掉所有is_dyn_quant这个变量

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但是没法区分 创建的权重是啥 这个只能改名字

Comment on lines 259 to 260
layer.weight.value().get_tensor()._clear()
del layer.weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两行封装成一个utils函数调用

Comment on lines 765 to 775
if (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
else:
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

即然都一样,为何要多此一举呢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if w.dtype != self.weight_dtype:
w = w.cast(self.weight_dtype)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个weight_loader为什么删掉?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVBatchLinear有自己的weight_loader会有什么问题?

Copy link
Collaborator Author

@bukejiyu bukejiyu Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8b137a10f180393555011459b9ac1e86这2个权重复用同一个磁盘权重,self.kv_b_proj_bmm 只是 self.kv_b_proj权重做了些额外处理,他没必要loader,甚至他都不该有权重


model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
if "kv_b_proj" in model_sublayer_name:
kv_model_sublayer_name = model_sublayer_name.replace("kv_b_proj", "kv_b_proj_bmm")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块为何key对不上?

Copy link
Collaborator Author

@bukejiyu bukejiyu Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image self.kv_b_proj_bmm 和 self.kv_b_proj 复用同一个权重,所以循环磁盘权重时 找不到 kv_b_proj_bmm,但是 kv_b_proj_bmm又应该在 kv_b_proj load之后 进行一些操作。在量化场景,kv_b_proj需要量化 但是 kv_b_proj_bmm需要保持 bf16,所以需要在load结束就去 处理 kv_b_proj_bmm

Comment on lines 163 to 170
if fd_config.model_config.moe_use_aux_free:
self.e_score_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.moe_num_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
else:
self.e_score_correction_bias = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么改完了,e_score_correction_bias的key变了吧?RL的names_mapping会有问题?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我改回去了 这部分改回去 maping影响不大

@bukejiyu bukejiyu force-pushed the test_dyn_quant branch 2 times, most recently from 4c89684 to 2429cbe Compare August 22, 2025 06:39
self.infer_to_train_mapping[
f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.gate_correction_bias"
] = f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate_correction_bias"] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里bias不区分文本和视觉了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

区分,但是load的时候按 fuse在一起的load,组网自己会把bias paramter分割。就是RL的fuse在一起的权重set到这里就行

**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_dim 这些标记,TP 并行、EP 并行都支持吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ep磁盘权重用不到 output_dim这个属性,tp支持

@@ -0,0 +1,89 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以 test_开头

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conftest.py 是管理 pytest 配置和 fixture 的不是测试文件,conftest.py会被pytest自动识别

2,
1024,
marks=[pytest.mark.core_model],
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦在这里加个"Qwen2-7B-Instruct"的模型,把#3502 这个PR的内容一起测试吧~

all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping

params_dict = dict(self.named_parameters())
after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_weights_after_loading -> process_weights_after_loading_fn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

yuanlehome
yuanlehome previously approved these changes Aug 22, 2025
@yuanlehome yuanlehome changed the title [feat]support dyn quant [V1 Loader] support weight_only Aug 22, 2025
@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 77514e3 into PaddlePaddle:develop Aug 23, 2025
13 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants