Open In App

Working of Convolutional Neural Network (CNN) in Tensorflow

Last Updated : 18 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Convolutional Neural Networks (CNNs) are deep learning models particularly used for image processing tasks. In this article, we’ll see how CNNs work using TensorFlow. To understand how Convolutional Neural Networks function it is important to break down the process into three core operations:

  1. Convolution
  2. Pooling
  3. Flattening

These operations are the foundation of CNNs and enable them to perform significantly better than traditional artificial neural networks particularly in visual data processing.

1. Convolution

Convolution is the process of scanning an image with a set of filters also known as kernels to extract high-level features such as edges, textures and shapes. These filters are small, learnable matrices that slide over the image, producing feature maps that highlight areas where the filter detects a pattern.

Let’s consider the MNIST dataset which contains 28x28 grayscale images of handwritten digits. Feeding these images into a traditional neural network would require 784 input nodes (one for each pixel) resulting in a large number of weights and limited spatial awareness. Now imagine handling a 1920x1080 high-resolution image over 2 million input nodes and potentially more than 130 million weights if even one small hidden layer is used. This is not scalable.

CNNs solve this by applying convolutional layers that focus on learning relevant patterns instead of processing every pixel independently hence reducing the number of trainable parameters.

2. Pooling

Once features are extracted using convolution, the resulting feature maps are passed through a pooling layer to reduce their dimensionality. This helps simplify the data and ensures the most significant features are retained.

The most common type is MaxPooling which selects the highest value from a defined window like a 2x2 window which effectively summarize the presence of features without losing critical information. This speeds up training but also adds a degree of translation invariance.

3. Flattening

After pooling the multi dimensional feature maps are converted into a one dimensional vector. This step is known as flattening. This vector is then fed into fully connected (dense) layers that perform high-level reasoning and output predictions.

Implementation of CNN models

We will now implement a CNN model to understand its working with more clarity.

1. Importing Required Libraries

We are importing TensorFlow and Keras modules to train deep learning models.

Python
import tensorflow as tf
from tensorflow import keras

2. Loading and Preprocessing Dataset

We are loading the CIFAR-10 dataset into memory and normalizing the pixel values to a 0-1 range which helps the model train faster and improves convergence.

Python
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

train_images, test_images = train_images / 255.0, test_images / 255.0

3. Defining CNN Architecture

We are stacking convolutional and pooling layers to extract features from images then flattening the output before passing it through dense layers for classification.

Python
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    tf.keras.layers.MaxPooling2D((2, 2)),

    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),

    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

4. Compiling the Model

We are specifying the optimizer, loss function and evaluation metrics to learn and minimize prediction errors during training.

Python
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

5. Training the Model

We are fitting the model on the training data for a set number of epochs while validating its performance on unseen test data to monitor progress.

Python
history = model.fit(train_images, train_labels, epochs=10,
                    validation_data=(test_images, test_labels))

Output:

We can see our model is trainned.


Next Article

Similar Reads