Skip to content

[shardformer] Support customized policy for llamav2 based model with HybridParallelPlugin#4614

Closed
eric8607242 wants to merge 3 commits intohpcaitech:feature/shardformerfrom
eric8607242:feature/shardformer
Closed

[shardformer] Support customized policy for llamav2 based model with HybridParallelPlugin#4614
eric8607242 wants to merge 3 commits intohpcaitech:feature/shardformerfrom
eric8607242:feature/shardformer

Conversation

@eric8607242
Copy link
Copy Markdown
Contributor

@eric8607242 eric8607242 commented Sep 5, 2023

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs

🚨 Issue number

#4613

📝 What does this PR do?

In this PR, I have two modifications for ShardFormer.

  1. I created a new argument policy for HybridParallelPlugin, which enables user can apply HybridParallel to their own model with customized Policy.
  2. I add a new attribute replacement self_attn.num_key_value_heads for LlamaPolicy. The attribute is new for LLaMAv2, without this attribute replacement, I can not apply tensor parallelism on LLaMAv2 successfully.

💥 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

@eric8607242 eric8607242 changed the title [shardformer] Enable policy assignment in HybridParallelPlugin and enable llama policy for … [shardformer] Support customized policy for llamav2 based model Sep 5, 2023
@eric8607242 eric8607242 changed the title [shardformer] Support customized policy for llamav2 based model [shardformer] Support customized policy for llamav2 based model with HybridParallelPlugin Sep 5, 2023
@flybird11111 flybird11111 marked this pull request as draft September 5, 2023 08:37
@flybird11111 flybird11111 marked this pull request as ready for review September 5, 2023 08:37
class HybridParallelModule(ModelWrapper):

def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
Copy link
Copy Markdown
Contributor

@flybird11111 flybird11111 Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution. Could you please make some changes here? It's best not to expose 'Policy' in the plugin; it's just a component of Shardformer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Thanks for your review!

I have two candidate approaches currently.

  1. Remove the typing hint
  2. Modify the typing hint from Policy to Object

Which one do you think better?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I modify the code with the second approach (modify from Policy to object)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants