Text Classification using Decision Trees in Python
Last Updated :
18 Mar, 2024
Text classification is the process of classifying the text documents into predefined categories. In this article, we are going to explore how we can leverage decision trees to classify the textual data.
Text Classification and Decision Trees
Text classification involves assigning predefined categories or labels to text documents based on their content. Decision trees are hierarchical tree structures that recursively partition the feature space based on the values of input features. They are particularly well-suited for classification tasks due to their simplicity, interpretability, and ability to handle non-linear relationships.
Decision Trees provide a clear and understandable model for text classification, making them an excellent choice for tasks where interpretability is as important as predictive power. Their inherent simplicity, however, might lead to challenges when dealing with very complex or nuanced text data, leading practitioners to explore more sophisticated or ensemble methods for improvement.
Implementation: Text Classification using Decision Trees
For text classification using Decision Trees in Python, we'll use the popular 20 Newsgroups dataset. This dataset comprises around 20,000 newsgroup documents, partitioned across 20 different newsgroups. We'll use scikit-learn to fetch the dataset, preprocess the text, convert it into a feature vector using TF-IDF vectorization, and then apply a Decision Tree classifier for classification.
Ensure you have scikit-learn installed in your environment. You can install it using pip if you haven't already:
pip install scikit-learn
Importing Necessary Libraries
Python3
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import numpy as np
Load the Dataset
The 20 Newsgroups dataset is loaded with specific categories for simplification. Headers, footers, and quotes are removed to focus on the text content.
Python3
# Load the dataset
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories, remove=('headers', 'footers', 'quotes'))
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories, remove=('headers', 'footers', 'quotes'))
Exploratory Data Analysis
This code snippet provides basic exploratory data analysis by visualizing the distribution of classes in the training and test sets and displaying sample documents.
Python3
# Display distribution of classes in the training set
class_distribution = np.bincount(y_train)
plt.bar(range(len(class_distribution)), class_distribution)
plt.xticks(range(len(class_distribution)), newsgroups_train.target_names, rotation=45)
plt.title('Distribution of Classes in Training Set')
plt.xlabel('Class')
plt.ylabel('Number of Documents')
plt.show()
Output:
.png)
Python3
# Display distribution of classes in the test set
class_distribution = np.bincount(y_test)
plt.bar(range(len(class_distribution)), class_distribution)
plt.xticks(range(len(class_distribution)), newsgroups_test.target_names, rotation=45)
plt.title('Distribution of Classes in Test Set')
plt.xlabel('Class')
plt.ylabel('Number of Documents')
plt.show()
Output:
.png)
Data Preprocessing
Text data is converted into TF-IDF feature vectors. TF-IDF (Term Frequency-Inverse Document Frequency) is a numerical statistic that reflects how important a word is to a document in a collection. This step is crucial for converting text data into a format that can be used for machine learning.
Python3
# Data preprocessing
vectorizer = TfidfVectorizer(stop_words='english')
X_train = vectorizer.fit_transform(newsgroups_train.data)
X_test = vectorizer.transform(newsgroups_test.data)
y_train = newsgroups_train.target
y_test = newsgroups_test.target
Decision Tree Classifier
A Decision Tree classifier is initialized and trained on the processed training data. Decision Trees are a non-linear predictive modeling tool that can be used for both classification and regression tasks.
Python3
# Initialize and train a Decision Tree classifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
Model Evaluation
The trained model is used to make predictions on the test set, and the model's performance is evaluated using accuracy and a detailed classification report, which includes precision, recall, f1-score, and support for each class.
Python3
# Make predictions
y_pred = clf.predict(X_test)
# Evaluate the model
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=newsgroups_test.target_names))
Output:
Accuracy: 0.6324900133155792
Classification Report:
precision recall f1-score support
alt.atheism 0.53 0.39 0.45 319
comp.graphics 0.61 0.82 0.70 389
sci.med 0.66 0.58 0.62 396
soc.religion.christian 0.70 0.69 0.70 398
accuracy 0.63 1502
macro avg 0.62 0.62 0.62 1502
weighted avg 0.63 0.63 0.62 1502
The output demonstrates the performance of a Decision Tree classifier on a text classification task using the 20 Newsgroups dataset. An accuracy of approximately 63.25% indicates that the model correctly predicted the category of over half of the newsgroup posts in the test set. The precision, recall, and f1-score for each category show how well the model performs for individual classes. Precision indicates the model's accuracy in labeling a class correctly, recall reflects how well the model identifies all relevant instances of a class, and the f1-score provides a balance between precision and recall. The variation across different categories (alt.atheism, comp.graphics, sci.med, soc.religion.christian) suggests that the model's ability to correctly classify posts varies with the subject matter, performing best in 'soc.religion.christian' and worst in 'alt.atheism'.
Comparison with Other Text Classification Techniques
We will compare decision trees with other popular text classification algorithms such as Random Forest and Support Vector Machines.
Text Classification using Random Forest
Python3
from sklearn.ensemble import RandomForestClassifier
# Initialize and train a Random Forest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
# Evaluate the model
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=newsgroups_test.target_names))
Output:
Classification Report:
precision recall f1-score support
alt.atheism 0.70 0.48 0.57 319
comp.graphics 0.77 0.93 0.84 389
sci.med 0.80 0.75 0.77 396
soc.religion.christian 0.74 0.82 0.78 398
accuracy 0.76 1502
macro avg 0.75 0.74 0.74 1502
weighted avg 0.75 0.76 0.75 1502
Text Classification using SVM
Python3
from sklearn.svm import SVC
# Initialize and train an SVM classifier
clf = SVC(kernel='linear', random_state=42)
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
# Evaluate the model
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=newsgroups_test.target_names))
Output:
Classification Report:
precision recall f1-score support
alt.atheism 0.75 0.63 0.68 319
comp.graphics 0.91 0.90 0.90 389
sci.med 0.80 0.90 0.85 396
soc.religion.christian 0.80 0.82 0.81 398
accuracy 0.82 1502
macro avg 0.82 0.81 0.81 1502
weighted avg 0.82 0.82 0.82 1502
Observations
- SVM outperforms both Random Forest and Decision Tree classifiers in terms of accuracy and overall performance, as indicated by the higher F1-score.
- Random Forest performs relatively well but slightly lags behind SVM.
- Decision Tree shows the lowest performance among the three classifiers, indicating the importance of choosing an appropriate algorithm for text classification tasks.
Similar Reads
Text Classification using Logistic Regression
Text classification is a fundamental task in Natural Language Processing (NLP) that involves assigning predefined categories or labels to textual data. It has a wide range of applications, including spam detection, sentiment analysis, topic categorization, and language identification. Logistic Regre
4 min read
Text Classification using scikit-learn in NLP
The purpose of text classification, a key task in natural language processing (NLP), is to categorise text content into preset groups. Topic categorization, sentiment analysis, and spam detection can all benefit from this. In this article, we will use scikit-learn, a Python machine learning toolkit,
5 min read
ML | Logistic Regression v/s Decision Tree Classification
Logistic Regression and Decision Tree classification are two of the most popular and basic classification algorithms being used today. None of the algorithms is better than the other and one's superior performance is often credited to the nature of the data being worked upon. We can compare the two
2 min read
Decision Tree Classifiers in R Programming
Classification is the task in which objects of several categories are categorized into their respective classes using the properties of classes. A classification model is typically used to, Predict the class label for a new unlabeled data objectProvide a descriptive model explaining what features ch
4 min read
Tree-Based Models for Classification in Python
Tree-based models are a cornerstone of machine learning, offering powerful and interpretable methods for both classification and regression tasks. This article will cover the most prominent tree-based models used for classification, including Decision Tree Classifier, Random Forest Classifier, Gradi
8 min read
Comparing Support Vector Machines and Decision Trees for Text Classification
Support Vector Machines (SVMs) and Decision Trees are both popular algorithms for text classification, but they have different characteristics and are suitable for different types of problems. Why is model selection important in Text Classification?Selecting the ideal model for text classification r
8 min read
Classification of text documents using sparse features in Python Scikit Learn
Classification is a type of machine learning algorithm in which the model is trained, so as to categorize or label the given input based on the provided features for example classifying the input image as an image of a dog or a cat (binary classification) or to classify the provided picture of a liv
5 min read
How to build classification trees in R?
In this article, we will discuss What is a Classification Tree and how we create a Classification Tree in the R Programming Language. What is a Classification Tree?Classification trees are powerful tools for predictive modeling in machine learning, particularly for categorical outcomes. In R, the rp
3 min read
Implementing CART (Classification And Regression Tree) in Python
Classification and Regression Trees (CART) are a type of decision tree algorithm used in machine learning and statistics for predictive modeling. CART is versatile, used for both classification (predicting categorical outcomes) and regression (predicting continuous outcomes) tasks. Here we check the
6 min read
Compute Classification Report and Confusion Matrix in Python
Classification Report and Confusion Matrix are used to check machine learning model's performance during model development. These help us understand the accuracy of predictions and tells areas of improvement. In this article, we will learn how to compute these metrics in Python using a simple exampl
3 min read