Implementing Decision Tree Classifiers with Scikit-Learn
Last Updated :
18 May, 2025
Decision Tree Classifier is a method used to classify data into categories like "Yes" or "No" or different types such as "Spam" or "Not Spam". It works by using a tree-like structure that asks questions to split the data step-by-step. These splits are based on input features to help the model make accurate predictions. At the end of each branch called a leaf node the model assigns a class label based on the majority of the data in that group.
Implementation of Decision Tree Classifier
The syntax for DecisionTreeClassifier is as follows:
class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None,min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None,ccp_alpha=0.0, monotonic_cst=None)
Let's go through the parameters:
- criterion: It meeasure the quality of a split. Supported values are 'gini', 'entropy' and 'log_loss'. The default value is 'gini'
- splitter: This parameter is used to choose the split at each node. Supported values are 'best' & 'random'. The default value is 'best'
- max_features: It defines the number of features to consider when looking for the best split.
- max_depth: This parameter denotes maximum depth of the tree (default=None).
- min_samples_split: It defines the minimum number of samples reqd. to split an internal node (default=2).
- min_samples_leaf: The minimum number of samples required to be at a leaf node (default=1)
- max_leaf_nodes: It defines the maximum number of possible leaf nodes.
- min_impurity_split: It defines the threshold for early stopping tree growth.
- class_weight: It defines the weights associated with classes.
- ccp_alpha: It is a complexity parameter used for minimal cost-complexity pruning
Let’s see the Step-by-Step implementation using scikit learn library in Python. We will use the Iris dataset to train the model.
Step 1: Import Libraries
We will import libraries like Scikit-Learn for machine learning tasks.
Python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
Step 2: Data Loading
In order to perform classification load a dataset. For demonstration one can use sample datasets from Scikit-Learn such as Iris or Breast Cancer.
Python
data = load_iris()
X = data.data
y = data.target
Step 3: Splitting Data
Use the train_test_split method from sklearn.model_selection to split the dataset into training and testing sets.
Python
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state = 99)
Step 4: Starting the Model
Using DecisionTreeClassifier from sklearn.tree create an object for the Decision Tree Classifier.
Python
clf = DecisionTreeClassifier(random_state=1)
Step 5: Training the Model
Apply the fit method to match the classifier to the training set of data.
Python
clf.fit(X_train, y_train)
Output:
Decision Tree ClassifierStep 6: Making Predictions
Apply the predict method to the test data and use the trained model to create predictions.
Python
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
Output:
Accuracy: 0.9555555555555556
Hyperparameter Tuning with Decision Tree Classifier
Hyperparameters are configuration settings that control the behavior of a decision tree model and significantly affect its performance. Proper tuning can improve accuracy, reduce overfitting and enhance generalization of model. Popular methods for tuning include Grid Search, Random Search and Bayesian Optimization which explore different combinations to find the best configuration.
For more details you can check out the article How to tune a Decision Tree in Hyperparameter tuning?
Hyperparmater tuning using GridSearchCV
Let's make use of Scikit-Learn's GridSearchCV to find the best combination of of hyperparameter values. The code is as follows:
Python
from sklearn.model_selection import GridSearchCV
# Hyperparameter to fine tune
param_grid = {
'max_depth': range(1, 10, 1),
'min_samples_leaf': range(1, 20, 2),
'min_samples_split': range(2, 20, 2),
'criterion': ["entropy", "gini"]
}
tree = DecisionTreeClassifier(random_state=1)
# GridSearchCV
grid_search = GridSearchCV(estimator=tree, param_grid=param_grid,
cv=5, verbose=True)
grid_search.fit(X_train, y_train)
print("best accuracy", grid_search.best_score_)
print(grid_search.best_estimator_)
Output:
Hyperparameter TuningHere we defined the parameter grid with a set of hyperparameters and a list of possible values. The GridSearchCV evaluates the different hyperparameter combinations for the DecissionTree Classifier and selects the best combination of hyperparameters based on the performance across all k folds.
Visualizing the Decision Tree Classifier
Decision Tree visualization is used to interpret and comprehend model's choices. We'll plot feature importance obtained from the Decision Tree model to see which features have the greatest predictive power. Here we fetch the best estimator obtained from the gridsearchcv as the decision tree classifier.
Python
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
tree_clf = grid_search.best_estimator_
plt.figure(figsize=(18, 15))
plot_tree(tree_clf, filled=True, feature_names=iris.feature_names,
class_names=iris.target_names)
plt.show()
Output:
DecisionTree VisualizationWe can see that it start from the root node (depth 0 at the top).
- The root node checks whether the flower petal width is less than or equal to 0.75. If it is then we move to the root's left child node (depth1, left). Here the left node doesn't have any child nodes so the classifier will predict the class for that node as setosa.
- If the petal width is greater than 0.75 then we must move down to the root's right child node (depth 1, right). Here the right node is not a leaf node, so node check for the condition until it reaches the leaf node.
By using hyperparameter tuning methods like GridSearchCV we can optimize their performance.
Similar Reads
Building and Implementing Decision Tree Classifiers with Scikit-Learn: A Comprehensive Guide
Decision Tree Classifier is a method used to classify data into categories like "Yes" or "No" or different types such as "Spam" or "Not Spam". It works by using a tree-like structure that asks questions to split the data step-by-step. These splits are based on input features to help the model make a
4 min read
Determining Feature Importance in SVM Classifiers with Scikit-Learn
Support Vector Machines (SVM) are a powerful tool in the machine learning toolbox, renowned for their ability to handle high-dimensional data and perform both linear and non-linear classification tasks. However, one of the challenges with SVMs is interpreting the model, particularly when it comes to
5 min read
Decision Tree Classifiers in R Programming
Classification is the task in which objects of several categories are categorized into their respective classes using the properties of classes. A classification model is typically used to, Predict the class label for a new unlabeled data objectProvide a descriptive model explaining what features ch
4 min read
Classifier Comparison in Scikit Learn
In scikit-learn, a classifier is an estimator that is used to predict the label or class of an input sample. There are many different types of classifiers that can be used in scikit-learn, each with its own strengths and weaknesses. Let's load the iris datasets from the sklearn.datasets and then tr
3 min read
Text Classification using scikit-learn in NLP
The purpose of text classification, a key task in natural language processing (NLP), is to categorise text content into preset groups. Topic categorization, sentiment analysis, and spam detection can all benefit from this. In this article, we will use scikit-learn, a Python machine learning toolkit,
5 min read
Save classifier to disk in scikit-learn in Python
In this article, we will cover saving a Save classifier to disk in scikit-learn using Python. We always train our models whether they are classifiers, regressors, etc. with the scikit learn library which require a considerable time to train. So we can save our trained models and then retrieve them w
3 min read
Ensemble Learning with SVM and Decision Trees
Ensemble learning is a machine learning technique that combines multiple individual models to improve predictive performance. Two popular algorithms used in ensemble learning are Support Vector Machines (SVMs) and Decision Trees. What is Ensemble Learning?By merging many models (also referred to as
5 min read
Classifiers in Scikit-Learn That Handle NaN/Null
Managing missing data is an important part of machine learning since it affects how well models work. Building robust classifiers requires handling NaN (Not a Number) or null values effectively, which are ubiquitous in many real-world datasets. Numerous classifiers available in Scikit-Learn, a well-
8 min read
How to Identify the Most Informative Features for scikit-learn Classifiers
Feature selection is an important step in the machine learning pipeline. By identifying the most informative features, you can enhance model performance, reduce overfitting, and improve computational efficiency. In this article, we will demonstrate how to use scikit-learn to determine feature import
7 min read
ML | Logistic Regression v/s Decision Tree Classification
Logistic Regression and Decision Tree classification are two of the most popular and basic classification algorithms being used today. None of the algorithms is better than the other and one's superior performance is often credited to the nature of the data being worked upon. We can compare the two
2 min read