diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 58b5b0487a82..08c05e9063bf 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,6 +65,7 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, + attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 27f904292597..2a10d86c79bb 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -32,7 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, model_fn, loss_fn, test_config ) if enable_gradient_checkpointing: - org_model.gradient_checkpointing_enable() + # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(