1 #ifndef GEN_EIGS_SOLVER_H
2 #define GEN_EIGS_SOLVER_H
6 #include <Eigen/Eigenvalues>
16 #include "UpperHessenbergQR.h"
17 #include "MatOp/DenseGenMatProd.h"
18 #include "MatOp/DenseGenRealShiftSolve.h"
19 #include "MatOp/DenseGenComplexShiftSolve.h"
77 template <
typename Scalar = double,
83 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
84 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector;
85 typedef Eigen::Array<Scalar, Eigen::Dynamic, 1> Array;
86 typedef Eigen::Array<bool, Eigen::Dynamic, 1> BoolArray;
87 typedef Eigen::Map<const Matrix> MapMat;
88 typedef Eigen::Map<const Vector> MapVec;
90 typedef std::complex<Scalar> Complex;
91 typedef Eigen::Matrix<Complex, Eigen::Dynamic, Eigen::Dynamic> ComplexMatrix;
92 typedef Eigen::Matrix<Complex, Eigen::Dynamic, 1> ComplexVector;
94 typedef Eigen::EigenSolver<Matrix> EigenSolver;
95 typedef Eigen::HouseholderQR<Matrix> QRdecomp;
96 typedef Eigen::HouseholderSequence<Matrix, Vector> QRQ;
98 typedef std::pair<Complex, int> SortPair;
116 ComplexVector ritz_val;
117 ComplexMatrix ritz_vec;
128 void factorize_from(
int from_k,
int to_m,
const Vector &fk)
130 if(to_m <= from_k)
return;
134 Vector v(dim_n), w(dim_n);
137 fac_H.rightCols(ncv - from_k).setZero();
138 fac_H.block(from_k, 0, ncv - from_k, from_k).setZero();
139 for(
int i = from_k; i <= to_m - 1; i++)
142 v.noalias() = fac_f / beta;
144 fac_H.block(i, 0, 1, i).setZero();
145 fac_H(i, i - 1) = beta;
147 op->perform_op(v.data(), w.data());
150 Vector h = fac_V.leftCols(i + 1).transpose() * w;
151 fac_H.block(0, i, i + 1, 1) = h;
153 fac_f = w - fac_V.leftCols(i + 1) * h;
158 Scalar v1f = fac_f.dot(fac_V.col(0));
159 if(v1f > prec || v1f < -prec)
162 Vf.tail(i) = fac_V.block(0, 1, dim_n, i).transpose() * fac_f;
164 fac_f -= fac_V.leftCols(i + 1) * Vf;
169 static bool is_complex(Complex v, Scalar eps)
171 return std::abs(v.imag()) > eps;
174 static bool is_conj(Complex v1, Complex v2, Scalar eps)
176 return std::abs(v1 - std::conj(v2)) < eps;
191 for(
int i = k; i < ncv; i++)
193 if(is_complex(ritz_val[i], prec) && is_conj(ritz_val[i], ritz_val[i + 1], prec))
201 Scalar re = ritz_val[i].real();
202 Scalar s = std::norm(ritz_val[i]);
204 HH.diagonal().array() -= 2 * re;
206 HH.diagonal().array() += s;
209 decomp_gen.compute(HH);
210 QRQ Q = decomp_gen.householderQ();
213 fac_V.applyOnTheRight(Q);
215 fac_H.applyOnTheRight(Q);
216 fac_H.applyOnTheLeft(Q.adjoint());
218 em.applyOnTheLeft(Q.adjoint());
223 fac_H.diagonal().array() -= ritz_val[i].real();
230 fac_H.diagonal().array() += ritz_val[i].real();
236 Vector fk = fac_f * em[k - 1] + fac_V.col(k) * fac_H(k, k - 1);
237 factorize_from(k, ncv, fk);
242 int num_converged(Scalar tol)
245 Array thresh = tol * ritz_val.head(nev).array().abs().max(prec);
246 Array resid = ritz_vec.template bottomRows<1>().transpose().array().abs() * fac_f.norm();
248 ritz_conv = (resid < thresh);
250 return ritz_conv.cast<
int>().sum();
254 int nev_adjusted(
int nconv)
260 if(is_complex(ritz_val[nev - 1], prec) &&
261 is_conj(ritz_val[nev - 1], ritz_val[nev], prec))
266 nev_new = nev_new + std::min(nconv, (ncv - nev_new) / 2);
267 if(nev_new == 1 && ncv >= 6)
269 else if(nev_new == 1 && ncv > 3)
272 if(nev_new > ncv - 2)
276 if(is_complex(ritz_val[nev_new - 1], prec) &&
277 is_conj(ritz_val[nev_new - 1], ritz_val[nev_new], prec))
286 void retrieve_ritzpair()
288 EigenSolver eig(fac_H);
289 ComplexVector evals = eig.eigenvalues();
290 ComplexMatrix evecs = eig.eigenvectors();
292 std::vector<SortPair> pairs(ncv);
293 EigenvalueComparator<Complex, SelectionRule> comp;
294 for(
int i = 0; i < ncv; i++)
296 pairs[i].first = evals[i];
299 std::sort(pairs.begin(), pairs.end(), comp);
302 for(
int i = 0; i < ncv; i++)
304 ritz_val[i] = pairs[i].first;
306 for(
int i = 0; i < nev; i++)
308 ritz_vec.col(i) = evecs.col(pairs[i].second);
315 virtual void sort_ritzpair()
317 std::vector<SortPair> pairs(nev);
318 EigenvalueComparator<Complex, LARGEST_MAGN> comp;
319 for(
int i = 0; i < nev; i++)
321 pairs[i].first = ritz_val[i];
324 std::sort(pairs.begin(), pairs.end(), comp);
326 ComplexMatrix new_ritz_vec(ncv, nev);
327 BoolArray new_ritz_conv(nev);
329 for(
int i = 0; i < nev; i++)
331 ritz_val[i] = pairs[i].first;
332 new_ritz_vec.col(i) = ritz_vec.col(pairs[i].second);
333 new_ritz_conv[i] = ritz_conv[pairs[i].second];
336 ritz_vec.swap(new_ritz_vec);
337 ritz_conv.swap(new_ritz_conv);
362 ncv(ncv_ > dim_n ? dim_n : ncv_),
365 prec(std::pow(std::numeric_limits<Scalar>::epsilon(), Scalar(2.0 / 3)))
367 if(nev_ < 1 || nev_ > dim_n - 2)
368 throw std::invalid_argument(
"nev must satisfy 1 <= nev <= n - 2, n is the size of matrix");
370 if(ncv_ < nev_ + 2 || ncv_ > dim_n)
371 throw std::invalid_argument(
"ncv must satisfy nev + 2 <= ncv <= n, n is the size of matrix");
383 void init(
const Scalar *init_resid)
386 fac_V.resize(dim_n, ncv);
387 fac_H.resize(ncv, ncv);
389 ritz_val.resize(ncv);
390 ritz_vec.resize(ncv, nev);
391 ritz_conv.resize(nev);
401 std::copy(init_resid, init_resid + dim_n, v.data());
402 Scalar vnorm = v.norm();
404 throw std::invalid_argument(
"initial residual vector cannot be zero");
408 op->perform_op(v.data(), w.data());
411 fac_H(0, 0) = v.dot(w);
412 fac_f = w - v * fac_H(0, 0);
425 Vector init_resid = Vector::Random(dim_n);
426 init_resid.array() -= 0.5;
427 init(init_resid.data());
438 int compute(
int maxit = 1000, Scalar tol = 1e-10)
441 factorize_from(1, ncv, fac_f);
444 int i, nconv, nev_adj;
445 for(i = 0; i < maxit; i++)
447 nconv = num_converged(tol);
451 nev_adj = nev_adjusted(nconv);
459 return std::min(nev, nconv);
481 int nconv = ritz_conv.cast<
int>().sum();
482 ComplexVector res(nconv);
488 for(
int i = 0; i < nev; i++)
492 res[j] = ritz_val[i];
509 int nconv = ritz_conv.cast<
int>().sum();
510 ComplexMatrix res(dim_n, nconv);
515 ComplexMatrix ritz_vec_conv(ncv, nconv);
517 for(
int i = 0; i < nev; i++)
521 ritz_vec_conv.col(j) = ritz_vec.col(i);
526 res.noalias() = fac_V * ritz_vec_conv;
555 template <
typename Scalar = double,
561 typedef std::complex<Scalar> Complex;
562 typedef Eigen::Array<Complex, Eigen::Dynamic, 1> ComplexArray;
571 ComplexArray ritz_val_org = Scalar(1.0) / this->ritz_val.head(this->nev).array() + sigma;
572 this->ritz_val.head(this->nev) = ritz_val_org;
595 GenEigsSolver<Scalar, SelectionRule, OpType>(op_, nev_, ncv_),
598 this->op->set_shift(sigma);
625 template <
typename Scalar = double,
631 typedef Eigen::Array<Scalar, Eigen::Dynamic, 1> Array;
632 typedef std::complex<Scalar> Complex;
633 typedef Eigen::Array<Complex, Eigen::Dynamic, 1> ComplexArray;
652 ComplexArray nu = this->ritz_val.head(this->nev).array();
653 ComplexArray tmp1 = Scalar(0.5) / nu + sigmar;
654 ComplexArray tmp2 = (Scalar(1) / nu / nu - 4 * sigmai * sigmai).sqrt() * Scalar(0.5);
656 ComplexArray root1 = tmp1 + tmp2;
657 ComplexArray root2 = tmp1 - tmp2;
659 ComplexArray v = this->fac_V * this->ritz_vec.col(0);
660 Array v_real = v.real();
661 Array v_imag = v.imag();
662 Array lhs_real(this->dim_n), lhs_imag(this->dim_n);
664 this->op->set_shift(sigmar, 0);
665 this->op->perform_op(v_real.data(), lhs_real.data());
666 this->op->perform_op(v_imag.data(), lhs_imag.data());
668 ComplexArray rhs1 = v / (root1[0] - Complex(sigmar, 0));
669 ComplexArray rhs2 = v / (root2[0] - Complex(sigmar, 0));
671 Scalar err1 = (rhs1.real() - lhs_real).abs().sum() + (rhs1.imag() - lhs_imag).abs().sum();
672 Scalar err2 = (rhs2.real() - lhs_real).abs().sum() + (rhs2.imag() - lhs_imag).abs().sum();
676 this->ritz_val.head(this->nev) = root1;
678 this->ritz_val.head(this->nev) = root2;
703 GenEigsSolver<Scalar, SelectionRule, OpType>(op_, nev_, ncv_),
704 sigmar(sigmar_), sigmai(sigmai_)
706 this->op->set_shift(sigmar, sigmai);
710 #endif // GEN_EIGS_SOLVER_H
virtual Matrix matrix_RQ()
virtual void compute(ConstGenericMatrix &mat)
GenEigsRealShiftSolver(OpType *op_, int nev_, int ncv_, Scalar sigma_)
void apply_YQ(GenericMatrix Y)
void init(const Scalar *init_resid)
GenEigsSolver(OpType *op_, int nev_, int ncv_)
ComplexVector eigenvalues()
int compute(int maxit=1000, Scalar tol=1e-10)
ComplexMatrix eigenvectors()
GenEigsComplexShiftSolver(OpType *op_, int nev_, int ncv_, Scalar sigmar_, Scalar sigmai_)
void apply_QtY(Vector &Y)