Pytorch-数学运算

引言

本篇介绍tensor的数学运算。 基本运算

add/minus/multiply/divide

matmul

pow

sqrt/rsqrt

round

基础运算

可以使用 + - * / 推荐

也可以使用 torch.add, mul, sub, div

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
  In[3]: a = torch.rand(3,4)
In[4]: b = torch.rand(4) # 使用broadcast
In[5]: a+b
Out[5]:
tensor([[0.9463, 1.3325, 1.0427, 1.3508],
[1.8552, 0.5614, 0.8546, 1.2186],
[1.4794, 1.3745, 0.7024, 1.1688]])
In[6]: torch.add(a,b)
Out[6]:
tensor([[0.9463, 1.3325, 1.0427, 1.3508],
[1.8552, 0.5614, 0.8546, 1.2186],
[1.4794, 1.3745, 0.7024, 1.1688]])
In[8]: torch.all(torch.eq((a-b),torch.sub(a,b)))
Out[8]: tensor(1, dtype=torch.uint8)
In[9]: torch.all(torch.eq((a*b),torch.mul(a,b)))
Out[9]: tensor(1, dtype=torch.uint8)
In[10]: torch.all(torch.eq((a/b),torch.div(a,b)))
Out[10]: tensor(1, dtype=torch.uint8)
 

torch.all() 判断每个位置的元素是否相同

是否存在为0的元素

1
2
3
4
  In[21]: torch.all(torch.ByteTensor([1,1,1,1]))
Out[21]: tensor(1, dtype=torch.uint8)
In[22]: torch.all(torch.ByteTensor([1,1,1,0]))
Out[22]: tensor(0, dtype=torch.uint8)
 

matmul

matmul 表示 matrix mul

* 表示的是element-wise

torch.mm(a,b) 只能计算2D 不推荐

torch.matmul(a,b) 可以计算更高维度,落脚点依旧在行与列。 推荐

@ 是matmul 的重载形式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
  In[24]: a = 3*torch.ones(2,2)
In[25]: a
Out[25]:
tensor([[3., 3.],
[3., 3.]])
In[26]: b = torch.ones(2,2)
In[27]: torch.mm(a,b)
Out[27]:
tensor([[6., 6.],
[6., 6.]])
In[28]: torch.matmul(a,b)
Out[28]:
tensor([[6., 6.],
[6., 6.]])
In[29]: [email protected]
Out[29]:
tensor([[6., 6.],
[6., 6.]])
 
例子

线性层的计算 : x @ w.t() + b

x是4张照片且已经打平了 (4, 784)

我们希望 (4, 784) —> (4, 512)

这样的话w因该是 (784, 512)

但由于pytorch默认 第一个维度是 channel-out(目标), 第二个维度是 channel-in (输入) , 所以需要用一个转置

note:.t() 只适合2D,高维用transpose

1
2
3
4
  In[31]: x = torch.rand(4,784)
In[32]: w = torch.rand(512,784)
In[33]: ([email protected]()).shape
Out[33]: torch.Size([4, 512])
 

神经网络 -> 矩阵运算 -> tensor flow

2维以上的tensor matmul

对于2维以上的matrix multiply , torch.mm(a,b)就不行了。

运算规则:只取最后的两维做矩阵乘法

对于 [b, c, h, w] 来说,b,c 是不变的,图片的大小在改变;并且也并行的计算出了b,c。也就是支持多个矩阵并行相乘

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpwsgf.html