diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 9f2aa7e645f3..91380e243fb8 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -424,10 +424,7 @@ def main(): PLACEMENT_POLICY = 'auto' if version.parse(cai_version) >= version.parse("0.3.1"): from colossalai.zero import GeminiDDP - model = GeminiDDP(model, - chunk_init_device=get_current_device(), - placement_policy=PLACEMENT_POLICY, - pin_memory=True) + model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True) elif version.parse(cai_version) > version.parse("0.1.10"): try: from colossalai.nn.parallel import GeminiDDP diff --git a/examples/tutorial/opt/opt/test_ci.sh b/examples/tutorial/opt/opt/test_ci.sh index e505da1364de..431b37c12004 100755 --- a/examples/tutorial/opt/opt/test_ci.sh +++ b/examples/tutorial/opt/opt/test_ci.sh @@ -4,9 +4,9 @@ set -xue pip install -r requirements.txt -BS=8 +BS=4 MEMCAP=0 -GPUNUM=2 +GPUNUM=4 MODLE="facebook/opt-125m" torchrun \