ARPACK-Armadillo
SymmetricLDL.h
1 // Copyright (C) 2015 Yixuan Qiu
2 //
3 // This Source Code Form is subject to the terms of the Mozilla Public
4 // License, v. 2.0. If a copy of the MPL was not distributed with this
5 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 
7 #ifndef SYMMETRIC_LDL_H
8 #define SYMMETRIC_LDL_H
9 
10 #include <armadillo>
11 #include <stdexcept>
12 #include "LapackWrapperExtra.h"
13 
24 template <typename Scalar = double>
26 {
27 private:
28  typedef arma::Mat<Scalar> Matrix;
29  typedef arma::Col<Scalar> Vector;
30  typedef arma::Col<int> IntVector;
31 
32  int dim_n; // size of the matrix
33  char mat_uplo; // whether using lower triangle or upper triangle
34  Matrix mat_fac; // storing factorization structures
35  IntVector vec_fac; // storing factorization structures
36  bool computed; // whether factorization has been computed
37 
38 public:
45  dim_n(0), mat_uplo('L'), computed(false)
46  {}
47 
57  SymmetricLDL(const Matrix &mat, const char uplo = 'L') :
58  dim_n(mat.n_rows),
59  mat_uplo(uplo),
60  computed(false)
61  {
62  compute(mat, uplo);
63  }
64 
73  void compute(const Matrix &mat, const char uplo = 'L')
74  {
75  if(!mat.is_square())
76  throw std::invalid_argument("SymmetricLDL: matrix must be square");
77 
78  dim_n = mat.n_rows;
79  mat_uplo = (uplo == 'L' ? 'L' : 'U'); // force to be one of 'L' and 'U'
80  mat_fac = mat;
81  vec_fac.set_size(dim_n);
82 
83  Scalar lwork_query;
84  int lwork = -1, info;
85  arma::lapack::sytrf(&mat_uplo, &dim_n, mat_fac.memptr(), &dim_n,
86  vec_fac.memptr(), &lwork_query, &lwork, &info);
87  lwork = int(lwork_query);
88 
89  Scalar *work = new Scalar[lwork];
90  arma::lapack::sytrf(&mat_uplo, &dim_n, mat_fac.memptr(), &dim_n,
91  vec_fac.memptr(), work, &lwork, &info);
92  delete [] work;
93 
94  if(info < 0)
95  throw std::invalid_argument("Lapack sytrf: illegal value");
96  if(info > 0)
97  throw std::logic_error("SymmetricLDL: matrix is singular");
98 
99  computed = true;
100  }
101 
113  void solve(Vector &vec_in, Vector &vec_out)
114  {
115  if(!computed)
116  return;
117 
118  vec_out = vec_in;
119 
120  int one = 1;
121  int info;
122  arma::lapack::sytrs(&mat_uplo, &dim_n, &one, mat_fac.memptr(), &dim_n,
123  vec_fac.memptr(), vec_out.memptr(), &dim_n, &info);
124  if(info < 0)
125  throw std::invalid_argument("Lapack sytrs: illegal value");
126  }
127 };
128 
129 
130 
131 #endif // SYMMETRIC_LDL_H
void compute(const Matrix &mat, const char uplo= 'L')
Definition: SymmetricLDL.h:73
SymmetricLDL(const Matrix &mat, const char uplo= 'L')
Definition: SymmetricLDL.h:57
void solve(Vector &vec_in, Vector &vec_out)
Definition: SymmetricLDL.h:113