////////////////////////////////////////////////////////////////////////////////
//
// Copyright (c) 2008 The Regents of the University of California
//
// This file is part of Qbox
//
// Qbox is distributed under the terms of the GNU General Public License
// as published by the Free Software Foundation, either version 2 of
// the License, or (at your option) any later version.
// See the file COPYING in the root directory of this distribution
// or .
//
////////////////////////////////////////////////////////////////////////////////
//
// MLWFTransform.C
//
////////////////////////////////////////////////////////////////////////////////
#include
#include
#include
#include
#include "MLWFTransform.h"
#include "D3vector.h"
#include "Basis.h"
#include "SlaterDet.h"
#include "UnitCell.h"
#include "jade.h"
#include "blas.h"
using namespace std;
////////////////////////////////////////////////////////////////////////////////
MLWFTransform::MLWFTransform(const SlaterDet& sd) : sd_(sd),
cell_(sd.basis().cell()), ctxt_(sd.context()), bm_(BasisMapping(sd.basis())),
maxsweep_(50), tol_(1.e-8)
{
a_.resize(6);
adiag_.resize(6);
const int n = sd.c().n();
const int nb = sd.c().nb();
for ( int k = 0; k < 6; k++ )
{
a_[k] = new DoubleMatrix(ctxt_,n,n,nb,nb);
adiag_[k].resize(n);
}
u_ = new DoubleMatrix(ctxt_,n,n,nb,nb);
sdcosx_ = new SlaterDet(sd_);
sdcosy_ = new SlaterDet(sd_);
sdcosz_ = new SlaterDet(sd_);
sdsinx_ = new SlaterDet(sd_);
sdsiny_ = new SlaterDet(sd_);
sdsinz_ = new SlaterDet(sd_);
}
////////////////////////////////////////////////////////////////////////////////
MLWFTransform::~MLWFTransform(void)
{
for ( int k = 0; k < 6; k++ )
delete a_[k];
delete u_;
delete sdcosx_;
delete sdcosy_;
delete sdcosz_;
delete sdsinx_;
delete sdsiny_;
delete sdsinz_;
}
////////////////////////////////////////////////////////////////////////////////
void MLWFTransform::update(void)
{
// recompute cos and sin matrices
const ComplexMatrix& c = sd_.c();
ComplexMatrix& ccosx = sdcosx_->c();
ComplexMatrix& csinx = sdsinx_->c();
ComplexMatrix& ccosy = sdcosy_->c();
ComplexMatrix& csiny = sdsiny_->c();
ComplexMatrix& ccosz = sdcosz_->c();
ComplexMatrix& csinz = sdsinz_->c();
// proxy real matrices cr, cc, cs
DoubleMatrix cr(c);
DoubleMatrix ccx(ccosx);
DoubleMatrix csx(csinx);
DoubleMatrix ccy(ccosy);
DoubleMatrix csy(csiny);
DoubleMatrix ccz(ccosz);
DoubleMatrix csz(csinz);
vector > zvec(bm_.zvec_size()),
zvec_cos(bm_.zvec_size()), zvec_sin(bm_.zvec_size()),
ct(bm_.np012loc()), ct_cos(bm_.np012loc()), ct_sin(bm_.np012loc());
for ( int i = 0; i < 6; i++ )
{
a_[i]->resize(c.n(), c.n(), c.nb(), c.nb());
adiag_[i].resize(c.n());
}
u_->resize(c.n(), c.n(), c.nb(), c.nb());
// loop over all local states
const int np0 = bm_.np0();
const int np1 = bm_.np1();
const int np2 = bm_.np2();
const int np01 = np0 * np1;
const int np2loc = bm_.np2loc();
const int nvec = bm_.nvec();
for ( int n = 0; n < c.nloc(); n++ )
{
const complex* f = c.cvalptr(n*c.mloc());
complex* fcx = ccosx.valptr(n*c.mloc());
complex* fsx = csinx.valptr(n*c.mloc());
complex* fcy = ccosy.valptr(n*c.mloc());
complex* fsy = csiny.valptr(n*c.mloc());
complex* fcz = ccosz.valptr(n*c.mloc());
complex* fsz = csinz.valptr(n*c.mloc());
// direction z
// map state to array zvec_
bm_.vector_to_zvec(&f[0],&zvec[0]);
for ( int ivec = 0; ivec < nvec; ivec++ )
{
const int ibase = ivec * np2;
compute_sincos(np2,&zvec[ibase],&zvec_cos[ibase],&zvec_sin[ibase]);
}
// map back zvec_cos to sdcos and zvec_sin to sdsin
bm_.zvec_to_vector(&zvec_cos[0],&fcz[0]);
bm_.zvec_to_vector(&zvec_sin[0],&fsz[0]);
// x direction
// map zvec to ct
bm_.transpose_fwd(&zvec[0],&ct[0]);
for ( int iz = 0; iz < np2loc; iz++ )
{
for ( int iy = 0; iy < np1; iy++ )
{
const int ibase = iz * np01 + iy * np0;
compute_sincos(np0,&ct[ibase],&ct_cos[ibase],&ct_sin[ibase]);
}
}
// transpose back ct_cos to zvec_cos
bm_.transpose_bwd(&ct_cos[0],&zvec_cos[0]);
// transpose back ct_sin to zvec_sin
bm_.transpose_bwd(&ct_sin[0],&zvec_sin[0]);
// map back zvec_cos to sdcos and zvec_sin to sdsin
bm_.zvec_to_vector(&zvec_cos[0],&fcx[0]);
bm_.zvec_to_vector(&zvec_sin[0],&fsx[0]);
// y direction
vector > c_tmp(np1),ccos_tmp(np1),csin_tmp(np1);
int one = 1;
int len = np1;
int stride = np0;
for ( int iz = 0; iz < np2loc; iz++ )
{
for ( int ix = 0; ix < np0; ix++ )
{
const int ibase = iz * np01 + ix;
zcopy(&len,&ct[ibase],&stride,&c_tmp[0],&one);
compute_sincos(np1,&c_tmp[0],&ccos_tmp[0],&csin_tmp[0]);
zcopy(&len,&ccos_tmp[0],&one,&ct_cos[ibase],&stride);
zcopy(&len,&csin_tmp[0],&one,&ct_sin[ibase],&stride);
}
}
// transpose back ct_cos to zvec_cos
bm_.transpose_bwd(&ct_cos[0],&zvec_cos[0]);
// transpose back ct_sin to zvec_sin
bm_.transpose_bwd(&ct_sin[0],&zvec_sin[0]);
// map back zvec_cos and zvec_sin
bm_.zvec_to_vector(&zvec_cos[0],&fcy[0]);
bm_.zvec_to_vector(&zvec_sin[0],&fsy[0]);
}
// dot products a_[0] = , a_[1] =
a_[0]->gemm('t','n',2.0,cr,ccx,0.0);
a_[0]->ger(-1.0,cr,0,ccx,0);
a_[1]->gemm('t','n',2.0,cr,csx,0.0);
a_[1]->ger(-1.0,cr,0,csx,0);
// dot products a_[2] = , a_[3] =
a_[2]->gemm('t','n',2.0,cr,ccy,0.0);
a_[2]->ger(-1.0,cr,0,ccy,0);
a_[3]->gemm('t','n',2.0,cr,csy,0.0);
a_[3]->ger(-1.0,cr,0,csy,0);
// dot products a_[4] = , a_[5] =
a_[4]->gemm('t','n',2.0,cr,ccz,0.0);
a_[4]->ger(-1.0,cr,0,ccz,0);
a_[5]->gemm('t','n',2.0,cr,csz,0.0);
a_[5]->ger(-1.0,cr,0,csz,0);
}
////////////////////////////////////////////////////////////////////////////////
void MLWFTransform::compute_transform(void)
{
int nsweep = jade(maxsweep_,tol_,a_,*u_,adiag_);
}
////////////////////////////////////////////////////////////////////////////////
void MLWFTransform::compute_sincos(const int n, const complex* f,
complex* fc, complex* fs)
{
// fc[i] = 0.5 * ( f[i-1] + f[i+1] )
// fs[i] = (0.5/i) * ( f[i-1] - f[i+1] )
// i = 0
complex zp = f[n-1];
complex zm = f[1];
fc[0] = 0.5 * ( zp + zm );
complex zdiff = zp - zm;
fs[0] = 0.5 * complex(imag(zdiff),-real(zdiff));
for ( int i = 1; i < n-1; i++ )
{
const complex zzp = f[i-1];
const complex zzm = f[i+1];
fc[i] = 0.5 * ( zzp + zzm );
const complex zzdiff = zzp - zzm;
fs[i] = 0.5 * complex(imag(zzdiff),-real(zzdiff));
}
// i = n-1
zp = f[n-2];
zm = f[0];
fc[n-1] = 0.5 * ( zp + zm );
zdiff = zp - zm;
fs[n-1] = 0.5 * complex(imag(zdiff),-real(zdiff));
}
////////////////////////////////////////////////////////////////////////////////
D3vector MLWFTransform::center(int i)
{
assert(i>=0 && i=0 && i=0 && j<3);
const double c = adiag_[2*j][i];
const double s = adiag_[2*j+1][i];
// Next line: M_1_PI = 1.0/pi
const double fac = 1.0 / length(cell_.b(j));
return fac*fac * ( 1.0 - c*c - s*s );
}
////////////////////////////////////////////////////////////////////////////////
double MLWFTransform::spread2(int i)
{
assert(i>=0 & i