Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 341 additions & 31 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
policy = get_autopolicy(model, shard_config=self.shard_config)
self.model, _ = shardformer.optimize(model, policy)

if self.shard_config.inference_gptq:
Expand Down
23 changes: 18 additions & 5 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,28 @@ class SubModuleReplacementDescription:


class Policy(ABC):
r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.

def __init__(self)
self.model = None
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""

def set_model(self, model: nn.Module) -> None:
def __init__(self, model: Optional[Module] = None, shard_config: Optional[ShardConfig] = None) -> None:
"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
Initialize a Policy object.

This method sets the model and shard configuration for the policy and performs a configuration sanity check.

Args:
model (Optional[Module]): The model to be used with this policy.
shard_config (Optional[ShardConfig]): The sharding configuration for the policy.
"""
self.model = model
self.model: Optional[Module] = model
self.shard_config: Optional[ShardConfig] = shard_config
self.config_sanity_check()

@abstractmethod
def preprocess(self) -> nn.Module:
Expand Down
5 changes: 4 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

Expand All @@ -16,6 +16,9 @@
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",
Expand Down
154 changes: 146 additions & 8 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod

import torch.nn as nn

from colossalai.lazy import LazyInitContext

__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
from .utils import SeqParallelUtils

__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]

FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
Expand Down Expand Up @@ -35,23 +38,133 @@
]


class FusedLayerNorm:
class BaseLayerNorm(ABC):
@abstractmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
"""
Convert a native PyTorch layer normalization module to a specific layer normalization module,
and optionally mark parameters for gradient aggregation.

Args:
module (nn.Module): The native PyTorch layer normalization module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.

Returns:
nn.Module: The specific layer normalization module.

Raises:
AssertionError: If the provided module is not an instance of the supported layer normalization type.
"""


class RMSNorm(BaseLayerNorm):
r"""
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
"""

def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
)

@staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
"""
Convert a native RMSNorm module to colossalai layer norm module,
and optionally mark parameters for gradient aggregation.

Args:
module (nn.Module): The native RMSNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.

Returns:
nn.Module: The RMSNorm module.
"""

LazyInitContext.materialize(module)

if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)

return module


class LayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
"""

def __init__(self) -> None:
raise NotImplementedError(
"LayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
)

@staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.

Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.

Returns:
nn.Module: The LayerNorm module.

Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."

LazyInitContext.materialize(module)

if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)

return module


class FusedLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""

def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
)

@staticmethod
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.

Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.

Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].

Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
# check if apex is installed

assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."

try:
pass
except ImportError:
Expand Down Expand Up @@ -85,22 +198,41 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:

layernorm.weight = module.weight
layernorm.bias = module.bias

if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)

return layernorm


class FusedRMSNorm:
class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""

def __init__(self) -> None:
raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
)

@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
and optionally marking parameters for gradient aggregation.

Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.

Returns:
nn.Module: FusedRMSNorm module.
"""
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
Expand All @@ -124,4 +256,10 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:

rmsnorm.weight = module.weight

if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)

return rmsnorm
76 changes: 75 additions & 1 deletion colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,82 @@
from contextlib import contextmanager
from typing import List

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size


class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.

Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)

@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.

Args:
param: The parameter to check.

Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)

@staticmethod
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
"""
Allreduce partial derived gradients across the specified process group.

This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.

Args:
tp_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.

Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."

# Get the size of the process group, which determines whether synchronization is needed.
tp_size = get_world_size(tp_group) if tp_group is not None else 1

if tp_size == 1:
# If the process group size is 1, no synchronization is required.
return

if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
for p in model.parameters():
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)

# Flatten and reduce the gradients using the specified process group.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)

# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)


class Randomizer:
Expand Down
9 changes: 5 additions & 4 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch.nn as nn

from ..shard.shard_config import ShardConfig
from .base_policy import Policy

__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
Expand Down Expand Up @@ -197,7 +198,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__


def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
r"""
Return the auto policy for the model

Expand All @@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if inference_only:
if ShardConfig.inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
Expand All @@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, inference_only)
return policy()
policy = import_policy(policy_location, ShardConfig.inference_only)
return policy(model, shard_config)
Loading