社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  aigc

AIGC之Text2Image(一) | CLIP模型原理与代码实现详解

ArronAI • 1 年前 • 345 次点击  

前言

       目前,大模型十分活跃,openai公司呈现GPT系列,特别是ChatGPT给人深刻印象,意识到大模型厉害之处,随后推出GPT4模型,更是将大模型进一步推到一个高度,并将多模态融合技术留下深刻印象,同时,学者也对多模态融合技术研究呈现百花齐放之势。然而,多模态模型大多以CLIP所提方法或思路实现多模态融合。为此,本文将重新回顾CLIP论文相关理论,也重点梳理其源码,并附其代码供读者参考(本文会涉及VIT与BERT代码解读)。

PS:代码环境安装、重点部分代码解释(如:image encode(VIT),text encode(BERT)等)

论文地址:https://arxiv.org/pdf/2103.00020.pdf

官网源代码:https://github.com/openai/CLIP

我的代码:https://pan.baidu.com/s/1ujX19IUV0EPSIMyIcBnClA?pwd=r63z 名称为:CLIP模型.zip 提取码:r63z

一、CLIP模型原理

1.1 背景介绍

       CLIP算是在跨模态训练无监督中的开创性工作,作者提到早在2017年之后就陆续有工作提出和本文类似的想法,但数据量太少,而无好结果。本文收集4亿数据的大数据集,才得到很好的效果。这种现象最近好像在机器学习领域越来越突出。本文采用对比方式,图像使用vit结构编码、文本使用bert编码,实现视觉与语言多模态融合。

1.2 对比训练方式

       本文并非像图像caption方式,而是通过对比学习实现模型训练,我想也是这种对比学习才被目前多模态融合方法所借鉴。其采用对比学习原因如下:

  1. OpenAI是不愁计算资源的公司,喜欢将一切都gpt化(就是做生成式模型);

  2. 以往工作在1000类ImageNet数据训练方法,非常耗费资源,而CLIP要做的是开发世界的视觉识别任务,所以训练的效率对于自监督的模型至关重要;

  3. 如果任务改为给定一张图片去预测一个文本(或者给定一个文本去预测一张图片),那么训练效率将会非常低下(因为一个图片可能对应很多种说法,一个文本也对应着很多种场景);

  4. 与其做默写古诗词,不如做选择题!(只要判断哪一个文本与图片配对即可);

  5. 通过从预测任务改为只预测某个单词到只选出配对的答案,模型的训练效率一下提升了4倍;

       为此,本文训练阶段使用对比学习,让模型学习文本-图像对的匹配关系,也就是下面模型原理图中,蓝色对角线为匹配的图文对。训练集用的他们自己采集的包含4亿个图文对的 WIT数据集。

1.3 prompt推理方式

       使用某种固定prompt结构,正如训练获得特征,通过图像与prompt特征相似度匹配,实现clip分类,如:图像猫、狗二分类,可分别输入 “ A photo of cat ” 和 “ A photo of dog ”,分别与图像特征算相似度,确定其图像类被。

1.4 图像与文本编码结构

       CLIP为多模态模型是指图像维度与文本维度融合,那么需要对图像特征化与文本特征化,本文选择图像编码结构为VIT,文本编码结构为BERT。后面,代码讲解,我将有大量笔墨说明。

1.5 特征CLS token结构

       对于图像数据而言,其数据格式为[H, W, C],分别代表的是图片的通道数Channel,图片的高Height和宽Width。但很明显的是三维数据并不是Transformer所需要的。所以需要通过使用一个Embedding层来对原始的图片数据进行变换。

vit划分patch原理

       vit论文做法为将给定的一堆图片按照给定的大小分成一堆Patches。本文将输入的图片尺寸为(224×224)按照16×16大小的Patch进行划分。其中(224×224)/(16×16)=196,因此我们会得到196个patches。到这里我们可以知道每一个Patches数据的shape为[16, 16, 3]。为了满足Transformer的需求,在这里,对每个Patch进行投影变化,映射到一维向量中。即完成如下转化。[16, 16, 3]->[768],那么这样一来,就将原始的[224, 224, 3]转化为[196, 768]。

cls token原理

       在输入Transformer Encoder之前,值得注意的是需要加上[class] token。在原论文中,作者的意思是参考BERT,在上述得到的一堆tokens中插入一个专门用于分类操作的[class] token,这个[class] token是一个可训练的参数,数据格式和其他token保持一致,均为一个向量。

       以本文为例,其维度大小为[1, 768]。注意的是,这里采取的是Concat操作。即cat cls token [1, 768]与图像pathch [196, 768] -> [197, 768],此时正好变成了二维矩阵。最终将图像patch变成维度是[197, 768],而本文是将cls token放在第一位,后面分类也是通过cls token给出,如下图。

PS:cls token是一个可学习参数。

二、CLIP环境安装

       本小节介绍如何使用官网代码安装环境,而不同电脑或cuda版本不一样,所安装也有所不同,但基本不影响,我的电脑相关属性:

gpu:RTX 3060显卡

CUDA:11.1

2.1 官方环境安装

官网代码安装如下命令:

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0$ pip install ftfy regex tqdm$ pip install git+https://github.com/openai/CLIP.git

2.2 CLIP环境安装

构建虚拟环境

conda create -n clip python=3.8

安装torch相关包:

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html  -i https://pypi.mirrors.ustc.edu.cn/simple/

安装相关依赖包:

pip install ftfy regex tqdm  -i https://pypi.mirrors.ustc.edu.cn/simple/

       运行源码setup.py,其一为install运行,该操作是一个包安装虚拟环境,其二为develop运行,该操作是开发安装,指向了源代码而不是安装它的位置,方便调试,其命令如下:

# 方法一安装命令python setup.py install# 方法二安装命令python setup.py develop  # 我采用该命令

PS:建议使用方法二指向源码

2.3 CLIP运行结果

以上安装即可运行检测命令,可测试安装成功,其结果如下:

三、CLIP的Transformer结构代码解读

       无论是文本text或图像image的编码encode均大量使用Transformer结构(以VIT与BERT编码),其实质是Q K V结构,可参考文章点击这里,为此我将单独使用一小节介绍。

改代码在源码model.py文件中,其调用类如下代码:

class Transformer(nn.Module):    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):        super().__init__()        self.width = width        self.layers = layers        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor): return self.resblocks(x)

以上代码可知,该类为一个包装结构,重点是重复调用ResidualAttentionBlock结构,其结构如下代码:

class ResidualAttentionBlock(nn.Module):    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):        super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head) # n_head 头,d_model 表示维度。 self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask
def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] # 三个x表示Q K V计算值,x最后维度=n_head*d_model
def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x

从上面forward代码结构可知。

      首先使用 x = x + self.attention(self.ln_1(x)),类似残差方式x+transform后的结果,该结构类似进行了attention方法,等同于transform结构的attention,该结构也被torch所集成,可直接调用其源码,如下:

self.attn = nn.MultiheadAttention(d_model, n_head)  # n_head 头,d_model 表示维度。

      其次又调用 x = x + self.mlp(self.ln_2(x)),类似FFN结构,进行nn.Linear常规线性操作,在来一个激活GELU结构,最后在来一次线性操作,符合mlp结构,具体如下:

self.mlp = nn.Sequential(OrderedDict([            ("c_fc", nn.Linear(d_model, d_model * 4)),            ("gelu", QuickGELU()),            ("c_proj", nn.Linear(d_model * 4, d_model))        ]))

其中GELU使用QuickGELU方法,其代码如下:

class QuickGELU(nn.Module):    def forward(self, x: torch.Tensor):


    
        return x * torch.sigmoid(1.702 * x)

注:该部分结构类似transformer结构,并n次使用于image与text的编码。

四、CLIP模型主函数代码解读

CLIP模型主函数也在源码model.py文件中,如下图所示:

其中forward为模型流走向,其代码如下:

    def forward(self, image, text):        image_features = self.encode_image(image)        text_features = self.encode_text(text)
# normalized features,# 每一行sqr(a1^2+a2^2+...) image_features = image_features / image_features.norm(dim=1, keepdim=True) # [batch_img,512] text_features = text_features / text_features.norm(dim=1, keepdim=True) # [batch_text,512]
# cosine similarity as logits logit_scale = self.logit_scale.exp() # 可学习参数 logits_per_image = logit_scale * image_features @ text_features.t() # 特征相乘获得相似度 logits_per_text = logits_per_image.t() # 变成文本
# shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text

       以上可知,CLIP实现多模态融合,实际是对图像编码与文本编码,使其分别获得对应的特征表达,在将表达特征进行norm(我的理解减小偏差,是一个常规操作),随后将图像特征与对应文本特相差,便可获得相似值。

       假设以2个图像与3个文本表示,其图像特征获得对应文本特征得到相似值,简易说明如下:

        将其转职获得文本特征获得对应图像特征相似值,简易说明如下:

       其中,每个图像与文本特征表达维度为512(CLIP使用此维度),获得对应相似值如上图V**,每一行的最大值分别是CLIP模型认为最相似的,也得到图像获得文本标签,或文本获得匹配的图像。

五、CLIP的image encode代码解读

       图像编码使用VIT编码结构,将图片划分为多个patch,然后使用transformer结构编码提取特征,最终获得特征表达。接下来,我将详细阐述。

5.1、主函数代码解读

CLIP使用encode_image函数调用,如下:

image_features = self.encode_image(image)

而encode_image函数如下:

def encode_image(self, image):    return self.visual(image.type(self.dtype))

      CLIP使用图像编码有ResNet结构与VisionTransformer,前者是CNN方式,后者是transformer方式,我将以transformer方式解读,如下代码:

        if isinstance(vision_layers, (tuple, list)):            vision_heads = vision_width * 32 // 64            self.visual = ModifiedResNet(                layers=vision_layers,                output_dim=embed_dim,                heads=vision_heads,                input_resolution=image_resolution,                width=vision_width            )        else:            vision_heads = vision_width // 64            self.visual = VisionTransformer(                input_resolution=image_resolution,                patch_size=vision_patch_size,                width=vision_width,                layers=vision_layers,                heads=vision_heads,                output_dim=embed_dim            )

5.2、VisionTransformer结构代码解读

       该类是图像encode的所有精华所在,代码已有我的注释,其代码如下:

class VisionTransformer(nn.Module):    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):        super().__init__()        self.input_resolution = input_resolution        self.output_dim = output_dim        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)        # width相当于transform中的d_model        scale = width ** -0.5        self.class_embedding = nn.Parameter(scale * torch.randn(width))        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))        self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor): # x=[1,3,224,224] x = self.conv1(x) # shape = [*, width, grid, grid] # 将图片分成[32,32]个patch [1,768,7,7] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2],合并高宽 [1,768,49] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] ,更换位置 [1,49,768] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width],添加cls token[1,50,768] x = x + self.positional_embedding.to(x.dtype) # 这里位置编码是可学习的参数,可能是切了path顺序让模型自己学习吧 [1,50,768] x = self.ln_pre(x) # [1,50,768]
x = x.permute(1, 0, 2) # NLD -> LND # [pixel,b,d_model]=[50,1,768] x = self.transformer(x) # 多头transformer [50,1,768] x = x.permute(1, 0, 2) # LND -> NLD # [1,50,768]
x = self.ln_post(x[:, 0, :]) # x[:, 0, :] 将所有信息汇聚到cls token中,只需前面来做下游任务 [1,768]
if self.proj is not None: # self.proj是可学习参数,维度为[768,512] x = x @ self.proj # 通过学习参数将维度再次融合变成512特征,最终为[1,512]
return x

       以上可知,图片首先切成patch块,然后转成transformer能使用的结构,该结构可参考这里,同时,代码也有位置编码模块与特征结合,随后将所有信息汇聚到cls token,可实现下游任务,最后也通过可学习参数实现最终图像特征提取。我将在下面具体解读。

5.3、图像patch方法代码解读

       将图像划分patch实际是VIT最重要思想,意在解决训练和推理速度问题,代码层面处理,实际为卷积核与步长来处理,代码如下:

self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

       以上代码简单一句,即可将如[1,3,224,224]的一个图片分成3232尺寸(vit使用1616,这个根据模型而定,仅是一个参数而已)化成768个patch,高宽分别为7,格式为[1,768,7,7]:

# x=[1,3,224,224]x = self.conv1(x)  # shape = [*, width, grid, grid] # 将图片分成[32,32]个patch [1,768,7,7]

结果如图:

       768来源:VIT模型将输入224224尺寸化成1616像素的patch,那么每个patch为16163=768,其中3为图像通道,将每个patch投影为768维度表示,也就是本文中self.conv1通道为768的缘故。

       196与49区别:196也是来源VIT将224变成16尺寸的patch,那么共有224224/(1616)=196,而本文的patch尺寸为32,变成224224/(3232)=49。

      最终图像使用reshape将宽高7*7合并转为49的像素,成为[1,49,768],可理解1为batch在NLP中表示一句话,49为像素在NLP中表示文字,768为每个patch投影表达在NLP中表示d_model为每个文字使用d_model表达特征。其代码如下:

x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2],合并高宽 [1,768,49]x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width] ,更换位置 [1,49,768]

5.4、图像cls token编码代码解读

       cls token为VIT较为特殊设置,是一个可学习参数,我已在上面原理中介绍,不在细说,只解读实现方式,实现代码如下:

scale = width ** -0.5self.class_embedding = nn.Parameter(scale * torch.randn(width))

       将cls token嵌入,原来[1,49,768]变为[1,50,768],其代码中如下:

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width],添加cls token[1,50,768]

       若在VIT模型cls token嵌入,将[1,196,768]变成[1,197,768]。

5.5、图像位置编码代码解读

       位置编码也是一个可学习参数,实现代码如下:

self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))

      将位置编码嵌入,实际是x加上了位置信息,和我之前attention is all you need文章解释类似,该结构代码如下:

x = x + self.positional_embedding.to(x.dtype)  # 这里位置编码是可学习的参数,可能是切了path顺序让模型自己学习吧  [1,50,768]

5.6、图像cls token特征表达代码解读

       最终每张图像特征表达直接使用cls token来代替,直接取前第一个,如下图显示:

5.7、图像特殊结构代码解读

       proj特殊结构,该结构若使用将进一步将图像特征表达进行变换,该变换的self.proj是可学习参数,代码如下:

self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

      将该结构嵌入,我理解可进一步特征混合整合或组合获得图像特征表达,该结构代码如下:

if self.proj is not None:  # self.proj是可学习参数,维度为[768,512]   x = x @ self.proj  # 通过学习参数将维度再次融合变成512特征,最终为[1,512]

代码运行图像显示如下:

我个人觉得该结构可被借鉴。

六、CLIP的text encode代码解读

       文本编码使用BERT编码结构,显然使用transformer结构编码提取文本特征,最终获得特征表达。接下来,我将详细阐述。

6.1、主函数代码解读

CLIP使用encode_text函数调用,如下:

text_features = self.encode_text(text)

而encode_text函数如下:

def encode_text(self, text):    # x 每个句子前面有值,有2个特殊符号[CLS]与[Seq]    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model],[3,77,512]    x = x + self.positional_embedding.type(self.dtype)  # 位置编码直接赋可学习位置,添加位置信息[3,77,512]    x = x.permute(1, 0, 2)  # NLD -> LND,[77,3,512]    x = self.transformer(x)  # 共11个 和图像encode结构一致 [77,3,512]    x = x.permute(1, 0, 2)  # LND -> NLD,[3,77,512]    x = self.ln_final(x).type(self.dtype)    # x.shape = [batch_size, n_ctx, transformer.width]    # take features from the eot embedding (eot_token is the highest number in each sequence)    # text.argmax(dim=-1) 句子最后有一个seq字段,是最大的,因此能获得句子个数数量    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
    return x

6.2、文本token代码解读

       文本编码和我之前文章点击这里解释transform的encode基本相同,读者可查看。很多与我之前文章相同内容将不在解释,该小节说明如何使用文本token。首先文本为text_language = ["a diagram", "a dog", "a black cat"],也就是三句话,每句话大概几个词,其转码为下图计算机可识别符号方法,查阅我的博客点击这里。其代码如下:

x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model],[3,77,512]

         其结果如下图:

     以上可知,文本变成[3,77]结构,如输入text第一行文本为"a diagram",理论映射只有2个,但有四个数字,其中第一个为[CLS]值,最后一个为[Seq]值,本文设置每个句子长度为77,不足使用0表示,最终变成[3,77]表示为3个句子有77个文字(不足用0表示)。最终使用512维度表达,成为[3,77,512]结构,该部分与我之前文章内容一致,详情可参考之前文章。

6.3、文本位置编码代码解读

       位置编码也是一个可学习参数,实现代码如下:

self.positional_embedding = nn.Parameter(torch.empty(self


    
.context_length, transformer_width))

       将位置编码嵌入,实际是x加上了位置信息,和我之前attention is all you need文章解释类似,该结构代码如下:

x = x + self.positional_embedding.type(self.dtype)  # 位置编码直接赋可学习位置,添加位置信息[3,77,512]

6.4、文本特殊结构代码解读

       self.text_projection特殊结构,该结构若使用将进一步将文本特征表达进行变换,该变换的self.text_projection是可学习参数,代码如下:

self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

        将该结构嵌入,与图像变啊特殊结构类似,该结构代码如下:

# text.argmax(dim=-1) 句子最后有一个seq字段,是最大的,因此能获得句子个数数量x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

PS:x[torch.arange(x.shape[0]), text.argmax(dim=-1)]改代码表达取x为[3,77,512]维度索引分别[0,3],[1,3],[2,4],得到三个句子512维度特征表达,而每个句子都是取第二个维度77文字最大那一个,我的理解是每句话都是从第一个文字[CLS]叠加到最后一个文字[Seq],因此使用最后一个就有时序表达该句话的特征。

代码运行图像显示如下:

至于文本encode过程可参考代码走向,因其过于简单,我不在说明。

七、CLIP多模态融合代码解读

       在上面小节中我们已然知晓图像编码与文本编码方式,该小节说明获得图像、文本特征表达融合方式,其代码如下:

    def forward(self, image, text):        image_features = self.encode_image(image)        text_features = self.encode_text(text)
# normalized features,# 每一行sqr(a1^2+a2^2+...) image_features = image_features / image_features.norm(dim=1, keepdim=True) # [batch_img,512] text_features = text_features / text_features.norm(dim=1, keepdim=True) # [batch_text,512]
# cosine similarity as logits logit_scale = self.logit_scale.exp() # 可学习参数 logits_per_image = logit_scale * image_features @ text_features.t() # 特征相乘获得相似度 logits_per_text = logits_per_image.t() # 变成文本
# shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text

       从代码可知,图像特征与文本特征进行norm(其作用在上面已说明),然后求解其相似度获得图像与文本匹配结果。其过程也较为简单,可直接参考以上源码,其图示如下:

图像特征为[1,512]表示一个图像被512维度表达;

文本特征[3,512]表示3个句子分别被512维度表达;

八、CLIP推理结构解读

        推理代码官网也有提供,直接官网下载权重便可实现,我使用VIT-B-32模型结构,实现推理分类任务。该模型使用对比学习,可定义很多文本,让每个图像与多个文本特征相似匹配,匹配值越高,自然就是那个类。如同,我在上面CLIP模型主函数代码解读说明一样。其代码如下:

import torchimport clipfrom PIL import Imageimport numpy as np
def class_demo(): # 测试分类的demo device = "cuda" if torch.cuda.is_available() else "cpu" # 模型选择['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'],对应不同权重 model, preprocess = clip.load("../ViT-B-32.pt", device=device) # 载入模型 image = preprocess(Image.open("../CLIP.png")).unsqueeze(0).to(device) text_language = ["a diagram", "a dog", "a black cat"] text = clip.tokenize(text_language).to(device)
with torch.no_grad(): logits_per_image, logits_per_text = model(image, text) # 第一个值是图像,第二个是第一个的转置 probs = logits_per_image.softmax(dim=-1).cpu().numpy()
idx = np.argmax(probs, axis=1) for i in range(image.shape[0]): id = idx[i] print('image {}\tlabel\t{}:\t{}'.format(i, text_language[id],probs[i,id])) print('image {}:\t{}'.format(i, [v for v in zip(text_language,probs[i])]))

if __name__ == '__main__': class_demo()

其结果如下:

九、CLIP训练结构解读

       分类的CLIP训练实际是交叉熵方法,我们获得匹配值,可看成每个图像分别与不同文本相似值为预测类别值,进行类似交叉熵运算即可,另外反过来也可看成每个文本与分别与不同图像相似值为预测值,亦可进行交叉熵运算。我大概查了github其它训练方法,可供参考,其代码如下:

        with torch.no_grad():            for i, batch in enumerate(dataloader):                images, texts = batch                images = images.to(device=device, non_blocking=True)                texts = texts.to(device=device, non_blocking=True)
with autocast(): image_features, text_features, logit_scale = model(images, texts) # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) all_text_features.append(text_features.cpu()) logit_scale = logit_scale.mean() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t()
batch_size = images.shape[0] labels = torch.arange(batch_size, device=device).long() total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)                    ) / 2

参考文献:

[1] https://blog.csdn.net/weixin_38252409/article/details/133828294

[2] https://blog.csdn.net/caroline_wendy/article/details/125088243

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/165835
 
345 次点击