Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def save_sharded_model(

Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)

Expand Down Expand Up @@ -257,6 +257,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): master weights. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
Expand Down Expand Up @@ -296,6 +297,7 @@ def __init__(
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
master_weights: bool = True,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
Expand Down Expand Up @@ -334,6 +336,7 @@ def __init__(
min_chunk_size_m=min_chunk_size_m,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
Expand Down
6 changes: 2 additions & 4 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def step(self, closure=None, div_scale: float = -1):
target_device = p.device
if len(state) == 0:
state["step"] = 0

# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
# gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances
Expand All @@ -148,7 +145,8 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16:
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
6 changes: 2 additions & 4 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ def step(self, closure=None, div_scale: float = -1):
target_device = p.device
if len(state) == 0:
state["step"] = 0

# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
# gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances
Expand All @@ -125,7 +122,8 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16:
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
6 changes: 5 additions & 1 deletion colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"


def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
assert len(list(d1.keys())) == len(
list(d2.keys())
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
Expand All @@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
if not ignore_device:
v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu")
if ignore_dtype:
v1_i = v1_i.to(v2_i.dtype)
assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict)
Expand All @@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
if not ignore_device:
v1 = v1.to("cpu")
v2 = v2.to("cpu")
if ignore_dtype:
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"
Expand Down
52 changes: 50 additions & 2 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def __init__(
self.l2_norm_flag = False
self.l2_norm = None

self.grad_chunk = None

@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
Expand Down Expand Up @@ -414,7 +416,9 @@ def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) ->
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)

def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
def copy_tensor_to_chunk_slice(
self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True
) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.

Expand All @@ -427,7 +431,8 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Ten

tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)

def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload."""
Expand Down Expand Up @@ -577,3 +582,46 @@ def print_tensor(tensor, prefix=""):
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))

return "".join(output)

def init_grad_chunk(self) -> "Chunk":
"""Init grad chunk. This should be called in grad handler.

Returns:
Chunk: Grad chunk
"""
if self.grad_chunk is None:
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors
for tensor, state in self.tensors_info.items():
grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end)

grad_chunk.valid_end = self.valid_end

if grad_chunk.chunk_temp.device.type == "cpu":
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
else:
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
grad_chunk.chunk_temp = None

if grad_chunk.pin_memory:
grad_chunk.cpu_shard = torch.empty(
grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory
)

self.grad_chunk = grad_chunk
else:
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
alloc_storage(self.grad_chunk.cuda_global_chunk)

return self.grad_chunk
10 changes: 10 additions & 0 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,13 @@ def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem

def init_grad_chunk(self, chunk: Chunk) -> Chunk:
if chunk.grad_chunk is not None:
self.__sub_memory_usage(chunk.grad_chunk.memory_usage)
grad_chunk = chunk.init_grad_chunk()
self.__add_memory_usage(grad_chunk.memory_usage)
if grad_chunk not in self.accessed_chunks:
self.accessed_chunks.add(grad_chunk)
self.accessed_mem += grad_chunk.chunk_mem
return grad_chunk
Loading