LSTM

Definition

LSTM

A Long Short-Term Memory (LSTM) network is a recurrent neural network (RNN) architecture designed to process sequences while learning long-range dependencies. Its defining feature is a cell state — a memory vector carried across time steps with mostly additive (rather than repeatedly multiplicative) updates — regulated by three multiplicative gates (forget, input, output). The gates let the network learn what to remember, what to write, and what to read at each step, which prevents the vanishing/exploding gradients that cripple vanilla RNNs over long horizons.

Intuition

A plain RNN updates its hidden state by . Because is passed through a saturating nonlinearity and a weight matrix at every step, the gradient of an early input with respect to a late loss is a long product of Jacobians — it shrinks toward zero (vanishing) or blows up (exploding) exponentially in the sequence length. The network therefore cannot connect events separated by many steps.

The LSTM fixes this with a constant error carousel: the cell state has a near-identity recurrence, , so gradient can flow back many steps almost undamped when the forget gate . The gates are themselves learned sigmoid functions of the input and previous hidden state, so the network decides dynamically how long to hold each piece of information.

In RL this is exactly what is needed under Partial Observability: the LSTM hidden state acts as a learned, recursively-updated internal state summarising the whole history of observations, rather than only the last frames (frame stacking). This is the basis of Deep Recurrent Q-Learning (DRQN). In IR, the same sequence-modelling ability underlies early neural rankers and query/document encoders before Transformers became dominant.

Mathematical Formulation

At each time step , given input , previous hidden state , and previous cell state , the LSTM computes:

where:

  • — input at step (e.g. an embedded token, or CNN features of an observation)
  • — hidden state, also the layer’s output at step
  • — cell state (the long-term memory carried across steps)
  • — forget / input / output gates (element-wise sigmoid )
  • — candidate values proposed for writing into the cell (, range )
  • — input and recurrent weight matrices; — biases (each gate has its own set)
  • — element-wise (Hadamard) product

Why the gradient survives

The cell recurrence gives . Backpropagating through steps multiplies these diagonal terms: . When the forget gate stays near , this product stays near , so error flows back across long spans without vanishing — unlike the dense Jacobian product of a vanilla RNN.

Key Properties / Variants

  • Gated additive memory: the additive cell update (vs. the multiplicative hidden-state update of vanilla RNNs) is what tames vanishing gradients and enables long-range memory.
  • Parameter cost: roughly parameters per layer for hidden size and input size — four affine maps (three gates + candidate).
  • Trained with Backpropagation Through Time (BPTT): the network is unrolled over the sequence and gradients are summed across steps; long sequences are usually handled with truncated BPTT.
  • GRU (Gated Recurrent Unit): a lighter variant merging cell and hidden state and using only two gates (reset, update); fewer parameters, often comparable performance.
  • Bidirectional / stacked LSTMs: read the sequence both directions and/or stack layers; common in IR encoders where the full sequence is available offline.
  • Largely superseded by Transformers (self-attention) for both NLP and IR — attention gives direct, parallelisable access to all positions — but LSTMs remain relevant where streaming/online recurrence or low memory is needed, and in RL recurrent agents.
Algorithm: LSTM forward pass over a sequence
─────────────────────────────────────────────
Input: sequence x_1, ..., x_T ; params {W_*, U_*, b_*}
Initialize h_0 = 0,  c_0 = 0
 
for t = 1 .. T:
    f_t  = σ(W_f x_t + U_f h_{t-1} + b_f)      # forget: keep how much of c_{t-1}
    i_t  = σ(W_i x_t + U_i h_{t-1} + b_i)      # input:  how much to write
    o_t  = σ(W_o x_t + U_o h_{t-1} + b_o)      # output: how much of cell to expose
    c~_t = tanh(W_c x_t + U_c h_{t-1} + b_c)   # candidate content
    c_t  = f_t ⊙ c_{t-1} + i_t ⊙ c~_t          # update memory (additive)
    h_t  = o_t ⊙ tanh(c_t)                     # emit hidden state / output
return (h_1..h_T, c_1..c_T)

RNN training is finicky

Even with gates, LSTMs can hit local optima, need gradient clipping for the exploding-gradient direction, and are slow to train because BPTT is inherently sequential (no within-sequence parallelism). In DRQN this shows up as sensitivity to the unrolling strategy (bootstrapped random-start vs. sequential whole-episode updates).

Connections

Appears In