直入正题,我们需要解决的问题是多项式乘法。给定两个最高次为nnmm 的多项式A(x)A(x)B(x)B(x),我们需要计算它们的乘积C(x)=A(x)B(x)C(x) = A(x) \cdot B(x)

直接计算的方法需要O(nm)O(nm) 的时间复杂度,而FFT可以将其优化到O((n+m)log(n+m))O((n+m)\log(n+m))


时域和频域

相信我,你不需要有 EE 背景也能理解这个问题。图片可以参考这篇文章:知乎

总而言之,一般的多项式表示办法:A(x)=a0+a1x+a2x2++an1xn1A(x) = a_0 + a_1 x + a_2 x^2 + \ldots + a_{n-1} x^{n-1},就是一种时域的表示方法。这种方法的优点是直观,缺点是乘法需要O(nm)O(nm) 的时间复杂度。

与此对应的,还有一种表示办法:A={(x0,A(x0)),(x1,A(x1)),,(xk,A(xk))}A = \{(x_0, A(x_0)), (x_1, A(x_1)), \ldots, (x_k, A(x_k))\},其中xix_i 是一些特定的点。这种方法叫做频域的表示方法。它的优点是 (i) 要计算多项式乘法,将各个点的A(xi)A(x_i)B(xi)B(x_i) 相乘就行了,时间复杂度是O(k)O(k);(ii) 有kk 个点就能唯一确定一个最高次为k1k-1 的多项式。

时域和频域的对应关系 —— 如果假设A(x)A(x) 是一堆正弦波的叠加,在时域中我们看到的是一个复杂的波形(如矩形波),代表最终的多项式值;在频域中,我们看到的是构成A(x)A(x) 的各个正弦波,在不同频率上的波有着不同的振幅(即A(xi)A(x_i) 的值)。因此,时域和频域是同一个多项式的两种不同表示方法。

转载自上面的知乎链接,感谢原作者

因此,我们要进行多项式乘法,其实思路是很简单的:

  1. DFT(离散傅里叶变换):将A(x)A(x)B(x)B(x) 从时域转换到频域,得到A(xi)A(x_i)B(xi)B(x_i)
  2. 点乘:在频域中,对应点相乘,得到C(xi)=A(xi)B(xi)C(x_i) = A(x_i) \cdot B(x_i)
  3. IDFT(离散逆傅里叶变换):将C(xi)C(x_i) 从频域转换回时域,得到C(x)C(x)

单位根

不过说起来,如果真的只是随便选点AiA_i 的话,计算起来并不会更快。考虑一下,A(x)A(x) 就要选n+1n+1 个点,计算这些点的A(xi)A(x_i) 已经是O(n2)O(n^2) 的时间复杂度了。但是就不能更快了吗?显然不是。

FFT 的核心就是选择了一些特殊的点,这些点叫做单位根(roots of unity)。我相信你有一定的复数基础,总之至少应该见过三次主单位根ω3=1+i32\omega_3 = \frac{-1 + i\sqrt{3}}{2},它的平方ω32=1i32\omega_3^2 = \frac{-1 - i\sqrt{3}}{2},和ω30=1\omega_3^0 = 1。它们一起构成了一个三次单位根的集合,平分了单位圆。也因此,可以写成:ω3=e2πi/3\omega_3 = e^{2\pi i / 3}ω32=e4πi/3\omega_3^2 = e^{4\pi i / 3}ω30=e0=1\omega_3^0 = e^{0} = 1

一般地,nn 次单位根的集合可以表示为{ωn0,ωn1,ωn2,,ωnn1}\{\omega_n^0, \omega_n^1, \omega_n^2, \ldots, \omega_n^{n-1}\},其中ωn=e2πi/n=cos(2π/n)+isin(2π/n)\omega_n = e^{2\pi i / n} = \cos(2\pi / n) + i \sin(2\pi / n)

很明显,单位根有强烈的对称性:

  • ωnn=1\omega_n^n = 1
  • 对称性:ωnn/2=1\omega_n^{n/2} = -1(当nn 是偶数时),所以ωnk+n/2=ωnk\omega_n^{k + n/2} = -\omega_n^k
  • 折半引理:(ωnk)2=ωn2k=ωn/2k(\omega_n^k)^2 = \omega_n^{2k} = \omega_{n/2}^k(当nn 是偶数时)。

分治

因此我们就可以通过分治的方法来计算 DFT 了。

我们可以将A(x)=a0+a1x+a2x2++an1xn1A(x) = a_0 + a_1 x + a_2 x^2 + \ldots + a_{n-1} x^{n-1} 分成两部分:

A(x)=(a0+a2x2+a4x4+)+x(a1+a3x2+a5x4+)A(x) = (a_0 + a_2 x^2 + a_4 x^4 + \ldots) + x (a_1 + a_3 x^2 + a_5 x^4 + \ldots)

前后两部分是相似的结构,我们可以定义:

  • Aeven(x)=a0+a2x+a4x2+A_{\text{even}}(x) = a_0 + a_2 x + a_4 x^2 + \ldots(偶数项)
  • Aodd(x)=a1+a3x+a5x2+A_{\text{odd}}(x) = a_1 + a_3 x + a_5 x^2 + \ldots(奇数项)

A(x)=Aeven(x2)+xAodd(x2)A(x) = A_{\text{even}}(x^2) + x A_{\text{odd}}(x^2)

将单位根ωnk\omega_n^k 代入A(x)A(x) 中,我们得到:

A(ωnk)=Aeven((ωnk)2)+ωnkAodd((ωnk)2)=Aeven(ωn/2k)+ωnkAodd(ωn/2k)\begin{aligned} A(\omega_n^k) &= A_{\text{even}}((\omega_n^k)^2) + \omega_n^k A_{\text{odd}}((\omega_n^k)^2) \\ &= A_{\text{even}}(\omega_{n/2}^k) + \omega_n^k A_{\text{odd}}(\omega_{n/2}^k) \end{aligned}

好消息是,当我们计算A(ωnk+n/2)A(\omega_n^{k + n/2}) 时,利用对称性:

A(ωnk+n/2)=Aeven((ωnk+n/2)2)+ωnk+n/2Aodd((ωnk+n/2)2)=Aeven(ωn/2k+n/2)+ωnk+n/2Aodd(ωn/2k+n/2)=Aeven(ωn/2k)+ωnk+n/2Aodd(ωn/2k)()=Aeven(ωn/2k)ωnkAodd(ωn/2k)\begin{aligned} A(\omega_n^{k + n/2}) &= A_{\text{even}}((\omega_n^{k + n/2})^2) + \omega_n^{k + n/2} A_{\text{odd}}((\omega_n^{k + n/2})^2) \\ &= A_{\text{even}}(\omega_{n/2}^{k + n/2}) + \omega_n^{k + n/2} A_{\text{odd}}(\omega_{n/2}^{k + n/2}) \\ &= A_{\text{even}}(\omega_{n/2}^k) + \omega_n^{k + n/2} A_{\text{odd}}(\omega_{n/2}^k) \qquad (\ast) \\ &= A_{\text{even}}(\omega_{n/2}^k) - \omega_n^k A_{\text{odd}}(\omega_{n/2}^k) \end{aligned}

()(\ast) 注意这一步中,单位根的周期已经变成n2\frac{n}{2} 了,所以ωn/2k+n/2=ωn/2k\omega_{n/2}^{k + n/2} = \omega_{n/2}^k

这样一来,要计算A(ωnk)A(\omega_n^k)A(ωnk+n/2)A(\omega_n^{k + n/2}),只需要递归地计算Aeven(ωn/2k)A_{\text{even}}(\omega_{n/2}^k)Aodd(ωn/2k)A_{\text{odd}}(\omega_{n/2}^k) 就行了。

如果假设n+1=2kn + 1 = 2^k,也就是说总项数是 2 的幂次的话,那么得到的AevenA_{\text{even}}AoddA_{\text{odd}} 的项数都是2k12^{k-1},因此可以完美地进行递归。每一层递归的时间复杂度是O(n)O(n),总共有O(logn)O(\log n) 层递归,因此总的时间复杂度是O(nlogn)O(n \log n)

递归的边界条件是当n=1n = 1 的时候,直接返回A(x)=a0A(x) = a_0

从 DFT 到 IDFT

IDFT 的问题与 DFT 正好相反,我们已知A={(ωn0,A(ωn0)),(ωn1,A(ωn1)),,(ωnn1,A(ωnn1))}A = \{(\omega_n^0, A(\omega_n^0)), (\omega_n^1, A(\omega_n^1)), \ldots, (\omega_n^{n-1}, A(\omega_n^{n-1}))\},我们需要计算A(x)A(x) 的系数a0,a1,,an1a_0, a_1, \ldots, a_{n-1}

这就用到了 DFT 的逆变换公式:

ak=1nj=0n1A(ωnj)ωnjka_k = \frac{1}{n} \sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk}

这是如何得到的呢?我们可以将这nn 个多项式的值写成线性方程组:

[11111ωn1ωn2ωnn11ωn2ωn4ωn2(n1)1ωnn1ωn2(n1)ωn(n1)(n1)][a0a1a2an1]=[A(ωn0)A(ωn1)A(ωn2)A(ωnn1)]\begin{bmatrix} 1 & 1 & 1 & \ldots & 1 \\ 1 & \omega_n^1 & \omega_n^2 & \ldots & \omega_n^{n-1} \\ 1 & \omega_n^2 & \omega_n^4 & \ldots & \omega_n^{2(n-1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \ldots & \omega_n^{(n-1)(n-1)} \end{bmatrix} \begin{bmatrix}a_0 \\ a_1 \\ a_2 \\ \vdots \\ a_{n-1}\end{bmatrix} = \begin{bmatrix}A(\omega_n^0) \\ A(\omega_n^1) \\ A(\omega_n^2) \\ \vdots \\ A(\omega_n^{n-1})\end{bmatrix}

观察其中与aka_k 相关的列:

[1ωnkωn2kωn(n1)k]\begin{bmatrix} 1 \\ \omega_n^k \\ \omega_n^{2k} \\ \vdots \\ \omega_n^{(n-1)k} \end{bmatrix}

我们可以将这个列向量与A(ωnj)A(\omega_n^j) 的行向量进行点积:

j=0n1A(ωnj)ωnjk=A(ωn0)1+A(ωn1)ωnk+A(ωn2)ωn2k++A(ωnn1)ωn(n1)k=a0j=0n1ωn0+a1j=0n1ωnj+a2j=0n1ωn2j++an1j=0n1ωn(n1)j\begin{aligned} \sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk} &= A(\omega_n^0) \cdot 1 + A(\omega_n^1) \cdot \omega_n^{-k} + A(\omega_n^2) \cdot \omega_n^{-2k} + \ldots + A(\omega_n^{n-1}) \cdot \omega_n^{-(n-1)k} \\ &= a_0 \sum_{j=0}^{n-1} \omega_n^{0} + a_1 \sum_{j=0}^{n-1} \omega_n^{j} + a_2 \sum_{j=0}^{n-1} \omega_n^{2j} + \ldots + a_{n-1} \sum_{j=0}^{n-1} \omega_n^{(n-1)j} \end{aligned}

由于单位根的性质:

  • k=0k = 0 时,j=0n1ωn0=1+1++1=n\sum_{j=0}^{n-1} \omega_n^{0} = 1 + 1 + \ldots + 1 = n
  • k0k \neq 0 时,j=0n1ωnjk=ωn0+ωnk+ωn2k++ωn(n1)k\sum_{j=0}^{n-1} \omega_n^{jk} = \omega_n^{0} + \omega_n^{k} + \omega_n^{2k} + \ldots + \omega_n^{(n-1)k},无论kk 的值是多少,这都是所有单位根的一种排列,等于ωn0+ωn1+ωn2++ωnn1=0\omega_n^{0} + \omega_n^{1} + \omega_n^{2} + \ldots + \omega_n^{n-1} = 0

因此,只有当k=0k = 0 时,才会有非零的贡献:

j=0n1A(ωnj)ωnjk=akn\sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk} = a_k \cdot n

因此,我们可以得到

ak=1nj=0n1A(ωnj)ωnjka_k = \frac{1}{n} \sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk}

再观察这个公式,我们发现它与 DFT 的公式 $$A(\omega_n^k) = \sum_{j=0}^{n-1} a_j \cdot \omega_n^{jk}$$

的结构非常相似;唯一的区别是ωnjk\omega_n^{jk} 变成了ωnjk\omega_n^{-jk},以及前面多了一个1n\frac{1}{n} 的系数。

因此,我们可以通过相同的分治方法来计算 IDFT。


可能你发现了,以上思路和计算神经网络的输出层使用 softmax、误差使用 cross-entropy 时,计算Lwij\frac{\partial L}{\partial w_{ij}} 的思路有点共通之处,最终只有i=ki = k 的项会有非零的贡献。这玩意学名叫做克罗内克 delta 函数,反映的是函数的正交性。

除此之外类似的,就是有不止一种方法推导,以下是另一种:

要证明ak=1nj=0n1A(ωnj)ωnjka_k = \frac{1}{n} \sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk},我们可以将A(ωnj)A(\omega_n^j) 的定义代入:

S=1nj=0n1A(ωnj)ωnjk=1nj=0n1(m=0n1amωnjm)ωnjk=1nm=0n1am(j=0n1ωnj(mk))\begin{aligned} S &= \frac{1}{n} \sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk} \\ &= \frac{1}{n} \sum_{j=0}^{n-1} \left( \sum_{m=0}^{n-1} a_m \cdot \omega_n^{jm} \right) \cdot \omega_n^{-jk} \\ &= \frac{1}{n} \sum_{m=0}^{n-1} a_m \left( \sum_{j=0}^{n-1} \omega_n^{j(m-k)} \right) \end{aligned}

  • m=km = k 时,j=0n1ωn0=n\sum_{j=0}^{n-1} \omega_n^{0} = n
  • mkm \neq k 时,j=0n1ωnj(mk)=ωn0+ωnmk+ωn2(mk)++ωn(n1)(mk)\sum_{j=0}^{n-1} \omega_n^{j(m-k)} = \omega_n^{0} + \omega_n^{m-k} + \omega_n^{2(m-k)} + \ldots + \omega_n^{(n-1)(m-k)},和上面分析的一样,结果为00

所以,

S=1nm=0n1am{nif m=k,0if mk.S = \frac{1}{n} \sum_{m=0}^{n-1} a_m \cdot \begin{cases} n & \text{if } m = k, \\ 0 & \text{if } m \neq k. \end{cases}

因此,S=aknS = a_k \cdot n,从而得到:

ak=1nj=0n1A(ωnj)ωnjka_k = \frac{1}{n} \sum_{j=0}^{n-1} A(\omega_n^j) \cdot \omega_n^{-jk}

实现

不过实际上,众所周知,递归的开销比较大,因此我们通常会使用迭代的方式来实现 FFT 和 IFFT。迭代的核心思想是先进行位逆序(bit-reversal)排列,然后在每一轮迭代中,按照当前的子问题大小进行合并。

原索引二进制逆序二进制逆序索引
00000000
10011004
20100102
30111106
41000011
51011015
61100113
71111117
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <iostream>
#include <vector>
#include <complex>
#include <cmath>

using namespace std;
using cd = complex<double>;
using int64 = long long;
const double PI = acos(-1);

void fft(vector<cd> &a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) { // 位逆序排列
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j |= bit;
if (i < j)
swap(a[i], a[j]);
}

for (int len = 2; len <= n; len <<= 1) {
double angle = 2 * PI / len * (invert ? -1 : 1);
cd wlen(cos(angle), sin(angle));
for (int i = 0; i < n; i += len) {
cd w(1);
for (int j = 0; j < len / 2; j++) {
cd u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w *= wlen;
}
}
}

if (invert) {
for (cd &x : a)
x /= n;
}
}

vector<int64> multiply(const vector<int64> &a, const vector<int64> &b) {
vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());

// 长度 N >= n + m - 1,且 N 是 2 的幂次
int n = 1;
while (n < a.size() + b.size())
n <<= 1;
fa.resize(n);
fb.resize(n);

fft(fa, false);
fft(fb, false);
for (int i = 0; i < n; i++)
fa[i] *= fb[i];
fft(fa, true);

vector<int64> result(n);
for (int i = 0; i < n; i++)
result[i] = round(fa[i].real());
return result;
}

int main() {
// (1 + 2x + 3x^2) * (4 + 5x) = 4 + 13x + 22x^2 + 15x^3
vector<int64> a = {1, 2, 3};
vector<int64> b = {4, 5};
vector<int64> result = multiply(a, b);
for (int64 coeff : result)
cout << coeff << " ";
cout << endl;
}