本部分的CNN模型采用https://blog.csdn.net/qq_43750573/article/details/105930152所用的模型,代码也是在该博主的代码基础上修改得到,CNN的模型介绍可参考论文(《面向心血管疾病识别的心电信号分类研究》)
代码及流程
1.下载数据
通过读取二进制文件,得到训练集和测试集的数据和标签
# 获取数据
def load_data(filename, sample_num, sample_len):
fp = open(filename, "rb")
if sample_len == data_len:
context = fp.read(2 * sample_num * sample_len)
fmt_unpack = '%d' % (sample_num * sample_len) + 'h'
dat_arr = np.array(struct.unpack(fmt_unpack, context), dtype=float)
dat_arr = dat_arr.reshape((sample_num, sample_len))
for i in range(0, sample_num):
dat_arr[i] = normalization(dat_arr[i])
dat_arr = dat_arr.reshape(-1, data_len, 1)
else:
context = fp.read(sample_num * sample_len)
fmt_unpack = '%d' % (sample_num * sample_len) + 'B'
dat_arr = np.array(struct.unpack(fmt_unpack, context))
fp.close()
return dat_arr
# 获取训练和测试的数据
def get_data():
data_path = os.path.join(os.getcwd(), "data")
# 训练数据,35个个体,168000个心拍,N与V比例1:1,前半部分为N,后半部分为V
data_train = load_data(os.path.join(data_path, data_train_name), train_num, data_len)
label_train