模型的保存与恢复
介绍一些常见的模型保存与恢复方法,以及如何使用回调函数保存模型。
1.保存完整模型model.save()方法可以保存完整的模型,包括模型的架构、模型的权重以及优化器。
model.save()的参数为保存路径以及文件名。
首先我们构建一个简单的Sequential模型,使用fishion_mnist数据集进行训练,得到一个训练后的模型。
1 import tensorflow as tf 2 import numpy as np 3 4 (train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data() 5 6 train_image = np.expand_dims(train_image, -1) 7 test_image = np.expand_dims(test_image, -1) 8 9 model = tf.keras.Sequential() 10 model.add(tf.keras.layers.Conv2D(32, [3, 3], input_shape = (28, 28, 1), activation = 'relu')) 11 model.add(tf.keras.layers.Conv2D(64, [3, 3], activation = 'relu')) 12 model.add(tf.keras.layers.GlobalAveragePooling2D()) 13 model.add(tf.keras.layers.Dense(64, activation = 'relu')) 14 model.add(tf.keras.layers.Dense(10, activation = 'softmax')) 15 16 model.compile(optimizer = 'adam', 17 loss = 'sparse_categorical_crossentropy', 18 metrics = ['acc']) 19 20 history = model.fit(train_image, train_label, 21 epochs = 10, 22 validation_data = (test_image, test_label))