Add GraniteMoeHybrid support for 4.0#37658
Conversation
|
cc @ArthurZucker for text models! |
| class GraniteMoeHybridSdpaAttention(GraniteMoeSharedSdpaAttention): | ||
| pass | ||
|
|
||
| GRANITEMOEHYBRID_ATTENTION_CLASSES = { |
There was a problem hiding this comment.
Just as a heads up, I think it would be nice to follow using the new attention interface (see #35235 for the original PR). Llama can also provide a good first pointer for this, e.g.
(Except I'm missing that this is a more special kind of attention here :D )
There was a problem hiding this comment.
Thanks for the heads up @vasqu! We are still cleaning up this branch a bit, will take a look at this once the tests are in a better state 🙂
There was a problem hiding this comment.
Thanks for the pointer @vasqu! Refactored this PR to the new attention interface 😄
|
ccing @molbap for mamba2/bamba (feels like I'm pinging you constantly 😆) |
ac9b018 to
d751d26
Compare
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
a70d949 to
8274d2c
Compare
|
Thanks @ArthurZucker! It's ready for another look when you get the chance! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Very nice use of modular thanks a lot! 🤗
|
|
||
| hidden_states = self.input_layernorm(hidden_states) | ||
| self_attn_weights = None | ||
| if self.layer_type == "mamba": |
There was a problem hiding this comment.
I am thinking let's remove the check on type, rely rather on the check of self.self_attn is not None?
There was a problem hiding this comment.
I agree, I also didn't like self.mamba being conditionally undefined. Updated this to define both in __init__ and just check do mamba if self.mamba is not None and attention otherwise 🙂
| else: | ||
| raise ValueError(f"Expected layer type in ['attention', 'mamba'], got {self.layer_type}") |
| hidden_states = self.post_attention_layernorm(hidden_states) | ||
| moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) | ||
|
|
||
| if self.shared_mlp is None: |
There was a problem hiding this comment.
I don't know if you answered or not, is there two different checkpoint being released, one with / one without this?
There was a problem hiding this comment.
The models that are about to come out do use it! I think there are likely experiments ongoing without it, but am not sure about concrete plans for when they'll be released since I'm not the one training the models 🙂
There was a problem hiding this comment.
In that case lets remove what's uncertain! 🤗
There was a problem hiding this comment.
Sounds good! Removed the case with 0 experts, I'll open a follow-up PR if it ends up being used in a model to be released 😄
| if self.gradient_checkpointing and self.training: | ||
| layer_outputs = self._gradient_checkpointing_func( | ||
| decoder_layer.__call__, | ||
| hidden_states, | ||
| layer_mask, | ||
| past_key_values, | ||
| output_attentions, | ||
| use_cache, | ||
| cache_position, | ||
| output_router_logits, | ||
| position_embeddings, | ||
| ) | ||
| else: | ||
| layer_outputs = decoder_layer( | ||
| hidden_states, | ||
| attention_mask=layer_mask, | ||
| past_key_value=past_key_values, | ||
| output_attentions=output_attentions, | ||
| use_cache=use_cache, | ||
| cache_position=cache_position, | ||
| output_router_logits=output_router_logits, | ||
| position_embeddings=position_embeddings, | ||
| ) |
There was a problem hiding this comment.
let's use the new GradientCHeckpointingLayer wdyt?
There was a problem hiding this comment.
Definitely, that is a lot cleaner! I updated the models in the chain for modular to all use the gradient checkpointing layer (GraniteMoe/GraniteMoeShared/GraniteMoeHybrid)
| if not return_dict: | ||
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
There was a problem hiding this comment.
we have a @can_return_tuple for the forward
| if not return_dict: | |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| ) | ||
|
|
||
|
|
||
| class GraniteMoeHybridModelTester: |
There was a problem hiding this comment.
can we try to inherit tests from closes model so mambda in the same fashion as here
There was a problem hiding this comment.
Good idea! The closest models are for the tests are bamba. Consolidated a bit to use Bamba tests, should be way easier to look at now 🤞
berserkr
left a comment
There was a problem hiding this comment.
std initialized twice - std = self.config.initializer_range
align test init delete more tests Use common layer init with bamba tests finish test consolidation
|
Thanks @berserkr! There were two because of |
|
Thank you very much for the fast review @ArthurZucker! I've made all the changes 🙂 |
6b0ba0c to
1c0272a
Compare
|
Thanks @ArthurZucker! Added the missing TOC entry and removed the currently unused shared condition for the MLP, should pass now! 🤞 |
* initial config and MLA layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at decoder Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * completion of layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * modeling class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * adding hybrid class to imports Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix imports granitemoehybrid Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix granitehybrid imports Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix granitehybrid import Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix generated modeling file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add some comments Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * minor fixes in layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add sharedMLP layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * correct layer names Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fixes in mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * change name of MLP layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix seq mizer layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * correct mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fixes in param names Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * enable hybrid model Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix config granite hybrid Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix attention layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * cleanup to re-use mamba code Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * keep layer types Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * attention bias cleanup Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update mamba layer name Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * use granite attention Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix: self attn weights Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * pass at making pos_emb optional Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * initialize self_attn only as needed Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * overwrite forward to create HybridMambaCache Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * Log invalid layer types * Add attention outputs test * Only emit attentions/logits if not None * Fix config test hidden size divisibility * mark granitmoehybrid as stateful * Initialize mamba convolutional layers * Formatting fixes * config docstring, removed some unused attrs * Fix missing arg in models test * Fix create and check decoder model test * support logits to keep in granitemoe * regen to pass logits_to_keep * Allow None or rope * Fix gradient checkpointing * Add granitemoehybrid as special cache for generate check * Remove unused MLA refs * Fix mamba layer mask * Remove logits to keep from config * Minor docstring nits * Update licenses * Enable cache by default * map layer types to layer block type * First pass at granite moe hybrid docs * Ignore granite moe hybrid in valid checkpoint check * Align attention interfaces * regenerate modular granitemoeshared attention interface * Align granite moe hybrid attn interface * run formatting * Handle mamba initialization * avoid conditional attr defs * Move hybrid layer validation to config * Add placeholder integration tests * Docs nits / Update model names * Clean up forward conditions * Use gradient checkpointing layer * Remove some copied bamba tests + inherit align test init delete more tests Use common layer init with bamba tests finish test consolidation * avoid redundant intermediate std var * use @can_return_tuple * Remove unused moe state * make skipped test names consistent * Fix docstring order * Add missing toc * Always create the shared mlp * Fix name in docstring * link preview model in docs --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> Co-authored-by: Alex-Brooks <Alex.Brooks@ibm.com>
What does this PR do?
The PR adds support for upcoming Granite4.0 models. It terms of model architecture, it is a hybrid class with shared MLP layer and Bamba layers.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.