使用Python+TensorFlow2构建基于卷积神经网络(CNN)的ECG心电信号识别分类(四) (3)

然后使用model.compile()构建;model.fit()训练30轮,批大小为128,划分验证集的比例为0.3,设置callback进行训练记录的保存;model.save()保存模型;model.predict_classes()预测。完整代码可以取本人的GitHub仓库查看,地址在文章(一)中。

def main(): # X,Y为所有的数据集和标签集 # X_test,Y_test为拆分的测试集和标签集 X, Y, X_test, Y_test = loadData() if os.path.exists(model_path): # 导入训练好的模型 model = tf.keras.models.load_model(filepath=model_path) else: # 构建CNN模型 model = buildModel() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.summary() # 定义TensorBoard对象 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) # 训练与验证 model.fit(X, Y, epochs=30, batch_size=128, validation_split=RATIO, callbacks=[tensorboard_callback]) model.save(filepath=model_path) # 预测 Y_pred = model.predict_classes(X_test)

对心电信号的深度学习识别分类至此结束,识别率可达99%左右。

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zydgyz.html