社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

使用Python+OpenCV进行数据增广方法综述(附代码演练)

小白学视觉 • 2 年前 • 578 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

数据扩充是一种增加数据集多样性的技术,无需收集更多的真实数据,但仍然有助于提高模型的准确性和防止模型过度拟合。在这篇文章中,你将学习使用Python和OpenCV实现最流行和最有效的对象检测任务的数据扩充过程。
介绍的数据扩充方法包括:
  1. 随机剪裁
  2. Cutout
  3. ColorJitter
  4. 添加噪声
  5. 过滤
首先,让我们导入几个库并准备一些必要的子例程。
import os
import cv2
import numpy as np
import random

def file_lines_to_list(path):
    '''
    ### Convert Lines in TXT File to List ###
    path: path to file
    '''

    with open(path) as f:
        content = f.readlines()
    content = [(x.strip()).split() for x in content]
    return content

def get_file_name(path):
    '''
    ### Get Filename of Filepath ###
    path: path to file
    '''

    basename = os.path.basename(path)
    onlyname = os.path.splitext(basename)[0]
    return onlyname

def write_anno_to_txt (boxes, filepath):
    '''
    ### Write Annotation to TXT File ###
    boxes: format [[obj x1 y1 x2 y2],...]
    filepath: path/to/file.txt
    '''

    txt_file = open(filepath, "w")
    for box in boxes:
        print(box[0], int(box[1]), int(box[2]), int(box[3]), int(box[4]), file=txt_file)
    txt_file.close()
下图在本文中用作示例图像。

随机剪裁

随机剪裁:随机选择一个区域并将其裁剪出来,形成一个新的数据样本,被裁剪的区域应与原始图像具有相同的宽高比,以保持对象的形状。
在上图中,左边的图像是带有真实边界框的原始图像(红色部分),右边的图像是通过裁剪橙色框中的区域创建的新样本。
在新样本的标注中,去除所有与左侧图像中橙色框不重叠的对象,并将橙色框边界上的对象的坐标进行细化,使之与新样本相匹配。对原始图像进行随机裁剪的输出是新的裁剪后的图像及其注释。
def randomcrop(img, gt_boxes, scale=0.5):
    '''
    ### Random Crop ###
    img: image
    gt_boxes: format [[obj x1 y1 x2 y2],...]
    scale: percentage of cropped area
    '''

    
    # Crop image
    height, width = int(img.shape[0]*scale), int(img.shape[1]*scale)
    x = random.randint(0, img.shape[1] - int(width))
    y = random.randint(0, img.shape[0] - int(height))
    cropped = img[y:y+height, x:x+width]
    resized = cv2.resize(cropped, (img.shape[1], img.shape[0]))
    
    # Modify annotation
    new_boxes=[]
    for box in gt_boxes:
        obj_name = box[0]
        x1 = int(box[1])
        y1 = int(box[2])
        x2 = int(box[3])
        y2 = int(box[4])
        x1, x2 = x1-x, x2-x
        y1, y2 = y1-y, y2-y
        x1, y1, x2, y2 = x1/scale, y1/scale, x2/scale, y2/scale
        if (x11
and y10]) and (x2>0 and y2>0):
            if x1<0: x1=0
            if y1<0: y1=0
            if x2>img.shape[1]: x2=img.shape[1]
            if y2>img.shape[0]: y2=img.shape[0]
            new_boxes.append([obj_name, x1, y1, x2, y2])
    return resized, new_boxes

Cutout

Cutout是2017年由Terrance DeVries和Graham W. Taylor在他们的论文中介绍的,是一种简单的正则化技术,在训练过程中随机掩盖输入的正方形区域,可以用来提高卷积神经网络的鲁棒性和整体性能。这种方法不仅非常容易实现,而且表明它可以与现有形式的数据扩充和其他正则化器一起使用,进一步提高模型的性能。
  • 论文地址:https://arxiv.org/abs/1708.04552
与本文一样,我们使用了cutout来提高图像识别(分类)的精度,因此,如果我们将相同的方案部署到目标检测数据集中,可能会导致丢失目标(特别是小目标)的问题。在下图中,删除了剪切区域(黑色区域)内的大量小对象,这不符合数据增强的精神。
为了使这种方式适合对象检测,我们可以做一个简单的修改,而不是仅使用一个遮罩并将其放置在图像中的随机位置,而是随机选择一半的对象,并将裁剪应用于每个目标区域,效果更佳。增强后的图像如下图所示。
Cutout的输出是一个新生成的图像,我们不删除对象或改变图像大小,那么生成的图像的注释就是原始注释。
def cutout(img, gt_boxes, amount=0.5):
    '''
    ### Cutout ###
    img: image
    gt_boxes: format [[obj x1 y1 x2 y2],...]
    amount: num of masks / num of objects 
    '''

    out = img.copy()
    ran_select = random.sample(gt_boxes, round(amount*len(gt_boxes)))

    for box in ran_select:
        x1 = int(box[1])
        y1 = int(box[2])
        x2 = int(box[3])
        y2 = int(box[4])
        mask_w = int((x2 - x1)*0.5)
        mask_h = int((y2 - y1)*0.5)
        mask_x1 = random.randint(x1, x2 - mask_w)
        mask_y1 = random.randint(y1, y2 - mask_h)
        mask_x2 = mask_x1 + mask_w
        mask_y2 = mask_y1 + mask_h
        cv2.rectangle(out, (mask_x1, mask_y1), (mask_x2, mask_y2), (000), thickness=-1)
    return out

ColorJitter

ColorJitter是另一种简单的图像数据扩充类型,我们随机改变图像的亮度、对比度和饱和度。我相信这个“家伙”很容易被大多数读者理解。
def colorjitter(img, cj_type="b"):
    '''
    ### Different Color Jitter ###
    img: image
    cj_type: {b: brightness, s: saturation, c: constast}
    '''

    if cj_type == "b":
        # value = random.randint(-50, 50)
        value = np.random.choice(np.array([-50-40-30304050]))
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        if value >= 0:
            lim = 255 - value
            v[v > lim] = 255
            v[v <= lim] += value
        else:
            lim = np.absolute(value)
            v[v 0

            v[v >= lim] -= np.absolute(value)

        final_hsv = cv2.merge((h, s, v))
        img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
        return img
    
    elif cj_type == "s":
        # value = random.randint(-50, 50)
        value = np.random.choice(np.array([-50-40-30304050]))
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        if value >= 0:
            lim = 255 - value
            s[s > lim] = 255
            s[s <= lim] += value
        else:
            lim = np.absolute(value)
            s[s 0
            s[s >= lim] -= np.absolute(value)

        final_hsv = cv2.merge((h, s, v))
        img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
        return img
    
    elif cj_type == "c":
        brightness = 10
        contrast = random.randint(40100)
        dummy = np.int16(img)
        dummy = dummy * (contrast/127+1) - contrast + brightness
        dummy = np.clip(dummy, 0255)
        img = np.uint8(dummy)
        return img

添加噪声

通常,噪声被认为是图像中不可预料的因素,然而,有几种类型的噪声(如高斯噪声、椒盐噪声)可以用于数据扩充,在深度学习中,添加噪声是一种非常简单而有益的数据扩充方法。在下面的例子中,为了增强数据,将高斯噪声和椒盐噪声添加到原始图像中。
对于那些无法识别高斯噪声和椒盐噪声区别的人,高斯噪声的取值范围取决于配置,从0到255,因此,在RGB图像中,高斯噪声像素可以是任何颜色。相反,椒盐噪声像素只能有两个值:0或255,分别为黑色(椒)或白色(盐)。
def noisy(img, noise_type="gauss"):
    '''
    ### Adding Noise ###
    img: image
    cj_type: {gauss: gaussian, sp: salt & pepper}
    '''

    if noise_type == "gauss":
        image=img.copy() 
        mean=0
        st=0.7
        gauss = np.random.normal(mean,st,image.shape)
        gauss = gauss.astype('uint8')
        image = cv2.add(image,gauss)
        return image
    
    elif noise_type == "sp":
        image=img.copy() 
        prob = 0.05
        if len(image.shape) == 2:
            black = 0
            white = 255            
        else:
            colorspace = image.shape[2]
            if colorspace == 3:  # RGB
                black = np.array([000], dtype='uint8')
                white = np.array([255255255], dtype='uint8')
            else:  # RGBA
                black = np.array([000255], dtype='uint8')
                white = np.array([255255255255], dtype='uint8')
        probs = np.random.random(image.shape[:2])
        image[probs 2
)] = black
        image[probs > 1 - (prob / 2)] = white
        return image

过滤

本文介绍的最后一个数据扩充过程是过滤。与添加噪声类似,过滤也很简单,易于实现。在实现中使用的三种滤波类型包括模糊(均值)、高斯和中值。
def filters(img, f_type = "blur"):
    '''
    ### Filtering ###
    img: image
    f_type: {blur: blur, gaussian: gaussian, median: median}
    '''

    if f_type == "blur":
        image=img.copy()
        fsize = 9
        return cv2.blur(image,(fsize,fsize))
    
    elif f_type == "gaussian":
        image=img.copy()
        fsize = 9
        return cv2.GaussianBlur(image, (fsize, fsize), 0)
    
    elif f_type == "median":
        image=img.copy()
        fsize = 9
        return cv2.medianBlur(image, fsize)

总结

在这篇文章中,主要向大家介绍了一个关于对象检测任务中数据扩充实现的教程。你们可以在这里找到完整实现。
  • https://github.com/tranleanh/data-augmentation
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/156179
 
578 次点击