Commit 058730d2 by Francois Gygi

Implemented PSDA for complex wave functions

parent 26c34952
......@@ -56,6 +56,7 @@ using namespace std;
#define pdtrsm pdtrsm_
#define pztrsm pztrsm_
#define pdtrtrs pdtrtrs_
#define pztrtrs pztrtrs_
#define pdpotrf pdpotrf_
#define pzpotrf pzpotrf_
#define pdpotri pdpotri_
......@@ -67,6 +68,7 @@ using namespace std;
#define pzheev pzheev_
#define pzheevd pzheevd_
#define pdtrtri pdtrtri_
#define pztrtri pztrtri_
#define pdlatra pdlatra_
#define pdlacp2 pdlacp2_
#define pdlacp3 pdlacp3_
......@@ -102,6 +104,7 @@ using namespace std;
#define dtrmm dtrmm_
#define dtrsm dtrsm_
#define dtrtri dtrtri_
#define ztrtri ztrtri_
#define ztrsm ztrsm_
#define dtrtrs dtrtrs_
#define dpotrf dpotrf_
......@@ -197,6 +200,9 @@ extern "C"
void pdtrtrs(const char*, const char*, const char*, const int*, const int*,
const double*, const int*, const int*, const int*,
double*, const int*, const int*, const int*, int*);
void pztrtrs(const char*, const char*, const char*, const int*, const int*,
const complex<double>*, const int*, const int*, const int*,
complex<double>*, const int*, const int*, const int*, int*);
void pigemr2d(const int*,const int*,
const int*,const int*,const int*, const int*,
int*,const int*,const int*,const int*,const int*);
......@@ -249,6 +255,8 @@ extern "C"
int* iwork, int* liwork, int* info);
void pdtrtri(const char*, const char*, const int*, double*,
const int*, const int*, const int*, int*);
void pztrtri(const char*, const char*, const int*, complex<double>*,
const int*, const int*, const int*, int*);
void pdgetrf(const int* m, const int* n, double* val,
int* ia, const int* ja, const int* desca, int* ipiv, int* info);
void pzgetrf(const int* m, const int* n, complex<double>* val,
......@@ -1882,6 +1890,41 @@ void DoubleMatrix::trtrs(char uplo, char trans, char diag,
}
////////////////////////////////////////////////////////////////////////////////
// Solves a triangular system of the form A * X = B or
// A**T * X = B, where A is a triangular matrix of order N,
// and B is an N-by-NRHS matrix.
// Output in B.
////////////////////////////////////////////////////////////////////////////////
void ComplexMatrix::trtrs(char uplo, char trans, char diag,
ComplexMatrix& b) const
{
int info;
if ( active() )
{
assert(m_==n_);
#ifdef SCALAPACK
int ione=1;
pztrtrs(&uplo, &trans, &diag, &m_, &b.n_,
val, &ione, &ione, desc_,
b.val, &ione, &ione, b.desc_, &info);
#else
ztrtrs(&uplo, &trans, &diag, &m_, &b.n_, val, &m_,
b.val, &b.m_, &info);
#endif
if(info!=0)
{
cout <<" ComplexMatrix::trtrs, info=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
exit(2);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////
// LU decomposition of a double matrix
////////////////////////////////////////////////////////////////////////////////
void DoubleMatrix::lu(valarray<int>& ipiv)
......@@ -2269,6 +2312,31 @@ void DoubleMatrix::trtri(char uplo, char diag)
}
}
void ComplexMatrix::trtri(char uplo, char diag)
{
int info;
if ( active() )
{
assert(m_==n_);
#ifdef SCALAPACK
int ione=1;
pztrtri(&uplo, &diag, &m_, val, &ione, &ione, desc_, &info);
#else
ztrtri(&uplo, &diag, &m_, val, &m_, &info);
#endif
if(info!=0)
{
cout << " Matrix::trtri, info=" << info << endl;
#ifdef USE_MPI
MPI_Abort(MPI_COMM_WORLD, 2);
#else
exit(2);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Polar decomposition A = UH
// Replace *this with its orthogonal polar factor U
......@@ -2352,6 +2420,88 @@ void DoubleMatrix::polar(double tol, int maxiter)
}
////////////////////////////////////////////////////////////////////////////////
// Polar decomposition A = UH (complex case)
// Replace *this with its unitary polar factor U
// return when iter > maxiter or ||I - X^H*X|| < tol
////////////////////////////////////////////////////////////////////////////////
void ComplexMatrix::polar(double tol, int maxiter)
{
ComplexMatrix x(ctxt_,m_,n_,mb_,nb_);
ComplexMatrix xp(ctxt_,m_,n_,mb_,nb_);
ComplexMatrix q(ctxt_,n_,n_,nb_,nb_);
ComplexMatrix qt(ctxt_,n_,n_,nb_,nb_);
ComplexMatrix t(ctxt_,n_,n_,nb_,nb_);
#ifdef SCALAPACK
double qnrm2 = numeric_limits<double>::max();
int iter = 0;
x = *this;
while ( iter < maxiter && qnrm2 > tol )
{
// q = I - x^T * x
q.identity();
q.herk('l','c',-1.0,x,1.0);
q.symmetrize('l');
double qnrm2 = q.nrm2();
#ifdef DEBUG
if ( ctxt_.onpe0() )
cout << " ComplexMatrix::polar: qnrm2 = " << qnrm2 << endl;
#endif
// choose Bjork-Bowie or Higham iteration depending on q.nrm2
// threshold value
// see A. Bjork and C. Bowie, SIAM J. Num. Anal. 8, 358 (1971) p.363
if ( qnrm2 < 1.0 )
{
// Bjork-Bowie iteration
// compute xp = x * ( I + 0.5*q * ( I + 0.75 * q ) )
// t = ( I + 0.75 * q )
t.identity();
t.axpy(0.75,q);
// compute q*t
qt.gemm('n','n',1.0,q,t,0.0);
// xp = x * ( I + 0.5*q * ( I + 0.75 * q ) )
// = x * ( I + 0.5 * qt )
// Use t to store (I + 0.5 * qt)
t.identity();
t.axpy(0.5,qt);
// t now contains (I + 0.5 * qt)
// xp = x * t
xp.gemm('n','n',1.0,x,t,0.0);
// update x
x = xp;
}
else
{
// Higham iteration
assert(m_==n_);
//if ( ctxt_.onpe0() )
// cout << " ComplexMatrix::polar: using Higham algorithm" << endl;
// t = X^H
t.transpose(1.0,x,0.0);
t.inverse();
// t now contains X^-H
// xp = 0.5 * ( x + x^-H );
for ( int i = 0; i < x.size(); i++ )
x[i] = 0.5 * ( x[i] + t[i] );
}
iter++;
}
*this = x;
#else
#error "ComplexMatrix::polar only implemented with SCALAPACK"
#endif
}
////////////////////////////////////////////////////////////////////////////////
// estimate the reciprocal of the condition number (in the 1-norm) of a
// real symmetric positive definite matrix using the Cholesky factorization
// A = U**T*U or A = L*L**T computed by DoubleMatrix::potrf
......
......@@ -537,6 +537,9 @@ class ComplexMatrix
// Inverse of a symmetric matrix from Cholesky factor
void potri(char uplo);
// Polar decomposition, tolerance ||I-X^H*X||<tol or iter<maxiter
void polar(double tol, int maxiter);
// LU decomposition
void lu(std::valarray<int>& ipiv);
......
......@@ -62,9 +62,12 @@ void PSDAWavefunctionStepper::update(Wavefunction& dwf)
}
else
{
// not implemented in the complex case
cout << "PSDA is not implemented for complex wave functions" << endl;
assert(false);
ComplexMatrix& c = wf_.sd(ispin,ikp)->c();
ComplexMatrix& cp = dwf.sd(ispin,ikp)->c();
ComplexMatrix a(c.context(),c.n(),c.n(),c.nb(),c.nb());
a.gemm('c','n',1.0,c,cp,0.0);
// cp = cp - c * a
cp.gemm('n','n',-1.0,c,a,1.0);
}
}
}
......@@ -123,6 +126,9 @@ void PSDAWavefunctionStepper::update(Wavefunction& dwf)
a += f * delta_f;
b += delta_f * delta_f;
}
if ( wf_.sd(ispin,ikp)->basis().real() )
{
// correct for double counting of asum and bsum on first row
// factor 2.0: G and -G
a *= 2.0;
......@@ -140,6 +146,7 @@ void PSDAWavefunctionStepper::update(Wavefunction& dwf)
b -= delta_f0 * delta_f0 + delta_f1 * delta_f1;
}
}
}
// a and b contain the partial sums of a and b
double tmpvec[2] = { a, b };
......@@ -151,9 +158,6 @@ void PSDAWavefunctionStepper::update(Wavefunction& dwf)
if ( b != 0.0 )
theta = - a / b;
if ( wf_.sdcontext()->onpe0() )
cout << " Anderson extrapolation: theta=" << theta;
if ( theta < -1.0 )
{
theta = 0.0;
......@@ -161,9 +165,6 @@ void PSDAWavefunctionStepper::update(Wavefunction& dwf)
theta = min(2.0,theta);
if ( wf_.sdcontext()->onpe0() )
cout <<" (" << theta << ")" << endl;
// extrapolation
for ( int i = 0; i < 2*mloc*nloc; i++ )
{
......
......@@ -669,11 +669,46 @@ void SlaterDet::lowdin(void)
}
else
{
// complex case: not implemented
if ( ctxt_.onpe0() )
cout << " SlaterDet::lowdin: not implemented, reverting to Gram-Schmidt"
<< endl;
gram();
// complex case
ComplexMatrix c_tmp(c_);
ComplexMatrix l(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
ComplexMatrix x(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
ComplexMatrix t(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
l.clear();
l.herk('l','c',1.0,c_,0.0);
//cout << "SlaterDet::lowdin: A=\n" << l << endl;
// Cholesky decomposition of A=Y^H Y
l.potrf('l');
// The lower triangle of l now contains the Cholesky factor of Y^T Y
//cout << "SlaterDet::lowdin: L=\n" << l << endl;
// Compute the polar decomposition of R = L^T
x.transpose(1.0,l,0.0);
// x now contains R
const double tol = 1.e-6;
const int maxiter = 3;
x.polar(tol,maxiter);
// x now contains the unitary polar factor U of the
// polar decomposition R = UH
//cout << " SlaterDet::lowdin: unitary polar factor=\n" << x << endl;
// Compute L^-1
l.trtri('l','n');
// l now contains L^-1
// Form the product L^-T U
t.gemm('c','n',1.0,l,x,0.0);
// Multiply c by L^-T U
c_.gemm('n','n',1.0,c_tmp,t,0.0);
}
}
......@@ -786,15 +821,105 @@ void SlaterDet::ortho_align(const SlaterDet& sd)
#if TIMING
tmap["gemm2"].stop();
#endif
}
else
{
// complex case: not implemented
if ( ctxt_.onpe0() )
cout << " SlaterDet::lowdin: not implemented, reverting to riccati"
<< endl;
riccati(sd);
// complex case
ComplexMatrix c_tmp(c_);
const ComplexMatrix& sdc = sd.c();
ComplexMatrix l(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
ComplexMatrix x(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
#if TIMING
tmap["herk"].reset();
tmap["herk"].start();
#endif
l.clear();
l.herk('l','c',1.0,c_,0.0);
#if TIMING
tmap["herk"].stop();
#endif
// Cholesky decomposition of A=Y^H Y
#if TIMING
tmap["potrf"].reset();
tmap["potrf"].start();
#endif
l.potrf('l');
#if TIMING
tmap["potrf"].stop();
#endif
// The lower triangle of l now contains the Cholesky factor of Y^T Y
//cout << "SlaterDet::ortho_align: L=\n" << l << endl;
// Compute the polar decomposition of L^-1 B
// where B = C^H sd.C
// Compute B: store result in x
#if TIMING
tmap["gemm"].reset();
tmap["gemm"].start();
#endif
x.gemm('c','n',1.0,c_,sdc,0.0);
#if TIMING
tmap["gemm"].stop();
#endif
// Form the product L^-1 B, store result in x
// triangular solve: L X = B
// trtrs: solve op(*this) * X = Z, output in Z
#if TIMING
tmap["trtrs"].reset();
tmap["trtrs"].start();
#endif
l.trtrs('l','n','n',x);
#if TIMING
tmap["trtrs"].stop();
#endif
// x now contains L^-1 B
// compute the polar decomposition of X = L^-1 B
#if TIMING
tmap["polar"].reset();
tmap["polar"].start();
#endif
const double tol = 1.e-6;
const int maxiter = 2;
x.polar(tol,maxiter);
#if TIMING
tmap["polar"].stop();
#endif
// x now contains the unitary polar factor X of the
// polar decomposition L^-1 B = XH
//cout << " SlaterDet::ortho_align: unitary polar factor=\n"
// << x << endl;
// Form the product L^-T Q
// Solve trans(L) Z = X
#if TIMING
tmap["trtrs2"].reset();
tmap["trtrs2"].start();
#endif
l.trtrs('l','c','n',x);
#if TIMING
tmap["trtrs2"].stop();
#endif
// x now contains L^-H Q
// Multiply c by L^-H Q
#if TIMING
tmap["gemm2"].reset();
tmap["gemm2"].start();
#endif
c_.gemm('n','n',1.0,c_tmp,x,0.0);
#if TIMING
tmap["gemm2"].stop();
#endif
}
#if TIMING
for ( TimerMap::iterator i = tmap.begin(); i != tmap.end(); i++ )
......@@ -831,7 +956,7 @@ void SlaterDet::align(const SlaterDet& sd)
DoubleMatrix t(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
// Compute the polar decomposition of B
// where B = C^T sd.C
// where B = C^H sd.C
#if TIMING
tmap["align_gemm1"].start();
......@@ -890,10 +1015,66 @@ void SlaterDet::align(const SlaterDet& sd)
}
else
{
// complex case: not implemented
if ( ctxt_.onpe0() )
cout << " SlaterDet::align: not implemented, alignment skipped"
<< endl;
// complex case
ComplexMatrix c_tmp(c_);
const ComplexMatrix& sdc = sd.c();
ComplexMatrix x(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
ComplexMatrix t(ctxt_,c_.n(),c_.n(),c_.nb(),c_.nb());
// Compute the polar decomposition of B
// where B = C^H sd.C
#if TIMING
tmap["align_gemm1"].start();
#endif
// Compute B: store result in x
x.gemm('c','n',1.0,c_,sdc,0.0);
#if TIMING
tmap["align_gemm1"].stop();
#endif
// x now contains B
//cout << "SlaterDet::align: B=\n" << x << endl;
// Compute the distance | c - sdc | before alignment
//for ( int i = 0; i < c_proxy.size(); i++ )
// c_tmp_proxy[i] = c_proxy[i] - sdc_proxy[i];
//cout << " SlaterDet::align: distance before: "
// << c_tmp_proxy.nrm2() << endl;
// compute the polar decomposition of B
double tol = 1.e-6;
const int maxiter = 3;
#if TIMING
tmap["align_polar"].start();
#endif
x.polar(tol,maxiter);
#if TIMING
tmap["align_while"].stop();
#endif
// x now contains the unitary polar factor X of the
// polar decomposition B = XH
//cout << " SlaterDet::align: unitary polar factor=\n" << x << endl;
#if TIMING
tmap["align_gemm2"].start();
#endif
// Multiply c by X
c_tmp = c_;
c_.gemm('n','n',1.0,c_tmp,x,0.0);
#if TIMING
tmap["align_gemm2"].stop();
#endif
// Compute the distance | c - sdc | after alignment
//for ( int i = 0; i < c_proxy.size(); i++ )
// c_tmp_proxy[i] = c_proxy[i] - sdc_proxy[i];
//cout << " SlaterDet::align: distance after: "
// << c_tmp_proxy.nrm2() << endl;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment