Detecting Heart Abnormalities Using 1D CNN On Data You Cannot See
Detecting Heart Abnormalities Using 1D CNN On Data You Cannot See
TL;DR:
Can We Apply the Split Learning Architecture to Train a 1D CNN Model on Heartbeat
Data and Accurately Detect Heart Abnormalities while Preserving Data Privacy?
Well, that’s still too long, and words are cheap, just show me the code!
Here you go. Enjoy!
Introduction
Machine Learning (ML) is a subfield of Artificial Intelligence where algorithms are
trained to find patterns from massive datasets. These patterns are then used to make
decisions and predictions on new data. One of the problems that ML faces today is data
sharing: data scientists need to gather a large amount of data from data owners in order
to train their algorithms. This is often not ideal, especially for sensitive data in sectors
such as healthcare or finance. Split learning is one of the methods in Privacy Preserving
Machine Learning (PPML) that tries to address this data privacy problem.
Split learning refers to the process of cutting a Deep Neural Network (DNN) into two or
more parts. In the simplest scenario, i.e. there is only one data owner (the client) and
one data scientist (the server), the DNN is split into two parts. The first part of the DNN
Get started Open in app
is employed on the client’s machine where the data reside, and the second part is
employed on the server’s side. The client’s model will learn a set of features (also called
“activation maps”) from the dataset, then sends those activation maps to the server to
continue the training process. Then, during the backward pass, the server calculates the
loss function and the gradients of the loss up to the split layer, then sends those
gradients back to the client so he can continue the backward pass. This way, the
server/data scientist never gets to see the input training data, but can still train the
network. You can learn more about the fundamentals of split learning from this tutorial.
In this blog post, we will walk through the process of training a split neural network
using OpenMined’s framework PySyft: a Python library for computing on data you do
not own and cannot see. In OpenMined’s free course “Foundations of Private
Computation”, there is already a tutorial on how to train a split DNN using PySyft’s Duet
with two Jupyter notebooks: one to represent the client, and the other to represent the
server. However, if you are developing a new split learning method, using two notebooks
is quite bothersome as you have to switch back and forth. Luckily, there is another
feature of PySyft called VirtualMachine that allows us to develop a split DNN in only one
jupyter notebook or python file. We will learn how to use it today, along with PySyft’s
other features such as RemoteDataset and RemoteDataLoader to load a custom remote
dataset. Most importantly, we will discover how to train a split 1D CNN neural network
to detect heart abnormalities on input data that never leave the client’s machine, based
on the work from [1].
packages-ecg-split-1DCNN.py
hosted with ❤ by GitHub view raw
1 project_path = Path.cwd()
2 print(f'project_path: {project_path}')
3 # paths to files and directories
4 train_name = 'train_ecg.hdf5'
5 test_name = 'test_ecg.hdf5'
paths-to-files.py
hosted with ❤ by GitHub view raw
Setting the training and test file names for data imports (Image by Author)
1 server = sy.VirtualMachine(name="server")
2 client = server.get_root_client()
3 remote_torch = client.torch
pysyftVM.py
hosted with ❤ by GitHub view raw
ecg_dataset.py
hosted with ❤ by GitHub view raw
The post-processing dataset consists of 26 490 heartbeat samples in total, each one is a
time-series vector of length 128. There are 5 different types of heartbeats as
classification targets: normal beat (class 0), left bundle branch block (class 1), right
bundle branch block (class 2), atrial premature contraction (class 3), ventricular
premature contraction (class 4). We can see an example of each class in Figure 1 below.
Figure 1: some examples in the ECG dataset (Image by Author)
Get started Open in app
The client then loads the datasets and saves them into .pt files and sends them to the
server, using the code below.
1 train_dataset = ECG(mode='train')
2 test_dataset = ECG(mode='test')
3 torch.save(train_dataset, "train_dataset.pt")
4 torch.save(test_dataset, "test_dataset.pt")
saving-pt.py
hosted with ❤ by GitHub view raw
The client creates and save the dataset into .pt files (Image by Author)
If using duet , he can send the string path to the server with this syntax (note that we do
Let’s loop through the remote data loader and see what’s inside. Note that I used ic
from the icecream package to print out variables while debugging; it is quite handy.
Seeing what’s inside the remote data loader for the training data (Image by Author)
Using the code above, we would get X and y as pointers to the corresponding torch
Tensors, but not the real tensors themselves, like in the figure below.
Get started Open in app
Figure 2: output when looping through the remote data loader (Image by Author)
The server can request to access the tensors by using X.get() or X.get_copy() , but this
needs to be accepted by the client. Here, we assume that the client accepts all requests
from the server for convenience. However, we will see in the training loop later that the
client will never request to get access to the training input data. Furthermore, as we only
loaded 50 examples, and the batch size is 32, there are only two batches, one with 32
samples, and one with 18 samples.
Similarly, the server makes the remote dataset and data loader for the test dataset.
Get started Open in app
The server creates the remote dataset and remote data loader for the testing data (Image by Author)
Server: defining the split neural network architecture to train on the ECG
dataset
Figure 3 below shows the architecture of the 1D CNN neural network used to train on
the ECG dataset. The model on the client side contains two 1D convolution layers (we
will learn about it more later) with Leaky Relu activation functions. Each conv layer is
followed by a 1D Max Pooling operation. The server’s model contains two fully
connected layers, followed by a softmax activation function. The loss function used is
the cross-entropy loss.
Get started Open in app
Let’s learn a bit about the 1D convolution layer. It is simply a method that slides a weight
kernel along one dimension. Figure 4 shows the 1D convolution vs. 2D convolution
operation. 1D convolution is suitable for 1D data, such as time series that we have in the
ECG signals. If you want to learn more about 1D, 2D and 3D convolution, this blog post
Get started Open in app
offers very clear explanations.
Now we can move on and define the neural network models on the client side with the
code below. It is a class that inherits from syft.Module . Note that in line number 3, we
have torch_ref as an argument in the constructor, which we will pass remote_torch into
later. All the layers are constructed using this torch_ref module.
Code to define the split neural network part on the client side (Image by Author)
The server model also inherits from syft.Module ; its constructor still gets torch_ref as
Get started Open in app
an argument, however, the layers are defined with the normal torch.nn module, as they
are trained locally.
Code to define the split neural network part on the server side (Image by Author)
The server then sends the client’s model to the remote client side (line 2 in the code
below).
Get started Open in app
Creating the models and sending the client’s model (Image by Author)
Finally, let the fun begin. Below is the code for the training and testing loop:
Get started Open in app
The training and testing loop for our split 1D CNN model (Image by Author)
In the forward pass, we first get the pointers to the batch data (line 12). After initializing
all gradients to 0 (line 15, 16), the client’s model extracts the activation maps from the
training input data (line 18). The server then asks to access these activation maps (line
20) and continues the forward pass (line 22). The server also asks for access to the
ground truth output data (line 24) to calculate the loss (line 26).
In the backward pass, the server starts the backpropagation until the split layer (line
30), then sends the gradients to the client (line 32). Upon reception, the client continues
the backpropagation and calculates his gradients (line 34). Finally, when all gradients of
the loss function with respect to the weights are calculated, both the client and server
can update the parameters.
In the testing loop for each epoch, we only need to do the forward pass and calculate the
testing losses.
Get started Open in app
Figure 5: the result of the training and testing loop (Image by Author)
Finally, after 400 epochs are over, we can print out the best test accuracy and plot the
training/testing losses and accuracies, like in Figure 6 and 7. As we can see, the split
learning 1D CNN method can achieve 98.85% accuracy on the test dataset after 351
epochs. Not bad at all.
Conclusions
In this blog post, we walked through the process of training a split 1D CNN model on the
ECG dataset. Employing the split learning architecture, the algorithm can predict heart
abnormalities up to 98,85% accurately while keeping the heartbeat data of the patients
private. Thank you for reading, I hope you find something useful. See you in other blog
posts on Secure and Private AI.
References
[1] Sharif Abuadbba et al., Can We Use Split Learning on 1D CNN Models for Privacy
Preserving Training? (2020), ACM ASIA Conference on Computer and Communications
Security (ACM ASIACCS 2020)
[2] Moody GB, Mark RG. The impact of the MIT-BIH Arrhythmia Database (2001), IEEE
Eng in Med and Biol 20(3):45–50 (May-June 2001)
[3] Praneeth Vepakomma et al., Split learning for health: Distributed deep learning
without sharing raw patient data (2018)