//////////////////////////////////////////////////////////////////////////////// // // Copyright (c) 2008 The Regents of the University of California // // This file is part of Qbox // // Qbox is distributed under the terms of the GNU General Public License // as published by the Free Software Foundation, either version 2 of // the License, or (at your option) any later version. // See the file COPYING in the root directory of this distribution // or . // //////////////////////////////////////////////////////////////////////////////// // // MLWFTransform.C // //////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include "MLWFTransform.h" #include "D3vector.h" #include "Basis.h" #include "SlaterDet.h" #include "UnitCell.h" #include "jade.h" #include "blas.h" using namespace std; //////////////////////////////////////////////////////////////////////////////// MLWFTransform::MLWFTransform(const SlaterDet& sd) : sd_(sd), cell_(sd.basis().cell()), ctxt_(sd.context()), bm_(BasisMapping(sd.basis())), maxsweep_(50), tol_(1.e-8) { a_.resize(6); adiag_.resize(6); const int n = sd.c().n(); const int nb = sd.c().nb(); for ( int k = 0; k < 6; k++ ) { a_[k] = new DoubleMatrix(ctxt_,n,n,nb,nb); adiag_[k].resize(n); } u_ = new DoubleMatrix(ctxt_,n,n,nb,nb); sdcosx_ = new SlaterDet(sd_); sdcosy_ = new SlaterDet(sd_); sdcosz_ = new SlaterDet(sd_); sdsinx_ = new SlaterDet(sd_); sdsiny_ = new SlaterDet(sd_); sdsinz_ = new SlaterDet(sd_); } //////////////////////////////////////////////////////////////////////////////// MLWFTransform::~MLWFTransform(void) { for ( int k = 0; k < 6; k++ ) delete a_[k]; delete u_; delete sdcosx_; delete sdcosy_; delete sdcosz_; delete sdsinx_; delete sdsiny_; delete sdsinz_; } //////////////////////////////////////////////////////////////////////////////// void MLWFTransform::update(void) { // recompute cos and sin matrices const ComplexMatrix& c = sd_.c(); ComplexMatrix& ccosx = sdcosx_->c(); ComplexMatrix& csinx = sdsinx_->c(); ComplexMatrix& ccosy = sdcosy_->c(); ComplexMatrix& csiny = sdsiny_->c(); ComplexMatrix& ccosz = sdcosz_->c(); ComplexMatrix& csinz = sdsinz_->c(); // proxy real matrices cr, cc, cs DoubleMatrix cr(c); DoubleMatrix ccx(ccosx); DoubleMatrix csx(csinx); DoubleMatrix ccy(ccosy); DoubleMatrix csy(csiny); DoubleMatrix ccz(ccosz); DoubleMatrix csz(csinz); vector > zvec(bm_.zvec_size()), zvec_cos(bm_.zvec_size()), zvec_sin(bm_.zvec_size()), ct(bm_.np012loc()), ct_cos(bm_.np012loc()), ct_sin(bm_.np012loc()); for ( int i = 0; i < 6; i++ ) { a_[i]->resize(c.n(), c.n(), c.nb(), c.nb()); adiag_[i].resize(c.n()); } u_->resize(c.n(), c.n(), c.nb(), c.nb()); // loop over all local states const int np0 = bm_.np0(); const int np1 = bm_.np1(); const int np2 = bm_.np2(); const int np01 = np0 * np1; const int np2loc = bm_.np2loc(); const int nvec = bm_.nvec(); for ( int n = 0; n < c.nloc(); n++ ) { const complex* f = c.cvalptr(n*c.mloc()); complex* fcx = ccosx.valptr(n*c.mloc()); complex* fsx = csinx.valptr(n*c.mloc()); complex* fcy = ccosy.valptr(n*c.mloc()); complex* fsy = csiny.valptr(n*c.mloc()); complex* fcz = ccosz.valptr(n*c.mloc()); complex* fsz = csinz.valptr(n*c.mloc()); // direction z // map state to array zvec_ bm_.vector_to_zvec(&f[0],&zvec[0]); for ( int ivec = 0; ivec < nvec; ivec++ ) { const int ibase = ivec * np2; compute_sincos(np2,&zvec[ibase],&zvec_cos[ibase],&zvec_sin[ibase]); } // map back zvec_cos to sdcos and zvec_sin to sdsin bm_.zvec_to_vector(&zvec_cos[0],&fcz[0]); bm_.zvec_to_vector(&zvec_sin[0],&fsz[0]); // x direction // map zvec to ct bm_.transpose_fwd(&zvec[0],&ct[0]); for ( int iz = 0; iz < np2loc; iz++ ) { for ( int iy = 0; iy < np1; iy++ ) { const int ibase = iz * np01 + iy * np0; compute_sincos(np0,&ct[ibase],&ct_cos[ibase],&ct_sin[ibase]); } } // transpose back ct_cos to zvec_cos bm_.transpose_bwd(&ct_cos[0],&zvec_cos[0]); // transpose back ct_sin to zvec_sin bm_.transpose_bwd(&ct_sin[0],&zvec_sin[0]); // map back zvec_cos to sdcos and zvec_sin to sdsin bm_.zvec_to_vector(&zvec_cos[0],&fcx[0]); bm_.zvec_to_vector(&zvec_sin[0],&fsx[0]); // y direction vector > c_tmp(np1),ccos_tmp(np1),csin_tmp(np1); int one = 1; int len = np1; int stride = np0; for ( int iz = 0; iz < np2loc; iz++ ) { for ( int ix = 0; ix < np0; ix++ ) { const int ibase = iz * np01 + ix; zcopy(&len,&ct[ibase],&stride,&c_tmp[0],&one); compute_sincos(np1,&c_tmp[0],&ccos_tmp[0],&csin_tmp[0]); zcopy(&len,&ccos_tmp[0],&one,&ct_cos[ibase],&stride); zcopy(&len,&csin_tmp[0],&one,&ct_sin[ibase],&stride); } } // transpose back ct_cos to zvec_cos bm_.transpose_bwd(&ct_cos[0],&zvec_cos[0]); // transpose back ct_sin to zvec_sin bm_.transpose_bwd(&ct_sin[0],&zvec_sin[0]); // map back zvec_cos and zvec_sin bm_.zvec_to_vector(&zvec_cos[0],&fcy[0]); bm_.zvec_to_vector(&zvec_sin[0],&fsy[0]); } // dot products a_[0] = , a_[1] = a_[0]->gemm('t','n',2.0,cr,ccx,0.0); a_[0]->ger(-1.0,cr,0,ccx,0); a_[1]->gemm('t','n',2.0,cr,csx,0.0); a_[1]->ger(-1.0,cr,0,csx,0); // dot products a_[2] = , a_[3] = a_[2]->gemm('t','n',2.0,cr,ccy,0.0); a_[2]->ger(-1.0,cr,0,ccy,0); a_[3]->gemm('t','n',2.0,cr,csy,0.0); a_[3]->ger(-1.0,cr,0,csy,0); // dot products a_[4] = , a_[5] = a_[4]->gemm('t','n',2.0,cr,ccz,0.0); a_[4]->ger(-1.0,cr,0,ccz,0); a_[5]->gemm('t','n',2.0,cr,csz,0.0); a_[5]->ger(-1.0,cr,0,csz,0); } //////////////////////////////////////////////////////////////////////////////// void MLWFTransform::compute_transform(void) { int nsweep = jade(maxsweep_,tol_,a_,*u_,adiag_); } //////////////////////////////////////////////////////////////////////////////// void MLWFTransform::compute_sincos(const int n, const complex* f, complex* fc, complex* fs) { // fc[i] = 0.5 * ( f[i-1] + f[i+1] ) // fs[i] = (0.5/i) * ( f[i-1] - f[i+1] ) // i = 0 complex zp = f[n-1]; complex zm = f[1]; fc[0] = 0.5 * ( zp + zm ); complex zdiff = zp - zm; fs[0] = 0.5 * complex(imag(zdiff),-real(zdiff)); for ( int i = 1; i < n-1; i++ ) { const complex zzp = f[i-1]; const complex zzm = f[i+1]; fc[i] = 0.5 * ( zzp + zzm ); const complex zzdiff = zzp - zzm; fs[i] = 0.5 * complex(imag(zzdiff),-real(zzdiff)); } // i = n-1 zp = f[n-2]; zm = f[0]; fc[n-1] = 0.5 * ( zp + zm ); zdiff = zp - zm; fs[n-1] = 0.5 * complex(imag(zdiff),-real(zdiff)); } //////////////////////////////////////////////////////////////////////////////// D3vector MLWFTransform::center(int i) { assert(i>=0 && i=0 && i=0 && j<3); const double c = adiag_[2*j][i]; const double s = adiag_[2*j+1][i]; // Next line: M_1_PI = 1.0/pi const double fac = 1.0 / length(cell_.b(j)); return fac*fac * ( 1.0 - c*c - s*s ); } //////////////////////////////////////////////////////////////////////////////// double MLWFTransform::spread2(int i) { assert(i>=0 & i