Llama Based Punctuation Restoration With Forward Pass Only Decoding
Llama Based Punctuation Restoration With Forward Pass Only Decoding
Yutong Pang1 , Debjyoti Paul1 , Kevin Jiang1 , Xuedong Zhang1 , Xin Lei1
1
Meta, USA
LLaMA for punctuation restoration, which demonstrates supe- tuning [11], which demands significantly less supervised train-
rior performance compared to the established benchmark. ing data, we achieve comparable and even superior performance
Despite its impressive quality, LLaMA faces challenges re- for punctuation restoration compared to traditional methods in-
garding inference speed and hallucinations. To address this, troduced previously. This approach addresses both the quality
our second contribution presents Forward Pass Only Decoding effectiveness and scale-up concerns associated with punctuation
(FPOD), a novel decoding approach for annotation tasks. This restoration in diverse languages and domains.
innovative method results in a substantial 19.8x improvement in
In our exploration of LLaMA-based punctuation restora-
inference speed, effectively addressing a critical bottleneck and
tion, we present a range of strategies. Initially, we delve into
enhancing the practical utility of LLaMA for large-scale data
the traditional approach with auto-regressive generation. Sub-
annotation tasks without hallucinations.
sequently, we explore techniques to address the inherent chal-
The combination of these contributions not only solidifies
lenge of inference speed in LLaMA. The first of these strate-
LLaMA as a powerful tool for punctuation restoration but also
gies involves speculative decoding, showcasing improvements
highlights FPOD as a crucial strategy for overcoming speed
in inference speed while maintaining the quality of generated
constraints.
outputs exactly the same as the original base model. Finally, we
Index Terms: speech recognition, human-computer interac-
present a new forward pass only approach, eliminating the need
tion, computational paralinguistics, punctuation, LLM
for auto-regressive generation entirely. This novel approach re-
sults in a substantial boost in inference speed.
1. Introduction Our contribution not only establishes LLaMA as a potent
Automatic Speech Recognition (ASR) plays a vital role in nu- alternative for achieving high-quality punctuation restoration
merous domains involving human-computer interaction [1] [2] but also introduces practical enhancements to overcome the
[3]. However, the outputs of many ASR systems often lack challenges associated with inference speed.
punctuation. Punctuation restoration in the context of ASR out-
put is a crucial component [4] [5], essential for enhancing the 2. Proposed Method
overall utility, user experience, and comprehensibility of tran-
scribed speech. Restoring the punctuation will make the raw In this section, we described the proposed forward pass only
ASR output more coherent, with improved intractability. method to restore the punctuation. At the same time, we will
The field of punctuation restoration encompasses two dis- compare it with other decoding methods: the auto regressive
tinct techniques: cascade methods, exemplified by models decoding and speculative decoding.
like BERT [6], commonly applied independently to Automatic
Speech Recognition (ASR) outputs in spoken domains without 2.1. Auto Regressive Generation
punctuation [7]. These cascade models function as standalone Auto-regressive generation refers to a process in which a lan-
systems, addressing the punctuation restoration task sequen- guage model generates output (e.g., text) one token at a time in
tially. On the other hand, the End-to-End (E2E) approach, rep- a sequential manner. At each step, the model predicts the next
resented by models such as Recurrent Neural Network Trans- token based on the context of the preceding tokens it has gen-
ducer (RNNT) or Whisper [8], trained in an end2end fash- erated. This process is ”auto-regressive” because the model’s
ion, incorporates built-in punctuation output. This category own outputs become part of the input for predicting subsequent
of techniques streamlines the punctuation restoration process. tokens. Inference from LLaMA auto-regressively is slow - de-
However, both approaches face challenges, the former requir- coding K tokens takes K sequential run of the model.
ing independent but domain-aligned training data and evalua-
tion effort, and the latter compels to use large amounts of high- 2.2. Speculative Decoding
quality supervised data containing punctuation paired with au-
dio, which is a bottleneck for scaling ASR systems to new do- We have already seen that the auto regressive generation is a
mains and languages requiring punctuation restoration. very slow generation process. Speculative Decoding is intro-
Recognizing the significance of the punctuation restoration duced to improve it [12]. It refers to the process of using an
task and the challenges posed by previous models, our work in- assistant model to help the decoding process to prevent going
through auto regressive decoding for most cases. Here is a brief
explanation of how speculative decoding works:
• We first use the assistant model (usually a small distilled stu-
dent model) to generate the output auto regressively
• Then we send the output to the large main model (usually a
large teacher model), and only perform verification forward
passes
• If the verification is successful (the main model agrees with
the assistant model), then we directly use the assistant model
output as final output. Otherwise, we need to run the full
auto-regressive generation with the large main model to get a
“better” output.
Figure 2: Directly feeding input as response in prompt for for-
• Since for the cases with successfully verified results, we ward pass only decoding (FPOD) scheme.
only run the auto-regressive generation with the fast assis-
tant model and only perform verification forward pass with
the slow main model, the decoding process is sped up sub-
stantially.
Speculative decoding could help us to improve the inference
speed; however, we still need to train the distilled student
model, and auto-regressive generation is still needed for all the
student model pass and some of the base model pass (the case
failed with forward verification). The inference speed limit is
totally dependent on the quality and size of the student model.
And the general inference speed improvement is usually less
than 2X [13].
1 − αL
E(#token) =
1−α
.
The above factor is similar to speculative decoding ex-
pected token generation [12]. Moreover, for Algorithm 1 we
Figure 4: Sliding window with padding approach for long input need to consider a time efficiency factor for running forward
text. pass decoding w.r.t. regressive generation. We introduce η, a
time-efficiency factor to attribute running one step of forward
pass decoding vs. one step regressive generation, where η ≤ 1
tion task, significantly enhances the inference speed of punctu- but very close to ≈ 1, in the later experiment section, we will
ation restoration compared to traditional auto-regressive meth- give an estimate of the η. Since forward pass predicts for L
ods. Furthermore, we utilize the frozen LoRA fine-tuned model, tokens in parallel with multiprocessing, ideally taking the same
eliminating the need for additional training such as token clas- time as one token prediction with regressive generation, with
sification for punctuation task. slight overhead for multiprocessing. Then, the overall expected
improvement factor in token generation is
In addition to enhancing speed, the use of forward pass only
decoding ensures that the token lengths and the original sen- 1 − αL
Improvement Factor (IF) = η
tence structure (with punctuation modifications only) remain 1−α
unaltered. This method effectively mitigates the issue of hallu-
cination, a common problem associated with the auto-regressive Let’s conduct an empirical analysis to gauge the enhance-
approach. ment factor of FPOD for the punctuation task. Drawing from
Limitations. Decoding solely through the forward pass appears a frequency distribution analysis of punctuation marks in En-
highly efficient and straightforward; nevertheless, certain de- glish across extensive corpora [14], we can approximate around
tails require careful consideration: 91,000 punctuation marks (including commas, periods, and
question marks) per million words, equating to roughly 9 punc-
• In general, the performance of the large language model tuation marks per 100 words. We can reasonably assume an av-
(LLM) is usually worse for “super long” input context, which erage number of tokens per word for LLaMA models, denoted
is often more important for punctuation restoration. as κ, where κ ≥ 1. This implies we expect to encounter ap-
• With forward only decoding, the given token prediction only proximately 9 punctuation marks every 100κ tokens. However,
depends on the previous token history. So let’s see if the his- for the sake of simplicity, we’ll consider κ = 1. Hence, α is,
tory is “hello how are you”, and we want to predict the next
token of you, ideally should be “?”. However, because the 9
α=1− = 0.91 let κ = 1
previous history does not contain any punctuation, the model 100
behavior may be different from the auto-regressive genera-
Therefore, the improvement factor for the punctuation task is
tion process.
1 − αL 1 − 0.9150
Solutions. To address the first limitation i.e., decoding longer IF = η =η ≈ 11η for L = 50
text, we use a simple sliding window with padding approach, 1−α 1 − 0.91
illustrated by the following Figure 4. To solve the second lim-
itations i.e., context dependant decoding, instead of one pass Applications. As mentioned earlier, FPOD can be promising
forward decoding, we will split the process into the following in various applications such as tagging, verification, and text
step as Recursive FPOD: enhancement. For instance, it can be utilized for tasks like entity
• Iterate through the input tokens “hello how are you”, we will tagging, verifying speech recognition transcriptions for quality
update the sentence once we find a punctuation prediction. control, and text normalization or inverse-text normalization.
• In this case, “hello how are you” → “hello, how are you” →
“hello, how are you?” 3. Experiments
• So instead of one pass forward decoding, we pass the input In this section, we verify the effectiveness of the proposed
sentence two times to the forward decoding process. forward pass only decoding method for punctuation restora-
Improvement Factor. Lets analyze the improvement factor tion. We will compare the F1 score as a metric for punctuation
running recursive FPOD with respect to auto-regressive restoration quality [7]. In all the experiments, the punctuation
generation. restoration is applied directly to ASR output reference (without
punctuation). We will also compare the inference speed, mea- mas (,). We employ a recursive forward decoding technique
sured by tokens/second for each decoding method. The detailed with a window for this study. As shown in Table 2, the results
setup and results of each experiment are also described. suggest that the LLaMA2-based model, both with FPOD and
recursive FPOD method, can improve the F1 score for all punc-
3.1. LoRA Finetuned Model for Punctuation Restoration tuation marks. These improvements surpass the performance
of both the RNNT and Whisper models. Notably, the utiliza-
The punctuation restoration model is trained with Lora Fine tun-
tion of recursive FPOD further amplifies the F1 score by a sub-
ing on the 13B LLaMA2 model, and the training data is 20k
stantial margin. Regarding inference speed, the adoption of
of train-clean-360 data from Librispeech-PC dataset [15]. The
recursive FPOD achieves an impressive rate of 959.1 tokens/s
prompt template for LoRA fine-tuning is described in Figure 1.
for long input texts. This represents a remarkable 10.8x im-
After the LoRA fine-tuning process, we can get the merged
provement compared to the auto-regressive baseline of 88.72
model for the punctuation restoration task. We run the knowl-
tokens/s, as demonstrated in Table 1. Here we can estimate η as
edge distillation (KD) with the same training dataset to get the
10.8/11 = 0.98.
distilled assistant model (350MB) [16].