# -*- coding: utf-8 -*-
"""TensorFlow 2 [Link]
Automatically generated by Colaboratory.
Original file is located at
[Link]
"""
import tensorflow as tf
from tensorflow import keras
import [Link] as plt
import numpy as np
mnist = [Link]
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0
def visualize(images, labels ,w=2, h=5, width = 6, height = 3, color='gray',
show_axis='off'):
fig, axes = [Link](w, h, figsize=(width, height))
axes = [Link]()
for i in range(len(images)):
axes[i].imshow(images[i], cmap=color)
axes[i].axis(show_axis)
axes[i].set_title(labels[i])
plt.tight_layout()
[Link]()
print('Display sample of train images')
visualize(x_train[:10], y_train[:10])
print('Display sample of test images')
visualize(x_test[:10], y_test[:10])
def model_(input_shape = tuple([28, 28]), activations = ['relu', 'softmax']):
return [Link]([
[Link](input_shape=input_shape),
[Link](128, activation=activations[0]),
[Link](10, activation=activations[1])
])
def model_train(model, data=[[],[]], epochs = 10, optimizer='adam',
loss='sparse_categorical_crossentropy', metrics=['accuracy']):
[Link](optimizer=optimizer,
loss=loss,
metrics=metrics)
[Link](data[0], data[1], epochs=epochs)
return model
model = model_()
print('Train the model')
model = model_train(model = model, data = [x_train, y_train], epochs=10)
print('Test Evaluation')
test_loss, test_acc = [Link](x_test, y_test, verbose=2)
print('Test accuracy:', test_acc)
print('Make predictions on the test dataset')
predictions = [Link](x_test)
print('Getting the predicted labels')
predicted_labels = [Link](predictions, axis=1)
# Function to visualize original images along with predicted labels
def visualize_predictions(images, true_labels, predicted_labels,
w=2, h=5, width = 8, height =
3,color='gray',show_axis='off'):
fig, axes = [Link](w, h, figsize=(width, height))
axes = [Link]()
for i in range(len(images)):
axes[i].imshow(images[i], cmap=color)
axes[i].set_title(f"Actual: {true_labels[i]}\nPrediction:
{predicted_labels[i]}")
axes[i].axis(show_axis)
plt.tight_layout()
[Link]()
print('Visualizing original images along with predicted labels')
visualize_predictions(x_test[:10], y_test[:10], predicted_labels[:10])