diff --git a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py index af7e3476d..d88782ede 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py +++ b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py @@ -34,10 +34,9 @@ def model_provider(): """Build the model.""" print_rank_0('building GPT2 model ...') - with deepspeed.zero.InitContext(data_parallel_group=mpu.get_data_parallel_group(), - zero_modules=True, - remote_device=get_args().remote_device, - enabled=get_args().zero_stage==3): + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=get_args().remote_device, + enabled=get_args().zero_stage==3): model = GPT2Model(num_tokentypes=0, parallel_output=True) return model