Commit 8dc7100d by Francois Gygi

Rewrite BasisMapping and FourierTransform

Functions performing the various mappings and transposes are moved
to BasisMapping.
parent 57f8d918
......@@ -22,18 +22,16 @@
#include <iostream>
#include <cassert>
#include <cstring> // memset
using namespace std;
////////////////////////////////////////////////////////////////////////////////
BasisMapping::BasisMapping (const Basis &basis) : basis_(basis)
BasisMapping::BasisMapping (const Basis &basis, int np0, int np1, int np2) :
basis_(basis), np0_(np0), np1_(np1), np2_(np2)
{
nprocs_ = basis_.npes();
myproc_ = basis_.mype();
np0_ = basis.np(0);
np1_ = basis.np(1);
np2_ = basis.np(2);
np2_loc_.resize(nprocs_);
np2_first_.resize(nprocs_);
......@@ -127,6 +125,9 @@ BasisMapping::BasisMapping (const Basis &basis) : basis_(basis)
rdispl[iproc] = rdispl[iproc-1] + rcounts[iproc-1];
}
// check if the basis_ fits in the grid np0, np1, np2
assert(basis_.fits_in_grid(np0_,np1_,np2_));
if ( basis_.real() )
{
// compute index arrays ip_ and im_ for mapping vector->zvec
......@@ -389,7 +390,7 @@ BasisMapping::BasisMapping (const Basis &basis) : basis_(basis)
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::transpose_bwd(const complex<double> *zvec,
complex<double> *ct)
complex<double> *ct) const
{
// Transpose zvec to ct
// scatter zvec to sbuf for transpose
......@@ -476,7 +477,7 @@ void BasisMapping::transpose_bwd(const complex<double> *zvec,
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::transpose_fwd(const complex<double> *ct,
complex<double> *zvec)
complex<double> *zvec) const
{
// transpose ct to zvec
// gather ct into rbuf
......@@ -545,7 +546,7 @@ void BasisMapping::transpose_fwd(const complex<double> *ct,
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::vector_to_zvec(const complex<double> *c,
complex<double> *zvec)
complex<double> *zvec) const
{
// map coefficients from the basis order to a zvec
const int ng = basis_.localsize();
......@@ -584,9 +585,43 @@ void BasisMapping::vector_to_zvec(const complex<double> *c,
pz[2*ip+1] = b;
}
}
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::doublevector_to_zvec(const complex<double> *c1,
const complex<double> *c2, complex<double> *zvec) const
{
// map two real functions to zvec
assert(basis_.real());
memset((void*)&zvec[0],0,zvec_size()*sizeof(complex<double>));
double* const pz = (double*) &zvec[0];
const int ng = basis_.localsize();
const double* const pc1 = (double*) &c1[0];
const double* const pc2 = (double*) &c2[0];
#pragma omp parallel for
for ( int ig = 0; ig < ng; ig++ )
{
// const double a = c1[ig].real();
// const double b = c1[ig].imag();
// const double c = c2[ig].real();
// const double d = c2[ig].imag();
// zvec_[ip] = complex<double>(a-d, b+c);
// zvec_[im] = complex<double>(a+d, c-b);
const double a = pc1[2*ig];
const double b = pc1[2*ig+1];
const double c = pc2[2*ig];
const double d = pc2[2*ig+1];
const int ip = ip_[ig];
const int im = im_[ig];
pz[2*ip] = a - d;
pz[2*ip+1] = b + c;
pz[2*im] = a + d;
pz[2*im+1] = c - b;
}
}
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::zvec_to_vector(const complex<double> *zvec,
complex<double> *c)
complex<double> *c) const
{
const int ng = basis_.localsize();
const double* const pz = (const double*) zvec;
......@@ -601,3 +636,35 @@ void BasisMapping::zvec_to_vector(const complex<double> *zvec,
pc[2*ig+1] = pz1;
}
}
////////////////////////////////////////////////////////////////////////////////
void BasisMapping::zvec_to_doublevector(const complex<double> *zvec,
complex<double> *c1, complex<double> *c2 ) const
{
// Mapping of zvec onto two real functions
assert(basis_.real());
const int ng = basis_.localsize();
const double* const pz = (double*) &zvec[0];
double* const pc1 = (double*) &c1[0];
double* const pc2 = (double*) &c2[0];
#pragma omp parallel for
for ( int ig = 0; ig < ng; ig++ )
{
// const double a = 0.5*zvec_[ip].real();
// const double b = 0.5*zvec_[ip].imag();
// const double c = 0.5*zvec_[im].real();
// const double d = 0.5*zvec_[im].imag();
// c1[ig] = complex<double>(a+c, b-d);
// c2[ig] = complex<double>(b+d, c-a);
const int ip = ip_[ig];
const int im = im_[ig];
const double a = pz[2*ip];
const double b = pz[2*ip+1];
const double c = pz[2*im];
const double d = pz[2*im+1];
pc1[2*ig] = 0.5 * ( a + c );
pc1[2*ig+1] = 0.5 * ( b - d );
pc2[2*ig] = 0.5 * ( b + d );
pc2[2*ig+1] = 0.5 * ( c - a );
}
}
......@@ -45,7 +45,7 @@ class BasisMapping
public:
BasisMapping (const Basis &basis);
BasisMapping (const Basis &basis, int np0, int np1, int np2);
int np0(void) const { return np0_; }
int np1(void) const { return np1_; }
int np2(void) const { return np2_; }
......@@ -56,14 +56,20 @@ class BasisMapping
// map a function c(G) to zvec_
void vector_to_zvec(const std::complex<double> *c,
std::complex<double> *zvec);
std::complex<double> *zvec) const;
// map two real functions c1(G) and c2(G) to zvec_
void doublevector_to_zvec(const std::complex<double> *c1,
const std::complex<double> *c2,std::complex<double> *zvec) const;
// map zvec_ to a function c(G)
void zvec_to_vector(const std::complex<double> *zvec,
std::complex<double> *c);
std::complex<double> *c) const;
// map zvec_ to two real functions c1(G) and c2(G)
void zvec_to_doublevector(const std::complex<double> *zvec,
std::complex<double> *c1, std::complex<double> *c2) const;
void transpose_bwd(const std::complex<double> *zvec,
std::complex<double> *ct);
std::complex<double> *ct) const;
void transpose_fwd(const std::complex<double> *ct,
std::complex<double> *zvec);
std::complex<double> *zvec) const;
};
#endif
......@@ -50,6 +50,7 @@
#endif
#include "Timer.h"
#include "BasisMapping.h"
class Basis;
......@@ -59,24 +60,17 @@ class FourierTransform
MPI_Comm comm_;
const Basis& basis_;
const BasisMapping bm_;
int nprocs_, myproc_;
int np0_,np1_,np2_;
const int np0_,np1_,np2_;
int ntrans0_,ntrans1_,ntrans2_;
int nvec_;
bool basis_fits_in_grid_;
std::vector<int> np2_loc_; // np2_loc_[iproc], iproc=0, nprocs_-1
std::vector<int> np2_first_; // np2_first_[iproc], iproc=0, nprocs_-1
std::vector<std::complex<double> > zvec_;
std::vector<int> scounts, sdispl, rcounts, rdispl;
std::vector<std::complex<double> > sbuf, rbuf;
std::vector<int> ifftp_, ifftm_;
std::vector<int> ipack_, iunpack_;
void init_lib(void);
#if USE_ESSL_FFT
......@@ -107,11 +101,6 @@ class FourierTransform
#error "Must define USE_FFTW2, USE_FFTW3, USE_ESSL_FFT or FFT_NOLIB"
#endif
void vector_to_zvec(const std::complex<double>* c);
void zvec_to_vector(std::complex<double>* c);
void doublevector_to_zvec(const std::complex<double>* c1,
const std::complex<double> *c2);
void zvec_to_doublevector(std::complex<double>* c1, std::complex<double>* c2);
void fwd(std::complex<double>* val);
void bwd(std::complex<double>* val);
......
......@@ -32,8 +32,9 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
MLWFTransform::MLWFTransform(const SlaterDet& sd) : sd_(sd),
cell_(sd.basis().cell()), ctxt_(sd.context()), bm_(BasisMapping(sd.basis())),
maxsweep_(50), tol_(1.e-8)
cell_(sd.basis().cell()), ctxt_(sd.context()),
bm_(BasisMapping(sd.basis(),sd.basis().np(0),sd.basis().np(1),
sd.basis().np(2))), maxsweep_(50), tol_(1.e-8)
{
a_.resize(6);
adiag_.resize(6);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment