TensorFlow学习笔记--CIFAR-10 图像识别 (2)

注2:
使用tf.train.string_input_producer() 创建完文件名队列后,文件名并没有被加入到队列中,如果此时开始计算,会导致整个系统处于阻塞状态。
在创建完文件名队列后,应调用 tf.train.start_queue_runners方法才会启动文件名队列的填充,整个程序才能正常运行起来。

代码

import tensorflow as tf # 新建session with tf.Session() as sess: # 要读取的三张图片 filename = [\'img/1.jpg\', \'img/2.jpg\', \'img/3.jpg\'] # 创建文件名队列 filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False) reader = tf.WholeFileReader() key, value = reader.read(filename_queue) # 初始化变量(epoch) tf.local_variables_initializer().run() threads = tf.train.start_queue_runners(sess=sess) i = 0 while True: i += 1 # 获取图片保存数据 image_data = sess.run(value) with open(\'read/test_%d.jpg\' % i, \'wb\') as f: f.write(image_data) 五、数据增强

对于图像数据来说,数据增强方法就是利用平移、缩放、颜色等变换增大训练集样本个数,从而达到更好的效果(注3),使用数据增强可以大大提高模型的泛化能力,并且能够预防过拟合。
常用的图像数据增强方法如下表

方法 说明
平移   将图像在一定尺度范围内平移  
旋转   将图像在一定角度范围内旋转  
翻转   水平翻转或者上下翻转图片  
裁剪   在原图上裁剪出一块  
缩放   将图像在一定尺度内放大或缩小  
颜色变换   对图像的RGB颜色空间进行一些变换  
噪声扰动   给图像加入一些人工生成的噪声  

注3:
使用数据增强的方法前提是,这些数据增强方法不会改变图像的原有标签。比如数字6的图片,经过上下翻转之后就变成了数字9的图片。

六、CIFAR-10识别模型

建立模型的代码在cifar10.py文件额inference函数中,代码在这里不进行详解,读者可以去阅读代码中的注释。
这里我们通过以下命令训练模型:

python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/

这段命令中 –data_dir cifar10_data/ 表示数据保存的位置, –train_dir cifar10_train/ 表示保存模型参数和训练时日志信息的位置

七、查看训练进度

在训练的时候我们往往需要知道损失的变化和每层的训练情况,这个时候我们就会用到tensorflow提供的 TensorBoard。打开一个新的命令行,输入如下命令:

tensorboard --logdir cifar10_train/

其中 –logdir cifar10_train/ 表示模型训练日志保存的位置,运行该命令后将会在命令行看到类似如下的内容

命令行反馈


在浏览器上输入显示的地址,即可访问TensorBoard。简单解释一下常用的几个标签:

标签 说明
total_loss_1   loss 的变化曲线,变化曲线会根据时间实时变化  
learning_rate   学习率变化曲线  
global_step   美妙训练步数的情况,如果训练速度变化较大,或者越来越慢,就说明程序有可能存在错误  
八、检测模型的准确性

在命令行窗口输入如下命令:

python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zgzswd.html