Skip to content

[BUG]: colossalai 0.3.3 + torch 2.0.1 + baichuan-2 7b 训练保存 lr_scheduler 时会报 NotImplementedError 错 #4829

@airlsyn

Description

@airlsyn

🐛 Describe the bug

用 colossalai 0.3.3 + torch 2.0.1 + baichuan-2 7b 训练保存 lr_scheduler 时 colossalai/nn/lr_scheduler/delayed.py 会报 NotImplementedError 错。

In [25]: lr_scheduler
Out[25]: <colossalai.nn.lr_scheduler.cosine.CosineAnnealingWarmupLR at 0x7f01cd616e00>
In [26]: booster.save_lr_scheduler(lr_scheduler, "/data/checkpoint/lr_scheduler")

 in <module>:1                                                                                    
                                                                                                  
 python3.10/site-packages/colossalai/booster/booster.py:308 in            
 save_lr_scheduler                                                                                
                                                                                                  
   305          lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.                 
   306          checkpoint (str): Path to the checkpoint. It must be a local file path.        
   307       """                                                                                
 ❱ 308       self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)                     
   309                                                                                           
   310    def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:       
   311       """Load lr scheduler from checkpoint.                                              
                                                                                                  
 python3.10/site-packages/colossalai/booster/plugin/gemini_plugin.py:225  
 in save_lr_scheduler                                                                             
                                                                                                  
   222       Save model to checkpoint but only on master process.                               
   223       """                                                                                
   224       if self.coordinator.is_master():                                                   
 ❱ 225          super().save_lr_scheduler(lr_scheduler, checkpoint)                            
   226                                                                                            
   227                                                                                            
   228 class GeminiPlugin(DPPluginBase):                                                          
                                                                                                  
 python3.10/site-packages/colossalai/checkpoint_io/checkpoint_io_base.py: 
 318 in save_lr_scheduler                                                                         
                                                                                                  
   315          lr_scheduler (LRScheduler): lr scheduler to be saved.                          
   316          checkpoint: checkpoint path. The checkpoint path can only be a file path.      
   317       """                                                                                
 ❱ 318       torch.save(lr_scheduler.state_dict(), checkpoint)                                  
   319                                                                                           
   320    def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):               
   321       """                                                                                
                                                                                                  
 python3.10/site-packages/colossalai/nn/lr_scheduler/delayed.py:93 in     
 state_dict                                                                                       
                                                                                                  
    90          state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dic   
    91          del state_dict["after_scheduler"]                                              
    92       else:                                                                              
 ❱  93          raise NotImplementedError()                                                    
    94       return state_dict                                                                  
    95                                                                                           
    96    def get_lr(self):

进一步分析 lr_scheduler 里的信息

state_dict = {key: value for key, value in lr_scheduler.__dict__.items() if key not in "optimizer"}

# =>
{
  'warmup_epochs': 2000,
 'after_scheduler': <torch.optim.lr_scheduler.CosineAnnealingLR at 0x7f01cd6173a0>,
 'finished': False,
 'base_lrs': [0.0003],
 'last_epoch': 1,
 'verbose': False,
 '_step_count': 2,
 '_get_lr_called_within_step': False,
 '_last_lr': [3e-07]
}
  • 其中 after_scheduler 是 torch.optim.lr_scheduler.CosineAnnealingLR 的实例,而 torch.optim.lr_scheduler.CosineAnnealingLR 是继承的 LRScheduler,那么 after_scheduler 的父类是 LRScheduler

  • _LRScheduler 是继承了 LRScheduler

  • 而在 save lr scheduler 时(delayed.py) 中,是 isinstance(state_dict['after_scheduler'], _LRScheduler)

from torch.optim.lr_scheduler import _LRScheduler, LRScheduler

isinstance(state_dict['after_scheduler'], LRScheduler)

# => True

isinstance(state_dict['after_scheduler'], _LRScheduler)

# => False

那这样,是否说明 应该用 LRScheduler 而不是 _LRScheduler 呢?

注:baichuan-2 依赖 torch 2.0+,不能降到 2.0 以下(用 1.13 会报 TypeError: sdp_kernel() got an unexpected keyword argument 'enable_mem_efficient')

Environment

  • colossalai 0.3.3
  • torch 2.0.1
  • baichuan-2 7b

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions