这篇文章将提出 回文树 的概念,一种解决回文串相关问题的优美数据结构。
感谢它的发明者 Mikhail Rubinchik,在 Petrozavodsk Summer Camp 2014 中向我们讲述了这一概念,使得这篇文章的创作成为可能。
结构
和其他所有树形结构一样,回文树由节点构成。在这里,每个节点代表了一个回文串。

在节点之间有边相连。一条从 $u$ 连至 $v$ 的树边上会标有一个字母 $x$,表示在 $u$ 处的回文串左右加上一个 $x$ 得到 $v$ 处的回文串。

同时也存在着一些非树边,我们称之为 $fail$ 指针。从 $u$ 指向 $v$ 的 $fail$ 指针表示 $v$ 是 $u$ 的最长回文后缀。当然,$u \neq v$。

在实际使用中,我们不记录每个节点具体表示了什么串,我们记录的是它的长度,$fail$ 指针,以及子节点的编号。
构建
初始化
作为一棵树,回文树有个不寻常之处:它有两个根节点。根节点 $0$ 用来处理长度为偶数的回文串,而根节点 $1$ 用来处理长度为奇数的。
特别地,$fail(0)=1$,而 $len(1)=-1$,因为子节点的长度一定是父节点 $+2$。
插入和匹配
当我们插入一个新的字符 $x$ 时,最终形成的回文串一定是在已有的回文串的左右各加上 $x$。我们设这个新回文串为 $N$,它包含的最长回文子串为 $C$,我们要寻找的是下面式子中的 $xC$。
$$N = x C x$$
也就是,我们从上个字符代表的回文串开始,判断这个回文串的左边一个字符是否等于 $x$,如果不是则跳到 $fail$ 指针指向的后缀字符串,不断进行,直到找到满足的回文串 $C$。这个回文串 $C$ 左右加上 $x$ 之后,就是以现在的 $x$ 结尾的最长回文串 $N$。
找到最长回文串之后,还需要构建它的 $fail$ 指针,也就是 $x$ 结尾的次长 回文串。$N$ 由 $C $ 转移而来,$N$ 的 $fail$ 也由 $C$ 的 $fail$ 转移而来。用下面这个式子表示,$xC’$ 就是 $C$ 的 $fail$ 的一个后缀,$C’$ 是一个回文串。
$fail(N)=xC’x$
int find(int x){
while(s[n-t[x].len-1]!=s[n])
x=t[x].fail;
return x;
}
void add(int x){
s[++n]=x;
int cur=find(last);
if (!t[cur].vis[x]){
t[++siz].fail=t[find(t[cur].fail)].vis[x];
t[cur].vis[x]=siz;t[siz].len=t[cur].len+2;
}
last=t[cur].vis[x];t[last].cnt++;
}
应用
我们发现,一个字符串的回文树的构建,帮助我们找到了它所有本质不同的回文子串的长度和出现次数。我们可以用这个性质解决很多问题。
例:[APIO2014]回文串
给你一个由小写拉丁字母组成的字符串 $s$。我们定义 $s$ 的一个子串的存在值为这个子串在 $s$ 中出现的次数乘以这个子串的长度。
对于给你的这个字符串 $s$,求所有回文子串中的最大存在值。
思路
实际上就是个裸题啦。但有一点要注意,计算次数的时候需要从深度大的点开始向上沿着 $fail$ 累加,因为一个回文串出现也代表着它指向的 $fail$ 的出现。
代码
#include <bits/stdc++.h>
#define chkmax(a,b) (a<b?a=b:0)
using namespace std;
const int maxn=300005;
typedef long long ll;
class Palindrome_Automaton{
private:
class node{
public:
int len,cnt,fail;
int vis[26];
};
node t[maxn];
int last,n,siz,s[maxn];
int find(int x){
while(s[n-t[x].len-1]!=s[n])
x=t[x].fail;
return x;
}
public:
void init(){
s[0]=-1;
t[siz=1].len=-1;
t[0].fail=1;
}
void add(int x){
s[++n]=x;
int cur=find(last);
if (!t[cur].vis[x]){
t[++siz].fail=t[find(t[cur].fail)].vis[x];
t[cur].vis[x]=siz;t[siz].len=t[cur].len+2;
}
last=t[cur].vis[x];t[last].cnt++;
}
ll count(){
ll res=0;
for (int i=siz;~i;--i){
t[t[i].fail].cnt+=t[i].cnt;
chkmax(res,1ll*t[i].cnt*t[i].len);
}
return res;
}
}T;
char s[maxn];
int len;
int main(){
scanf("%s",s),len=strlen(s);
T.init();
for (int i=0;i<len;++i)
T.add(s[i]-'a');
printf("%lld\n",T.count());
return 0;
}