Pytorch-数学运算 (3)

trunc、frac 裁剪

1
2
3
4
5
6
7
8
9
  In[24]: a = torch.tensor(3.14)
In[25]: a.floor(),a.ceil(),a.trunc(),a.frac()
Out[25]: (tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))
In[26]: a = torch.tensor(3.499)
In[27]: a.round()
Out[27]: tensor(3.)
In[28]: a = torch.tensor(3.5)
In[29]: a.round()
Out[29]: tensor(4.)
 
clamp

近似相关2 (用的更多一些)

gradient clipping 梯度裁剪

(min) 小于min的都变为某某值

(min, max) 不在这个区间的都变为某某值

梯度爆炸:一般来说,当梯度达到100左右的时候,就已经很大了,正常在10左右,通过打印梯度的模来查看 w.grad.norm(2)

对于w的限制叫做weight clipping,对于weight gradient clipping称为 gradient clipping。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  In[30]: grad = torch.rand(2,3)*15
In[31]: grad.max()
Out[31]: tensor(10.6977)
In[32]: grad.clamp(10)
Out[32]:
tensor([[10.0000, 10.6977, 10.0000],
[10.0000, 10.0000, 10.0000]])
In[33]: grad
Out[33]:
tensor([[ 6.7738, 10.6977, 4.4314],
[ 7.8088, 4.8236, 3.6213]])
In[34]: grad.clamp(0,10)
Out[34]:
tensor([[ 6.7738, 10.0000, 4.4314],
[ 7.8088, 4.8236, 3.6213]])
 

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpwsgf.html