Open In App

Back Propagation through time - RNN

Last Updated : 29 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Recurrent Neural Networks (RNNs) are designed to process sequential data. Unlike traditional neural networks, RNN outputs depend not only on the current input but also on previous inputs through a memory element. This memory allows RNNs to capture temporal dependencies in data such as time series or language.

Training RNNs involves backpropagation where instead of updating weights based only on the current timestep t, we also consider all previous timesteps : t-1, t-2, t-3, \ldots .

This method is called Backpropagation Through Time (BPTT) which extends traditional backpropagation to sequential data by unfolding the network over time and summing gradients across all relevant time steps. This method enables RNNs to learn complex temporal patterns.

RNN Architecture

At each timestep t , the RNN maintains a hidden state S_t​, which acts as the network’s memory summarizing information from previous inputs. The hidden state S_t​ updates by combining the current input X_t​ and the previous hidden stateS_{t-1} , applying an activation function to introduce non-linearity. Then the output Y_t ​is generated by transforming this hidden state.

S_t = g_1(W_x X_t + W_s S_{t-1})

  • S_t represents the hidden state (memory) at time t.
  • X_t​ is the input at time t.
  • Y_t​ is the output at time t.
  • W_s, W_x, W_y​ are weight matrices for hidden states, inputs and outputs, respectively.

Y_t = g_2(W_y S_t)

where g_1​ and g_2​ are activation functions.

RNN Architecture

Error Function at Time t=3

To train the network, we measure how far the predicted output Y_t ​ is from the desired output d_t​ using an error function. We use the squared error to measure the difference between the desired output d_t and actual output Y_t:

E_t = (d_t - Y_t)^2

At t=3:

E_3 = (d_3 - Y_3)^2

This error quantifies the difference between the predicted output and the actual output at time 3.

Updating Weights Using BPTT

BPTT updates the weights W_y, W_s, W_x to minimize the error by computing gradients. Unlike standard backpropagation, BPTT unfolds the network across time steps, considering how errors at time t depend on all previous states.

We want to adjust the weights W_y​, W_s​ and W_x​ to minimize the error E_3​.

1. Adjusting Output Weight W_y

The output weight W_y​ affects the output directly at time 3. This means we calculate how the error changes as Y_3​ changes, then how Y_3​ changes with respect to W_y​. Updating W_y​ is straightforward because it only influences the current output.

Using the chain rule:

\frac{\partial E_3}{\partial W_y} = \frac{\partial E_3}{\partial Y_3} \times \frac{\partial Y_3}{\partial W_y}

  • E_3 depends on Y_3​, so we differentiate E_3​ w.r.t. Y_3​.
  • Y_3​ depends on W_y​, so we differentiate Y_3​ w.r.t. W_y​.
Adjusting Wy

2. Adjusting Hidden State Weight W_s

The hidden state weight W_s​ influences not just the current hidden state but all previous ones because each hidden state depends on the previous one. To update W_s​, we must consider how changes to W_s​ affect all hidden statesS_1, S_2, S_3 and consequently the output at time 3.

The gradient for W_s​ considers all previous hidden states because each hidden state depends on the previous one:

\frac{\partial E_3}{\partial W_s} = \sum_{i=1}^3 \frac{\partial E_3}{\partial Y_3} \times \frac{\partial Y_3}{\partial S_i} \times \frac{\partial S_i}{\partial W_s}

Breaking down:

  • Start with the error gradient at output Y_3​.
  • Propagate gradients back through all hidden states S_3, S_2, S_1 since they affect Y_3​.
  • Each S_i​ depends on W_s​, so we differentiate accordingly.
    Adjusting Ws

3. Adjusting Input Weight W_x

Similar to W_s​, the input weight W_x​ affects all hidden states because the input at each timestep shapes the hidden state. The process considers how every input in the sequence impacts the hidden states leading to the output at time 3.

\frac{\partial E_3}{\partial W_x} = \sum_{i=1}^3 \frac{\partial E_3}{\partial Y_3} \times \frac{\partial Y_3}{\partial S_i} \times \frac{\partial S_i}{\partial W_x}

The process is similar to W_s​, accounting for all previous hidden states because inputs at each timestep affect the hidden states.

Adjusting Wx

Advantages of Backpropagation Through Time (BPTT)

  • Captures Temporal Dependencies: BPTT allows RNNs to learn relationships across time steps, crucial for sequential data like speech, text and time series.
  • Unfolding over Time: By considering all previous states during training, BPTT helps the model understand how past inputs influence future outputs.
  • Foundation for Modern RNNs: BPTT forms the basis for training advanced architectures such as LSTMs and GRUs, enabling effective learning of long sequences.
  • Flexible for Variable Length Sequences: It can handle input sequences of varying lengths, adapting gradient calculations accordingly.

Limitations of BPTT

  • Vanishing Gradient Problem: When backpropagating over many time steps, gradients tend to shrink exponentially, making early time steps contribute very little to weight updates. This causes the network to “forget” long-term dependencies.
  • Exploding Gradient Problem: Gradients may also grow uncontrollably large, causing unstable updates and making training difficult.

Solutions

  • Long Short-Term Memory (LSTM): Special RNN cells designed to maintain information over longer sequences and mitigate vanishing gradients.
  • Gradient Clipping: Limits the magnitude of gradients during backpropagation to prevent explosion by normalizing them when exceeding a threshold.

In this article, we learned how Backpropagation Through Time (BPTT) enables Recurrent Neural Networks to capture temporal dependencies by updating weights across multiple time steps along with its challenges and solutions.


Next Article
Article Tags :
Practice Tags :

Similar Reads