社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

机器学习Tensorflow笔记3:Python训练MNIST模型,在Android上实现评估

悦跑圈技术 • 8 年前 • 591 次点击  

通常而言我们会通过Python编写代码训练Tensorflow,但是我们训练的数据需要实际应用起来,本文会介绍如何通过Python训练Tensorflow,训练的结果在Android上应用,当前也可以通过传输数据给服务端去识别,然后返回数据,但是这种方式实时性较差,需要上传识别数据,然后等待返回数据,在某些场景下也是适用,可以查看下面的Java中调用文章。

实战

实战的内容是基于MNIST实验,在Android平台实现识别功能。

本文是基于MNIST实验,如果还没有做过MNIST实验,那么可以先看我之前2篇文章
《机器学习Tensorflow笔记1:Hello World到MNIST实验》
《机器学习Tensorflow笔记2:超详细剖析MNIST实验》

1. Python保存训练模型

在MNIST实验中,我们是训练完成模型后马上就调用测试代码,如果我们要应用起来,就不可能在移动端去训练,我们应该把训练好的模型放在手机里面,或者通过URL下载到手机里面,所以我们需要保存我们的训练的模型。

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import gzip
import sys
import struct
import numpy

from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile

train_images_file = "MNIST_data/train-images-idx3-ubyte.gz"
train_labels_file = "MNIST_data/train-labels-idx1-ubyte.gz"
t10k_images_file = "MNIST_data/t10k-images-idx3-ubyte.gz"
t10k_labels_file = "MNIST_data/t10k-labels-idx1-ubyte.gz"


def read32(bytestream):
    # 由于网络数据的编码是大端,所以需要加上>
    dt = numpy.dtype(numpy.int32).newbyteorder('>')
    data = bytestream.read(4)
    return numpy.frombuffer(data, dt)[0]


def read_labels(filename):
    with gzip.open(filename) as bytestream:
        magic = read32(bytestream)
        numberOfLabels = read32(bytestream)
        print(magic)
        print(numberOfLabels)
        labels = numpy.frombuffer(bytestream.read(numberOfLabels), numpy.uint8)
        data = numpy.zeros((numberOfLabels, 10))
        for i in xrange(len(labels)):
            data[i][labels[i]] = 1
        bytestream.close()
    return data


def read_images(filename):
    # 把文件解压成字节流
    with gzip.open(filename) as bytestream:
        magic = read32(bytestream)
        numberOfImages = read32(bytestream)
        rows = read32(bytestream)
        columns = read32(bytestream)
        images = numpy.frombuffer(bytestream.read(numberOfImages * rows * columns), numpy.uint8)
        images.shape = (numberOfImages, rows * columns)
        images = images.astype(numpy.float32)
        images = numpy.multiply(images, 1.0 / 255.0)
        bytestream.close()
        print(magic)
        print(numberOfImages)
        print(rows)
        print(columns)
    return images


# 解析labels的内容,train_labels包含了60000个数字标签,返回60000个数字标签的数组
train_labels = read_labels(train_labels_file)
# print(labels)
train_images = read_images(train_images_file)

test_labels = read_labels(t10k_labels_file)
# print(labels)
test_images = read_images(t10k_images_file)

import tensorflow as tf

x = tf.placeholder("float", [None, 784.],name='input/x_input')
W = tf.Variable(tf.zeros([784., 10.]))
b = tf.Variable(tf.zeros([10.]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder("float",name='input/y_input')
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1200):
    batch_xs = train_images[50 * i:50 * i + 50]
    batch_ys = train_labels[50 * i:50 * i + 50]
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})


correct_prediction = tf.equal(tf.argmax(y, 1, output_type='int32', name='output'),
                              tf.argmax(y_, 1, output_type='int32'))

# correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})

# 保存训练好的模型
# 形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:  # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
    f.write(output_graph_def.SerializeToString())
sess.close()

通过简单的修改代码,就可以轻松实现保存训练模型到本地。

测试导出的模型是否可用
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
from PIL import Image

#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = Image.open("data/test_image.png")

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # x_test = x_test.reshape(1, 28 * 28)
        input_x = sess.graph.get_tensor_by_name("input/x_input:0")
        output = sess.graph.get_tensor_by_name("output:0")

        #对图片进行测试
        testImage=testImage.convert('L')
        testImage = testImage.resize((28, 28))
        test_input=np.array(testImage)
        test_input = test_input.reshape(1, 28 * 28)
        pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果
        print('模型预测结果为:',pre_num)

2. 配置项目

  1. 在app目录对于的build.gradle添加Gradle依赖,由于so文件很大,所以建议只支持arm,引入Tensorflow后,apk仅仅只增加了4.9MB,如果人工智能当做重要的业务,这个成本是值得的,后续我也会编写Tensorflow Lite的文章,体积更小,更加适合移动设备。
android {
      //...
    buildTypes {
       debug {
            minifyEnabled false
            debuggable = false  
            proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
            ndk {
                abiFilters "armeabi-v7a","x86"
            }
        }
        release {
            minifyEnabled false
            debuggable = false
            proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
            ndk {
                abiFilters "armeabi-v7a"
            }
        }
    }
}
dependencies {
    implementation 'org.tensorflow:tensorflow-android:1.8.0'
}

  1. 把上面保存好的训练模型放到Android项目中的assets文件夹中,同时把需要测试的图片放到drawable文件夹下。
├── main
│   ├── AndroidManifest.xml
│   ├── assets
│   │   └── mnist.pb
│   └── res
│       ├── drawable
│       │   └── test_image.png
test_image.png image.png
image.png
测试模型
class MainActivity : AppCompatActivity() {

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        val bitmap = BitmapFactory.decodeResource(resources, R.drawable.test_image)
        val tfi = TensorFlowInferenceInterface(assets, "mnist.pb")
        val inputData = bitmapToFloatArray(bitmap, 28f, 28f)
        tfi.feed("input/x_input", inputData, 1, 784)
        val outputNames = arrayOf("output")
        tfi.run(outputNames)
        // 用于存储模型的输出数据
        val outputs = IntArray(1)
        tfi.fetch(outputNames[0], outputs)

        imageView.setImageBitmap(bitmap)
        textView.text = "结果为:" + outputs[0]
    }

    /**
     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
     * @param bitmap 输入被测试的bitmap图片
     * @param rx 将图片缩放到指定的大小(列)->28
     * @param ry 将图片缩放到指定的大小(行)->28
     * @return   返回归一化后的一维float数组 ->28*28
     */
    private fun bitmapToFloatArray(bitmap: Bitmap, rx: Float, ry: Float): FloatArray {
        var height = bitmap.height
        var width = bitmap.width
        // 计算缩放比例
        val scaleWidth = rx / width
        val scaleHeight = ry / height
        val matrix = Matrix()
        matrix.postScale(scaleWidth, scaleHeight)
        val bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true)
        height = bitmap.height
        width = bitmap.width
        val result = FloatArray(height * width)
        var k = 0
        for (row in 0 until height) {
            for (col in 0 until width) {
                val argb = bitmap.getPixel(col, row)
                val r = Color.red(argb)
                val g = Color.green(argb)
                val b = Color.blue(argb)
                //由于是灰度图,所以r,g,b分量是相等的。
                assert(r == g && g == b)
                result[k++] = r / 255.0f
            }
        }
        return result
    }
}

布局文件




    
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:padding="10dp"
    android:orientation="vertical">

    <ImageView
        android:id="@+id/imageView"
        android:layout_width="100dp"
        android:layout_height="100dp"
        android:layout_gravity="center"
        android:scaleType="fitXY" />

    <TextView
        android:id="@+id/textView"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_marginTop="20dp"
        android:gravity="center"
        android:text="结果为:" />
</LinearLayout>
结果
image.png
源码

github.com/taoweiji/Te…


今天看啥 - 高品质阅读平台
本文地址:http://www.jintiankansha.me/t/4a6YdM7ayz
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/12480