Commit 90f6ddce by Francois Gygi

Update gather/scatter in BasisMapping for MKL

parent 4f1caa41
......@@ -24,6 +24,14 @@
#include <cassert>
#include <cstring> // memset
using namespace std;
#if USE_GATHER_SCATTER
extern "C" {
// zgthr: x(i) = y(indx(i))
void zgthr_(int* n, complex<double>* y, complex<double>* x, int *indx);
// zsctr: y(indx(i)) = x(i)
void zsctr_(int* n, complex<double>* x, int* indx, complex<double>* y);
}
#endif
////////////////////////////////////////////////////////////////////////////////
BasisMapping::BasisMapping (const Basis &basis, int np0, int np1, int np2) :
......@@ -396,12 +404,11 @@ void BasisMapping::transpose_bwd(const complex<double> *zvec,
// scatter zvec to sbuf for transpose
#if USE_GATHER_SCATTER
// zsctr: y(indx(i)) = x(i)
// void zsctr_(int* n, complex<double>* x, int* indx, complex<double>* y);
{
complex<double>* y = &sbuf[0];
complex<double>* y = const_cast<complex<double>*>(&sbuf[0]);
complex<double>* x = const_cast<complex<double>*>(zvec);
int n = zvec_.size();
zsctr_(&n,x,&ipack_[0],y);
int n = zvec_size();
zsctr_(&n,x,const_cast<int*>(&ipack_[0]),y);
}
#else
const int len = zvec_size();
......@@ -451,9 +458,9 @@ void BasisMapping::transpose_bwd(const complex<double> *zvec,
// zsctr(n,x,indx,y): y(indx(i)) = x(i)
{
complex<double>* y = ct;
complex<double>* x = &rbuf[0];
complex<double>* x = const_cast<complex<double>*>(&rbuf[0]);
int n = rbuf.size();
zsctr_(&n,x,&iunpack_[0],y);
zsctr_(&n,x,const_cast<int*>(&iunpack_[0]),y);
}
#else
{
......@@ -483,12 +490,11 @@ void BasisMapping::transpose_fwd(const complex<double> *ct,
// gather ct into rbuf
#if USE_GATHER_SCATTER
// zgthr: x(i) = y(indx(i))
// void zgthr_(int* n, complex<double>* y, complex<double>* x, int*indx);
{
complex<double>* y = const_cast<complex<double>*>(ct);
complex<double>* x = &rbuf[0];
complex<double>* x = const_cast<complex<double>*>(&rbuf[0]);
int n = rbuf.size();
zgthr_(&n,y,x,&iunpack_[0]);
zgthr_(&n,y,x,const_cast<int*>(&iunpack_[0]));
}
#else
const int rbuf_size = rbuf.size();
......@@ -521,12 +527,11 @@ void BasisMapping::transpose_fwd(const complex<double> *ct,
#if USE_GATHER_SCATTER
// zgthr: x(i) = y(indx(i))
// void zgthr_(int* n, complex<double>* y, complex<double>* x, int*indx);
{
complex<double>* y = &sbuf[0];
complex<double>* y = const_cast<complex<double>*>(&sbuf[0]);
complex<double>* x = zvec;
int n = zvec_.size();
zgthr_(&n,y,x,&ipack_[0]);
int n = zvec_size();
zgthr_(&n,y,x,const_cast<int*>(&ipack_[0]));
}
#else
const int len = zvec_size();
......
......@@ -62,7 +62,7 @@ int main(int argc, char **argv)
<< " np1=" << basis.np(1)
<< " np2=" << basis.np(2) << endl;
cout << " basis.size=" << basis.size() << endl;
BasisMapping bmap(basis);
BasisMapping bmap(basis,basis.np(0),basis.np(1),basis.np(2));
cout << " zvec_size=" << bmap.zvec_size() << endl;
cout << " np012loc=" << bmap.np012loc() << endl;
......@@ -82,7 +82,7 @@ int main(int argc, char **argv)
bmap.vector_to_zvec(&f[0],&zvec[0]);
bmap.transpose_bwd(&zvec[0],&ct[0]);
for ( int k = 0; k < bmap.np2loc(); k++ )
for ( int k = 0; k < bmap.np2_loc(); k++ )
for ( int j = 0; j < bmap.np1(); j++ )
for ( int i = 0; i < bmap.np0(); i++ )
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment