Starcoder2 model#29120
Conversation
| return self.weight * hidden_states.to(input_dtype) | ||
|
|
||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2 |
There was a problem hiding this comment.
fix copies will not let this pass, should be copied from Mistral as we changed llama for compiled static cache.
I would also rather we support static cache as the API got quite a lot cleaner
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb |
There was a problem hiding this comment.
same here llama is different make fix-copies will help you fix this !
| return hidden_states | ||
|
|
||
|
|
||
| class Starcoder2GatedMLP(nn.Module): |
There was a problem hiding this comment.
probably missing copied from mention here (mistral)
There was a problem hiding this comment.
It has small changes (bias + dropout I think)
There was a problem hiding this comment.
Should we remove the copied mention from all the classes/methods where we added dropout?
There was a problem hiding this comment.
yes otherwise the check-copies will fail 😉
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | ||
|
|
There was a problem hiding this comment.
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
this is not used in Mistral anyways
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||
|
|
||
|
|
||
| class Starcoder2Attention(nn.Module): |
There was a problem hiding this comment.
would make sense to follow the llama implementation IMO for static cache (with the additional cache positions) but this can go in another PR no worries 🤗
| self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) | ||
|
|
||
| self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](config) | ||
|
|
||
| self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type]( | ||
| config.hidden_size, eps=config.norm_epsilon | ||
| ) | ||
| self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type]( | ||
| config.hidden_size, eps=config.norm_epsilon | ||
| ) |
There was a problem hiding this comment.
this is not what we usually do in transformers. The attention is a specific case 😅
- are all of these used in the default starcoder?
- if not then let's not support mistral. Mistral is a different architecture
The reason why attention is allowed is because it uses the same parameters -> same "Attention" with different forward vs here it's really a different architecture = againsttransformersphilosophy
| if self._attn_implementation == "flash_attention_2": | ||
| # 2d mask is passed through the layers | ||
| attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
| elif self._attn_implementation == "sdpa" and not output_attentions: | ||
| # output_attentions=True can not be supported when using SDPA, and we fall back on | ||
| # the manual implementation that requires a 4D causal mask in all cases. | ||
| attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( | ||
| attention_mask, | ||
| (batch_size, seq_length), | ||
| inputs_embeds, | ||
| past_key_values_length, | ||
| ) | ||
| else: | ||
| # 4d mask is passed through the layers | ||
| attention_mask = _prepare_4d_causal_attention_mask( | ||
| attention_mask, | ||
| (batch_size, seq_length), | ||
| inputs_embeds, | ||
| past_key_values_length, | ||
| sliding_window=self.config.sliding_window, | ||
| ) |
There was a problem hiding this comment.
see the new Llama code for this which was simpliefied. I'd rather we take it directly for the attention 😉
| @unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test") | ||
| def test_save_load_fast_init_from_base(self): | ||
| pass |
There was a problem hiding this comment.
I might have missed this but have not seen where these complex number buffers are?
|
I re-created a PR here since Joel is on vacation: #29215 |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Closing as #29215 was merged and starcoder 2 is officially supported |
The Starcoder2 model, adapted from Mistral. All changes are done through options, so Mistral itself is still supported. Main changes:
*Embedding and residual dropout
It does not support absolute embeddings, so can't support Santacoder or Starcoder
Todo:
Core generation] Adds support for static KV cache #27931, [CLeanup] Revert SDPA attention changes that got in the static kv cache PR #29027 (and future changes from Feb. 19)@younesbelkada