Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
177 changes: 177 additions & 0 deletions colossalai/auto_parallel/offload/amp_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import Dict, Tuple
from enum import Enum
import torch
from torch.optim import Optimizer

from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device

from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region


class OptimState(Enum):
SCALED = 0
UNSCALED = 1

class AMPOptimizer(ColossalaiOptimizer):

"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py

Args:
optimizer (Optimizer): An Optimizer instance.
module (BaseOffloadModule): A ``BaseOffloadModule`` instance.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""

def __init__(self,
optimizer: Optimizer,
module: BaseOffloadModule,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0):

super().__init__(optimizer)

self.module = module
self.optim_state = OptimState.UNSCALED
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm

self.region_manager: RegionManager = self.module.region_manager
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
self.param_to_region: Dict[torch.nn.Parameter, Region] = dict()

self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict()

if self.clipping_flag:
assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now"

self.__init__optimizer()

# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()

def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]

fake_param.data = region.cpu_grad[begin:end]
fake_param.grad = fake_param.data
fake_param.data = region.fp32_data[begin:end]

def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None

def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter.item())
return self._found_overflow.item() > 0

def _get_combined_scale(self):
loss_scale = 1

if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED

combined_scale = loss_scale

if combined_scale == 1:
return -1
else:
return combined_scale

@property
def loss_scale(self):
return self.grad_scaler.scale.item()

def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = torch.cuda.IntTensor([0])
return self.optim.zero_grad(set_to_none=True)

def step(self, *args, **kwargs):
# Copy gradients from model params to main params.
self._set_grad_ptr()

found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self.zero_grad() # reset all gradients
self._update_fp16_params()
return

# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)

ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self.zero_grad()
self._update_fp16_params()
return ret

def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
raise NotImplementedError

def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.module.backward(loss)

def __init__optimizer(self):

for group in self.optim.param_groups:
fake_params_list = list()

for param in group['params']:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
self.param_to_region[fake_param] = region
fake_params_list.append(fake_param)

# Reset existing state dict key to the new main param.
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)

group['params'] = fake_params_list

# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict())
109 changes: 109 additions & 0 deletions colossalai/auto_parallel/offload/base_offload_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Optional, Set
from functools import partial
import torch
import torch.nn as nn

from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage

from .region_manager import RegionManager
from .util import GlobalRuntimeInfo


class BaseOffloadModule:
"""
BaseOffloadModule: A model wrapper for parameter offloading.

Args:
model (nn.Module): model to apply offloading.
region_manager (RegionManager): a ``RegionManager`` instance.
is_sync (bool): synchronous mode or not.
"""

def __init__(self,
model: nn.Module,
region_manager: RegionManager,
is_sync=True):

self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
self.overflow_counter = torch.cuda.IntTensor([0])

self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream

self._cast_buffers()

def register_grad_hook(self):
for p in self.model.parameters():
if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))

def remove_grad_hook(self):
for hook in self.grad_hook_list:
hook.remove()

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def _pre_forward(self):
self.register_grad_hook()
for region in self.region_manager.region_list:
region.cpu_grad = None

def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.model.zero_grad(set_to_none=True)
self._pre_forward()
outputs = self.model(*args, **kwargs)
return outputs

def backward(self, loss):
loss.backward()
self._post_backward()

def _post_backward(self):
torch.cuda.synchronize()
self.remove_grad_hook()

for p in self.model.parameters():
p.grad = None

GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()

def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
region = self.region_manager.get_region(p)
region.copy_grad_to_region_slice(p, grad)
if region.can_release:
self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu()
return empty_grad

def _cast_buffers(self):
for buffer in self.model.buffers():
buffer.data = buffer.cuda()

def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)

def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.model.named_parameters(prefix, recurse)

def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.model.named_buffers(prefix, recurse)

def named_children(self):
return self.model.named_children()

def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.model.named_modules(memo, prefix, remove_duplicate)
49 changes: 49 additions & 0 deletions colossalai/auto_parallel/offload/mem_optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict
import torch
import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map

from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp

from .region_manager import RegionManager
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
from .base_offload_module import BaseOffloadModule
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo

def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
memory_budget: float = -1.0,
solver_name: str = 'asyn'):

model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
meta_args = tree_map(wrap_fn, inps)
graph = tracer.trace(model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
interp = MetaInfoProp(gm)
interp.propagate(*meta_args.values())

region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
region_manager._build_regions()
GlobalRuntimeInfo.region_list = region_manager.region_list

act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")

if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
elif solver_name == 'asyn':
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")

gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
return optimized_model
Loading