在查询到左儿子的时候,我们发现,现在最小的距离是 r = 10 ,当回溯到父亲节点的时候,我们发现,以目标点(10,1)为圆心,现在的最小距离 r = 10 为半径做圆,与分割平面 y = 8 相交,这时候,如果我们不在父亲节点的右儿子进行一次查找的话,就会漏掉 (10,9) 这个点,实际上,这个点才是距离目标点 (10,1) 最近的点
由于每次查询的时候可能会把左右两边的子树都查询完,所以,查询并不是简单的 log(n) 的,最坏的时候能够达到 sqrt(n)
好了,到此,K-D tree 就差不多了,写法上与很多值得优化的地方,至于怎么把最邻近查询变换到 K 邻近查询,我们用一个数组记录一个点是否可以用来更新最近距离即可,下面贴上 K-D tree 一个模板
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <queue>
#include <stack>
#define INT_INF 0x3fffffff
#define LL_INF 0x3fffffffffffffff
#define EPS 1e-12
#define MOD 1000000007
#define PI 3.141592653579798
#define N 60000
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef double DB;
struct data
{
LL pos[10];
int id;
} T[N] , op , point;
int split[N],now,n,demension;
bool use[N];
LL ans,id;
DB var[10];
bool cmp(data a,data b)
{
return a.pos[split[now]]<b.pos[split[now]];
}
void build(int L,int R)
{
if(L>R) return;
int mid=(L+R)>>1;
//求出 每一维 上面的方差
for(int pos=0;pos<demension;pos++)
{
DB ave=var[pos]=0.0;
for(int i=L;i<=R;i++)
ave+=T[i].pos[pos];
ave/=(R-L+1);
for(int i=L;i<=R;i++)
var[pos]+=(T[i].pos[pos]-ave)*(T[i].pos[pos]-ave);
var[pos]/=(R-L+1);
}
//找到方差最大的那一维,用它来作为当前区间的 split_method
split[now=mid]=0;
for(int i=1;i<demension;i++)
if(var[split[mid]]<var[i]) split[mid]=i;
//对区间排排序,找到中间点
nth_element(T+L,T+mid,T+R+1,cmp);
build(L,mid-1);
build(mid+1,R);
}
void query(int L,int R)
{
if(L>R) return;
int mid=(L+R)>>1;
//求出目标点 op 到现在的根节点的距离
LL dis=0;
for(int i=0;i<demension;i++)
dis+=(op.pos[i]-T[mid].pos[i])*(op.pos[i]-T[mid].pos[i]);
//如果当前区间的根节点能够用来更新最近距离,并且 dis 小于已经求得的 ans
if(!use[T[mid].id] && dis<ans)
{
ans=dis; //更新最近距离
point=T[mid]; //更新取得最近距离下的点
id=T[mid].id; //更新取得最近距离的点的 id
}
//计算 op 到分裂平面的距离
LL radius=(op.pos[split[mid]]-T[mid].pos[split[mid]])*(op.pos[split[mid]]-T[mid].pos[split[mid]]);
//对子区间进行查询
if(op.pos[split[mid]]<T[mid].pos[split[mid]])
{
query(L,mid-1);
if(radius<=ans) query(mid+1,R);
}
else
{
query(mid+1,R);
if(radius<=ans) query(L,mid-1);
}
}
int main()
{
while(scanf("%d%d",&n,&demension)!=EOF)
{
//读入 n 个点
for(int i=1;i<=n;i++)
{
for(int j=0;j<demension;j++)
scanf("%I64d",&T[i].pos[j]);
T[i].id=i;
}
build(1,n); //建树
int m,q; scanf("%d",&q); // q 个询问
while(q--)
{
memset(use,0,sizeof(use));
for(int i=0;i<demension;i++)
scanf("%I64d",&op.pos[i]);
scanf("%d",&m);
printf("the closest %d points are:\n",m);
while(m--)
{
ans=(((LL)INT_INF)*INT_INF);
query(1,n);
for(int i=0;i<demension;i++)
{
printf("%I64d",point.pos[i]);
if(i==demension-1) printf("\n");
else printf(" ");
}
use[id]=1;
}
}
}
return 0;
}