Commit d3c82971 by Francois Gygi

cleanup inverse and inverse_det members


git-svn-id: http://qboxcode.org/svn/qb/branches/efield@1588 cba15fb0-1239-40c8-b417-11db7ca47a34
parent e76f5de2
......@@ -71,7 +71,9 @@ using namespace std;
#define pdlacp2 pdlacp2_
#define pdlacp3 pdlacp3_
#define pdgetrf pdgetrf_
#define pzgetrf pzgetrf_
#define pdgetri pdgetri_
#define pzgetri pzgetri_
#define pdlapiv pdlapiv_
#define pzlapiv pzlapiv_
#define pdlapv2 pdlapv2_
......@@ -111,7 +113,9 @@ using namespace std;
#define zheev zheev_
#define idamax idamax_
#define dgetrf dgetrf_
#define zgetrf zgetrf_
#define dgetri dgetri_
#define zgetri zgetri_
#endif
extern "C"
......@@ -247,9 +251,14 @@ extern "C"
const int*, const int*, const int*, int*);
void pdgetrf(const int* m, const int* n, double* val,
int* ia, const int* ja, const int* desca, int* ipiv, int* info);
void pzgetrf(const int* m, const int* n, complex<double>* val,
int* ia, const int* ja, const int* desca, int* ipiv, int* info);
void pdgetri(const int* n, double* val,
const int* ia, const int* ja, int* desca, int* ipiv,
double* work, int* lwork, int* iwork, int* liwork, int* info);
void pzgetri(const int* n, complex<double>* val, const int* ia,
const int* ja, int* desca, int* ipiv, complex<double>* work,
int* lwork, int* iwork, int* liwork, int* info);
void pdlapiv(const char* direc, const char* rowcol, const char* pivroc,
const int* m, const int* n, double *a, const int* ia,
......@@ -355,8 +364,12 @@ extern "C"
void dtrtri(const char*, const char*, const int*, double*, const int*, int* );
void dgetrf(const int* m, const int* n, double* a, const int* lda,
int* ipiv, int*info);
void zgetrf(const int* m, const int* n, complex<double>* a, const int* lda,
int* ipiv, int*info);
void dgetri(const int* m, double* val, const int* lda, int* ipiv,
double* work, int* lwork, int* info);
void zgetri(const int* m, complex<double>* val, const int* lda, int* ipiv,
complex<double>* work, int* lwork, int* info);
}
......@@ -1898,39 +1911,110 @@ void DoubleMatrix::lu(valarray<int>& ipiv)
}
////////////////////////////////////////////////////////////////////////////////
// inverse of a general square matrix
// LU decomposition of a complex matrix
////////////////////////////////////////////////////////////////////////////////
void DoubleMatrix::inverse(void)
void ComplexMatrix::lu(valarray<int>& ipiv)
{
int info;
if ( active() )
{
assert(m_==n_);
valarray<int> ipiv(mloc_+mb_);
ipiv.resize(mloc_+mb_);
// LU decomposition
#ifdef SCALAPACK
int ione=1;
pdgetrf(&m_, &n_, val, &ione, &ione, desc_, &ipiv[0], &info);
pzgetrf(&m_, &n_, val, &ione, &ione, desc_, &ipiv[0], &info);
#else
dgetrf(&m_, &n_, val, &m_, &ipiv[0], &info);
zgetrf(&m_, &n_, val, &m_, &ipiv[0], &info);
#endif
if(info!=0)
{
cout << " ComplexMatrix::lu, info=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
exit(2);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////
// inverse of a general square matrix
////////////////////////////////////////////////////////////////////////////////
void DoubleMatrix::inverse(void)
{
int info;
if ( active() )
{
assert(m_==n_);
valarray<int> ipiv(mloc_+mb_);
// LU decomposition
lu(ipiv);
inverse_from_lu(ipiv);
}
}
////////////////////////////////////////////////////////////////////////////////
// compute inverse and determinant of a square matrix
////////////////////////////////////////////////////////////////////////////////
double DoubleMatrix::inverse_det(void)
{
int info;
if ( active() )
{
assert(m_==n_);
valarray<int> ipiv(mloc_+mb_);
lu(ipiv);
// compute determinant
valarray<double> diag(n_);
for ( int ii = 0; ii < n_; ii++ )
{
int iii = l(ii) * mb_ + x(ii);
int jjj = m(ii) * nb_ + y(ii);
if ( pr(ii) == ctxt_.myrow()
&& pc(ii) == ctxt_.mycol() )
diag[ii] = val[iii+mloc_*jjj];
}
ctxt_.dsum(n_,1,&diag[0],n_);
double det = 1.0;
for ( int ii = 0; ii < n_; ii++ )
det *= diag[ii];
inverse_from_lu(ipiv);
if(info!=0)
{
cout << " DoubleMatrix::inverse, info(getrf)=" << info << endl;
cout << " DoubleMatrix::inverse_det, info(getri)=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
exit(2);
#endif
}
return det;
}
}
// Compute inverse using LU decomposition
////////////////////////////////////////////////////////////////////////////////
// inverse from an LU decomposed square matrix
////////////////////////////////////////////////////////////////////////////////
void DoubleMatrix::inverse_from_lu(valarray<int>& ipiv)
{
int info;
if ( active() )
{
assert(m_==n_);
// Compute inverse using LU decomposition and array ipiv
#ifdef SCALAPACK
valarray<double> work(1);
valarray<int> iwork(1);
int lwork = -1;
int liwork = -1;
int ione = 1;
// First call to compute dimensions of work arrays lwork and liwork
// dimensions are returned in work[0] and iwork[0];
pdgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
......@@ -1954,7 +2038,107 @@ void DoubleMatrix::inverse(void)
#endif
if(info!=0)
{
cout << " DoubleMatrix::inverse, info(getri)=" << info << endl;
cout << " DoubleMatrix::inverse_from_lu, info(getri)=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
exit(2);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////
// compute inverse of a general square matrix
////////////////////////////////////////////////////////////////////////////////
void ComplexMatrix::inverse(void)
{
valarray<int> ipiv;
lu(ipiv);
inverse_from_lu(ipiv);
}
////////////////////////////////////////////////////////////////////////////////
// compute inverse and determinant of a complex square matrix
////////////////////////////////////////////////////////////////////////////////
complex<double> ComplexMatrix::inverse_det(void)
{
int info;
if ( active() )
{
assert(m_==n_);
valarray<int> ipiv(mloc_+mb_);
lu(ipiv);
// compute determinant
valarray<complex<double> > diag(n_);
for ( int ii = 0; ii < n_; ii++ )
{
int iii = l(ii) * mb_ + x(ii);
int jjj = m(ii) * nb_ + y(ii);
if ( pr(ii) == ctxt_.myrow()
&& pc(ii) == ctxt_.mycol() )
diag[ii] = val[iii+mloc_*jjj];
}
ctxt_.dsum(n_*2,1,(double*)&diag[0],n_*2);
complex<double> det = 1.0;
for ( int ii = 0; ii < n_; ii++ )
det *= diag[ii];
inverse_from_lu(ipiv);
if(info!=0)
{
cout << " ComplexMatrix::inverse_det, info(getri)=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
exit(2);
#endif
}
return det;
}
}
////////////////////////////////////////////////////////////////////////////////
// compute inverse of an LU decomposed matrix
void ComplexMatrix::inverse_from_lu(valarray<int>& ipiv)
{
// it is assumed that the current matrix is LU decomposed
int info;
if ( active() )
{
assert(m_==n_);
// Compute inverse using LU decomposition and array ipiv computed in lu()
#ifdef SCALAPACK
valarray< complex<double> > work(1);
valarray<int> iwork(1);
int lwork = -1;
int liwork = -1;
int ione = 1;
pzgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
&work[0], &lwork, &iwork[0], &liwork, &info);
lwork = (int) work[0].real() + 1;
liwork = iwork[0];
work.resize(lwork);
iwork.resize(liwork);
// Compute inverse
pzgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
&work[0], &lwork, &iwork[0], &liwork, &info);
#else
valarray< complex<double> > work(1);
int lwork = -1;
// First call to compute optimal size of work array, returned in work[0]
zgetri(&m_, val, &m_, &ipiv[0], &work[0], &lwork, &info);
lwork = (int) work[0] + 1;
work.resize(lwork);
zgetri(&m_, val, &m_, &ipiv[0], &work[0], &lwork, &info);
#endif
if(info!=0)
{
cout << " ComplexMatrix::inverse, info(getri)=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
......
......@@ -259,6 +259,9 @@ class DoubleMatrix
// compute inverse of a square matrix
void inverse(void);
// compute inverse and determinant of a square matrix
double inverse_det(void);
void inverse_from_lu(std::valarray<int>& ipiv);
// Inverse of triangular matrix
void trtri(char uplo,char diag);
......@@ -529,6 +532,15 @@ class ComplexMatrix
// Inverse of a symmetric matrix from Cholesky factor
void potri(char uplo);
// LU decomposition
void lu(std::valarray<int>& ipiv);
// compute inverse of a square matrix
void inverse(void);
// compute inverse and determinant of a square matrix
std::complex<double> inverse_det(void);
void inverse_from_lu(std::valarray<int>& ipiv);
// Inverse of triangular matrix
void trtri(char uplo,char diag);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment