感知机python实现
有用请点赞,没用请差评。
欢迎分享本文,转载请保留出处。
感知机原理参考博客:https://www.cnblogs.com/huangyc/p/9706575.html
算法引用李航博士《统计学习方法》p29.
# -*- coding:utf-8 -*-
# 感知机
import numpy as np
import matplotlib.pyplot as plt
class Perceptron(object):
def __init__(self,eta=1,iter=50):
# eta:学习率;itea:最大迭代次数
self.eta=eta
self.iter=iter
# 根据现有权值和偏置预测分类
def predict(self,xi,w,b):
target=np.dot(w,xi)+b
return target
# 迭代修正权值和偏置
def interation(self,vector,label):
"""
:param vector: 训练数据向量
:param label: 训练数据的原始划分类别
:return:
"""
data_shape=vector.shape
print("data_shape",data_shape)
# 初始化权值为零向量
self.weight=np.zeros(data_shape[1])
# 初始偏置
self.bias=0
# 记录每一轮迭代还没有误分类数据
errors_point=0
# True表示还需要继续迭代
check_inter=True
n=0
print("迭代次数:%d ,初始权值:%s,初始偏置:%s" % (n,str(self.weight), str(self.bias)))
while check_inter and n<self.iter:
n+=1
errors_point = 0
for xi,yi in zip(vector,label):
xi_prediction=self.predict(xi,self.weight,self.bias)
# 感知机中的判断数据有没有被误分类的公式
if yi*xi_prediction<=0:
# 修正权值和偏置
self.weight=self.weight+self.eta*xi*yi
self.bias=self.bias+self.eta*yi
errors_point+=1
print("迭代次数:%d ,误分类点:%s,权值:%s,偏置:%s"%(n,str(xi),str(self.weight),str(self.bias)))
break
if errors_point:
# 如果还存在误分类点,则继续迭代
check_inter=True
else:
# 如果不存误分类点,则停止迭代
check_inter=False
if __name__=="__main__":
x1 = np.array([[3, 3], [4, 3], [1, 1]])
x2 = np.array([1, 1, -1])
perception=Perceptron()
perception.interation(x1,x2)