然后使用model.compile()构建;model.fit()训练30轮,批大小为128,划分验证集的比例为0.3,设置callback进行训练记录的保存;model.save()保存模型;model.predict_classes()预测。完整代码可以取本人的GitHub仓库查看,地址在文章(一)中。
def main(): # X,Y为所有的数据集和标签集 # X_test,Y_test为拆分的测试集和标签集 X, Y, X_test, Y_test = loadData() if os.path.exists(model_path): # 导入训练好的模型 model = tf.keras.models.load_model(filepath=model_path) else: # 构建CNN模型 model = buildModel() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.summary() # 定义TensorBoard对象 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) # 训练与验证 model.fit(X, Y, epochs=30, batch_size=128, validation_split=RATIO, callbacks=[tensorboard_callback]) model.save(filepath=model_path) # 预测 Y_pred = model.predict_classes(X_test)对心电信号的深度学习识别分类至此结束,识别率可达99%左右。