最近整理了Pytorch中5个常用的 张量乘法 函数和用法,建议收藏学习。 1. 张量的维度在开始今天的学习之前,我们需要先学习一个知识点,即张量的维度 。它包括两方面内容,其一是 维度个数 ,其二是 维度大小 。维度个数可通过张量 .ndim
属性查看,维度大小可通过 .shape
或 .size()
查看。
>>> a=torch.arange(6).reshape(2,3) >>> a tensor([[0, 1, 2], [3, 4, 5]]) >>> a.ndim 2 >>> a.shape torch.Size([2, 3]) >>> a.size() torch.Size([2, 3])
比如上面的张量a:维度个数为2,代表a是一个二维张量;维度大小为[2,3],代表第0维的维度大小为2,第1维为3。
2. torch.matmul接下来,我们先学习最复杂也最灵活的torch.matmul [1] 函数。
2.1 概览功能 :matmul函数实现的是矩阵乘法 ,更确切地说,是"混合" 矩阵乘法 。
参数 :
out
(张量):结果张量,等同于matmul函数的返回值。
返回值 :张量。
2.2 示例代码matmul函数的行为根据输入张量的不同大体可以分为5种情形(以下统称case)。所以,我们也将通过5个case下的示例代码来学习这个函数。
(1) case1若两个张量均为一维张量 ,则执行向量点积 操作,等价于调用torch.dot 函数。
比如,下面我们创建了两个一维张量a、b,维度大小均为2。
>>> a=torch.randn(1) >>> b=torch.randn(1) >>> a.ndim 1 >>> b.ndim 1 >>> a.size() torch.Size([2]) >>> b.size() torch.Size([2]) >>> a tensor([0.8411]) >>> b tensor([-1.1787])
然后,分别对他们进行matmul和dot操作。从结果比对来看,两个操作是等价的,最终生成的都是scalar标量 。
>>> c1=torch.matmul(a,b) >>> c2=torch.dot(a,b) >>> c1.equal(c2) True >>> c1 tensor(-0.9914) >>> c1.ndim 0 >>> c1.size() torch.Size([])
(2) case2若两个张量均为二维张量 ,则执行矩阵乘法 操作,等价于调用torch.mm 函数。
比如下面的例子,a、b均为2维张量,维度大小分别为[2,2]、[2,3]。a.size()[1]=b.size()[0]
满足矩阵乘法约束 ,通过matmul函数或mm函数,我们将获得2维张量,维度大小为[2,3]。
>>> a=torch.randn(2,2) >>> b=torch.randn(2,3) >>> a.ndim 2 >>> b.ndim 2 >>> a.size() torch.Size([2, 2]) >>> b.size() torch.Size([2, 3]) >>> c1=torch.matmul(a,b) >>> c2=torch.mm(a,b) >>> c1.equal(c2) True >>> c1.size() torch.Size([2, 3]) >>> c1.ndim 2
(3) case3若第一个张量为一维张量 ,假设维度为[k],第二个张量为二维张量 ,假设维度为[k,p]。第一个张量会在左边 进行维度扩展 ,维度变为[1,k],然后再进行矩阵乘法,获得维度为[1,p]的张量,然后再去掉扩展的维度 ,最后结果张量维度为[p]。
比如,a是维度大小为[3]的一维矩阵,b是维度大小为[3,4]的二维矩阵,结果张量c1是一维张量,维度大小为[4]。
>>> a=torch.arange(1,4) >>> b=torch.arange(2,14).reshape((3,4)) >>> a.ndim 1 >>> a.size() torch.Size([3]) >>> b.ndim 2 >>> b.size() torch.Size([3, 4]) >>> >>> c1=torch.matmul(a,b) >>> c1.ndim 1 >>> c1.size() torch.Size([4]) >>> >>> a tensor([1, 2, 3]) >>> b tensor([[ 2, 3, 4, 5], [ 6, 7, 8, 9], [10, 11, 12, 13]]) >>> c1 tensor([44, 50, 56, 62])
更简单地记法,可以视为线性代数 中的行向量乘矩阵 ,结果为第二个张量矩阵的行向量的线性组合 ,组合系数为第一个张量中相应的值。
>>> c2=1*b[0]+2*b[1]+3*b[2] >>> c1.equal(c2) True >>> c2 tensor([44, 50, 56, 62])
需要注意的是,虽然matmul在进行维度扩展后执行的是矩阵乘法,但这种情形下,它与torch.mm是不等价的,因为torch.mm函数严格要求输入均为二维张量。
>>> torch.mm(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: matrices expected, got 1D, 2D tensors at ../aten/src/TH/generic/THTensorMath.cpp:131
(4) case4若第一个张量为二维张量 ,假设维度为[k,n],第二个张量为一维张量 ,假设维度为[n]。第二个张量会在右边 进行维度扩展 ,维度变为[n,1],然后再执行矩阵乘法,获得维度为[k,1]的张量,最后再去掉扩展的维度 ,获得维度为[k]的结果张量。
比如,a是维度大小为[3,2]的二维矩阵,b是维度大小为[2]的一维矩阵,结果张量c1是一维张量,维度大小为[3]。
>>> a=torch.arange(1,7).reshape(3,2) >>> b=torch.arange(1,3) >>> a.ndim 2 >>> >>> a.size() torch.Size([3, 2]) >>> b.ndim 1 >>> b.size() torch.Size([2]) >>> c1=torch.matmul(a,b) >>> c1.ndim 1 >>> c1.size() torch.Size([3]) >>> >>> a tensor([[1, 2], [3, 4], [5, 6]]) >>> b tensor([1, 2]) >>> c1 tensor([ 5, 11, 17])
更简单地,可以视为线性代数 中的矩阵乘列向量 操作,结果为第一个张量 矩阵的列向量的线性组合 ,组合系数为第二个张量中相应的值。
>>> c2=1*a[:,0]+2*a[:,-1] >>> c1.equal(c2) True >>> c2 tensor([ 5, 11, 17])
当然同case3,这种情况torch.mm也是无法执行的。
>>> >>> torch.mm(a,b) Traceback (most recent call last): File "" , line 1, in IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
(5)case5如果两个张量的维度均至少为1,且其中至少一个张量维度大于2,那么torch.matmul将执行批矩阵乘法 操作:默认使用两个张量的后两维度执行矩阵乘法 ,其他维度作为batch维 。
若两个张量均为3维张量 ,矩阵个数相等(第0维大小相等 )且后两维满足矩阵乘法约束 ,那么调用torch.matmul等价于调用torch.bmm 函数。
>>> a=torch.arange(12).reshape(2,2,3) >>> a tensor([[[ 0, 1, 2], [ 3, 4, 5]], [[ 6, 7, 8], [ 9, 10, 11]]]) >>> a.size() torch.Size([2, 2, 3]) >>> a.ndim 3
>>> b=torch.arange(1,7).reshape(2,3,1) >>> b tensor([[[1], [2], [3]], [[4], [5], [6]]]) >>> b.size() torch.Size([2, 3, 1]) >>> b.ndim 3
上面的例子中,a、b均为3维张量,维度大小分别为[2,2,3]、[2,3,1]。第0维大小相等为2,后两维满足矩阵乘法约束。这种情况下,两个函数等价,获得的结果为3维张量,维度大小为[2,2,1]。
>>> c1=torch.matmul(a,b) >>> c2=torch.bmm(a,b) >>> c1.equal(c2) True >>> c1 tensor([[[ 8], [ 26]], [[107], [152]]]) >>> c1.size() torch.Size([2, 2, 1]) >>> c1.ndim 3
其他情形,torch.bmm无法执行,但torch.matmul仍可执行:两个张量的后两维需满足矩阵乘法约束,不满足的情形会进行维度扩展 (参考case3,case4),其他维则会通过广播操作 对齐。
来看下面这个例子:张量a为1维张量,维度大小为[2];张量b为3维张量,维度大小为[3,2,1]。
>>> a=torch.arange(2) >>> b=torch.arange(6).reshape(3,2,1) >>> a.ndim 1 >>> a.size() torch.Size([2]) >>> b.ndim 3 >>> b.size() torch.Size([3, 2, 1]) >>> a tensor([0, 1]) >>> b tensor([[[0], [1]], [[2], [3]], [[4], [5]]])
为了进行批矩阵乘法,a经过变换(a的其他维经过广播操作、参与矩阵计算的维则进行类似case3的维度扩展)成为维度为[3,1,2]的张量,再与b(维度大小为[3,2,1])进行批矩阵乘法获得维度为[3,1,1]的张量。最终,去掉扩展维度后,结果的维度为[3,1]。
>>> c=torch.matmul(a,b) >>> c.size() torch.Size([3, 1]) >>> c.ndim 2 >>> c tensor([[1], [3], [5]]) >>>
下面这个例子同理。b:[2]->[2,1 ] (维度扩展) ->[2,1 ,2,1] (广播操作) ,再与a进行批矩阵乘法,结果:[2,1,3,1 ]->[2,1,3] (去扩展维度)。
>>> a=torch.arange(12).reshape(2,1,3,2) >>> b=torch.arange(2) >>> c=torch.matmul(a,b) >>> a.ndim,b.ndim,c.ndim (4, 1, 3) >>> a.shape torch.Size([2, 1, 3, 2]) >>> b.shape torch.Size([2]) >>> c.shape torch.Size([2, 1, 3]) >>> a tensor([[[[ 0, 1], [ 2, 3], [ 4, 5]]], [[[ 6, 7], [ 8, 9], [10, 11]]]]) >>> b tensor([0, 1]) >>> c tensor([[[ 1, 3, 5]], [[ 7, 9, 11]]])
3. torch.dot前面在介绍torch.matmul的case1时,已经知道了torch.dot执行的是向量点积 计算,这节我们来更细节地学习torch.dot [2] 函数。
3.1 概览功能 :向量点积。
参数 :
out
(张量):结果张量,等同于dot函数的返回值。
返回值 :张量(标量)。
重点 :只支持具有相同元素个数 的两个一维张量 做点积操作。
3.2 示例代码 (1)当a、b均为一维张量且维度大小相同时>>> import torch >>> a=torch.tensor([1,2]) >>> b=torch.tensor([3,4]) >>> a.ndim==b.ndim==1 True >>> a.size()==b.size()==torch.Size([2]) True >>> >>> c=torch.dot(a,b) >>> c.ndim 0 >>> >>> a tensor([1, 2]) >>> b tensor([3, 4]) >>> c tensor(11)
(2)当a、b均为一维张量,但维度大小不一致时>>> a=torch.tensor([1,2]) >>> b=torch.tensor([3,4,5]) >>> a.ndim==b.ndim==1 True >>> a.size()==b.size() False >>> c=torch.dot(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: inconsistent tensor size, expected tensor [2] and src [3] to have the same number of elements, but got 2 and 3 elements respectively
(3)当a或b不是一维张量时>>> a=torch.arange(4).reshape(2,2) >>> a tensor([[0, 1], [2, 3]]) >>> b=torch.tensor([3,4,5,6]).reshape(2,2) >>> b tensor([[3, 4], [5, 6]]) >>> >>> a.size() torch.Size([2, 2]) >>> a.ndim 2 >>> torch.dot(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: 1D tensors expected, got 2D, 2D tensors at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:733
总结 :从以上几个示例代码的学习,我们可以明确torch.dot函数 限制/约束输入的两个张量必须均为一维张量 ,且元素个数相同 。即要求输入的两个张量a、b维度数满足a.ndim==b.ndim=1
,维度大小满足a.shape==b.shape
。
4. torch.mm和case2下的torch.matmul相同,torch.mm执行的是矩阵乘法 运算。这节,我们来一起学习torch.mm [3] 函数。
4.1 概览功能 :矩阵乘法。
参数 :
返回值: 张量 。
重点 :mm不会进行广播操作 ,它严格要求两个张量满足维度约束 。即,假设两个张量分别为a、b,要求a.size()[1]=b.size()[0]
。
4.2 示例代码 (1)当a、b均为2维张量,且满足维度约束条件下面的例子中,我们创建了维度为[1,3]和[3,2]的二维张量a、b。然后,通过torch.mm(a,b),获得了维度为[1,2]的二维张量c。
>>> import torch >>> a=torch.arange(1,4).unsqueeze(0) >>> b=torch.arange(1,7).reshape(3,2) >>> c=torch.mm(a,b) >>> >>> a.size() torch.Size([1, 3]) >>> b.size() torch.Size([3, 2]) >>> c.size() torch.Size([1, 2]) >>> a.ndim==b.ndim==c.ndim==2 True >>> >>> a tensor([[1, 2, 3]]) >>> b tensor([[1, 2], [3, 4], [5, 6]]) >>> c tensor([[22, 28]])
(2)当某一个张量为非二维张量时下面这个例子中,a是维度大小为[3,2]的二维张量,b是维度大小为[2]的一维张量。torch.mm不会进行广播操作(这里主要是指维度扩展 ),所以不会像case3中的torch.matmul可以成功执行。
>>> a = torch.arange(6).reshape(3,2) >>> a tensor([[0, 1], [2, 3], [4, 5]]) >>> a.size() torch.Size([3, 2]) >>> a.ndim 2 >>> b = torch.arange(1,3) >>> b tensor([1, 2]) >>> b.size() torch.Size([2]) >>> b.ndim 1 >>> >>> torch.mm(a,b) Traceback (most recent call last): File "" , line 1, in IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) >>> >>> >>> c=torch.matmul(a,b) >>> c tensor([ 2, 8, 14]) >>> c.size() torch.Size([3]) >>> c.ndim 1
总结 :torch.mm 函数限制/约束输入的两个张量必须均为二维张量 ,且维度满足矩阵乘法约束 。即,要求输入的两个张量a、b维度数满足a.ndim==b.ndim=2
,维度大小则满足a.size()[1]==b.size()[0]
。
5. bmm前面,通过对torch.matmul函数在case5下行为的学习,我们了解到torch.bmm [4] 实现的是批量矩阵乘法 计算。本节,我们来具体学习这个函数。
5.1 概览功能 :批量矩阵乘法。
参数 :
input
(张量):第一批矩阵,即3维张量,第0维表示批大小。
mat2
(张量):第二批矩阵,即3维张量,第0维表示批大小。
out
(张量):结果张量,等同于bmm函数返回值。也是3维张量,第0维表示批大小。
返回值 :三维张量,第0维表示批大小。
重点 :bmm不会进行广播操作 ,它严格要求两个张量均为三维张量 ,且第0维大小相等 (表示有多少个矩阵),其他两维满足矩阵乘法约束 。即,假设两个张量分别为a、b,要求a.size()[0]=b.size()[0]
且a.size()[-1]==b.size()[1]
。
5.2 示例代码 (1) 当a、b均为3维张量,且严格满足约束条件。比如,这里a、b均为3维张量,维度大小分别为[2,3,2]、[2,2,5]。a.size()[-1]==b.size()[1]==2
满足维度约束,a.size()[0]==b.size()[0]==2
说明批大小相同,具有相同个数的矩阵。这种情形下,所以 bmm与 matmul完全等价。
>>> a=torch.arange(1,13).reshape(2,3,2) >>> b=torch.arange(20).reshape(2,2,5) >>> a.ndim==b.ndim==3 True >>> a.size()[-1]==b.size()[1] >>> c1=torch.bmm(a,b) >>> c1 tensor([[[ 10, 13, 16, 19, 22], [ 20, 27, 34, 41, 48], [ 30, 41, 52, 63, 74]], [[190, 205, 220, 235, 250], [240, 259, 278, 297, 316], [290, 313, 336, 359, 382]]]) >>> c1.size() torch.Size([2, 3, 5]) >>> c2=torch.matmul(a,b) >>> c2 tensor([[[ 10, 13, 16, 19, 22], [ 20, 27, 34, 41, 48], [ 30, 41, 52, 63, 74]], [[190, 205, 220, 235, 250], [240, 259, 278, 297, 316], [290, 313, 336, 359, 382]]]) >>> c2.size() torch.Size([2, 3, 5]) >>> c2.equal(c1) True >>> c1.ndim 3 >>> torch.mm(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: matrices expected, got 3D, 3D tensors at ../aten/src/TH/generic/THTensorMath.cpp:131
(2) 批大小不同时a与b均为3维张量,且满足矩阵乘法约束,但是批大小不同。torch.bmm不会执行广播操作,所以这种情形下它无法成功执行。但支持广播操作的torch.matmul可以成功执行。
>>> a=torch.arange(1,13).reshape(2,3,2) >>> b=torch.arange(10).reshape(2,5).unsqueeze(0) >>> a.size() torch.Size([2, 3, 2]) >>> b.size() torch.Size([1, 2, 5]) >>> a.ndim==b.ndim==3 True >>> torch.bmm(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: Expected tensor to have size 2 at dimension 0, but got size 1 for argument #2 'batch2' (while checking arguments for bmm) >>> c=torch.matmul(a,b) >>> c tensor([[[ 10, 13, 16, 19, 22], [ 20, 27, 34, 41, 48], [ 30, 41, 52, 63, 74]], [[ 40, 55, 70, 85, 100], [ 50, 69, 88, 107, 126], [ 60, 83, 106, 129, 152]]]) >>> c.size() torch.Size([2, 3, 5]) >>> c.ndim 3
总结 :torch.bmm 函数 限制/约束输入的两个张量必须均为三维张量 ,其中第0维大小相同 ,其他维满足矩阵乘法约束 。
6. mul与*本节我们来学习最后一个常用的张量乘法函数 torch.mul [5] ,它与 * 等价,实现的是逐元素(即element-wise)相乘 。
6.1 概览
功能 :逐元素相乘。
参数 :
out(张量):结果张量,等同于mul函数的返回值。
返回值 :张量。
重点 :要求两个张量维度相同 ,即a.size()==b.size()
;若不同,则通过广播操 作将相乘的两个张量的维度变得相同。同时,它的广播操作还会将两个张量类型统一 。
6.2 示例代码 (1) 当a、b维度相同且类型相同时下例中,我们先创建了两个类型为torch.LongTensor的张量a、b,他们的维度均为[2,3],然后执行了两个等价的计算操作:a*b
与torch.mul(a,b)
。
>>> a=torch.LongTensor(2,3) >>> b=torch.LongTensor(2,3) >>> c1=torch.mul(a,b) >>> c2=a*b >>> c1.equal(c2) True >>> >>> a.size()==b.size()==c1.size()==torch.Size([2, 3]) True >>> a.type()==b.type()==c1.type()=="torch.LongTensor" True
(2) 当a、b向量维度相同,类型不同时今天我们所学的5个张量乘法函数中,只有torch.matmul和torch.mul支持广播操作 。torch.matmul的广播操作仅针对张量的维度,而torch.mul还支持张量的类型变换 。
下面的例子中,我们首先创建了二维张量a与一维张量b。a的维度为[2,3],类型为torch.LongTensor;b的维度为[3],类型为torch.FloatTensor。
>>> a=torch.arange(2,8).reshape(2,3) >>> a tensor([[2, 3, 4], [5, 6, 7]]) >>> a.size() torch.Size([2, 3]) >>> a.ndim 2 >>> a.type()'torch.LongTensor' >>> >>> b=torch.randn(3) >>> b tensor([ 1.1250, 0.8435, -0.5835]) >>> b.size() torch.Size([3]) >>> b.ndim 1 >>> b.type()'torch.FloatTensor' >>>
然后,我们尝试matmul操作,执行失败。错误信息提示我们a与b类型不一致。也就是说,虽然matmul支持广播操作,但仅针对张量的维度,而不包括张量类型。所以,即使两个张量满足matmul在维度上的要求,但类型不一致,也是无法正确让matmul函数执行的。
>>> torch.matmul(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'vec' in call to _th_mv
相反,mul函数则是可以正确执行的。
>>> c=torch.mul(a,b) tensor([[ 2.2500, 2.5306, -2.3339], [ 5.6249, 5.0612, -4.0844]]) >>> c.type()'torch.FloatTensor'
如果想让matmul函数正确执行,我们可以手动调整张量b的类型。
>>> b=b.type(torch.long) >>> b tensor([1, 0, 0]) >>> torch.matmul(a,b) tensor([2, 5])
(3) 当a、b向量类型相同,维度不同时下例中,a、b均是类型为torch.LongTensor的张量。
>>> a=torch.arange(1,3).unsqueeze(1) >>> b=torch.arange(1,4) >>> a.type()==b.type()=='torch.LongTensor'
二维张量a的维度大小为[2,1],一维张量b的维度大小为[3]。
>>> a.size() torch.Size([2, 1]) >>> a.ndim 2 >>> b.size() torch.Size([3]) >>> b.ndim 1
mul 通过广播操作将a、b拉伸至具有同样的shape,然后再执行逐元素乘法,最后获得二维张量c1,c1的维度大小为[2,3]。
>>> c1=torch.mul(a,b) >>> c1.size() torch.Size([2, 3]) >>> c1.ndim 2 >>> c1.type()'torch.LongTensor'
根据我们在第一节对维度的说明,我们知道广播操作 包括两个层面:
首先,若维度数不同,维度较少的张量需要在最左边进行维度扩展 ,使维度数相同。
然后,若各维度的维度大小不同,维度大小为1的张量需要在该维上复制元素 ,扩展拉伸至维度大小和另一个张量在该维上的大小相同。
为了更好更直观地理解上述对广播操作的描述,我们接下来尝试手动复现广播操作。
首先,我们先看a、b、c1各自的值:
>>> a tensor([[1], [2]]) >>> b tensor([1, 2, 3]) >>> c1 tensor([[1, 2, 3], [2, 4, 6]])
然后,我们来分析维度数。a与b维度数不同,维度为1的b比维度为2的a少一个维度,所以b需要在最左边扩展一个维度。扩展后,a与b维度数相同,均为2。
>>> b=b.unsqueeze(0) >>> b.ndim 2 >>> b.size() torch.Size([1, 3]) >>> b tensor([[1, 2, 3]])
接着我们来分析维度大小。第0维:a维度大小为2,b维度大小为1;第1维:a维度大小为1,b的维度大小为3。也就是说,a、b在两个维度上大小都不同。所以,a需要在第1维复制元素至维度大小为3,b则需要在第0维复制元素至维度大小为2。
>>> a=a.repeat_interleave(3,dim=-1) >>> a tensor([[1, 1, 1], [2, 2, 2]]) >>> a.size() torch.Size([2, 3]) >>> b=b.repeat_interleave(2,dim=0) >>> b.size() torch.Size([2, 3]) >>> b tensor([[1, 2, 3], [1, 2, 3]])
最后,我们验证下结果是否和之前的一致:
>>> a1*b1 tensor([[1, 2, 3], [2, 4, 6]])
还需要注意 :当两个张量可以通过扩展维度使维度数相同时,若两个张量在相应的维度大小上相等,或者大小不同但其中较小的大小为1时,才可以执行计算。
比如,我们将b保持不变,将a换为维度大小为[2,2]的二维张量,mul就无法正常执行了。
>>> a=torch.arange(4).reshape(2,2) >>> a tensor([[0, 1], [2, 3]]) >>> b tensor([1, 2, 3]) >>> a.size() torch.Size([2, 2]) >>> b.size() torch.Size([3]) >>> torch.mul(a,b) Traceback (most recent call last): File "" , line 1, in RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1 >>> a*b Traceback (most recent call last): File "" , line 1, in RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
总结 :不同于前面的那4个函数,torch.mul 实现的是逐元素相乘 。它可以通过广播操作 将输入的两个张量扩展成具有相同的维度大小和维度数,还可以将两个张量变为相同类型。
7. 5个张量乘法函数最后再总结下Pytorch中常用的5个张量乘法函数:
# 向量点积运算。要求输入为一维张量且类型相同、元素个数相同,输出为scaler标量。 torch.dot # 矩阵乘法运算,不支持广播操作。要求输入为二维张量且类型相同,维度大小满足矩阵乘法约束。 torch.mm# 批矩阵乘法运算,不支持广播操作。要求输入为三维张量且类型相同,第0维大小相等,后两维大小满足矩阵乘法约束。 torch.bmm# 混合矩阵乘法运算,包括向量点积、矩阵乘法、批矩阵乘法,且支持广播操作(仅针对维度)。要求输入张量类型相同,具体行为根据维度可以分五种情况。 torch.matmul# 逐元素乘法,等价于*,支持广播操作(包括维度及类型)。无特殊要求或约束。 torch.mul
参考资料 [1] torch.matmul: https://pytorch.org/docs/stable/generated/torch.matmul.html?highlight=torch%20matmul#torch.matmul
[2] torch.dot: https://pytorch.org/docs/stable/generated/torch.dot.html?highlight=torch%20dot#torch.dot
[3] torch.mm: https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=torch%20mm#torch.mm
[4] torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=torch%20bmm#torch.bmm
[5] torch.mul: https://pytorch.org/docs/stable/generated/torch.mul.html?highlight=torch%20mul#torch.mul