在上一篇文章中,我们已经对心电信号进行了预处理,将含有噪声的信号变得平滑,以便分类。本篇文章我们将正式开始利用深度学习对心电信号进行分类识别。
卷积神经网络不论是传统机器学习,还是深度学习,分类的依据都是不同类别的数据中包含的不同特征。要进行分类识别就需要对数据的特征进行提取,但是二者的提取方式并不相同。对于传统的机器学习而言,数据的特征需要设计者或专业人员针对其特性进行手动提取,而深度学习则可以自动提取每类数据中的不同特征。对于卷积神经网络CNN而言,能够自动提取特征的关键在于卷积操作。经过卷积操作提取的特征往往会有冗余,并且多次卷积会使神经网络的参数过多不便于训练,所以CNN往往会在卷积层的后面跟上一个池化层。经过多次的卷积和池化后,较低层次的特征就会逐步构成高层次的特征,最后神经网络根据提取出的高层次特征进行分类。
另外需要指出的是,为什么在心电信号分类中可以使用CNN呢。这是因为CNN具有的卷积操作具有局部连接和权值共享的特征。
局部连接:用于区别不同种类的图片所需的特征只是整张图片中的某些局部区域,因此在进行卷积操作时使用的卷积核(感受野)可以只是几个不同小区域,而不必使用整张图片大小的卷积核(全连接)。这样做不仅可以更好地表达不同的特征,还能起到减少参数的作用。例如下图,左边是使用全连接的神经网络,右边是使用局部连接卷积核的网络。
权值共享:对于一类图片而言,他们拥有相似的特征,但是每张图片中特征的位置可能会有偏移。比如不同的人脸照片中眼睛的位置可能会有变化,很少有两张照片眼睛的位置完全重合。对一张图片进行卷积操作时,可以有多个卷积核来提取不同的特征,但一个卷积核在进行移动的过程中其权值是保持不变的(当然不同卷积核的权值不共享)。这样既能保证特征提取不受位置的影响,还能减少参数的数量。
而心电信号虽然是一维的,但是其中的特征也满足局部连接和权值共享的条件,因此我们可以采用卷积神经网络对其分类。
构建深度学习的数据集巧妇难为无米之炊,虽然我们已经有了预处理过的心电数据,但是这样的数据是无法拿来直接进行分类学习的。所以我们要先构建符合深度学习模型使用的数据集。转换的过程是首先从一条心电信号中切分出符合要求的心拍作为样本,然后将python list转为numpy array,再经过乱序和切分,最终构成可供深度学习使用的数据集。这里我们使用tf.keras提供的接口,可以直接使用numpy数组类型,而不用再转成TensorFlow的DataSet对象,对于训练过程而言也更加简单。
心拍的切分需要找到QRS波尖峰所在的位置。由于我们只训练网络模型,我们这里直接使用MIT-BIH数据集提供的人工标注,并在尖峰处向前取99个信号点、向后取200个信号点,构成一个完整的心拍。如果需要对真实测量的信号进行识别分类,还要设计心拍的检测算法,后续我也可能会继续做。
数据集根据用途分为训练集、验证集和测试集。训练集用于训练参数模型,验证集用于模型训练中准确率和误差(损失函数)的检验,测试集用于训练完成后对训练效果的最终检验。可以类比学习、测验和考试。这三者的数据结构都一致,只是包含的数据内容不同,每个训练集都包含数据和标签两部分内容。数据是预处理后切分出的若干心拍的列表,标签是每个心拍样本对应的心电类型。
首先将上一篇的预处理步骤封装成一个函数:
# 小波去噪预处理 def denoise(data): # 小波变换 coeffs = pywt.wavedec(data=data, wavelet='db5', level=9) cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs # 阈值去噪 threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1)))) cD1.fill(0) cD2.fill(0) for i in range(1, len(coeffs) - 2): coeffs[i] = pywt.threshold(coeffs[i], threshold) # 小波反变换,获取去噪后的信号 rdata = pywt.waverec(coeffs=coeffs, wavelet='db5') return rdata然后将读取数据和标注、心拍切分封装成一个函数:
# 读取心电数据和对应标签,并对数据进行小波去噪 def getDataSet(number, X_data, Y_data): ecgClassSet = ['N', 'A', 'V', 'L', 'R'] # 读取心电数据记录 print("正在读取 " + number + " 号心电数据...") record = wfdb.rdrecord('ecg_data/' + number, channel_names=['MLII']) data = record.p_signal.flatten() # 小波去噪 rdata = denoise(data=data) # 获取心电数据记录中R波的位置和对应的标签 annotation = wfdb.rdann('ecg_data/' + number, 'atr') Rlocation = annotation.sample Rclass = annotation.symbol # 去掉前后的不稳定数据 start = 10 end = 5 i = start j = len(annotation.symbol) - end # 因为只选择NAVLR五种心电类型,所以要选出该条记录中所需要的那些带有特定标签的数据,舍弃其余标签的点 # X_data在R波前后截取长度为300的数据点 # Y_data将NAVLR按顺序转换为01234 while i < j: try: lable = ecgClassSet.index(Rclass[i]) x_train = rdata[Rlocation[i] - 99:Rlocation[i] + 201] X_data.append(x_train) Y_data.append(lable) i += 1 except ValueError: i += 1 return