Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand All @@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
return x


class HybridParallelModule(ModelWrapper):
class HybridParallelModule(ModelWrapper, AMPModelMixin):
def __init__(
self,
module: Module,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def check_gemini_plugin(
"transformers_falcon_for_sequence_classification",
"transformers_falcon_for_token_classification",
"transformers_falcon_for_question_answering",
"transformers_gptj_lm", # lead to OOM when running in ci
"transformers_gptj_for_question_answering",
"transformers_gptj_for_sequence_classification",
]:
continue

Expand Down