From 9a51d4f0e274cc8ec3bd7e63ea652190857446cc Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 29 Nov 2023 16:57:29 +0800 Subject: [PATCH 1/4] fix 3d checkpoint load when booster boost without optimizer fix 3d checkpoint load when booster boost without optimizer --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bbc36ceab2ec..040fb7d57796 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): return x -class HybridParallelModule(ModelWrapper): +class HybridParallelModule(ModelWrapper, AMPModelMixin): def __init__( self, module: Module, From 17f0a43648ca259300e597d3bcd9c5e0a65b50e2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 30 Nov 2023 11:55:16 +0800 Subject: [PATCH 2/4] test ci --- .github/workflows/build_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index e2114d43bcd0..05e2d396c2dd 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/test_booster/test_plugin/test_gemini_plugin.py env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 From f68d36cd2fc4c6d6d7dff6924c5936d8e2474eb7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 30 Nov 2023 13:45:15 +0800 Subject: [PATCH 3/4] revert ci --- .github/workflows/build_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 05e2d396c2dd..e2114d43bcd0 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/test_booster/test_plugin/test_gemini_plugin.py + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 From 9561e4ff15218a33ceab5c44691b3d9af0c202c2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 30 Nov 2023 14:04:14 +0800 Subject: [PATCH 4/4] fix fix --- tests/test_booster/test_plugin/test_gemini_plugin.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 3c496ff64755..d4205e1f9d73 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -116,6 +116,9 @@ def check_gemini_plugin( "transformers_falcon_for_sequence_classification", "transformers_falcon_for_token_classification", "transformers_falcon_for_question_answering", + "transformers_gptj_lm", # lead to OOM when running in ci + "transformers_gptj_for_question_answering", + "transformers_gptj_for_sequence_classification", ]: continue