【OTFFT: High Speed FFT library】

【Stockham FFT の AVX による最適化】

 前ページ で配列へのアクセスを減らした Stockham のアルゴリズムを示しましが、 これに Intel AVX を適用するとさらに最適化することができます。とりあえず、 リスト10:配列へのアクセスを減らした Stockham のアルゴリズムを再掲しておきす。このアルゴリズムをベースに、各演算を AVX 化します。

リスト10:配列へのアクセスを減らした Stockham のアルゴリズム
#include <complex>
#include <cmath>

typedef std::complex<double> complex_t;

void fft0(int n, int s, bool eo, complex_t* x, complex_t* y)
// n  : 系列長
// s  : ストライド
// eo : eo == 0 か false なら x が出力、eo == 1 か true なら y が出力
// x  : フーリエ変換する入力系列(eo == 0 のとき出力)
// y  : 作業用配列(eo == 1 のとき出力)
{
    const int m = n/2;
    const double theta0 = 2*M_PI/n;

    if (n == 2) {
        complex_t* z = eo ? y : x;
        for (int q = 0; q < s; q++) {
            const complex_t a = x[q + 0];
            const complex_t b = x[q + s];
            z[q + 0] = a + b;
            z[q + s] = a - b;
        }
    }
    else if (n >= 4) {
        for (int p = 0; p < m; p++) {
            const complex_t wp = complex_t(cos(p*theta0), -sin(p*theta0));
            for (int q = 0; q < s; q++) {
                const complex_t a = x[q + s*(p + 0)];
                const complex_t b = x[q + s*(p + m)];
                y[q + s*(2*p + 0)] =  a + b;
                y[q + s*(2*p + 1)] = (a - b) * wp;
            }
        }
        fft0(n/2, 2*s, !eo, y, x);
    }
}

void fft(int N, complex_t* x) // フーリエ変換
// N : 系列長
// x : フーリエ変換する系列(入出力)
{
    complex_t* y = new complex_t[N];
    fft0(N, 1, 0, x, y);
    delete[] y;
    for (int k = 0; k < N; k++) x[k] /= N;
}

void ifft(int N, complex_t* x) // 逆フーリエ変換
// N : 系列長
// x : 逆フーリエ変換する系列(入出力)
{
    for (int p = 0; p < N; p++) x[p] = conj(x[p]);
    complex_t* y = new complex_t[N];
    fft0(N, 1, 0, x, y);
    delete[] y;
    for (int k = 0; k < N; k++) x[k] = conj(x[k]);
}

 まずは、比較的素直に AVX 化できる部分のみ AVX 化したバージョンを示します。 以下のようになります。

リスト11:AVX 化した Stockham のアルゴリズム
#include <complex>
#include <cmath>
#include <immintrin.h>

struct complex_t {
    double Re, Im;
    complex_t(const double& x, const double& y) : Re(x), Im(y) {}
};

inline complex_t operator+(const complex_t& x, const complex_t& y)
{
    return complex_t(x.Re + y.Re, x.Im + y.Im);
}

inline complex_t operator-(const complex_t& x, const complex_t& y)
{
    return complex_t(x.Re - y.Re, x.Im - y.Im);
}

inline complex_t operator*(const complex_t& x, const complex_t& y)
{
    return complex_t(x.Re*y.Re - x.Im*y.Im, x.Re*y.Im + x.Im*y.Re);
}

__m256d mulpz2(const __m256d ab, const __m256d xy) // __m256d 型の複素数の乗算
{
    const __m256d aa = _mm256_unpacklo_pd(ab, ab);
    const __m256d bb = _mm256_unpackhi_pd(ab, ab);
    const __m256d yx = _mm256_shuffle_pd(xy, xy, 5);
    return _mm256_addsub_pd(_mm256_mul_pd(aa, xy), _mm256_mul_pd(bb, yx));
}

void fft0(int n, int s, bool eo, complex_t* x, complex_t* y)
// n  : 系列長
// s  : ストライド
// eo : eo == 0 か false なら x が出力、eo == 1 か true なら y が出力
// x  : フーリエ変換する入力系列(eo == 0 のとき出力)
// y  : 作業用配列(eo == 1 のとき出力)
{
    const int m = n/2;
    const double theta0 = 2*M_PI/n;

    if (n == 2) {
        complex_t* z = eo ? y : x;
        if (s == 1) {
            double* xd = &x->Re;
            double* zd = &z->Re;
            const __m128d a = _mm_load_pd(xd + 2*0);
            const __m128d b = _mm_load_pd(xd + 2*1);
            _mm_store_pd(zd + 2*0, _mm_add_pd(a, b));
            _mm_store_pd(zd + 2*1, _mm_sub_pd(a, b));
        }
        else {
            for (int q = 0; q < s; q += 2) {
                double* xd = &(x + q)->Re;
                double* zd = &(z + q)->Re;
                const __m256d a = _mm256_load_pd(xd + 2*0);
                const __m256d b = _mm256_load_pd(xd + 2*s);
                _mm256_store_pd(zd + 2*0, _mm256_add_pd(a, b));
                _mm256_store_pd(zd + 2*s, _mm256_sub_pd(a, b));
            }
        }
    }
    else if (n >= 4) {
        if (s == 1) {
            for (int p = 0; p < m; p++) {
                const complex_t wp = complex_t(cos(p*theta0), -sin(p*theta0));
                const complex_t a = x[p + 0];
                const complex_t b = x[p + m];
                y[2*p + 0] =  a + b;
                y[2*p + 1] = (a - b) * wp;
            }
        }
        else {
            for (int p = 0; p < m; p++) {
                const double cs = cos(p*theta0);
                const double sn = sin(p*theta0);
                const __m256d wp = _mm256_setr_pd(cs, -sn, cs, -sn);
                for (int q = 0; q < s; q += 2) {
                    double* xd = &(x + q)->Re;
                    double* yd = &(y + q)->Re;
                    const __m256d a = _mm256_load_pd(xd + 2*s*(p + 0));
                    const __m256d b = _mm256_load_pd(xd + 2*s*(p + m));
                    _mm256_store_pd(yd + 2*s*(2*p + 0),            _mm256_add_pd(a, b));
                    _mm256_store_pd(yd + 2*s*(2*p + 1), mulpz2(wp, _mm256_sub_pd(a, b)));
                }
            }
        }
        fft0(n/2, 2*s, !eo, y, x);
    }
}

void fft(int N, std::complex<double>* x) // フーリエ変換
// N : 系列長
// x : フーリエ変換する系列(入出力)
{
    complex_t* y = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    complex_t* z = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    for (int p = 0; p < N; p++) {
        y[p].Re = x[p].real();
        y[p].Im = x[p].imag();
    }
    fft0(N, 1, 0, y, z);
    for (int k = 0; k < N; k++)
        x[k] = std::complex<double>(y[k].Re/N, y[k].Im/N);
    _mm_free(z);
    _mm_free(y);
}

void ifft(int N, std::complex<double>* x) // 逆フーリエ変換
// N : 系列長
// x : 逆フーリエ変換する系列(入出力)
{
    complex_t* y = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    complex_t* z = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    for (int p = 0; p < N; p++) {
        y[p].Re =  x[p].real();
        y[p].Im = -x[p].imag();
    }
    fft0(N, 1, 0, y, z);
    for (int k = 0; k < N; k++)
        x[k] = std::complex<double>(y[k].Re, -y[k].Im);
    _mm_free(z);
    _mm_free(y);
}

 次に、できうる限り AVX 化したバージョンを示します。以下のようになります。

リスト12:完全に AVX 化した Stockham のアルゴリズム
#include <complex>
#include <cmath>
#include <immintrin.h>

struct complex_t { double Re, Im; };

__m256d mulpz2(const __m256d ab, const __m256d xy) // __m256d 型の複素数の乗算
{
    const __m256d aa = _mm256_unpacklo_pd(ab, ab);
    const __m256d bb = _mm256_unpackhi_pd(ab, ab);
    const __m256d yx = _mm256_shuffle_pd(xy, xy, 5);
    return _mm256_addsub_pd(_mm256_mul_pd(aa, xy), _mm256_mul_pd(bb, yx));
}

void fft0(int n, int s, bool eo, complex_t* x, complex_t* y)
// n  : 系列長
// s  : ストライド
// eo : eo == 0 か false なら x が出力、eo == 1 か true なら y が出力
// x  : フーリエ変換する入力系列(eo == 0 のとき出力)
// y  : 作業用配列(eo == 1 のとき出力)
{
    const int m = n/2;
    const double theta0 = 2*M_PI/n;

    if (n == 2) {
        complex_t* z = eo ? y : x;
        if (s == 1) {
            double* xd = &x->Re;
            double* zd = &z->Re;
            const __m128d a = _mm_load_pd(xd + 2*0);
            const __m128d b = _mm_load_pd(xd + 2*1);
            _mm_store_pd(zd + 2*0, _mm_add_pd(a, b));
            _mm_store_pd(zd + 2*1, _mm_sub_pd(a, b));
        }
        else {
            for (int q = 0; q < s; q += 2) {
                double* xd = &(x + q)->Re;
                double* zd = &(z + q)->Re;
                const __m256d a = _mm256_load_pd(xd + 2*0);
                const __m256d b = _mm256_load_pd(xd + 2*s);
                _mm256_store_pd(zd + 2*0, _mm256_add_pd(a, b));
                _mm256_store_pd(zd + 2*s, _mm256_sub_pd(a, b));
            }
        }
    }
    else if (n >= 4) {
        if (s == 1) {
            for (int p = 0; p < m; p += 2) {
                const double cs0 = cos((p+0)*theta0);
                const double sn0 = sin((p+0)*theta0);
                const double cs1 = cos((p+1)*theta0);
                const double sn1 = sin((p+1)*theta0);
                const __m256d wp = _mm256_setr_pd(cs0, -sn0, cs1, -sn1);
                double* xd = &(x + p)->Re;
                double* yd = &(y + 2*p)->Re;
                const __m256d a = _mm256_load_pd(xd + 2*0);
                const __m256d b = _mm256_load_pd(xd + 2*m);
                const __m256d aA =            _mm256_add_pd(a, b);
                const __m256d bB = mulpz2(wp, _mm256_sub_pd(a, b));
                const __m256d ab = _mm256_permute2f128_pd(aA, bB, 0x20);
                const __m256d AB = _mm256_permute2f128_pd(aA, bB, 0x31);
                _mm256_store_pd(yd + 2*0, ab);
                _mm256_store_pd(yd + 2*2, AB);
            }
        }
        else {
            for (int p = 0; p < m; p++) {
                const double cs = cos(p*theta0);
                const double sn = sin(p*theta0);
                const __m256d wp = _mm256_setr_pd(cs, -sn, cs, -sn);
                for (int q = 0; q < s; q += 2) {
                    double* xd = &(x + q)->Re;
                    double* yd = &(y + q)->Re;
                    const __m256d a = _mm256_load_pd(xd + 2*s*(p + 0));
                    const __m256d b = _mm256_load_pd(xd + 2*s*(p + m));
                    _mm256_store_pd(yd + 2*s*(2*p + 0),            _mm256_add_pd(a, b));
                    _mm256_store_pd(yd + 2*s*(2*p + 1), mulpz2(wp, _mm256_sub_pd(a, b)));
                }
            }
        }
        fft0(n/2, 2*s, !eo, y, x);
    }
}

void fft(int N, std::complex<double>* x) // フーリエ変換
// N : 系列長
// x : フーリエ変換する系列(入出力)
{
    complex_t* y = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    complex_t* z = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    for (int p = 0; p < N; p++) {
        y[p].Re = x[p].real();
        y[p].Im = x[p].imag();
    }
    fft0(N, 1, 0, y, z);
    for (int k = 0; k < N; k++)
        x[k] = std::complex<double>(y[k].Re/N, y[k].Im/N);
    _mm_free(z);
    _mm_free(y);
}

void ifft(int N, std::complex<double>* x) // 逆フーリエ変換
// N : 系列長
// x : 逆フーリエ変換する系列(入出力)
{
    complex_t* y = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    complex_t* z = (complex_t*) _mm_malloc(N*sizeof(complex_t), 32);
    for (int p = 0; p < N; p++) {
        y[p].Re =  x[p].real();
        y[p].Im = -x[p].imag();
    }
    fft0(N, 1, 0, y, z);
    for (int k = 0; k < N; k++)
        x[k] = std::complex<double>(y[k].Re, -y[k].Im);
    _mm_free(z);
    _mm_free(y);
}