社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

基于深度学习的车牌检测识别(Pytorch)(ResNet +Transformer)

新机器视觉 • 2 年前 • 326 次点击  

点击下方卡片,关注“新机器视觉”公众号

重磅干货,第一时间送达

车牌识别


概述


基于深度学习的车牌识别,其中,车辆检测网络直接使用YOLO侦测。而后,才是使用网络侦测车牌与识别车牌号。


车牌的侦测网络,采用的是resnet18,网络输出检测边框的仿射变换矩阵,可检测任意形状的四边形。


车牌号序列模型,采用Resnet18+transformer模型,直接输出车牌号序列。

数据集上,车牌检测使用CCPD 2019数据集,在训练检测模型的时候,会使用程序生成虚假的车牌,覆盖于数据集图片上,来加强检测的能力。


车牌号的序列识别,直接使用程序生成的车牌图片训练,并佐以适当的图像增强手段。模型的训练直接采用端到端的训练方式,输入图片,直接输出车牌号序列,损失采用CTCLoss。


一、网络模型

1、车牌的侦测网络模型:

网络代码定义如下:


该网络,相当于直接对图片划分cell,即在16X16的格子中,侦测车牌,输出的为该车牌边框的反射变换矩阵。


2、车牌号的序列识别网络:

车牌号序列识别的主干网络:采用的是ResNet18+transformer,其中有ResNet18完成对图片的编码工作,再由transformer解码为对应的字符。

网络代码定义如下:


其中的Block类的代码如下:


位置编码的代码如下:


Block类使用的自注意力代码如下:


二、数据加载

1、车牌号的数据加载

同过程序生成一组车牌号:


再通过数据增强,

主要包括:



三、训练

分别训练即可

其中,侦测网络的损失计算,如下:


侦测网络输出的反射变换矩阵,但对车牌位置的标签给的是四个角点的位置,所以需要响应转换后,做损失。其中,该cell是否有目标,使用CrossEntropyLoss,而对车牌位置损失,采用的则是L1Loss。


四、推理

1、侦测网络的推理

按照一般侦测网络,推理即可。只是,多了一步将反射变换矩阵转换为边框位置的计算。


另外,在YOLO侦测到得测量图片传入该级进行车牌检测的时候,会做一步操作。代码见下,将车辆检测框的图片扣出,然后resize到长宽均为16的整数倍。


2、序列检测网络的推理

对网络输出的序列,进行去重操作即可,如间隔标识符为“*”时:


完整代码

https://github.com/HibikiJie/LicensePlate


原文地址

https://blog.csdn.net/weixin_48866452/article/details/120319588

本文仅做学术分享,如有侵权,请联系删文。

—THE END—

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/152283
 
326 次点击