RNN总结 (2)

RNN总结

使用rnn实现正弦函数预测 import torch import numpy as np from torch import nn import matplotlib.pyplot as plt num_time_steps=50 input_size=1 hidden_size=16 output_size=1 lr=0.001 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.model=nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True ) # 这里对参数进行初始化 for p in self.model.parameters(): nn.init.normal_(p,mean=0.0,std=0.001) self.linear=nn.Linear(hidden_size,output_size) def forward(self,x,hidden_prev): out,hidden_prev = self.model(x,hidden_prev) out=out.view(-1,hidden_size) out=self.linear(out) out=out.unsqueeze(dim=0) return out,hidden_prev model=Net() criterion=nn.MSELoss() optimizer=torch.optim.Adam(model.parameters(),lr) # h0的初始值 hidden_prev = torch.zeros(1,1,hidden_size) for iter in range(10000): start = np.random.randint(3, size=1)[0] time_steps = np.linspace(start, start + 10, num_time_steps) data = np.sin(time_steps) data = data.reshape(num_time_steps, 1) x = torch.tensor(data[:-1]).float().reshape(1, num_time_steps - 1, 1) y = torch.tensor(data[1:]).float().reshape(1, num_time_steps - 1, 1) output,hidden_prev=model(x,hidden_prev) # 这一句是干啥的? hidden_prev=hidden_prev.detach() loss=criterion(output,y) # 这里的 语句和 optimizer.zero_grad()有什么区别 model.zero_grad() loss.backward() optimizer.step() if iter % 100 ==0: print("Iteration: {} loss {}".format(iter,loss.item())) start = np.random.randint(6,10, size=1)[0] time_steps = np.linspace(start, start + 10, num_time_steps) data = np.sin(time_steps) data = data.reshape(num_time_steps, 1) x = torch.tensor(data[:-1]).float().reshape(1, num_time_steps - 1, 1) y = torch.tensor(data[1:]).float().reshape(1, num_time_steps - 1, 1) prdictions=[] input=x[:,0,:] for _ in range(x.shape[1]): input=input.view(1,1,1) # 这里看不懂 (pred,hidden_prev) =model(input,hidden_prev) input=pred prdictions.append(pred.detach().numpy().ravel()[0]) x =x.data.numpy().ravel() y=y.data.numpy() plt.scatter(time_steps[:-1],x.ravel(),s=90) plt.plot(time_steps[:-1],x.ravel()) plt.scatter(time_steps[1:],prdictions) plt.show()

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpzjwx.html