kd树搜索与线性搜索对比

在做毕业设计的时候,遇到这样一个需求:

给定一万五千个点,再给定一个目标点,要求离目标点的最近点,说白了就是求“最近邻”问题

传统的方式,就是从第一个点开始算距离,把一万五千个点都算完,再取最小值

但是这样的方式比较慢,所以利用了knn算法中的kd树进行搜索

kd树的原理在李航的《机器学习》书籍中有详细的介绍,包括kd树的构建和kd树的搜索,但是李航的书里面只有kd树搜索最近邻

关于原理性的东西网上有很多资料,在此就不再累赘了

但是我想探究的是,究竟kd树搜索,跟传统线性搜索相比,能够快多少,所以我就写代码验证了

import numpy as np
import time
import random
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

N = 100000
k = 1

method1_time = []
method2_time = []

print("N=", N)
print("k=", k)
for time_i in range(3):
    X = [[np.random.random() * 100 for _ in range(2)] for _ in range(N)]
    target = [[np.random.random() * 100 for _ in range(2)] for _ in range(k)]

    X = np.array(X)
    target = np.array(target)
    # plt.scatter(X[:,0], X[:,1], color='b')
    # plt.scatter(target[:,0], target[:,1], color='r')
    # plt.show()

    # 1、kd-tree搜索
    tm = time.time()
    for i, index in enumerate(target):
        min_distance = (index[0] - X[0][0]) ** 2 + (index[1] - X[0][1]) ** 2
        min_j = 0
        for j in range(N):
            distance = (index[0] - X[j][0]) ** 2 + (index[1] - X[j][1]) ** 2
            if min_distance > distance:
                min_distance = distance
                min_j = j
        print('线性搜索结果:', np.sqrt(min_distance), min_j)
    print('线性搜索耗时: {}s'.format(time.time() - tm))
    method1_time.append(time.time() - tm)

    tm = time.time()
    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(X)
    distances, indices = nbrs.kneighbors(target)
    print('kd树搜索结果:', distances, indices)
    print('kd树搜索耗时: {}s'.format(time.time() - tm))
    method2_time.append(time.time() - tm)

print(method1_time, np.average(method1_time))
print(method2_time, np.average(method2_time))

当样本容量是十万的时候,输出如下:

kd树搜索与线性搜索对比

可以换不同的样本容量,看最后的结果,分析结果得出结论

 

最后的结论是,当样本容量比较少的时候,比如只有100个点或者1000个点,其实用线性搜索的效果可能比kd树的效果还要好,因为kd树的建立需要花费时间,而当样本容量变大的时候,kd树的优势就慢慢提现出来了。