pocket算法的python实现

  • 构造数据集的一个简便方法

sklearn.datasets.make_circles(n_samples=100, shuffle=True, noise=None, random_state=None, factor=0.8)  ,生成环形

make_moons:生成半环形图,加入一定的噪声之后,可以用于含有噪声的二分类问题

  • 对pocket的分析

这个算法比PLA更加保守,在发现一个错误点之后,需要判断W更新后的分类结果是否优于更新前的,如果是,才选择更新。这个算法的收敛速度要慢于PLA。而且终止条件(全分类正确)是无法实现的,所以一般就是设置迭代次数。

import numpy as np
import matplotlib.pyplot as plt
import sklearn
import sklearn.datasets

# def create_train():
#     data=np.array([[3,-3],[4,-1.5],[2,-2],[3.5,-1],[5,0],[1,1],[1,2],[0,1],[2,2],[4,3]],dtype=float)
#     label=np.array([1,1,1,1,1,-1,-1,-1,-1,-1])
#     return data,label

def create_train(n_samples,noise,shuffle=True):
    np.random.seed(0) #保证每次生成的数据都一样

    data,label=sklearn.datasets.make_moons(n_samples,shuffle,noise)
    label=2*label-1
    return data,label

def pocket(data,label,iter_num):
    data_num,feature_num=data.shape
    w=np.zeros(feature_num+1)
    a=np.ones([data_num,1])
    data=np.concatenate((a,data),1)#为训练数据增加常数项
    for j in range(iter_num):
        for i in range(data_num):
            if(1):
                y=np.sign([email protected][i])
                if(y!=label[i]):
                    w_=w+label[i]*data[i]
                    #false_num+=1
                    if(predict(data,label,w)>predict(data,label,w_)):
                        w=w_
    return w


def predict(data,label,w):
    n=data.shape[0]
    error=0
    for i in range(n):
        if(label[i]*[email protected][i]<=0):
            error+=1
    return error

def plot_dots(data,label,w):
    data_num,feature_num=data.shape
    xcord1=[]
    ycord1=[]
    xcord2=[]
    ycord2=[]
    for i in range(data_num):
        if (label[i]==1):
            xcord1.append(data[i,0])
            ycord1.append(data[i,1])
        else: 
            xcord2.append(data[i,0])
            ycord2.append(data[i,1])
    plt.figure()
    plt.scatter(xcord1,ycord1,s=40,c='red',marker='s')
    plt.scatter(xcord2,ycord2,s=40,c='blue')
    plt.xlabel('x1')
    plt.ylabel('x2')
    #绘制分类线
    x = np.arange(-3.0, 8.0, 0.1)
    y = (-w[0]-w[1] * x)/w[2] 
    plt.plot(x,y)
    plt.xlim(-3, 3)
    plt.ylim(-2,2)
    plt.pause(10)
    
def plot_lines(w):
    x = range(-3.0, 8.0, 0.1)
    y = (-w[0]-w[1] * x)/w[2] 
    plt.plot(x,y)
    
def main():
    data,label=create_train()
    #plot_dots(data,label)
    w=pla(data,label)
    return w
def main():
    iter_num=10
    data,label=create_train(50,0.25)
    #data,label=create_train()
    w=pocket(data,label,iter_num)
    plot_dots(data,label,w)
    return w

if __name__=='__main__':
    main()

实验结果发现:算法不太稳定

pocket算法的python实现迭代五次和五十次的结果是一样的

pocket算法的python实现