Sequence Architectures
Sequence Modeling
- FFNNs can't be used because of limited context window
- Languages can have longer dependencies over arbitrary context length
- Language Models assign conditional probability to the next word
- $P(W_{1:n}) = \prod_{i=1}^{n} P(W_i | W_{1:i-1})$
- Quality of a language model is assessed by perplexity
- $PP = P(W_{1:n})^{-1/n}$
- Inverse probability that the model assigns to the test sequence normalized by the length
Recurrent Neural Networks
- NN architecture that contains a cycle in its network connections
- The hidden layer output from previous step is linked to the current hidden layer output
- Predict using current input and previous hidden state
- Removes the fixed context dependency arising in FFNNs
- The temporal hidden output can be persisted for infinite steps
- Inference
- $h_t = g(Uh_{t-1} + Wx_t)$
- $y_t = V(h_t)$
- Training
- Chain rule for backpropagation
- Output depends on hidden state and hidden state depends on previous time step
- BPTT: backpropagation through time
- In terms of computational graph, the network is "unrolled" for the entire sequence
- For very long sequences, use truncated BPTT
- RNNs and Language Models
- Predict next word using current word and previous hidden state
- Removes the limited context problem
- Use word embeddings to enhance the model's generalization ability
- $e_t = Ex_t$
- $h_t = g(Uh_{t-1} + We_t)$
- $y_t = V(h_t)$
- Output the probability distribution over the entire vocabulary
- Loss function: Cross entropy, difference between predicted probability and true distribution
- Minimize the error in predicting the next word
- Teacher forcing for training
- In training phase, ignore the model output for predicting the next word
- Use the actual word instead
- Weight tying
- Input embedding lookup and output probability matrix have same dimensions |V|
- Avoid using two different matrices, use the same one instead
- RNN Tasks
- Sequence Labeling
- NER tasks, POS tagging
- At each step predict the current tag rather than the next word
- Use softmax over tagset with CE loss function
- Sequence Classification
- Classify entire sequences rather than the tokens
- Use hidden state from the last step and pass to FFNN
- Backprop will be used to update the RNN cycle links
- Use pooling to enhance performance
- Element-wise Mean, Max of all intermediate hidden states
- Sequence Generation
- Encoder-decoder architecture
- Autoregressive generation
- Use <s> as the first token (BOS) and hidden state from encoder
- Sample form RNN, using output softmax
- Use the embedding from the generated token as next input
- Keep sampling till </s> (EOS) token is sampled
- RNN Architectures
- Stacked RNNs
- Multiple RNNs "stacked together"
- Output from one layer serves as input to another layer
- Differening levels of abstraction across layers
- Bidirectional RNNs
- Many applications have full access to input sequence
- Process the sequence from left-to-right and right-to-left
- Concatenate the output from forward and reversed passes
LSTM
- RNNs are hard to train due to vanishing/exploding gradients
- Hidden state tends to be fairly local in practice, limiting long-term dependencies
- Vanishing gradients: Signal from far-away timesteps gets lost
- Repeated multiplications in backpropagation step
- Sigmoid derivatives between (0-0.25) and tanh derivatives between (0-1)
- Gradients diminish exponentially over long sequence lengths
- LSTMs introduce explicit memory management
- Enable network to learn to forget information no longer needed
- Persist information likely needed for decisions yet to come
- Use gating mechanism (through additional parameters) to control the flow of information
- Architecture
- Memory cell (long-term memory) + hidden state (working memory)
- Three gates control information flow:
- Forget gate: What to remove from cell state
- Input gate: What new information to store
- Output gate: What to output based on cell state
- Input Gate Logic
- Candidate values: $g_t = \tanh(W_g x_t + U_g h_{t-1} + b_g)$
- Input gate: $i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i)$
- New information: $j_t = i_t \odot g_t$
- Forget Gate Logic
- Forget gate: $f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f)$
- Retained memory: $k_t = f_t \odot c_{t-1}$
- Cell State Update
- $c_t = j_t + k_t$ (add new information to retained memory)
- Output Gate Logic
- Output gate: $o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o)$
- Hidden state: $h_t = o_t \odot \tanh(c_t)$
- LSTMs maintain two states: cell state (c) for long-term memory and hidden state (h) for output
Self Attention
- LSTMs still have limitations:
- Difficult to parallelize (sequential processing)
- Still not fully effective for very long dependencies
- Transformers - Replace recurrent layers with self-attention layers
- Self-Attention Mechanism
- Create three projections of each input vector:
- Query (Q): What the token is looking for
- Key (K): What the token offers for matching
- Value (V): The actual information to be aggregated
- Compute attention scores between each token and all other tokens
- Weight values according to attention scores
- Crucial innovation: allows direct connections between any tokens regardless of distance
- Computation Steps
- Project input sequence X into Q, K, V matrices using learned weight matrices
- $Q = XW^Q$, $K = XW^K$, $V = XW^V$
- Compute attention scores: $S = QK^T$
- Scale to stabilize gradients: $S' = S/\sqrt{d_k}$ where d_k is dimension of keys
- Apply softmax to get attention weights: $A = \text{softmax}(S')$
- Compute weighted values: $Z = AV$
- Multi-Head Attention
- Multiple parallel attention mechanisms
- Each head can capture different types of relationships
- Concatenate outputs and project back to original dimension
- Positional Encodings
- Unlike RNNs, self-attention operations are order-invariant
- Add position information to input embeddings
- Using sinusoidal functions: $PE_{(pos,2i)} = \sin(pos/10000^{2i/d})$, $PE_{(pos,2i+1)} = \cos(pos/10000^{2i/d})$
- BERT Architecture
- Base Model - 12 heads, 12 layers, 64 diemnsions, 768 size (12 * 64)
- Large Model - 16 heads, 24 layers, 64 dimensions, 1024 size (16 * 64)