Describe the feature
This is a part of #3839 .
To support more low precision optimizers, we'd better use a new class to control the behavior of all kinds of low precision optimizers.
For fp16 mixed precision optimizer, it has special behaviors during backward, step, clip_grad_norm and zero_grad.
For bf16 mixed precision optimizer, it has no special behavior besides managing master weights.
To keep this mixin simple:
- Different parallelism may have different way to manage weights, so this mixin does not manage master weights.
- This mixin can be used in different parallelisms. To keep the relationship of classes simple, we don't write mixed precision optimizer directly. Instead, different mixed precision optimizers can call different mixins. E.g. Gemini can support fp16 and bf16 o1, if we write base fp16 optimizer and bf16 optimizer directly, then Gemini optimizer may inherit both of them. This is very difficult.
Pseudo-code:
class MixedPrecisionMixin(ABC):
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
Attributes:
dtype (torc.dtype): The expected dtype of the gradients.
Examples:
```python
class MyMixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer):
super().__init__(optim)
self.mixed_precision = MixedPrecisionMixin()
def backward(self, loss):
loss = self.mixed_precision.pre_backward(loss)
loss.backward()
def backward_by_grad(self, tensor, grad):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def step(self):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
div_scale = self.mixed_precision.get_grad_div_scale()
# maybe clip grad here
# maybe scale grad here
self.optim.step()
def zero_grad(self):
self.mixed_precision.pre_zero_grad()
return self.optim.zero_grad()
```
"""
dtype: torch.dtype
@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
"""Called before backward.
Args:
loss (Tensor): Loss value.
Returns:
Tensor: Loss value (possibly scaled).
"""
pass
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
"""Called before backward by grad. This is helpful for pipeline parallelism.
Args:
tensor (Tensor): Tensor to backward.
grad (Tensor): Gradient of the tensor.
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
pass
@abstractmethod
def should_skip_step(self) -> bool:
"""Called before step.
Returns:
bool: Whether to skip the step.
"""
pass
@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad.
"""
pass
@abstractmethod
def get_grad_div_scale(self) -> float:
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
Returns:
float: A divisor for gradient clipping or step.
"""
pass
Describe the feature
This is a part of #3839 .
To support more low precision optimizers, we'd better use a new class to control the behavior of all kinds of low precision optimizers.
For fp16 mixed precision optimizer, it has special behaviors during
backward,step,clip_grad_normandzero_grad.For bf16 mixed precision optimizer, it has no special behavior besides managing master weights.
To keep this mixin simple:
Pseudo-code: