Fix mamba regression#39728
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: falcon_mamba |
|
This comment contains run-slow, running the specified jobs: models: ['models/falcon_mamba'] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: falcon_mamba, mamba |
|
run-slow: falcon_mamba, mamba |
|
This comment contains run-slow, running the specified jobs: models: ['models/falcon_mamba', 'models/mamba'] |
| if is_mambapy_available(): | ||
| from mambapy.pscan import pscan | ||
| else: | ||
| pscan = None |
There was a problem hiding this comment.
does this model require mambapy.pscan or it could work with pscan = None too?
There was a problem hiding this comment.
it uses pscan when the fast path (mamby library) is available. Otherwise it defaults to a slow (python code) forward pass.
The import was at the top before, I moved it for debugging and it slipped in 😅 , now I am reverting the change.
There was a problem hiding this comment.
sure.
(FYI, so for the test of this model, we are testing the slow path I think)
| # This is needed since mamba overrides the intermediate_size attribute | ||
| self.intermediate_size = ( | ||
| int(expand * self.hidden_size) | ||
| if kwargs.get("intermediate_size") is None | ||
| else kwargs.get("intermediate_size") | ||
| ) |
There was a problem hiding this comment.
Could you explain this part a bit more for me 🙏 , but I believe you are right.
There was a problem hiding this comment.
When using modular_converter, super.init() unravels MambaConfig.init which sets intermediate_size to int(expand * self.hidden_size), overriding any value passed via kwargs.
Before #38086, setting intermediate_size to int(expand * self.hidden_size) wasn't an issue because PretrainedConfig.__init__() was called last, and the kwargs value prevailed.
However, #38086 reversed that order due to modular_converter, which forces PretrainedConfig.__init__() to run first, thus overwriting the kwargsintermediate_size.
The new code explicitly assigns intermediate_size to ensure the kwargs value takes precedence again.
There was a problem hiding this comment.
OK, I understand better. And here, since intermediate_size is stored in tiiuae/falcon-mamba-7b, during loading, it is passed as kwargs, and causing the issue.
Looks like this (issue) is something that would happen quite frequently and we have to be careful (as modular_converter force it as you mentioned)
There was a problem hiding this comment.
One final nit question: do you know why we have a config (i.e. tiiuae/falcon-mamba-7b)
that have intermediate_size != int(expand * self.hidden_size)
sounds a bit strange
There was a problem hiding this comment.
Yeah, we have to be careful. In general, it is counterintuitive to have values hardcoded in the config initialization. I think kwargs should always take precedence there.
So to me, the good fix would be to change MambaConfig. As for the question, it is just their design choice. hidden size was very big, and they just chose to make the intermediates and the convs smaller.
ydshieh
left a comment
There was a problem hiding this comment.
OK for me from the explanation, but better for a second review.
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
* fix mamba regression * fix compile test
This fixes the sneaky regression introduced in #38086 causing loading errors for falcon_mamba:
The gist of the problem is: modular forces the
super().__init__call to be on top of FalconMambaConfig. However, before modular rewrite, it was at the bottom, which was critical for theintermediate_sizeproperty from the config file to take effect.Second bug fixed:
tests/models/mamba/test_modeling_mamba.py::MambaIntegrationTests::test_compile_mamba_cachewas failing due to a missplaced import.