diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ed3a61dede56..91fcba55a0aa 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): return x -class HybridParallelModule(ModelWrapper): +class HybridParallelModule(ModelWrapper, AMPModelMixin): def __init__( self, module: Module, diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 3c496ff64755..d4205e1f9d73 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -116,6 +116,9 @@ def check_gemini_plugin( "transformers_falcon_for_sequence_classification", "transformers_falcon_for_token_classification", "transformers_falcon_for_question_answering", + "transformers_gptj_lm", # lead to OOM when running in ci + "transformers_gptj_for_question_answering", + "transformers_gptj_for_sequence_classification", ]: continue