From 4a0b2de5d2ba6a2b75282ca3a4dc919777c38893 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 23 Apr 2024 23:25:11 +0800 Subject: [PATCH] [test] fix llama model test --- tests/kit/model_zoo/transformers/llama.py | 1 + tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 58b5b0487a82..08c05e9063bf 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,6 +65,7 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, + attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 27f904292597..2a10d86c79bb 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -32,7 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, model_fn, loss_fn, test_config ) if enable_gradient_checkpointing: - org_model.gradient_checkpointing_enable() + # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(