Commit 0b54b012 by Francois Gygi

Merge branch 'develop'

parents 1a76f0a4 04bf3838
...@@ -22,18 +22,24 @@ ...@@ -22,18 +22,24 @@
#include <iostream> #include <iostream>
#include <cassert> #include <cassert>
#include <cstring> // memset
using namespace std; using namespace std;
#if USE_GATHER_SCATTER
extern "C" {
// zgthr: x(i) = y(indx(i))
void zgthr_(int* n, complex<double>* y, complex<double>* x, int *indx);
// zsctr: y(indx(i)) = x(i)
void zsctr_(int* n, complex<double>* x, int* indx, complex<double>* y);
}
#endif
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
BasisMapping::BasisMapping (const Basis &basis) : basis_(basis) BasisMapping::BasisMapping (const Basis &basis, int np0, int np1, int np2) :
basis_(basis), np0_(np0), np1_(np1), np2_(np2)
{ {
nprocs_ = basis_.npes(); nprocs_ = basis_.npes();
myproc_ = basis_.mype(); myproc_ = basis_.mype();
np0_ = basis.np(0);
np1_ = basis.np(1);
np2_ = basis.np(2);
np2_loc_.resize(nprocs_); np2_loc_.resize(nprocs_);
np2_first_.resize(nprocs_); np2_first_.resize(nprocs_);
...@@ -127,6 +133,9 @@ BasisMapping::BasisMapping (const Basis &basis) : basis_(basis) ...@@ -127,6 +133,9 @@ BasisMapping::BasisMapping (const Basis &basis) : basis_(basis)
rdispl[iproc] = rdispl[iproc-1] + rcounts[iproc-1]; rdispl[iproc] = rdispl[iproc-1] + rcounts[iproc-1];
} }
// check if the basis_ fits in the grid np0, np1, np2
assert(basis_.fits_in_grid(np0_,np1_,np2_));
if ( basis_.real() ) if ( basis_.real() )
{ {
// compute index arrays ip_ and im_ for mapping vector->zvec // compute index arrays ip_ and im_ for mapping vector->zvec
...@@ -388,19 +397,18 @@ BasisMapping::BasisMapping (const Basis &basis) : basis_(basis) ...@@ -388,19 +397,18 @@ BasisMapping::BasisMapping (const Basis &basis) : basis_(basis)
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void BasisMapping::transpose_fwd(const complex<double> *zvec, void BasisMapping::transpose_bwd(const complex<double> *zvec,
complex<double> *ct) complex<double> *ct) const
{ {
// Transpose zvec to ct // Transpose zvec to ct
// scatter zvec to sbuf for transpose // scatter zvec to sbuf for transpose
#if USE_GATHER_SCATTER #if USE_GATHER_SCATTER
// zsctr: y(indx(i)) = x(i) // zsctr: y(indx(i)) = x(i)
// void zsctr_(int* n, complex<double>* x, int* indx, complex<double>* y);
{ {
complex<double>* y = &sbuf[0]; complex<double>* y = const_cast<complex<double>*>(&sbuf[0]);
complex<double>* x = const_cast<complex<double>*>(zvec); complex<double>* x = const_cast<complex<double>*>(zvec);
int n = zvec_.size(); int n = zvec_size();
zsctr_(&n,x,&ipack_[0],y); zsctr_(&n,x,const_cast<int*>(&ipack_[0]),y);
} }
#else #else
const int len = zvec_size(); const int len = zvec_size();
...@@ -420,39 +428,36 @@ void BasisMapping::transpose_fwd(const complex<double> *zvec, ...@@ -420,39 +428,36 @@ void BasisMapping::transpose_fwd(const complex<double> *zvec,
// segments of z-vectors are now in sbuf // segments of z-vectors are now in sbuf
// transpose // transpose
#if USE_MPI if ( nprocs_ == 1 )
int status = MPI_Alltoallv((double*)&sbuf[0],&scounts[0],&sdispl[0],
MPI_DOUBLE,(double*)&rbuf[0],&rcounts[0],&rdispl[0],MPI_DOUBLE,
basis_.comm());
if ( status != 0 )
{ {
cout << " BasisMapping: status = " << status << endl; assert(sbuf.size()==rbuf.size());
MPI_Abort(basis_.comm(),2); rbuf.swap(sbuf);
} }
#else else
assert(sbuf.size()==rbuf.size());
rbuf = sbuf;
#endif
// copy from rbuf to ct
// scatter index array iunpack
{ {
const int len = np012loc_; int status =
double* const pv = (double*) ct; MPI_Alltoallv((double*)&sbuf[0],(int*)&scounts[0],(int*)&sdispl[0],
for ( int i = 0; i < len; i++ ) MPI_DOUBLE,(double*)&rbuf[0],(int*)&rcounts[0],(int*)&rdispl[0],
MPI_DOUBLE, basis_.comm());
if ( status != 0 )
{ {
pv[2*i] = 0.0; cout << " BasisMapping: status = " << status << endl;
pv[2*i+1] = 0.0; MPI_Abort(basis_.comm(),2);
} }
} }
// clear ct
memset((void*)ct,0,np012loc_*sizeof(complex<double>));
// copy from rbuf to ct
// using scatter index array iunpack
#if USE_GATHER_SCATTER #if USE_GATHER_SCATTER
// zsctr(n,x,indx,y): y(indx(i)) = x(i) // zsctr(n,x,indx,y): y(indx(i)) = x(i)
{ {
complex<double>* y = ct; complex<double>* y = ct;
complex<double>* x = &rbuf[0]; complex<double>* x = const_cast<complex<double>*>(&rbuf[0]);
int n = rbuf.size(); int n = rbuf.size();
zsctr_(&n,x,&iunpack_[0],y); zsctr_(&n,x,const_cast<int*>(&iunpack_[0]),y);
} }
#else #else
{ {
...@@ -475,19 +480,18 @@ void BasisMapping::transpose_fwd(const complex<double> *zvec, ...@@ -475,19 +480,18 @@ void BasisMapping::transpose_fwd(const complex<double> *zvec,
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void BasisMapping::transpose_bwd(const complex<double> *ct, void BasisMapping::transpose_fwd(const complex<double> *ct,
complex<double> *zvec) complex<double> *zvec) const
{ {
// transpose back distributed array ct into zvec // transpose ct to zvec
// gather ct into rbuf // gather ct into rbuf
#if USE_GATHER_SCATTER #if USE_GATHER_SCATTER
// zgthr: x(i) = y(indx(i)) // zgthr: x(i) = y(indx(i))
// void zgthr_(int* n, complex<double>* y, complex<double>* x, int*indx);
{ {
complex<double>* y = const_cast<complex<double>*>(ct); complex<double>* y = const_cast<complex<double>*>(ct);
complex<double>* x = &rbuf[0]; complex<double>* x = const_cast<complex<double>*>(&rbuf[0]);
int n = rbuf.size(); int n = rbuf.size();
zgthr_(&n,y,x,&iunpack_[0]); zgthr_(&n,y,x,const_cast<int*>(&iunpack_[0]));
} }
#else #else
const int rbuf_size = rbuf.size(); const int rbuf_size = rbuf.size();
...@@ -505,27 +509,34 @@ void BasisMapping::transpose_bwd(const complex<double> *ct, ...@@ -505,27 +509,34 @@ void BasisMapping::transpose_bwd(const complex<double> *ct,
#endif #endif
// transpose // transpose
#if USE_MPI if ( nprocs_ == 1 )
int status = MPI_Alltoallv((double*)&rbuf[0],&rcounts[0],&rdispl[0], {
MPI_DOUBLE,(double*)&sbuf[0],&scounts[0],&sdispl[0],MPI_DOUBLE, assert(sbuf.size()==rbuf.size());
basis_.comm()); sbuf.swap(rbuf);
assert ( status == 0 ); }
#else else
assert(sbuf.size()==rbuf.size()); {
sbuf = rbuf; int status =
#endif MPI_Alltoallv((double*)&rbuf[0],(int*)&rcounts[0],(int*)&rdispl[0],
MPI_DOUBLE,(double*)&sbuf[0],(int*)&scounts[0],(int*)&sdispl[0],
MPI_DOUBLE, basis_.comm());
if ( status != 0 )
{
cout << " BasisMapping: status = " << status << endl;
MPI_Abort(basis_.comm(),2);
}
}
// segments of z-vectors are now in sbuf // segments of z-vectors are now in sbuf
// gather sbuf into zvec_ // gather sbuf into zvec_
#if USE_GATHER_SCATTER #if USE_GATHER_SCATTER
// zgthr: x(i) = y(indx(i)) // zgthr: x(i) = y(indx(i))
// void zgthr_(int* n, complex<double>* y, complex<double>* x, int*indx);
{ {
complex<double>* y = &sbuf[0]; complex<double>* y = const_cast<complex<double>*>(&sbuf[0]);
complex<double>* x = zvec; complex<double>* x = zvec;
int n = zvec_.size(); int n = zvec_size();
zgthr_(&n,y,x,&ipack_[0]); zgthr_(&n,y,x,const_cast<int*>(&ipack_[0]));
} }
#else #else
const int len = zvec_size(); const int len = zvec_size();
...@@ -545,17 +556,14 @@ void BasisMapping::transpose_bwd(const complex<double> *ct, ...@@ -545,17 +556,14 @@ void BasisMapping::transpose_bwd(const complex<double> *ct,
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void BasisMapping::vector_to_zvec(const complex<double> *c, void BasisMapping::vector_to_zvec(const complex<double> *c,
complex<double> *zvec) complex<double> *zvec) const
{ {
// clear zvec
memset((void*)&zvec[0],0,zvec_size()*sizeof(complex<double>));
// map coefficients from the basis order to a zvec // map coefficients from the basis order to a zvec
const int ng = basis_.localsize();
const int len = zvec_size();
double* const pz = (double*) zvec; double* const pz = (double*) zvec;
for ( int i = 0; i < len; i++ ) const int ng = basis_.localsize();
{
pz[2*i] = 0.0;
pz[2*i+1] = 0.0;
}
const double* const pc = (const double*) c; const double* const pc = (const double*) c;
if ( basis_.real() ) if ( basis_.real() )
{ {
...@@ -584,9 +592,46 @@ void BasisMapping::vector_to_zvec(const complex<double> *c, ...@@ -584,9 +592,46 @@ void BasisMapping::vector_to_zvec(const complex<double> *c,
pz[2*ip+1] = b; pz[2*ip+1] = b;
} }
} }
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::doublevector_to_zvec(const complex<double> *c1,
const complex<double> *c2, complex<double> *zvec) const
{
assert(basis_.real());
// clear zvec
memset((void*)&zvec[0],0,zvec_size()*sizeof(complex<double>));
// map two real functions to zvec
double* const pz = (double*) &zvec[0];
const int ng = basis_.localsize();
const double* const pc1 = (double*) &c1[0];
const double* const pc2 = (double*) &c2[0];
#pragma omp parallel for
for ( int ig = 0; ig < ng; ig++ )
{
// const double a = c1[ig].real();
// const double b = c1[ig].imag();
// const double c = c2[ig].real();
// const double d = c2[ig].imag();
// zvec_[ip] = complex<double>(a-d, b+c);
// zvec_[im] = complex<double>(a+d, c-b);
const double a = pc1[2*ig];
const double b = pc1[2*ig+1];
const double c = pc2[2*ig];
const double d = pc2[2*ig+1];
const int ip = ip_[ig];
const int im = im_[ig];
pz[2*ip] = a - d;
pz[2*ip+1] = b + c;
pz[2*im] = a + d;
pz[2*im+1] = c - b;
}
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void BasisMapping::zvec_to_vector(const complex<double> *zvec, void BasisMapping::zvec_to_vector(const complex<double> *zvec,
complex<double> *c) complex<double> *c) const
{ {
const int ng = basis_.localsize(); const int ng = basis_.localsize();
const double* const pz = (const double*) zvec; const double* const pz = (const double*) zvec;
...@@ -601,3 +646,35 @@ void BasisMapping::zvec_to_vector(const complex<double> *zvec, ...@@ -601,3 +646,35 @@ void BasisMapping::zvec_to_vector(const complex<double> *zvec,
pc[2*ig+1] = pz1; pc[2*ig+1] = pz1;
} }
} }
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::zvec_to_doublevector(const complex<double> *zvec,
complex<double> *c1, complex<double> *c2 ) const
{
// Mapping of zvec onto two real functions
assert(basis_.real());
const int ng = basis_.localsize();
const double* const pz = (double*) &zvec[0];
double* const pc1 = (double*) &c1[0];
double* const pc2 = (double*) &c2[0];
#pragma omp parallel for
for ( int ig = 0; ig < ng; ig++ )
{
// const double a = 0.5*zvec_[ip].real();
// const double b = 0.5*zvec_[ip].imag();
// const double c = 0.5*zvec_[im].real();
// const double d = 0.5*zvec_[im].imag();
// c1[ig] = complex<double>(a+c, b-d);
// c2[ig] = complex<double>(b+d, c-a);
const int ip = ip_[ig];
const int im = im_[ig];
const double a = pz[2*ip];
const double b = pz[2*ip+1];
const double c = pz[2*im];
const double d = pz[2*im+1];
pc1[2*ig] = 0.5 * ( a + c );
pc1[2*ig+1] = 0.5 * ( b - d );
pc2[2*ig] = 0.5 * ( b + d );
pc2[2*ig+1] = 0.5 * ( c - a );
}
}
...@@ -38,32 +38,41 @@ class BasisMapping ...@@ -38,32 +38,41 @@ class BasisMapping
std::vector<int> np2_first_; // np2_first_[iproc], iproc=0, nprocs_-1 std::vector<int> np2_first_; // np2_first_[iproc], iproc=0, nprocs_-1
std::vector<int> scounts, sdispl, rcounts, rdispl; std::vector<int> scounts, sdispl, rcounts, rdispl;
std::vector<std::complex<double> > sbuf, rbuf; mutable std::vector<std::complex<double> > sbuf, rbuf;
std::vector<int> ip_, im_; std::vector<int> ip_, im_;
std::vector<int> ipack_, iunpack_; std::vector<int> ipack_, iunpack_;
public: public:
BasisMapping (const Basis &basis); BasisMapping (const Basis &basis, int np0, int np1, int np2);
int np0(void) const { return np0_; } int np0(void) const { return np0_; }
int np1(void) const { return np1_; } int np1(void) const { return np1_; }
int np2(void) const { return np2_; } int np2(void) const { return np2_; }
int np2loc(void) const { return np2_loc_[myproc_]; } int np2_loc(void) const { return np2_loc_[myproc_]; }
int np2_loc(int iproc) const { return np2_loc_[iproc]; }
int np2_first(void) const { return np2_first_[myproc_]; }
int np2_first(int iproc) const { return np2_first_[iproc]; }
int np012loc(void) const { return np012loc_; } int np012loc(void) const { return np012loc_; }
int nvec(void) const { return nvec_; } int nvec(void) const { return nvec_; }
int zvec_size(void) const { return nvec_ * np2_; } int zvec_size(void) const { return nvec_ * np2_; }
// map a function c(G) to zvec_ // map a function c(G) to zvec_
void vector_to_zvec(const std::complex<double> *c, void vector_to_zvec(const std::complex<double> *c,
std::complex<double> *zvec); std::complex<double> *zvec) const;
// map two real functions c1(G) and c2(G) to zvec_
void doublevector_to_zvec(const std::complex<double> *c1,
const std::complex<double> *c2,std::complex<double> *zvec) const;
// map zvec_ to a function c(G) // map zvec_ to a function c(G)
void zvec_to_vector(const std::complex<double> *zvec, void zvec_to_vector(const std::complex<double> *zvec,
std::complex<double> *c); std::complex<double> *c) const;
// map zvec_ to two real functions c1(G) and c2(G)
void zvec_to_doublevector(const std::complex<double> *zvec,
std::complex<double> *c1, std::complex<double> *c2) const;
void transpose_fwd(const std::complex<double> *zvec, void transpose_bwd(const std::complex<double> *zvec,
std::complex<double> *ct); std::complex<double> *ct) const;
void transpose_bwd(const std::complex<double> *ct, void transpose_fwd(const std::complex<double> *ct,
std::complex<double> *zvec); std::complex<double> *zvec) const;
}; };
#endif #endif
...@@ -406,7 +406,7 @@ void Bisection::compute_transform(const SlaterDet& sd, int maxsweep, double tol) ...@@ -406,7 +406,7 @@ void Bisection::compute_transform(const SlaterDet& sd, int maxsweep, double tol)
jade(maxsweep,tol,amat_,*u_,adiag_); jade(maxsweep,tol,amat_,*u_,adiag_);
#ifdef TIMING #ifdef TIMING
if ( ctxt_.onpe0() ) if ( ctxt_.onpe0() )
cout << "Bisection::compute_transform: nsweep=" << nsweep cout << "Bisection::compute_transform:"
<< " maxsweep=" << maxsweep << " tol=" << tol << endl; << " maxsweep=" << maxsweep << " tol=" << tol << endl;
#endif #endif
......
...@@ -82,15 +82,15 @@ CPSampleStepper::~CPSampleStepper(void) ...@@ -82,15 +82,15 @@ CPSampleStepper::~CPSampleStepper(void)
void CPSampleStepper::step(int niter) void CPSampleStepper::step(int niter)
{ {
const bool onpe0 = s_.ctxt_.onpe0(); const bool onpe0 = s_.ctxt_.onpe0();
// CP dynamics is allowed only for all doubly occupied states
// check if states are all doubly occupied // check that there are no fractionally occupied states
const bool wf_double_occ = (s_.wf.nel() == 2 * s_.wf.nst()); // next line: (3-nspin) = 2 if nspin==1 and 1 if nspin==2
if ( !wf_double_occ ) if ( s_.wf.nel() != (( 3 - s_.wf.nspin() ) * s_.wf.nst()) )
{ {
if ( s_.ctxt_.onpe0() ) if ( s_.ctxt_.onpe0() )
{ {
cout << " CPSampleStepper::step:" cout << " CPSampleStepper::step:"
" not all states doubly occupied: cannot run" << endl; " fractionally occupied or empty states: cannot run" << endl;
} }
return; return;
} }
......
...@@ -1007,6 +1007,7 @@ double ExchangeOperator::compute_exchange_at_gamma_(const Wavefunction &wf, ...@@ -1007,6 +1007,7 @@ double ExchangeOperator::compute_exchange_at_gamma_(const Wavefunction &wf,
// if using bisection, localize the wave functions // if using bisection, localize the wave functions
if ( use_bisection_ ) if ( use_bisection_ )
{ {
tmb.reset();
tmb.start(); tmb.start();
int maxsweep = 50; int maxsweep = 50;
if ( s_.ctrl.debug.find("BISECTION_MAXSWEEP") != string::npos ) if ( s_.ctrl.debug.find("BISECTION_MAXSWEEP") != string::npos )
......
...@@ -45,11 +45,8 @@ ...@@ -45,11 +45,8 @@
#endif #endif
#endif #endif
#if USE_MPI
#include <mpi.h>
#endif
#include "Timer.h" #include "Timer.h"
#include "BasisMapping.h"
class Basis; class Basis;
...@@ -57,26 +54,15 @@ class FourierTransform ...@@ -57,26 +54,15 @@ class FourierTransform
{ {
private: private:
MPI_Comm comm_;
const Basis& basis_; const Basis& basis_;
int nprocs_, myproc_; const BasisMapping bm_;
int np0_,np1_,np2_; const int np0_,np1_,np2_;
int ntrans0_,ntrans1_,ntrans2_; const int nvec_;
int nvec_; int ntrans0_,ntrans1_,ntrans2_;
bool basis_fits_in_grid_;
std::vector<int> np2_loc_; // np2_loc_[iproc], iproc=0, nprocs_-1
std::vector<int> np2_first_; // np2_first_[iproc], iproc=0, nprocs_-1
std::vector<std::complex<double> > zvec_; std::vector<std::complex<double> > zvec_;
std::vector<int> scounts, sdispl, rcounts, rdispl;
std::vector<std::complex<double> > sbuf, rbuf;
std::vector<int> ifftp_, ifftm_;
std::vector<int> ipack_, iunpack_;
void init_lib(void); void init_lib(void);
#if USE_ESSL_FFT #if USE_ESSL_FFT
...@@ -107,11 +93,10 @@ class FourierTransform ...@@ -107,11 +93,10 @@ class FourierTransform
#error "Must define USE_FFTW2, USE_FFTW3, USE_ESSL_FFT or FFT_NOLIB" #error "Must define USE_FFTW2, USE_FFTW3, USE_ESSL_FFT or FFT_NOLIB"
#endif #endif
void vector_to_zvec(const std::complex<double>* c); void fxy(std::complex<double>* val);
void zvec_to_vector(std::complex<double>* c); void fxy_inv(std::complex<double>* val);
void doublevector_to_zvec(const std::complex<double>* c1, void fz(void);
const std::complex<double> *c2); void fz_inv(void);
void zvec_to_doublevector(std::complex<double>* c1, std::complex<double>* c2);
void fwd(std::complex<double>* val); void fwd(std::complex<double>* val);
void bwd(std::complex<double>* val); void bwd(std::complex<double>* val);
...@@ -119,7 +104,6 @@ class FourierTransform ...@@ -119,7 +104,6 @@ class FourierTransform
FourierTransform (const Basis &basis, int np0, int np1, int np2); FourierTransform (const Basis &basis, int np0, int np1, int np2);
~FourierTransform (); ~FourierTransform ();
MPI_Comm comm(void) const { return comm_; }
// backward: Fourier synthesis, compute real-space function // backward: Fourier synthesis, compute real-space function
// forward: Fourier analysis, compute Fourier coefficients // forward: Fourier analysis, compute Fourier coefficients
...@@ -139,13 +123,13 @@ class FourierTransform ...@@ -139,13 +123,13 @@ class FourierTransform
int np0() const { return np0_; } int np0() const { return np0_; }
int np1() const { return np1_; } int np1() const { return np1_; }
int np2() const { return np2_; } int np2() const { return np2_; }
int np2_loc() const { return np2_loc_[myproc_]; } int np2_loc(void) const { return bm_.np2_loc(); }
int np2_loc(int iproc) const { return np2_loc_[iproc]; } int np2_loc(int iproc) const { return bm_.np2_loc(iproc); }
int np2_first() const { return np2_first_[myproc_]; } int np2_first(void) const { return bm_.np2_first(); }
int np2_first(int iproc) const { return np2_first_[iproc]; } int np2_first(int iproc) const { return bm_.np2_first(iproc); }
long int np012() const { return ((long int)np0_) * np1_ * np2_; } long int np012() const { return ((long int)np0_) * np1_ * np2_; }
int np012loc(int iproc) const { return np0_ * np1_ * np2_loc_[iproc]; } int np012loc(int iproc) const { return np0_ * np1_ * np2_loc(iproc); }
int np012loc() const { return np0_ * np1_ * np2_loc_[myproc_]; } int np012loc(void) const { return np0_ * np1_ * np2_loc(); }
int index(int i, int j, int k) const int index(int i, int j, int k) const
{ return i + np0_ * ( j + np1_ * k ); } { return i + np0_ * ( j + np1_ * k ); }
int i(int ind) const { return ind % np0_; } int i(int ind) const { return ind % np0_; }
...@@ -153,11 +137,7 @@ class FourierTransform ...@@ -153,11 +137,7 @@ class FourierTransform
int k(int ind) const { return (ind / np0_) / np1_ + np2_first(); } int k(int ind) const { return (ind / np0_) / np1_ + np2_first(); }
void reset_timers(void); void reset_timers(void);
Timer tm_f_map, tm_f_fft, tm_f_pack, tm_f_mpi, tm_f_zero, tm_f_unpack, Timer tm_fwd, tm_bwd, tm_map_fwd, tm_map_bwd, tm_trans_fwd, tm_trans_bwd,
tm_b_map, tm_b_fft, tm_b_pack, tm_b_mpi, tm_b_zero, tm_b_unpack, tm_fxy, tm_fxy_inv, tm_fz, tm_fz_inv, tm_init;
tm_f_xy, tm_f_z, tm_f_x, tm_f_y,
tm_b_xy, tm_b_z, tm_b_x, tm_b_y,
tm_init, tm_b_com, tm_f_com;
}; };
#endif #endif
...@@ -32,8 +32,9 @@ using namespace std; ...@@ -32,8 +32,9 @@ using namespace std;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
MLWFTransform::MLWFTransform(const SlaterDet& sd) : sd_(sd), MLWFTransform::MLWFTransform(const SlaterDet& sd) : sd_(sd),
cell_(sd.basis().cell()), ctxt_(sd.context()), bm_(BasisMapping(sd.basis())), cell_(sd.basis().cell()), ctxt_(sd.context()),
maxsweep_(50), tol_(1.e-8) bm_(BasisMapping(sd.basis(),sd.basis().np(0),sd.basis().np(1),
sd.basis().np(2))), maxsweep_(50), tol_(1.e-8)
{ {
a_.resize(6); a_.resize(6);
adiag_.resize(6); adiag_.resize(6);
...@@ -105,7 +106,7 @@ void MLWFTransform::update(void) ...@@ -105,7 +106,7 @@ void MLWFTransform::update(void)
const int np1 = bm_.np1(); const int np1 = bm_.np1();
const int np2 = bm_.np2(); const int np2 = bm_.np2();
const int np01 = np0 * np1; const int np01 = np0 * np1;
const int np2loc = bm_.np2loc(); const int np2loc = bm_.np2_loc();
const int nvec = bm_.nvec(); const int nvec = bm_.nvec();
for ( int n = 0; n < c.nloc(); n++ ) for ( int n = 0; n < c.nloc(); n++ )
{ {
...@@ -132,7 +133,7 @@ void MLWFTransform::update(void) ...@@ -132,7 +133,7 @@ void MLWFTransform::update(void)
// x direction // x direction
// map zvec to ct // map zvec to ct
bm_.transpose_fwd(&zvec[0],&ct[0]); bm_.transpose_bwd(&zvec[0],&ct[0]);
for ( int iz = 0; iz < np2loc; iz++ ) for ( int iz = 0; iz < np2loc; iz++ )
{ {
...@@ -142,10 +143,10 @@ void MLWFTransform::update(void) ...@@ -142,10 +143,10 @@ void MLWFTransform::update(void)
compute_sincos(np0,&ct[ibase],&ct_cos[ibase],&ct_sin[ibase]); compute_sincos(np0,&ct[ibase],&ct_cos[ibase],&ct_sin[ibase]);
} }
} }
// transpose back ct_cos to zvec_cos // transpose ct_cos to zvec_cos
bm_.transpose_bwd(&ct_cos[0],&zvec_cos[0]); bm_.transpose_fwd(&ct_cos[0],&zvec_cos[0]);
// transpose back ct_sin to zvec_sin // transpose ct_sin to zvec_sin
bm_.transpose_bwd(&ct_sin[0],&zvec_sin[0]); bm_.transpose_fwd(&ct_sin[0],&zvec_sin[0]);
// map back zvec_cos to sdcos and zvec_sin to sdsin // map back zvec_cos to sdcos and zvec_sin to sdsin
bm_.zvec_to_vector(&zvec_cos[0],&fcx[0]); bm_.zvec_to_vector(&zvec_cos[0],&fcx[0]);
...@@ -167,10 +168,10 @@ void MLWFTransform::update(void) ...@@ -167,10 +168,10 @@ void MLWFTransform::update(void)
zcopy(&len,&csin_tmp[0],&one,&ct_sin[ibase],&stride); zcopy(&len,&csin_tmp[0],&one,&ct_sin[ibase],&stride);
} }
} }
// transpose back ct_cos to zvec_cos // transpose ct_cos to zvec_cos
bm_.transpose_bwd(&ct_cos[0],&zvec_cos[0]); bm_.transpose_fwd(&ct_cos[0],&zvec_cos[0]);
// transpose back ct_sin to zvec_sin // transpose ct_sin to zvec_sin
bm_.transpose_bwd(&ct_sin[0],&zvec_sin[0]); bm_.transpose_fwd(&ct_sin[0],&zvec_sin[0]);
// map back zvec_cos and zvec_sin // map back zvec_cos and zvec_sin
bm_.zvec_to_vector(&zvec_cos[0],&fcy[0]); bm_.zvec_to_vector(&zvec_cos[0],&fcy[0]);
......
...@@ -19,5 +19,5 @@ ...@@ -19,5 +19,5 @@
#include "release.h" #include "release.h"
std::string release(void)