多项式优化常系数齐次线性递推

参考

https://www.cnblogs.com/Troywar/p/9078013.html https://www.cnblogs.com/cjyyb/p/10152566.html https://www.cnblogs.com/BAJimH/p/10574975.html https://blog.csdn.net/jokerwyt/article/details/85345981?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1.channel_param

线性递推

给出长为k的a数列$<a_0,a_1…a_{k-1}>$和一个无穷数列f的前k项$<f_1,f_2…f_{k}>$,求f_n。

$$f_n=\sum_{i=1}^ka_if_{k-i}$$

不同做法的复杂度比较

  • $暴力O(nk)$
  • $矩阵快速幂优化O(k^3\log n)$
  • $暴力多项式快速幂优化O(k^2\log n)$
  • $快速幂套NTT多项式取模优化O(k\log k\log n)$

求解思路

矩阵快速幂求线性地推,从一个初始矩阵开始递推,用矩阵乘法,最后在和f相乘得答案。 这里主要的复杂度在于矩阵的阶数k,如果k很大很大,那还不如直接暴力,所以就有多项式的做法了。

和快速幂一样,把矩阵乘法换成多项式乘法,取模换成多项式取模。

多项式乘法可以用NTT加速。

多项式取模: $$A(x)=B(x)D(x)+R(x)$$

已知A(x)和B(x),求商D(x)和余数R(x)。

步骤:

  • 将多项式系数反转,使得最高次幂为n-m。设反转之后为$A_R(x)=B_R(x)D_R(x) \;\;mod \;x^{n-m+1}$
  • $D(x)=reverse(A_R(x)*B_R^{-1}(x))$,即A乘B的逆再反转即可。
  • R(x)直接用A(x)-B(x)D(x)得到。

然后就到为什么可以用多项式处理常系数齐次线性递推。

由于笔者能力有限,只能看着大佬们的博客敲敲模板,详细解法不再赘述。

整理一下思路:

已知$f_n$,通过以下步骤得到$f_{2n}$:

  • 将表达系数多项式平方,使用FFT加速。$O(k \log k)$
  • 将求得的多项式对特征多项式取模。$O ( k \log ⁡ k )$

因此,要求得$f_n, 从f_1$倍增即可,就是上文说的多项式快速幂。而代码里的一些操作就是黑科技了。

笔者没有用NTT,直接用的任意模数MTT。使用方法为:

1
2
inline void MTT(ll *x, ll *y, ll *z, int len)
// 多项式x与y相乘得到z并返回,len为乘法中需要的长度。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
##include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pdd;

##define INF 0x3f3f3f3f
##define lowbit(x) x & (-x)
##define mem(a, b) memset(a , b , sizeof(a))
##define FOR(i, x, n) for(int i = x;i <= n; i++)

const ll mod = 998244353;
// const ll mod = 1e9 + 7;
// const double eps = 1e-6;
const double PI = acos(-1);
// const double R = 0.57721566490153286060651209;

const int N = 3e5 + 10;

struct Complex {
double x, y;
Complex(double a = 0, double b = 0): x(a), y(b) {}
Complex operator + (const Complex &rhs) { return Complex(x + rhs.x, y + rhs.y); }
Complex operator - (const Complex &rhs) { return Complex(x - rhs.x, y - rhs.y); }
Complex operator * (const Complex &rhs) { return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x); }
Complex conj() { return Complex(x, -y); }
} w[N];

int tr[N];

ll quick_pow(ll a, ll b) {
ll ans = 1;
while(b) {
if(b & 1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}

int getLen(int n) {
int len = 1; while (len < (n << 1)) len <<= 1;
for (int i = 0; i < len; i++) tr[i] = (tr[i >> 1] >> 1) (i & 1 ? len >> 1 : 0);
for (int i = 0; i < len; i++) w[i] = w[i] = Complex(cos(2 * PI * i / len), sin(2 * PI * i / len));
return len;
}

void rever(ll *f, int n) { for(int i = 0, j = n - 1;i < j; i++, j--) swap(f[i], f[j]); }

void FFT(Complex *A, int len) {
for (int i = 0; i < len; i++) if(i < tr[i]) swap(A[i], A[tr[i]]);
for (int i = 2, lyc = len >> 1; i <= len; i <<= 1, lyc >>= 1)
for (int j = 0; j < len; j += i) {
Complex *l = A + j, *r = A + j + (i >> 1), *p = w;
for (int k = 0; k < i >> 1; k++) {
Complex tmp = *r * *p;
*r = *l - tmp, *l = *l + tmp;
++l, ++r, p += lyc;
}
}
}

inline void MTT(ll *x, ll *y, ll *z, int len) {

for (int i = 0; i < len; i++) (x[i] += mod) %= mod, (y[i] += mod) %= mod;
static Complex a[N], b[N];
static Complex dfta[N], dftb[N], dftc[N], dftd[N];

for (int i = 0; i < len; i++) a[i] = Complex(x[i] & 32767, x[i] >> 15);
for (int i = 0; i < len; i++) b[i] = Complex(y[i] & 32767, y[i] >> 15);
FFT(a, len), FFT(b, len);
for (int i = 0; i < len; i++) {
int j = (len - i) & (len - 1);
static Complex da, db, dc, dd;
da = (a[i] + a[j].conj()) * Complex(0.5, 0);
db = (a[i] - a[j].conj()) * Complex(0, -0.5);
dc = (b[i] + b[j].conj()) * Complex(0.5, 0);
dd = (b[i] - b[j].conj()) * Complex(0, -0.5);
dfta[j] = da * dc;
dftb[j] = da * dd;
dftc[j] = db * dc;
dftd[j] = db * dd;
}
for (int i = 0; i < len; i++) a[i] = dfta[i] + dftb[i] * Complex(0, 1);
for (int i = 0; i < len; i++) b[i] = dftc[i] + dftd[i] * Complex(0, 1);
FFT(a, len), FFT(b, len);
for (int i = 0; i < len; i++) {
int da = (ll)(a[i].x / len + 0.5) % mod;
int db = (ll)(a[i].y / len + 0.5) % mod;
int dc = (ll)(b[i].x / len + 0.5) % mod;
int dd = (ll)(b[i].y / len + 0.5) % mod;
z[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
}
}

void Get_Inv(ll *f, ll *g, int n) {
if(n == 1) { g[0] = quick_pow(f[0], mod - 2); return ; }
Get_Inv(f, g, (n + 1) >> 1);

int len = getLen(n);
static ll c[N];
for(int i = 0;i < len; i++) c[i] = i < n ? f[i] : 0;
MTT(c, g, c, len); MTT(c, g, c, len);
for(int i = 0;i < n; i++) g[i] = (2ll * g[i] - c[i] + mod) % mod;
for(int i = n;i < len; i++) g[i] = 0;
for(int i = 0;i < len; i++) c[i] = 0;
}

int len;
int n, k;
ll a[N], h[N];
ll ans[N], s[N];
ll invG[N], G[N];

void Mod(ll *f,ll *g) {
static ll tmp[N];
rever(f, k + k - 1);
for(int i = 0;i < k; i++) tmp[i] = f[i];
MTT(tmp, invG, tmp, len);
for(int i = k - 1; i < len; i++) tmp[i] = 0;
rever(f, k + k - 1); rever(tmp, k - 1);
MTT(tmp, G, tmp, len);
for(int i = 0;i < k; i++) g[i] = (f[i] + mod - tmp[i]) % mod;
for(int i = k;i < len; i++) g[i] = 0;
for(int i = 0;i < len; i++) tmp[i] = 0;
}

void fpow(int b) {
s[1] = 1; ans[0] = 1;
while(b) {
if(b & 1) { MTT(ans, s, ans, len);
Mod(ans, ans); }
MTT(s, s, s, len);
Mod(s, s);
b >>= 1;
}
}

ll DITI(ll *a, ll *h, ll *ans, int n, int k) {
G[k] = 1; for(int i = 1;i <= k; i++) G[k - i] = (mod - a[i]) % mod;
rever(G, k + 1);
len = getLen(k + 1);
Get_Inv(G, invG, k + 1);
for(int i = k + 1;i < len; i++) invG[i] = 0;
rever(G, k + 1);
fpow(n);
ll Ans = 0;
for(int i = 0;i < k; i++) Ans = (Ans + 1ll * h[i] * ans[i] % mod) % mod;
return Ans;
}

void solve()
{
cin >> n >> k;
for(int i = 1;i <= k; i++){ cin >> a[i]; a[i] = a[i] < 0 ? a[i] + mod : a[i]; }
for(int i = 0;i < k; i++) { cin >> h[i]; h[i] = h[i] < 0 ? h[i] + mod : h[i]; }

ll Ans = DITI(a, h, ans, n, k);
cout << Ans << endl;
}

signed main() {
ios_base::sync_with_stdio(false);
//cin.tie(nullptr);
//cout.tie(nullptr);
##ifdef FZT_ACM_LOCAL
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
signed test_index_for_debug = 1;
char acm_local_for_debug = 0;
do {
if (acm_local_for_debug == '$') exit(0);
if (test_index_for_debug > 20)
throw runtime_error("Check the stdin!!!");
auto start_clock_for_debug = clock();
solve();
auto end_clock_for_debug = clock();
cout << "Test " << test_index_for_debug << " successful" << endl;
cerr << "Test " << test_index_for_debug++ << " Run Time: "
<< double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
cout << "--------------------------------------------------" << endl;
} while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
##else
solve();
##endif
return 0;
}

-–多项式是真的难!!!–

本文作者:jujimeizuo
本文地址https://blog.jujimeizuo.cn/2020/11/04/polynomial-optimization-constant-coefficient-homogeneous-linear-recursion/
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0 协议。转载请注明出处!