注2:
使用tf.train.string_input_producer() 创建完文件名队列后,文件名并没有被加入到队列中,如果此时开始计算,会导致整个系统处于阻塞状态。
在创建完文件名队列后,应调用 tf.train.start_queue_runners方法才会启动文件名队列的填充,整个程序才能正常运行起来。
代码
import tensorflow as tf # 新建session with tf.Session() as sess: # 要读取的三张图片 filename = [\'img/1.jpg\', \'img/2.jpg\', \'img/3.jpg\'] # 创建文件名队列 filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False) reader = tf.WholeFileReader() key, value = reader.read(filename_queue) # 初始化变量(epoch) tf.local_variables_initializer().run() threads = tf.train.start_queue_runners(sess=sess) i = 0 while True: i += 1 # 获取图片保存数据 image_data = sess.run(value) with open(\'read/test_%d.jpg\' % i, \'wb\') as f: f.write(image_data) 五、数据增强对于图像数据来说,数据增强方法就是利用平移、缩放、颜色等变换增大训练集样本个数,从而达到更好的效果(注3),使用数据增强可以大大提高模型的泛化能力,并且能够预防过拟合。
常用的图像数据增强方法如下表
平移 将图像在一定尺度范围内平移
旋转 将图像在一定角度范围内旋转
翻转 水平翻转或者上下翻转图片
裁剪 在原图上裁剪出一块
缩放 将图像在一定尺度内放大或缩小
颜色变换 对图像的RGB颜色空间进行一些变换
噪声扰动 给图像加入一些人工生成的噪声
注3:
使用数据增强的方法前提是,这些数据增强方法不会改变图像的原有标签。比如数字6的图片,经过上下翻转之后就变成了数字9的图片。
建立模型的代码在cifar10.py文件额inference函数中,代码在这里不进行详解,读者可以去阅读代码中的注释。
这里我们通过以下命令训练模型:
这段命令中 –data_dir cifar10_data/ 表示数据保存的位置, –train_dir cifar10_train/ 表示保存模型参数和训练时日志信息的位置
七、查看训练进度在训练的时候我们往往需要知道损失的变化和每层的训练情况,这个时候我们就会用到tensorflow提供的 TensorBoard。打开一个新的命令行,输入如下命令:
tensorboard --logdir cifar10_train/其中 –logdir cifar10_train/ 表示模型训练日志保存的位置,运行该命令后将会在命令行看到类似如下的内容
在浏览器上输入显示的地址,即可访问TensorBoard。简单解释一下常用的几个标签: 标签 说明
total_loss_1 loss 的变化曲线,变化曲线会根据时间实时变化
learning_rate 学习率变化曲线
global_step 美妙训练步数的情况,如果训练速度变化较大,或者越来越慢,就说明程序有可能存在错误
八、检测模型的准确性
在命令行窗口输入如下命令:
python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/