【笔记】快速沃尔什变换入门

【笔记】快速沃尔什变换入门

概念

快速沃尔什变换是一种类似 FFT,用来加速逻辑位运算卷积的算法。具体地,它被设计为用来解决以下问题:

$$C_k=\sum_{i\oplus j=k}a_i\times b_j$$

其中,$\oplus$ 表示位运算 $\lor, \wedge, \veebar$。直接枚举 $i,j$ 的复杂度是 $\Theta(n^2)$,FWT 能加速这个过程到 $\Theta(n\log n)$。

原理

前面说到它类似 FFT,因为它们都是通过先做一个变换,相乘,然后逆变换得到结果的,而变换的过程都是折半、分治。

我们定义变换为 $tf$,有

$$tf(C)=tf(A)\times tf(B)$$

对于 $tf$,分治的过程只需要前后折半,不像 FFT 需要二进制翻转,更为简洁;麻烦的是,对于每种不同的位运算,$tf$ 和 $utf$ 函数是不同的。

下面列出了三种运算的对应函数:

xor

$$\begin{aligned} tf(A)&= (tf(A_0)+tf(A_1),tf(A_0)-tf(A_1)) \\ utf(A)&= (utf(\frac{A_0+A_1}{2}),utf(\frac{A_0-A_1}{2}))\end{aligned}$$

and

$$\begin{aligned}tf(A)&=(tf({A}{0})+tf({A}{1}),tf({A}{1})) \\ utf(A)&=(utf({A}{0})-utf({A}{1}),utf({A}{1}))\end{aligned}$$

or

$$\begin{aligned}tf(A)&=(tf({A}{0}),tf({A}{1})+tf({A}{0})) \\ utf(A)&=(utf({A}{0}),utf({A}{1})-utf({A}{0})) \end{aligned}$$

所有式子中,$A_0$ 表示 $A$ 的前半部分,$A_1$ 为后半部分。

例题

给定一个序列 $a$,问 $\text{xor}$ 为 $0$ 子序列的最长长度。

思路

我们设所有元素的 $\text{xor}$ 为 $w$,我们实际上要挑出最少的元素,使得它们的 $\text{xor}$ 为 $w$,剩下的元素就是题目中要我们求的最长子序列了。

具体地,我们设 $f(i,x)$ 表示选 $i$ 个数,异或和能否恰好为 $x$。因为在最劣情况下,最大的 $i$ 依然不会超过 $\log$ 级别,我们可以直接 $FWT$ 优化转移。

代码

#include <bits/stdc++.h>
#define mid ((l+r)>>1)
#define maxn 2000005
#define md 10007
#define inv2 ((md+1)>>1) 
using namespace std;
int read(){
    int x=0;char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x;
}
int n,f[maxn],g[25][maxn],w,s,bit;
void FWT(int *a,int l,int r){
    if (l==r) return;
    FWT(a,l,mid);FWT(a,mid+1,r);
    int len=mid-l+1;
    for (int i=l;i<=mid;++i){
        int u=a[i],v=a[i+len];
        a[i]=(u+v)%md;
        a[i+len]=(u-v)%md;
    }
}
void IFWT(int *a,int l,int r){
    if (l==r) return;
    int len=mid-l+1;
    for (int i=l;i<=mid;++i){
        int u=a[i],v=a[i+len];
        a[i]=(u+v)*inv2%md;
        a[i+len]=(u-v)*inv2%md;
    }
    IFWT(a,l,mid);IFWT(a,mid+1,r);
}
int main(){
    freopen("sub.in","r",stdin);
    freopen("sub.out","w",stdout);
    n=read();
    for (int i=1;i<=n;++i){
        int x=read();
        f[x]=1,w^=x;
    }
    for (s=2,bit=1;s<2*n;s<<=1,bit++);
    if (!w){printf("%d\n",n);return 0;}
    f[0]=1;FWT(f,0,s-1);
    for (int i=0;i<s;++i)
        g[0][i]=1;
    for (int j=1;j<=bit;++j)
        for (int i=0;i<s;++i)
            g[j][i]=g[j-1][i]*f[i]%md;
    int l=0,r=bit+1;
    while(l<r){
        IFWT(g[mid],0,s-1);
        if (g[mid][w]) r=mid;
        else l=mid+1;
    }
    if (l==bit+1) puts("0");
    else printf("%d\n",n-l);
}