[20220115杂题选讲]特征多项式

一道裸题。

题目大意

给定方阵 $A$,求 $p_A(\lambda)=|\lambda I-A|$。

$n\le 500$,对 $998244353$ 取模。

source: https://www.luogu.com.cn/problem/P7776

思路

正常求 $\R$ 上的矩阵行列式应该是 $\Theta(n^3)$ 的,但是这里显然不能忽略乘法的复杂度,复杂度会达到惊人的 $\Theta(n^4\log n)$,显然不可行,怎么办呢?

算法一

代 $n+1$ 个值给 $\lambda$,总共 $\Theta(n^4)$ 求出来分别的行列式,$\Theta(n^3)$ 插值,搞定。时间复杂度 $\Theta(n^4)$,空间复杂度 $\Theta(n^3)$。

期望得分 $40$。

算法二

回顾求行列式的方法:将原本的矩阵变为一个容易求解行列式的矩阵。

我们定义相似矩阵:$A\sim B\Leftrightarrow \exist P:A=PBP^{-1}$。

我们有 $A\sim B\Rightarrow p_A=p_B$,因为:

这时肯定有人想,要是能消成上三角就好了!但是非常抱歉,大多数形如 $|\lambda I-A|$ 的矩阵不能被消成上三角。

如果消成上三角,你不就是把这个多项式给分解了么?$\R$ 上分解多项式显然要到 $\C$ 上。

——数学神仙 whd

我们可以退而求其次,如果我们少消一斜线($a{12}-a{n-1n}$)(称这样的矩阵是上 Hessenberg 矩阵),这样我们仍然可以在 $\Theta(n^3)$ 的复杂度内递推:

设前 $i$ 行列的特征多项式为 $pi$,则 $p_0=1$,$p_1=\lambda-A{11}$,$pi=(\lambda-A{ii})p{i-1}{\color{red}-}\sum\limits{j=0}^{i-2}pjA{j+1,i}\prod\limits{k=j+2}^iA{k,k+1}$。

注意红色标注的符号,这里符号恒为 $-$ 是因为虽然常数提供的逆序对个数奇偶性一直在变动导致符号变动,但与此同时 $A$ 前的负号的次幂数奇偶性也一直在变动,符号就恒为负了。

注意到这里我们不需要再进行多项式乘法了,复杂度 $\Theta(n^3)$。

代码

多亏了那个负号,调了好几个小时!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <cstdio>
#include <algorithm>

using namespace std;

int max;
int sqrt;

namespace mirai {

constexpr int MAXN = 505;
constexpr long long MOD = 998244353;
int m[MAXN][MAXN];

long long pow(long long a, long long p) {
long long ans = 1;
while (p) {
if (p & 1) {
ans = ans * a % MOD;
}
a = a * a % MOD;
p >>= 1;
}
return ans;
}
inline long long inv(long long a) {
return pow(a, MOD - 2);
}

int p[MAXN][MAXN];

//void print(long long a) {
//// std::printf("%lld", a);
//// return;
// if (a <= 10000) {
// std::printf("%lld", a);
// } else if (MOD - a <= 10000) {
// std::printf("%lld", a - MOD);
// } else {
// for (long long i = 1; i <= 10000; ++i) {
// if (a * i % MOD <= 20000) {
// std::printf("%lld/%lld", a * i % MOD, i);
// break;
// }
// if ((MOD - a * i % MOD) <= 20000) {
// std::printf("%lld/%lld", (a * i % MOD) - MOD, i);
// break;
// }
// }
// }
//}

int main(int argc, char** argv) {
#ifdef MIRAI
std::freopen("ala.out", "r", stdin);
#endif
int n;
std::scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
std::scanf("%lld", &m[i][j]);
}
}
if (n == 1) {
std::printf("%lld %lld\n", (MOD - m[1][1]) % MOD, MOD - 1);
return 0;
}

for (int i = 1; i <= n; ++i) {
if (m[i + 1][i] == 0) {
for (int j = i + 2; j <= n; ++j) {
if (m[j][i] != 0) {
for (int k = i; k <= n; ++k) {
std::swap(m[i + 1][k], m[j][k]);
}
for (int k = 1; k <= n; ++k) {
std::swap(m[k][j], m[k][i + 1]);
}
}
}
}
for (int j = i + 2; j <= n; ++j) {
long long scale = (MOD - 1ll * m[j][i] * inv(m[i + 1][i]) % MOD) % MOD;
for (int k = i; k <= n; ++k) {
m[j][k] = (m[j][k] + 1ll * m[i + 1][k] * scale) % MOD;
}
for (int k = 1; k <= n; ++k) {
m[k][i + 1] = (1ll * m[k][j] * scale - m[k][i + 1]) % MOD;
m[k][i + 1] = !!m[k][i + 1] * (MOD - m[k][i + 1]);
}
}
}

// std::printf("\n");
// for (int i = 1; i <= n; ++i) {
// for (int j = 1; j <= n; ++j) {
// print(m[i][j]);
// std::printf(" ");
// }
// std::printf("\n");
// }

p[0][0] = 1;
p[1][1] = 1;
p[1][0] = !!m[1][1] * (MOD - m[1][1]);
for (int i = 2; i <= n; ++i) {
for (int k = 1; k <= i; ++k) {
p[i][k] = (1ll * m[i][i] * p[i - 1][k] - p[i - 1][k - 1]) % MOD;
p[i][k] = !!p[i][k] * (MOD - p[i][k]);
}
p[i][0] = 1ll * p[i - 1][0] * m[i][i] % MOD;
p[i][0] = !!p[i][0] * (MOD - p[i][0]);
// std::printf(" %d: ", i);
// for (int k = 0; k <= n; ++k) {
// print(p[i][k]);
// std::printf(" ");
// }
// std::printf("\n");
long long k1 = m[i][i - 1];
for (int j = i - 2; j >= 0; --j) {
long long tmp = 1ll * k1 * m[j + 1][i] % MOD;
tmp = !!tmp * (MOD - tmp);
for (int k = 0; k <= i; ++k) {
p[i][k] = (p[i][k] + 1ll * p[j][k] * tmp) % MOD;
}
k1 = 1ll * k1 * m[j + 1][j] % MOD;
// std::printf(" %d, %d: ", i, j);
// for (int k = 0; k <= n; ++k) {
// print(p[i][k]);
// std::printf(" ");
// }
}
// std::printf("\n");
}
for (int i = 0; i <= n; ++i) {
std::printf("%lld ", p[n][i]);
}
std::printf("\n");

return 0;
}

}

int main(int argc, char** argv) {
return mirai::main(argc, argv);
}