定义
Splay 是一种可以自我调节的二叉搜索树。它在 $\Theta(\log n)$ 的均摊时间内执行基本操作,例如插入,查找和删除。对于许多非随机操作序列,Splay 比其他搜索树表现更好。
核心操作
我们提到 Splay 可以 “自我调节”,这一性质使得它能维持树形结构,保证复杂度。自我调节分为两部分:旋转 (rotate),以及伸展 (splay)。
旋转
伸展
实际运用
作为二叉搜索树,它能完成所有二叉搜索树都能完成的基本操作。当然,它也可以作为区间树完成诸如区间翻转的操作,正是这一性质使得它可以作为 Link Cut Tree 的辅助树。
代码
普通平衡树
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入 $x$ 数
- 删除 $x$ 数(若有多个相同的数,因只删除一个)
- 查询 $x$ 数的排名(排名定义为比当前数小的数的个数 $+1$。若有多个相同的数,因输出最小的排名)
- 查询排名为 $x$ 的数
- 求 $x$ 的前驱(前驱定义为小于 $x$,且最大的数)
- 求 $x$ 的后继(后继定义为大于 $x$,且最小的数)
#include <bits/stdc++.h>
using namespace std;
const int inf=INT_MAX;
const int maxn=2e5+10;
int read(){
int x=0,flag=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-') ch=getchar();
if (ch=='-') flag=-1,ch=getchar();
while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*flag;
}
class Splay{
private:
class node{
public:
int val,siz,recy;
int fa,ch[2];
void init(int x,int f){
val=x,fa=f;
siz=recy=1,ch[0]=ch[1]=0;
}
};
node t[maxn];
int root,tot,point;
int identify(int x){
return t[t[x].fa].ch[1]==x;
}
void update(int x){
t[x].siz=t[t[x].ch[0]].siz+t[t[x].ch[1]].siz+t[x].recy;
}
void connect(int x,int f,int son){
t[x].fa=f,t[f].ch[son]=x;
}
void rotate(int x){
int y=t[x].fa,z=t[y].fa;
int son=identify(y);
int per=identify(x);
int B=t[x].ch[per^1];
connect(B,y,per);
connect(y,x,per^1);
connect(x,z,son);
update(y),update(x);
}
void splay(int x,int to=0){
while(t[x].fa!=to){
int y=t[x].fa,z=t[y].fa;
if (z!=to)
identify(x)!=identify(y)?rotate(x):rotate(y);
rotate(x);
}
if (!to) root=x;
}
int find(int k){
int u=root;
while(t[u].ch[k>t[u].val]&&t[u].val!=k)
u=t[u].ch[k>t[u].val];
splay(u);return root;
}
void destory(int x){
t[x].init(0,0);
t[x].siz=t[x].recy=0;
if (x==tot) tot--;
}
int atrank(int k){
if (k>point) return -inf;
int u=root;
while(1){
if (t[t[u].ch[0]].siz>=k&&t[u].ch[0])
u=t[u].ch[0];
else if (t[t[u].ch[0]].siz+t[u].recy>=k) return u;
else k-=t[t[u].ch[0]].siz+t[u].recy,u=t[u].ch[1];
}
}
int upper(int v){
find(v);
if (t[root].val>v) return root;
int u=t[root].ch[1];
while(t[u].ch[0]) u=t[u].ch[0];
return u;
}
int lower(int v){
find(v);
if (t[root].val<v) return root;
int u=t[root].ch[0];
while(t[u].ch[1]) u=t[u].ch[1];
return u;
}
public:
void insert(int x){
int u=root,f=0;point++;
while(u&&t[u].val!=x)
f=u,u=t[u].ch[x>t[u].val];
if (u) t[u].recy++;
else{
u=++tot;t[u].init(x,f);
if (f) t[f].ch[x>t[f].val]=u;
}
splay(u);
}
void kickout(int x){
int last=lower(x),next=upper(x);
splay(last);splay(next,last);
int del=t[next].ch[0];
if (t[del].recy>1)
t[del].recy--,splay(del);
else t[next].ch[0]=0,update(next),update(root);
}
int rank_val(int v){
find(v);
return t[t[root].ch[0]].siz+(t[root].val>=v);
}
int atrank_val(int v){
return t[atrank(v)].val;
}
int lower_val(int v){
return t[lower(v)].val;
}
int upper_val(int v){
return t[upper(v)].val;
}
}F;
int main(){
F.insert(inf),F.insert(-inf);
int T_T=read();
while(T_T--){
int op=read(),v=read();
if (op==1) F.insert(v);
if (op==2) F.kickout(v);
if (op==3) printf("%d\n",F.rank_val(v)-1);
if (op==4) printf("%d\n",F.atrank_val(v+1));
if (op==5) printf("%d\n",F.lower_val(v));
if (op==6) printf("%d\n",F.upper_val(v));
}
return 0;
}
区间翻转
点击查看这道题目的 详细描述。
#include <bits/stdc++.h>
using namespace std;
const int inf=INT_MAX;
const int maxn=2e5+10;
int read(){
int x=0,flag=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-') ch=getchar();
if (ch=='-') flag=-1,ch=getchar();
while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*flag;
}
int n,m;
class Splay{
private:
class node{
public:
int val,siz,rev;
int fa,ch[2];
void init(int x,int f){
val=x,fa=f;
siz=rev,ch[0]=ch[1]=0;
}
};
node t[maxn];
int root,tot,point;
int identify(int x){
return t[t[x].fa].ch[1]==x;
}
void update(int x){
t[x].siz=t[t[x].ch[0]].siz+t[t[x].ch[1]].siz+1;
}
void down(int x){
if (t[x].rev){
t[x].rev=0;
swap(t[x].ch[1],t[x].ch[0]);
t[t[x].ch[1]].rev^=1;
t[t[x].ch[0]].rev^=1;
}
}
void connect(int x,int f,int son){
t[x].fa=f,t[f].ch[son]=x;
}
void rotate(int x){
int y=t[x].fa,z=t[y].fa;
int son=identify(y);
int per=identify(x);
int B=t[x].ch[per^1];
connect(B,y,per);
connect(y,x,per^1);
connect(x,z,son);
update(y),update(x);
}
void splay(int x,int to=0){
while(t[x].fa!=to){
int y=t[x].fa,z=t[y].fa;
if (z!=to)
identify(x)!=identify(y)?rotate(x):rotate(y);
rotate(x);
}
if (!to) root=x;
}
int atrank(int k){
if (k>point) return -inf;
int u=root;
while(1){
down(u);
if (t[t[u].ch[0]].siz>=k&&t[u].ch[0])
u=t[u].ch[0];
else if (t[t[u].ch[0]].siz+1>=k) return u;
else k-=t[t[u].ch[0]].siz+1,u=t[u].ch[1];
}
}
public:
void insert(int x){
int u=root,f=0;point++;
while(u&&t[u].val!=x)
f=u,u=t[u].ch[x>t[u].val];
u=++tot;t[u].init(x,f);
if (f) t[f].ch[x>t[f].val]=u;
splay(u);
}
void work(int l,int r){
l=atrank(l),r=atrank(r+2);
splay(l),splay(r,l);
t[t[t[root].ch[1]].ch[0]].rev^=1;
}
void write(int u){
down(u);
if (t[u].ch[0]) write(t[u].ch[0]);
if (t[u].val>1&&t[u].val<n+2) printf("%d ",t[u].val-1);
if (t[u].ch[1]) write(t[u].ch[1]);
}
void finish(){
write(root);
}
}F;
int main(){
n=read(),m=read();
for (int i=1;i<=n+2;++i)
F.insert(i);
while(m--){
int l=read(),r=read();
F.work(l,r);
}
F.finish();
return 0;
}