Building and Implementing Decision Tree Classifiers with Scikit-Learn: A Comprehensive Guide
Last Updated :
27 Jan, 2025
Decision Tree Classifiers is a fundamental machine learning algorithm for classification tasks. They organize data into a tree-like structure where internal nodes represent decisions, branches represent outcomes and leaf node represent class labels. This article introduces how to build and implement these classifiers using Scikit-Learn a Python library for machine learning.
Implementing Decision Tree Classifiers with Scikit-Learn
The DecisionTreeClassifier from Sklearn has the ability to perform multi-class classification on a dataset. The syntax for DecisionTreeClassifier is as follows:
class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None,min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None,ccp_alpha=0.0, monotonic_cst=None)
Let's go through the parameters:
- criterion: It meeasure the quality of a split. Supported values are 'gini', 'entropy' and 'log_loss'. The default value is 'gini'
- splitter: This parameter is used to choose the split at each node. Supported values are 'best' & 'random'. The default value is 'best'
- max_features: It defines the number of features to consider when looking for the best split.
- max_depth: This parameter denotes maximum depth of the tree (default=None).
- min_samples_split: It defines the minimum number of samples reqd. to split an internal node (default=2).
- min_samples_leaf: The minimum number of samples required to be at a leaf node (default=1)
- max_leaf_nodes: It defines the maximum number of possible leaf nodes.
- min_impurity_split: It defines the threshold for early stopping tree growth.
- class_weight: It defines the weights associated with classes.
- ccp_alpha: It is a complexity parameter used for minimal cost-complexity pruning
Steps to train a DecisionTreeClassifier Using Sklearn
Let's look at how to train a DecisionTreeClassifier using Sklearn on Iris dataset. The phase by phase execution as follows:
Step 1: Import Libraries
To start, import the libraries you'll need such as Scikit-Learn (sklearn) for machine learning tasks.
Python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
Step 2: Data Loading
In order to perform classification load a dataset. For demonstration one can utilize sample datasets from Scikit-Learn such as Iris or Breast Cancer.
Python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
Step 3: Splitting Data
Use the train_test_split method from sklearn.model_selection to split the dataset into training and testing sets.
Python
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=99)
Step 4: Starting the Model
Using DecisionTreeClassifier from sklearn.tree create an object for the Decision Tree Classifier.
Python
clf = DecisionTreeClassifier(random_state=1)
Step 5: Training the Model
Apply the fit method to match the classifier to the training set of data.
Python
clf.fit(X_train, y_train)
Step 6: Making Predictions
Apply the predict method to the test data and use the trained model to create predictions.
Python
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
Let's implement the complete code based on above steps. The code is as follows:
Python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
# load iris dataset
iris = load_iris()
X = iris.data
y = iris.target
# split dataset to training and test set
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state = 99)
# initialize decision tree classifier
clf = DecisionTreeClassifier(random_state=1)
# train the classifier
clf.fit(X_train, y_train)
# predict using classifier
y_pred = clf.predict(X_test)
# claculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
Output:
Accuracy: 0.9555555555555556
Hyperparameter Tuning with Decision Tree Classifier
Hyperparameters are configuration settings that control the behavior of a decision tree model and significantly affect its performance. Proper tuning can improve accuracy, reduce overfitting and enhance generalization of model.
Popular methods for tuning include Grid Search, Random Search, and Bayesian Optimization, which explore different combinations to find the best configuration.
For more details, you can check out the article How to tune a Decision Tree in Hyperparameter tuning?
Hyperparmater tuning using GridSearchCV
Let's make use of Scikit-Learn's GridSearchCV to find the best combination of of hyperparameter values. The code is as follows:
Python
from sklearn.model_selection import GridSearchCV
# Hyperparameter to fine tune
param_grid = {
'max_depth': range(1, 10, 1),
'min_samples_leaf': range(1, 20, 2),
'min_samples_split': range(2, 20, 2),
'criterion': ["entropy", "gini"]
}
tree = DecisionTreeClassifier(random_state=1)
# GridSearchCV
grid_search = GridSearchCV(estimator=tree, param_grid=param_grid,
cv=5, verbose=True)
grid_search.fit(X_train, y_train)
print("best accuracy", grid_search.best_score_)
print(grid_search.best_estimator_)
Output:
Fitting 5 folds for each of 1620 candidates, totalling 8100 fits
best accuracy 0.9714285714285715
DecisionTreeClassifier(criterion='entropy', max_depth=4, min_samples_leaf=3, random_state=1)
Here we defined the parameter grid with a set of hyperparameters and a list of possible values. The GridSearchCV evaluates the different hyperparameter combinations for the DecissionTree Classifier and selects the best combination of hyperparameters based on the performance across all k folds.
Visualizing the Decision Tree Classifier
Decision Tree visualization is used to interpret and comprehend model's choices. We'll plot feature importance obtained from the Decision Tree model to see which features have the greatest predictive power. Here we fetch the best estimator obtained from the gridsearchcv as the decision tree classifier.
Python
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
tree_clf = grid_search.best_estimator_
plt.figure(figsize=(18, 15))
plot_tree(tree_clf, filled=True, feature_names=iris.feature_names,
class_names=iris.target_names)
plt.show()
Output:
DecisionTree VisualizationWe can see that it start from the root node (depth 0, at the top).
- The root node checks whether the flower petal width is less than or equal to 0.75. If it is then we move to the root's left child node (depth1, left). Here the left node doesn't have any child nodes so the classifier will predict the class for that node as setosa.
- If the petal width is greater than 0.75 then we must move down to the root's right child node (depth 1, right). Here the right node is not a leaf node, so node check for the condition until it reaches the leaf node.
Decision Tree Classifier With Spam Email Detection Dataset
Spam email detection dataset is trained on decision trees to predict e-mails as spam or safe to open(ham). As scikit-learn library is used for this implementation.
Let's load the spam email dataset and plot the count of spam and ham emails using matplotlib. The code is as follows:
Python
import pandas as pd
import matplotlib.pyplot as plt
dataset_link = 'https://media.geeksforgeeks.org/wp-content/uploads/20240620175612/spam_email.csv'
df = pd.read_csv(dataset_link)
df['Category'].value_counts().plot.bar(color = ["g","r"])
plt.title('Total number of ham and spam in the dataset')
plt.show()
Output:
Spam and Ham Email CountAs a next step, let's prepare the data for decision tree classifier. The code is as follows:
Python
from nltk.tokenize import RegexpTokenizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
def clean_str(string):
reg = RegexpTokenizer(r'[a-z]+')
string = string.lower()
tokens = reg.tokenize(string)
return " ".join(tokens)
df['Category'] = df['Category'].map({'ham' : 0,'spam' : 1 })
df['text_clean'] = df['Message'].apply(
lambda string: clean_str(string))
cv = CountVectorizer()
X = cv.fit_transform(df.text_clean)
y = df.Category
x_train, x_test, y_train, y_test = train_test_split(
X, y, test_size = 0.2, random_state = 42)
As part of data preparation the category string is replaced with a numeric attribute, RegexpTokenizer is used for message cleaning and CountVectorizer() is used to convert text documents to a matrix of tokens. Finally, the dataset is separated into training and test sets.
Now we can use the prepared data to train a DecisionTreeClassifier. The code is as follows:
Python
from sklearn.metrics import (
accuracy_score, classification_report,
confusion_matrix)
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
model.fit(x_train, y_train)
pred = model.predict(x_test)
print(classification_report(y_test, pred))
Output:
precision recall f1-score support
0 0.98 0.98 0.98 966
1 0.89 0.89 0.89 149
accuracy 0.97 1115
macro avg 0.94 0.93 0.94 1115
weighted avg 0.97 0.97 0.97 1115
Here we used DecisionTreeClassifier from sklearn to train our model, and the classicfication_metrics() is used for evaluating the predictions. Let's check the confusion matrix for the decision tree classifier.
Python
import seaborn as sns
cmat = confusion_matrix(y_test, pred)
sns.heatmap(cmat, annot=True, cmap='Paired',
cbar=False, fmt="d", xticklabels=[
'Not Spam', 'Spam'], yticklabels=['Not Spam', 'Spam'])
Output:
HeatmapAdvantages and Disadvantages of Decision Tree Classifier
Advantages of Decision Tree Classifier
- Interpretability: Decision trees are accessible for comprehending the decisions made by the model because they are simple to grasp and display.
- Takes Care of Non-linearity: They don't need feature scaling in order to capture non-linear correlations between features and target variables.
- Handles Mixed Data Types: Without the necessity for one-hot encoding, decision trees are able to handle both numerical and categorical data.
- Feature Selection: By choosing significant features further up the tree, they obliquely carry out feature selection.
- No Assumptions about Data Distribution: Unlike parametric models, decision trees do not make any assumptions on the distribution of data.
- Robust to Outliers: Because of the way dividing nodes works they are resistant to data outliers.
Drawbacks of Decision Tree Classifier
- Overfitting: If training data isn't restricted or pruned, decision trees have a tendency to overfit and capture noise and anomalies.
- Instability: Minor changes in the data might result in entirely different tree structures, which is what makes them unstable.
- Difficulty in Capturing Complex interactions: Deeper trees may be necessary in order to capture complex interactions such as XOR.
- Bias towards Dominant Classes: Decision Trees may exhibit bias towards dominant classes in datasets with an uneven distribution of classes.
Decision Tree Classifiers are a powerful and interpretable tool in machine learning and we can implement them using Scikit learn python library. By using hyperparameter tuning methods like GridSearchCV we can optimize their performance.
Similar Reads
Building a Custom Estimator for Scikit-learn: A Comprehensive Guide
Scikit-learn is a powerful machine learning library in Python that offers a wide range of tools for data analysis and modeling. One of its best features is the ease with which you can create custom estimators, allowing you to meet specific needs. In this article, we will walk through the process of
5 min read
Comprehensive Guide to Classification Models in Scikit-Learn
Scikit-Learn, a powerful and user-friendly machine learning library in Python, has become a staple for data scientists and machine learning practitioners. It offers a wide array of tools for data mining and data analysis, making it accessible and reusable in various contexts. This article delves int
12 min read
Determining Feature Importance in SVM Classifiers with Scikit-Learn
Support Vector Machines (SVM) are a powerful tool in the machine learning toolbox, renowned for their ability to handle high-dimensional data and perform both linear and non-linear classification tasks. However, one of the challenges with SVMs is interpreting the model, particularly when it comes to
5 min read
Classifier Comparison in Scikit Learn
In scikit-learn, a classifier is an estimator that is used to predict the label or class of an input sample. There are many different types of classifiers that can be used in scikit-learn, each with its own strengths and weaknesses. Let's load the iris datasets from the sklearn.datasets and then tr
3 min read
How to implement cost-sensitive learning in decision trees?
Decision trees are tools for classification, but they can struggle with imbalanced datasets where one class significantly outnumbers the other. Cost-sensitive learning is a technique that addresses this issue by assigning different costs to misclassification errors, making the decision tree more sen
4 min read
Machine Learning Packages and IDEs: A Comprehensive Guide
Machine learning (ML) has revolutionized various industries by enabling systems to learn from data and make intelligent decisions. To harness the power of machine learning, developers and data scientists rely on a plethora of packages and Integrated Development Environments (IDEs). This article delv
12 min read
Decision Tree Classifiers in R Programming
Classification is the task in which objects of several categories are categorized into their respective classes using the properties of classes. A classification model is typically used to, Predict the class label for a new unlabeled data objectProvide a descriptive model explaining what features ch
4 min read
How to Identify the Most Informative Features for scikit-learn Classifiers
Feature selection is an important step in the machine learning pipeline. By identifying the most informative features, you can enhance model performance, reduce overfitting, and improve computational efficiency. In this article, we will demonstrate how to use scikit-learn to determine feature import
7 min read
Comparing Support Vector Machines and Decision Trees for Text Classification
Support Vector Machines (SVMs) and Decision Trees are both popular algorithms for text classification, but they have different characteristics and are suitable for different types of problems. Why is model selection important in Text Classification?Selecting the ideal model for text classification r
8 min read
Building Naive Bayesian classifier with WEKA
The use of the Naive Bayesian classifier in Weka is demonstrated in this article. The âweather-nominalâ data set used in this experiment is available in ARFF format. This paper assumes that the data has been properly preprocessed. The Bayes' Theorem is used to build a set of classification algorithm
3 min read