社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

深度学习中的类别激活热图可视化

小白学视觉 • 2 年前 • 318 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

导读

使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性的改进模型。

类别激活图(CAM)是一种用于计算机视觉分类任务的强大技术。它允许研究人员检查被分类的图像,并了解图像的哪些部分/像素对模型的最终输出有更大的贡献。

基本上,假设我们构建一个CNN,目标是将人的照片分类为“男人”和“女人”,然后我们给它提供一个新照片,它返回标签“男人”。有了CAM工具,我们就能看到图片的哪一部分最能激活“Man”类。如果我们想提高模型的准确性,必须了解需要修改哪些层,或者我们是否想用不同的方式预处理训练集图像,这将非常有用。

在本文中,我将向你展示这个过程背后的思想。为了达到这个目的,我会使用一个在ImageNet上预训练好的CNN, Resnet50。

我在这个实验中要用到的图像是,这只金毛猎犬:

首先,让我们在这张图上尝试一下我们预训练模型,让它返回三个最有可能的类别:

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as npmodel = ResNet50(weights='imagenet')img_path = 'golden.jpg'
img = image.load_img(img_path, target_size=(224224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)

print('Predicted:', decode_predictions(preds, top=3)[0])

如你所见,第一个结果恰好返回了我们正在寻找的类别:Golden retriver。

现在我们的目标是识别出我们的照片中最能激活黄金标签的部分。为此,我们将使用一种称为“梯度加权类别激活映射(Grad-CAM)”的技术(官方论文:https://arxiv.org/abs/1610.02391)。

这个想法是这样的:想象我们有一个训练好的CNN,我们给它提供一个新的图像。它将为该图像返回一个类。然后,如果我们取最后一个卷积层的输出特征图,并根据输出类别对每个通道的梯度对每个通道加权,我们就得到了一个热图,它表明了输入图像中哪些部分对该类别激活程度最大。

让我们看看使用Keras的实现。首先,让我们检查一下我们预先训练过的ResNet50的结构,以确定我们想要检查哪个层。由于网络结构很长,我将在这里只显示最后的block:

from keras.utils import plot_model
plot_model(model)

让我们使用最后一个激活层activation_49来提取我们的feature map。

golden = model.output[:, np.argmax(preds[0])]
last_conv_layer = model.get_layer('activation_49')

from keras import backend as K

grads = K.gradients(golden, last_conv_layer.output)[0]
pooled_grads = K.mean(grads, axis=(012))
iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
pooled_grads_value, conv_layer_output_value = iterate([x])
for i in range(pooled_grads.shape[0]):
    conv_layer_output_value[:, :, i] *= pooled_grads_value[i]
heatmap = np.mean(conv_layer_output_value, axis=-1)

import matplotlib.pyplot as plt

heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
plt.matshow(heatmap)

这个热图上看不出什么东西出来。因此,我们将该热图与输入图像合并如下:

import cv2
img = cv2.imread(img_path)
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
merged= heatmap * 0.4 + imgplt.imshow(merged)

如你所见,图像的某些部分(如鼻子部分)特别的指示出了输入图像的类别。

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇



下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/148619
 
318 次点击