码迷,mamicode.com
首页 > 其他好文 > 详细

【最简单的平衡树】Treap

时间:2021-05-24 02:25:32      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:ref   时间   数据结构   简单的   添加元素   参考   previous   code   就会   

Treap

树是很有用的一种结构,加上不同的约束规则之后,就形成了各种特性鲜明的结构。
最基本的二叉树,加以约束之后,可以形成 BST、AVL、Heap、Splay、R-B Tree …… 等,适用于各种场景。
对于平衡树,种类有很多,有的严格平衡,每次某个子树上任意两个子树的高度差超过1就会进行调整;也有的弱平衡,两子树高度差不超过一倍就不会调整。平衡树很重要,常用于 map、set 等数据结构和数据库等基础设施。但是平衡树大多并不好写,因此一般使用标准库提供的套件来完成工作。这就导致了有些时候,我们想要定制一些操作的时候,难以在封装好的数据结构上进行操作。

本文介绍一种简单好写的平衡树,并给出模板代码,可用于实际解决查元素排名、查排名元素、查前驱、查后继,以及区间操作等。

Treap= Tree + heap

Treap 使用 BST 来提供二分检索特性,使用 Heap 来管理旋转操作,维护二叉树平衡。相比其他平衡树,多用了一个字段来存储堆的权重。堆权随机生成,由随机性来保证二叉树“大概率是平衡的”。

Treap 有多种实现,可以用指针,也可以用数组;可以使用旋转来保证平衡,也可以用分裂合并来做。OI Wiki 解释说无旋 Treap 具有支持序列化的优势,我们在写题的时候用不到这一点,可以直接上数组,减少指针操作。

参考代码

这里使用了 OI Wiki 提供的模板代码,并修正了删除不存在元素时 size-- 的错误,另外还加入了 find / count / delall 操作。
注意一点,查询操作只需要传值,插入删除旋转操作传递的是引用,以便旋转操作修改根节点索引。
为了清晰展示思路,函数用的都是递归操作。代码如下:

#include <cstdio>
#include <algorithm>

#define maxn 100005
#define INF (1 << 30)

struct treap {
    // cnt 是元素重复次数,size 是子树结点数目(计入重复元素),rnd 是堆权
    // l, r 自动初始化为0,作为空值;0 号结点是空结点
    // cnt 默认0个元素,size 默认子树元素为0
    int l[maxn], r[maxn], val[maxn], cnt[maxn], rnd[maxn], size[maxn];
    int sz; // array size, used for insert
    int rt; // tree root index
    int ans;

    void lrotate(int& k)
    {
        int t = r[k];
        r[k] = l[t];
        l[t] = k;
        size[t] = size[k];
        size[k] = size[l[k]] + size[r[k]] + cnt[k];
        k = t;
    }
    void rrotate(int& k)
    {
        int t = l[k];
        l[k] = r[t];
        r[t] = k;
        size[t] = size[k];
        size[k] = size[l[k]] + size[r[k]] + cnt[k];
        k = t;
    }
    void insert(int& k, int x)
    {
        if (!k) { // append to the end
            sz++;
            k = sz;
            val[k] = x;
            cnt[k] = 1;
            size[k] = 1;
            rnd[k] = rand();
            return;
        }
        size[k]++;
        if (val[k] == x) {
            cnt[k]++;
        } else if (val[k] < x) {
            insert(r[k], x);
            if (rnd[r[k]] < rnd[k])
                lrotate(k);
        } else {
            insert(l[k], x);
            if (rnd[l[k]] < rnd[k])
                rrotate(k);
        }
    }

    bool del(int& k, int x)
    {
        if (!k)
            return false;
        if (val[k] == x) {
            if (cnt[k] > 1) {
                cnt[k]--;
                size[k]--;
                return true;
            }
            if (l[k] == 0 || r[k] == 0) { // 元素已调整到链条结点或叶节点
                k = l[k] + r[k];
                return true;
            } else if (rnd[l[k]] < rnd[r[k]]) {
                rrotate(k);
                return del(k, x);
            } else {
                lrotate(k);
                return del(k, x);
            }
        } else if (val[k] < x) {
            bool succ = del(r[k], x);
            if (succ)
                size[k]--;
            return succ;
        } else {
            bool succ = del(l[k], x);
            if (succ)
                size[k]--;
            return succ;
        } // 先把结点转到能删除的位置再删除
    }

    int delall(int& k, int x)
    {
        if (!k)
            return 0;
        if (val[k] == x) {
            if (l[k] == 0 || r[k] == 0) {  // 元素已调整到链条结点或叶节点
                int diff = cnt[k];
                k = l[k] + r[k];
                return diff;
            } else if (rnd[l[k]] < rnd[r[k]]) {
                rrotate(k);
                return delall(k, x);
            } else {
                lrotate(k);
                return delall(k, x);
            }
        } else if (val[k] < x) {
            int diff = delall(r[k], x);
            size[k] -= diff;
            return diff;
        } else {
            int diff = delall(l[k], x);
            size[k] -= diff;
            return diff;
        } // 先把结点转到能删除的位置再删除
    }

    int find(int k, int x) {
        if (!k) return 0;
        if (val[k] == x) {
            return k;
        } else if (x < val[k]) {
            return find(l[k], x);
        } else {
            return find(r[k], x);
        }
    }
    int count(int k, int x) {
        k = find(k, x);
        if (!k) return 0;
        return cnt[k];
    }

    // 查元素排名:x是第几小的数
    int queryrank(int k, int x)
    {
        if (!k)
            return 0;
        if (val[k] == x)
            return size[l[k]] + 1;
        else if (x > val[k]) {
            return size[l[k]] + cnt[k] + queryrank(r[k], x);
        } else
            return queryrank(l[k], x);
    }
    // 查排名元素:第x小数
    int querynum(int k, int x)
    {
        if (!k)
            return 0; // 返回空
        if (x <= size[l[k]])
            return querynum(l[k], x);
        else if (x > size[l[k]] + cnt[k])
            return querynum(r[k], x - size[l[k]] - cnt[k]);
        else
            return val[k];
    }
    // 查前驱:刚好比x小的元素
    void querypre(int k, int x)
    {
        if (!k)
            return;
        if (val[k] < x)
            ans = k, querypre(r[k], x);
        else
            querypre(l[k], x);
    }
    // 查后继:刚好比x大的元素
    void querysub(int k, int x)
    {
        if (!k)
            return;
        if (val[k] > x)
            ans = k, querysub(l[k], x);
        else
            querysub(r[k], x);
    }
} T;

int main()
{
    srand(123);
    int n;
    scanf("%d", &n);
    int opt, x;
    for (int i = 1; i <= n; i++) {
        scanf("%d%d", &opt, &x);
        switch (opt) {
        case 1:
            T.insert(T.rt, x);
            break;
        case 2:
            printf("del %d is %d\n", x, T.del(T.rt, x));
            break;
        case 3:
            printf("delall %d count %d\n", x, T.delall(T.rt, x));
            break;
        case 4:
            printf("rank of %d is %d\n", x, T.queryrank(T.rt, x));
            break;
        case 5:
            printf("value of rank %d is %d\n", x, T.querynum(T.rt, x));
            break;
        case 6:
            T.ans = 0;
            T.querypre(T.rt, x);
            printf("previous value of %d is %d\n", x, T.val[T.ans]);
            break;
        case 7:
            T.ans = 0;
            T.querysub(T.rt, x);
            printf("successor value of %d is %d\n", x, T.val[T.ans]);
            break;
        default:
            printf("invalid opt %d\n", opt);
        }
    }
    return 0;
}

注意,这份代码也并不完美,存在一个问题:
添加元素的时候在数组末尾添加,删除元素并不会回收所占数组位置,形成一个个“空洞”,造成空间浪费。
每次删除的时候,都主动用最后一个元素来填补空洞是不可接受的,这会让 delete 操作的时间复杂度上升到 O(N);可以接收的解决办法是记录空洞个数,当“”空洞率”达到一定阈值之后启动整理,一次性压缩所有空洞。
当然,写题的时候是不需要考虑这么多的。

参考文献

【最简单的平衡树】Treap

标签:ref   时间   数据结构   简单的   添加元素   参考   previous   code   就会   

原文地址:https://www.cnblogs.com/zhcpku/p/14703745.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!