From d13f02c0781bb57ab5f44f7badfa165bf2fb0977 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 11 Jun 2024 16:10:40 +0800 Subject: [PATCH] [test] fix chatglm test kit --- tests/kit/model_zoo/transformers/chatglm2.py | 31 ++++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index f443553bbd32..9a7cf34c1195 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -33,22 +33,6 @@ def data_gen_for_conditional_generation(): ) loss_fn = lambda x: x["loss"] -config = AutoConfig.from_pretrained( - "THUDM/chatglm2-6b", - trust_remote_code=True, - num_layers=2, - padded_vocab_size=65024, - hidden_size=64, - ffn_hidden_size=214, - num_attention_heads=8, - kv_channels=16, - rmsnorm=True, - original_rope=True, - use_cache=True, - multi_query_attention=False, - torch_dtype=torch.float32, -) - infer_config = AutoConfig.from_pretrained( "THUDM/chatglm2-6b", @@ -68,6 +52,21 @@ def data_gen_for_conditional_generation(): def init_chatglm(): + config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + ffn_hidden_size=214, + num_attention_heads=8, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + multi_query_attention=False, + torch_dtype=torch.float32, + ) model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True) for m in model.modules(): if m.__class__.__name__ == "RMSNorm":