In the ever-evolving landscape of artificial intelligence, the quest for efficient and versatile models has led researchers to explore innovative training paradigms. Among these, self-supervised learning has emerged as a frontrunner, offering a promising solution to the perennial challenge of acquiring labelled data for diverse tasks. One remarkable stride in this direction comes with Wav2Vec2, a groundbreaking model designed for self-supervised speech representation learning.
What is the Wav2Vec2 Model?
Wav2Vec2 stands as a testament to the transformative potential of self-supervised training, particularly in the realm of Natural Language Processing (NLP). Its architecture is tailored to harness vast amounts of unlabeled speech data, distilling intricate patterns and nuances to create a rich and generalized understanding of spoken language. This self-supervised pre-training phase sets the stage for subsequent fine-tuning, where the model refines its knowledge on specific downstream tasks using limited labeled datasets.
- Pre-Training: Here the model is trained to develop broad general representations by learning from extensive datasets that are unlabeled. This step is believed to enhance performance on a downstream task with constrained data.
- Fine Tuning : The pre trained model is then fine tuned using a small labelled dataset. The fine tuning can be done for a variety of downstream task
It means is that the model is a general model that has been trained to learn a discretized representation of speech audio. It is trained on a large amount of unlabeled data to learn to represent the raw audio data as a discretized vector space encoding. This discretized vector space is thought of as speech units.
Why discretized?
Since speech signal is continuous the focus is to make it discretized so that the many of the transformer architecture which have been developed in the domain of text processing such as BERT which take discretized inputs can be utilized for further processing. The BERT architecture takes embeddings as input to predict the next embeddings. This embeddings are discretized based on the words input. This step is called pre-training.
After the pre-training stage, the model can be fine-tuned for various downstream tasks using a very small amount of labeled data. In the paper, the model was fine-tuned on small labeled data with CTC loss for the ASR(automatic speech recognition) task .
Let us understand the architecture and training process of Wav2Vec2 model
Architecture of Wav2Vec2 Model
Wav2Vec2 Model
Let's take a closer look at each component.
Feature encoder
The input to the feature encoder is a sound waveform sampled at 16khz. The feature encoder has seven blocks, and each block's temporal convolutions have 512 channels with strides (5,2,2,2,2,2,2) and kernel widths (10,3,3,3,3,2,2). This yields encoder output frequency of 49Hz with a stride of around 20ms (Time Interval = \frac{1}{49} \times 1000ms \approx 20.41ms) between each sample and a receptive field of 400 input samples or 25ms of audio. Detailed calculation is shown below.
Layer
| Channel * Input Dimension
| Kernel/Filter Width
| Strides
| Channel*outputDimension
|
---|
Total Stride = 5 x 2 x 2 x 2 x 2 x 2 x 2 = 320
Time per Sample = \frac{1}{16000} = 0.0625 milliseconds (ms) per sample
Duration per sample output = 320 x 0.0625ms = 20 ms
|
---|
1
| 1 x 16000
|
10
|
5
| 512 x 3199
|
---|
2
| 512 x 3199
|
3
|
2
| 512 x 1599
|
---|
3
| 512 x 1599
|
3
|
2
| 512 x 799
|
---|
4
| 512 x 799
|
3
|
2
| 512 x 399
|
---|
5
| 512 x 399
|
3
|
2
| 512 x 199
|
---|
6
| 512 x 199
|
2
|
2
| 512 x 99
|
---|
7
| 512 x 99
|
2
|
2
| 512 x 49
|
---|
Feature Encoder of Wav2Vec2Contextualized representations with Transformers
The core of wav2vec 2.0 is its Transformer encoder, which takes as input the latent feature vectors obtained from the feature encoder and processes it through transformer blocks. The input sequence undergoes an initial transformation by passing through a feature projection layer, which increases the dimension from 512 (the feature encoder output) to 768 for the BASE variant or 1,024 for the LARGE variant thereby aligning with the inner dimension requirements of the Transformer encoder.
BASE contains 12 transformer blocks, model dimension 768, inner dimension (FFN) 3,072 and 8 attention heads. The LARGE model is made up of 24 transformer blocks with model dimensions of 1,024, inner dimensions of 4,096 and 16 attention heads.
One difference with respect to BERT architecture is how positional information is incorporated. Instead of fixed positional embeddings which encode absolute positional information, the wav2vec model instead uses a new grouped convolution layer to learn relative positional embeddings by itself.
The output of transformer is a context vector. The transformer builds context representations over continuous speech representations which are compared with respect to the output of quantization module . The output of quantization module (quantized vector) represent the discrete targets to be learnt by the transformer encoder. Here both the quantized vector and context vector are jointly learn using contrastive loss . More details about this in the training section.
Quantization module
The quantization module of Wav2Vec2 is adopted from vq-wav2vec architecture. Below diagram shows the overall quantization process.
Wav2Vec2 Quantization Process The output of the feature encoder (rather than the context transformer) is discretized in parallel using a product quantization-based quantization module. Quantization is the process of mapping infinite values to discrete ones. A codebook in product quantization is like a set of representative points that help us to discretize . This representative values can be thought of as speech units. Here are the steps of quantization
- For time duration of 1 sec we get 512*49 dimension vector from the feature encoder. Thus we get 49 latent features each of size 512 .
- A linear layer projects each of the feature from 512 to 640(V) logits. Here the 640 logits is divided into two groups (G=2). This 320 logits represent codebook of 320 discrete vectors. The codebook is randomly initialized. The codebook representation is learnt during training using contrastive loss. Since we have mapped our feature vector into two groups we get a total possible combination of 320 * 320 =102400 speech units
- Using Gumbel-Softmax a one hot vector is produced for each group G. Thus we get two one hot vector . Each of the one hot vector corresponds to one of the 320 discrete vectors in the codebook.
- Gumbel Softmax is a popular technique for sampling from discrete space. The method involves introducing stochasticity (using Gumbel distribution) into the discrete decision-making process by using a differentiable approximation(softmax) to the argmax operation. It enables to backpropagate through random samples of discrete variables. Gumbel-Max Trick is very similar to the Reparameterization track whereby we are combining the deterministic part (the model logits) with the stochastic part (Gumbel noise ). During forward pass or inference the largest index is picked and the vector corresponding to it from the codebook is used. During backward pass the logits calculated is used for backpropagation.
- Each of the vector in code book is of size d/2 . We obtain two code book vectors(e1 and e2) for each latent feature vector (Z). This vector e1 and e2 are concatenated to get a 'd' dimension vector. Then it is passed through a linear transformation Rd→ Rf to obtain quantized vector q ∈ Rf. This transformation is done to match the dimension of transformer output.
Training Process
First let us understand what is a contrastive score and contrastive loss in order to understand the training procedure of wav2vec model
Contrastive Score typically involves computing a similarity metric between pairs of samples. Commonly used similarity metrics include cosine similarity or dot product. The idea is to compare the representations of two instances in the embedding space. For positive pairs (examples that should be similar), the contrastive score should be high, indicating high similarity. For negative pairs (examples that should be dissimilar), the contrastive score should be low, indicating low similarity.
Contrastive Loss is often used as part of a loss function during training. One popular loss function in contrastive learning is the contrastive loss, which encourages the model to bring positive pairs closer together in the embedding space while pushing negative pairs apart.
L(i,j) = -log(\frac{e^{(sim(z_i,z_j/\tau)}}{\Sigma e^{(sim(z_i,z_k/\tau)}})
where
- L(i,j)
is the contrastive loss for samples i and j
- sim(z_i , z_j)
is the similarity score between samples i and j
- the sum is over all samples in the batch
Here one important thing to note is that the positive pair are moved/changed to make them more similar and negative pairs are moved/changed in a way that make them more dissimilar.
In the context of Wav2Vec2
- The output from feature encoder is passed through quantization module to get a quantized representation from the codebook. This is the positive sample .
- The same output form feature encoder is passed through transformer encoder . Before passing a proportion of the feature is masked (~50%). The objective is to learn the representation of discrete speech audio at the masked position by comparing it with true quantized latent speech representation. For each masked position, 100 negative distractors(negative sample) are uniformly sampled from other positions in the same sentence. This 100 negative distractors are from codebook of 320 representations excluding the positive vector.
- The model compares the similarity using the conservative loss equation as shown above.
- The loss is then backpropagated through the transformer as well as the quantization module to make the output of transformer encoder and the codebook positive sample similar as well as codebook negative sample more dissimilar.
Diversity Loss is used to encourage the equal use of all the entries in codebooks to represent both positive and negative samples during training a diversity loss is added . This works by maximizing the entropy of the averaged-Softmax distribution, preventing the model to always choose from a small sub-group of all available codebook entries.
Wav2Vec2 Model Implementation
Install Libraries
Install the below libraries if not available in your environment. These are required to run the subsequent code.
!pip install datasets
!pip install transformers
!pip install torch
!pip install evaluate
!pip install transformers[torch]
Import Libraries
And then import the libraries into your notebook. Required libraries include numpy, transformers and pytorch.
Python
# Imports required
import numpy as np
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import evaluate
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from transformers import TrainingArguments, Trainer
Loading Dataset and Preprocessing
Loading Minds 14 dataset and split the dataset in 80:20 ratio.
Python
# Load the PolyAI dataset.
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train[:80]")
# Remove unnecessary columns
dataset = dataset.remove_columns(['path','english_transcription','intent_class'])
# Split the dataset into train and test
dataset = dataset.train_test_split(test_size = 0.2, shuffle=False)
Resampling data
We need to resample the data to 16khz as the Wav2Vec2 model is trained in 16khz and the dataset is in 8khz. For this we will use Audio library.
Python
# Declare device variable
device = 'cuda' if torch.cuda.is_available() else'cpu'
# Resample the dataset to 16 Khz as MCTCT model is trained on 16khz
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(device)
Drawing Inferences
We format an input and use the base model to infer its transcription. The model produces output in logits, and we decode it by selecting the maximum value among the logits. The use of 'torch.no_grad()' ensures that these operations do not contribute to gradient computation, which is particularly helpful when there's no need to update the model weights.
Python
# Lets process the first example of train dataset
inputs = processor(dataset['train'][3]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
# getting the predictions
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
transcription
Output:
['HOW DO I FURN A JOINA COUT']
The actual text of the audio is 'how do I start a joint account' .
Fine Tuning the model
We want to prepare our data to match the expected format for the Wav2Vec2 model using the Dataset map function. For this, we're creating two columns named 'input_values,' where the raw input sound wave array needs to be resampled to 16kHz, and 'labels,' which will hold the transcription in the format expected by the tokenizer. To achieve this, we're passing each piece of data through a processor defined below.
Python
# Preparing a function to process the entire dataset
# We need to crate two variables with name 'input_featrues'
# (input array of sound wave in raw foram) and 'labels'(transcription)
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_values"] = processor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
with processor.as_target_processor():
batch["labels"] = processor(batch["transcription"].upper()).input_ids
return batch
encoded_dataset = dataset.map(prepare_dataset, num_proc=1)
Creating a specialized class for Data
We're crafting a DataCollator Class specifically designed for fine-tuning Wav2Vec2. Unlike Transformer models, ASR tasks don't have a built-in data collator. So, we're tweaking the DataCollatorWithPadding class to create batches of examples that match the elements found in the training or evaluation datasets.
It's worth highlighting that 'input_values' and 'labels' need different padding strategies since they can have varying lengths. In ASR tasks with potentially large input sizes, it's more efficient to dynamically pad training batches. This means each training sample only gets padded to match the length of the longest sample within its batch, rather than padding to the overall longest sample.
So, in essence, for fine-tuning Wav2Vec2, we're crafting a specialized padding data collator, and we'll define it below:
Python
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = "longest"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_values = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(input_values, padding=self.padding, return_tensors="pt")
with self.processor.as_target_processor():
labels_batch = self.processor.pad(label_features, padding=self.padding, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")
Evaluation Metric
For our task, we'll be using the word error rate metric. To measure this, we need to define a 'compute_metrics' function. Each logit vector has a length equal to the configured vocabulary size, which is noted as 'config.vocab_size.' Our main focus is on figuring out the model's prediction, and we do this by calculating the argmax(...) of the logits.
To make sense of the predictions, we convert the encoded labels back into their original string form. This involves a couple of steps. First, we replace instances of -100 with the 'pad_token_id.' Then, we decode the IDs while making sure that consecutive tokens are not incorrectly grouped together. This decoding process aligns with the CTC (Connectionist Temporal Classification) style, ensuring accuracy in the representation of the original string.
Python
wer = evaluate.load('wer')
def compute_metrics(pred):
wer = evaluate.load("wer")
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Model Training
Wav2Vec2 is a sizable model that demands a significant amount of memory, making GPU training a necessity. If your system lacks sufficient memory, there's a risk of encountering out-of-memory issues. The learning rate has been fine-tuned through heuristic methods to ensure stable fine-tuning. It's crucial to note that these parameters are highly dependent on the dataset, so experimenting with various values is essential.
To initiate the training process, pass these training arguments, along with the dataset, model, tokenizer, and data collator, to the Trainer. Once set up, call the '.train()' method to kickstart the training.
Python
del model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h",
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id)
model.to(device)
# defining training arguments and trainer
training_args = TrainingArguments(
output_dir="wav2vec2_finetuned",
gradient_checkpointing=True,
per_device_train_batch_size=1,
learning_rate=1e-5,
warmup_steps=2,
max_steps=2000,
fp16=True,
optim='adafactor',
group_by_length=True,
evaluation_strategy="steps",
per_device_eval_batch_size=1,
eval_steps=100,
load_best_model_at_end=True,
metric_for_best_model="wer",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["test"],
tokenizer=processor.feature_extractor,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
Output:
Step Training Loss Validation Loss Wer
100 No log 1.422148 0.354839
200 No log 1.584326 0.379032
300 No log 1.595137 0.346774
400 No log 1.534755 0.314516
500 1.022900 1.548012 0.322581
600 1.022900 1.525821 0.322581
Getting Prediction from the Fine-tuned model
Python
## getting test data
i2 = processor(dataset['test'][6]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
print(f"The input test audio is: {dataset['test'][6]['transcription']}")
# prediction for test data
with torch.no_grad():
logits = model(**i2.to(device)).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
print(f'The output prediction is : {transcription[0]}')
Output :
The input test audio is: so you spent the money I'd like to see my new account balance
The output prediction is : SO JUS SPEND SOME MONEY I'D LIKE TO SEE MY NEW ACCOUNT BALANCE
The output is better this time.
Conclusion
Self-supervised learning, exemplified by models like Wav2Vec2, offers a robust approach for representation learning in domains with limited labeled data. Fine-tuning on specific tasks further refines the model's performance, showcasing the adaptability and effectiveness of this training methodology.
Similar Reads
Self-Supervised Learning (SSL)
In this article, we will learn a major type of machine learning model which is Self-Supervised Learning Algorithms. Usage of these algorithms has increased widely in the past times as the sizes of the model have increased up to billions of parameters and hence require a huge corpus of data to train
8 min read
Speech emotion Recognition using Transfer Learning
This article provides a comprehensive guide to implementing Speech Emotion Recognition (SER) using Transfer Learning, leveraging tools like Librosa for audio feature extraction and VGG16 for robust classification. Prerequisites: VGG-16Need for Speech Emotion Recognition Speech emotion recognition (S
8 min read
Different Techniques for Sentence Semantic Similarity in NLP
Semantic similarity is the similarity between two words or two sentences/phrase/text. It measures how close or how different the two pieces of word or text are in terms of their meaning and context.In this article, we will focus on how the semantic similarity between two sentences is derived. We wil
15+ min read
ALBERT - A Light BERT for Supervised Learning
The BERT was proposed by researchers at Google AI in 2018. BERT has created something like a transformation in NLP similar to that caused by AlexNet in computer vision in 2012. It allows one to leverage large amounts of text data that is available for training the model in a self-supervised way. ALB
4 min read
DeepSeek-R1: Technical Overview of its Architecture and Innovations
DeepSeek-R1 the latest AI model from Chinese startup DeepSeek represents a groundbreaking advancement in generative AI technology. Released in January 2025, it has gained global attention for its innovative architecture, cost-effectiveness, and exceptional performance across multiple domains. What M
5 min read
How to implement unsupervised learning tasks with TensorFlow?
In this article, we are going to explore how can we implement unsupervised learning tasks using TensorFlow framework. Unsupervised learning, a branch of machine learning, discovers patterns or structures in data without explicit labels. TensorFlow users can explore diverse unsupervised learning tech
4 min read
10 Must Read Machine Learning Research Papers
Machine learning is a rapidly evolving field with research papers often serving as the foundation for discoveries and advancements. For anyone keen to delve into the theoretical and practical aspects of machine learning, the following ten research papers are essential reads. They cover foundational
7 min read
Semi Supervised Learning Examples
Semi-supervised learning is a type of machine learning where the training dataset contains both labeled and unlabeled data. This approach is useful when acquiring labeled data is expensive or time-consuming but unlabeled data is readily available. In this article, we are going to explore Semi-superv
5 min read
How to Set Up Speech Recognition on Windows?
Windows 11 and Windows 10, allow users to control their computer entirely with voice commands, allowing them to navigate, launch applications, dictate text, and perform other tasks. Originally designed for people with disabilities who cannot use a mouse or keyboard. In this article, We'll show you H
5 min read
Statistical Nature of the Learning Process in Neural Networks
Understanding the statistical nature of the learning process in neural networks (NNs) is pivotal for optimizing their performance. This article aims to provide a comprehensive understanding of the statistical nature of the learning process in NNs. It will delve into the concepts of bias and variance
6 min read