Skip to content

[BUG]: Parameters missing in the state_dict output of ZeroDPP module #2361

@eric8607242

Description

@eric8607242

🐛 Describe the bug

Hello,
I currently fine-tune the Huggingface GPT2 with ColossalAI. I follow the example with GeminiDPP and ZeroOptimizer. However, I found that there are some keys missing issue when I load the checkpoint stored by the ColossalAI to the GPT2 model.

The following codebase can reproduce the issue:

import colossalai
from colossalai.tensor import ProcessGroup, ShardSpec, ColoParameter
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import save_checkpoint
from colossalai.utils.model.colo_init_context import ColoInitContext

from transformers import GPT2LMHeadModel
import torch

if __name__ == "__main__":
    device = "cuda:0"
    tp_degree = 1
    shardinit = True
    placement_policy = "cpu"
    path_to_checkpoint = "./test.pth"

    colossalai.launch_from_torch(config={})

    default_pg = ProcessGroup(tp_degree=tp_degree)
    default_dist_spec = ShardSpec([-1], [tp_degree]) if shardinit else None
    with ColoInitContext(
        device,
        dtype=torch.half,
        default_dist_spec=default_dist_spec,
        default_pg=default_pg
    ):
        model = GPT2LMHeadModel.from_pretrained("gpt2")

    model = GeminiDDP(
        model, device=device,
        placement_policy=placement_policy,
        pin_memory=True, hidden_dim=768,
        search_range_mb=64
    )

    save_checkpoint(path_to_checkpoint, epoch=0, model=model)

    original_model = GPT2LMHeadModel.from_pretrained("gpt2")
    checkpoint = torch.load(path_to_checkpoint)
    original_model.load_state_dict(checkpoint["model"])

And the error message is shown as the follow:

RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel:
        Missing key(s) in state_dict: "lm_head.weight". 

Environment

Os:
Ubuntu 22.04

GPU:
NVIDIA GeForce RTX 3090

Package list:
pytorch 2.0.0_cuda11.6_cudnn8.3.2_0
cuda-toolkit 11.6.1
colossalai 0.2.0+torch2.0cu11.7

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions