0% found this document useful (0 votes)
19 views17 pages

Alzheimers Classification 2

The document discusses classifying Alzheimer's disease from MRI images using a convolutional neural network (CNN). It provides details on: 1. The dataset contains 6,400 MRIs labeled as mild, moderate, non or very mild dementia. 2. The data is imbalanced, with non-demented cases comprising 50% but moderate dementia only 1%. 3. Samples of each class are shown after the data is split 80% for training, 10% for validation, and 10% for testing. Pixel values are normalized from 0 to 1. Class weights are calculated to address the imbalanced classes during training. The CNN model is built with convolutional and max pooling layers.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
19 views17 pages

Alzheimers Classification 2

The document discusses classifying Alzheimer's disease from MRI images using a convolutional neural network (CNN). It provides details on: 1. The dataset contains 6,400 MRIs labeled as mild, moderate, non or very mild dementia. 2. The data is imbalanced, with non-demented cases comprising 50% but moderate dementia only 1%. 3. Samples of each class are shown after the data is split 80% for training, 10% for validation, and 10% for testing. Pixel values are normalized from 0 to 1. Class weights are calculated to address the imbalanced classes during training. The CNN model is built with convolutional and max pooling layers.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 17

eimer-mri-classification-using-cnn

July 23, 2023

Alzheimer MR Images Classification


<ul>
<p>6400 different MRIs (Magnetic Resonance Image) collected from different sources were giv
<li>Mild Demented</li>
<li>Moderate Demented</li>
<li>Non Demented</li>
<li>Very Mild Demented</li>
<br>
<p>After examining the target value distribution, we will split the data into training, tes
</ul>

1 Importing Necessary Libraries


[1]: import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import math
import os
import warnings
warnings.filterwarnings('ignore')

from sklearn.utils.class_weight import compute_class_weight


from sklearn.metrics import classification_report, confusion_matrix

import keras
from tensorflow import keras
from keras import Sequential
from keras import layers
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, Dropout, Activation,␣
↪BatchNormalization, Flatten, Conv2D, MaxPooling2D

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

1
plt.rcParams["figure.figsize"] = (10,6)
plt.rcParams['figure.dpi'] = 300
colors = ["#B6EE56", "#D85F9C", "#EEA756", "#56EEE8"]

[2]: try:
if tf.test.gpu_device_name():
physical_devices = tf.config.experimental.list_physical_devices('GPU')
print('GPU active! -', physical_devices)
else:
print('GPU not active!')
except Exception as e:
print('An error occurred while checking the GPU:', e)

GPU active! - [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2 Target Value Distribution


Seeing the distribution of the target values is of critical importance in determining the methods to
be applied in the subsequent steps. Therefore, let’s first navigate to the data folder path and check
the number of images in each class folder.
[34]: class_dist = {}
def image_counter(folder_path):
basename = os.path.basename(folder_path)
print('\033[92m'+f"A search has been initiated within the folder named␣
↪'{basename}'."+'\033[0m')

image_extensions = ['.jpg', '.jpeg', '.png']

for root, dirs, _ in os.walk(folder_path):


for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
count = 0

for filename in os.listdir(dir_path):


file_ext = os.path.splitext(filename)[1].lower()

if file_ext in image_extensions:
count += 1

class_dist[dir_name] = count
print(f"There are \033[35m{count}\033[0m images in the {dir_name}␣
folder.")

print('\033[92m'+"The search has been completed."+'\033[0m')

keys = list(class_dist.keys())
values = list(class_dist.values())

2
explode = (0.1,)*len(keys)

labels = [f'{key} ({value} images)' for key, value in zip(keys, values)]

plt.pie(values, explode=explode,labels=labels, autopct='%1.1f%%',


shadow=True, startangle=90, colors=colors, textprops={'fontsize':␣
↪12, "fontweight" : "bold", "color":"darkblue"}, wedgeprops=
{'edgecolor':'darkblue'} , labeldistance=1.15)
plt.title("Distribution of \nAlzheimer MRI Images", size=12,␣
↪fontweight="bold")

PATH = '/content/drive/MyDrive/Dataset'

image_counter(PATH)

A search has been initiated within the folder named 'Dataset'.


There are 2240 images in the Very_Mild_Demented folder.
There are 64 images in the Moderate_Demented folder.
There are 3200 images in the Non_Demented folder.
There are 896 images in the Mild_Demented folder.
The search has been completed.

As observed in the class distribution, we have an imbalanced dataset. Non Demented MRI class
constitutes 50% of the total data with 3200 images, while Moderate Demented MRI class only
makes up 1% of the dataset with 64 images.

3
3 Generate TF Dataset
When examining the data path, we can see that the folder named “Dataset” is the main folder,
and within it, there are separate folders for each class containing their respective images.
Dataset/ …Mild_Demented/ ……mild_1.jpg ……mild_2.jpg …Moderate_Demented/ ……moder-
ate_1.jpg ……moderate_2.jpg
When encountering a situation with such a structure, TensorFlow has a powerful function
tf.keras.utils.image_dataset_from_directory for reading data.
[35]: data = tf.keras.utils.image_dataset_from_directory(PATH,
batch_size = 32,
image_size=(128, 128),
shuffle=True,
seed=42,)

class_names = data.class_names

Found 6400 files belonging to 4 classes.

[5]: from google.colab import drive


drive.mount('/content/drive')

Mounted at /content/drive
Let’s see some samples for each class!
MRI Samples for Each Class
[36]: def sample_bringer(path, target, num_samples=5):

class_path = os.path.join(path, target)

image_files = [image for image in os.listdir(class_path) if image.


↪endswith('.jpg')]

fig, ax = plt.subplots(1, num_samples, facecolor="gray")


fig.suptitle(f'{target} Brain MRI Samples', color="yellow",fontsize=16,␣
↪fontweight='bold', y=0.75)

for i in range(num_samples):
image_path = os.path.join(class_path, image_files[i])
img = mpimg.imread(image_path)

ax[i].imshow(img)
ax[i].axis('off')
ax[i].set_title(f'Sample {i+1}', color="aqua")

plt.tight_layout()

4
for target in class_names:
sample_bringer(PATH, target=target)

5
Pixel normalization improves the performance of a neural network. Therefore, we will go with pixel
values from 0 to 1, rather than values in the range 0 to 255.
[37]: alz_dict = {index: img for index, img in enumerate(data.class_names)}

class Process:
def __init__(self, data):
self.data = data.map(lambda x, y: (x/255, y))

def create_new_batch(self):
self.batch = self.data.as_numpy_iterator().next()
text = "Min and max pixel values in the batch ->"
print(text, self.batch[0].min(), "&", self.batch[0].max())

def show_batch_images(self, number_of_images=5):


fig, ax = plt.subplots(ncols=number_of_images, figsize=(20,20),␣
↪facecolor="gray")

fig.suptitle("Brain MRI (Alzheimer) Samples in the Batch",␣


↪color="yellow",fontsize=18, fontweight='bold', y=0.6)

for idx, img in enumerate(self.batch[0][:number_of_images]):


ax[idx].imshow(img)
class_no = self.batch[1][idx]
ax[idx].set_title(alz_dict[class_no], color="aqua")
ax[idx].set_xticklabels([])
ax[idx].set_yticklabels([])

def train_test_val_split(self, train_size, val_size, test_size):

train = int(len(self.data)*train_size)
test = int(len(self.data)*test_size)
val = int(len(self.data)*val_size)

train_data = self.data.take(train)
val_data = self.data.skip(train).take(val)
test_data = self.data.skip(train+val).take(test)

6
return train_data, val_data, test_data

[38]: process = Process(data)


process.create_new_batch()
process.show_batch_images(number_of_images=5)

Min and max pixel values in the batch -> 0.0 & 1.0

We will divide the dataset into 80% training data, 10% validation data and 10% test data.
[39]: train_data, val_data, test_data= process.train_test_val_split(train_size=0.8,␣
↪val_size=0.1, test_size=0.1)

We have an imbalanced distribution of target class. When dealing with an imbalanced target class
distribution, using class weights can help the model perform better and effectively recognize the
minority classes. Therefore, let’s calculate the weights of the target classes in our training data and
provide this information to our model during training.
[41]: y_train = tf.concat(list(map(lambda x: x[1], train_data)), axis=0)
class_weight = compute_class_weight('balanced',classes=np.unique(y_train),␣
↪y=y_train.numpy())

class_weights = dict(zip(np.unique(y_train), class_weight))

4 Model Building
[42]: def build_model():
model = Sequential()

model.add(Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1),␣


↪activation="relu", kernel_initializer='he_normal',
input_shape=(128, 128, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1),␣


↪activation="relu", kernel_initializer='he_normal'))

7
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1),␣


↪activation="relu", kernel_initializer='he_normal'))

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(128, activation="relu", kernel_initializer='he_normal'))
model.add(Dense(64, activation="relu"))
model.add(Dense(4, activation="softmax"))

model.compile(optimizer='adam', loss="sparse_categorical_crossentropy",␣
↪metrics=['accuracy'])

model.summary()

return model

model = build_model()

Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 126, 126, 16) 448

max_pooling2d (MaxPooling2D (None, 63, 63, 16) 0


)

conv2d_1 (Conv2D) (None, 61, 61, 32) 4640

max_pooling2d_1 (MaxPooling (None, 30, 30, 32) 0


2D)

conv2d_2 (Conv2D) (None, 28, 28, 128) 36992

max_pooling2d_2 (MaxPooling (None, 14, 14, 128) 0


2D)

flatten (Flatten) (None, 25088) 0

dense (Dense) (None, 128) 3211392

dense_1 (Dense) (None, 64) 8256

dense_2 (Dense) (None, 4) 260

8
=================================================================
Total params: 3,261,988
Trainable params: 3,261,988
Non-trainable params: 0
_________________________________________________________________
Callbacks
[43]: def checkpoint_callback():

checkpoint_filepath = '/tmp/checkpoint'

model_checkpoint_callback= ModelCheckpoint(filepath=checkpoint_filepath,
save_weights_only=False,
frequency='epoch',
monitor='val_accuracy',
save_best_only=True,
verbose=1)

return model_checkpoint_callback

def early_stopping(patience):
es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',␣
↪patience=patience, verbose=1)

return es_callback

EPOCHS = 20
checkpoint_callback = checkpoint_callback()
early_stopping = early_stopping(patience=5)
callbacks = [checkpoint_callback, early_stopping]

[45]: history = model.fit(train_data, epochs = EPOCHS, validation_data = val_data,␣


↪class_weight = class_weights, callbacks = callbacks)

Epoch 1/20
160/160 [==============================] - ETA: 0s - loss: 0.8399 - accuracy:
0.5764
Epoch 1: val_accuracy improved from -inf to 0.70156, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 32s 199ms/step - loss: 0.8399 -
accuracy: 0.5764 - val_loss: 0.6907 - val_accuracy: 0.7016
Epoch 2/20
159/160 [============================>.] - ETA: 0s - loss: 0.3954 - accuracy:

9
0.7531
Epoch 2: val_accuracy improved from 0.70156 to 0.80625, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 15s 93ms/step - loss: 0.3947 -
accuracy: 0.7529 - val_loss: 0.4266 - val_accuracy: 0.8062
Epoch 3/20
159/160 [============================>.] - ETA: 0s - loss: 0.1618 - accuracy:
0.9135
Epoch 3: val_accuracy improved from 0.80625 to 0.95781, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 89ms/step - loss: 0.1614 -
accuracy: 0.9139 - val_loss: 0.1392 - val_accuracy: 0.9578
Epoch 4/20
160/160 [==============================] - ETA: 0s - loss: 0.0535 - accuracy:
0.9779
Epoch 4: val_accuracy improved from 0.95781 to 0.96406, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 19s 118ms/step - loss: 0.0535 -
accuracy: 0.9779 - val_loss: 0.1029 - val_accuracy: 0.9641
Epoch 5/20
160/160 [==============================] - ETA: 0s - loss: 0.0261 - accuracy:
0.9918
Epoch 5: val_accuracy improved from 0.96406 to 0.98438, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 74s 466ms/step - loss: 0.0261 -
accuracy: 0.9918 - val_loss: 0.0501 - val_accuracy: 0.9844
Epoch 6/20
159/160 [============================>.] - ETA: 0s - loss: 0.0069 - accuracy:
0.9994
Epoch 6: val_accuracy improved from 0.98438 to 0.99375, saving model to
/tmp/checkpoint

10
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 88ms/step - loss: 0.0068 -
accuracy: 0.9994 - val_loss: 0.0271 - val_accuracy: 0.9937
Epoch 7/20
160/160 [==============================] - ETA: 0s - loss: 0.0015 - accuracy:
1.0000
Epoch 7: val_accuracy improved from 0.99375 to 0.99531, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 89ms/step - loss: 0.0015 -
accuracy: 1.0000 - val_loss: 0.0149 - val_accuracy: 0.9953
Epoch 8/20
160/160 [==============================] - ETA: 0s - loss: 5.9136e-04 -
accuracy: 1.0000
Epoch 8: val_accuracy did not improve from 0.99531
160/160 [==============================] - 17s 106ms/step - loss: 5.9136e-04 -
accuracy: 1.0000 - val_loss: 0.0111 - val_accuracy: 0.9953
Epoch 9/20
160/160 [==============================] - ETA: 0s - loss: 3.5138e-04 -
accuracy: 1.0000
Epoch 9: val_accuracy did not improve from 0.99531
160/160 [==============================] - 14s 84ms/step - loss: 3.5138e-04 -
accuracy: 1.0000 - val_loss: 0.0130 - val_accuracy: 0.9953
Epoch 10/20
160/160 [==============================] - ETA: 0s - loss: 2.7242e-04 -
accuracy: 1.0000
Epoch 10: val_accuracy improved from 0.99531 to 0.99687, saving model to
/tmp/checkpoint
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op,
_jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing
3 of 3). These functions will not be directly callable after loading.
160/160 [==============================] - 14s 86ms/step - loss: 2.7242e-04 -
accuracy: 1.0000 - val_loss: 0.0101 - val_accuracy: 0.9969
Epoch 11/20
160/160 [==============================] - ETA: 0s - loss: 2.1518e-04 -
accuracy: 1.0000
Epoch 11: val_accuracy did not improve from 0.99687
160/160 [==============================] - 18s 109ms/step - loss: 2.1518e-04 -
accuracy: 1.0000 - val_loss: 0.0122 - val_accuracy: 0.9969
Epoch 12/20
160/160 [==============================] - ETA: 0s - loss: 1.7408e-04 -

11
accuracy: 1.0000
Epoch 12: val_accuracy did not improve from 0.99687
160/160 [==============================] - 13s 77ms/step - loss: 1.7408e-04 -
accuracy: 1.0000 - val_loss: 0.0128 - val_accuracy: 0.9937
Epoch 13/20
160/160 [==============================] - ETA: 0s - loss: 1.3730e-04 -
accuracy: 1.0000
Epoch 13: val_accuracy did not improve from 0.99687
160/160 [==============================] - 13s 77ms/step - loss: 1.3730e-04 -
accuracy: 1.0000 - val_loss: 0.0105 - val_accuracy: 0.9953
Epoch 14/20
159/160 [============================>.] - ETA: 0s - loss: 1.1478e-04 -
accuracy: 1.0000
Epoch 14: val_accuracy did not improve from 0.99687
160/160 [==============================] - 13s 77ms/step - loss: 1.1823e-04 -
accuracy: 1.0000 - val_loss: 0.0172 - val_accuracy: 0.9922
Epoch 15/20
160/160 [==============================] - ETA: 0s - loss: 1.0131e-04 -
accuracy: 1.0000
Epoch 15: val_accuracy did not improve from 0.99687
160/160 [==============================] - 17s 104ms/step - loss: 1.0131e-04 -
accuracy: 1.0000 - val_loss: 0.0124 - val_accuracy: 0.9937
Epoch 15: early stopping
Loss and Accuracy
[46]: fig, ax = plt.subplots(1, 2, figsize=(12,6), facecolor="khaki")
ax[0].set_facecolor('palegoldenrod')
ax[0].set_title('Loss', fontweight="bold")
ax[0].set_xlabel("Epoch", size=14)
ax[0].plot(history.epoch, history.history["loss"], label="Train Loss",␣
↪color="navy")

ax[0].plot(history.epoch, history.history["val_loss"], label="Validation Loss",␣


↪color="crimson", linestyle="dashed")

ax[0].legend()
ax[1].set_facecolor('palegoldenrod')
ax[1].set_title('Accuracy', fontweight="bold")
ax[1].set_xlabel("Epoch", size=14)
ax[1].plot(history.epoch, history.history["accuracy"], label="Train Acc.",␣
↪color="navy")

ax[1].plot(history.epoch, history.history["val_accuracy"], label="Validation␣


↪Acc.", color="crimson", linestyle="dashed")

ax[1].legend()

[46]: <matplotlib.legend.Legend at 0x783ebd55a950>

12
Evaluating Test Data
[47]: model.evaluate(test_data)

20/20 [==============================] - 45s 2s/step - loss: 0.0348 - accuracy:


0.9844

[47]: [0.0348445363342762, 0.984375]

Classification Report
[48]: predictions = []
labels = []

for X, y in test_data.as_numpy_iterator():
y_pred = model.predict(X, verbose=0)
y_prediction = np.argmax(y_pred, axis=1)
predictions.extend(y_prediction)
labels.extend(y)

predictions = np.array(predictions)
labels = np.array(labels)

print(classification_report(labels, predictions, target_names=class_names))

precision recall f1-score support

Mild_Demented 1.00 0.98 0.99 84

13
Moderate_Demented 1.00 1.00 1.00 5
Non_Demented 0.99 0.99 0.99 327
Very_Mild_Demented 0.98 0.99 0.98 224

accuracy 0.99 640


macro avg 0.99 0.99 0.99 640
weighted avg 0.99 0.99 0.99 640

Confusion Matrix
[49]: cm = confusion_matrix(labels, predictions)
cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
cm_df
plt.figure(figsize=(10,6), dpi=300)
sns.heatmap(cm_df, annot=True, cmap="Greys", fmt=".1f")
plt.title("Confusion Matrix", fontweight="bold")
plt.xlabel("Predicted", fontweight="bold")
plt.ylabel("True", fontweight="bold")

[49]: Text(286.1666666666666, 0.5, 'True')

Let’s create a function that fetches a random image and displays a pie chart showing the probability
distribution of which target value the image belongs to, represented as percentages. In this way, it
will be seen which class the model gives the highest probability to.
Alzheimer Probability of a Random MRI from Test Data

14
[50]: def random_mri_prob_bringer(image_number=0):

for images, _ in test_data.skip(5).take(1):


image = images[image_number]
pred = model.predict(tf.expand_dims(image, 0))[0]

probs = list(tf.nn.softmax(pred).numpy())
probs_dict = dict(zip(class_dist.keys(), probs))

keys = list(probs_dict.keys())
values = list(probs_dict.values())

fig, (ax1, ax2) = plt.subplots(1, 2, facecolor='black')


plt.subplots_adjust(wspace=0.4)
ax1.imshow(image)
ax1.set_title('Brain MRI', color="yellow", fontweight="bold", fontsize=16)

edges = ['left', 'bottom', 'right', 'top']


edge_color = "greenyellow"
edge_width = 3
for edge in edges:
ax1.spines[edge].set_linewidth(edge_width)
ax1.spines[edge].set_edgecolor(edge_color)

plt.gca().axes.yaxis.set_ticklabels([])
plt.gca().axes.xaxis.set_ticklabels([])

wedges, labels, autopct = ax2.pie(values, labels=keys, autopct='%1.1f%%',


shadow=True, startangle=90, colors=colors, textprops={'fontsize': 8,␣
↪"fontweight":"bold", "color":"white"}, wedgeprops=
{'edgecolor':'black'} , labeldistance=1.15)

for autotext in autopct:


autotext.set_color('black')

ax2.set_title('Alzheimer Probabilities', color="yellow", fontweight="bold",␣


↪fontsize=16)

rand_img_no = np.random.randint(1, 32)


random_mri_prob_bringer(image_number=rand_img_no)

1/1 [==============================] - 0s 177ms/step

15
Now, let’s see the actual classes and predicted classes of these samples by bringing samples from
our test data.
Comparing Predicted Classes with the Actual Classes from the Test Data
[51]: plt.figure(figsize=(20, 20), facecolor="gray")
for images, labels in test_data.take(1):
for i in range(25):
ax = plt.subplot(5, 5, i + 1)
plt.imshow(images[i])
predictions = model.predict(tf.expand_dims(images[i], 0), verbose=0)
score = tf.nn.softmax(predictions[0])
if(class_names[labels[i]]==class_names[np.argmax(score)]):
plt.title("Actual: "+class_names[labels[i]], color="aqua",␣
↪fontweight="bold", fontsize=10)

plt.ylabel("Predicted: "+class_names[np.argmax(score)],␣
↪color="springgreen", fontweight="bold", fontsize=10)

ok_text = plt.text(2, 10, "OK \u2714", color="springgreen",␣


↪fontsize=14)

ok_text.set_bbox(dict(facecolor='lime', alpha=0.5))

else:
plt.title("Actual: "+class_names[labels[i]], color="aqua",␣
↪fontweight="bold", fontsize=10)

plt.ylabel("Predicted: "+class_names[np.argmax(score)],␣
↪color="maroon", fontweight="bold", fontsize=10)

nok_text = plt.text(2, 10, "NOK \u2718", color="red", fontsize=14)


nok_text.set_bbox(dict(facecolor='maroon', alpha=0.5))
plt.gca().axes.yaxis.set_ticklabels([])
plt.gca().axes.xaxis.set_ticklabels([])

16
I hope you enjoyed it. Please upvote if you like this notebook. Any feedbacks are welcome. Thank
you in advance!

17

You might also like