0% found this document useful (0 votes)
40 views18 pages

Detecting Heart Abnormalities Using 1D CNN On Data You Cannot See

The document describes how to train a split 1D CNN model on heartbeat data to detect heart abnormalities while preserving data privacy using PySyft. Key points: 1. The client loads the ECG dataset and extracts features using part of the 1D CNN model. 2. The client sends the extracted features to the server without sharing the raw data. 3. The server completes the training using the received features and its part of the 1D CNN model without accessing the raw data. This split learning approach allows detecting heart abnormalities from the ECG data accurately while preserving the privacy of the sensitive training data.

Uploaded by

sdfasfd
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
40 views18 pages

Detecting Heart Abnormalities Using 1D CNN On Data You Cannot See

The document describes how to train a split 1D CNN model on heartbeat data to detect heart abnormalities while preserving data privacy using PySyft. Key points: 1. The client loads the ECG dataset and extracts features using part of the 1D CNN model. 2. The client sends the extracted features to the server without sharing the raw data. 3. The server completes the training using the received features and its part of the 1D CNN model without accessing the raw data. This split learning approach allows detecting heart abnormalities from the ECG data accurately while preserving the privacy of the sensitive training data.

Uploaded by

sdfasfd
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 18

Get started Open in app

Follow 601K Followers

Detecting Heart Abnormalities Using 1D CNN


on Data You Cannot See
Preserve Sensitive Training Data Privacy with Split Neural Network and PySyft

Khoa Nguyen Oct 3 · 8 min read

TL;DR:
Can We Apply the Split Learning Architecture to Train a 1D CNN Model on Heartbeat
Data and Accurately Detect Heart Abnormalities while Preserving Data Privacy?

Well, that’s still too long, and words are cheap, just show me the code!
Here you go. Enjoy!

Introduction
Machine Learning (ML) is a subfield of Artificial Intelligence where algorithms are
trained to find patterns from massive datasets. These patterns are then used to make
decisions and predictions on new data. One of the problems that ML faces today is data
sharing: data scientists need to gather a large amount of data from data owners in order
to train their algorithms. This is often not ideal, especially for sensitive data in sectors
such as healthcare or finance. Split learning is one of the methods in Privacy Preserving
Machine Learning (PPML) that tries to address this data privacy problem.

Split learning refers to the process of cutting a Deep Neural Network (DNN) into two or
more parts. In the simplest scenario, i.e. there is only one data owner (the client) and
one data scientist (the server), the DNN is split into two parts. The first part of the DNN
Get started Open in app
is employed on the client’s machine where the data reside, and the second part is
employed on the server’s side. The client’s model will learn a set of features (also called
“activation maps”) from the dataset, then sends those activation maps to the server to
continue the training process. Then, during the backward pass, the server calculates the
loss function and the gradients of the loss up to the split layer, then sends those
gradients back to the client so he can continue the backward pass. This way, the
server/data scientist never gets to see the input training data, but can still train the
network. You can learn more about the fundamentals of split learning from this tutorial.

In this blog post, we will walk through the process of training a split neural network
using OpenMined’s framework PySyft: a Python library for computing on data you do
not own and cannot see. In OpenMined’s free course “Foundations of Private
Computation”, there is already a tutorial on how to train a split DNN using PySyft’s Duet
with two Jupyter notebooks: one to represent the client, and the other to represent the
server. However, if you are developing a new split learning method, using two notebooks
is quite bothersome as you have to switch back and forth. Luckily, there is another
feature of PySyft called VirtualMachine that allows us to develop a split DNN in only one
jupyter notebook or python file. We will learn how to use it today, along with PySyft’s
other features such as RemoteDataset and RemoteDataLoader to load a custom remote
dataset. Most importantly, we will discover how to train a split 1D CNN neural network
to detect heart abnormalities on input data that never leave the client’s machine, based
on the work from [1].

Let’s jump into it


First, we need to import the necessary packages and define the paths to the necessary
files. I used torch 1.8.1+cu102 and syft 0.5.0.

1 from pathlib import Path


2
3 import h5py
4 import matplotlib.pyplot as plt
5 import numpy as np
6 from icecream import ic # easy printing for debugging
7 from tqdm import tqdm
8
9 plt.style.use('dark_background')
10
Get
11 started
import syftOpen in app
as sy
12 import torch
13 import torch.nn as nn
14 from torch.optim import SGD, Adam
15 from torch.utils.data import DataLoader, Dataset
16
17 print(f'torch version: {torch.__version__}')
18 print(f'syft version: {sy.__version__}')

packages-ecg-split-1DCNN.py
hosted with ❤ by GitHub view raw

Code for importing packages (Image by Author)

1 project_path = Path.cwd()
2 print(f'project_path: {project_path}')
3 # paths to files and directories
4 train_name = 'train_ecg.hdf5'
5 test_name = 'test_ecg.hdf5'

paths-to-files.py
hosted with ❤ by GitHub view raw

Setting the training and test file names for data imports (Image by Author)

Defining the client and the server


Using PySyft’s VirtualMachine, we can define abstract actors in this scenario like in the
code below.

1 server = sy.VirtualMachine(name="server")
2 client = server.get_root_client()
3 remote_torch = client.torch

pysyftVM.py
hosted with ❤ by GitHub view raw

Defining the server and client virtual machines (Image by Author)

Client: loading and exploring the dataset


First, let’s assume to be the client (data owner) and discover the dataset. We will use
MIT-BIH arrhythmia, a popular dataset for ECG signal classification or arrhythmia
diagnosis [2]. You can find the original dataset here, however, we use the processed
data from here. Below is the code needed to load the dataset from train_ecg.hdf5 and
test_ecg.hdf5 .
1 class ECG(Dataset):
Get
2 started Open used
# The class in appto load the ECG dataset
3 def __init__(self, mode='train'):
4 if mode == 'train':
5 with h5py.File(project_path/train_name, 'r') as hdf:
6 self.x = torch.tensor(hdf['x_train'][:], dtype=torch.float)
7 self.y = torch.tensor(hdf['y_train'][:])
8 elif mode == 'test':
9 with h5py.File(project_path/test_name, 'r') as hdf:
10 self.x = torch.tensor(hdf['x_test'][:], dtype=torch.float)
11 self.y = torch.tensor(hdf['y_test'][:])
12 else:
13 raise ValueError('Argument of mode should be train or test')
14
15 def __len__(self):
16 return len(self.x)
17
18 def __getitem__(self, idx):
19 return self.x[idx], self.y[idx]

ecg_dataset.py
hosted with ❤ by GitHub view raw

The class for loading the ECG dataset (Image by Author)

The post-processing dataset consists of 26 490 heartbeat samples in total, each one is a
time-series vector of length 128. There are 5 different types of heartbeats as
classification targets: normal beat (class 0), left bundle branch block (class 1), right
bundle branch block (class 2), atrial premature contraction (class 3), ventricular
premature contraction (class 4). We can see an example of each class in Figure 1 below.
Figure 1: some examples in the ECG dataset (Image by Author)
Get started Open in app

The client then loads the datasets and saves them into .pt files and sends them to the
server, using the code below.

1 train_dataset = ECG(mode='train')
2 test_dataset = ECG(mode='test')
3 torch.save(train_dataset, "train_dataset.pt")
4 torch.save(test_dataset, "test_dataset.pt")

saving-pt.py
hosted with ❤ by GitHub view raw

The client creates and save the dataset into .pt files (Image by Author)

If using duet , he can send the string path to the server with this syntax (note that we do

not use duet this time)


Get started Open in app

Code to send a string path to a dataset if using duet (Image by Author)

Server: creating the remote dataset and remote data loader


Now, after receiving the .pt path of the dataset from the client, the server creates the
RemoteDataset and RemoteDataLoader on the remote side.
The server creates the remote dataset and remote data loader (Image by Author)
Get started Open in app

Let’s loop through the remote data loader and see what’s inside. Note that I used ic

from the icecream package to print out variables while debugging; it is quite handy.

Seeing what’s inside the remote data loader for the training data (Image by Author)

Using the code above, we would get X and y as pointers to the corresponding torch
Tensors, but not the real tensors themselves, like in the figure below.
Get started Open in app

Figure 2: output when looping through the remote data loader (Image by Author)

The server can request to access the tensors by using X.get() or X.get_copy() , but this

needs to be accepted by the client. Here, we assume that the client accepts all requests
from the server for convenience. However, we will see in the training loop later that the
client will never request to get access to the training input data. Furthermore, as we only
loaded 50 examples, and the batch size is 32, there are only two batches, one with 32
samples, and one with 18 samples.

Similarly, the server makes the remote dataset and data loader for the test dataset.
Get started Open in app

The server creates the remote dataset and remote data loader for the testing data (Image by Author)

Server: defining the split neural network architecture to train on the ECG
dataset
Figure 3 below shows the architecture of the 1D CNN neural network used to train on
the ECG dataset. The model on the client side contains two 1D convolution layers (we
will learn about it more later) with Leaky Relu activation functions. Each conv layer is
followed by a 1D Max Pooling operation. The server’s model contains two fully
connected layers, followed by a softmax activation function. The loss function used is
the cross-entropy loss.
Get started Open in app

Figure 3: the split learning 1DCNN model architecture (Image by Author)

Figure 4: 1D Convolution layer vs 2D Convolution layer (image by the author)

Let’s learn a bit about the 1D convolution layer. It is simply a method that slides a weight
kernel along one dimension. Figure 4 shows the 1D convolution vs. 2D convolution
operation. 1D convolution is suitable for 1D data, such as time series that we have in the
ECG signals. If you want to learn more about 1D, 2D and 3D convolution, this blog post
Get started Open in app
offers very clear explanations.

Now we can move on and define the neural network models on the client side with the
code below. It is a class that inherits from syft.Module . Note that in line number 3, we

have torch_ref as an argument in the constructor, which we will pass remote_torch into
later. All the layers are constructed using this torch_ref module.

Code to define the split neural network part on the client side (Image by Author)
The server model also inherits from syft.Module ; its constructor still gets torch_ref as
Get started Open in app
an argument, however, the layers are defined with the normal torch.nn module, as they
are trained locally.

Code to define the split neural network part on the server side (Image by Author)

The server then sends the client’s model to the remote client side (line 2 in the code
below).
Get started Open in app

Creating the models and sending the client’s model (Image by Author)

Server and client: training and testing loop


Before the training and testing loop, we need to define some hyperparameters:
Get started Open in app

Setting hyperparameters and random seed (Image by Author)

Finally, let the fun begin. Below is the code for the training and testing loop:
Get started Open in app

The training and testing loop for our split 1D CNN model (Image by Author)

In the forward pass, we first get the pointers to the batch data (line 12). After initializing
all gradients to 0 (line 15, 16), the client’s model extracts the activation maps from the
training input data (line 18). The server then asks to access these activation maps (line
20) and continues the forward pass (line 22). The server also asks for access to the
ground truth output data (line 24) to calculate the loss (line 26).

In the backward pass, the server starts the backpropagation until the split layer (line
30), then sends the gradients to the client (line 32). Upon reception, the client continues
the backpropagation and calculates his gradients (line 34). Finally, when all gradients of
the loss function with respect to the weights are calculated, both the client and server
can update the parameters.

In the testing loop for each epoch, we only need to do the forward pass and calculate the
testing losses.
Get started Open in app

Figure 5: the result of the training and testing loop (Image by Author)

Finally, after 400 epochs are over, we can print out the best test accuracy and plot the
training/testing losses and accuracies, like in Figure 6 and 7. As we can see, the split
learning 1D CNN method can achieve 98.85% accuracy on the test dataset after 351
epochs. Not bad at all.

Figure 6: printing out the best test accuracy (Image by Author)


Figure 7: training/testing losses and accuracies (Image by Author)
Get started Open in app

Drawbacks and Future Directions


While the split learning method achieves promising results, there are several problems
to be addressed. Firstly, the server still needs to access the ground truth output data to
calculate the loss. To solve this problem, we can use the U-shaped split learning
configuration [3]. Secondly, the activation maps sent from the client to the server can
still leak information about the input training data. The authors from [1] have
experimented with differential privacy to solve this problem, however, it hinders greatly
the accuracies of the algorithm. Thirdly, the time needed to train the split network using
PySyft is very long, almost 14 hours on Intel Xeon CPU 2.60GHz and 6 cores. Training
the same network locally with GPU only takes a few minutes. For now, PySyft has not
supported training on GPU. Tackling these problems will be the focus of future works.

Conclusions
In this blog post, we walked through the process of training a split 1D CNN model on the
ECG dataset. Employing the split learning architecture, the algorithm can predict heart
abnormalities up to 98,85% accurately while keeping the heartbeat data of the patients
private. Thank you for reading, I hope you find something useful. See you in other blog
posts on Secure and Private AI.

References
[1] Sharif Abuadbba et al., Can We Use Split Learning on 1D CNN Models for Privacy
Preserving Training? (2020), ACM ASIA Conference on Computer and Communications
Security (ACM ASIACCS 2020)

[2] Moody GB, Mark RG. The impact of the MIT-BIH Arrhythmia Database (2001), IEEE
Eng in Med and Biol 20(3):45–50 (May-June 2001)

[3] Praneeth Vepakomma et al., Split learning for health: Distributed deep learning
without sharing raw patient data (2018)

Sign up for The Variable


By Towards Data Science
Every Thursday, the Variable delivers the very best of Towards Data Science: from hands-on tutorials
Get started Open in app to original features you don't want to miss. Take a look.
and cutting-edge research

Get this newsletter

Machine Learning Secure And Private Ai Privacy

About Write Help Legal

Get the Medium app

You might also like