diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6fd2570b8..e62d0c49d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -461,11 +461,7 @@ def get_transform_buffer( state = (shape[::-1], to_order) if to_order == "row" or to_order == "col": - if HIP_ENVIRONMENT and to_order == "col": - # row to col transformation transposes output shape, so change buffer allocation accordingly - return init_func(shape[::-1], dtype=dtype, device=device), state - else: - return init_func(shape, dtype=dtype, device=device), state + return init_func(shape, dtype=dtype, device=device), state elif to_order == "col32": # blocks of 32 columns (padded) cols = 32 * ((cols + 31) // 32) @@ -503,7 +499,7 @@ def nvidia_transform( from_order = state[1] if out is None: out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1] + state[0], A.dtype, A.device, to_order, state[1], transpose ) else: new_state = (state[1], to_order) diff --git a/csrc/ops.hip b/csrc/ops.hip index 54743d111..cfb268dec 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -383,7 +383,12 @@ template void trans hipblasLtOrder_t orderA = get_order(); hipblasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); - int ldOut = get_leading_dim(dim1, dim2); + int ldOut; + if (TARGET==COL && transpose) { + ldOut = dim2; + } else { + ldOut = get_leading_dim(dim1, dim2); + } hipblasLtMatrixLayout_t A_desc = NULL, out_desc = NULL, B_desc = NULL; T B = T(0); @@ -395,13 +400,21 @@ template void trans { checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_8I, dim1, dim2, ldA)); checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_8I, 0, 0, 0)); - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut)); + if (TARGET==COL && transpose) { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim2, dim1, ldOut)); + } else { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut)); + } } else if(DTYPE == 32) { checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_32I, dim1, dim2, ldA)); checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_32I, 0, 0, 0)); - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut)); + if (TARGET==COL && transpose) { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim2, dim1, ldOut)); + } else { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut)); + } } else { @@ -424,6 +437,9 @@ template void trans } template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index ba551dcc3..b7fdf113e 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -158,6 +158,9 @@ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(hipblasLtHandle_t lt #endif MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8); +MAKE_FUNC_TRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32); +MAKE_FUNC_TRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32); MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32); @@ -406,6 +409,9 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8) + MAKE_FUNC_CTRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32) + MAKE_FUNC_CTRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32) MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32) diff --git a/tests/test_functional.py b/tests/test_functional.py index 8d04a650f..f914820fe 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -719,19 +719,16 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_close(C1, C3.float()) - # Since ROCm supports row to col transformation only which is same as transpose, - # skipping this for HIP environment - if not HIP_ENVIRONMENT: - ## transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( - torch.int8 - ) - C1 = torch.matmul(A.float(), B.float()) + ## transpose + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( + torch.int8 + ) + C1 = torch.matmul(A.float(), B.float()) - B2t, SBt = F.transform(B, "col_turing", transpose=True) - C2, SC = F.igemmlt(A2, B2t, SA, SBt) - C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_close(C1, C3.float()) + B2t, SBt = F.transform(B, "col_turing", transpose=True) + C2, SC = F.igemmlt(A2, B2t, SA, SBt) + C3, S = F.nvidia_transform(C2, "row", state=SC) + torch.testing.assert_close(C1, C3.float()) dim1 = [32]