diff --git a/tests/test_triton.py b/tests/test_triton.py index e18c7a930..8890193fc 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -4,7 +4,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn import Linear8bitLt +from bitsandbytes.cextension import HIP_ENVIRONMENT +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires triton and a GPU with compute capability 8.0 or higher.") @pytest.mark.parametrize("vector_wise_quantization", [False, True])