ML | Transfer Learning with Convolutional Neural Networks
Last Updated :
20 Mar, 2024
Transfer learning as a general term refers to reusing the knowledge learned from one task for another. Specifically for convolutional neural networks (CNNs), many image features are common to a variety of datasets (e.g. lines, edges are seen in almost every image). It is for this reason that, especially for large structures, CNNs are very rarely trained completely from scratch as large datasets and heavy computational resources are hard to come by.
A common pretraining dataset used is the ImageNet dataset, consisting of 1.2 million images. The actual model used varies from task to task (many times, people just choose what performs best on the ImageNet challenge), but ResNet50 model in used this article. The pre-trained model can often be found through whatever library is being used which, in this case, is Keras.
ResNet Introduction
ResNet was initially designed as a method to solve the vanishing gradient problem. This is a problem where backpropagated gradients become extremely small as they’re multiplied over and over again, limiting the size of a neural network. The ResNet architecture attempts to solve that by employing skip connections, that is adding shortcuts that allow data to skip past layers.
The model consists of a series of convolutional layers + skip connections, then average pooling, then an output fully connected (dense) layer. For transfer learning, we only want the convolutional layers as those to contain the features we’re interested in, so we would want to omit them when importing the model. Finally, because we’re removing the output layers, we then need to replace them with our own series of layers.
Problem Statement
To show the process of transfer learning, I’ll be using the Caltech-101 dataset, an image dataset with 101 categories and about 40-800 images per category.
Data Processing
First download and extract the dataset. Make sure to remove the “BACKGROUND_Google” folder after extraction.
Code : To properly evaluate, we need to split the data into training and testing sets as well. Here, we need to split within each category to ensure proper representation in the test set.
TEST_SPLIT = 0.2
VALIDATION_SPLIT = 0.2
import os
import math
os.mkdir( "caltech_test" )
for cat in os.listdir( "101_ObjectCategories/" ):
os.mkdir( "caltech_test/" + cat)
imgs = os.listdir( "101_ObjectCategories/" + cat)
split = math.floor( len (imgs) * TEST_SPLIT)
test_imgs = imgs[:split]
for t_img in test_imgs:
os.rename( "101_ObjectCategories/" + cat + "/" + t_img,
"caltech_test/" + cat + "/" + t_img)
|
Output:
This above code creates the file structure:
101_ObjectCategories/
-- accordion
-- airplanes
-- anchor
-- ...
caltech_test/
-- accordion
-- airplanes
-- anchor
-- ...
The first folder contains the train images, the second contains test images. Each subfolder includes images belonging to that category. To input the data, we’re going to use Keras’s ImageDataGenerator class. ImageDataGenerator allows for the easy processing of image data, having options for augmentation as well.
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator
train_gen = ImageDataGenerator(
validation_split = 0.2 ,
preprocessing_function = preprocess_input)
train_flow = train_gen.flow_from_directory( "101_ObjectCategories/" ,
target_size = ( 256 , 256 ),
batch_size = 32 ,
subset = "training" )
valid_flow = train_gen.flow_from_directory( "101_ObjectCategories/" ,
target_size = ( 256 , 256 ),
batch_size = 32 ,
subset = "validation" )
test_gen = ImageDataGenerator(
preprocessing_function = preprocess_input)
test_flow = test_gen.flow_from_directory( "caltech_test" ,
target_size = ( 256 , 256 ),
batch_size = 32 )
|
The above code takes the file path of the image directory and creates an object for data generation.
Model Building
Code : To add the base pretrained model.
from keras.applications.resnet50 import ResNet50
from keras.layers import GlobalAveragePooling2D, Dense
from keras.layers import BatchNormalization, Dropout
from keras.models import Model
res = ResNet50(weights = 'imagenet' , include_top = False ,
input_shape = ( 256 , 256 , 3 ))
|
This dataset is relatively small at around 5628 images after splitting, with most categories having only 50 images, so fine-tuning the convolutional layers may result in overfitting. Our new dataset is pretty similar to the ImageNet dataset, so we can be confident that a lot of the pre-trained weights have the correct features as well. So, we can freeze those trained convolutional layers so they aren’t changed when we train the rest of the classifier. If you have a smaller dataset that is significantly different from the original, fine-tuning may still cause overfitting, but the later layers wouldn’t contain the correct features. So, you could again freeze the convolutional layers but only use the output from earlier layers as those contain more general features. With a large dataset, you don’t need to worry about overfitting, so you can often fine-tune the entire network.
from keras.applications.resnet50 import ResNet50
from keras.layers import GlobalAveragePooling2D, Dense
from keras.layers import BatchNormalization, Dropout
from keras.models import Model
res = ResNet50(weights = 'imagenet' , include_top = False ,
input_shape = ( 256 , 256 , 3 ))
|
Now, we can add the rest of the classifier. This takes the output from the pre-trained convolutional layers and inputs it into a separate classifier that gets trained on the new dataset.
x = res.output
x = GlobalAveragePooling2D()(x)
x = BatchNormalization()(x)
x = Dropout( 0.5 )(x)
x = Dense( 512 , activation = 'relu' )(x)
x = BatchNormalization()(x)
x = Dropout( 0.5 )(x)
x = Dense( 101 , activation = 'softmax' )(x)
model = Model(res. input , x)
model. compile (optimizer = 'Adam' ,
loss = 'categorical_crossentropy' ,
metrics = [ 'accuracy' ])
model.summary()
|
Code : Train the model
model.fit_generator(train_flow, epochs = 5 , validation_data = valid_flow)
|
Output:
Epoch 1/5
176/176 [==============================] - 27s 156ms/step - loss: 1.6601 - acc: 0.6338 - val_loss: 0.3799 - val_acc: 0.8922
Epoch 2/5
176/176 [==============================] - 19s 107ms/step - loss: 0.4637 - acc: 0.8696 - val_loss: 0.2841 - val_acc: 0.9225
Epoch 3/5
176/176 [==============================] - 19s 107ms/step - loss: 0.2777 - acc: 0.9211 - val_loss: 0.2714 - val_acc: 0.9225
Epoch 4/5
176/176 [==============================] - 19s 107ms/step - loss: 0.2223 - acc: 0.9327 - val_loss: 0.2419 - val_acc: 0.9284
Epoch 5/5
176/176 [==============================] - 19s 106ms/step - loss: 0.1784 - acc: 0.9461 - val_loss: 0.2499 - val_acc: 0.9239
Code: To evaluate the test set
result = model.evaluate(test_flow)
print ( 'The model achieved a loss of %.2f and,'
'accuracy of %.2f%%.' % (result[ 0 ], result[ 1 ] * 100 ))
|
Output:
53/53 [==============================] - 5s 95ms/step
The model achieved a loss of 0.23 and accuracy of 92.80%.
For a 101 class dataset, we have achieved a 92.8% accuracy after only 5 epochs. For perspective, the original ResNet was trained on an ~1 million image dataset, for 120 epochs.
There are a couple of things that could be improved upon. For one, looking at the discrepancy between validation loss and training loss in the last epoch, you can see that the model is starting to overfit. One way to solve this is to add image augmentation. Simple image augmentation can be easily implemented with the ImageDataGenerator class. You could also play around with adding/removing layers or changing hyperparameters such as the dropout or the size of the Dense layer.
Run this code here with Google Colab’s free GPU compute resources.
Similar Reads
Convolutional Neural Network (CNN) in Machine Learning
Convolutional Neural Networks (CNNs) are a specialized class of neural networks designed to process grid-like data, such as images. They are particularly well-suited for image recognition and processing tasks. They are inspired by the visual processing mechanisms in the human brain, CNNs excel at ca
8 min read
Convolutional Neural Network (CNN) in Tensorflow
Convolutional Neural Networks (CNNs) have revolutionized the field of computer vision by automatically learning spatial hierarchies of features from images. In this article we will explore the basic building blocks of CNNs and show you how to implement a CNN model using TensorFlow. Building Blocks o
5 min read
Convolutional Neural Networks (CNNs) in R
Convolutional Neural Networks (CNNs) are a specialized type of neural network designed to process and analyze visual data. They are particularly effective for tasks involving image recognition and classification due to their ability to automatically and adaptively learn spatial hierarchies of featur
11 min read
Importance of Convolutional Neural Network | ML
Convolutional Neural Network as the name suggests is a neural network that makes use of convolution operation to classify and predict. Let's analyze the use cases and advantages of a convolutional neural network over a simple deep learning network. Weight sharing: It makes use of Local Spatial coher
2 min read
Vision Transformers vs. Convolutional Neural Networks (CNNs)
In recent years, the landscape of computer vision has evolved significantly with the introduction of Vision Transformers (ViTs), which challenge the dominance of traditional Convolutional Neural Networks (CNNs). While CNNs have been the backbone of many state-of-the-art image classification models,
5 min read
Working of Convolutional Neural Network (CNN) in Tensorflow
In this article, we are going to see the working of convolution neural networks with TensorFlow a powerful machine learning library to create neural networks. Now to know, how a convolution neural network lets break it into parts. the 3 most important parts of this convolution neural networks are, C
5 min read
Flower Recognition Using Convolutional Neural Network
Convolutional Neural Network (CNN) are a type of deep learning model specifically designed for processing structured grid data such as images. In this article we will build a CNN model to classify different types of flowers from a dataset containing images of various flowers like roses, daisies, dan
6 min read
How to Define a Simple Convolutional Neural Network in PyTorch?
In this article, we are going to see how to Define a Simple Convolutional Neural Network in PyTorch using Python. Convolutional Neural Networks(CNN) is a type of Deep Learning algorithm which is highly instrumental in learning patterns and features in images. CNN has a unique trait which is its abil
5 min read
Training of Convolutional Neural Network (CNN) in TensorFlow
In this article, we are going to implement and train a convolutional neural network CNN using TensorFlow a massive machine learning library. Now in this article, we are going to work on a dataset called 'rock_paper_sissors' where we need to simply classify the hand signs as rock paper or scissors. S
5 min read
Emotion Detection Using Convolutional Neural Networks (CNNs)
Emotion detection, also known as facial emotion recognition, is a fascinating field within the realm of artificial intelligence and computer vision. It involves the identification and interpretation of human emotions from facial expressions. Accurate emotion detection has numerous practical applicat
15+ min read