python Deep learning 学习笔记(6)

本节介绍循环神经网络及其优化
循环神经网络(RNN,recurrent neural network)处理序列的方式是,遍历所有序列元素,并保存一个状态(state),其中包含与已查看内容相关的信息。在处理两个不同的独立序列(比如两条不同的 IMDB 评论)之间,RNN 状态会被重置,因此,你仍可以将一个序列看作单个数据点,即网络的单个输入。真正改变的是,数据点不再是在单个步骤中进行处理,相反,网络内部会对序列元素进行遍历,RNN 的特征在于其时间步函数

python Deep learning 学习笔记(6)

Keras 中的循环层

from keras.layers import SimpleRNN

它接收形状为 (batch_size, timesteps, input_features) 的输入
与 Keras 中的所有循环层一样,SimpleRNN 可以在两种不同的模式下运行:一种是返回每个时间步连续输出的完整序列,即形状为 (batch_size, timesteps, output_features)的三维张量;另一种是只返回每个输入序列的最终输出,即形状为 (batch_size, output_features) 的二维张量。这两种模式由return_sequences 这个构造函数参数来控制。为了提高网络的表示能力,将多个循环层逐个堆叠有时也是很有用的。在这种情况下,你需要让所有中间层都返回完整的输出序列,即将return_sequences设置为True

简单Demo with SimpleRNN

from keras.datasets import imdb from keras.preprocessing import sequence from keras.layers import Dense, Embedding, SimpleRNN from keras.models import Sequential import matplotlib.pyplot as plt max_features = 10000 maxlen = 500 batch_size = 32 (input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features, path='E:\\study\\dataset\\imdb.npz') print(len(input_train), 'train sequences') print(len(input_test), 'test sequences') print('Pad sequences (samples x time)') input_train = sequence.pad_sequences(input_train, maxlen=maxlen) input_test = sequence.pad_sequences(input_test, maxlen=maxlen) print('input_train shape:', input_train.shape) print('input_test shape:', input_test.shape) # 用 Embedding 层和 SimpleRNN 层来训练模型 model = Sequential() model.add(Embedding(max_features, 32)) model.add(SimpleRNN(32)) model.add(Dense(1, activation='sigmoid')) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) history = model.fit(input_train, y_train, epochs=10, batch_size=128, validation_split=0.2) acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(acc) + 1) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()

结果

python Deep learning 学习笔记(6)

python Deep learning 学习笔记(6)

Keras同时还内置了另外两个循环层:LSTM GRU
SimpleRNN 的最大问题不能学到长期依赖,其原因在于梯度消失问题。LSTM 层和 GRU 层都是为了解决这个问题而设计的
LSTM(long short-term memory)层是 SimpleRNN 层的一种变体,它增加了一种携带信息跨越多个时间步的方法,保存信息以便后面使用,从而防止较早期的信号在处理过程中逐渐消失

简单Demo with LSTM

from keras.datasets import imdb from keras.preprocessing import sequence from keras.layers import Dense, Embedding, LSTM from keras.models import Sequential import matplotlib.pyplot as plt max_features = 10000 maxlen = 500 batch_size = 32 (input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features, path='E:\\study\\dataset\\imdb.npz') print(len(input_train), 'train sequences') print(len(input_test), 'test sequences') print('Pad sequences (samples x time)') input_train = sequence.pad_sequences(input_train, maxlen=maxlen) input_test = sequence.pad_sequences(input_test, maxlen=maxlen) print('input_train shape:', input_train.shape) print('input_test shape:', input_test.shape) model = Sequential() model.add(Embedding(max_features, 32)) model.add(LSTM(32)) model.add(Dense(1, activation='sigmoid')) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) history = model.fit(input_train, y_train, epochs=10, batch_size=128, validation_split=0.2) acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(acc) + 1) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zywjgz.html