深度学习实战:tensorflow训练循环神经网络让AI创作出模仿莎士比亚风格的作品 (2)

现在,数据集已经变成了我们想要的输入和输出。

Input data: \'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou\' Target data: \'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou \'

对向量的每个索引进行一次性处理;对于第0步的输入,模型接收“F”的数值索引,并尝试预测“i”作为下一个字符。在下一个时序步骤中,它做同样的事情,但是RNN不仅考虑前面的步骤,而且还考虑它刚才预测的字符。

for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])): print("Step {:4d}".format(i)) print(" input: {} ({:s})".format(input_idx, repr(idx2char[input_idx]))) print(" expected output: {} ({:s})".format(target_idx, repr(idx2char[target_idx])))

[输出]:

Step 0
input: 18 (\'F\')
expected output: 47 (\'i\')
Step 1
input: 47 (\'i\')
expected output: 56 (\'r\')
Step 2
input: 56 (\'r\')
expected output: 57 (\'s\')
Step 3
input: 57 (\'s\')
expected output: 58 (\'t\')
Step 4
input: 58 (\'t\')
expected output: 1 (\' \')

Tensorflow的 tf.data 可以用来将文本分割成更易于管理的序列——但首先,需要将数据打乱并打包成批。

# Batch size BATCH_SIZE = 64 # Buffer size to shuffle the dataset BUFFER_SIZE = 10000 dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True) dataset

[输出]:

<BatchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

构建模型

最后,我们可以构建模型。让我们先设定一些重要的变量:

# Length of the vocabulary in chars vocab_size = len(vocab) # The embedding dimension embedding_dim = 256 # Number of RNN units rnn_units = 1024

模型将有一个嵌入层或输入层,该层将每个字符的数量映射到一个具有变量embedding_dim维数的向量。它将有一个GRU层(可以用LSTM层代替),大小为units = rnn_units。最后,输出层将是一个标准的全连接层,带有vocab_size输出。

下面的函数帮助我们快速而清晰地创建一个模型。

def build_model(vocab_size, embedding_dim, rnn_units, batch_size): model = tf.keras.Sequential([ tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]), tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer=\'glorot_uniform\'), tf.keras.layers.Dense(vocab_size) ]) return model

通过调用函数组合模型架构。

model = build_model( vocab_size = len(vocab), embedding_dim=embedding_dim, rnn_units=rnn_units, batch_size=BATCH_SIZE)

让我们总结一下我们的模型,看看有多少参数。

Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) (64, None, 256) 16640 _________________________________________________________________ gru (GRU) (64, None, 1024) 3938304 _________________________________________________________________ dense (Dense) (64, None, 65) 66625 ================================================================= Total params: 4,021,569 Trainable params: 4,021,569 Non-trainable params: 0 _________________________________________________________________

400万的参数!我们希望把它训练的久一点。

汇集

这个问题现在可以作为一个分类问题来处理。
给定先前的RNN状态和时间步长的输入,预测表示下一个字符的类。
因此,我们将附加一个稀疏分类熵损失函数和Adam优化器。

def loss(labels, logits): return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True) example_batch_loss = loss(target_example_batch, example_batch_predictions) print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)") print("scalar_loss: ", example_batch_loss.numpy().mean()) model.compile(optimizer=\'adam\', loss=loss)

[输出]:

Prediction shape: (64, 100, 65) # (batch_size, sequence_length, vocab_size)
scalar_loss: 4.1746616

配置检查点

模型训练,尤其是像莎士比亚戏剧这样的大型数据集,需要很长时间。理想情况下,我们不会为了做出预测而反复训练它。tf.keras.callbacks.ModelCheckpoint函数可以在训练期间将某些检查点的权重保存到一个文件中,该文件可以在一个空白模型被后续检索。这在训练因任何原因中断时也很方便。

# Directory where the checkpoints will be saved checkpoint_dir = \'./training_checkpoints\' # Name of the checkpoint files checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") checkpoint_callback=tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_prefix, save_weights_only=True) 最后,执行训练 EPOCHS=30 history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

这应该需要大约6个小时的时间来获得不那么令人印象深刻但更快的结果,epochs可以调整到10(任何小于5的都会完全变成垃圾)。

生成文本

冲检查点中恢复权重参数

tf.train.latest_checkpoint(checkpoint_dir)

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zwwjjy.html