[多项式学习笔记]快速 Fourier 变换

快速 Fourier 变换是离散 Fourier 的一种优化,在 OI 中常用来加速一些卷积。

算法思路

首先,什么是 Fourier 变换?

连续的 Fourier 变换将可积函数 $f:\R\to\C$ 表示为复指数函数的积分或级数形式:

,其中 $\xi$ 为任意实数;Fourier 逆变换用 $\hat f$ 确定 $f$:

,其中 $x$ 为任意实数。

离散 Fourier 变换就是将连续的积分转为离散的求和。


如果给定多项式 $f(x)$、$g(x)$,求多项式 $h(x)=f(x)\cdot g(x)$,容易给出一种朴素的 $\Theta(\deg f\cdot \deg g)$ 的算法。

考虑用将特定选取的 $x_1,\cdots, x_n$ 代入多项式 $f$ 得到的结果 $f(x_1),\cdots,f(x_n)$ 来表示 $f$,用同样的方法表示 $g$,再将结果直接点乘,再过一次逆变换,就可以得到 $h$ 了。

根据连续 Fourier 变换的形式不难知道,这里的 $x_1,\cdots,x_n$ 为单位根 $\omega_n^0,\cdots, \omega_n^{(n-1)}$。

称代入 $\omega_n^i$ 的过程为 离散 Fourier 变换(Discrete Fourier Transform, DFT)。

朴素的做法仍然是 $\Theta(n^2)$ 的,我们考虑分治:

发现括号内形式一样,令 $f_1(x)=a_0+a_2x+\cdots$,$f_2(x)=a_1+a_3x+\cdots$,则

当然,为了使分治后两个多项式度数相同,我们需要令 $n=2^k$。

将 $\omegan^k$ 代入:$f(\omega_n^k)=f_1(\omega{n/2}^k)+\omega{n/2}^kf_2(\omega{n/2}^k)$。我们仅需要知道 $f1,f_2$ 对 $\omega{n/2}^i$ 的所有取值即可算得 $f$ 对 $\omega_n^i$ 的所有取值。

于是我们只需要先将奇偶次项分离,分别计算后再进行合并,时间复杂度 $T(n)=2T(\dfrac n2)+\Theta(n)=\Theta(n\log n)$。

发现奇偶次项分离的空间复杂度为 $\Theta(n\log n)$,且内存开销较大,难以接受。

观察系数分离情况:

不难证明 $i$ 在结束后的位置 $j$ 一定为 $i$ 在 $k$ 位二进制表达下逆序的值 $\operatorname{rev}(i)$。

利用 $\operatorname{rev}(i)=\left\lfloor\dfrac{\operatorname{rev}\left(\left\lfloor\dfrac i2\right\rfloor\right)}2\right\rfloor+[i\bmod 2=1]\cdot\dfrac n2$ 能够在 $\Theta(n)$ 的时间复杂度递推求得 $\operatorname{rev}(i)$。

离散 Fourier 变换的逆变换称为 逆离散 Fourier 变换(Inverse Discrete Fourier Transform, IDFT)。

容易知道,若视 $[a_1,\cdots, a_n]^{\rm tr}$ 与 $[y_1,\cdots, y_n]^{\rm tr}$ 为列向量,则 DFT 本质上为左乘一个矩阵

容易验证该矩阵的逆矩阵为其每一项取倒数再除以 $n$。

所以 IDFT 时仅需将 DFT 用的单位根取倒数,算完除以 $n$ 即可。

代码

多项式乘法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/**
* @file #108. 多项式乘法.cpp
* @author Kuriyama Mirai (hermione_granger@foxmail.com)
* @brief
* @link https://loj.ac/s/1368664
* @date 2022-01-30
*
* @copyright Copyright (c) 2022
*
*/
#include <cstdio>
#include <complex>
#include <algorithm>
#include <cmath>

namespace mirai {

typedef std::complex<long double> comp;

constexpr int MAXN = 262514; // 262144
int rev[MAXN];
comp a[MAXN], b[MAXN], c[MAXN];

void fft(comp* arr, int len, int on) {
for (int i = 0; i < len; ++i) {
if (i < rev[i]) {
std::swap(arr[i], arr[rev[i]]);
}
}
for (int n = 2; n <= len; n <<= 1) {
comp omega_n(std::cos(2 * M_PI / n), std::sin(on * 2 * M_PI / n));
for (int i = 0; i < len; i += n) {
comp omega = 1;
for (int j = i; j < i + n / 2; ++j) { // 注意 j 的取值!
auto u = arr[j], v = arr[j + n / 2];
arr[j] = u + omega * v;
arr[j + n / 2] = u - omega * v;
omega *= omega_n;
}
}
}
if (on == -1) {
for (int i = 0; i < len; ++i) {
arr[i] /= len;
}
}
}

int main(int argc, char** argv) {
int n, m;
std::scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i) {
int read;
std::scanf("%d", &read);
a[i] = read;
}
for (int i = 0; i <= m; ++i) {
int read;
std::scanf("%d", &read);
b[i] = read;
}
int len = 1;
while (len < n + m + 2) {
len <<= 1;
}

for (int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= len >> 1;
}
// std::printf("%d:%d\n", i, rev[i]);
}

fft(a, len, 1);
fft(b, len, 1);
for (int i = 0; i < len; ++i) {
c[i] = a[i] * b[i];
}
fft(c, len, -1);
for (int i = 0; i <= n + m; ++i) {
std::printf("%lld ", std::llroundl(c[i].real())); // 一定要注意四舍五入而非向下取整
}
std::printf("\n");
return 0;
}

}

int main(int argc, char** argv) {
return mirai::main(argc, argv);
}

洛谷的坑人模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/**
* @file P3803 【模板】多项式乘法(FFT) V.cpp
* @author Kuriyama Mirai (hermione_granger@foxmail.com)
* @brief
* @date 2022-01-30
* @link https://www.luogu.com.cn/record/68373463
*
* @copyright Copyright (c) 2022
*
*/
#include <cstdio>
#include <complex>
#include <algorithm>
#include <cmath>

namespace mirai {

typedef std::complex<long double> comp;

constexpr int MAXN = 2100000; // 1048576*2
int rev[MAXN];
comp a[MAXN], b[MAXN], c[MAXN];

void fft(comp* arr, int len, int on) {
for (int i = 0; i < len; ++i) {
if (i < rev[i]) {
std::swap(arr[i], arr[rev[i]]);
}
}
for (int n = 2; n <= len; n <<= 1) {
comp omega_n(std::cos(2 * M_PI / n), std::sin(on * 2 * M_PI / n));
for (int i = 0; i < len; i += n) {
comp omega = 1;
for (int j = i; j < i + n / 2; ++j) { // 注意 j 的取值!
auto u = arr[j], v = arr[j + n / 2];
arr[j] = u + omega * v;
arr[j + n / 2] = u - omega * v;
omega *= omega_n;
}
}
}
if (on == -1) {
for (int i = 0; i < len; ++i) {
arr[i] /= len;
}
}
}

int main(int argc, char** argv) {
int n, m;
std::scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i) {
int read;
std::scanf("%d", &read);
a[i] = read;
}
for (int i = 0; i <= m; ++i) {
int read;
std::scanf("%d", &read);
b[i] = read;
}
int len = 1; // 注意当 len=1 时处理不了 n+m=1 的情况
while (len < n + m + 2) { // 注意 +2
len <<= 1;
}

for (int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= len >> 1;
}
// std::printf("%d:%d\n", i, rev[i]);
}

fft(a, len, 1);
fft(b, len, 1);
for (int i = 0; i < len; ++i) {
c[i] = a[i] * b[i];
}
fft(c, len, -1);
for (int i = 0; i <= n + m; ++i) {
std::printf("%lld ", std::llroundl(c[i].real())); // 一定要注意四舍五入而非向下取整
}
std::printf("\n");
return 0;
}

}

int main(int argc, char** argv) {
return mirai::main(argc, argv);
}

高精度乘法模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/**
* @file P1919 【模板】A*B Problem 升级版(FFT 快速傅里叶变换).cpp
* @author Kuriyama Mirai (hermione_granger@foxmail.com)
* @brief
* @date 2022-01-30
* @link https://www.luogu.com.cn/record/68373411
*
* @copyright Copyright (c) 2022
*
*/
#include <cstdio>
#include <complex>
#include <algorithm>
#include <cmath>
#include <cctype>

namespace mirai {

typedef std::complex<long double> comp;

constexpr int MAXN = 2100000; // 1048576*2
int rev[MAXN];
comp a[MAXN], b[MAXN], c[MAXN];
int ans[MAXN];

void fft(comp* arr, int len, int on) {
for (int i = 0; i < len; ++i) {
if (i < rev[i]) {
std::swap(arr[i], arr[rev[i]]);
}
}
for (int n = 2; n <= len; n <<= 1) {
comp omega_n(std::cos(2 * M_PI / n), std::sin(on * 2 * M_PI / n));
for (int i = 0; i < len; i += n) {
comp omega = 1;
for (int j = i; j < i + n / 2; ++j) { // 注意 j 的取值!
auto u = arr[j], v = arr[j + n / 2];
arr[j] = u + omega * v;
arr[j + n / 2] = u - omega * v;
omega *= omega_n;
}
}
}
if (on == -1) {
for (int i = 0; i < len; ++i) {
arr[i] /= len;
}
}
}

int main(int argc, char** argv) {
int n = 0, m = 0;
char ch;
while (std::isdigit(ch = std::getchar())) {
a[n++] = ch - '0';
}
while (!std::isdigit(ch = std::getchar()));
b[m++] = ch - '0';
while (std::isdigit(ch = std::getchar())) {
b[m++] = ch - '0';
}
--n;
--m;
for (int i = 0; i <= n / 2; ++i) {
std::swap(a[i], a[n - i]);
}
for (int i = 0; i <= m / 2; ++i) {
std::swap(b[i], b[m - i]);
}
// for (int i = 0; i <= n; ++i) {
// // std::printf("%lld ", std::llroundl(a[i].real()));
// std::printf("%Lf ", a[i].real());
// }
// std::puts("");
// for (int i = 0; i <= m; ++i) {
// // std::printf("%lld ", std::llroundl(b[i].real()));
// std::printf("%Lf ", b[i].real());
// }
// std::puts("");
int len = 1; // 注意当 len=1 时处理不了 n+m=1 的情况
while (len < n * 2 + 2 || len < m * 2 + 2) { // 注意 +2!
len <<= 1;
}

for (int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= len >> 1;
}
// std::printf("%d:%d\n", i, rev[i]);
}

fft(a, len, 1);
fft(b, len, 1);
for (int i = 0; i < len; ++i) {
// std::printf(" %Lf+%Lfi %Lf+%Lfi\n", a[i].real(), a[i].imag(), b[i].real(), b[i].imag());
c[i] = a[i] * b[i];
}
fft(c, len, -1);
long long last = 0;
for (int i = 0; i <= n + m + 1; ++i) { // 注意 +1
long long now = std::llroundl(c[i].real()) + last; // 一定要注意四舍五入而非向下取整
ans[i] = now % 10;
last = now / 10;
}
for (int i = n + m + 1; i >= 0; --i) {
if (i == n + m + 1 && ans[i] == 0) { continue; }
std::printf("%d", ans[i]);
}
std::printf("\n");
return 0;
}

}

int main(int argc, char** argv) {
return mirai::main(argc, argv);
}

三次变两次优化

正常的 FFT 要做三次 DFT,非常不优秀。考虑如下变换:

,其中平方是指点乘自身,于是我们能用两次 DFT 解决:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/**
* @file #108. 多项式乘法.cpp
* @author Kuriyama Mirai (hermione_granger@foxmail.com)
* @brief
* @link https://loj.ac/s/1369337
* @link https://www.luogu.com.cn/record/68430042
* @date 2022-01-30
*
* @copyright Copyright (c) 2022
*
*/
#include <cstdio>
#include <complex>
#include <algorithm>
#include <cmath>

namespace mirai {

typedef std::complex<long double> comp;

constexpr int MAXN = 2200000; // 262144
int rev[MAXN];
comp a[MAXN], b[MAXN], c[MAXN];

void fft(comp* arr, int len, int on) {
for (int i = 0; i < len; ++i) {
if (i < rev[i]) {
std::swap(arr[i], arr[rev[i]]);
}
}
for (int n = 2; n <= len; n <<= 1) {
comp omega_n(std::cos(2 * M_PI / n), std::sin(on * 2 * M_PI / n));
for (int i = 0; i < len; i += n) {
comp omega = 1;
for (int j = i; j < i + n / 2; ++j) { // 注意 j 的取值!
auto u = arr[j], v = arr[j + n / 2];
arr[j] = u + omega * v;
arr[j + n / 2] = u - omega * v;
omega *= omega_n;
}
}
}
if (on == -1) {
for (int i = 0; i < len; ++i) {
arr[i] /= len;
}
}
}

int main(int argc, char** argv) {
int n, m;
std::scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i) {
int read;
std::scanf("%d", &read);
a[i] = read;
}
for (int i = 0; i <= m; ++i) {
int read;
std::scanf("%d", &read);
a[i] += comp(0, read);
}
int len = 1;
while (len < n + m + 2) {
len <<= 1;
}

for (int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= len >> 1;
}
// std::printf("%d:%d\n", i, rev[i]);
}

fft(a, len, 1);
for (int i = 0; i < len; ++i) {
a[i] = a[i] * a[i];
}
fft(a, len, -1);
for (int i = 0; i <= n + m; ++i) {
std::printf("%lld ", std::llroundl(a[i].imag() / 2)); // 一定要注意四舍五入而非向下取整
}
std::printf("\n");
return 0;
}

}

int main(int argc, char** argv) {
return mirai::main(argc, argv);
}