Using Learning Curves - ML
Last Updated :
17 Jul, 2020
A learning model of a Machine Learning model shows how the error in the prediction of a Machine Learning model changes as the size of the training set increases or decreases.
Before we continue, we must first understand what variance and bias mean in the Machine Learning model.
Bias:
It is basically nothing but the difference between the average prediction of a model and the correct value of the prediction. Models with high bias make a lot of assumptions about the training data. This leads to over-simplification of the model and may cause a high error on both the training and testing sets. However, this also makes the model faster to learn and easy to understand. Generally, linear model algorithms like Linear Regression have a high bias.
Variance:
It is the amount a model's prediction will change if the training data is changed. Ideally, a machine learning model should not vary too much with a change in training sets i.e., the algorithm should be good at picking up important details about the data, regardless of the data itself. Example of algorithms with high variance is Decision Trees, Support Vector Machines (SVM).
Ideally, we would want a model with low variance as well as low bias. To achieve lower bias, we need more training data but with higher training data, the variance of the model will increase. So, we have to strike a balance between the two. This is called the
bias-variance trade-off.
A learning curve can help to find the right amount of training data to fit our model with a good bias-variance trade-off. This is why learning curves are so important.
Now that we understand the bias-variance trade-off and why a learning curve is important, we will now learn how to use learning curves in Python using the scikit-learn library of Python.
Implementation of Learning Curves in Python:
For the sake of this example, we will be using the very popular, 'Digit' data set. For more information on this data set, you can refer to the link below :
https://scikit-learn.org/stable/auto_examples/datasets/plot_digits_last_image
We will use a k-Nearest Neighbour classifier for this example. We will also perform 10-fold cross-validation for obtaining validation scores to plot on the graph.
Code:
python3
#Importing Required Libraries and Modules
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve
# Load data set
dataset = load_digits()
# X contains data and y contains labels
X, y = dataset.data, dataset.target
# Obtain scores from learning curve function
# cv is the number of folds while performing Cross Validation
sizes, training_scores, testing_scores = learning_curve(KNeighborsClassifier(), X, y, cv=10, scoring='accuracy', train_sizes=np.linspace(0.01, 1.0, 50))
# Mean and Standard Deviation of training scores
mean_training = np.mean(training_scores, axis=1)
Standard_Deviation_training = np.std(training_scores, axis=1)
# Mean and Standard Deviation of testing scores
mean_testing = np.mean(testing_scores, axis=1)
Standard_Deviation_testing = np.std(testing_scores, axis=1)
# dotted blue line is for training scores and green line is for cross-validation score
plt.plot(sizes, mean_training, '--', color="b", label="Training score")
plt.plot(sizes, mean_testing, color="g", label="Cross-validation score")
# Drawing plot
plt.title("LEARNING CURVE FOR KNN Classifier")
plt.xlabel("Training Set Size"), plt.ylabel("Accuracy Score"), plt.legend(loc="best")
plt.tight_layout()
plt.show()
Output:
From the curve, we can clearly see that as the size of the training set increases, the training score curve and the cross-validation score curve converge. The cross-validation accuracy increases as we add more training data. So adding training data is useful in this case. Since the training score is very accurate, this indicates low bias and high variance. So this model also begins overfitting the data because the cross-validation score is relatively lower and increases very slowly as the size of the training set increases.
Conclusion:
Learning Curves are a great diagnostic tool to determine bias and variance in a supervised machine learning algorithm. In this article, we have learnt what learning curves and how they are implemented in Python.
Similar Reads
Validation Curve using Scikit-learn
Validation curves are essential tools in machine learning for diagnosing model performance and understanding the impact of hyperparameters on model accuracy. This article will delve into the concept of validation curves, their importance, and how to implement them using Scikit-learn in Python. Table
7 min read
Calories Burnt Prediction using Machine Learning
In this article, we will learn how to develop a machine learning model using Python which can predict the number of calories a person has burnt during a workout based on some biological measures.Importing Libraries and DatasetPython libraries make it easy for us to handle the data and perform typica
5 min read
SciPy | Curve Fitting
Given a Dataset comprising of a group of points, find the best fit representing the Data.We often have a dataset comprising of data following a general path, but each data has a standard deviation which makes them scattered across the line of best fit. We can get a single line using curve-fit() func
4 min read
Creating a simple machine learning model
Machine Learning models are the core of smart applications. To get a better insight into how machine learning models work, let us discuss one of the most basic algorithms: Linear Regression. This is employed in predicting a dependent variable based on one or multiple independent variables by utilizi
3 min read
Maths for Machine Learning
Mathematics is the foundation of machine learning. Math concepts plays a crucial role in understanding how models learn from data and optimizing their performance. Before diving into machine learning algorithms, it's important to familiarize yourself with foundational topics, like Statistics, Probab
5 min read
Python for Machine Learning
Welcome to "Python for Machine Learning," a comprehensive guide to mastering one of the most powerful tools in the data science toolkit. Python is widely recognized for its simplicity, versatility, and extensive ecosystem of libraries, making it the go-to programming language for machine learning. I
6 min read
Learning Curve To Identify Overfit & Underfit
A learning curve is a graphical representation showing how an increase in learning comes from greater experience. It can also reveal if a model is learning well, overfitting, or underfitting. In this article, we'll gain insights on how to identify underfitted and overfitted models using Learning Cur
9 min read
Learning Model Building in Scikit-learn
Building machine learning models from scratch can be complex and time-consuming. Scikit-learn which is an open-source Python library which helps in making machine learning more accessible. It provides a straightforward, consistent interface for a variety of tasks like classification, regression, clu
8 min read
Curve Fitting using Linear and Nonlinear Regression
Curve fitting, a fundamental technique in data analysis and machine learning, plays a pivotal role in modelling relationships between variables, predicting future outcomes, and uncovering underlying patterns in data. In this article, we delve into the intricacies of linear and nonlinear regression,
4 min read
Tuning Machine Learning Models using Caret package in R
Machine Learning is an important part of Artificial Intelligence for data analysis. It is widely used in many sectors such as healthcare, E-commerce, Finance, Recommendations, etc. It plays an important role in understanding the trends and patterns in our data to predict useful information that can
15+ min read