Open In App

SVM with Cross Validation in R

Last Updated : 11 Jul, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Support Vector Machine (SVM) is a powerful and versatile machine learning model used for classification and regression tasks. In this article, we'll go through the steps to implement an SVM with cross-validation in R using the caret package.

Cross Validation in R

Cross-validation involves splitting the data into multiple parts (folds), training the model on some parts, and testing it on the remaining parts. The most common type is k-fold cross-validation, where the data is divided into k subsets (folds), and the model is trained k times, each time leaving out one of the folds for validation and using the rest for training.

Now we will discuss a Step-by-Step Guide for performing SVM with Cross Validation in R Programming Language.

Step 1: Install and Load Necessary Packages

Ensure you have the caret, e1071, and ggplot2 packages installed. These packages provide the tools for creating and evaluating SVM models.

R
install.packages("caret")
install.packages("e1071")
install.packages("ggplot2")
library(caret)
library(e1071)
library(ggplot2)

Step 2: Load and Prepare the Data

For this example, we'll use the built-in iris dataset.

R
data(iris)
head(iris)

Output:

  Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa

Step 3: Set Up Cross-Validation Control

Define the control function for cross-validation using the trainControl function.

R
custom_control <- trainControl(
  method = "cv",       # Cross-validation
  number = 5,          # Number of folds
  verboseIter = TRUE,  # Output the progress
  savePredictions = "final",
  classProbs = TRUE,   # Estimate class probabilities
  summaryFunction = multiClassSummary # Use for multi-class classification
)

Step 4: Train and Evaluate SVM Model

Use the train function to train the SVM model with cross-validation. Print the summary of the trained model to see the results of cross-validation.

R
set.seed(123)  # Set seed for reproducibility
svm_model <- train(
  Species ~ .,            # Formula for the model
  data = iris,            # Data to be used
  method = "svmRadial",   # SVM with radial basis function kernel
  trControl = custom_control,
  tuneLength = 10         # Tune over 10 different parameter values
)
print(svm_model)

Output:

Support Vector Machines with Radial Basis Function Kernel 

150 samples
4 predictor
3 classes: 'setosa', 'versicolor', 'virginica'

No pre-processing
Resampling: Cross-Validated (5 fold)
Summary of sample sizes: 120, 120, 120, 120, 120
Resampling results across tuning parameters:

C logLoss AUC prAUC Accuracy Kappa Mean_F1 Mean_Sensitivity
0.25 0.1436566 0.9973333 0.8950640 0.9466667 0.92 0.9463806 0.9466667
0.50 0.1353177 0.9976667 0.8956340 0.9466667 0.92 0.9463806 0.9466667
1.00 0.1338090 0.9966667 0.8937293 0.9533333 0.93 0.9534371 0.9533333
2.00 0.1363814 0.9950000 0.8900847 0.9533333 0.93 0.9534371 0.9533333
4.00 0.1325292 0.9950000 0.8896610 0.9533333 0.93 0.9535384 0.9533333
8.00 0.1500452 0.9933333 0.8867857 0.9533333 0.93 0.9531318 0.9533333
16.00 0.1773016 0.9896667 0.8771895 0.9466667 0.92 0.9463472 0.9466667
32.00 0.1792608 0.9893333 0.8772480 0.9533333 0.93 0.9531318 0.9533333
64.00 0.1829900 0.9853333 0.8752865 0.9600000 0.94 0.9601883 0.9600000
128.00 0.1872609 0.9883333 0.8742261 0.9533333 0.93 0.9535384 0.9533333
Mean_Specificity Mean_Pos_Pred_Value Mean_Neg_Pred_Value Mean_Precision Mean_Recall
0.9733333 0.9517172 0.9747042 0.9517172 0.9466667
0.9733333 0.9517172 0.9747042 0.9517172 0.9466667
0.9766667 0.9583838 0.9777489 0.9583838 0.9533333
0.9766667 0.9583838 0.9777489 0.9583838 0.9533333
0.9766667 0.9573737 0.9774603 0.9573737 0.9533333
0.9766667 0.9579798 0.9779076 0.9579798 0.9533333
0.9733333 0.9529293 0.9750216 0.9529293 0.9466667
0.9766667 0.9579798 0.9779076 0.9579798 0.9533333
0.9800000 0.9646465 0.9809524 0.9646465 0.9600000
0.9766667 0.9573737 0.9774603 0.9573737 0.9533333
Mean_Detection_Rate Mean_Balanced_Accuracy
0.3155556 0.960
0.3155556 0.960
0.3177778 0.965
0.3177778 0.965
0.3177778 0.965
0.3177778 0.965
0.3155556 0.960
0.3177778 0.965
0.3200000 0.970
0.3177778 0.965

Tuning parameter 'sigma' was held constant at a value of 0.7668493
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were sigma = 0.7668493 and C = 64.

Step 5: Plot the Results

Visualize the results using ggplot2.

R
plot(svm_model)

Output:

gh
SVM with Cross Validation in R

Using cross-validation, you can reliably estimate the performance of your SVM model and tune its parameters to achieve the best results. This approach ensures that your model generalizes well to unseen data.

Conclusion

Cross-validation is an essential technique in machine learning for assessing model performance and ensuring generalizability to unseen data. The caret package in R provides a comprehensive framework for implementing various cross-validation methods with ease. By following the steps outlined, you can effectively train and evaluate models using cross-validation, choose the best model parameters, and visualize the results for better understanding and interpretation.


Next Article

Similar Reads