Open In App

LSTM - Derivation of Back propagation through time

Last Updated : 28 May, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

Long Short-Term Memory (LSTM) are a type of neural network designed to handle long-term dependencies by handling the vanishing gradient problem. One of the fundamental techniques used to train LSTMs is Backpropagation Through Time (BPTT) where we have sequential data. In this article we see how BPTT works in LSTMs.

Why Backpropagation Through Time?

Backpropagation is the main method used to train neural networks. It helps the network learn by adjusting its weights based on how much the predictions differ from the actual values. In feedforward neural networks this process is simple because the data moves in one direction from input to output.

However in RNNs and LSTMs the situation is a bit different as they process sequential data like sentences, time-series or speech meaning that each piece of data depends on previous pieces. This creates a temporal dependency between the steps. Because of this we can't just apply backpropagation normally. Instead we need to modify it so that it can account for time dimension, this is where Backpropagation Through Time (BPTT) comes in.

It works by unfolding the algorithm over time by treating each step as a layer in a deep network. It calculates the gradients at each step and sends them backward to update the model's weights. This helps the network learn how each step in the sequence depends on the others.

LSTMs introduce structural differences that modify the gradient computations:

  • LSTMs have gates (input, forget and output) that control gradient flow unlike simple RNNs.
  • The cell state (c_t) enables long-term memory retention and mitigates the vanishing gradient problem.
  • Unlike RNNs gradients propagate through both hidden states (h_t) and cell states (c_t) leading to additional partial derivatives.
  • LSTMs require different gradient computations due to element-wise multiplications and activation functions.

Understanding these differences is important to know how LSTMs solve vanishing gradient issues while maintaining long-term dependencies.

Notations Used:

  • x_t be the input at timestep t.
  • h_{t-1} and h_t be the hidden states at previous and current timestep.
  • c_{t-1} and c_t be the cell states at previous and current timestep.
  • \sigma represents the sigmoid activation function.
  • \tanh represents the hyperbolic tangent activation function.
  • \odot represents element-wise (Hadamard) multiplication.

The initial values of c_0 and h_0 are usually set to zero.

Working of BPTT

1. Initialization of Weights

Each LSTM cell has three gates: Input Gate, Forget Gate and Output Gate each with associated weights and biases.

Input Gate:

  • Weights: W_{xi}, W_{hi}, b_i
  • Candidate Cell State Weights: W_{xg}, W_{hg}, b_g

Forget Gate:

  • Weights: W_{xf}, W_{hf}, b_f

Output Gate:

  • Weights: W_{xo}, W_{ho}, b_o

2. Forward Pass Through LSTM Gates

At each timestep t we compute the gate activations as follows:

  • Forget Gate: f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
  • Input Gate:i_t = σ(W_xi x_t + W_hi h_{t-1} + b_i)
  • Candidate Cell State: g_t = \tanh(W_{xg} x_t + W_{hg} h_{t-1} + b_g)
  • Cell State Update: c_t = f_t \odot c_{t-1} + i_t \odot g_t
  • Output Gate: o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
  • Hidden State Update: h_t = o_t \odot \tanh(c_t)

3. Performing Backpropagation Through Time (BPTT) in LSTMs

The goal is to compute gradients for all weights using the chain rule considering LSTM-specific gates and memory cells.

1. Gradient of Loss w.r.t Output Gate:

  • \frac{dE}{do_t} = \frac{dE}{dh_t} \odot \tanh(c_t)

2. Gradient of Loss w.r.t Cell State:

  • \frac{dE}{dc_t} = \frac{dE}{dh_t} \odot o_t \odot (1 - \tanh^2(c_t))

3. Gradient of Loss w.r.t Input Gate and Candidate Cell State:

  • \frac{dE}{di_t} = \frac{dE}{dc_t} \odot g_t
  • \frac{dE}{dg_t} = \frac{dE}{dc_t} \odot i_t

4. Gradient of Loss w.r.t Forget Gate:

  • \frac{dE}{df_t} = \frac{dE}{dc_t} \odot c_{t-1}

5. Gradient of Loss w.r.t Previous Cell State:

  • \frac{dE}{dc_{t-1}} = \frac{dE}{dc_t} \odot f_t

4. Computing Weight Gradients

Using chain rule we compute gradients for weights associated with each gate.

Gradients for Output Gate Weights:

  • \frac{dE}{dW_{xo}} = \frac{dE}{do_t} \odot o_t (1 - o_t) \odot x_t
  • \frac{dE}{dW_{ho}} = \frac{dE}{do_t} \odot o_t (1 - o_t) \odot h_{t-1}
  • \frac{dE}{db_o} = \frac{dE}{do_t} \odot o_t (1 - o_t)

Similarly, gradients for other gates' weights and biases are calculated. Backpropagation Through Time for LSTMs involves:

  • Unfolding the LSTM over all timesteps.
  • Calculating gradients for each gate and cell state using chain rule.
  • Accounting for both hidden state and cell state dependencies.
  • Propagating gradients backward through time to update weights.

This process enables LSTMs to learn long-term dependencies in sequential data by efficiently updating weights while mitigating vanishing gradients.
 


Next Article
Article Tags :

Similar Reads