Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
57d4fab
add pipeline policy and bert forward to be done
CjhHa1 Jun 30, 2023
8300f45
add bertmodel pipeline forward and make tests
CjhHa1 Jul 3, 2023
246b6d3
add Bert_Policy and test for policy
CjhHa1 Jul 3, 2023
db0a1f1
update formatting
CjhHa1 Jul 3, 2023
9f57067
update formatting
CjhHa1 Jul 3, 2023
dac6427
update the code
CjhHa1 Jul 3, 2023
8b30a02
fix bugs
CjhHa1 Jul 4, 2023
585eb9d
fix name confilt
CjhHa1 Jul 4, 2023
27fb804
add bloom model and policy ,revise the base class of policy
CjhHa1 Jul 4, 2023
86f3d90
Merge branch 'feature/pipeline' into feature/pipeline
CjhHa1 Jul 4, 2023
3ea0ba4
revise
CjhHa1 Jul 4, 2023
cd99e13
Merge branch 'feature/pipeline' of github.com:CjhHa1/ColossalAI into …
CjhHa1 Jul 4, 2023
edb02b2
revision
CjhHa1 Jul 4, 2023
369df2c
add bert_for_pretraining
CjhHa1 Jul 4, 2023
ac114f3
Merge branch 'hpcaitech:feature/pipeline' into feature/pipeline
CjhHa1 Jul 5, 2023
0319c8b
add bert_for_pretraining forward and policy
CjhHa1 Jul 5, 2023
29ef380
fix typos
CjhHa1 Jul 6, 2023
5cd2478
cancel warning
CjhHa1 Jul 6, 2023
ef528e6
change the imediate output to default dict
CjhHa1 Jul 6, 2023
e3e6c3b
change the default output of get_shared_params
CjhHa1 Jul 6, 2023
9eb6047
update
CjhHa1 Aug 1, 2023
6d21b49
rewrite bert test
CjhHa1 Aug 2, 2023
c5e974e
rewrite bert test
CjhHa1 Aug 2, 2023
3bfdd53
[test] Hotfix/fix some model test and refactor check util api (#4369)
FoolPlayer Aug 3, 2023
21c6bb0
[shardformer] add util functions for shardformer tests/fix sync_share…
Aug 3, 2023
c5f4844
[pipeline] add chatglm (#4363)
CjhHa1 Aug 4, 2023
7c84f51
[Shardformer] Merge flash attention branch to pipeline branch (#4362)
flybird11111 Aug 7, 2023
2e77e57
[pipeline] rewrite t5 tests & support multi-tensor transmitting in pi…
Aug 8, 2023
eecef52
add pipeline policy and bert forward to be done
CjhHa1 Jun 30, 2023
e4f25f5
add bertmodel pipeline forward and make tests
CjhHa1 Jul 3, 2023
e7b2a57
add Bert_Policy and test for policy
CjhHa1 Jul 3, 2023
88e5660
update formatting
CjhHa1 Jul 3, 2023
a4116a0
update formatting
CjhHa1 Jul 3, 2023
fe50399
update the code
CjhHa1 Jul 3, 2023
3c4b78d
fix bugs
CjhHa1 Jul 4, 2023
4d2605c
fix name confilt
CjhHa1 Jul 4, 2023
7ac03ae
add bloom model and policy ,revise the base class of policy
CjhHa1 Jul 4, 2023
e7431a7
revise
CjhHa1 Jul 4, 2023
13b9d52
revision
CjhHa1 Jul 4, 2023
79b4bbb
add bert_for_pretraining
CjhHa1 Jul 4, 2023
41f930a
add bert_for_pretraining forward and policy
CjhHa1 Jul 5, 2023
13c86ad
fix typos
CjhHa1 Jul 6, 2023
bc806a9
cancel warning
CjhHa1 Jul 6, 2023
5d99ae0
change the imediate output to default dict
CjhHa1 Jul 6, 2023
63968b3
change the default output of get_shared_params
CjhHa1 Jul 6, 2023
ad2a938
rewrite bert test
CjhHa1 Aug 2, 2023
c0740a7
rewrite bert test
CjhHa1 Aug 2, 2023
040499f
fix some bugs
CjhHa1 Aug 10, 2023
01cc2c0
del pipeline tests
CjhHa1 Aug 10, 2023
179b98f
del pipeline tests
CjhHa1 Aug 10, 2023
491aed8
del some useless prints
CjhHa1 Aug 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
if precision == 'fp16':
module = module.half().cuda()
elif precision == 'bf16':
Expand All @@ -49,8 +50,10 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp

def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
if self.stage_manager.stage in shared_param:
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
dist.barrier()

def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
Expand Down
26 changes: 19 additions & 7 deletions colossalai/kernel/cuda_native/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import os
import subprocess
import warnings

import torch

Expand All @@ -14,15 +15,20 @@
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
print('please install xformers from https://github.com/facebookresearch/xformers')
warnings.warn(f'please install xformers from https://github.com/facebookresearch/xformers')

if HAS_MEM_EFF_ATTN:

from typing import Optional

from einops import rearrange
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)

from .scaled_softmax import AttnMaskType

Expand Down Expand Up @@ -86,11 +92,14 @@ def backward(ctx, grad_output):

class ColoAttention(torch.nn.Module):

def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert embed_dim % num_heads == 0, \
f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
self.scale = 1 / math.sqrt(embed_dim // num_heads)
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout

@staticmethod
Expand All @@ -116,7 +125,7 @@ def forward(self,
bias: Optional[torch.Tensor] = None):
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
attn_bias = None
if attn_mask_type == AttnMaskType.padding: # bert style
if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style
assert attn_mask is not None, \
f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, \
Expand All @@ -134,7 +143,10 @@ def forward(self,
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
if attn_mask_type == AttnMaskType.padding:
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.paddedcausal:
attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.causal: # gpt style
attn_bias = LowerTriangularMask()

Expand All @@ -146,7 +158,7 @@ def forward(self,

out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)

if attn_mask_type == AttnMaskType.padding and batch_size > 1:
if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1:
out = self.repad(out, q_indices, batch_size, tgt_len)

out = rearrange(out, 'b s h d -> b s (h d)')
Expand Down
5 changes: 3 additions & 2 deletions colossalai/kernel/cuda_native/scaled_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3


class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
Expand Down Expand Up @@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
Expand All @@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"

# input is 3D tensor (attn_batches, sq, sk)
Expand Down
6 changes: 5 additions & 1 deletion colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import io
import pickle
import re
from typing import Any, List, Optional, Union

import torch
Expand Down Expand Up @@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
if b'cuda' in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b'cuda', buf_array):
pos = cuda_str.start()
buf_array[pos + 5] = 48 + device_index
buf = bytes(buf_array)

io_bytes = io.BytesIO(buf)
Expand Down
22 changes: 22 additions & 0 deletions colossalai/pipeline/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Dict, List, Optional, Tuple, Type

from torch import Tensor
from torch.nn import Module, Parameter

from colossalai.pipeline.stage_manager import PipelineStageManager

from .base import Policy
from .bert import BertModel, BertModelPolicy

POLICY_MAP: Dict[Type[Module], Type[Policy]] = {
BertModel: BertModelPolicy,
}


def pipeline_parallelize(
model: Module,
stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
if type(model) not in POLICY_MAP:
raise NotImplementedError(f"Policy for {type(model)} not implemented")
policy = POLICY_MAP[type(model)](stage_manager)
return policy.parallelize_model(model)
141 changes: 141 additions & 0 deletions colossalai/pipeline/policy/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from torch import Tensor
from torch.nn import Module, Parameter

from colossalai.lazy import LazyTensor
from colossalai.pipeline.stage_manager import PipelineStageManager


class Policy:

def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager

def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]:
"""Setup model for pipeline parallel

Args:
module (Module): Module to be setup

Returns:
Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers
"""
hold_params = set()
hold_buffers = set()

def init_layer(layer: Module):
for p in layer.parameters():
if isinstance(p, LazyTensor):
p.materialize()
p.data = p.cuda()
hold_params.add(p)
for b in layer.buffers():
if isinstance(b, LazyTensor):
b.materialize()
b.data = b.cuda()
hold_buffers.add(b)

hold_layers = self.get_hold_layers(module)

for layer in hold_layers:
init_layer(layer)

hold_params_dict = {}
hold_buffers_dict = {}

# release other tensors
for n, p in module.named_parameters():
if p in hold_params:
hold_params_dict[n] = p
else:
if isinstance(p, LazyTensor):
p.materialize()
p.data = p.cuda()
p.storage().resize_(0)
for n, b in module.named_buffers():
if b in hold_buffers:
hold_buffers_dict[n] = b
else:
if isinstance(b, LazyTensor):
b.materialize()
b.data = b.cuda()
# FIXME(ver217): use meta tensor may be better
b.storage().resize_(0)
return hold_params_dict, hold_buffers_dict

def replace_forward(self, module: Module) -> None:
"""Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict

Args:
module (Module): _description_
"""
raise NotImplementedError

def get_hold_layers(self, module: Module) -> List[Module]:
"""Get layers that should be hold in current stage. This method should be implemented by subclass.

Args:
module (Module): Module to be setup

Returns:
List[Module]: List of layers that should be hold in current stage
"""
raise NotImplementedError

def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]:
"""Get parameters that should be shared across stages. This method should be implemented by subclass.

Args:
module (Module): Module to be setup

Returns:
List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
raise NotImplementedError

def parallelize_model(self,
module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
"""Parallelize model for pipeline parallel

Args:
module (Module): Module to be setup

Returns:
Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters
"""
hold_params, hold_buffers = self.setup_model(module)
self.replace_forward(module)
shared_params = self.get_shared_params(module)
return hold_params, hold_buffers, shared_params

@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

@staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]

return [start_idx, end_idx]
Loading