社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

深度学习Pytorch框架Tensor张量

极市平台 • 2 年前 • 206 次点击  
↑ 点击蓝字 关注极市平台

作者 | 秦一@知乎(已授权)
来源 | https://zhuanlan.zhihu.com/p/399350505
编辑 | 极市平台

极市导读

 

本文主要介绍了Tensor的裁剪运算、索引与数据筛选、组合/拼接、切片、变形操作、填充操作和Tensor的频谱操作(傅里叶变换)。>>加入极市CV技术交流群,走在计算机视觉的最前沿

1 Tensor的裁剪运算

  • 对Tensor中的元素进行范围过滤
  • 常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理
  • torch.clamp(input, min, max, out=None) → Tensor:将输入input张量每个元素的夹紧到区间 [min,max],并返回结果到一个新张量。

2 Tensor的索引与数据筛选

  • torch.where(codition,x,y):按照条件从x和y中选出满足条件的元素组成新的tensor,输入参数condition:条件限制,如果满足条件,则选择a,否则选择b作为输出。
  • torch.gather(input,dim,index,out=None):在指定维度上按照索引赋值输出tensor
  • torch.inex_select(input,dim,index,out=None):按照指定索引赋值输出tensor
  • torch.masked_select(input,mask,out=None):按照mask输出tensor,输出为向量
  • torch.take(input,indices):将输入看成1D-tensor,按照索引得到输出tensor
  • torch.nonzero(input,out=None):输出非0元素的坐标
import torch
#torch.where

a = torch.rand(44)
b = torch.rand(44)

print(a)
print(b)

out = torch.where(a > 0.5, a, b)

print(out)
print("torch.index_select")
a = torch.rand(44)
print(a)
out = torch.index_select(a, dim=0,
                   index=torch.tensor([032]))
#dim=0按列,index取的是行
print(out, out.shape)
print("torch.gather")
a = torch.linspace(11616).view(44)

print(a)

out = torch.gather(a, dim=0,
             index=torch.tensor([[0111],
                                 [0122],
                                 [0133]]))
print(out)
print(out.shape)
#注:从0开始,第0列的第0个,第一列的第1个,第二列的第1个,第三列的第1个,,,以此类推
#dim=0, out[i, j, k] = input[index[i, j, k], j, k]
#dim=1, out[i, j, k] = input[i, index[i, j, k], k]
#dim=2, out[i, j, k] = input[i, j, index[i, j, k]]
print("torch.masked_index")
a = torch.linspace(11616).view(44)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)
print("torch.take")
a = torch.linspace(11616).view(44)

b = torch.take(a, index=torch.tensor([0151310]))

print(b)



    
#torch.nonzero
print("torch.take")
a = torch.tensor([[0120], [2301]])
out = torch.nonzero(a)
print(out)
#稀疏表示

3 Tensor的组合/拼接

  • torch.cat(seq,dim=0,out=None):按照已经存在的维度进行拼接
  • torch.stack(seq,dim=0,out=None):沿着一个新维度对输入张量序列进行连接。序列中所有的张量都应该为相同形状。
print("torch.stack")
a = torch.linspace(166).view(23)
b = torch.linspace(7126).view(23)
print(a, b)
out = torch.stack((a, b), dim=2)
print(out)
print(out.shape)

print(out[:, :, 0])
print(out[:, :, 1])

4 Tensor的切片

  • torch.chunk(tensor,chunks,dim=0):按照某个维度平均分块(最后一个可能小于平均值)
  • torch.split(tensor,split_size_or_sections,dim=0):按照某个维度依照第二个参数给出的list或者int进行分割tensor

5 Tensor的变形操作

  • torch().reshape(input,shape)
  • torch().t(input):只针对2D tensor转置
  • torch().transpose(input,dim0,dim1):交换两个维度
  • torch().squeeze(input,dim=None,out=None):去除那些维度大小为1的维度
  • torch().unbind(tensor,dim=0):去除某个维度
  • torch().unsqueeze(input,dim,out=None):在指定位置添加维度,dim=-1在最后添加
  • torch().flip(input,dims):按照给定维度翻转张量
  • torch().rot90(input,k,dims):按照指定维度和旋转次数进行张量旋转
import torch
a = torch.rand(23)
print(a)
out = torch.reshape(a, (32))
print(out)
print(a)
print(torch.flip(a, dims=[21]))

print(a)
print(a.shape)
out = torch.rot90(a, -1, dims=[02]) #顺时针旋转90°  
print(out)
print(out.shape)

6 Tensor的填充操作

  • torch.full((2,3),3.14)

7 Tensor的频谱操作(傅里叶变换)

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取最新CV干货

公众号后台回复“CVPR21检测”获取CVPR2021目标检测论文下载~


极市干货
项目/比赛:珠港澳人工智能算法大赛算法打榜
算法trick目标检测比赛中的tricks集锦从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述:一文弄懂各种loss function工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~



觉得有用麻烦给个在看啦~  
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/122340
 
206 次点击