book
归档: solution 
flag

好题好做法…

感谢 gjx 的讲解。

Solution

要用到的结论: $\frac{1}{a}\frac{1}{b}=(\frac{1}{a}-\frac{1}{b})\frac{1}{b-a}$ 。

我们要求的就是下面这个生成函数的第 $n$ 项: $$\prod \limits_{i=1}^{m} \frac{1}{1-(ui+v)x}$$

上面这个生成函数可以化为这样的形式: $$(\frac{1}{ux})^{m-1} \sum \limits_{i=1}^m \frac{a_i}{1-(ui-v)x}$$

其中 $a_i$ 是该项的不定系数。

至于为什么是这样,我们下面继续。

$p_i=1-(ui+v)x$ 。我们从 $1\to m$ 处理这个生成函数,即假设我们已经求出了 $1\to m$ 的生成函数,然后我们要把这个函数乘上一个 $\frac{1}{p_{m+1}}$ 来推出 $1\to m+1$ 的生成函数。

那我们考虑对于第二种形式的生成函数的每个 $i$ ,乘上这个之后的变化。

即: $$\frac{a_i}{p_i}\frac{1}{p_{m+1}} = a_i(\frac{1}{p_i}-\frac{1}{p_{m+1}})\frac{1}{u(i-m-1)x}$$

此时我们发现对于每一项都会多乘上一个 $\frac{1}{ux}$ ,因此我们可以将其整体提出,变成上面生成函数的形式。

我们考虑第一次插入 $\frac{1}{p_{i}}$ 这一项之后,令它的系数为 $f_i$ 。那么在插入 $i+1$ 项时,$f_i$ 会乘上一个 $\frac{1}{ux(-1)}$ ,插入第 $i+2$ 项时会乘上一个 $\frac{1}{ux(-2)}$ 。我们惊讶的发现有: $$a_i=(\frac{1}{ux})^{m-i}(-1)^{m-i}\frac{1}{(m-i)!}f_i$$

再考虑 $f_{m+1}$ 的值: $$f_{m+1}=\sum \limits_{i=1}^{m} (-1)^{m-i}\frac{1}{(m+1-i)!}f_i$$

我们考虑 $f_i$ 的生成函数为 $F$ ,则有: $$F=F*(-e^{-x}+1)+x \Rightarrow F=xe^x$$

则我们可以得到最终生成函数的每一部分的第 $n$ 项的和: $$\sum \limits_{i=1}^{m} (ui+v)^{n+m-1} u^{-m+1} (-1)^{m-i} \frac{1}{(m-i)!} f_i$$

复杂度 $O(m\log n)$ 。

Code

// Code by ajcxsu
// Problem: IOer

#include<bits/stdc++.h>
#define MOD (998244353ll)
using namespace std;

template<typename T> inline void gn(T &x) {
    char ch=getchar(), pl=0; x=0;
    while(!isdigit(ch)) pl=(ch=='-'), ch=getchar();
    while(isdigit(ch)) x=x*10+ch-'0', ch=getchar(); x*=pl?-1:1;
}

const int N=2e5+10;
typedef long long ll;
ll finv[N], f[N];
ll qpow(ll x, ll y) {
    ll ret=1;
    while(y) {
        if(y&1) ret=ret*x%MOD;
        x=x*x%MOD, y>>=1;
    }
    return ret;
}
ll n, m, u, v;
ll pl[]={1, -1};
int solve() {
    int ans=0;
    ll invu=qpow(u, (MOD-2)*(m-1));
    for(int i=1; i<=m; i++)
        ans=(ans+qpow((u*i+v)%MOD, n+m-1)*invu%MOD*f[i]%MOD*pl[(m-i)&1]*finv[m-i]%MOD+MOD)%MOD;
    return ans;
}

int main() {
    finv[0]=finv[1]=1;
    for(int i=2; i<N; i++) finv[i]=MOD-1ll*MOD/i*finv[MOD%i]%MOD;
    for(int i=2; i<N; i++) finv[i]=finv[i]*finv[i-1]%MOD;
    for(int i=0; i<N-1; i++) f[i+1]=finv[i];
    int T;
    gn(T);
    while(T--) {
        gn(n), gn(m), gn(u), gn(v);
        printf("%d\n", solve());
    }
    return 0;
}
navigate_before navigate_next