【学习笔记】KD-Tree

前言

早有耳闻,只是一直懒得学,搞完军训没事做,就顺便来学一学。


正题

$\texttt{参考OI Wiki}$

$\texttt{KD-Tree}$ 是一种可以高效处理 $\texttt{k}$ 维空间信息的数据结构,虽然可以被卡(最低 $O(\log n)$,最高 $O(n^{1-\frac{1}{k}})$),但在大多数情况下不失为一种优秀的部分分甚至正解算法。

$\texttt{KD-Tree}$ 具有二叉搜索树的形态,每次会根据某一维度将高维平面进行划分,例如在二维平面上,每一次可能以纵坐标或横坐标对所有的点进行划分,从而实现对整个二维平面的划分。

维护信息

$\texttt{KD-Tree}$ 中有很多需要维护在某一个节点及其子树下,每一维度的 $\texttt{max,min}$,相当于维护了一个包含该节点及其子树下所有节点的高维平面,部分需要维护删除 $\texttt{tag}$,同时对于维护该节点的 $\texttt{ls,rs}$,我们直接保存其在数组中的位置,因此可以发现空间复杂度是 $O(n)$ 的级别。当然还需要维护题目所要求的信息 废话

建树

单独考虑某一个点上的划分方法时,假设我们已经选择完了根据哪个维度进行划分,那么为了保持整个树尽量优秀,我们希望左右子树尽量平均,因此应该选择在该维度下的 中位数 所代表的点,将区间内的点进行划分。

可以使用 nth_element 函数进行操作,具体来说他可以将第 $k$ 项放在第 $k$ 的位置,并且大于他的在右边,小于的在左边。复杂度 $O(n)$。

需要注意与线段树将点存在叶子节点上不同,$\texttt{KD-Tree}$ 每次划分之后,都将中位数这个点保存在当前节点上,然后左右儿子递归。

如何选择该节点根据哪个维度进行划分?可以选择 方差最大 的维度 并不懂原理

则构建复杂度为 $O(n\log n)$。

插入

当 $\texttt{KD-Tree}$ 需要支持动态加入一个点,则需要进行插入操作。具体来说,我们在每个节点上记录该节点是根据哪一维度划分的,然后将加入的新点从根节点开始,每次根据所在节点的维度考虑应去左儿子或右儿子。

删除

需要支持动态删除时,可以使用惰性删除的方法,也就是在该点上打一个 $\texttt{tag}$,删除这个点的贡献即可。

重构

不难发现,在上面插入与删除的操作中,$\texttt{KD-Tree}$ 的形态可以发生很大的变化,因此很有可能导致不平衡。为了尽可能保证效率,当一个节点的子树下方过度不平衡时,我们需要进行重构。

考虑类似替罪羊树的方法,定义一个重构常数 $\alpha$,当一个节点某一儿子的大小占比超过 $\alpha$ ,则对其进行重构。当然,不论是插入还是删除都有可能造成不平衡的发生,因此两种操作中都需要考虑是否重构。通过重构,我们可以保持 $\texttt{KD-Tree}$ 树高保持在 $O(\log n)$ 级别。

剪枝

由于 $\texttt{KD-Tree}$ 可以被卡,剪枝是十分重要的。在某一些题目中,可以在每个节点上计算该子树的估价函数,需要注意的是该估价函数 不能比该子树的最优答案更差,否则可能剪去最优答案。则若某节点的估价函数比当前最优解要优,则进入节点,否则不用进入了。


习题

[CQOI2016]K远点对

已知平面内 $n$ 个点,求欧氏距离第 $k$ 远的点对的距离。

$n\le 10^5,k\le 100$

$\texttt{Solution}:$

首先考虑 $k=1$ 时的做法,其实也就是平面最远点对。将所有点加入 $\texttt{KD-Tree}$ 中,每次对某一个点查询与其最近的点的距离。我们可以设计一个估价函数。

考虑在某一节点上,当前最优解为 $\texttt{ans}$,则可以将估价函数设计为目标点到该节点所维护的长方形中,四个顶点中最远的顶点的距离。不难发现这样的估价函数一定不劣于该子树内的最优解,因此可以这样设计。

那拓展到第 $k$ 远呢?可以考虑维护一个小根堆,首先插入 $k$ 个 $-\infty$,然后每当到一个节点时,考虑该节点所代表的点与目标点之间的距离与堆顶的大小关系,并及时进行弹出和插入新元素。不难发现小根堆中元素个数恒为 $k$ 个,并且最后的堆顶便是第 $k$ 远的点对。

需要注意的是,该点对是无序的,因此 $k$ 需要 $\times 2$。由于点对距离不可能小于 $0$,因此初始化小根堆时,我插入的是 $0$。

$\texttt{Code:}$

//Code By CXY07
#include<bits/stdc++.h>
using namespace std;

//#define FILE
#define int long long
//#define LL long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) ksm(x,mod - 2)
#define lowbit(x) (x & (-x))
#define p2(x) ((x) * (x))

const int MAXN = 1e5 + 10;
const int INF = 2e9;
const double PI = acos(-1);
//const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;

struct NODE {
    int x,y;
} s[MAXN];

template<typename T> inline void read(T &a) {
    a = 0; char c = getchar(); int f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
    a *= f;
}

int n,k,root;
int ls[MAXN],rs[MAXN];
int xmax[MAXN],xmin[MAXN],ymax[MAXN],ymin[MAXN];
priority_queue<int,vector<int>,greater<int> > Q;

bool cmpx(NODE &a,NODE &b) {return a.x < b.x;}
bool cmpy(NODE &a,NODE &b) {return a.y < b.y;}

void pushup(int x) {
    xmax[x] = xmin[x] = s[x].x;
    ymax[x] = ymin[x] = s[x].y;//首先设为只有自己
    if(ls[x]) {
        xmax[x] = max(xmax[x],xmax[ls[x]]);
        xmin[x] = min(xmin[x],xmin[ls[x]]);
        ymax[x] = max(ymax[x],ymax[ls[x]]);
        ymin[x] = min(ymin[x],ymin[ls[x]]);
    } 
    if(rs[x]) {
        xmax[x] = max(xmax[x],xmax[rs[x]]);
        xmin[x] = min(xmin[x],xmin[rs[x]]);
        ymax[x] = max(ymax[x],ymax[rs[x]]);
        ymin[x] = min(ymin[x],ymin[rs[x]]);    
    }
}

int build(int l,int r) {
    if(l > r) return 0;
    int mid = (l + r) >> 1;
    double x1 = 0,xv = 0,y1 = 0,yv = 0;
    for(int i = l;i <= r; ++i) {
        x1 += s[i].x;
        y1 += s[i].y;
    }
    x1 /= 1. * (r - l + 1);
    y1 /= 1. * (r - l + 1);
    for(int i = l;i <= r; ++i) {
        xv += p2(s[i].x - x1);
        yv += p2(s[i].y - y1);
    }//找方差
    if(xv > yv) nth_element(s + l,s + mid,s + r + 1,cmpx);
    else nth_element(s + l,s + mid,s + r + 1,cmpy);//进行划分
    ls[mid] = build(l,mid - 1);//注意此处右边界为mid-1
    rs[mid] = build(mid + 1,r);
    pushup(mid);
    return mid;
}

int f(int a,int b) {
    return max(p2(s[a].x - xmax[b]),p2(s[a].x - xmin[b])) + 
           max(p2(s[a].y - ymax[b]),p2(s[a].y - ymin[b]));
}//估价函数

void query(int l,int r,int id) {
    if(l > r) return;
    int mid = (l + r) >> 1;
    int val = p2(s[mid].x - s[id].x) + p2(s[mid].y - s[id].y);
    if(val > Q.top()) {
        Q.pop();
        Q.push(val);
    }
    int d1 = f(id,ls[mid]),d2 = f(id,rs[mid]);//利用估价函数剪枝
    if(d1 > Q.top() && d2 > Q.top()) {
        if(d1 > d2) {
            query(l,mid - 1,id);
            if(d2 > Q.top()) query(mid + 1,r,id);
        } else {
            query(mid + 1,r,id);
            if(d1 > Q.top()) query(l,mid - 1,id);
        }
    } else {
        if(d1 > Q.top()) query(l,mid - 1,id);
        if(d2 > Q.top()) query(mid + 1,r,id);
    }
}

signed main () {
#ifdef FILE
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
#endif
    read(n); read(k);
    for(int i = 1;i <= n; ++i)
        read(s[i].x),read(s[i].y);
    k <<= 1;
    for(int i = 1;i <= k; ++i) Q.push(0);
    build(1,n);
    for(int i = 1;i <= n; ++i) query(1,n,i);//对于每一个点,更新小根堆内的值
    printf("%lld\n",Q.top());
    return 0;
}

咕咕咕

仅有 1 条评论
  1. Chaos1018 Chaos1018

    %%%
    cxy txdy

    丑小鸭挺小的呀

添加新评论