Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):

def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
Copy link
Copy Markdown
Contributor

@flybird11111 flybird11111 Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution. Could you please make some changes here? It's best not to expose 'Policy' in the plugin; it's just a component of Shardformer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Thanks for your review!

I have two candidate approaches currently.

  1. Remove the typing hint
  2. Modify the typing hint from Policy to Object

Which one do you think better?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I modify the code with the second approach (modify from Policy to object)

ddp_config: dict, policy: Optional[object]=None) -> None:

self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group

shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
module, self.shared_params = shardformer.optimize(module, policy=policy)

# setting process groups for shared parameters
self.shared_param_process_groups = []
Expand Down Expand Up @@ -268,6 +268,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
policy (object, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically.
"""

def __init__(self,
Expand Down Expand Up @@ -300,7 +301,8 @@ def __init__(self,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True) -> None:
overlap_communication: bool = True,
policy: Optional[object] = None) -> None:

super().__init__()
assert dist.get_world_size() % (
Expand All @@ -324,6 +326,7 @@ def __init__(self,
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
self.policy = policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
Expand Down Expand Up @@ -403,7 +406,7 @@ def configure(
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
self.ddp_config, self.policy)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
Expand Down
18 changes: 12 additions & 6 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")


if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size

policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
Expand Down