【Note】FHQTreap / 非旋 Treap

【Note】FHQTreap / 非旋 Treap

定义

当我们提到 “Treap” 的时候,我们实际上说的是 “Tree” + “Heap”。Treap 的每个节点拥有两个值:”Key” 和 “Val”。其中 “Key” 满足堆的性质,而 “Val” 满足二叉搜索树的性质。

朴素的 Treap 在插入的时候给每个点一个随机的 Key,通过旋转来满足堆的性质。对于 FHQTreap,我们不再使用旋转维护,转而使用拆分和合并。

核心操作

拆分

我们将一个 Treap 按特定的值 Val 拆分成两个 Treap – 左树和右树,使得左树的所有值小于等于 Val,而右树的所有值大于 Val。

void split(int now,int &a,int &b,int val){
    if (!now){
        a=b=0;
        return;
    }
    if (t[now].val<=val)
        a=now,split(t[now].rch,t[a].rch,b,val);
    else b=now,split(t[now].lch,a,t[b].lch,val);
    update(now);

合并

我们可以按照堆的性质,将两个 Treap a, b ,其中 a 包含的所有值小于 b 中所有元素,合并到一个 Treap 上。

void merge(int &now,int a,int b){
    if (!a||!b){
        now=a+b;
        return;
    }
    if (t[a].key<t[b].key)
        now=a,merge(t[now].rch,t[a].rch,b);
    else now=b,merge(t[now].lch,a,t[b].lch);
    update(now);
}

实际运用

当我们想要修改 / 获取某个点的权值时,我们希望操作的影响范围尽可能地小。比如,当我们期望删除一个点的时候,如果它恰好在根节点,我们可以通过合并它的左右子树来达成这一目的。上文提到的拆分和合并恰好可以完成这个操作。

同样的思路适用于以下所有问题:

  • 插入一个数值;
  • 删除一个数值;
  • 查询给定数值的排名;
  • 查询给定排名的数值;
  • 查询给定数值的前驱 / 后继。

代码

#include <bits/stdc++.h>
#define inf 2147483647
#define maxn 100010
using namespace std;
int seed=19260817;
int myrand(){
    return seed=int(seed*482811ll%inf);
}
int read(){
    int x=0,flag=1;char ch=getchar();
    while(!isdigit(ch)&&ch!='-') ch=getchar();
    if (ch=='-') flag=-1,ch=getchar();
    while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch-'0'),ch=getchar();
    return x*flag;
}
class FHQTreap{
    private:
        class node{
            public:
            int siz,val,key;
            int lch,rch;
        };
        node t[maxn];
        int tot,root;
        int build(int val){
            t[++tot].siz=1;
            t[tot].val=val;
            t[tot].key=myrand();
            t[tot].lch=t[tot].rch=0;
            return tot;
        }
        void update(int now){
            t[now].siz=t[t[now].lch].siz+t[t[now].rch].siz+1;
        }
        void split(int now,int &a,int &b,int val){
            if (!now){
                a=b=0;
                return;
            }
            if (t[now].val<=val)
                a=now,split(t[now].rch,t[a].rch,b,val);
            else b=now,split(t[now].lch,a,t[b].lch,val);
            update(now);
        }
        void merge(int &now,int a,int b){
            if (!a||!b){
                now=a+b;
                return;
            }
            if (t[a].key<t[b].key)
                now=a,merge(t[now].rch,t[a].rch,b);
            else now=b,merge(t[now].lch,a,t[b].lch);
            update(now);
        }
        int find(int now,int rank){
            while(t[t[now].lch].siz+1!=rank){
                if (t[t[now].lch].siz>=rank)
                    now=t[now].lch;
                else rank-=t[t[now].lch].siz+1,
                    now=t[now].rch;
            }
            return t[now].val;
        }
    public:
        void setup(){
            build(inf);
            t[root=1].siz=0;
        }
        void insert(int val){
            int x=0,y=0,z;
            z=build(val);
            split(root,x,y,val);
            merge(x,x,z);
            merge(root,x,y);
        }
        void delet(int val){
            int x=0,y=0,z=0;
            split(root,x,y,val);
            split(x,x,z,val-1);
            merge(z,t[z].lch,t[z].rch);
            merge(x,x,z);
            merge(root,x,y);
        }
        int rank(int val){
            int x=0,y=0;
            split(root,x,y,val-1);
            int res=t[x].siz+1;
            merge(root,x,y);
            return res;
        }
        int atrank(int rank){
            return find(root,rank);
        }
        int lower(int val){
            int x=0,y=0;
            split(root,x,y,val-1);
            int res=find(x,t[x].siz);
            merge(root,x,y);
            return res;
        }
        int upper(int val){
            int x=0,y=0;
            split(root,x,y,val);
            int res=find(y,1);
            merge(root,x,y);
            return res;
        }
}ft;
void out(int x){
    printf("%d\n",x);
}
int main(){
    int n=read();ft.setup();
    for (int i=1;i<=n;++i){
        int opt=read(),x=read();
        switch(opt){
            case 1:ft.insert(x);break;
            case 2:ft.delet(x);break;
            case 3:out(ft.rank(x));break;
            case 4:out(ft.atrank(x));break;
            case 5:out(ft.lower(x));break;
            case 6:out(ft.upper(x));break;
        }
    }
    return 0;
}