From f775f404dacd60394b6bc248c53073b967231f4a Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 2 Dec 2024 15:18:03 -0600 Subject: [PATCH 1/3] [oneMKL] Fix gesvd! --- lib/mkl/wrappers_lapack.jl | 17 +++++++++-------- test/onemkl.jl | 10 ++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/lib/mkl/wrappers_lapack.jl b/lib/mkl/wrappers_lapack.jl index f6a64d0b..a47060bb 100644 --- a/lib/mkl/wrappers_lapack.jl +++ b/lib/mkl/wrappers_lapack.jl @@ -304,30 +304,31 @@ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesv jobvt::Char, A::oneStridedMatrix{$elty}) m, n = size(A) + k = min(m, n) lda = max(1, stride(A, 2)) U = if jobu === 'A' oneMatrix{$elty}(undef, m, m) - elseif jobu == 'S' || jobu === 'O' - oneMatrix{$elty}(undef, m, min(m, n)) - elseif jobu === 'N' + elseif jobu === 'S' + oneMatrix{$elty}(undef, m, k) + elseif jobu === 'N' || jobu === 'O' oneMatrix{$elty}(undef, 0, 0) # Equivalence of CU_NULL? else error("jobu must be one of 'A', 'S', 'O', or 'N'") end ldu = U == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(U, 2)) - S = oneVector{$relty}(undef, min(m, n)) + S = oneVector{$relty}(undef, k) Vt = if jobvt === 'A' oneMatrix{$elty}(undef, n, n) - elseif jobvt === 'S' || jobvt === 'O' - oneMatrix{$elty}(undef, min(m, n), n) - elseif jobvt === 'N' + elseif jobvt === 'S' + oneMatrix{$elty}(undef, k, n) + elseif jobvt === 'N' || jobvt === 'O' oneMatrix{$elty}(undef, 0, 0) else error("jobvt must be one of 'A', 'S', 'O', or 'N'") end - ldvt = Vt == oneArray{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2)) + ldvt = Vt == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2)) queue = global_queue(context(A), device()) scratchpad_size = $bname(sycl_queue(queue), jobu, jobvt, m, n, lda, ldu, ldvt) diff --git a/test/onemkl.jl b/test/onemkl.jl index 19f06926..e0c20918 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1421,6 +1421,16 @@ end d_A = oneMatrix(A) U, Σ, Vt = oneMKL.gesvd!('A', 'A', d_A) @test A ≈ collect(U[:,1:n] * Diagonal(Σ) * Vt) + + for jobu in ('A', 'S', 'N', 'O') + for jobvt in ('A', 'S', 'N', 'O') + (jobu == 'A') && (jobvt == 'A') && continue + (jobu == 'O') && (jobvt == 'O') && continue + d_A = CuMatrix(A) + U2, Σ2, Vt2 = oneMKL.gesvd!(jobu, jobvt, d_A) + @test Σ ≈ Σ2 + end + end end @testset "syevd! -- heevd!" begin From 165a0bec8020c2562471b4c12de2c62f865a9f3d Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 2 Dec 2024 15:40:31 -0600 Subject: [PATCH 2/3] Fix a typo --- test/onemkl.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onemkl.jl b/test/onemkl.jl index e0c20918..50651584 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1426,7 +1426,7 @@ end for jobvt in ('A', 'S', 'N', 'O') (jobu == 'A') && (jobvt == 'A') && continue (jobu == 'O') && (jobvt == 'O') && continue - d_A = CuMatrix(A) + d_A = oneMatrix(A) U2, Σ2, Vt2 = oneMKL.gesvd!(jobu, jobvt, d_A) @test Σ ≈ Σ2 end From 9dd54b7c4d21ffa4bf9171bbfac79c4ef4e2dc85 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 2 Dec 2024 21:17:29 -0600 Subject: [PATCH 3/3] Use ZE_NULL --- lib/mkl/wrappers_lapack.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/mkl/wrappers_lapack.jl b/lib/mkl/wrappers_lapack.jl index a47060bb..177c2510 100644 --- a/lib/mkl/wrappers_lapack.jl +++ b/lib/mkl/wrappers_lapack.jl @@ -312,11 +312,11 @@ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesv elseif jobu === 'S' oneMatrix{$elty}(undef, m, k) elseif jobu === 'N' || jobu === 'O' - oneMatrix{$elty}(undef, 0, 0) # Equivalence of CU_NULL? + ZE_NULL else error("jobu must be one of 'A', 'S', 'O', or 'N'") end - ldu = U == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(U, 2)) + ldu = U == ZE_NULL ? 1 : max(1, stride(U, 2)) S = oneVector{$relty}(undef, k) Vt = if jobvt === 'A' @@ -324,11 +324,11 @@ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesv elseif jobvt === 'S' oneMatrix{$elty}(undef, k, n) elseif jobvt === 'N' || jobvt === 'O' - oneMatrix{$elty}(undef, 0, 0) + ZE_NULL else error("jobvt must be one of 'A', 'S', 'O', or 'N'") end - ldvt = Vt == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2)) + ldvt = Vt == ZE_NULL ? 1 : max(1, stride(Vt, 2)) queue = global_queue(context(A), device()) scratchpad_size = $bname(sycl_queue(queue), jobu, jobvt, m, n, lda, ldu, ldvt)