Skip to content

[FEATURE]: mixed precision mixin #3863

@ver217

Description

@ver217

Describe the feature

This is a part of #3839 .

To support more low precision optimizers, we'd better use a new class to control the behavior of all kinds of low precision optimizers.

For fp16 mixed precision optimizer, it has special behaviors during backward, step, clip_grad_norm and zero_grad.

For bf16 mixed precision optimizer, it has no special behavior besides managing master weights.

To keep this mixin simple:

  1. Different parallelism may have different way to manage weights, so this mixin does not manage master weights.
  2. This mixin can be used in different parallelisms. To keep the relationship of classes simple, we don't write mixed precision optimizer directly. Instead, different mixed precision optimizers can call different mixins. E.g. Gemini can support fp16 and bf16 o1, if we write base fp16 optimizer and bf16 optimizer directly, then Gemini optimizer may inherit both of them. This is very difficult.

Pseudo-code:

class MixedPrecisionMixin(ABC):
    """A helper class for mixed precision training. This mixin is used in mixed precision optimizers.

    Attributes:
        dtype (torc.dtype): The expected dtype of the gradients.

    Examples:
        ```python
        class MyMixedPrecisionOptimizer(OptimizerWrapper):
            def __init__(self, optim: Optimizer):
                super().__init__(optim)
                self.mixed_precision = MixedPrecisionMixin()

            def backward(self, loss):
                loss = self.mixed_precision.pre_backward(loss)
                loss.backward()

            def backward_by_grad(self, tensor, grad):
                grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
                tensor.backward(grad)

            def step(self):
                if self.mixed_precision.should_skip_step():
                    self.zero_grad()
                    return
                div_scale = self.mixed_precision.get_grad_div_scale()
                # maybe clip grad here
                # maybe scale grad here
                self.optim.step()

            def zero_grad(self):
                self.mixed_precision.pre_zero_grad()
                return self.optim.zero_grad()
        ```
    """
    dtype: torch.dtype

    @abstractmethod
    def pre_backward(self, loss: Tensor) -> Tensor:
        """Called before backward.

        Args:
            loss (Tensor): Loss value.

        Returns:
            Tensor: Loss value (possibly scaled).
        """
        pass

    @abstractmethod
    def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
        """Called before backward by grad. This is helpful for pipeline parallelism.

        Args:
            tensor (Tensor): Tensor to backward.
            grad (Tensor): Gradient of the tensor.

        Returns:
            Tensor: Gradient of the tensor (possibly scaled).
        """
        pass

    @abstractmethod
    def should_skip_step(self) -> bool:
        """Called before step.

        Returns:
            bool: Whether to skip the step.
        """
        pass

    @abstractmethod
    def pre_zero_grad(self) -> None:
        """Called before zero_grad.
        """
        pass

    @abstractmethod
    def get_grad_div_scale(self) -> float:
        """Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.

        Returns:
            float: A divisor for gradient clipping or step.
        """
        pass

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions