Skip to content

[BUG]: TypeError when saving model with tensor parallelism using ColossalAI #5006

@imgaojun

Description

@imgaojun

🐛 Describe the bug

I am encountering a TypeError when trying to save a model that uses tensor parallelism with ColossalAI. Specifically, the error occurs when calling booster.save_model() with the shard=True option.

booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)                                                                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/colossalai/booster/booster.py", line 261, in save_model                                                                                                      
    self.checkpoint_io.save_model(                                                                                                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/colossalai/checkpoint_io/checkpoint_io_base.py", line 135, in save_model                                                                                     
    _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 714, in _save
    self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
  File "/usr/local/lib/python3.10/dist-packages/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py", line 243, in save_sharded_model
    total_size = save_state_dict_shards(
  File "/usr/local/lib/python3.10/dist-packages/colossalai/checkpoint_io/utils.py", line 245, in save_state_dict_shards
    save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
pickler.dump(obj)  File "/usr/local/lib/python3.10/dist-packages/colossalai/checkpoint_io/utils.py", line 315, in save_state_dict

TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroup' object

Environment

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions