Open In App

How to choose ideal Decision Tree depth without overfitting?

Last Updated : 09 Jan, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Choosing the ideal depth for a decision tree is crucial to avoid overfitting, a common issue where the model fits the training data too well but fails to generalize to new data. The core idea is to balance the complexity of the model with its ability to generalize. Here, we will explore how to set the optimal depth for decision trees to prevent overfitting. Let's discuss few techniques for Preventing Overfitting in Decision Trees:

1. Use Cross-Validation

Divide the dataset into multiple subsets and train Decision Trees with varying depths on one subset while validating on another. This approach identifies the depth that generalizes best to unseen data. For instance, training trees with depths from 1 to 15 might reveal that depth 7 achieves the best validation accuracy without overfitting.

2. Set a Maximum Depth Parameter

Set a maximum depth for the tree, typically between 3 and 10, based on the complexity of the data. Limiting depth prevents the model from capturing noise or irrelevant patterns. For example, a tree with a maximum depth of 5 may generalize better than a deeper tree that overfits by learning minor data irregularities.

3. Monitor Training and Validation Accuracy Trends

Track training and validation accuracy as tree depth increases. Overfitting becomes evident when validation accuracy peaks while training accuracy continues to rise. For example, if validation accuracy plateaus at depth 8 but training accuracy keeps improving, depth 8 should be selected as the optimal value.

4. Automated Depth Optimization

Use GridSearchCV or RandomizedSearchCV to efficiently identify the best tree depth by testing a range of values, such as 1 to 15. These methods automate the search process and determine the optimal depth based on cross-validation results. For example, grid search might suggest depth 6 as ideal for balancing accuracy and generalization.

5. Pruning Techniques

Apply pruning parameters such as min_samples_splitmin_samples_leaf, or ccp_alpha to simplify the tree by removing low-impact branches. Pruning reduces complexity and enhances generalization by eliminating branches that do not significantly improve predictions. For instance, setting pruning parameters might reduce a tree from a depth of 15 to an effective depth of 8, resulting in better performance on validation data.

Let's understand with the below example:

This code demonstrates five techniques to prevent overfitting in decision trees.

  • First, it uses cross-validation to evaluate model accuracy across depths, helping identify an optimal depth.
  • Next, it limits the tree's depth to prevent overfitting, then monitors training and validation accuracy to observe overfitting trends.
  • It also applies grid search to automatically find the best depth.
  • Finally, it uses pruning techniques (`min_samples_split` and `min_samples_leaf`) to reduce overfitting by simplifying the tree. Each technique’s results are displayed to show its impact on training and test accuracy.
Python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.metrics import accuracy_score

data = load_iris()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

depths = range(1, 16)
train_accuracy = []
test_accuracy = []

print("Method 1: Cross-Validation")
for depth in depths:
    clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
    scores = cross_val_score(clf, X_train, y_train, cv=5)
    train_accuracy.append(scores.mean())

plt.figure(figsize=(10, 6))
plt.plot(depths, train_accuracy, marker='o', label="Cross-Validation Accuracy")
plt.xlabel('Tree Depth')
plt.ylabel('Cross-Validated Accuracy')
plt.title('Tree Depth vs. Cross-Validated Accuracy')
plt.legend()
plt.grid(True)
plt.show()

max_depth = 7
clf = DecisionTreeClassifier(max_depth=max_depth, random_state=42)
clf.fit(X_train, y_train)
train_accuracy_fixed_depth = accuracy_score(y_train, clf.predict(X_train))
test_accuracy_fixed_depth = accuracy_score(y_test, clf.predict(X_test))

print(f"Method 2: Set Max Depth to {max_depth}")
print(f"Training Accuracy (Max Depth {max_depth}): {train_accuracy_fixed_depth}")
print(f"Test Accuracy (Max Depth {max_depth}): {test_accuracy_fixed_depth}")

train_accuracy = []
validation_accuracy = []
for depth in depths:
    clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
    clf.fit(X_train, y_train)
    train_accuracy.append(clf.score(X_train, y_train))
    validation_scores = cross_val_score(clf, X_train, y_train, cv=5)
    validation_accuracy.append(validation_scores.mean())

plt.figure(figsize=(10, 6))
plt.plot(depths, train_accuracy, marker='o', label="Training Accuracy")
plt.plot(depths, validation_accuracy, marker='o', label="Validation Accuracy")
plt.xlabel('Tree Depth')
plt.ylabel('Accuracy')
plt.title('Tree Depth vs. Training and Validation Accuracy')
plt.legend()
plt.grid(True)
plt.show()

param_grid = {'max_depth': range(1, 16)}
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5)
grid_search.fit(X_train, y_train)
best_depth = grid_search.best_params_['max_depth']
print(f"Method 4: Best Depth found by Grid Search is {best_depth}")

clf_pruned = DecisionTreeClassifier(max_depth=best_depth, min_samples_split=4, min_samples_leaf=2, random_state=42)
clf_pruned.fit(X_train, y_train)
train_accuracy_pruned = accuracy_score(y_train, clf_pruned.predict(X_train))
test_accuracy_pruned = accuracy_score(y_test, clf_pruned.predict(X_test))

print("Method 5: Pruning with min_samples_split=4, min_samples_leaf=2")
print(f"Training Accuracy (Pruned): {train_accuracy_pruned}")
print(f"Test Accuracy (Pruned): {test_accuracy_pruned}")

Output:

Screenshot-from-2024-11-14-15-46-46
Cross Validation and Max Depth affect on decision trees

Method 4: Best Depth found by Grid Search is 4
Method 5: Pruning with min_samples_split=4, min_samples_leaf=2
Training Accuracy (Pruned): 0.9666666666666667
Test Accuracy (Pruned): 1.0

Key Takeaways for Preventing Overfitting in Decision Trees

  • Limit Tree Depth: Set a maximum depth to prevent the tree from becoming too complex.
  • Minimum Samples: Ensure a minimum number of samples are required for splits and leaf nodes.
  • Pruning: Remove non-contributing branches to simplify the model.
  • Cross-Validation: Use multiple subsets of data to evaluate and tune the model's performance.
  • Ensemble Methods: Combine multiple decision trees to reduce variance and improve robustness.

Next Article

Similar Reads