社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

机器学习系列-第1篇-感知机识别手写数字(mnist例子分析)

javiepong • 7 年前 • 374 次点击  

机器学习系列-第1篇-感知机识别手写数字(mnist例子分析)

系列目录:

  1. 机器学习系列-第0篇-开发工具与tensorflow环境搭建  
  2. 机器学习系列-第1篇-感知机识别手写数字(mnist例子分析)
  3. 机器学习系列-第2篇-CNN识别手写数字(mnist例子分析)


      第0篇已经搭好了开发环境,本文详细介绍用感知机识别手写数字(mnist例子)的过程, 希望依照本文的步骤,每个人能清楚理解mnist例子,动手实践。继续阅读之前,最好你已经了解以下知识点(不然你迟早会回来的):

  1. 向量和矩阵基础运算规则
  2. tensorflow基础(session,graph,tensor等)
  3. 感知机
  4. 交叉熵
  5. 最小梯度下降算法


一 图片与数据分析

mnist例子会自动下载下列数据


train-*.gz是用来训练的数据, t10k-*gz是用来测试的数据,这些数据都不是原始的图片,而是经过处理变成了二级制的文件,这里重点分析一下这个数据格式。

1 image文件

      手写数字的图片是28*28的灰度图片,图片中每个像素点的值范围是0-255(黑色是0,白色是255), 图片文件是按照这样格式写的:

 魔法值(32位)+图片数量(32位)+图片宽(32位)+图片长(32位)+ 所有图数据

(1) 魔法值: 文件标识,train-images-idx3-ubyte文件的magic值是2051

(2) 所有图数据:单张图数据28*28=784个 uint8, 所以所有图N,就是N*784个uint8

用16进制查看image文件,结果如下图所示:


2 label文件

label文件记录的是与image顺序一一对应的图片实际值,范围是0-9。文件的格式是:

魔法值(32位)+标签数量(32位)+ 所有标签数据

(1)train-labels-idx1-ubyte文件的magic值是2049

(2) 所有标签数据:每个标签是一个 uint8, 所以所有标签 就是N个uint8

用16进制查看label文件,结果如下图所示:


从上可以看到, 图片数据从ea60之后开始,前4张图片分别是 5 0 4 1,我把它们还原成图片:   

还原图片的代码如下:


通过对文件数据的分析与还原,我们能清楚知道数据格式是怎样的,帮组我们在编码过程中处理数据时不会产生困惑,同时在文章后面会用到自己制作手写字体来验证模型的准确性。


二 mnist分析

1.文件import

mnist_reader是我写的,用来读取自己制作的手写字体(在第三部分有给出代码):



2.训练过程


(1) 读取数据

   one_hot参数为True的意思是把图片28*28的二维数组处理为一维数组[784],这样处理之后,所有像素点都是一个特征输入,最终变成只是分析像素点对预测结果的影响,丢失了图片的结构信息。因为这只是练习,所以知道问题所在就好,暂时不讨论该方法的优劣。

(2)模型构建

    weight为什么是一个[784*10]的数组,因为图片数据是一个 784的数组,而每一张图片可能是 0-9这十个数字中的一个,每一张图片预测的结果有10种可能, 对应这张图片是0,1, ...9的的概率值。同理,偏置b也与实际输出的维数相同。

   x变量是在分批训练过程中用来存放一批图片,所以它没有指定行数量,让系统自动推导。

   y_变量是在分批训练过程中,用来存放一批标签数据,它与x一一对应

   y 变量是我们的模型函数f(x)=wx+b的结果 经过softmax函数处理的输出

(3)初始化session

  tensorflow的变量需要先初始化才能用。

(4)模型训练

每次100张图片为一批,循环1000次训练。每一批的数量大小怎么定没有规定,太小会导致训练出来的模型预测结果不理想,太大会需要训练很久。需要根据自己的样本数量来定批大小和循环次数。主要考虑的如何合理设定这两个值的大小, 在有限的样本数量时,尽量避免梯度下降算法 可能找到局部最小值 而非全局最小值的问题。

(5)模型评估

模型评估的规则是取预测值概率最大的对应数字与实际值相比较, 然后统计预测正常的占比,为了方便理解,假设训练9个样本, 下列代码的输出结果:

correct_prediction=tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) ->

           [ True True True False False True False False False]

tf.reduce_mean(tf.cast(correct_prediction, "float")) -> 0.44444445

根据测试样本,样本的准确率可以达到 91%左右:

training prediction: 0.9138 

training done

然后再通过不断调整 批量大小和循环次数,发现最高的准确率基本都在91%-92%之间,所以基本可以认定该模型的极限准确率在92%左右。

(6)模型保存

一个模型保存了如下4个文件:


checkpoint: 记录当前目录下有哪些模型

mnist.data-00000-of-00001:模型的所有参数值 (如w和b)

mnist.index:

mnist.meta:模型的图结构信息

三.模型使用

1.制作自己的手写数字图片

根据 第一 部分最后 输出mnist样子的图片,发现了它是黑色底,即大部分像素点都是0,组成数字的部分是白色,像素值都基本在128-255,所以参照它来做成的图片,才能用训练好的模型来识别数字。此前我自己学习过程, 按tensorflow中文社区http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html里面的例子是白底黑字,制作手写字根本识别不了。

正确的只需按照下面两步就可以制作跟mnist例子一样的图片:



最后保存为png图片就可以了.

以下是我制作的图片,0-9数字各9张,名字规则是 数字+序号,即名字的第一字符标识图片里面写的是什么数字,好处是不用写label文件,在读取图片的时候可以自动解析出label:

         



2.加载图片数据

加载图片数据的重点是理解mnist里面的数据时怎样的格式,这个我不去解释,大家自己调试一下, 看内存数据。 读取图片的代码如下:


3.识别

(1)使用训练好的模型


(2)可以不指定 specify,全部测试, 我这个分开为每个数字单独测试,看看各自的准确率


(3)结果

(4)结果分析

看到以上预测结果,惊不惊喜,意不意外? 跟训练评估时的91%准确率相差甚远。是什么原因导致的呢? 回头去看我上面制作的手写图片, 你会发现跟它例子自带的数据有非常大的区别,我制作的时候,故意把数字写的大小不一,位置不同。

再看 "mnist分析"部分第(1)点"读取数据"里面说到的, 把图片数据转为一个[784]数组,丢失图片原有的结构信息, 单纯分析各像素点的值对图片的影响, 当在28*28大小的区域里面,书写的数字大小和位置有很大差异的时候,肯定是识别不出来。

四.总结

至此,我们完成了整个感知机识别手写数字的训练和预测过程,在这过程中,对于一个新手,训练和测试的代码并不难, 难点是要理解它背后实现的基础,就是文章一开始我列出的知识点。如果你已经完全懂了本文中每一行代码的意思, 那么恭喜你, 机器学习的Hello World已经完成了。



今天看啥 - 高品质阅读平台
本文地址:http://www.jintiankansha.me/t/vsK9KNPJot
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/11520
 
374 次点击