概念
快速沃尔什变换是一种类似 FFT,用来加速逻辑位运算卷积的算法。具体地,它被设计为用来解决以下问题:
$$C_k=\sum_{i\oplus j=k}a_i\times b_j$$
其中,$\oplus$ 表示位运算 $\lor, \wedge, \veebar$。直接枚举 $i,j$ 的复杂度是 $\Theta(n^2)$,FWT 能加速这个过程到 $\Theta(n\log n)$。
原理
前面说到它类似 FFT,因为它们都是通过先做一个变换,相乘,然后逆变换得到结果的,而变换的过程都是折半、分治。
我们定义变换为 $tf$,有
$$tf(C)=tf(A)\times tf(B)$$
对于 $tf$,分治的过程只需要前后折半,不像 FFT 需要二进制翻转,更为简洁;麻烦的是,对于每种不同的位运算,$tf$ 和 $utf$ 函数是不同的。
下面列出了三种运算的对应函数:
xor
$$\begin{aligned} tf(A)&= (tf(A_0)+tf(A_1),tf(A_0)-tf(A_1)) \\ utf(A)&= (utf(\frac{A_0+A_1}{2}),utf(\frac{A_0-A_1}{2}))\end{aligned}$$
and
$$\begin{aligned}tf(A)&=(tf({A}{0})+tf({A}{1}),tf({A}{1})) \\ utf(A)&=(utf({A}{0})-utf({A}{1}),utf({A}{1}))\end{aligned}$$
or
$$\begin{aligned}tf(A)&=(tf({A}{0}),tf({A}{1})+tf({A}{0})) \\ utf(A)&=(utf({A}{0}),utf({A}{1})-utf({A}{0})) \end{aligned}$$
所有式子中,$A_0$ 表示 $A$ 的前半部分,$A_1$ 为后半部分。
例题
给定一个序列 $a$,问 $\text{xor}$ 为 $0$ 子序列的最长长度。
思路
我们设所有元素的 $\text{xor}$ 为 $w$,我们实际上要挑出最少的元素,使得它们的 $\text{xor}$ 为 $w$,剩下的元素就是题目中要我们求的最长子序列了。
具体地,我们设 $f(i,x)$ 表示选 $i$ 个数,异或和能否恰好为 $x$。因为在最劣情况下,最大的 $i$ 依然不会超过 $\log$ 级别,我们可以直接 $FWT$ 优化转移。
代码
#include <bits/stdc++.h>
#define mid ((l+r)>>1)
#define maxn 2000005
#define md 10007
#define inv2 ((md+1)>>1)
using namespace std;
int read(){
int x=0;char ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x;
}
int n,f[maxn],g[25][maxn],w,s,bit;
void FWT(int *a,int l,int r){
if (l==r) return;
FWT(a,l,mid);FWT(a,mid+1,r);
int len=mid-l+1;
for (int i=l;i<=mid;++i){
int u=a[i],v=a[i+len];
a[i]=(u+v)%md;
a[i+len]=(u-v)%md;
}
}
void IFWT(int *a,int l,int r){
if (l==r) return;
int len=mid-l+1;
for (int i=l;i<=mid;++i){
int u=a[i],v=a[i+len];
a[i]=(u+v)*inv2%md;
a[i+len]=(u-v)*inv2%md;
}
IFWT(a,l,mid);IFWT(a,mid+1,r);
}
int main(){
freopen("sub.in","r",stdin);
freopen("sub.out","w",stdout);
n=read();
for (int i=1;i<=n;++i){
int x=read();
f[x]=1,w^=x;
}
for (s=2,bit=1;s<2*n;s<<=1,bit++);
if (!w){printf("%d\n",n);return 0;}
f[0]=1;FWT(f,0,s-1);
for (int i=0;i<s;++i)
g[0][i]=1;
for (int j=1;j<=bit;++j)
for (int i=0;i<s;++i)
g[j][i]=g[j-1][i]*f[i]%md;
int l=0,r=bit+1;
while(l<r){
IFWT(g[mid],0,s-1);
if (g[mid][w]) r=mid;
else l=mid+1;
}
if (l==bit+1) puts("0");
else printf("%d\n",n-l);
}