【笔记】从傅里叶变换到数论变换

【笔记】从傅里叶变换到数论变换

概述

我们可以使用傅里叶变换来快速计算卷积。然而快速傅立叶变换具有一些实现上的缺点,我们必须做复数而且是浮点数的运算,因此计算量会比较大,而且浮点数运算产生的误差会比较大。

这个问题可以用数论变换解决。在模意义下,一些整数可以发挥单位复根的作用,使我们更精确地计算只有整数参加的卷积运算结果。

原根

概念

在 $FFT$ 中,我们利用了单位复根的以下性质:

  1. $\omega_n^k$ 互不相同;
  2. $\omega_{2n}^{2k}=\omega_n^k$;
  3. $\omega_n^{k+\frac{n}{2}}=-\omega_n^k$
  4. 当 $k \neq 0$ 时,$1+\omega_n^k+(\omega_n^k)^2+\cdots+(\omega_n^k)^{n-1}=0$。

在数论变换中,若当前的模数 $p$ 能被表示成 $p=qn+1$(其中 $n$ 为 $2$ 的幂),它有原根 $g$,使得 $g^i(0\leq i\leq p-1)$ 互不相同。令 $\omega_n = g^q$,我们来一一验证这四个性质。

性质一

$1,g^q,g^{2q},\cdots,g^{(n-1)q}$ 互不相同。

性质二

$\omega_{2n}=g^{\frac{q}{2}}(p=\frac{q}{2}\times2n+1)$,有 $\omega_{2n}^{2k}=g^{2k\frac{q}{2}}=g^{kq}=\omega_n^k$。

性质三

根据费马小定理,有

$$\omega_n^n=g^{nq}=g^{p-1}\equiv1\pmod p$$

可得 $\omega_n^{\frac{n}{2}}\equiv \pm1 \pmod p$。根据性质一,$\omega_n^{\frac{n}{2}}\neq \omega_n^0$,所以 $\omega_n^{\frac{n}{2}}\equiv -1\pmod p$,$\omega_n^{k+\frac{n}{2}}=\omega_n^k \times \omega_n^{\frac{n}{2}}=-\omega_n^k$。

性质四

当 $k\neq 0$ 时,

$$\begin{aligned} S(\omega_n^k)&=1+\omega_n^k+(\omega_n^k)^2+\cdots+(\omega_n^k)^{n-1}\\ \omega_n^k S(\omega_n^k)&=\omega_n^k+(\omega_n^k)^2+(\omega_n^k)^3+\cdots+(\omega_n^k)^{n}\\ S(\omega_n^k)&=\frac{(\omega_n^k)^n-1}{\omega_n^k-1}\end{aligned}$$

根据性质三,$(\omega_n^k)^n-1 \equiv 0 \pmod p$,故 $S(\omega_n^k)=0$。

构造

对于 $p$,我们可以从小到大枚举所有的 $g$。

对于一个 $g$,满足 $g^k\equiv1\pmod p$ 的最小的 $k$ 一定是 $p-1$ 的约数。

所以,判断 $g$ 合法性的方法就是枚举 $p-1$ 的约数 $k$,使得所有 $g^k \not\equiv 1 \pmod p$ 的 $g$ 合法。

实现

把 $FFT$ 中所有复数运算修改为模意义下的整数运算即可。

例题

定义一个无向图权值为所有结点度数的 $k$ 次方之和(规定 $0^0=1$)。

求所有 $n$ 个点的简单无向图(共有 $2^{C(n,2)}$ 个)的权值之和,对 $998244353$ 取模。

思路

考虑每个点的贡献,答案为 $2^{C(n-1,2)}n\sum_{i=0}^{n-1}i^kC(n-1,i)$。

我们要求的实际上是 $\sum_{i=0}^ni^kC(n,i)$。

根据第二类斯特林数反演,我们有 $x^k=\sum_{i=0}^kS(k,i)C(x,i)i!$。

证明:左边相当于 $k$ 个不同的小球放入 $x$ 个不同的箱子的方案数,而右边相当于枚举有几个箱子非空。

代入原式,得到

$$\begin{aligned}&\sum_{i=0}^ni^kC(n,i) \\ =& \sum_{i=0}^nC(n,i)\sum_{j=0}^kS(k,j)C(i,j)j! \\ =& \sum_{j=0}^kS(k,j)j!\sum_{i=0}^n C(n,j)C(n-j,i-j) \\ =& \sum_{j=0}^k S(k,j)j!C(n,j)2^{n-j}\end{aligned}$$

考虑二项式反演,

令 $f(n)=\sum_{i=0}^nC(n,i)g(i)$,则 $g(n)=\sum_{i=0}^n(-1)^{n-i}C(n,i)f(i)$。

可得 $S(k,i)i!=\sum_{j=0}^i(-1)^{i-j}C(i,j)j^k$,$NTT$ 即可。

代码

#include <bits/stdc++.h>
#define maxm 1048576
#define maxn 600005
#define ll long long
#define md 998244353
using namespace std;
int read(){
    int x=0,flag=1;char ch=getchar();
    while(!isdigit(ch)&&ch!='-') ch=getchar();
    if (ch=='-') flag=-1,ch=getchar();
    while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch-'0'),ch=getchar();
    return x*flag;
}
int n,m,rev[maxm+1],bit,s;
ll fac[maxn],inv[maxn],inv2;
ll wi[maxm+1],a[maxm+1],b[maxm+1];
ll ksm(ll a,ll b){
    ll res=1;
    while(b){
        if (b&1) res=res*a%md;
        a=a*a%md;b>>=1;
    }
    return res;
}
void init(){
    fac[0]=wi[0]=1;
    for (int i=1;i<=m+1;++i)
        fac[i]=fac[i-1]*i%md;
    inv[m+1]=ksm(fac[m+1],md-2);
    for (int i=m+1;i>=1;--i)
        inv[i-1]=inv[i]*i%md;
    for (bit=1,s=2;(1<<bit)<2*m+2;++bit)
        s<<=1;
    ll c=ksm(3,(md-1)/s);
    inv2=ksm(s,md-2);
    for (int i=1;i<=s;++i){
        if(i!=s) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
        wi[i]=wi[i-1]*c%md; 
    }
}
void ntt(ll *a,int n,int dft){
    for (int i=0;i<n;++i)
        if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int step=1;step<n;step<<=1){
        ll wn=dft?wi[n-n/(step<<1)]:wi[n/(step<<1)];
        for (int j=0;j<n;j+=step<<1){
            ll wnk=1;
            for (int k=j;k<j+step;++k,(wnk*=wn)%=md){
                ll t=wnk*a[k+step]%md;
                a[k+step]=(a[k]-t+md)%md;
                (a[k]+=t)%=md;
            }
        }
    }
    if (dft)
        for (int i=0;i<n;++i)
            (a[i]*=inv2)%=md;
}
int main(){
    n=read(),m=read();
    if(m==0){
        printf("%lld\n",ksm(2,n-1)*n%md*ksm(2,(ll)(n-1)*(ll)(n-2)/2)%md);
        return 0;
    }
    init();
    ll ans=0,v=1,c=1;
    int l=max(0,n-1-m);
    for (int i=0;i<=m;++i)
        a[i]=(inv[i]*v+md)%md,
        b[i]=ksm(i,m)*inv[i]%md,v=-v;
    ntt(a,s,0),ntt(b,s,0);
    for (int i=0;i<s;++i)
        (a[i]*=b[i])%=md;
    ntt(a,s,1);
    ll vs=ksm(2,n-1);
    for (int i=n-1;i>=l;--i){
        ans=(ans+vs*a[n-1-i]%md*c%md)%md;
        vs=vs*inv[2]%md;
        c=c*(ll)i%md;
    }
    ans=ans*(ll)n%md*ksm(2,(ll)(n-1)*(ll)(n-2)/2)%md;
    printf("%lld\n",(ans+md)%md);
    return 0;
}

Show Comments