Tensorflow-数据集简单分类(使用MNIST数据集)

一、数据集分类原理

      数据集中每一张图片包含28*28个像素,我们把这一个数组展开成一个向量,长度是28*28=784。第一个维度数字用 来索引图片,第二个维度数字用来索引每张图片中的像素点。图片里的某个像素的强度值介于0-1 之间。

Tensorflow-数据集简单分类(使用MNIST数据集)

MNIST数据集的标签是介于0-9的数字,我们要把标签转化为“one-hot vectors”。一个onehot向量除了某一位数字是1以外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]) ,标签3将表示为([0,0,0,1,0,0,0,0,0,0]) 。 

因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。

Softmax函数

我们知道MNIST的结果是0-9,我们的模型可能推测出一张图片是数字9的概率是80%,是数字8 的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模 型的经典案例。softmax模型可以用来给不同的对象分配概率。

Tensorflow-数据集简单分类(使用MNIST数据集)

 比如输出结果为[1,5,3] 

Tensorflow-数据集简单分类(使用MNIST数据集)

二、实战演练

数据集简单分类流程(自己画的,字有点丑)

Tensorflow-数据集简单分类(使用MNIST数据集)

代码如下:

#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#加载数据集
#每个批次大小
batch_size = 100#训练模型一次性放一批次去训练
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size#训练数据数量/批次大小=一共有多少批次
#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])#行784,列跟批次有关
y = tf.placeholder(tf.float32,[None,10])#数字0-9总共十个标签
#创建一个简单的神经网络输入层+输出层10个神经元
W = tf.Variable(tf.zeros([784,10]))#权值
b = tf.Variable(tf.zeros([10]))    #偏置值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)#预测值 使用softmax函数 x*W+b,转换成概率值,存在预测变量里面
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)#0.2的学习率,
#初始化变量
init = tf.global_variables_initializer()
#测试模型准确率==方法:结果存放在一个布尔型列表中
#求预测标签最大值在哪个位置,得到标签预测值跟真实值进行对比判断是否一样
correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(prediction,1))#argmax返回一维张量中最大值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#对比的结果讲布尔型转换成float型,最后求平均值
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):#迭代21个周期,把所有图片训练21次
        for batch in range(n_batch):#n_batch一共定义的批次,把每一批次100张图片训练一次
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)#获得一百张图片,标签保存在这里面
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
           
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter" + str(epoch) + ",Testing Accuracy " + str(acc))

运行截图:(最后训练准确率达到%91.44)

Tensorflow-数据集简单分类(使用MNIST数据集)