RNN总结 (2)
使用rnn实现正弦函数预测
import torch
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
num_time_steps=50
input_size=1
hidden_size=16
output_size=1
lr=0.001
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model=nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True
)
# 这里对参数进行初始化
for p in self.model.parameters():
nn.init.normal_(p,mean=0.0,std=0.001)
self.linear=nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev = self.model(x,hidden_prev)
out=out.view(-1,hidden_size)
out=self.linear(out)
out=out.unsqueeze(dim=0)
return out,hidden_prev
model=Net()
criterion=nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(),lr)
# h0的初始值
hidden_prev = torch.zeros(1,1,hidden_size)
for iter in range(10000):
start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().reshape(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().reshape(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
# 这一句是干啥的?
hidden_prev=hidden_prev.detach()
loss=criterion(output,y)
# 这里的 语句和 optimizer.zero_grad()有什么区别
model.zero_grad()
loss.backward()
optimizer.step()
if iter % 100 ==0:
print("Iteration: {} loss {}".format(iter,loss.item()))
start = np.random.randint(6,10, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().reshape(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().reshape(1, num_time_steps - 1, 1)
prdictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
# 这里看不懂
(pred,hidden_prev) =model(input,hidden_prev)
input=pred
prdictions.append(pred.detach().numpy().ravel()[0])
x =x.data.numpy().ravel()
y=y.data.numpy()
plt.scatter(time_steps[:-1],x.ravel(),s=90)
plt.plot(time_steps[:-1],x.ravel())
plt.scatter(time_steps[1:],prdictions)
plt.show()
内容版权声明:除非注明,否则皆为本站原创文章。