Recurrent Neural Networks
RNNs are designed to process sequential data by maintaining internal state
Unlike feedforward networks, RNNs share parameters across different time steps
The hidden state carries information across the sequence, acting as memory
Core Recurrent Cell
- Basic RNN update: $h_t = \phi(W_{xh}x_t + W_{hh}h_{t-1} + b_h)$
- $W_{xh}$: Input-to-hidden weights
- $W_{hh}$: Hidden-to-hidden weights (recurrent connections)
- $h_{t-1}$: Previous hidden state
- $\phi$: Activation function (typically tanh)
- Basic RNN update: $h_t = \phi(W_{xh}x_t + W_{hh}h_{t-1} + b_h)$
Types of Sequence Processing Tasks
Seq2Seq (sequence generation)
- Maps fixed input to variable-length output sequence
- Examples: Image captioning, machine translation
- Autoregressive generation: Each output depends on previous outputs
Seq2Vec (sequence classification)
- Maps variable-length input to fixed output vector
- Examples: Sentiment analysis, document classification
- Often uses final hidden state or aggregation of all states
Vec2Seq (conditioned generation)
- Maps fixed input to variable-length output sequence
- Example: Generate text conditioned on a topic vector
Seq2Seq (sequence-to-sequence)
- Maps variable-length input to variable-length output
- Examples: Machine translation, summarization
- Typically employs encoder-decoder architecture
Bidirectional RNNs
- Process sequence in both forward and backward directions
- Captures both past and future context for each position
- Forward hidden states: $\vec{h}t = \phi(W{xh}^{\rightarrow}x_t + W_{hh}^{\rightarrow}\vec{h}_{t-1})$
- Backward hidden states: $\overleftarrow{h}t = \phi(W{xh}^{\leftarrow}x_t + W_{hh}^{\leftarrow}\overleftarrow{h}_{t+1})$
- Final representation combines both directions: $h_t = [\vec{h}_t; \overleftarrow{h}_t]$
Challenges with Basic RNNs
- Vanishing Gradients: Signal from distant time steps diminishes exponentially
- Exploding Gradients: Gradients grow uncontrollably (solved with gradient clipping)
- Limited context window: Difficulty capturing long-range dependencies
Advanced RNN Architectures
LSTM (Long Short-Term Memory)
- Explicitly designed to capture long-term dependencies
- Cell state ($C_t$) acts as conveyor belt of information through time
- Three gates control information flow:
- Input gate ($I_t$): Controls what new information enters the cell
- Forget gate ($F_t$): Controls what information is discarded
- Output gate ($O_t$): Controls what information is exposed as output
- LSTM equations:
- $I_t = \sigma(W_{ix}X_t + W_{ih}H_{t-1})$
- $F_t = \sigma(W_{fx}X_t + W_{fh}H_{t-1})$
- $O_t = \sigma(W_{ox}X_t + W_{oh}H_{t-1})$
- $\tilde{C}t = \tanh(W{cx}X_t + W_{ch}H_{t-1})$ (candidate cell state)
- $C_t = F_t \odot C_{t-1} + I_t \odot \tilde{C}_t$ (cell state update)
- $H_t = O_t \odot \tanh(C_t)$ (hidden state)
- Solves vanishing gradient through additive updates and gating
GRU (Gated Recurrent Unit)
- Simplified version of LSTM with fewer parameters
- Has two gates: update gate and reset gate
- Update gate controls how much previous state is retained
- Reset gate controls how much previous state influences candidate state
- Competitive performance with LSTM but more efficient
Backpropagation through Time (BPTT)
- Unrolling the computation graph along time axis
- $h_t = W_{hx}x_t + W_{hh}h_{t-1} = f(x_t, h_{t-1}, w_h)$
- $o_t = W_{ho}h_t = g(h_t, w_{oh})$
- $L = {1 \over T}\sum l(y_t, o_t)$
- ${\delta L \over \delta w_h} = {1 \over T} \sum {\delta l \over \delta w_h}$
- ${\delta L \over \delta w_h} = {1 \over T} \sum {\delta l \over \delta o_t} {\delta o_t \over \delta h_t} {\delta h_t \over \delta w_h}$
- ${\delta h_t \over \delta w_h} = {\delta h_t \over \delta w_h} + {\delta h_t \over \delta h_{t-1}} {\delta h_{t-1} \over \delta w_h}$
- Common to truncate the update to length of the longest subsequence in the batch
- As the sequence goes forward, the hidden state keeps getting multiplied by W(hh)
- Gradients can decay or explode as we go backwards in time
- Solution is to use additive rather than multiplicative updates
Decoding
- Output is generated one token at a time
- Simple Solution: Greedy Decoding
- Argmax over vocab at each step
- Keep sampling unless
token output
- May not be globally optimal path
- Alternative: Beam Search
- Compute top-K candidate outputs at each step
- Expand each one in V possible ways
- Total VK candidates generated
- GPT used top-k and top-p sampling
- Top-K sampling: Redistribute the probability mass
- Top-P sampling: Sample till the cumulative probability exceeds p
Attention
- In RNNs, hidden state linearly combines the inputs and then sends them to an activation function
- Attention mechanism allows for more flexibility.
- Suppose there are m feature vectors or values
- Model decides which to use based on the input query vector q and its similarity to a set of m keys
- If query is most similar to key i, then we use value i.
- Attention acts as a soft dictionary lookup
- Compare query q to each key k(i)
- Retrieve the corresponding value v(i)
- To make the operation differentiable:
- Compute a convex combination
- $Attn(q,(k_1,v_1),(k_2, v_2)...,(k_m,v_m)) = \sum_{i=1}^m \alpha_i (q, {k_i}) v_i$
- $\alpha_i (q, {k_i})$ are the attention weights
- Attention weights are computed from an attention score function $a(q,k_i)$
- Computes the similarity between query and key
- Once the scores are computed, use soft max to impose distribution
- Masking helps in ignoring the index which are invalid while computing soft max
- For computational efficiency, set the dim of query and key to be same (say d)
- The similarity is given by dot product
- The weights are randomly initialized
- The expected variance of dot product will be d.
- Scale the dot product by $\sqrt d$
- Scaled Dot-Product Attention
- Attention Weight: $a(q,k) = {q^Tk \over \sqrt d}$
- Scaled Dot Product Attention: $Attn(Q,K,V) = S({QK^T \over \sqrt d})V$
- Example: Seq2Seq with Attention
- Consider encoder-decoder architecture
- In the decoder:
- $h_t = f(h_{t-1}, c)$
- c is the context vector from encoder
- Usually the last hidden state of the encoder
- Attention allows the decoder to look at all the input words
- Better alignment between source and target
- Make the context dynamic
- Query: previous hidden state of the decoder
- Key: all the hidden states from the encoder
- Value: all the hidden states from the encoder
- $c_t = \sum_{i=1}^T \alpha_i(h_{t-1}^d, {h_i^e})h_i^e$
- If RNN has multiple hidden layers, usually take the top most layer
- Can be extended to Seq2Vec models
Transformers
- Transformers are seq2seq models using attention in both encoder and decoder steps
- Eliminate the need for RNNs
- Self Attention:
- Modify the encoder such that it attends to itself
- Given a sequence of input tokens $[x_1, x_2, x_3...,x_n]$
- Sequence of output tokens: $y_i = Attn(x_i, (x_1,x_1), (x_2, x_2)...,(x_n, x_n))$
- Query is xi
- Keys and Values are are x1,x2…xn (all valid inputs)
- In the decoder step:
- $y_i = Attn(y_{i-1}, (y_1,y_1), (y_2, y_2)...(y_{i-1}, y_{i-1}))$
- Each new token generated has access to all the previous output
- Multi-Head Attention
- Use multiple attention matrices to capture different nuances and similarities
- $h_i = Attn(W_i^q q_i, (W_i^k k_i, W_i^v v_i))$
- Stack all the heads together and use a projection matrix to get he output
- Set $p_q h = p_k h = p_v h = p_o$ for parallel computation **How?
- Positional Encoding
- Attention is permutation invariant
- Positional encodings help overcome this
- Sinusoidal Basis
- Positional Embeddings are combined with original input X → X + P
- Combining All the Blocks
- Encoder
- Input: $ Z = LN(MHA(X,X,X) + X$
- Encoder: $E = LN(FF(Z) + Z)$
- For the first layer:
- $ Z = \text{POS}(\text{Embed}(X))$
- For the first layer:
- In general, model has N copies of the encoder
- Decoder
- Has access to both: encoder and previous tokens
- Input: $ Z = LN(MHA(X,X,X) + X$
- Input $ Z = LN(MHA(Z,E,E) + Z$
- Encoder
Representation Learning
- Contextual Word Embeddings
- Hidden state depends on all previous tokens
- Use the latent representation for classification / other downstream tasks
- Pre-train on a large corpus
- Fine-tune on small task specific dataset
- Transfer Learning
- ELMo
- Embeddings from Language Model
- Fit two RNN models
- Left to Right
- Right to Left
- Combine the hidden state representations to fetch embedding for each word
- BERT
- Bi-Directional Encoder Representations from Transformers
- Pre-trained using Cloze task (MLM i.e. Masked Language Modeling)
- Additional Objective: Next sentence Prediction
- GPT
- Generative Pre-training Transformer
- Causal model using Masked Decoder
- Train it as a language model on web text
- T5
- Text-to-Text Transfer Transformer
- Single model to perform multiple tasks
- Tell the task to perform as part of input sequence
- Contextual Word Embeddings