Mistral-related models for QnA#34045
Conversation
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.__init__ with Llama->Mistral,transformer->model | ||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| self.model = MistralModel(config) | ||
| self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||
|
|
||
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.get_input_embeddings with transformer->model | ||
| def get_input_embeddings(self): | ||
| return self.model.embed_tokens | ||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.set_input_embeddings with transformer->model | ||
| def set_input_embeddings(self, value): | ||
| self.model.embed_tokens = value | ||
|
|
||
| @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) | ||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.forward with Llama->Mistral, transformer->model |
There was a problem hiding this comment.
So it's more of a stylistic choice: individual copies vs. include (unnecessary) base model prefix.
There was a problem hiding this comment.
If #34061 gets merged, we can top-level copy from llama without any problems.
There was a problem hiding this comment.
You can also use # Ignore copy on the single place where the copy does not match!
There was a problem hiding this comment.
Ah ok, perfect I'll change it later and ping you when ready ;)
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM in general! would be nice to have a single # Copied from at the top of the class (either # Ignore copy or just don't copy from llama for one of them!)
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.__init__ with Llama->Mistral,transformer->model | ||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| self.model = MistralModel(config) | ||
| self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||
|
|
||
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.get_input_embeddings with transformer->model | ||
| def get_input_embeddings(self): | ||
| return self.model.embed_tokens | ||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.set_input_embeddings with transformer->model | ||
| def set_input_embeddings(self, value): | ||
| self.model.embed_tokens = value | ||
|
|
||
| @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) | ||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.forward with Llama->Mistral, transformer->model |
There was a problem hiding this comment.
You can also use # Ignore copy on the single place where the copy does not match!
|
@ArthurZucker Changed it to top-level copied from now. Lmk if I should change something else. |
| ) | ||
| # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model | ||
| class MistralForQuestionAnswering(MistralPreTrainedModel): | ||
| base_model_prefix = "model" |
There was a problem hiding this comment.
base_model_prefix = "model" is due to the llama stuff I mentioned, otherwise the classes have different structures and copied from will fail in an error.
|
That's it! Merging 🤗 |
* mistral qna start * mixtral qna * oops * qwen2 qna * qwen2moe qna * add missing input embed methods * add copied to all methods, can't directly from llama due to the prefix * make top level copied from
What does this PR do?
Adds question answering to mistral, mixtral, qwen2, qwen2moe. Either we take every model due to the copy statements or we need to ignore it in the copied checks. Based on #29168 but using copied from instead.
Motivation: We have a benchmark paper at https://github.com/LSX-UniWue/SuperGLEBer which uses the transformers QnA models for simplicity but due to it not being available in main, it's manually patched in. Would be great to see it getting into main!
Fixes #28908
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik @ArthurZucker