【Note】Splay / 文艺平衡树

【Note】Splay / 文艺平衡树

定义

Splay 是一种可以自我调节的二叉搜索树。它在 $\Theta(\log n)$ ​ 的均摊时间内执行基本操作,例如插入,查找和删除。对于许多非随机操作序列,Splay 比其他搜索树表现更好。

核心操作

我们提到 Splay 可以 “自我调节”,这一性质使得它能维持树形结构,保证复杂度。自我调节分为两部分:旋转 (rotate),以及伸展 (splay)。

旋转

伸展

实际运用

作为二叉搜索树,它能完成所有二叉搜索树都能完成的基本操作。当然,它也可以作为区间树完成诸如区间翻转的操作,正是这一性质使得它可以作为 Link Cut Tree 的辅助树。

代码

普通平衡树

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入 $x$ 数
  2. 删除 $x$ 数(若有多个相同的数,因只删除一个)
  3. 查询 $x$ 数的排名(排名定义为比当前数小的数的个数 $+1$。若有多个相同的数,因输出最小的排名)
  4. 查询排名为 $x$ 的数
  5. 求 $x$ 的前驱(前驱定义为小于 $x$,且最大的数)
  6. 求 $x$ 的后继(后继定义为大于 $x$,且最小的数)
#include <bits/stdc++.h>
using namespace std;
const int inf=INT_MAX;
const int maxn=2e5+10;

int read(){
	int x=0,flag=1;char ch=getchar();
	while(!isdigit(ch)&&ch!='-') ch=getchar();
	if (ch=='-') flag=-1,ch=getchar();
	while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
	return x*flag;
}

class Splay{
	private:
		class node{
			public:
				int val,siz,recy;
				int fa,ch[2];
				void init(int x,int f){
					val=x,fa=f;
					siz=recy=1,ch[0]=ch[1]=0;
				}
		};
		node t[maxn];
		int root,tot,point;
		int identify(int x){
			return t[t[x].fa].ch[1]==x;
		}
		void update(int x){
			t[x].siz=t[t[x].ch[0]].siz+t[t[x].ch[1]].siz+t[x].recy;
		}
		void connect(int x,int f,int son){
			t[x].fa=f,t[f].ch[son]=x;
		}
		void rotate(int x){
			int y=t[x].fa,z=t[y].fa;
			int son=identify(y);
			int per=identify(x);
			int B=t[x].ch[per^1];
			connect(B,y,per);
			connect(y,x,per^1);
			connect(x,z,son);
			update(y),update(x);
		}
		void splay(int x,int to=0){
			while(t[x].fa!=to){
				int y=t[x].fa,z=t[y].fa;
				if (z!=to)
					identify(x)!=identify(y)?rotate(x):rotate(y);
				rotate(x);
			}
			if (!to) root=x;
		}
		int find(int k){
			int u=root;
			while(t[u].ch[k>t[u].val]&&t[u].val!=k)
			u=t[u].ch[k>t[u].val];
			splay(u);return root;
		}
		void destory(int x){
			t[x].init(0,0);
			t[x].siz=t[x].recy=0;
			if (x==tot) tot--;
		}
		int atrank(int k){
			if (k>point) return -inf;
			int u=root;
			while(1){
				if (t[t[u].ch[0]].siz>=k&&t[u].ch[0])
				u=t[u].ch[0];
				else if (t[t[u].ch[0]].siz+t[u].recy>=k) return u;
				else k-=t[t[u].ch[0]].siz+t[u].recy,u=t[u].ch[1];
			}
		}
		int upper(int v){
			find(v);
			if (t[root].val>v) return root;
			int u=t[root].ch[1];
			while(t[u].ch[0]) u=t[u].ch[0];
			return u;
		}
		int lower(int v){
			find(v);
			if (t[root].val<v) return root;
			int u=t[root].ch[0];
			while(t[u].ch[1]) u=t[u].ch[1];
			return u;
		}
	public:
		void insert(int x){
			int u=root,f=0;point++;
			while(u&&t[u].val!=x)
			f=u,u=t[u].ch[x>t[u].val];
			if (u) t[u].recy++;
			else{
				u=++tot;t[u].init(x,f);
				if (f) t[f].ch[x>t[f].val]=u;
			}
			splay(u);
		}
		void kickout(int x){
			int last=lower(x),next=upper(x);
			splay(last);splay(next,last);
			int del=t[next].ch[0];
			if (t[del].recy>1)
			t[del].recy--,splay(del);
			else t[next].ch[0]=0,update(next),update(root);
		}
		int rank_val(int v){
			find(v);
			return t[t[root].ch[0]].siz+(t[root].val>=v);
		}
		int atrank_val(int v){
			return t[atrank(v)].val;
		}
		int lower_val(int v){
			return t[lower(v)].val;
		}
		int upper_val(int v){
			return t[upper(v)].val;
		}
}F;

int main(){
	F.insert(inf),F.insert(-inf);
	int T_T=read();
	while(T_T--){
		int op=read(),v=read();
		if (op==1) F.insert(v);
		if (op==2) F.kickout(v);
		if (op==3) printf("%d\n",F.rank_val(v)-1);
		if (op==4) printf("%d\n",F.atrank_val(v+1));
		if (op==5) printf("%d\n",F.lower_val(v));
		if (op==6) printf("%d\n",F.upper_val(v));
	}
	return 0;
}

区间翻转

点击查看这道题目的 详细描述

#include <bits/stdc++.h>
using namespace std;
const int inf=INT_MAX;
const int maxn=2e5+10;

int read(){
	int x=0,flag=1;char ch=getchar();
	while(!isdigit(ch)&&ch!='-') ch=getchar();
	if (ch=='-') flag=-1,ch=getchar();
	while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
	return x*flag;
}

int n,m;

class Splay{
	private:
		class node{
			public:
				int val,siz,rev;
				int fa,ch[2];
				void init(int x,int f){
					val=x,fa=f;
					siz=rev,ch[0]=ch[1]=0;
				}
		};
		node t[maxn];
		int root,tot,point;
		int identify(int x){
			return t[t[x].fa].ch[1]==x;
		}
		void update(int x){
			t[x].siz=t[t[x].ch[0]].siz+t[t[x].ch[1]].siz+1;
		}
		void down(int x){
			if (t[x].rev){
				t[x].rev=0;
				swap(t[x].ch[1],t[x].ch[0]);
				t[t[x].ch[1]].rev^=1;
				t[t[x].ch[0]].rev^=1;
			}
		}
		void connect(int x,int f,int son){
			t[x].fa=f,t[f].ch[son]=x;
		}
		void rotate(int x){
			int y=t[x].fa,z=t[y].fa;
			int son=identify(y);
			int per=identify(x);
			int B=t[x].ch[per^1];
			connect(B,y,per);
			connect(y,x,per^1);
			connect(x,z,son);
			update(y),update(x);
		}
		void splay(int x,int to=0){
			while(t[x].fa!=to){
				int y=t[x].fa,z=t[y].fa;
				if (z!=to)
					identify(x)!=identify(y)?rotate(x):rotate(y);
				rotate(x);
			}
			if (!to) root=x;
		}
		int atrank(int k){
			if (k>point) return -inf;
			int u=root;
			while(1){
				down(u);
				if (t[t[u].ch[0]].siz>=k&&t[u].ch[0])
				u=t[u].ch[0];
				else if (t[t[u].ch[0]].siz+1>=k) return u;
				else k-=t[t[u].ch[0]].siz+1,u=t[u].ch[1];
			}
		}
	public:
		void insert(int x){
			int u=root,f=0;point++;
			while(u&&t[u].val!=x)
			f=u,u=t[u].ch[x>t[u].val];	
			u=++tot;t[u].init(x,f);
			if (f) t[f].ch[x>t[f].val]=u;
			splay(u);
		}
		void work(int l,int r){
			l=atrank(l),r=atrank(r+2);
			splay(l),splay(r,l);
			t[t[t[root].ch[1]].ch[0]].rev^=1;
		}
		void write(int u){
			down(u);
			if (t[u].ch[0]) write(t[u].ch[0]);
			if (t[u].val>1&&t[u].val<n+2) printf("%d ",t[u].val-1);
			if (t[u].ch[1]) write(t[u].ch[1]);
		}
		void finish(){
			write(root);
		}
}F;

int main(){
	n=read(),m=read();
	for (int i=1;i<=n+2;++i)
	F.insert(i);
	while(m--){
		int l=read(),r=read();
		F.work(l,r);
	}
	F.finish();
	return 0;
}