Commit 3c5fea3f by Francois Gygi

merged fftw3 feature branch.

enables USE_FFTW2, USEFFTW3 and USE_ESSL_FFT


git-svn-id: http://qboxcode.org/svn/qb/trunk@1548 cba15fb0-1239-40c8-b417-11db7ca47a34
parent 19452efb
......@@ -110,6 +110,9 @@ Bisection::Bisection(const SlaterDet& sd, int nlevels[3])
if ( np_[idim] % base != 0 ) np_[idim] += base/2;
}
}
while (!sd.basis().factorizable(np_[0])) np_[0] += (1<<nlevels[0]);
while (!sd.basis().factorizable(np_[1])) np_[1] += (1<<nlevels[1]);
while (!sd.basis().factorizable(np_[2])) np_[2] += (1<<nlevels[2]);
// number of grid points of augmented grid for normalization
ft_ = new FourierTransform(sd.basis(),np_[0],np_[1],np_[2]);
......
......@@ -22,7 +22,15 @@
#include <complex>
#include <vector>
#if USE_FFTW
#if !( defined(USE_FFTW2) || defined(USE_FFTW3) || defined(USE_ESSL_FFT) || defined(FFT_NOLIB) )
#error "Must define USE_FFTW2, USE_FFTW3, USE_ESSL_FFT or FFT_NOLIB"
#endif
#if defined(USE_FFTW2) && defined(USE_FFTW3)
#error "Cannot define USE_FFTW2 and USE_FFTW3"
#endif
#if USE_FFTW2
#if USE_DFFTW
#include "dfftw.h"
#else
......@@ -30,6 +38,13 @@
#endif
#endif
#if USE_FFTW3
#include "fftw3.h"
#if USE_FFTW3MKL
#include "fftw3_mkl.h"
#endif
#endif
#if USE_MPI
#include <mpi.h>
#endif
......@@ -63,21 +78,32 @@ class FourierTransform
void init_lib(void);
#if USE_ESSL
#if USE_ESSL_FFT
#if USE_ESSL_2DFFT
std::vector<double> aux1xyf;
std::vector<double> aux1xyb;
int naux1xy;
std::vector<double> aux1xyf,aux1zf;
std::vector<double> aux1xyb,aux1zb;
std::vector<double> aux2;
int naux1xy,naux1z,naux2;
#else
std::vector<double> aux1xf, aux1yf, aux1zf;
std::vector<double> aux1xb, aux1yb, aux1zb;
std::vector<double> aux2;
int naux1x,naux1y,naux1z,naux2;
#endif
#elif USE_FFTW || USE_FFTW3
#elif USE_FFTW2
fftw_plan fwplan0,fwplan1,fwplan2,bwplan0,bwplan1,bwplan2;
#elif USE_FFTW3
//plans for np2_
fftw_plan fwplan, bwplan;
#if defined(USE_FFTW3_2D) || defined(USE_FFTW3_THREADS)
fftw_plan fwplan2d, bwplan2d;
#else
fftw_plan fwplanx, fwplany, bwplanx, bwplany;
#endif
#elif defined(FFT_NOLIB)
// no library
#else
#error "Must define USE_FFTW2, USE_FFTW3, USE_ESSL_FFT or FFT_NOLIB"
#endif
void vector_to_zvec(const std::complex<double>* c);
......@@ -124,6 +150,10 @@ class FourierTransform
void reset_timers(void);
Timer tm_f_map, tm_f_fft, tm_f_pack, tm_f_mpi, tm_f_zero, tm_f_unpack,
tm_b_map, tm_b_fft, tm_b_pack, tm_b_mpi, tm_b_zero, tm_b_unpack;
tm_b_map, tm_b_fft, tm_b_pack, tm_b_mpi, tm_b_zero, tm_b_unpack,
tm_f_xy, tm_f_z, tm_f_x, tm_f_y,
tm_b_xy, tm_b_z, tm_b_x, tm_b_y,
tm_init, tm_b_com, tm_f_com;
};
#endif
......@@ -15,27 +15,61 @@
// sinft.C
//
////////////////////////////////////////////////////////////////////////////////
// $Id: sinft.C,v 1.4 2008-09-08 15:56:20 fgygi Exp $
#include "sinft.h"
#include <math.h>
#include <assert.h>
#if USE_FFTW
#include <vector>
#include <complex>
using namespace std;
#if USE_FFTW2
#if USE_DFFTW
#include "dfftw.h"
#else
#include "fftw.h"
#endif
#elif USE_FFTW3
#include "fftw3.h"
#elif USE_ESSL_FFT
extern "C" {
void dcft_(int *initflag, std::complex<double> *x, int *inc2x, int *inc3x,
std::complex<double> *y, int *inc2y, int *inc3y,
int *length, int *ntrans, int *isign,
double *scale, double *aux1, int *naux1,
double *aux2, int *naux2);
}
#else
// no FFT library
void cfft ( int idir, complex<double> *z1, complex<double> *z2, int n,
int *inzee );
void fftstp ( int idir, complex<double> *zin, int after,
int now, int before, complex<double> *zout );
#endif
#include <vector>
#include <complex>
using namespace std;
void sinft(int n, double *y)
{
vector<complex<double> > zin(2*n), zout(2*n);
#if defined(USE_FFTW2) || defined(USE_FFTW3)
fftw_plan fwplan;
#endif
#if USE_FFTW2
fwplan = fftw_create_plan(2*n,FFTW_FORWARD,FFTW_ESTIMATE);
vector<complex<double> > zin(2*n), zout(2*n);
#elif USE_FFTW3
fwplan = fftw_plan_dft_1d(2*n,(fftw_complex*)&zin[0],(fftw_complex*)&zout[0],
FFTW_FORWARD, FFTW_ESTIMATE);
#elif USE_ESSL_FFT
int np = 2 * n;
int naux1 = (int) (30000 + 2.28 * np);
std::vector<double> aux1(naux1);
int ntrans = 1;
int naux2 = (int) (20000 + 2.28 * np + (256 + 2*np)*min(64,ntrans));
std::vector<double> aux2(naux2);
#else
// no FFT library
// no initialization needed
#endif
zin[0] = 0.0;
for ( int i = 1; i < n; i++ )
{
......@@ -43,20 +77,61 @@ void sinft(int n, double *y)
zin[i] = t;
zin[2*n-i] = -t;
}
#if USE_FFTW2
fftw_one(fwplan,(fftw_complex*)&zin[0],(fftw_complex*)&zout[0]);
#elif USE_FFTW3
fftw_execute(fwplan);
#elif USE_ESSL_FFT
// initialize forward transform
int initflag = 1;
int inc1 = 1, inc2 = np;
int isign = 1;
double scale = 1.0;
complex<double> *p = 0;
dcft_(&initflag,p,&inc1,&inc2,p,&inc1,&inc2,&np,&ntrans,
&isign,&scale,&aux1[0],&naux1,&aux2[0],&naux2);
// call transform
initflag = 0;
dcft_(&initflag,&zin[0],&inc1,&inc2,&zout[0],&inc1,&inc2,&np,&ntrans,
&isign,&scale,&aux1[0],&naux1,&aux2[0],&naux2);
#else
// no FFT library
int idir = 1, inzee = 1, np = 2*n;
cfft ( idir, &zin[0],&zout[0],np, &inzee );
#endif
for ( int i = 0; i < n; i++ )
{
y[i] = -0.5 * imag(zout[i]);
}
#if defined(USE_FFTW2) || defined(USE_FFTW3)
fftw_destroy_plan(fwplan);
#endif
}
void cosft1(int n, double *y)
{
/* Note: the array y contains n+1 elements */
vector<complex<double> > zin(2*n), zout(2*n);
#if defined(USE_FFTW2) || defined(USE_FFTW3)
fftw_plan fwplan;
#endif
#if USE_FFTW2
fwplan = fftw_create_plan(2*n,FFTW_FORWARD,FFTW_ESTIMATE);
vector<complex<double> > zin(2*n), zout(2*n);
#elif USE_FFTW3
fwplan = fftw_plan_dft_1d(2*n,(fftw_complex*)&zin[0],(fftw_complex*)&zout[0],
FFTW_FORWARD, FFTW_ESTIMATE);
#elif USE_ESSL_FFT
int np = 2 * n;
int naux1 = (int) (30000 + 2.28 * np);
std::vector<double> aux1(naux1);
int ntrans = 1;
int naux2 = (int) (20000 + 2.28 * np + (256 + 2*np)*min(64,ntrans));
std::vector<double> aux2(naux2);
#else
// no FFT library
// no initialization needed
#endif
zin[0] = y[0];
for ( int i = 1; i < n+1; i++ )
......@@ -65,11 +140,35 @@ void cosft1(int n, double *y)
zin[i] = t;
zin[2*n-i] = t;
}
#if USE_FFTW2
fftw_one(fwplan,(fftw_complex*)&zin[0],(fftw_complex*)&zout[0]);
#elif USE_FFTW3
fftw_execute(fwplan);
#elif USE_ESSL_FFT
// initialize forward transform
int initflag = 1;
int inc1 = 1, inc2 = np;
int isign = 1;
double scale = 1.0;
complex<double> *p = 0;
dcft_(&initflag,p,&inc1,&inc2,p,&inc1,&inc2,&np,&ntrans,
&isign,&scale,&aux1[0],&naux1,&aux2[0],&naux2);
// call transform
initflag = 0;
dcft_(&initflag,&zin[0],&inc1,&inc2,&zout[0],&inc1,&inc2,&np,&ntrans,
&isign,&scale,&aux1[0],&naux1,&aux2[0],&naux2);
#else
// no FFT library
int idir = 1, inzee = 1, np = 2*n;
cfft ( idir, &zin[0],&zout[0],np, &inzee );
#endif
y[0] = 0.5 * real(zout[0]);
for ( int i = 1; i < n; i++ )
{
y[i] = 0.5 * real(zout[i]);
}
#if defined(USE_FFTW2) || defined(USE_FFTW3)
fftw_destroy_plan(fwplan);
#endif
}
......@@ -16,6 +16,7 @@
#include <iostream>
#include <iomanip>
#include <cstdlib>
using namespace std;
#include "Basis.h"
......@@ -142,7 +143,7 @@ int main(int argc, char **argv)
cout << " backward done " << endl;
MPI_Barrier(MPI_COMM_WORLD);
#if 0
#if 1
tm.reset();
ft2.reset_timers();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment