Open In App

Implementing Decision Tree Classifiers with Scikit-Learn

Last Updated : 18 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Decision Tree Classifier is a method used to classify data into categories like "Yes" or "No" or different types such as "Spam" or "Not Spam". It works by using a tree-like structure that asks questions to split the data step-by-step. These splits are based on input features to help the model make accurate predictions. At the end of each branch called a leaf node the model assigns a class label based on the majority of the data in that group.

Implementation of Decision Tree Classifier

The syntax for DecisionTreeClassifier is as follows:

class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None,min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None,ccp_alpha=0.0, monotonic_cst=None)

Let's go through the parameters:

  • criterion: It meeasure the quality of a split. Supported values are 'gini', 'entropy' and 'log_loss'. The default value is 'gini'
  • splitter: This parameter is used to choose the split at each node. Supported values are 'best' & 'random'. The default value is 'best'
  • max_features: It defines the number of features to consider when looking for the best split.
  • max_depth: This parameter denotes maximum depth of the tree (default=None).
  • min_samples_split: It defines the minimum number of samples reqd. to split an internal node (default=2).
  • min_samples_leaf: The minimum number of samples required to be at a leaf node (default=1)
  • max_leaf_nodes: It defines the maximum number of possible leaf nodes.
  • min_impurity_split: It defines the threshold for early stopping tree growth.
  • class_weight: It defines the weights associated with classes.
  • ccp_alpha: It is a complexity parameter used for minimal cost-complexity pruning

Let’s see the Step-by-Step implementation using scikit learn library in Python. We will use the Iris dataset to train the model.

Step 1: Import Libraries

We will import libraries like Scikit-Learn for machine learning tasks.

Python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

Step 2: Data Loading

In order to perform classification load a dataset. For demonstration one can use sample datasets from Scikit-Learn such as Iris or Breast Cancer.

Python
data = load_iris()
X = data.data  
y = data.target 

Step 3: Splitting Data

Use the train_test_split method from sklearn.model_selection to split the dataset into training and testing sets.

Python
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state = 99)

Step 4: Starting the Model

Using DecisionTreeClassifier from sklearn.tree create an object for the Decision Tree Classifier.

Python
clf = DecisionTreeClassifier(random_state=1)

Step 5: Training the Model

Apply the fit method to match the classifier to the training set of data.

Python
clf.fit(X_train, y_train)

Output:

Decision-Tree-Classifier
Decision Tree Classifier

Step 6: Making Predictions

Apply the predict method to the test data and use the trained model to create predictions.

Python
y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')

Output:

Accuracy: 0.9555555555555556

Hyperparameter Tuning with Decision Tree Classifier

Hyperparameters are configuration settings that control the behavior of a decision tree model and significantly affect its performance. Proper tuning can improve accuracy, reduce overfitting and enhance generalization of model. Popular methods for tuning include Grid Search, Random Search and Bayesian Optimization which explore different combinations to find the best configuration.

For more details you can check out the article How to tune a Decision Tree in Hyperparameter tuning?

Hyperparmater tuning using GridSearchCV

Let's make use of Scikit-Learn's GridSearchCV to find the best combination of of hyperparameter values. The code is as follows:

Python
from sklearn.model_selection import GridSearchCV

# Hyperparameter to fine tune
param_grid = {
    'max_depth': range(1, 10, 1),
    'min_samples_leaf': range(1, 20, 2),
    'min_samples_split': range(2, 20, 2),
    'criterion': ["entropy", "gini"]
}

tree = DecisionTreeClassifier(random_state=1)
# GridSearchCV
grid_search = GridSearchCV(estimator=tree, param_grid=param_grid, 
                           cv=5, verbose=True)
grid_search.fit(X_train, y_train)

print("best accuracy", grid_search.best_score_)
print(grid_search.best_estimator_)

Output:

GridSearchCV
Hyperparameter Tuning

Here we defined the parameter grid with a set of hyperparameters and a list of possible values. The GridSearchCV evaluates the different hyperparameter combinations for the DecissionTree Classifier and selects the best combination of hyperparameters based on the performance across all k folds.

Visualizing the Decision Tree Classifier

Decision Tree visualization is used to interpret and comprehend model's choices. We'll plot feature importance obtained from the Decision Tree model to see which features have the greatest predictive power. Here we fetch the best estimator obtained from the gridsearchcv as the decision tree classifier.

Python
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
 
tree_clf = grid_search.best_estimator_

plt.figure(figsize=(18, 15))
plot_tree(tree_clf, filled=True, feature_names=iris.feature_names,
          class_names=iris.target_names)
plt.show()

Output:

iris_decision_tree
DecisionTree Visualization

We can see that it start from the root node (depth 0 at the top).

  • The root node checks whether the flower petal width is less than or equal to 0.75. If it is then we move to the root's left child node (depth1, left). Here the left node doesn't have any child nodes so the classifier will predict the class for that node as setosa.
  • If the petal width is greater than 0.75 then we must move down to the root's right child node (depth 1, right). Here the right node is not a leaf node, so node check for the condition until it reaches the leaf node.

By using hyperparameter tuning methods like GridSearchCV we can optimize their performance.


Similar Reads