KD-tree详解 (3)

                                                                

KD-tree详解

         在查询到左儿子的时候,我们发现,现在最小的距离是 r = 10 ,当回溯到父亲节点的时候,我们发现,以目标点(10,1)为圆心,现在的最小距离 r = 10 为半径做圆,与分割平面 y = 8 相交,这时候,如果我们不在父亲节点的右儿子进行一次查找的话,就会漏掉 (10,9) 这个点,实际上,这个点才是距离目标点 (10,1) 最近的点

由于每次查询的时候可能会把左右两边的子树都查询完,所以,查询并不是简单的 log(n) 的,最坏的时候能够达到 sqrt(n)


        好了,到此,K-D tree 就差不多了,写法上与很多值得优化的地方,至于怎么把最邻近查询变换到 K 邻近查询,我们用一个数组记录一个点是否可以用来更新最近距离即可,下面贴上 K-D tree 一个模板


#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <vector> #include <string> #include <queue> #include <stack> #define INT_INF 0x3fffffff #define LL_INF 0x3fffffffffffffff #define EPS 1e-12 #define MOD 1000000007 #define PI 3.141592653579798 #define N 60000 using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef double DB; struct data { LL pos[10]; int id; } T[N] , op , point; int split[N],now,n,demension; bool use[N]; LL ans,id; DB var[10]; bool cmp(data a,data b) { return a.pos[split[now]]<b.pos[split[now]]; } void build(int L,int R) { if(L>R) return; int mid=(L+R)>>1; //求出 每一维 上面的方差 for(int pos=0;pos<demension;pos++) { DB ave=var[pos]=0.0; for(int i=L;i<=R;i++) ave+=T[i].pos[pos]; ave/=(R-L+1); for(int i=L;i<=R;i++) var[pos]+=(T[i].pos[pos]-ave)*(T[i].pos[pos]-ave); var[pos]/=(R-L+1); } //找到方差最大的那一维,用它来作为当前区间的 split_method split[now=mid]=0; for(int i=1;i<demension;i++) if(var[split[mid]]<var[i]) split[mid]=i; //对区间排排序,找到中间点 nth_element(T+L,T+mid,T+R+1,cmp); build(L,mid-1); build(mid+1,R); } void query(int L,int R) { if(L>R) return; int mid=(L+R)>>1; //求出目标点 op 到现在的根节点的距离 LL dis=0; for(int i=0;i<demension;i++) dis+=(op.pos[i]-T[mid].pos[i])*(op.pos[i]-T[mid].pos[i]); //如果当前区间的根节点能够用来更新最近距离,并且 dis 小于已经求得的 ans if(!use[T[mid].id] && dis<ans) { ans=dis; //更新最近距离 point=T[mid]; //更新取得最近距离下的点 id=T[mid].id; //更新取得最近距离的点的 id } //计算 op 到分裂平面的距离 LL radius=(op.pos[split[mid]]-T[mid].pos[split[mid]])*(op.pos[split[mid]]-T[mid].pos[split[mid]]); //对子区间进行查询 if(op.pos[split[mid]]<T[mid].pos[split[mid]]) { query(L,mid-1); if(radius<=ans) query(mid+1,R); } else { query(mid+1,R); if(radius<=ans) query(L,mid-1); } } int main() { while(scanf("%d%d",&n,&demension)!=EOF) { //读入 n 个点 for(int i=1;i<=n;i++) { for(int j=0;j<demension;j++) scanf("%I64d",&T[i].pos[j]); T[i].id=i; } build(1,n); //建树 int m,q; scanf("%d",&q); // q 个询问 while(q--) { memset(use,0,sizeof(use)); for(int i=0;i<demension;i++) scanf("%I64d",&op.pos[i]); scanf("%d",&m); printf("the closest %d points are:\n",m); while(m--) { ans=(((LL)INT_INF)*INT_INF); query(1,n); for(int i=0;i<demension;i++) { printf("%I64d",point.pos[i]); if(i==demension-1) printf("\n"); else printf(" "); } use[id]=1; } } } return 0; }

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zwgjss.html