cs231n-线性分类

线性分类

线性分类是一个非常简单的学习算法,同时也有助于帮我们建立起整个神经网络和卷积网络。

你可以把神经网络想成玩乐高,搭建神经网络就是把不同的组件组合到一起,而线性分类算法就是最基本的组件之一

线性分类器

线性分类器是参数模型中最简单的例子,参数模型中,我们将总结所有对训练数据的认识并应用到参数W上面,因此预测阶段我们就可以丢开训练数据,直接用训练好的W来预测。参数模型实际上有两个不同的组成部分,如下图,x代表了输入数据,w表示权重,f(x,w)表示包含了输入数据x和参数w的函数,and this will spit out 10 numbers describing what are the scores corresponding to each of those 10 categories in CIFAR-10(有十个分类的数据集)
cs231n-线性分类
在机器学习中,最关键的是函数F的结构(可以用不同的方法组合x和w),最简单的就是F=wx,这就是线性分类器,前提是我们的图像是32x32x3(=3072)的输入值,我们最后想得到每个类别的分数即10x1,而我们通常会加一个bias term,常数向量,不与训练数据交互,而只会给我们一些数据独立的偏好值,所以当你的数据集不平衡时,比如狗的数量远多于猫的数量
那么与猫对应的偏差元素的值就会比其他高。
cs231n-线性分类

一个例子

如下图我们可以清晰的得到线性分类器是如何工作的,首先将输入图片伸展成列向量,权重矩阵W为3x4,即有三个类别。所以你可以把线性分类看作一种模板匹配,权重矩阵的每一行对应于图像的某个模板,and now the enter product or dot product between the row of the matrix and the column,computing this dot product give a similarity between this template for the classes and the pixels of our image,then bias gives you this data independence scaling and offset to each of the classes

cs231n-线性分类
实际上我们可以取权重矩阵的行向量,并且将他们还原成图像,实际上是将这些模板可视化成图像,如下图所示车辆这个类别实际想要找类似有玻璃的等特征的类别,由此我们可以看出线性分类器的问题即每个类别只能学习一个模板
所以如果这个类别出现了变体,那么线性分类器尝试求取所有不同的变体的平均值并且只使用一个单独的模板。而例如CNN网络可以取得更高的精确率是因为没有每个类别只学习一个单独的模板的限制。

cs231n-线性分类

线性分类器的另一个观点,images as points and high dimensional space.每一张图像都是类似于高维空间中一个点,而线性分类器尝试画一个线性分类面来划分一个类别。如下图,蓝色的线尝试分出飞机的类别和其他的类别。
cs231n-线性分类
线性分类器失败的例子,即无法画出直线进行分类
cs231n-线性分类