Emotion Classification with DistilBERT
Emotion Classification with DistilBERT
Archive: emotions.zip
inflating: /content/emotions/text.csv
1 Importing
[ ]: import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
[ ]: df = pd.read_csv("/content/emotions/text.csv")
df.head()
[ ]: text label
0 i just feel really helpless and heavy hearted 4
1
1 ive enjoyed being able to slouch about relax a… 0
2 i gave up my internship with the dmrg and am f… 4
3 i dont know i feel so lost 0
4 i am a kindergarten teacher and i am thoroughl… 4
[ ]: emotion_map = {
0: "sadness",
1: "joy",
2: "love",
3: "anger",
4: "fear",
5: "surprise"
}
df["label"] = df["label"].map(emotion_map)
[ ]: df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 416809 entries, 0 to 416808
Data columns (total 2 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 text 416809 non-null object
1 label 416809 non-null object
dtypes: object(2)
memory usage: 6.4+ MB
2 Pre Processing
[ ]: # Set the size of the plot
plt.figure(figsize=(10, 6))
2
[ ]: # Identify duplicate rows
duplicate_rows = df.duplicated()
count 416123.000000
mean 97.102662
std 56.176302
min 2.000000
25% 54.000000
3
50% 86.000000
75% 128.000000
max 830.000000
Name: text_length, dtype: float64
[ ]: # Calculate Q1 and Q3
Q1 = df['text_length'].quantile(0.25)
Q3 = df['text_length'].quantile(0.75)
IQR = Q3 - Q1
4
# Identify outliers
outliers_length = df[(df['text_length'] < lower_bound) | (df['text_length'] >␣
↪upper_bound)]
plt.figure(figsize=(10,6))
sns.histplot(df['text_length'], bins=50, kde=True, color='blue')
plt.axvline(lower_bound, color='red', linestyle='--', label='Lower Bound')
plt.axvline(upper_bound, color='red', linestyle='--', label='Upper Bound')
plt.title('Distribution of Text Lengths with Outlier Boundaries')
plt.xlabel('Number of Characters')
plt.ylabel('Frequency')
plt.legend()
plt.show()
"""
Analyze and identify outliers based on a numerical feature.
Parameters:
5
- df (pd.DataFrame): The DataFrame containing the data.
- feature (str): The numerical feature to analyze.
- start (int, optional): The starting value of the feature to consider.␣
↪Defaults to the minimum value.
- mode (int, optional): If not zero, saves the subset to CSV. Defaults to 0.
- save_filename (str, optional): The filename to save the subset if mode !=␣
↪0.
Returns:
- None
"""
# Ensure the feature exists in the DataFrame
if feature not in df.columns:
print(f"Feature '{feature}' not found in the DataFrame.")
return
plt.figure(figsize=(14, 8))
sns.boxplot(x='label', y='sentence_length', data=df, palette="coolwarm")
plt.title('Correlation between Labels and Sentence Length', fontsize=16,␣
↪fontweight='bold')
plt.xlabel("Labels", fontsize=14)
6
plt.ylabel("Sentence Length", fontsize=14)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.show()
Parameters:
- start (int): Starting sentence length (inclusive).
- end (int): Ending sentence length (exclusive).
- save_csv (bool): If True, saves each group to a separate CSV file.
Returns:
- outliers (dict): A dictionary with sentence lengths as keys and␣
↪corresponding DataFrames as values.
"""
outliers = {}
for length in range(start, end):
filtered_df = df[df['sentence_length'] == length]
count = len(filtered_df)
7
if count > 0:
outliers[length] = filtered_df
if save_csv:
filename = f"df_sentence_length_{length}.csv"
filtered_df.to_csv(filename, index=False)
print(f"Saved {count} records to {filename}")
return outliers
# Example usage:
outlier_data = sentence_length_outlier(start=1, end=5, save_csv=True)
[ ]: df.shape
[ ]: (415065, 4)
8
text sentence_length
347001 a few days back i was waiting for the bus at t… 178
290349 two years back someone invited me to be the tu… 110
97687 i have been thinking of changing my major for … 101
38584 when i got into a bus i found that my wallet h… 100
22750 my living and working conditions at home were … 100
332276 i had a dream i had a very close friend who ha… 94
249491 i worked with several classmates on a project … 80
158527 i was camping in an old broken hut which had n… 79
56688 last semester when i dated a girl whom ive kno… 78
162121 i was a prefect at secondary school on the spo… 77
174240 when i was in lower six class during the summe… 76
387931 a boy phoned me at night and wanted to talk to… 75
212944 a friend female and i were on holiday on great… 74
35550 i was studying in class at night i was in form… 74
114318 i suddenly found that those whom i considerere… 74
263129 my friend often played a joke on me and someti… 73
337213 a few years ago my mother suffered from cancce… 72
343770 i can get a feel cuz ya make me so horny all i… 71
37247 after attending a song contest proposed by a b… 70
68607 a new gas connection was to be installed and t… 70
[ ]: df.shape
[ ]: (415059, 4)
9
)
def processing_data(text):
"""
Preprocesses the input text by performing the following operations:
1. Removes HTML tags.
2. Removes URLs.
3. Removes emojis and non-standard Unicode characters.
4. Removes numeric values.
5. Removes non-alphanumeric characters (excluding spaces).
6. Converts text to lowercase.
7. Removes extra whitespaces.
8. Strips leading and trailing whitespaces.
Parameters:
- text (str): The input text to preprocess.
Returns:
- str: The cleaned and preprocessed text.
"""
if pd.isnull(text) or not isinstance(text, str):
return ""
# Remove URLs
text = URL_RE.sub('', text)
# Remove emojis
text = EMOJI_RE.sub('', text)
# Convert to lowercase
text = text.lower()
10
# Strip leading and trailing whitespaces
text = text.strip()
return text
[ ]: df['text'] = df['text'].apply(processing_data)
[ ]: df.head(20)
sentence_length
0 8
1 45
2 12
3 7
4 42
5 7
6 25
7 25
8 15
9 9
10 34
11 18
12 14
13 37
14 22
11
15 35
16 5
17 21
18 10
19 25
print(label_distribution)
Count Percentage
label
joy 140474 33.84
sadness 120692 29.08
anger 57008 13.73
fear 47510 11.45
love 34441 8.30
surprise 14934 3.60
plt.figure(figsize=(8, 8))
plt.pie(label_counts, labels=label_counts.index, autopct='%1.1f%%',␣
↪startangle=140, colors=colors, explode=(0.05, 0.05, 0.05, 0.05, 0.05, 0.05))
12
[ ]: plt.figure(figsize=(10, 6))
sns.countplot(x='label', data=df, palette="Set2", order=label_counts.index)
plt.title('Emotion Distribution in Dataset', fontsize=16, fontweight='bold')
plt.xlabel('Emotion', fontsize=14)
plt.ylabel('Number of Occurrences', fontsize=14)
plt.show()
13
2.2 Preprocessing Lemmatization and Filtering StopWords
[ ]: import time
import re
import spacy
from tqdm import tqdm
import pandas as pd
# Load SpaCy model with parser and NER disabled for efficiency
nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])
def process_text(doc):
"""
Processes a SpaCy Doc object by lemmatizing tokens, removing stopwords, and␣
↪filtering non-alphabetic tokens.
Parameters:
- doc (spacy.tokens.Doc): The SpaCy Doc object to process.
Returns:
- str: The processed text.
14
"""
return ' '.join(
token.lemma_ for token in doc
if not token.is_stop and token.is_alpha
)
Parameters:
- texts (list of str): The list of text strings to process.
- batch_size (int): The number of texts to process in each batch.
Returns:
- list of str: The list of processed text strings.
"""
processed_texts = []
processed_text = process_text(doc)
processed_texts.append(processed_text)
return processed_texts
if __name__ == "__main__":
# Load your dataset here
# Example:
# df = pd.read_csv("your_data.csv")
# Ensure that the DataFrame 'df' exists and has a 'text' column
try:
df
except NameError:
print("Error: DataFrame 'df' is not defined. Please load your data␣
↪before processing.")
exit()
15
if missing_texts > 0:
print(f"Warning: Found {missing_texts} missing texts. Filling them with␣
↪empty strings.")
df['text'] = df['text'].fillna("")
Collecting datasets
Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-
packages (from datasets) (3.16.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-
packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=15.0.0 in
/usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages
(from datasets) (2.2.2)
Requirement already satisfied: requests>=2.32.2 in
/usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)
Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-
packages (from datasets) (4.67.1)
Collecting xxhash (from datasets)
Downloading
xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
16
(12 kB)
Collecting multiprocess<0.70.17 (from datasets)
Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from
fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-
packages (from datasets) (3.11.10)
Requirement already satisfied: huggingface-hub>=0.23.0 in
/usr/local/lib/python3.10/dist-packages (from datasets) (0.27.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-
packages (from datasets) (24.2)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-
packages (from datasets) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)
Requirement already satisfied: aiosignal>=1.1.2 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.2)
Requirement already satisfied: async-timeout<6.0,>=4.0 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-
packages (from aiohttp->datasets) (24.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in
/usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in
/usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets)
(4.12.2)
Requirement already satisfied: charset-normalizer<4,>=2 in
/usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets)
(3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-
packages (from requests>=2.32.2->datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in
/usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets)
(2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in
/usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets)
(2024.12.14)
Requirement already satisfied: python-dateutil>=2.8.2 in
/usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-
packages (from pandas->datasets) (2024.2)
17
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-
packages (from pandas->datasets) (2024.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-
packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
���������������������������������������� 480.6/480.6 kB
18.0 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
���������������������������������������� 116.3/116.3 kB
6.7 MB/s eta 0:00:00
Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
���������������������������������������� 179.3/179.3 kB
14.4 MB/s eta 0:00:00
Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
���������������������������������������� 134.8/134.8 kB
10.8 MB/s eta 0:00:00
Downloading
xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
���������������������������������������� 194.1/194.1 kB
9.6 MB/s eta 0:00:00
Installing collected packages: xxhash, fsspec, dill, multiprocess,
datasets
Attempting uninstall: fsspec
Found existing installation: fsspec 2024.10.0
Uninstalling fsspec-2024.10.0:
Successfully uninstalled fsspec-2024.10.0
ERROR: pip's dependency resolver does not currently take into account all
the packages that are installed. This behaviour is the source of the following
dependency conflicts.
gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which
is incompatible.
Successfully installed datasets-3.2.0 dill-0.3.8 fsspec-2024.9.0
multiprocess-0.70.16 xxhash-3.5.0
18
df['label_encoded'] = label_encoder.fit_transform(df['label'])
19
2.4 Tokenization
[ ]: model_name = "distilbert-base-uncased"
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(
examples['text'],
padding="max_length",
truncation=True,
max_length=128 # Adjust based on your data
)
20
else:
print("CUDA is not available. Training on CPU...")
[ ]: num_labels = len(label_mapping)
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds,␣
↪average='weighted')
training_args = TrainingArguments(
output_dir='./results', # Output directory
evaluation_strategy="epoch", # Evaluation strategy to adopt␣
↪during training
21
Some weights of DistilBertForSequenceClassification were not initialized from
the model checkpoint at distilbert-base-uncased and are newly initialized:
['classifier.bias', 'classifier.weight', 'pre_classifier.bias',
'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it
for predictions and inference.
# Start training
trainer.train()
[ ]: TrainOutput(global_step=62259, training_loss=0.1351904600557597,
metrics={'train_runtime': 11252.9092, 'train_samples_per_second': 88.523,
'train_steps_per_second': 5.533, 'total_flos': 3.299140499276851e+16,
'train_loss': 0.1351904600557597, 'epoch': 3.0})
2.6 Evaluation
[ ]: # Evaluate the model on the test set
results = trainer.evaluate()
print("Evaluation Results:", results)
<IPython.core.display.HTML object>
Evaluation Results: {'eval_loss': 0.13016673922538757, 'eval_accuracy':
0.9320821086108032, 'eval_f1': 0.9328861620931987, 'eval_precision':
0.9387182163079509, 'eval_recall': 0.9320821086108032, 'eval_runtime': 271.8072,
'eval_samples_per_second': 305.408, 'eval_steps_per_second': 19.091, 'epoch':
3.0}
[ ]: # Get predictions
predictions = trainer.predict(test_dataset)
pred_labels = np.argmax(predictions.predictions, axis=1)
22
true_labels = predictions.label_ids
# Classification Report
print("=== Classification Report ===\n")
print(classification_report(true_labels, pred_labels,␣
↪target_names=label_mapping.keys()))
<IPython.core.display.HTML object>
=== Classification Report ===
cm = confusion_matrix(true_labels, pred_labels)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=label_mapping.keys(),
yticklabels=label_mapping.keys())
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix')
plt.show()
23
2.7 Prediction and Inference
[ ]: def predict_emotion(text):
# Tokenize the input text
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
24
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).cpu().numpy()[0]
25