Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoeCheckpintIO
from colossalai.moe import MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
Expand Down Expand Up @@ -322,8 +322,8 @@ def seed_worker(worker_id):
**_kwargs,
)

def get_checkpoint_io(self) -> MoeCheckpintIO:
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io

def configure(
Expand Down Expand Up @@ -359,9 +359,6 @@ def configure(
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
Comment thread
ver217 marked this conversation as resolved.
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
Expand Down
10 changes: 6 additions & 4 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def __init__(

self.multi_query_group_num = model.config.num_attention_heads
# default to attention_heads
self.multi_query_attention = model.config.multi_query_attention
if hasattr(model.config, "multi_query_attention"):
self.multi_query_attention = getattr(model.config, "multi_query_attention")

if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num
self.multi_query_group_num = getattr(model.config, "multi_query_group_num")

if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads
self.multi_query_group_num = getattr(model.config, "num_key_value_heads")

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None

Expand All @@ -108,7 +110,7 @@ def _init_manager(self) -> None:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads

if self.multi_query_attention:
if hasattr(self, "multi_query_attention"):
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0
Expand Down
4 changes: 2 additions & 2 deletions colossalai/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .checkpoint import MoeCheckpintIO
from .checkpoint import MoECheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
Expand All @@ -13,5 +13,5 @@
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoeCheckpintIO",
"MoECheckpintIO",
]
Loading