Py学习  »  机器学习算法

使用机器学习 HED 网络优化 SmartCropper 边缘检测

pqpo • 4 年前 • 148 次点击  

使用机器学习 HED 网络优化 SmartCropper 边缘检测

  • 2019-08-02
  • 59
  • 0

SmartCropper 是我写的一个开源库,主要用于卡片及文档的识别与裁剪 。最近主要对 SmartCropper 进行了两次较大升级,一是升级了 OpenCV 框架到官方最新版,解决了饱为诟病的打包问题(ISSUE), 通过升级 OpenCV 自然也支持了 64 位架构(ISSUE), Google 已经向开发者下发了最后通牒:Support 64-bit architectures。二是完成了一个从初版就躺在 todo-list 里面的 feature:优化智能选区算法。

看过 《Android 端基于 OpenCV 的边框识别功能》 的同学应该知道SmartCropper 是通过 OpenCV 的 Canny 算法识别出照片的边缘线条,然后进行后续处理的,但是 Canny 算法并不能很好的提取出我们想要的边框,如果背景稍微复杂一点就会夹杂着很多非识别物体的边缘线条,对后续的处理提出了很大的挑战。

我们理想中的 Canny 算法效果应该是这样的,输入一张图片能精准的识别出我们想要的边缘线条,后续再配合 OpenCV 的线条检测功能可以很容易得识别出目标物体的位置:

很早之前就看过了 FengJian 大神的文章:《手机端运行卷积神经网络的一次实践 — 基于 TensorFlow 和 OpenCV 实现文档检测功能》,了解到 OpenCV 这种传统算法很快会进入到识别瓶颈,机器学习是一条新思路。

使用篇

网络部分使用的是 FengJian 基于 MobileNetV2 改造的 HED 网络,具体原理后面再说,相关代码位于: SmartCropper/edge_detection/

1.使用/验证模型

使用预训练好的模型识别边框:

python evaluate.py 
    --input_img test_image/test.jpg 
    --checkpoint_dir finetuning_model/ 
    --output_img test_image/result.jpg

识别结果:

注意:输入网络的图像会 resize 到 256 * 256,网络输出的图像也是 256 * 256,为了方便观测,我后期处理将图片恢复到了原始尺寸,自己测试的时候得到的是 一张 256 * 256 的图片。

这样的识别效果加上 OpenCV 的线段检测已经基本上可以定位到卡片位置了。针对一些识别不好的图片可以在原模型基础上进行 finetuning 。

2.训练数据准备以及预处理

还是以上面这张图片为例子(实际情况下,这张图片已经算识别良好,不需要 finetuning 了),开始之前,需要准备一张根据原图标注好的图片,目前没有好的标注工具(后续有时间可以做一个),暂时使用 Sketch 制作,由于导出的图片线条不是纯白的,需要使用以下脚本进行二值化处理:

python image_threshold.py 
    --input_img test_image/annotation.jpg
    --output_img test_image/annotation_threshold.jpg

上面是原图和输出图片放大后的对比图,左边是二值化的图片,右边是 Sketch 输出的图片,二值化后的图片只有黑白两种像素,那么这样就得到了一张二值化处理后的标注图:

3.模型训练

输入原始图片,和上方的标注图片开始调优训练:

python finetuning.py 
    --finetuning_dir finetuning_model/ 
    --checkpoint_dir checkpoint/ 
    --image test_image/test.jpg 
    --annotation test_image/annotation_threshold.jpg 
    --iterations 30 
    --lr 0.0004 
    --batch_size 1

脚本运行完之后会将调优后的模型保存到 –checkpoint_dir 指定的目录下 ,并且在 test_image 目录下生成一张结果图片:fine_tuning_output_img.png

左边是使用原始模型识别的输出的图片,右边是调优之后输出的图片。调优之后识别出的边缘线条正好是我们想要的。

需要特别注意的是输入参数 iterations 不能过大(设置在 15 以内),不然很容易发生过拟合,在调优的图片下拟合良好,但是原来拟合好的图片又不能正常拟合了。

当然也可以通过输入 CSV 文件进行批量训练,批量训练标注是一个很大的工作量,FengJian 的文章提出使用合成的方式生成图片进行训练,他是使用 iOS 模拟器进行合成的,我也在尝试只使用 Python 代码合成。由于单张图片的 finetuning 很容易过拟合,所以更推荐批量训练。参考:hed-tutorial-for-document-scanning

4.模型导出与使用

保存的模型为 ckpt 格式不能直接用于移动端,需要转成 tflite 格式,使用如下脚本进行转换:

python freeze_model.py 
    --checkpoint_dir checkpoint/ 
    --output_file hed_lite_model_quantize.tflite

导出过程中设置了量化处理,生成的 tflite 模型文件只有 346 KB。接着将 tflite 文件复制进 assets 目录下供 TensorFlowLite SDK 加载。SmartCropper 可以用如下方式加载自己的模型文件:

SmartCropper.buildImageDetector(this,"hed_lite_model_quantize.tflite")

具体是通过如下方式加载 assets 目录下的 tflite 模型:

MappedByteBuffer tfliteModel = loadModelFile(context, modelFile);
nterpreter.Options tfliteOptions = new Interpreter.Options();
tflite = new Interpreter(tfliteModel, tfliteOptions);

private MappedByteBuffer loadModelFile(Context activity, String modelFile) throws IOException {
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelFile);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

将输入输出数据包装成 ByteBuffer 之后就可以直接使用模型进行预测了:

public synchronized Bitmap detectImage(Bitmap bitmap) {
    if (bitmap == null) {
        return null;
    }
    imgData.clear();
    outImgData.clear();
    bitmap = Bitmap.createScaledBitmap(bitmap, desiredSize, desiredSize, false);
    convertBitmapToByteBuffer(bitmap);
    tflite.run(imgData, outImgData);
    return convertOutputBufferToBitmap(outImgData);
}

以上就是移动端应用机器学习的整个过程,包括数据预处理,训练,验证及使用模型等。

理解 HED 网络与 MobileNet 网络

HED (Holistically-nested edge detection) 是用于边缘检测的网络,功能与 Canny 算法一致,效果比 Canny 算法好。HED 网络是基于 VGG 网络改进而来的,关于 VGG 网络的介绍可以参考我的上篇文章:《深入理解 VGG 卷积神经网络》 ,以下网络部分代码参考:hed-tutorial-for-document-scanning

下面是 VGG 网络的结构图:

vgg 网络

HED 网络在 VGG 网络的基础上去除了后5层,后面的全连接层与 softmax 层主要用于分类,HED 网络只需要提取图片的特征,保留了前面的卷积层和池化层(注意:去掉最后一层池化层)。下面是 HED 网络的示意图:

hed 网络

分别提取出 VGG 网络的 conv1_2, conv2_2, conv3_3, conv4_3, conv5_3 层,这些输出层的大小分别为 [224, 224, 64],[112, 112, 128],[56, 56, 256],[28, 28, 512],[14, 14, 512],由于需要将这些层的数据和成一张图片,首先需要将深度降维到1,然后再按比例放大 1,2,4,8,16 倍使得每一层的数据大小都为 [224, 224],最终相加就可以得到尺寸为 [224, 224] 的输出图片了。

首先需要去深度,使用输出深度为 1,卷积核为 1*1 的卷积操作,得到深度为 1 的输出:

def _dsn_1x1_conv2d(inputs):
    kernel_size = [1, 1]
    outputs = tf.layers.conv2d(inputs,
                              filters=1,
                              kernel_size=[1, 1], 
                              padding='same', 
                              activation=None, ## no activation
                              use_bias=True, 
                              kernel_initializer=filter_initializer,
                              kernel_regularizer=weights_regularizer)
    return outputs

然后通过反卷积扩大输入层尺寸,反卷积是常用的上采样方法,也叫转置卷积,是一种特殊的正向卷积,先按照一定的比例通过补0来扩大输入图像的尺寸,接着旋转卷积核,再进行正向卷积。

反卷积

将前面的输出通过以下函数反采样到统一的大小,其中 filters 总为 1, upsample_factor 为相应的扩大倍数:

def _dsn_deconv2d_with_upsample_factor(inputs, filters, upsample_factor):
     kernel_size = [2 * upsample_factor, 2 * upsample_factor]
     outputs = tf.layers.conv2d_transpose(inputs,
                                         filters, 
                                         kernel_size, 
                                         strides=(upsample_factor, upsample_factor), 
                                         padding='same', 
                                         activation=None, ## no activation
                                         use_bias=True, ## use bias
                                         kernel_initializer=filter_initializer,
                                         kernel_regularizer=weights_regularizer)
     return outputs

这样每一层的尺寸均为 [224, 224, 1],实际大小还有 batch_size, 假如 batch_size 为 1,那么每一层的输出为 [1, 224, 224, 1],将 5 层输出合并:

dsn_fuse = tf.concat([dsn1, dsn2, dsn3, dsn4, dsn5], axis=3)

axis 设置为 3,那么输出为 [1, 224, 224, 5],再执行一次 1×1 convolution 得到 [1, 224, 224, 1] 的输出。最终这个输出就是边缘检测的结果了,以上就是 HED 网络正向传播的过程。

至于反向传播需要注意的是,由于一个样本中边缘像素远远小于非边缘像素,所以不能使用普通的交叉熵作为损失函数,需要引入 pos_weight:

def class_balanced_sigmoid_cross_entropy(logits, label):
    with tf.name_scope('class_balanced_sigmoid_cross_entropy'):
        count_neg = tf.reduce_sum(1.0 - label) # 样本中0的数量
        count_pos = tf.reduce_sum(label) # 样本中1的数量(远小于count_neg)
        beta = count_neg / (count_neg + count_pos)  ## e.g.  60000 / (60000 + 800) = 0.9868
        pos_weight = beta / (1.0 - beta)  ## 0.9868 / (1.0 - 0.9868) = 0.9868 / 0.0132 = 74.75
        cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=label, pos_weight=pos_weight)
        cost = tf.reduce_mean(cost * (1 - beta))
        zero = tf.equal(count_pos, 0.0)
        final_cost = tf.where(zero, 0.0, cost) 
    return final_cost

FengJian 的文章中提出使用 MobileNet 网络改造 HED 网络使得更适应于移动端的运行环境。MobileNet 与 VGG 一样也是一种 CNN(卷积神经网络) 网络,可用于图像分类, MobileNet 是 Google 针对手机等嵌入式设备提出的一种轻量级的神经网络。与 VGG 网络不同的是 MobileNet 使用了一种更轻量级的深度可分离卷积(depthwise separable convolution)代替了原来的普通卷积。

深度可分离卷积分为两个部分:Depthwise 卷积和 Pointwise(1*1 Conv) 卷积。基本结构如下:

首先通过 Depthwise 对每一个通道进行分别卷积,然后通过 Pointwise 对各通道进行结合,最终达到类似普通卷积的效果,但是计算量和参数量大大减少。将以上单元替换 HED 网络的卷积层便可得到一个更加适用于移动端的边缘检测网络。

举个例子说明一下什么是深度可分离卷积,对一张 224*224 的彩色 3 通道图片进行卷积操作,卷积核为 3*3,输出四通道矩阵,普通卷积的过程示意如下:

相同效果的深度可分离卷积将此过程分成了两步,如下图所示:

对比两个卷积过程,直观上来看深度可分离卷积的计算量明显比普通卷积少了很多。深度可分离卷积的计算量为: 3*3*3*224*224 + 3*4*224*224,普通卷积的计算量为:3*3*3*4*224*224,为深度可分离卷积的 2.77 倍,如果输出通道数较大,则最终趋近于 9 倍。

之后 Google 又推出了 MobileNet V2 网络,MobileNet V2 的主要贡献是在MobileNet V1 的基础上提出了线性瓶颈层(Linear Bottlenecks)和反转残差块(Inverted Residuals), block 示意图如下:

在 V1 的基础上主要有如下改变:

  1. 在每个单元的深度可分离卷积之前添加 Pointwise 卷积来扩张通道数。
  2. 第二次 Pointwise 未采用非线性激活,保留线性特征。
  3. 特定的 Block 里加入了残差连接。将输入与输出直接进行相加。使得网络在较深的时候依旧可以进行训练。

残差网络 ResNet 中的 residual block 是先压缩,再特征提取,最后扩张,而MobileNet V2 是先扩张,再特征提取,最后压缩,这样一来可以获取更多的特征,因此也叫做 inverted residuals.

>> 转载请注明来源:使用机器学习 HED 网络优化 SmartCropper 边缘检测

●非常感谢您的阅读,欢迎订阅微信公众号(右边扫一扫)以表达对我的认可与支持,我会在第一时间同步文章到公众号上。当然也可点击下方打赏按钮为我打赏。

免费分享,随意打赏

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/37148
 
148 次点击