mean shift聚类算法分析

最近看一个车道线识别的算法LaneNet,其中用到了mean shift进行聚类,然后研究了一下这个聚类算法,主要是从代码中了解的,简单记录一下自己的理解,防止以后忘记。meanshift code

使用mean shift聚类我们不用预先知道数据需要聚集为几类,算法会自动找出几个cluster。

随机数据

在开始使用mean shift算法之前先随机生成几蔟数据,方便后面验证聚类效果。

from sklearn.datasets import make_blobs
data, label = make_blobs(n_samples=500, centers=5, cluster_std=1.2, random_state=5)

这样就生成500个数据,有5个类别,使用不同颜色显示出来,可以看到有两组数据很接近,后面可以看到算法的聚类效果。

mean shift聚类算法分析

mean shift聚类

1.首先找出可能是中心点的一些坐标,做法就是把所有的数据通过np.round规整为几十类,然后把这几十类中属于每个类的点的个数大于3的保留下来,这样筛选出来大概28组可能的中心点。其实还可以用其他的方法选择中心点,或者把每个数据都当做中心点也可以。

  def get_seeds(self, data):
    if self.bin_seeding:
      binsize = self.band_width
    else:
      binsize = 1
    seed_list = []
    seeds_fre = defaultdict(int)
    for sample in data:
      seed = tuple(np.round(sample / binsize))
      seeds_fre[seed] += 1
    for seed, fre in seeds_fre.items():
      if fre >= self.min_fre:
        seed_list.append(np.array(seed))
    if not seed_list:
      raise ValueError('the bin size and min_fre are not proper')
    return seed_list

2.对这些中心点一个一个进行聚类操作。

for seed in seed_list:
  • 其他所有的数据中,找出所有与这个中心点的距离小于某个阀值的点的个数记为tmp_center_score,并且将所有的点以这个中心点为原点进行向量求和,从而得到新的中心点坐标,就是mean shift。
  • 然后用更新后的坐标与更新前的坐标比较,如果他们之间的距离小于一个阀值,就表示已经达到了中心点而不用进一步移动了。

上面两个步骤重复进行,直到不用移动中心点为止

# 对每个中心坐标重复进行
      while True:
        next_center = self.shift_center(current_center, data, tmp_center_score)
        delta_dis = np.linalg.norm(next_center - current_center, 2)
        if delta_dis < self.epsilon:
          break
        current_center = next_center

# 偏移的计算方法
  def shift_center(self, current_center, data, tmp_center_score):
    denominator = 0
    numerator = np.zeros_like(current_center)
    for ind, sample in enumerate(data):
      dis2 = self.euclidean_dis2(current_center, sample)
      if dis2 <= self.radius2:
        tmp_center_score += 1
      d = self.gaussian_kel(dis2)
      denominator += d
      numerator += d * sample
    return numerator / denominator

通过高斯核函数来计算中心点的偏移,高斯核函数的公式如下,其中h就是band_width:

mean shift聚类算法分析

  def gaussian_kel(self, dis2):
    return 1.0 / self.band_width * (2 * math.pi) ** (-1.0 / 2) * math.exp(-dis2 / (2 * self.band_width ** 2))

3.当中心点偏移好了后,跟已经做好了偏移的中心点进行比较,如果现在这个中心点的距离与之前已经偏移好的某个中心点的距离小于一个阀值band_width,然后判断这两个中心点谁的tmp_center_score更大,如果新的中心点的center score更大,就用新的中心点的信息替换旧的中心点信息。

      for i in range(len(self.centers)):
        print(i)
        if np.linalg.norm(current_center - self.centers[i], 2) < self.band_width:
          if tmp_center_score > self.center_score[i]:
            self.centers[i] = current_center
            self.center_score[i] = tmp_center_score
          break
      else:
        self.centers.append(current_center)
        self.center_score.append(tmp_center_score)

4.通过以上步骤就找到了所有中心点,然后对中心点进行聚类,每个点与某个中心点距离最小就属于这一类。

  def classify(self, data):
    center_arr = np.array(self.centers)
    for i in range(self.N):
      delta = center_arr - data[i]
      dis2 = np.sum(delta * delta, axis=1)
      self.labels[i] = np.argmin(dis2)
    return

最后运行的结果如下:

mean shift聚类算法分析

从聚类结果可以看到效果还不错,重叠的那一类是无法分辨的,这个很正常,因为我们自己也无法分辨重叠的那两类的区别。

感觉上面的计算过程应该还可以优化,因为里面多次多数据进行了遍历,会导致效率不高。不过这个代码主要是为了理解mean shift聚类的过程,项目过程中我们更多会使用sklearn中的mean shift算法。