Py学习  »  机器学习算法

机器学习系列-第1篇-感知机识别手写数字(mnist例子分析)

javiepong • 5 年前 • 365 次点击  

机器学习系列-第1篇-感知机识别手写数字(mnist例子分析)

系列目录:

  1. 机器学习系列-第0篇-开发工具与tensorflow环境搭建  
  2. 机器学习系列-第1篇-感知机识别手写数字(mnist例子分析)
  3. 机器学习系列-第2篇-CNN识别手写数字(mnist例子分析)


      第0篇已经搭好了开发环境,本文详细介绍用感知机识别手写数字(mnist例子)的过程, 希望依照本文的步骤,每个人能清楚理解mnist例子,动手实践。继续阅读之前,最好你已经了解以下知识点(不然你迟早会回来的):

  1. 向量和矩阵基础运算规则
  2. tensorflow基础(session,graph,tensor等)
  3. 感知机
  4. 交叉熵
  5. 最小梯度下降算法


一 图片与数据分析

mnist例子会自动下载下列数据


train-*.gz是用来训练的数据, t10k-*gz是用来测试的数据,这些数据都不是原始的图片,而是经过处理变成了二级制的文件,这里重点分析一下这个数据格式。

1 image文件

      手写数字的图片是28*28的灰度图片,图片中每个像素点的值范围是0-255(黑色是0,白色是255), 图片文件是按照这样格式写的:

 魔法值(32位)+图片数量(32位)+图片宽(32位)+图片长(32位)+ 所有图数据

(1) 魔法值: 文件标识,train-images-idx3-ubyte文件的magic值是2051

(2) 所有图数据:单张图数据28*28=784个 uint8, 所以所有图N,就是N*784个uint8

用16进制查看image文件,结果如下图所示:


2 label文件

label文件记录的是与image顺序一一对应的图片实际值,范围是0-9。文件的格式是:

魔法值(32位)+标签数量(32位)+ 所有标签数据

(1)train-labels-idx1-ubyte文件的magic值是2049

(2) 所有标签数据:每个标签是一个 uint8, 所以所有标签 就是N个uint8

用16进制查看label文件,结果如下图所示:


从上可以看到, 图片数据从ea60之后开始,前4张图片分别是 5 0 4 1,我把它们还原成图片:   

还原图片的代码如下:


通过对文件数据的分析与还原,我们能清楚知道数据格式是怎样的,帮组我们在编码过程中处理数据时不会产生困惑,同时在文章后面会用到自己制作手写字体来验证模型的准确性。


二 mnist分析

1.文件import

mnist_reader是我写的,用来读取自己制作的手写字体(在第三部分有给出代码):



2.训练过程


(1) 读取数据

   one_hot参数为True的意思是把图片28*28的二维数组处理为一维数组[784],这样处理之后,所有像素点都是一个特征输入,最终变成只是分析像素点对预测结果的影响,丢失了图片的结构信息。因为这只是练习,所以知道问题所在就好,暂时不讨论该方法的优劣。

(2)模型构建

    weight为什么是一个[784*10]的数组,因为图片数据是一个 784的数组,而每一张图片可能是 0-9这十个数字中的一个,每一张图片预测的结果有10种可能, 对应这张图片是0,1, ...9的的概率值。同理,偏置b也与实际输出的维数相同。

   x变量是在分批训练过程中用来存放一批图片,所以它没有指定行数量,让系统自动推导。

   y_变量是在分批训练过程中,用来存放一批标签数据,它与x一一对应

   y 变量是我们的模型函数f(x)=wx+b的结果 经过softmax函数处理的输出

(3)初始化session

  tensorflow的变量需要先初始化才能用。

(4)模型训练

每次100张图片为一批,循环1000次训练。每一批的数量大小怎么定没有规定,太小会导致训练出来的模型预测结果不理想,太大会需要训练很久。需要根据自己的样本数量来定批大小和循环次数。主要考虑的如何合理设定这两个值的大小, 在有限的样本数量时,尽量避免梯度下降算法 可能找到局部最小值 而非全局最小值的问题。

(5)模型评估

模型评估的规则是取预测值概率最大的对应数字与实际值相比较, 然后统计预测正常的占比,为了方便理解,假设训练9个样本, 下列代码的输出结果:

correct_prediction=tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) ->

           [ True True True False False True False False False]

tf.reduce_mean(tf.cast(correct_prediction, "float")) -> 0.44444445

根据测试样本,样本的准确率可以达到 91%左右:

training prediction: 0.9138 

training done

然后再通过不断调整 批量大小和循环次数,发现最高的准确率基本都在91%-92%之间,所以基本可以认定该模型的极限准确率在92%左右。

(6)模型保存

一个模型保存了如下4个文件:


checkpoint: 记录当前目录下有哪些模型

mnist.data-00000-of-00001:模型的所有参数值 (如w和b)

mnist.index:

mnist.meta:模型的图结构信息

三.模型使用

1.制作自己的手写数字图片

根据 第一 部分最后 输出mnist样子的图片,发现了它是黑色底,即大部分像素点都是0,组成数字的部分是白色,像素值都基本在128-255,所以参照它来做成的图片,才能用训练好的模型来识别数字。此前我自己学习过程, 按tensorflow中文社区http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html里面的例子是白底黑字,制作手写字根本识别不了。

正确的只需按照下面两步就可以制作跟mnist例子一样的图片:



最后保存为png图片就可以了.

以下是我制作的图片,0-9数字各9张,名字规则是 数字+序号,即名字的第一字符标识图片里面写的是什么数字,好处是不用写label文件,在读取图片的时候可以自动解析出label:

         



2.加载图片数据

加载图片数据的重点是理解mnist里面的数据时怎样的格式,这个我不去解释,大家自己调试一下, 看内存数据。 读取图片的代码如下:


3.识别

(1)使用训练好的模型


(2)可以不指定 specify,全部测试, 我这个分开为每个数字单独测试,看看各自的准确率


(3)结果

(4)结果分析

看到以上预测结果,惊不惊喜,意不意外? 跟训练评估时的91%准确率相差甚远。是什么原因导致的呢? 回头去看我上面制作的手写图片, 你会发现跟它例子自带的数据有非常大的区别,我制作的时候,故意把数字写的大小不一,位置不同。

再看 "mnist分析"部分第(1)点"读取数据"里面说到的, 把图片数据转为一个[784]数组,丢失图片原有的结构信息, 单纯分析各像素点的值对图片的影响, 当在28*28大小的区域里面,书写的数字大小和位置有很大差异的时候,肯定是识别不出来。

四.总结

至此,我们完成了整个感知机识别手写数字的训练和预测过程,在这过程中,对于一个新手,训练和测试的代码并不难, 难点是要理解它背后实现的基础,就是文章一开始我列出的知识点。如果你已经完全懂了本文中每一行代码的意思, 那么恭喜你, 机器学习的Hello World已经完成了。



今天看啥 - 高品质阅读平台
本文地址:http://www.jintiankansha.me/t/vsK9KNPJot
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/11520
 
365 次点击