diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index 2e8a8606..7f0a792f 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(123) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_eig(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index 5528af0f..38f0b82e 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(123) if !is_buildkite TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_eigh(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 6c9f8fd4..f95a4832 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_lq(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 6f8dac9a..5e7db0ac 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_orthnull(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 9e4e366e..4386854d 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -17,5 +17,10 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 9ffc4798..f18971e7 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index 982ec040..22e55483 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) if !is_buildkite TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end