【Note】Palindromic Tree / 回文树

【Note】Palindromic Tree / 回文树

这篇文章将提出 回文树 的概念,一种解决回文串相关问题的优美数据结构。

感谢它的发明者 Mikhail Rubinchik,在 Petrozavodsk Summer Camp 2014 中向我们讲述了这一概念,使得这篇文章的创作成为可能。

结构

和其他所有树形结构一样,回文树由节点构成。在这里,每个节点代表了一个回文串。

四个节点的回文树
四节点回文树的样例

在节点之间有边相连。一条从 $u$ 连至 $v$ 的树边上会标有一个字母 $x$,表示在 $u$ 处的回文串左右加上一个 $x$ 得到 $v$ 处的回文串。

‘aba’ 由 ‘b’ 在左右两边各添加一个 ‘a’ 得来

同时也存在着一些非树边,我们称之为 $fail$ 指针。从 $u$ 指向 $v$ 的 $fail$ 指针表示 $v$ 是 $u$ 的最长回文后缀。当然,$u \neq v$。

‘a’ 是 ‘aba’ 的最长回文后缀

在实际使用中,我们不记录每个节点具体表示了什么串,我们记录的是它的长度,$fail$ 指针,以及子节点的编号。

构建

初始化

作为一棵树,回文树有个不寻常之处:它有两个根节点。根节点 $0$ 用来处理长度为偶数的回文串,而根节点 $1$ 用来处理长度为奇数的。

特别地,$fail(0)=1$,而 $len(1)=-1$,因为子节点的长度一定是父节点 $+2$。

插入和匹配

当我们插入一个新的字符 $x$ 时,最终形成的回文串一定是在已有的回文串的左右各加上 $x$。我们设这个新回文串为 $N$,它包含的最长回文子串为 $C$,我们要寻找的是下面式子中的 $xC$。

$$N = x C x$$

也就是,我们从上个字符代表的回文串开始,判断这个回文串的左边一个字符是否等于 $x$,如果不是则跳到 $fail$ 指针指向的后缀字符串,不断进行,直到找到满足的回文串 $C$。这个回文串 $C$ 左右加上 $x$ 之后,就是以现在的 $x$ 结尾的最长回文串 $N$。

找到最长回文串之后,还需要构建它的 $fail$ 指针,也就是 $x$ 结尾的次长 回文串。$N$ 由 $C $ 转移而来,$N$ 的 $fail$ 也由 $C$ 的 $fail$ 转移而来。用下面这个式子表示,$xC’$ 就是 $C$ 的 $fail$ 的一个后缀,$C’$ 是一个回文串。

$fail(N)=xC’x$

int find(int x){
    while(s[n-t[x].len-1]!=s[n])
        x=t[x].fail;
    return x;
}
void add(int x){
    s[++n]=x;
    int cur=find(last);
    if (!t[cur].vis[x]){
        t[++siz].fail=t[find(t[cur].fail)].vis[x];
        t[cur].vis[x]=siz;t[siz].len=t[cur].len+2;
    }
    last=t[cur].vis[x];t[last].cnt++;
}

应用

我们发现,一个字符串的回文树的构建,帮助我们找到了它所有本质不同的回文子串的长度和出现次数。我们可以用这个性质解决很多问题。

例:[APIO2014]回文串

给你一个由小写拉丁字母组成的字符串 $s$。我们定义 $s$ 的一个子串的存在值为这个子串在 $s$ 中出现的次数乘以这个子串的长度。

对于给你的这个字符串 $s$,求所有回文子串中的最大存在值。

思路

实际上就是个裸题啦。但有一点要注意,计算次数的时候需要从深度大的点开始向上沿着 $fail$ 累加,因为一个回文串出现也代表着它指向的 $fail$ 的出现。

代码
#include <bits/stdc++.h>
#define chkmax(a,b) (a<b?a=b:0)
using namespace std;
const int maxn=300005;
typedef long long ll;
class Palindrome_Automaton{
    private:
        class node{
            public:
                int len,cnt,fail;
                int vis[26];
        };
        node t[maxn];
        int last,n,siz,s[maxn];
        int find(int x){
            while(s[n-t[x].len-1]!=s[n])
                x=t[x].fail;
            return x;
        }
    public:
        void init(){
            s[0]=-1;
            t[siz=1].len=-1;
            t[0].fail=1;
        }
        void add(int x){
            s[++n]=x;
            int cur=find(last);
            if (!t[cur].vis[x]){
                t[++siz].fail=t[find(t[cur].fail)].vis[x];
                t[cur].vis[x]=siz;t[siz].len=t[cur].len+2;
            }
            last=t[cur].vis[x];t[last].cnt++;
        }
        ll count(){
            ll res=0;
            for (int i=siz;~i;--i){
                t[t[i].fail].cnt+=t[i].cnt;
                chkmax(res,1ll*t[i].cnt*t[i].len);
            }
            return res;
        }
}T;
char s[maxn];
int len;
int main(){
    scanf("%s",s),len=strlen(s);
    T.init();
    for (int i=0;i<len;++i)
        T.add(s[i]-'a');
    printf("%lld\n",T.count());
    return 0;
}