Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(self,
assert isinstance(p, ColoParameter)

if is_ddp_ignored(p):
p.data = p.data.half()
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue

fp32_data = p.data.float()
Expand Down Expand Up @@ -451,8 +451,14 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."

param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
ddp_param_list.append((name, param))
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
Expand Down Expand Up @@ -588,8 +594,16 @@ def load(param_name, dest_tensor, copy_func):
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())

ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
else:
ddp_param_list.append((name, param))

fp32_to_name = dict()
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
fp32_to_name[fp32_p] = name

Expand Down
12 changes: 10 additions & 2 deletions tests/test_gemini/update/test_zeroddp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing import assert_close

import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
Expand All @@ -17,6 +18,13 @@
from tests.test_tensor.common_utils import debug_print, set_seed


def ignore_the_first_parameter(model: torch.nn.Module):
for name, param in model.named_parameters():
print(f"parameter `{name}` is set ignored")
ZeroDDP.set_params_to_ignore([param])
return


@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert'])
Expand Down Expand Up @@ -47,7 +55,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
for key, value in torch_dict.items():
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)


@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
Expand Down Expand Up @@ -84,7 +92,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
for key, value in torch_dict.items():
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)


def run_dist(rank, world_size, port):
Expand Down