概念
快速沃尔什变换是一种类似 FFT,用来加速逻辑位运算卷积的算法。具体地,它被设计为用来解决以下问题:
$$C_k=\sum_{i\oplus j=k}a_i\times b_j$$
其中,$\oplus$ 表示位运算 $\lor, \wedge, \veebar$。直接枚举 $i,j$ 的复杂度是 $\Theta(n^2)$,FWT 能加速这个过程到 $\Theta(n\log n)$。
原理
前面说到它类似 FFT,因为它们都是通过先做一个变换,相乘,然后逆变换得到结果的,而变换的过程都是折半、分治。
我们定义变换为 $tf$,有
$$tf(C)=tf(A)\times tf(B)$$
对于 $tf$,分治的过程只需要前后折半,不像 FFT 需要二进制翻转,更为简洁;麻烦的是,对于每种不同的位运算,$tf$ 和 $utf$ 函数是不同的。
下面列出了三种运算的对应函数:
xor
$$\begin{aligned} tf(A)&= (tf(A_0)+tf(A_1),tf(A_0)-tf(A_1)) \\ utf(A)&= (utf(\frac{A_0+A_1}{2}),utf(\frac{A_0-A_1}{2}))\end{aligned}$$
and
$$\begin{aligned}tf(A)&=(tf({A}{0})+tf({A}{1}),tf({A}{1})) \\ utf(A)&=(utf({A}{0})-utf({A}{1}),utf({A}{1}))\end{aligned}$$
or
$$\begin{aligned}tf(A)&=(tf({A}{0}),tf({A}{1})+tf({A}{0})) \\ utf(A)&=(utf({A}{0}),utf({A}{1})-utf({A}{0})) \end{aligned}$$
所有式子中,$A_0$ 表示 $A$ 的前半部分,$A_1$ 为后半部分。
例题
给定一个序列 $a$,问 $\text{xor}$ 为 $0$ 子序列的最长长度。
思路
我们设所有元素的 $\text{xor}$ 为 $w$,我们实际上要挑出最少的元素,使得它们的 $\text{xor}$ 为 $w$,剩下的元素就是题目中要我们求的最长子序列了。
具体地,我们设 $f(i,x)$ 表示选 $i$ 个数,异或和能否恰好为 $x$。因为在最劣情况下,最大的 $i$ 依然不会超过 $\log$ 级别,我们可以直接 $FWT$ 优化转移。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
#include <bits/stdc++.h> #define mid ((l+r)>>1) #define maxn 2000005 #define md 10007 #define inv2 ((md+1)>>1) using namespace std; int read(){ int x=0;char ch=getchar(); while(!isdigit(ch)) ch=getchar(); while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar(); return x; } int n,f[maxn],g[25][maxn],w,s,bit; void FWT(int *a,int l,int r){ if (l==r) return; FWT(a,l,mid);FWT(a,mid+1,r); int len=mid-l+1; for (int i=l;i<=mid;++i){ int u=a[i],v=a[i+len]; a[i]=(u+v)%md; a[i+len]=(u-v)%md; } } void IFWT(int *a,int l,int r){ if (l==r) return; int len=mid-l+1; for (int i=l;i<=mid;++i){ int u=a[i],v=a[i+len]; a[i]=(u+v)*inv2%md; a[i+len]=(u-v)*inv2%md; } IFWT(a,l,mid);IFWT(a,mid+1,r); } int main(){ freopen("sub.in","r",stdin); freopen("sub.out","w",stdout); n=read(); for (int i=1;i<=n;++i){ int x=read(); f[x]=1,w^=x; } for (s=2,bit=1;s<2*n;s<<=1,bit++); if (!w){printf("%d\n",n);return 0;} f[0]=1;FWT(f,0,s-1); for (int i=0;i<s;++i) g[0][i]=1; for (int j=1;j<=bit;++j) for (int i=0;i<s;++i) g[j][i]=g[j-1][i]*f[i]%md; int l=0,r=bit+1; while(l<r){ IFWT(g[mid],0,s-1); if (g[mid][w]) r=mid; else l=mid+1; } if (l==bit+1) puts("0"); else printf("%d\n",n-l); } |