情報系の手考ノート

数学とか情報系の技術とか調べたり勉強したりしてメモしていきます.

FFTを実装してみた (C++)

FFTの式を導いたので,C++で実装してみました。

前準備

前の記事から,fをデータ数N = 2^ kとしてDFTした結果F

\begin{aligned}
F(x + \frac{N}{2} b_0) = f_{\frac{N}{2}}(x) + (-1)^{b_0} w^{x}_{N} f_{\frac{N}{2}}(x + \frac{N}{2})
\qquad (0 \le x \le \frac{N}{2}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ l}}(x + \frac{N}{2^{l+1}}(2^ l &b_0 + \cdots + 2b_{l-1} + b_l)) \\
=& f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l+1}}(2^ l b_0 + \cdots + 2b_{l-1})) \\
&+ (-1)^{b_l} w^{x}_{\frac{N}{2^{l}}} f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l+1}}(2^ l b_0 + \cdots + 2b_{l-1} + 1)) \\
& \qquad \qquad (1 \le l \le k-1, \quad 0 \le x \le \frac{N}{2^{l+1}}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ k}}(2^{k-1} b_0 + \cdots + b_{k-1}) = f(2^{k-1} b_{k-1} + \cdots + b_0)
\end{aligned}

と書くことができます。 なお,w^{a}_{N} = e^{-\frac{2 \pi a}{N}i}b_0, b_1, \cdots , b_{k-1} \in \{ 0, 1 \}です。

このまま実装してもいいですが,これだと少しわかりにくいしやりにくいので少し式を書き換えます。

まず,2進数に対して

\begin{aligned}
(a_{N-1}, a_{N-2}, \cdots, a_{1}, a{0})_2 = 2^{N-1} a_{N-1} + 2^{N-2} a_{N-2} + \cdots + 2^{1} a_{1} + 2^{0} a_{0}
\end{aligned}

という書き方をすることにします。 すると,FFTの式はf

\begin{aligned}
F(x + \frac{N}{2} b_0) = f_{\frac{N}{2}}(x) + (-1)^{b_0} w^{x}_{N} f_{\frac{N}{2}}(x + \frac{N}{2})
\qquad (0 \le x \le \frac{N}{2}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ l}}(x + \frac{N}{2^{l}}(&b_0, b_1, \cdots , b_{l-1})_2 + \frac{N}{2^{l+1}} b_l) \\
=& f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l}}(b_0, b_1, \cdots , 2b_{l-1})_2) \\
&+ (-1)^{b_l} w^{x}_{\frac{N}{2^{l}}} f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l}}(b_0, b_1, \cdots , 2b_{l-1})_2 + \frac{N}{2^{l+1}}) \\
& \qquad \qquad (1 \le l \le k-1, \quad 0 \le x \le \frac{N}{2^{l+1}}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ k}}((b_0, b_1, \cdots , b_{k-1})_2) = f((b_{k-1}, b_{k-2}, \cdots , b_0)_2)
\end{aligned}

と書くことができます。 さらに,あらたにjを用いて

\begin{aligned}
F(x + \frac{N}{2} b_0) = f_{\frac{N}{2}}(x) + (-1)^{b_0} w^{x}_{N} f_{\frac{N}{2}}(x + \frac{N}{2})
\qquad (0 \le x \le \frac{N}{2}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ l}}(x + \frac{N}{2^{l}} j + \frac{N}{2^{l+1}} b_l) =& f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l}} j) + (-1)^{b_l} w^{x}_{\frac{N}{2^{l}}} f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l}} j + \frac{N}{2^{l+1}}) \\
& (1 \le l \le k-1, \quad 0 \le j \le 2^ l -1, \quad 0 \le x \le \frac{N}{2^{l+1}}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ k}}((b_0, b_1, \cdots , b_{k-1})_2) = f((b_{k-1}, b_{k-2}, \cdots , b_0)_2)
\end{aligned}

と書き換えます。

これで,あとはほぼ実装するだけになりました。 しかしw^{x}_{N}をいちいち計算するのは面倒ですし,先に計算しておくにしてもw^{x}_{\frac{N}{2^ l}}l0 \le l \le k-1までと保持して置かなければならない数が多すぎます。 そこで,式中のw^{x}_{\frac{N}{2^ l}}をすべてw^{x}_{N}に書き換えます。 すると

\begin{aligned}
F(x + \frac{N}{2} b_0) = f_{\frac{N}{2}}(x) + (-1)^{b_0} w^{x}_{N} f_{\frac{N}{2}}(x + \frac{N}{2})
\qquad (0 \le x \le \frac{N}{2}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ l}}(x + \frac{N}{2^{l}} j + \frac{N}{2^{l+1}} b_l) =& f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l}} j) + (-1)^{b_l} w^{2^{l} x}_{N} f_{\frac{N}{2^{l+1}}}(x + \frac{N}{2^{l}} j + \frac{N}{2^{l+1}}) \\
& (1 \le l \le k-1, \quad 0 \le j \le 2^ l -1, \quad 0 \le x \le \frac{N}{2^{l+1}}-1)
\end{aligned}
\begin{aligned}
f_{\frac{N}{2^ k}}((b_0, b_1, \cdots , b_{k-1})_2) = f((b_{k-1}, b_{k-2}, \cdots , b_0)_2)
\end{aligned}

となります。

あとは,この式を実装していきます。

ビット順序入れ替え

f_{\frac{N}{2^ k}}の式中で,ビットの順序を逆順に入れ替える操作が必要になります。 これはビット演算で実装しました。

unsigned int bit_invert(unsigned int x, int bit_number) {
    x = ((x & 0xffff0000) >> 16) | ((x & 0x0000ffff) << 16);
    x = ((x & 0xff00ff00) >>  8) | ((x & 0x00ff00ff) <<  8);
    x = ((x & 0xf0f0f0f0) >>  4) | ((x & 0x0f0f0f0f) <<  4);
    x = ((x & 0xcccccccc) >>  2) | ((x & 0x33333333) <<  2);
    x = ((x & 0xaaaaaaaa) >>  1) | ((x & 0x55555555) <<  1);

    return x >> (32 - bit_number);
}

プログラム中でのxは入力,bit_numberは入れ替えるビット数です。

たとえば

unsigned int a = bit_invert(3, 2);
unsigned int b = bit_invert(3, 3);

とすれば,aには3,bには6が代入されます。

実装

上で作ったbit_invert関数を使ってFFTを実装すると以下のようになりました。

std::vector<std::complex<long double>> fft(const std::vector<std::complex<long double>> &f) {
    std::vector<std::complex<long double>> F;

    int N = f.size();
    if (N != (N & (-N)))
        return F;
    F.resize(N);

    unsigned int k = static_cast<unsigned int>(std::log2(N));

    std::vector<std::complex<long double>> w(N / 2);
    for (int x = 0; x < w.size(); x++)
        w[x] = std::exp(std::complex<long double>(0.0, -2.0 * pi * x / N));

    for (int x = 0; x < N; x++)
        F[x] = f[bit_invert(x, k)];

    std::complex<long double> a, m;
    for (int n = 1, J = N / 2; n < N; n *= 2, J /= 2) {
        for (int x = 0; x < n; x++) {
            for (int j = 0; j < J; j++) {
                a = F[x + 2 * n * j] + w[J * x] * F[x + 2 * n * j + n];
                m = F[x + 2 * n * j] - w[J * x] * F[x + 2 * n * j + n];

                F[x + 2 * n * j    ] = a;
                F[x + 2 * n * j + n] = m;
            }
        }
    }

    return F;
}

さらに,FFTの逆変換(以下,IFFTと書く)はwの指数の符号が違うのと,最終的な結果をNで割る以外の点ではFFTと同じなので関数は以下のようになります。

std::vector<std::complex<long double>> ifft(const std::vector<std::complex<long double>> &f) {
    std::vector<std::complex<long double>> f;

    int N = F.size();
    if (N != (N & (-N)))
        return f;
    f.resize(N);

    unsigned int k = static_cast<unsigned int>(std::log2(N));

    std::vector<std::complex<long double>> w(N / 2);
    for (int t = 0; t < w.size(); t++)
        w[t] = std::exp(std::complex<long double>(0.0, 2.0 * pi * t / N));

    for (int t = 0; t < N; t++)
        f[t] = F[bit_invert(t, k)];

    std::complex<long double> a, m;
    for (int n = 1, J = N / 2; n < N; n *= 2, J /= 2) {
        for (int t = 0; t < n; t++) {
            for (int j = 0; j < J; j++) {
                a = f[t + 2 * n * j] + w[J * t] * f[t + 2 * n * j + n];
                m = f[t + 2 * n * j] - w[J * t] * f[t + 2 * n * j + n];

                f[t + 2 * n * j    ] = a;
                f[t + 2 * n * j + n] = m;
            }
        }
    }

    for (int t = 0; t < N; t++)
        f[t] = f[t] / static_cast<long double>(N);

    return f;
}

一応,データ数が2の累乗数でない場合はif文で除外し,空のvectorを返すようにしました (実際に使うわけないのであまり意味がない気もしますが...)。

実行結果

fftとifftを実行してみました。 fftに対する入力データは以下のように生成しました。

int N = (1 << 10);
std::vector<comp> f(N);
for (int i = 0; i < f.size(); i++) {
    f[i] = 0.0;
    f[i] +=   10 * std::cos((050) * 2.0 * pi * static_cast<long double>(i) / static_cast<long double>(N));
    f[i] +=    5 * std::cos((200) * 2.0 * pi * static_cast<long double>(i) / static_cast<long double>(N));
    f[i] +=   25 * std::cos((350) * 2.0 * pi * static_cast<long double>(i) / static_cast<long double>(N));
}

この入力用のデータをプロットすると以下のようになります。

f:id:muushi:20190515001755p:plain
入力信号
プロットするさいに,絶対値を取ってからプロットしたため負の値がないです。以降プロットする場合も絶対値を取ります。

このように生成したデータに対してfftを実行し,得られた結果をプロットすると以下のようになりました。

f:id:muushi:20190515001830p:plain
FFT出力
実際に横軸が50, 200, 350のあたりに信号があることが確認できます。

さらに,fftで得られた結果を入力としてifftを実行して得られた結果は以下のようになりました。

f:id:muushi:20190515001856p:plain
IFFT出力
入力した信号とほぼ遜色なさそうなことが見てわかります。

おわりに

今回はFFT(とIFFT)を実装しました。 実際に実装してみると,DFTを愚直に実装したコードの2倍位のコードで速度がものすごく早くなるので,アルゴリズムってすごいなぁって感じがします。

せっかくなので,データ数が2の累乗数に限らないような場合でもやってみたいです。