[bnb] Fix bnb skip modules#24043
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| ) | ||
| self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear)) | ||
| self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear)) | ||
|
|
There was a problem hiding this comment.
We should also check at least one other layer not in llm_int8_skip_modules is loaded in 8bit. Ideally one which will effectively check the recursion logic.
There was a problem hiding this comment.
Awesome yes agreed! Will add that now
| seq_classification_model = AutoModelForSequenceClassification.from_pretrained( | ||
| "roberta-large-mnli", quantization_config=quantization_config | ||
| ) | ||
| self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear)) |
There was a problem hiding this comment.
Just for my own understanding (not a comment to address), here we're checking the layers of the classifier are nn.Linear. In test_linear_are_8bit, we check that the layers are nn.Linear too and that their dtype is torch.int8 (I didn't know this was possible!). Are we certain that this means these layers are loaded in correctly? Do we need a dtype check on the weights?
There was a problem hiding this comment.
You are right, we also need a dtype check on the weights! Linear8bitLt has nn.Linear as a super class. Adding new tests!
* fix skip modules test * oops * address comments
What does this PR do?
Fixes #24037
#23479 removed by mistake the logic introduced in #21579 to deal with modules that are not needed to be converted
The PR also adds a nice test to make sure this will never happen again