From c84902c980854bc2ae1802e19af00d3d53536596 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Thu, 11 Nov 2021 17:09:46 +0800 Subject: [PATCH] Changed API in Schedule, Engine --- colossalai/engine/__init__.py | 2 +- colossalai/engine/_base_engine.py | 153 +++-- colossalai/engine/amp/__init__.py | 2 + colossalai/engine/amp/amp_type.py | 10 + colossalai/engine/amp/grad_scaler.py | 577 +++++++++++++++++++ colossalai/engine/schedule/_base_schedule.py | 146 ++--- colossalai/engine/schedule/_no_pipeline.py | 109 ++-- colossalai/engine/schedule/_pipeline.py | 91 +-- colossalai/engine/schedule/_utils.py | 11 + 9 files changed, 812 insertions(+), 289 deletions(-) create mode 100644 colossalai/engine/amp/__init__.py create mode 100644 colossalai/engine/amp/amp_type.py create mode 100644 colossalai/engine/amp/grad_scaler.py diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py index c00be7df6e7b..7e55922363d8 100644 --- a/colossalai/engine/__init__.py +++ b/colossalai/engine/__init__.py @@ -1,7 +1,7 @@ -from .amp_type import AMP_TYPE from ._base_engine import Engine from .gradient_handler import * from .schedule import * +from .amp import * __all__ = ['Engine'] diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 843ef1d4f046..3e87864c8a37 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -12,8 +12,6 @@ from torch.nn import Module from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader from .schedule import BaseSchedule, NoPipelineSchedule @@ -21,39 +19,33 @@ class Engine: """Basic engine class for training and evaluation. It runs a specific process method :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. + It controls a iteration in training. - :param train_dataloader: Dataloader in training - :param test_dataloader: Dataloader in evaluation :param model: The neural network model - :param criterion: Criterion for calculating loss :param optimizer: Optimizer for updating the parameters - :param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation - :param schedule: Running schedule in :meth:`step` - :type train_dataloader: DataLoader, optional - :type test_dataloader: DataLoader, optional + :param step_schedule: Running schedule in :meth:`step` + :param gradient_accumulation: Steps of gradient accumulation + :param gradient_clipping: The norm of gradient clipping :type model: Module - :type criterion: _Loss, optional - :type optimizer: Optimizer, optional - :type lr_scheduler: _LRScheduler, optional - :type schedule: BaseSchedule, optional + :type optimizer: Optimizer + :type step_schedule: BaseSchedule, optional + :type gradient_accumulation: int, optional + :type gradient_clipping: float, optional """ + def __init__(self, - train_dataloader: Optional[DataLoader] = None, - test_dataloader: Optional[DataLoader] = None, - model: Module = None, - criterion: _Loss = None, - optimizer: Optimizer = None, - lr_scheduler: Optional[_LRScheduler] = None, - schedule: BaseSchedule = None): - self.train_dataloader = train_dataloader - self.test_dataloader = test_dataloader - assert model is not None, "Engine requires a model" - self.model = model - self.criterion = criterion - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.schedule = schedule if schedule is not None \ + model: Module, + optimizer: Optimizer, + step_schedule: BaseSchedule = None, + gradient_accumulation: int = 1, + gradient_clipping: float = 0.0): + self.schedule = step_schedule if step_schedule is not None \ else NoPipelineSchedule() + self.schedule.initialize(model, optimizer) + self.grad_accum_size = gradient_accumulation + self.grad_accum_cur_step = 0 + self.grad_clip = gradient_clipping + self.training = True # default self._logger = get_global_dist_logger() # build gradient handler @@ -65,8 +57,8 @@ def __init__(self, f'argument gradient_handler_cfg expected type list, ' \ f'but got type {type(gpc.config.gradient_handler)}' gradient_handler_cfg = gpc.config.gradient_handler - elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, + ZeroRedundancyOptimizer_Level_3)): gradient_handler_cfg = [dict(type='ZeROGradientHandler')] self._logger.info( "Training with zero is detected, ZeROGradientHandler is automatically " @@ -85,86 +77,71 @@ def __init__(self, "to all-reduce the gradients after a training step.", ranks=[0]) for cfg in gradient_handler_cfg: - handler = build_gradient_handler(cfg, self.model, self.optimizer) + handler = build_gradient_handler(cfg, model, optimizer) self._gradient_handlers.append(handler) - self.schedule.initialize(self.train_dataloader, self.model, - self.criterion, self.optimizer, - self.lr_scheduler) - self.forward_only = False - def handle_gradient(self): """Handles all-reduce operations of gradients across different parallel groups. """ for handler in self._gradient_handlers: handler.handle_gradient() - def set_dataloader(self, data: DataLoader, train: bool = True): - """Sets dataloader in training or evaluation. - - :param data: Dataloader to be set - :param train: Set training dataloader if True, otherwise evaluation dataloader - :type data: DataLoader - :type train: bool - """ - if train: - self.train_dataloader = data - else: - self.test_dataloader = data - - def get_model(self): - """Returns the neural network model in the engine. - """ - return self.model - def get_optimizer(self): - """Returns optimizier in the engine. - """ - return self.optimizer - - def get_lr_scheduler(self): - """Returns the learning rate scheduler in the engine. - """ - return self.lr_scheduler - def train(self): """Sets the model to training mode. """ - self.forward_only = False - self.schedule.train(dataloader=self.train_dataloader, mode=True) + self.training = True def eval(self): """Sets the model to evaluation mode. """ - self.forward_only = True - self.schedule.train(dataloader=self.test_dataloader, mode=False) - - def is_train(self): - """Returns True if it is in training, otherwise False. - """ - return not self.forward_only - - def get_lr(self): - """Gets current learning rate. - """ - return self.schedule.get_lr() - - def step(self, return_loss=True): + self.training = False + + def step(self, + data_iter, + model: Module, + criterion: _Loss, + optimizer: Optimizer = None, + is_last_iteration: bool = False, + return_loss=True): """A running step based on the schedule. Usually, it runs a training or evaluation over a batch of dataset. + :param data_iter: Data iterator of the dataset + :param model: The neural network model + :param criterion: Loss function used to calculate + :param optimizer: Optimizer for updating the parameters + :param is_last_iteration: If True, this iteration is the last iteration in the epoch :param return_loss: loss will be returned if True - :type return_loss: bool + :type data_iter: Iterator + :type model: Module + :type criterion: _Loss + :type optimizer: Optimizer, optional + :type is_last_iteration: bool, optional + :type return_loss: bool, optional :return: (output, lablel, loss) """ - self.schedule.zero_grad(forward_only=self.forward_only) + if self.training and self.grad_accum_cur_step == 0: + optimizer.zero_grad() output, label, loss = self.schedule.forward_backward_step( - forward_only=self.forward_only, return_loss=return_loss) - - if not self.forward_only: - # all reduce gradients - self.handle_gradient() - - self.schedule.step() + data_iter, model, criterion, optimizer, + forward_only=not self.training, + grad_accum_size=self.grad_accum_size, + return_loss=return_loss) + + if self.training: + self.grad_accum_cur_step += 1 + if self.grad_accum_cur_step == self.grad_accum_size: + # all reduce gradients + self.handle_gradient() + self.schedule.optimizer_step(model, optimizer, self.grad_clip) + self.grad_accum_cur_step = 0 + + if is_last_iteration: + while True: + try: + trash = next(data_iter) + except StopIteration: + break return output, label, loss diff --git a/colossalai/engine/amp/__init__.py b/colossalai/engine/amp/__init__.py new file mode 100644 index 000000000000..927d5cf09d1a --- /dev/null +++ b/colossalai/engine/amp/__init__.py @@ -0,0 +1,2 @@ +from .grad_scaler import GradScaler +from .amp_type import AMP_TYPE diff --git a/colossalai/engine/amp/amp_type.py b/colossalai/engine/amp/amp_type.py new file mode 100644 index 000000000000..7f7c5a659df0 --- /dev/null +++ b/colossalai/engine/amp/amp_type.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +class AMP_TYPE(Enum): + APEX = 'apex' + TORCH = 'torch' + PARALLEL = 'parallel' diff --git a/colossalai/engine/amp/grad_scaler.py b/colossalai/engine/amp/grad_scaler.py new file mode 100644 index 000000000000..7859d132db17 --- /dev/null +++ b/colossalai/engine/amp/grad_scaler.py @@ -0,0 +1,577 @@ +# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p +import torch +from collections import defaultdict, abc +import warnings +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple +from colossalai.context import ParallelMode +import torch.distributed as dist +from colossalai.core import global_context as gpc + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to( + device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + if enabled and not torch.cuda.is_available(): + warnings.warn( + "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + self._enabled = False + else: + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict( + _refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format( + funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format( + funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.full( + (1,), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full( + (1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda or outputs.device.type == 'xla' + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + # holds a reference that can be overwritten by apply_scale + stash: List[_MultiDeviceReplicator] = [] + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda or val.device.type == 'xla' + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError( + "outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict( + lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError( + "Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append( + to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_(grads, + per_device_found_inf.get( + device), + per_device_inv_scale.get(device)) + # For tensor parallel paramters it should be all-reduced over tensor parallel process group + if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: + for tensor in per_device_found_inf._per_device_tensors.values(): + dist.all_reduce(tensor, op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.TENSOR)) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError( + "Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError( + "step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) + optimizer_state["stage"] = OptState.STEPPED + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert len(optimizer_state["found_inf_per_device"] + ) > 0, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step( + optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + # type: ignore[attr-defined] + assert isinstance(new_scale, torch.cuda.FloatTensor), reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values()] + + assert len( + found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + torch._amp_update_scale_(_scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async().item() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return {"scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.full( + (1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index c64031c09409..0583ccbf3d14 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -5,125 +5,85 @@ import torch +from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.utils import get_current_device class BaseSchedule(ABC): """A basic helper class to control the process of training or evaluation. + It mainly composes of forward_backward_step for gradient backward and + optimizer_step for parameters update. + For the convenience to enable FP16, we aggreate all codes that contain the + control of FP16 in class schedule. """ + def __init__(self): - self.initialized = False self.logger = get_global_dist_logger() - @property - @abstractmethod - def num_steps(self): - """The number of batches in training or evaluation. - """ - pass + @staticmethod + def _move_tensor(element): + if torch.is_tensor(element): + if not element.is_cuda: + return element.to(get_current_device()).detach() + return element - def initialize(self, - dataloader=None, - model=None, - criterion=None, - optimizer=None, - lr_scheduler=None): - """Initializes the schedule and set parameters before running. - - :param dataloader: DataLoader in training or evaluation - :param model: The neural network model - :param criterion: Criterion for calculating loss - :param optimizer: Optimizer for updating the parameters - :param lr_scheduler: Learning rate scheduler in the process - """ - self.dataloader = dataloader - assert model is not None, "Schedule requires a model" - self.model = model - assert criterion is not None, "Schedule requires a criterion" - self.criterion = criterion - assert optimizer is not None, "Schedule requires an optimizer" - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.initialized = True - - def check_initialized(self): - """Checks whether the schedule is initialized. - """ - assert self.initialized, \ - 'Schedule is not initialized. Call schedule.initialize(...) before using it.' + def _move_to_device(self, data): + if isinstance(data, (tuple, list)): + data = tuple([self._move_tensor(d) for d in data]) + elif torch.is_tensor(data): + data = data.to(get_current_device()).detach() + return data - def load_batch(self): - """Loads a batch of dataset. It returns the data and labels which are + def load_batch(self, data_iter): + """Loads a batch from data iterator. It returns the data and labels which are already in the same GPU as where the model's. :return: (data, label) - :rtype: (Tensor, Tensor) + :rtype: (Tensor, Tensor) """ - self.check_initialized() - if self.data_iter is None: + if data_iter is None: raise RuntimeError('Dataloader is not defined.') - data, label = next(self.data_iter) + data, label = next(data_iter) return self._move_to_device(data), self._move_to_device(label) - def _move_to_device(self, data): - if isinstance(data, ( - tuple, - list, - )): - data = tuple([ - d.to(get_current_device()).detach() for d in data - if torch.is_tensor(d) - ]) - elif torch.is_tensor(data): - data = data.to(get_current_device()).detach() - return data - - def train(self, dataloader=None, mode=True): - """Sets the dataloader to be used and turn the model to - training or evaluation mode. + def initialize(self, model, optimizer): + """Initializes the model and the optimizer before training. + This is often used in FP16 training. - :param dataloader: Dataloader to be used - :param mode: If True, the model will set as training mode. Otherwise, evaluation mode. - """ - self.check_initialized() - if mode: - self.model.train() - else: - self.model.eval() - if dataloader is not None: - self.dataloader = dataloader - self.data_iter = iter(dataloader) - - def zero_grad(self, forward_only=False): - """Cleans gradients with the optimizer. + :param model: The neural network model + :param optimizer: Optimizer for updating the parameters """ - if not forward_only: - self.check_initialized() - self.optimizer.zero_grad() + return model, optimizer - def get_lr(self): - """Returns the current learning rate. - """ - if self.lr_scheduler is not None: - return self.lr_scheduler.get_lr()[0] - else: - return self.optimizer.param_groups[0]['lr'] + @abstractmethod + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): + """The process function over a batch of dataset for training or evaluation. - def step(self): - """Updates the parameters and learning rate with the optimizer. + :param data_iter: Data iterator of the dataset + :param model: Model used in training or evaluation + :param optimizer: Optimizer used in training or evaluation + :param criterion: Loss function + :param forward_only: If True, the process won't include backward + :param grad_accum_size: Steps of gradient accumulation + :param return_loss: If False, the loss won't be returned """ - self.check_initialized() - self.optimizer.step() - # update lr scheduler - if self.lr_scheduler is not None: - self.lr_scheduler.step() + pass @abstractmethod - def forward_backward_step(self, forward_only=False, return_loss=True): - """The process function over a batch of dataset for training or evaluation. + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): + """Updates the parameters with the optimizer. - :param forward_only: If True, the process won't include backward. - :param return_loss: If False, the loss won't be returned. + :param model: The neural network model + :param optimizer: Optimizer for updating the parameters + :param grad_clipping: The norm of gradient clipping + :type grad_clipping: float, optional """ pass diff --git a/colossalai/engine/schedule/_no_pipeline.py b/colossalai/engine/schedule/_no_pipeline.py index 3ab1fa2d3ce4..7f62475c4d39 100644 --- a/colossalai/engine/schedule/_no_pipeline.py +++ b/colossalai/engine/schedule/_no_pipeline.py @@ -10,13 +10,12 @@ except: print('PyTorch amp is not supported with the current PyTorch version') -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.engine.amp_type import AMP_TYPE from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from ._utils import convert_to_fp16 +from colossalai.nn.optimizer._utils import clip_grad_norm_fp32 +from ._utils import convert_to_fp16, convert_to_fp32 from ._base_schedule import BaseSchedule +from ..amp import AMP_TYPE, GradScaler class NoPipelineSchedule(BaseSchedule): @@ -30,6 +29,7 @@ class NoPipelineSchedule(BaseSchedule): :type amp_type: AMP_TYPE :type amp_config: dict """ + def __init__( self, amp_type: AMP_TYPE = None, @@ -41,12 +41,6 @@ def __init__( assert amp_type is None or isinstance(amp_type, AMP_TYPE), \ 'unrecognised value for argument fp16, it can only be None, torch or apex' - # LSG: check compatibility - # LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel - if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size( - ParallelMode.TENSOR) > 1: - assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \ - 'You can only AMP_TYPE.PARALLEL for tensor parallel training' self.use_zero_level_2_3 = False if amp_type is not None: @@ -79,34 +73,29 @@ def __init__( self.fp16 = False self.amp_type = None - @property - def num_steps(self): - return len(self.dataloader) - - def initialize(self, - dataloader, - model, - criterion, - optimizer, - lr_scheduler=None): - super().initialize(dataloader, - model, - criterion, - optimizer, - lr_scheduler=lr_scheduler) - if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + def initialize(self, model, optimizer): + if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, + ZeroRedundancyOptimizer_Level_3)): self.use_zero_level_2_3 = True - assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' + assert self.amp_type != AMP_TYPE.PARALLEL, \ + 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' if self.fp16: if self.amp_type == AMP_TYPE.TORCH: - self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg) + self._torch_amp_scaler = GradScaler(**self.amp_cfg) elif self.amp_type == AMP_TYPE.APEX: - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, **self.amp_cfg) - - def forward_backward_step(self, forward_only=False, return_loss=True): + model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg) + + return model, optimizer + + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): """The process function that loads loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. @@ -115,71 +104,65 @@ def forward_backward_step(self, forward_only=False, return_loss=True): assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - data, label = self.load_batch() + data, label = self.load_batch(data_iter) loss = None - # LSG: leave for debug, make sure dataloader is deterministic - # if forward_only: - # img = data[0] - # rank = gpc.get_local_rank(ParallelMode.DATA) - # world_size = gpc.get_world_size(ParallelMode.DATA) - # group = gpc.get_group(ParallelMode.DATA) - # input_list = [img.clone() for _ in range(world_size)] - # output_list = [torch.empty_like(img) for _ in range(world_size)] - # output_list[rank] = img.clone() - # dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group) - # assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2]) - # forward if self.fp16 and self.amp_type == AMP_TYPE.TORCH: with torch_amp.autocast(): - output = self.model(*data) + output = model(*data) if not isinstance(output, (tuple, list)): output = (output,) if return_loss: - loss = self.criterion(*output, *label) + loss = criterion(*output, *label) else: if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: data = convert_to_fp16(data) - output = self.model(*data) + output = model(*data) + + if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: + output = convert_to_fp32(output) + if not isinstance(output, (tuple, list)): output = (output,) if return_loss: - loss = self.criterion(*output, *label) + loss = criterion(*output, *label) + + loss /= grad_accum_size if not forward_only: # backward if self.use_zero_level_2_3: - self.optimizer.backward(loss) + optimizer.backward(loss) elif self.fp16: if self.amp_type == AMP_TYPE.APEX: - with apex_amp.scale_loss(loss, - self.optimizer) as scaled_loss: + with apex_amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() elif self.amp_type == AMP_TYPE.TORCH: self._torch_amp_scaler.scale(loss).backward() elif self.amp_type == AMP_TYPE.PARALLEL: - loss = self.optimizer.scale_loss(loss) + loss = optimizer.scale_loss(loss) loss.backward() # scale back to display the original value in logs - loss.div_(self.optimizer.grad_scaler.scale) + loss.div_(optimizer.grad_scaler.scale) else: loss.backward() if return_loss: - return output, label, loss + return output, label, loss * grad_accum_size else: return output, None, None - def step(self): + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): # step optimizer if self.fp16 and self.amp_type == AMP_TYPE.TORCH: - self._torch_amp_scaler.step(self.optimizer) + if grad_clipping > 0.0: + self._torch_amp_scaler.unscale_(optimizer) + clip_grad_norm_fp32(model.parameters(), grad_clipping) + self._torch_amp_scaler.step(optimizer) self._torch_amp_scaler.update() else: - self.optimizer.step() - - # update lr scheduler - if self.lr_scheduler is not None: - self.lr_scheduler.step() + if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0: + clip_grad_norm_fp32(model.parameters(), grad_clipping) + optimizer.step() diff --git a/colossalai/engine/schedule/_pipeline.py b/colossalai/engine/schedule/_pipeline.py index 0b477c0d5361..6defea93d57a 100644 --- a/colossalai/engine/schedule/_pipeline.py +++ b/colossalai/engine/schedule/_pipeline.py @@ -15,7 +15,7 @@ from colossalai.utils import get_current_device from ._base_schedule import BaseSchedule from ._utils import convert_to_fp16 -from ..amp_type import AMP_TYPE +from ..amp import AMP_TYPE def squeeze(x: Union[Tensor, tuple, list]): @@ -93,12 +93,11 @@ def _sync_data(self): ) # Pipeline schedule just puts data in memory - def load_batch(self): - self.check_initialized() - if self.data_iter is None: + def load_batch(self, data_iter): + if data_iter is None: raise RuntimeError('Dataloader is not defined.') self.batch_pos = 0 - data, label = next(self.data_iter) + data, label = next(data_iter) self.batch_data, self.batch_label = \ self._move_to_device(data), self._move_to_device(label) batch_size = self.batch_data.shape[0] @@ -117,23 +116,8 @@ def load_micro_batch(self): self.batch_pos += self.microbatch_size return (data,), (label,) - @property - def num_steps(self): - return len(self.dataloader) - - def initialize(self, - dataloader, - model, - criterion, - optimizer, - lr_scheduler=None): - super().initialize(dataloader, - model, - criterion, - optimizer, - lr_scheduler=lr_scheduler) - if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, - ZeroRedundancyOptimizer_Level_3)): + def initialize(self, model, optimizer): + if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): raise TypeError( "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" ) @@ -145,7 +129,8 @@ def initialize(self, 'default tensor dtype is set to torch.half for fp16 training', ranks=[0]) - def forward_step(self, input_tensor, return_tensors, return_loss=True): + def forward_step(self, model, criterion, input_tensor, return_tensors, + grad_accum_size, return_loss=True): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_tensor is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -156,14 +141,14 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True): if self.amp_type == AMP_TYPE.PARALLEL: input_tensor = convert_to_fp16(input_tensor) input_tensor = squeeze(input_tensor) - output_tensor = self.model(input_tensor) + output_tensor = model(input_tensor) output_tensor = squeeze(output_tensor) if gpc.is_last_rank(ParallelMode.PIPELINE): if return_loss: input_tensor, label = self.load_micro_batch() - loss_reduced = self.criterion(output_tensor, * - label) / self.num_microbatches + loss_reduced = criterion(output_tensor, *label) \ + / (self.num_microbatches * grad_accum_size) return_tensors.append( tuple((output_tensor, label[0], loss_reduced))) return loss_reduced @@ -174,7 +159,7 @@ def forward_step(self, input_tensor, return_tensors, return_loss=True): else: return output_tensor - def backward_step(self, input_tensor, output_tensor, output_tensor_grad): + def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad): """Backward step through the passed-in output tensor. If it is the last stage, the output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. Returns the gradients with respect to the input tensor (None if first stage). @@ -187,7 +172,7 @@ def backward_step(self, input_tensor, output_tensor, output_tensor_grad): # Backward pass. if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL: - output_tensor = self.optimizer.scale_loss(output_tensor) + output_tensor = optimizer.scale_loss(output_tensor) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) # Collect the grad of the input_tensor. @@ -197,17 +182,24 @@ def backward_step(self, input_tensor, output_tensor, output_tensor_grad): return input_tensor_grad - def forward_backward_step(self, forward_only=True, return_loss=True): + def forward_backward_step(self, + data_iter, + model, + criterion, + optimizer=None, + forward_only=False, + grad_accum_size: int = 1, + return_loss=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. - + :return: (output, label, loss) """ assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch() + self.load_batch(data_iter) num_warmup_microbatches = \ (gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) @@ -233,9 +225,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = recv_tensor_meta(ft_shape) input_tensor = recv_forward(ft_shape) - output_tensor = self.forward_step(input_tensor, - return_tensors, - return_loss=return_loss) + output_tensor = self.forward_step( + model, criterion, + input_tensor, return_tensors, + grad_accum_size, return_loss=return_loss + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): bt_shape = output_tensor.shape fs_checker = send_tensor_meta(output_tensor, fs_checker) @@ -257,9 +251,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - output_tensor = self.forward_step(input_tensor, - return_tensors, - return_loss=return_loss) + output_tensor = self.forward_step( + model, criterion, + input_tensor, return_tensors, + grad_accum_size, return_loss=return_loss + ) if forward_only: send_forward(output_tensor) @@ -279,9 +275,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - input_tensor_grad = self.backward_step(input_tensor, - output_tensor, - output_tensor_grad) + input_tensor_grad = self.backward_step( + optimizer, + input_tensor, output_tensor, + output_tensor_grad + ) if last_iteration: input_tensor = None @@ -298,9 +296,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): output_tensor_grad = recv_backward(bt_shape) - input_tensor_grad = self.backward_step(input_tensor, - output_tensor, - output_tensor_grad) + input_tensor_grad = self.backward_step( + optimizer, + input_tensor, output_tensor, + output_tensor_grad + ) send_backward(input_tensor_grad) @@ -309,8 +309,11 @@ def forward_backward_step(self, forward_only=True, return_loss=True): output, label, loss = tuple(map(list, zip(*return_tensors))) return (torch.cat(output, dim=0), torch.cat(label, dim=0), - sum(loss)) + sum(loss) * grad_accum_size) else: return tuple((torch.cat(return_tensors, dim=0), None, None)) else: return tuple((None, None, None)) + + def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0): + optimizer.step() diff --git a/colossalai/engine/schedule/_utils.py b/colossalai/engine/schedule/_utils.py index 9c4a2a19b912..cdfd0246c12d 100644 --- a/colossalai/engine/schedule/_utils.py +++ b/colossalai/engine/schedule/_utils.py @@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]): else: raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") return ret + + +def convert_to_fp32(data: Union[Tensor, List[Tensor]]): + if isinstance(data, Tensor): + ret = data.float() + elif isinstance(data, (list, tuple)): + ret = [val.float() for val in data] + else: + raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") + return ret +