Training a Neural Network using Keras API in Tensorflow
Last Updated :
11 Jun, 2024
In the field of machine learning and deep learning has been significantly transformed by tools like TensorFlow and Keras. TensorFlow, developed by Google, is an open-source platform that provides a comprehensive ecosystem for machine learning. Keras, now fully integrated into TensorFlow, offers a user-friendly, high-level API for building and training neural networks. This article will guide you through the process of training a neural network using the Keras API within TensorFlow.
Pre requisite:
pip install tensorflow
Step By Step Implementation of Training a Neural Network using Keras API in Tensorflow
Training a neural network involves several steps, including data preprocessing, model building, compiling, training, and evaluating the model. Here’s a step-by-step guide using Keras API in TensorFlow.
Step 1: Import Libraries
Python
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dropout
from tensorflow.keras.optimizers import Adam
Step 2: Prepare the Data
Load and preprocess the dataset. For demonstration, we’ll use the MNIST dataset:
Python
from tensorflow.keras.datasets import mnist
# Load data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Preprocess data
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
Step 3: Build the Model
Define the architecture of the neural network:
Python
model = Sequential([
Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(64, kernel_size=(3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(128, kernel_size=(3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
])
Step 4: Compile the Model
Compile the model with an optimizer, loss function, and metrics:
Python
model.compile(optimizer=Adam(),
loss='categorical_crossentropy',
metrics=['accuracy'])
Step 5: Train the Model
Train the model using the training data:
Python
model.fit(x_train, y_train, epochs=10, batch_size=128, validation_split=0.2)
Step 6: Evaluate the Model
Evaluate the model using the test data to check its performance:
Python
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_accuracy}')
Output:
Test accuracy: 0.78
In conclusion, the integration of TensorFlow and Keras has significantly streamlined the process of training neural networks, making it more accessible to both beginners and experienced practitioners in the field of machine learning and deep learning.
With TensorFlow providing a robust open-source platform and Keras offering a user-friendly interface through its high-level API, developers can efficiently build, train, and evaluate neural network models.
Through the step-by-step implementation outlined in this guide, we've seen how to preprocess data, define the neural network architecture, compile the model with appropriate parameters, train the model using training data, and evaluate its performance using test data.
However, it's essential to note that achieving high accuracy in model evaluation, as demonstrated by the test accuracy of 0.78 in this example, often requires experimentation with various architectures, hyperparameters, and optimization techniques. Continuous learning and experimentation are key to refining models and pushing the boundaries of what is achievable in the field of machine learning.
Similar Reads
Machine Learning Tutorial
Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.It can
5 min read
Non-linear Components
In electrical circuits, Non-linear Components are electronic devices that need an external power source to operate actively. Non-Linear Components are those that are changed with respect to the voltage and current. Elements that do not follow ohm's law are called Non-linear Components. Non-linear Co
11 min read
Linear Regression in Machine learning
Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets. It assumes that there is a linear relationship between the input and output, mea
15+ min read
Support Vector Machine (SVM) Algorithm
Support Vector Machine (SVM) is a supervised machine learning algorithm used for classification and regression tasks. While it can handle regression problems, SVM is particularly well-suited for classification tasks. SVM aims to find the optimal hyperplane in an N-dimensional space to separate data
10 min read
Class Diagram | Unified Modeling Language (UML)
A UML class diagram is a visual tool that represents the structure of a system by showing its classes, attributes, methods, and the relationships between them. It helps everyone involved in a projectâlike developers and designersâunderstand how the system is organized and how its components interact
12 min read
K means Clustering â Introduction
K-Means Clustering is an Unsupervised Machine Learning algorithm which groups unlabeled dataset into different clusters. It is used to organize data into groups based on their similarity. Understanding K-means ClusteringFor example online store uses K-Means to group customers based on purchase frequ
4 min read
Spring Boot Tutorial
Spring Boot is a Java framework that makes it easier to create and run Java applications. It simplifies the configuration and setup process, allowing developers to focus more on writing code for their applications. This Spring Boot Tutorial is a comprehensive guide that covers both basic and advance
10 min read
Logistic Regression in Machine Learning
In our previous discussion, we explored the fundamentals of machine learning and walked through a hands-on implementation of Linear Regression. Now, let's take a step forward and dive into one of the first and most widely used classification algorithms â Logistic RegressionWhat is Logistic Regressio
12 min read
K-Nearest Neighbor(KNN) Algorithm
K-Nearest Neighbors (KNN) is a supervised machine learning algorithm generally used for classification but can also be used for regression tasks. It works by finding the "k" closest data points (neighbors) to a given input and makesa predictions based on the majority class (for classification) or th
8 min read
Backpropagation in Neural Network
Backpropagation is also known as "Backward Propagation of Errors" and it is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network. In this article we will explore what
10 min read