Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from functools import partial
from typing import Optional
Expand Down Expand Up @@ -198,7 +199,7 @@ def _create_master_param_current_rank(self, param_list):
params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device()

for param in reversed(param_list):
for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
self._param_store.record_param_padding_size(param, padding_size)

Expand Down Expand Up @@ -468,3 +469,68 @@ def no_sync(self):
yield
finally:
self.require_grad_sync = old_require_grad_sync

##############
# State Dict #
##############
def _pack_state(self, state: dict) -> dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0

def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed

param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}

return {'state': packed_state, 'param_groups': param_groups}

def state_dict(self) -> dict:
"""Return a state_dict same with DDP

Returns:
dict: the pytorch form state_dict
"""
zero_state = dict()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(gather_tensor, v, group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
zero_state[param][k] = param_state

states_dict = self._pack_state(zero_state)

return states_dict

def load_state_dict(self, state_dict: dict):
"""Load state dict, requires the state_dict be the pytorch form

Args:
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict = copy.deepcopy(state_dict)
for param_idx, state in zero_state_dict['state'].items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()

self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
121 changes: 121 additions & 0 deletions tests/test_zero/test_low_level/test_zero_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import copy

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer


class MlpModel(nn.Module):

def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(12, 24)
self.linear2 = nn.Linear(24, 12)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3

a = a.detach().to(dtype)
b = b.detach().to(dtype)

assert_close(a, b, rtol=rtol, atol=atol)


def exam_zero_1_torch_ddp_ckpt():
"""
We examine the state_dict of zero and DDP.
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
"""
local_rank = torch.distributed.get_rank()
seed_all(1453)

# create models
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model)

torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()

# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)

# we only test stage 1 here
# the state dicts of stage 1 and stage 2 are the same
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144)

torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)

seed_all(1453 + local_rank)
# create
input_data = torch.rand(4, 12).cuda()

# forward
zero_output = zero_model(input_data)
torch_output = torch_model(input_data)

# backward
zero_optimizer.backward(zero_output.mean().float())
torch_output.mean().backward()

# step
zero_optimizer.step()
torch_optimizer.step()

torch_state_dict = torch_optimizer.state_dict()
zero_state_dict = zero_optimizer.state_dict()

# examine the original state dict
for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
for t_v, z_v in zip(torch_state.values(), zero_state.values()):
loose_close(t_v, z_v)

# empty the optimzer state
zero_optimizer.optim.state = []

# zero load a torch checkpoint
zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict))
zero_state_dict = zero_optimizer.state_dict()

# examine the loaded state dict
for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
for t_v, z_v in zip(torch_state.values(), zero_state.values()):
loose_close(t_v, z_v)


def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

exam_zero_1_torch_ddp_ckpt()


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_ckpt():
spawn(run_dist, 2)


if __name__ == '__main__':
test_zero_ckpt()