Alzheimers Classification 2
Alzheimers Classification 2
import keras
from tensorflow import keras
from keras import Sequential
from keras import layers
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, Dropout, Activation,␣
↪BatchNormalization, Flatten, Conv2D, MaxPooling2D
1
plt.rcParams["figure.figsize"] = (10,6)
plt.rcParams['figure.dpi'] = 300
colors = ["#B6EE56", "#D85F9C", "#EEA756", "#56EEE8"]
[2]: try:
if tf.test.gpu_device_name():
physical_devices = tf.config.experimental.list_physical_devices('GPU')
print('GPU active! -', physical_devices)
else:
print('GPU not active!')
except Exception as e:
print('An error occurred while checking the GPU:', e)
if file_ext in image_extensions:
count += 1
class_dist[dir_name] = count
print(f"There are \033[35m{count}\033[0m images in the {dir_name}␣
folder.")
↪
keys = list(class_dist.keys())
values = list(class_dist.values())
2
explode = (0.1,)*len(keys)
PATH = '/content/drive/MyDrive/Dataset'
image_counter(PATH)
As observed in the class distribution, we have an imbalanced dataset. Non Demented MRI class
constitutes 50% of the total data with 3200 images, while Moderate Demented MRI class only
makes up 1% of the dataset with 64 images.
3
3 Generate TF Dataset
When examining the data path, we can see that the folder named “Dataset” is the main folder,
and within it, there are separate folders for each class containing their respective images.
Dataset/ …Mild_Demented/ ……mild_1.jpg ……mild_2.jpg …Moderate_Demented/ ……moder-
ate_1.jpg ……moderate_2.jpg
When encountering a situation with such a structure, TensorFlow has a powerful function
tf.keras.utils.image_dataset_from_directory for reading data.
[35]: data = tf.keras.utils.image_dataset_from_directory(PATH,
batch_size = 32,
image_size=(128, 128),
shuffle=True,
seed=42,)
class_names = data.class_names
Mounted at /content/drive
Let’s see some samples for each class!
MRI Samples for Each Class
[36]: def sample_bringer(path, target, num_samples=5):
for i in range(num_samples):
image_path = os.path.join(class_path, image_files[i])
img = mpimg.imread(image_path)
ax[i].imshow(img)
ax[i].axis('off')
ax[i].set_title(f'Sample {i+1}', color="aqua")
plt.tight_layout()
4
for target in class_names:
sample_bringer(PATH, target=target)
5
Pixel normalization improves the performance of a neural network. Therefore, we will go with pixel
values from 0 to 1, rather than values in the range 0 to 255.
[37]: alz_dict = {index: img for index, img in enumerate(data.class_names)}
class Process:
def __init__(self, data):
self.data = data.map(lambda x, y: (x/255, y))
def create_new_batch(self):
self.batch = self.data.as_numpy_iterator().next()
text = "Min and max pixel values in the batch ->"
print(text, self.batch[0].min(), "&", self.batch[0].max())
train = int(len(self.data)*train_size)
test = int(len(self.data)*test_size)
val = int(len(self.data)*val_size)
train_data = self.data.take(train)
val_data = self.data.skip(train).take(val)
test_data = self.data.skip(train+val).take(test)
6
return train_data, val_data, test_data
Min and max pixel values in the batch -> 0.0 & 1.0
We will divide the dataset into 80% training data, 10% validation data and 10% test data.
[39]: train_data, val_data, test_data= process.train_test_val_split(train_size=0.8,␣
↪val_size=0.1, test_size=0.1)
We have an imbalanced distribution of target class. When dealing with an imbalanced target class
distribution, using class weights can help the model perform better and effectively recognize the
minority classes. Therefore, let’s calculate the weights of the target classes in our training data and
provide this information to our model during training.
[41]: y_train = tf.concat(list(map(lambda x: x[1], train_data)), axis=0)
class_weight = compute_class_weight('balanced',classes=np.unique(y_train),␣
↪y=y_train.numpy())
4 Model Building
[42]: def build_model():
model = Sequential()
7
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation="relu", kernel_initializer='he_normal'))
model.add(Dense(64, activation="relu"))
model.add(Dense(4, activation="softmax"))
model.compile(optimizer='adam', loss="sparse_categorical_crossentropy",␣
↪metrics=['accuracy'])
model.summary()
return model
model = build_model()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 126, 126, 16) 448
8
=================================================================
Total params: 3,261,988
Trainable params: 3,261,988
Non-trainable params: 0
_________________________________________________________________
Callbacks
[43]: def checkpoint_callback():
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback= ModelCheckpoint(filepath=checkpoint_filepath,
save_weights_only=False,
frequency='epoch',
monitor='val_accuracy',
save_best_only=True,
verbose=1)
return model_checkpoint_callback
def early_stopping(patience):
es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',␣
↪patience=patience, verbose=1)
return es_callback
EPOCHS = 20
checkpoint_callback = checkpoint_callback()
early_stopping = early_stopping(patience=5)
callbacks = [checkpoint_callback, early_stopping]
Epoch 1/20
160/160 [==============================] - ETA: 0s - loss: 0.8399 - accuracy:
0.5764
Epoch 1: val_accuracy improved from -inf to 0.70156, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 32s 199ms/step - loss: 0.8399 -
accuracy: 0.5764 - val_loss: 0.6907 - val_accuracy: 0.7016
Epoch 2/20
159/160 [============================>.] - ETA: 0s - loss: 0.3954 - accuracy:
9
0.7531
Epoch 2: val_accuracy improved from 0.70156 to 0.80625, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 15s 93ms/step - loss: 0.3947 -
accuracy: 0.7529 - val_loss: 0.4266 - val_accuracy: 0.8062
Epoch 3/20
159/160 [============================>.] - ETA: 0s - loss: 0.1618 - accuracy:
0.9135
Epoch 3: val_accuracy improved from 0.80625 to 0.95781, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 89ms/step - loss: 0.1614 -
accuracy: 0.9139 - val_loss: 0.1392 - val_accuracy: 0.9578
Epoch 4/20
160/160 [==============================] - ETA: 0s - loss: 0.0535 - accuracy:
0.9779
Epoch 4: val_accuracy improved from 0.95781 to 0.96406, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 19s 118ms/step - loss: 0.0535 -
accuracy: 0.9779 - val_loss: 0.1029 - val_accuracy: 0.9641
Epoch 5/20
160/160 [==============================] - ETA: 0s - loss: 0.0261 - accuracy:
0.9918
Epoch 5: val_accuracy improved from 0.96406 to 0.98438, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 74s 466ms/step - loss: 0.0261 -
accuracy: 0.9918 - val_loss: 0.0501 - val_accuracy: 0.9844
Epoch 6/20
159/160 [============================>.] - ETA: 0s - loss: 0.0069 - accuracy:
0.9994
Epoch 6: val_accuracy improved from 0.98438 to 0.99375, saving model to
/tmp/checkpoint
10
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 88ms/step - loss: 0.0068 -
accuracy: 0.9994 - val_loss: 0.0271 - val_accuracy: 0.9937
Epoch 7/20
160/160 [==============================] - ETA: 0s - loss: 0.0015 - accuracy:
1.0000
Epoch 7: val_accuracy improved from 0.99375 to 0.99531, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 89ms/step - loss: 0.0015 -
accuracy: 1.0000 - val_loss: 0.0149 - val_accuracy: 0.9953
Epoch 8/20
160/160 [==============================] - ETA: 0s - loss: 5.9136e-04 -
accuracy: 1.0000
Epoch 8: val_accuracy did not improve from 0.99531
160/160 [==============================] - 17s 106ms/step - loss: 5.9136e-04 -
accuracy: 1.0000 - val_loss: 0.0111 - val_accuracy: 0.9953
Epoch 9/20
160/160 [==============================] - ETA: 0s - loss: 3.5138e-04 -
accuracy: 1.0000
Epoch 9: val_accuracy did not improve from 0.99531
160/160 [==============================] - 14s 84ms/step - loss: 3.5138e-04 -
accuracy: 1.0000 - val_loss: 0.0130 - val_accuracy: 0.9953
Epoch 10/20
160/160 [==============================] - ETA: 0s - loss: 2.7242e-04 -
accuracy: 1.0000
Epoch 10: val_accuracy improved from 0.99531 to 0.99687, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 86ms/step - loss: 2.7242e-04 -
accuracy: 1.0000 - val_loss: 0.0101 - val_accuracy: 0.9969
Epoch 11/20
160/160 [==============================] - ETA: 0s - loss: 2.1518e-04 -
accuracy: 1.0000
Epoch 11: val_accuracy did not improve from 0.99687
160/160 [==============================] - 18s 109ms/step - loss: 2.1518e-04 -
accuracy: 1.0000 - val_loss: 0.0122 - val_accuracy: 0.9969
Epoch 12/20
160/160 [==============================] - ETA: 0s - loss: 1.7408e-04 -
11
accuracy: 1.0000
Epoch 12: val_accuracy did not improve from 0.99687
160/160 [==============================] - 13s 77ms/step - loss: 1.7408e-04 -
accuracy: 1.0000 - val_loss: 0.0128 - val_accuracy: 0.9937
Epoch 13/20
160/160 [==============================] - ETA: 0s - loss: 1.3730e-04 -
accuracy: 1.0000
Epoch 13: val_accuracy did not improve from 0.99687
160/160 [==============================] - 13s 77ms/step - loss: 1.3730e-04 -
accuracy: 1.0000 - val_loss: 0.0105 - val_accuracy: 0.9953
Epoch 14/20
159/160 [============================>.] - ETA: 0s - loss: 1.1478e-04 -
accuracy: 1.0000
Epoch 14: val_accuracy did not improve from 0.99687
160/160 [==============================] - 13s 77ms/step - loss: 1.1823e-04 -
accuracy: 1.0000 - val_loss: 0.0172 - val_accuracy: 0.9922
Epoch 15/20
160/160 [==============================] - ETA: 0s - loss: 1.0131e-04 -
accuracy: 1.0000
Epoch 15: val_accuracy did not improve from 0.99687
160/160 [==============================] - 17s 104ms/step - loss: 1.0131e-04 -
accuracy: 1.0000 - val_loss: 0.0124 - val_accuracy: 0.9937
Epoch 15: early stopping
Loss and Accuracy
[46]: fig, ax = plt.subplots(1, 2, figsize=(12,6), facecolor="khaki")
ax[0].set_facecolor('palegoldenrod')
ax[0].set_title('Loss', fontweight="bold")
ax[0].set_xlabel("Epoch", size=14)
ax[0].plot(history.epoch, history.history["loss"], label="Train Loss",␣
↪color="navy")
ax[0].legend()
ax[1].set_facecolor('palegoldenrod')
ax[1].set_title('Accuracy', fontweight="bold")
ax[1].set_xlabel("Epoch", size=14)
ax[1].plot(history.epoch, history.history["accuracy"], label="Train Acc.",␣
↪color="navy")
ax[1].legend()
12
Evaluating Test Data
[47]: model.evaluate(test_data)
Classification Report
[48]: predictions = []
labels = []
for X, y in test_data.as_numpy_iterator():
y_pred = model.predict(X, verbose=0)
y_prediction = np.argmax(y_pred, axis=1)
predictions.extend(y_prediction)
labels.extend(y)
predictions = np.array(predictions)
labels = np.array(labels)
13
Moderate_Demented 1.00 1.00 1.00 5
Non_Demented 0.99 0.99 0.99 327
Very_Mild_Demented 0.98 0.99 0.98 224
Confusion Matrix
[49]: cm = confusion_matrix(labels, predictions)
cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
cm_df
plt.figure(figsize=(10,6), dpi=300)
sns.heatmap(cm_df, annot=True, cmap="Greys", fmt=".1f")
plt.title("Confusion Matrix", fontweight="bold")
plt.xlabel("Predicted", fontweight="bold")
plt.ylabel("True", fontweight="bold")
Let’s create a function that fetches a random image and displays a pie chart showing the probability
distribution of which target value the image belongs to, represented as percentages. In this way, it
will be seen which class the model gives the highest probability to.
Alzheimer Probability of a Random MRI from Test Data
14
[50]: def random_mri_prob_bringer(image_number=0):
probs = list(tf.nn.softmax(pred).numpy())
probs_dict = dict(zip(class_dist.keys(), probs))
keys = list(probs_dict.keys())
values = list(probs_dict.values())
plt.gca().axes.yaxis.set_ticklabels([])
plt.gca().axes.xaxis.set_ticklabels([])
15
Now, let’s see the actual classes and predicted classes of these samples by bringing samples from
our test data.
Comparing Predicted Classes with the Actual Classes from the Test Data
[51]: plt.figure(figsize=(20, 20), facecolor="gray")
for images, labels in test_data.take(1):
for i in range(25):
ax = plt.subplot(5, 5, i + 1)
plt.imshow(images[i])
predictions = model.predict(tf.expand_dims(images[i], 0), verbose=0)
score = tf.nn.softmax(predictions[0])
if(class_names[labels[i]]==class_names[np.argmax(score)]):
plt.title("Actual: "+class_names[labels[i]], color="aqua",␣
↪fontweight="bold", fontsize=10)
plt.ylabel("Predicted: "+class_names[np.argmax(score)],␣
↪color="springgreen", fontweight="bold", fontsize=10)
ok_text.set_bbox(dict(facecolor='lime', alpha=0.5))
else:
plt.title("Actual: "+class_names[labels[i]], color="aqua",␣
↪fontweight="bold", fontsize=10)
plt.ylabel("Predicted: "+class_names[np.argmax(score)],␣
↪color="maroon", fontweight="bold", fontsize=10)
16
I hope you enjoyed it. Please upvote if you like this notebook. Any feedbacks are welcome. Thank
you in advance!
17