pocket算法的python实现
- 构造数据集的一个简便方法
sklearn.datasets.make_circles(n_samples=100, shuffle=True, noise=None, random_state=None, factor=0.8) ,生成环形
make_moons:生成半环形图,加入一定的噪声之后,可以用于含有噪声的二分类问题
- 对pocket的分析
这个算法比PLA更加保守,在发现一个错误点之后,需要判断W更新后的分类结果是否优于更新前的,如果是,才选择更新。这个算法的收敛速度要慢于PLA。而且终止条件(全分类正确)是无法实现的,所以一般就是设置迭代次数。
import numpy as np
import matplotlib.pyplot as plt
import sklearn
import sklearn.datasets
# def create_train():
# data=np.array([[3,-3],[4,-1.5],[2,-2],[3.5,-1],[5,0],[1,1],[1,2],[0,1],[2,2],[4,3]],dtype=float)
# label=np.array([1,1,1,1,1,-1,-1,-1,-1,-1])
# return data,label
def create_train(n_samples,noise,shuffle=True):
np.random.seed(0) #保证每次生成的数据都一样
data,label=sklearn.datasets.make_moons(n_samples,shuffle,noise)
label=2*label-1
return data,label
def pocket(data,label,iter_num):
data_num,feature_num=data.shape
w=np.zeros(feature_num+1)
a=np.ones([data_num,1])
data=np.concatenate((a,data),1)#为训练数据增加常数项
for j in range(iter_num):
for i in range(data_num):
if(1):
y=np.sign([email protected][i])
if(y!=label[i]):
w_=w+label[i]*data[i]
#false_num+=1
if(predict(data,label,w)>predict(data,label,w_)):
w=w_
return w
def predict(data,label,w):
n=data.shape[0]
error=0
for i in range(n):
if(label[i]*[email protected][i]<=0):
error+=1
return error
def plot_dots(data,label,w):
data_num,feature_num=data.shape
xcord1=[]
ycord1=[]
xcord2=[]
ycord2=[]
for i in range(data_num):
if (label[i]==1):
xcord1.append(data[i,0])
ycord1.append(data[i,1])
else:
xcord2.append(data[i,0])
ycord2.append(data[i,1])
plt.figure()
plt.scatter(xcord1,ycord1,s=40,c='red',marker='s')
plt.scatter(xcord2,ycord2,s=40,c='blue')
plt.xlabel('x1')
plt.ylabel('x2')
#绘制分类线
x = np.arange(-3.0, 8.0, 0.1)
y = (-w[0]-w[1] * x)/w[2]
plt.plot(x,y)
plt.xlim(-3, 3)
plt.ylim(-2,2)
plt.pause(10)
def plot_lines(w):
x = range(-3.0, 8.0, 0.1)
y = (-w[0]-w[1] * x)/w[2]
plt.plot(x,y)
def main():
data,label=create_train()
#plot_dots(data,label)
w=pla(data,label)
return w
def main():
iter_num=10
data,label=create_train(50,0.25)
#data,label=create_train()
w=pocket(data,label,iter_num)
plot_dots(data,label,w)
return w
if __name__=='__main__':
main()
实验结果发现:算法不太稳定
迭代五次和五十次的结果是一样的