Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
75fa0b6
[moe] support moe fwd and bwd with low level zero (#4421)
oahzxl Aug 14, 2023
4373d06
[moe] support low level zero optim (#4429)
oahzxl Aug 14, 2023
8240463
[moe] refactor code to better adapt to llm (#4469)
oahzxl Aug 25, 2023
75fdcc2
[moe] support local moe and fix bugs (#4574)
oahzxl Sep 3, 2023
61995f8
[moe] support openmoe inference (#4616)
oahzxl Sep 6, 2023
bf53487
[moe] support openmoe train (#4637)
oahzxl Sep 7, 2023
55a81a6
[moe] align train settings and losses (#4655)
oahzxl Sep 8, 2023
84f05b1
[moe] move to moe and remove legacy (#4672)
oahzxl Sep 11, 2023
d1d0de8
[moe]: add top k router (#4597)
cwher Sep 12, 2023
708bf6f
[moe]: modify router loss, polish code (#4693)
cwher Sep 13, 2023
fde57bf
[moe] speed up embed and mlp (#4701)
oahzxl Sep 14, 2023
adb8ebe
[moe] adapt to main modifications
oahzxl Sep 15, 2023
3f02e57
[moe]: add flash attention & optimize top2 router (#4712)
cwher Sep 18, 2023
d12bbe7
[moe] support hybrid parallel (#4748)
oahzxl Sep 21, 2023
b72fa37
[moe] update benchmark (#4770)
oahzxl Sep 21, 2023
5c97a96
[moe] fix ci (#4772)
oahzxl Sep 22, 2023
c68303b
[moe] update benchmark scripts and ckpt io (#4804)
oahzxl Sep 29, 2023
4d74f83
[moe] support overlap for expert tp (#4851)
oahzxl Oct 4, 2023
2481b83
[moe] support hybrid zero strategy. (#4877)
oahzxl Oct 11, 2023
7441a1f
update mm (#4893)
oahzxl Oct 12, 2023
5844f34
[moe] support load balance (#4914)
oahzxl Oct 16, 2023
5f20878
update bench (#4923)
oahzxl Oct 17, 2023
b0e277b
[moe]: add overlap ep, and fix overlap tp (#4925)
cwher Oct 18, 2023
4a7bf29
[moe] polish code (#4952)
oahzxl Oct 20, 2023
c644b47
[moe] update train script (#4959)
oahzxl Oct 26, 2023
5cc3ad0
update
oahzxl Oct 26, 2023
713446b
delete context
oahzxl Oct 26, 2023
1b19a5f
remove moe
oahzxl Oct 26, 2023
ca42bf4
fix bugs
oahzxl Oct 26, 2023
c381e4c
update timeout temporarily
oahzxl Oct 26, 2023
b19fb91
resume time
oahzxl Oct 26, 2023
61df786
fix bug
oahzxl Oct 28, 2023
685c80a
remove tp
oahzxl Oct 28, 2023
9586f61
use kwargs
oahzxl Oct 28, 2023
6c0094c
polish and align with main
oahzxl Oct 30, 2023
b732ab0
fix test
oahzxl Oct 30, 2023
e85122b
update doc
oahzxl Oct 31, 2023
25c329f
Dist (#7)
oahzxl Oct 31, 2023
6b03bd4
update dist script
oahzxl Oct 31, 2023
659c9b1
update cai version
oahzxl Oct 31, 2023
caece56
update fsdp
oahzxl Oct 31, 2023
9fe7680
update zero
oahzxl Oct 31, 2023
0eb5623
fix bug
oahzxl Oct 31, 2023
4be194a
reverse legacy
oahzxl Nov 1, 2023
7e92e7b
update
oahzxl Nov 1, 2023
da6392f
update readme
oahzxl Nov 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
382 changes: 382 additions & 0 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions colossalai/context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .config import Config, ConfigException

# from .moe_context import MOE_CONTEXT

__all__ = [
"Config",
"ConfigException",
Expand Down
132 changes: 0 additions & 132 deletions colossalai/context/moe_context.py

This file was deleted.

185 changes: 185 additions & 0 deletions colossalai/kernel/triton/llama_act_combine_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from functools import reduce
from typing import Any, Tuple

import torch
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd

try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

if HAS_TRITON:
PRECISION_MAP = {
"fp32": (0, torch.float32),
"fp16": (1, torch.float16),
"bf16": (2, torch.bfloat16),
}

@triton.jit
def _llama_act_combine_forward(
X_GATE1,
X_GATE2,
X_UP,
Y,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
Y += row * stride

# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output
tl.store(Y + cols, y, mask=mask)

@triton.jit
def _llama_act_combine_backward(
X_GATE1,
X_GATE2,
X_UP,
X_GATE1_GRAD,
X_GATE2_GRAD,
X_UP_GRAD,
Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
X_GATE1_GRAD += row * stride
X_GATE2_GRAD += row * stride
X_UP_GRAD += row * stride
Y_GRAD += row * stride

# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)

# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
x_up_grad = x_gate2_act * x_gate1
x_gate1_grad = x_gate2_act * x_up
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)]
# = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]}
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))

# Write output
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)

class LlamaActCombine(torch.autograd.Function):
"""
act(x_gate) * x_up

Args:
x_gate (torch.Tensor): (b, l, 2d) x_gate
x_up (torch.Tensor): (b, l, d) x_up
activation (str): only support swiglu
precision (str): fp32, fp16, bf16
"""

@staticmethod
@custom_fwd
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
"""
act(x_gate) * x_up

Args:
x_gate (torch.Tensor): (b, l, 2d) x gate
x_up (torch.Tensor): (b, l, d) x up
activation (str): only support swiglu
"""
assert activation == "swiglu", "Only swiglu is supported"

# split x gate
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
x_gate1 = x_gate1.contiguous()
x_gate2 = x_gate2.contiguous()
if not x_up.is_contiguous():
x_up = x_up.contiguous()
# assert shape
assert x_gate1.shape == x_gate2.shape == x_up.shape

# add ctx for backward
if x_gate.requires_grad:
ctx.save_for_backward(x_gate1, x_gate2, x_up)

# allocate output
y = torch.empty_like(x_up)
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]

# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1,
x_gate2,
x_up,
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y

@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
# restore from ctx
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps

# init grad
y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
x_gate2), torch.empty_like(x_up)

# enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None
2 changes: 0 additions & 2 deletions colossalai/legacy/engine/gradient_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
Expand All @@ -10,6 +9,5 @@
"DataParallelGradientHandler",
"ZeROGradientHandler",
"PipelineSharedModuleGradientHandler",
"MoeGradientHandler",
"SequenceParallelGradientHandler",
]
12 changes: 0 additions & 12 deletions colossalai/legacy/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch.utils.data import DataLoader

from colossalai.context import Config, ConfigException
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
Expand All @@ -36,7 +35,6 @@
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param


def get_default_parser():
Expand Down Expand Up @@ -323,8 +321,6 @@ def initialize(
if not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif MOE_CONTEXT.is_initialized:
sync_moe_model_param(model)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
Expand Down Expand Up @@ -377,14 +373,6 @@ def initialize(
"added even though not specified in the configuration",
ranks=[0],
)
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
gradient_handler_cfg = [dict(type="MoeGradientHandler")]
if verbose:
logger.info(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0],
)
elif is_using_sequence():
model = DDP(
model,
Expand Down
17 changes: 17 additions & 0 deletions colossalai/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .checkpoint import MoeCheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator

__all__ = [
"MLPExperts",
"MoeRouter",
"Top1Router",
"Top2Router",
"TopKRouter",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoeCheckpintIO",
]
Loading