Jamba: update integration tests#32250
Conversation
|
cc @ydshieh |
| # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) | ||
| # Depending on the hardware we get different logits / generations | ||
| cuda_compute_capability_major_version = None |
There was a problem hiding this comment.
This cuda_compute_capability_major_version pattern is copied from other models like e.g. gemma
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
(thank you for trigger the tests on the runner 🙏 ) |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for digging into this and fixing, and for writing a detailed PR description ❤️
Agreed it's not worth digging into given jamba usage, and as the generated texts appear similar despite the logic differences
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
| torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) | ||
| torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) | ||
| # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist | ||
| if self.cuda_compute_capability_major_version == 8: |
There was a problem hiding this comment.
maybe better to use
self.skipTest(reason="Skipping for T4 runners because ...")
There was a problem hiding this comment.
oops, merged before seeing this comment!
You have a good point, in fact we should split the test in two to test (/skip) the logits separately
* try test updates * a few more changes * a few more changes * a few more changes * [run slow] jamba * skip logits checks on older gpus * [run slow] jamba * oops * [run slow] jamba * Update tests/models/jamba/test_modeling_jamba.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/jamba/test_modeling_jamba.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
🟢 Fixes
generate-related integration tests for jamba 🟢I've checked them against:
Detective work 🕵️
ai21labs/Jamba-tiny-random, the generation text quality doesn't matter.