这可能是国内最全面的char RNN注释

本人在学习char RNN的过程中,遇到了很多的问题,但是依然选择一行代码一行代码的啃下来,并且注释好,我在啃代码的过程中,就想要是有一位大神在我旁边就好了,我在看代码的过程中,不懂那里,就问那里,可是现实中并没有,所有问题都要自己解决,今日我终于把代码全部弄懂了,也把代码分享给下一位想要学习char RNN的人。开源才能进步,中国加油。觉有有用希望大家可以点个赞,关注我,这将给我莫大的动力。如果我文中有错误的地方,欢迎指出,我也需要学习和进步。多一点包容,多一点努力。

详细代码注释

train.py

# -*- coding:utf-8 -*- import tensorflow as tf from read_utils import TextConverter, batch_generator from model import CharRNN import os import codecs FLAGS = tf.flags.FLAGS tf.flags.DEFINE_string('name', 'default', '模型名') tf.flags.DEFINE_integer('num_seqs', 32, '一个batch里面的序列数量') # 32 tf.flags.DEFINE_integer('num_steps', 26, '序列的长度') # 26 tf.flags.DEFINE_integer('lstm_size', 128, 'LSTM隐层的大小') tf.flags.DEFINE_integer('num_layers', 2, 'LSTM的层数') tf.flags.DEFINE_boolean('use_embedding', False, '是否使用 embedding') tf.flags.DEFINE_integer('embedding_size', 128, 'embedding的大小') tf.flags.DEFINE_float('learning_rate', 0.001, '学习率') tf.flags.DEFINE_float('train_keep_prob', 0.5, '训练期间的dropout比率') tf.flags.DEFINE_string('input_file', '', 'utf8编码过的text文件') tf.flags.DEFINE_integer('max_steps', 10000, '一个step 是运行一个batch, max_steps固定了最大的运行步数') tf.flags.DEFINE_integer('save_every_n', 1000, '每隔1000步会将模型保存下来') tf.flags.DEFINE_integer('log_every_n', 10, '每隔10步会在屏幕上打出曰志') # 使用的字母(汉字)的最大个数。默认为3500 。程序会自动挑选出使用最多的字,井将剩下的字归为一类,并标记为<unk> tf.flags.DEFINE_integer('max_vocab', 10000, '最大字符数量') # python train.py --use_embedding --input_file data/poetry.txt --name poetry --learning_rate 0.005 --num_steps 26 --num_seqs 32 --max_steps 10000 # python train.py \ # --use_embedding \ # --input_file data/poetry.txt \ # --name poetry \ # --learning_rate 0.005 \ # --num_steps 26 \ # --num_seqs 32 \ # --max_steps 10000 def main(_): model_path = os.path.join('model', FLAGS.name) if os.path.exists(model_path) is False: os.makedirs(model_path) with codecs.open(FLAGS.input_file, encoding='utf-8') as f: # 打开训练数据集poetry.txt text = f.read() converter = TextConverter(text, FLAGS.max_vocab) # 最大字符数量10000 converter.save_to_file(os.path.join(model_path, 'converter.pkl')) arr = converter.text_to_arr(text) g = batch_generator(arr, FLAGS.num_seqs, FLAGS.num_steps) # 句子数量、句子长度 print(converter.vocab_size) # 3501 model = CharRNN(converter.vocab_size, num_seqs=FLAGS.num_seqs, num_steps=FLAGS.num_steps, lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers, learning_rate=FLAGS.learning_rate, train_keep_prob=FLAGS.train_keep_prob, use_embedding=FLAGS.use_embedding, embedding_size=FLAGS.embedding_size) model.train(g, FLAGS.max_steps, model_path, FLAGS.save_every_n, FLAGS.log_every_n) if __name__ == '__main__': tf.app.run()

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zzdsss.html