diff --git a/lax/src/eig.rs b/lax/src/eig.rs index 53245de7..a6d6a16e 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -23,8 +23,16 @@ macro_rules! impl_eig_complex { mut a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); - // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. - // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A + // LAPACK assumes a column-major input. A row-major input can + // be interpreted as the transpose of a column-major input. So, + // for row-major inputs, we we want to solve the following, + // given the column-major input `A`: + // + // A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H + // + // So, in this case, the right eigenvectors are the conjugates + // of the left eigenvectors computed with `A`, and the + // eigenvalues are the eigenvalues computed with `A`. let (jobvl, jobvr) = if calc_v { match l { MatrixLayout::C { .. } => (b'V', b'N'), @@ -118,8 +126,22 @@ macro_rules! impl_eig_real { mut a: &mut [Self], ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); - // Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate. - // However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A + // LAPACK assumes a column-major input. A row-major input can + // be interpreted as the transpose of a column-major input. So, + // for row-major inputs, we we want to solve the following, + // given the column-major input `A`: + // + // A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H + // + // So, in this case, the right eigenvectors are the conjugates + // of the left eigenvectors computed with `A`, and the + // eigenvalues are the eigenvalues computed with `A`. + // + // We could conjugate the eigenvalues instead of the + // eigenvectors, but we have to reconstruct the eigenvectors + // into new matrices anyway, and by not modifying the + // eigenvalues, we preserve the nice ordering specified by + // `sgeev`/`dgeev`. let (jobvl, jobvr) = if calc_v { match l { MatrixLayout::C { .. } => (b'V', b'N'), @@ -211,40 +233,34 @@ macro_rules! impl_eig_real { // - v(j) = VR(:,j) + i*VR(:,j+1) // - v(j+1) = VR(:,j) - i*VR(:,j+1). // - // ``` - // j -> <----pair----> <----pair----> - // [ ... (real), (imag), (imag), (imag), (imag), ... ] : eigs - // ^ ^ ^ ^ ^ - // false false true false true : is_conjugate_pair - // ``` + // In the C-layout case, we need the conjugates of the left + // eigenvectors, so the signs should be reversed. + let n = n as usize; let v = vr.or(vl).unwrap(); let mut eigvecs = unsafe { vec_uninit(n * n) }; - let mut is_conjugate_pair = false; // flag for check `j` is complex conjugate - for j in 0..n { - if eig_im[j] == 0.0 { - // j-th eigenvalue is real - for i in 0..n { - eigvecs[i + j * n] = Self::complex(v[i + j * n], 0.0); + let mut col = 0; + while col < n { + if eig_im[col] == 0. { + // The corresponding eigenvalue is real. + for row in 0..n { + let re = v[row + col * n]; + eigvecs[row + col * n] = Self::complex(re, 0.); } + col += 1; } else { - // j-th eigenvalue is complex - // complex conjugated pair can be `j-1` or `j+1` - if is_conjugate_pair { - let j_pair = j - 1; - assert!(j_pair < n); - for i in 0..n { - eigvecs[i + j * n] = Self::complex(v[i + j_pair * n], v[i + j * n]); - } - } else { - let j_pair = j + 1; - assert!(j_pair < n); - for i in 0..n { - eigvecs[i + j * n] = - Self::complex(v[i + j * n], -v[i + j_pair * n]); + // This is a complex conjugate pair. + assert!(col + 1 < n); + for row in 0..n { + let re = v[row + col * n]; + let mut im = v[row + (col + 1) * n]; + if jobvl == b'V' { + im = -im; } + eigvecs[row + col * n] = Self::complex(re, im); + eigvecs[row + (col + 1) * n] = Self::complex(re, -im); } - is_conjugate_pair = !is_conjugate_pair; + col += 2; } } diff --git a/ndarray-linalg/tests/eig.rs b/ndarray-linalg/tests/eig.rs index 28314b8a..d8144e68 100644 --- a/ndarray-linalg/tests/eig.rs +++ b/ndarray-linalg/tests/eig.rs @@ -1,6 +1,19 @@ use ndarray::*; use ndarray_linalg::*; +fn sorted_eigvals(eigvals: ArrayView1<'_, T>) -> Array1 { + let mut indices: Vec = (0..eigvals.len()).collect(); + indices.sort_by(|&ind1, &ind2| { + let e1 = eigvals[ind1]; + let e2 = eigvals[ind2]; + e1.re() + .partial_cmp(&e2.re()) + .unwrap() + .then(e1.im().partial_cmp(&e2.im()).unwrap()) + }); + indices.iter().map(|&ind| eigvals[ind]).collect() +} + // Test Av_i = e_i v_i for i = 0..n fn test_eig(a: Array2, eigs: Array1, vecs: Array2) where @@ -87,7 +100,10 @@ fn test_matrix_real() -> Array2 { } fn test_matrix_real_t() -> Array2 { - test_matrix_real::().t().permuted_axes([1, 0]).to_owned() + let orig = test_matrix_real::(); + let mut out = Array2::zeros(orig.raw_dim().f()); + out.assign(&orig); + out } fn answer_eig_real() -> Array1 { @@ -154,10 +170,10 @@ fn test_matrix_complex() -> Array2 { } fn test_matrix_complex_t() -> Array2 { - test_matrix_complex::() - .t() - .permuted_axes([1, 0]) - .to_owned() + let orig = test_matrix_complex::(); + let mut out = Array2::zeros(orig.raw_dim().f()); + out.assign(&orig); + out } fn answer_eig_complex() -> Array1 { @@ -213,7 +229,11 @@ macro_rules! impl_test_real { fn [<$real _eigvals_t>]() { let a = test_matrix_real_t::<$real>(); let (e, _vecs) = a.eig().unwrap(); - assert_close_l2!(&e, &answer_eig_real::<$real>(), 1.0e-3); + assert_close_l2!( + &sorted_eigvals(e.view()), + &sorted_eigvals(answer_eig_real::<$real>().view()), + 1.0e-3 + ); } #[test]