Transfer Learning and Pre-trained Models
Transfer learning has revolutionized NLP. Pre-train a large model on massive text data, then fine-tune for specific tasks. This chapter covers BERT and the pre-train/fine-tune paradigm.
The Big Picture
The Old Way: Train a model from scratch for each task.
- Requires lots of labeled data
- Each task starts from zero
The New Way: Pre-train ā Fine-tune.
- Pre-train once on huge unlabeled corpus
- Fine-tune with small labeled dataset per task
- Transfer linguistic knowledge across tasks
Key Concepts
Contextual Embeddings
Static embeddings (Word2Vec): Same vector for "bank" regardless of context.
Contextual embeddings (BERT): Different vector for "river bank" vs. "bank account".
The same word gets different representations based on surrounding words.
The Pre-train ā Fine-tune Paradigm
[MASSIVE UNLABELED TEXT] ā Pre-training ā [GENERAL LANGUAGE MODEL]
ā
[SMALL LABELED DATA] ā Fine-tuning ā [TASK-SPECIFIC MODEL]
Pre-training: Self-supervised learning on vast text (books, Wikipedia, web).
Fine-tuning: Supervised learning on task-specific data.
Language Model Types
| Type | Direction | Example | Best For |
|---|---|---|---|
| Causal | Left-to-right | GPT | Generation |
| Bidirectional | Both directions | BERT | Understanding |
| Encoder-Decoder | Both + generation | T5 | Both |
Bidirectional Transformers (BERT)
Why Bidirectional?
For many tasks, we can see the entire input at once.
Causal models (GPT): Can only see left context.
"The cat sat on the [???]" - what comes next?
Bidirectional models (BERT): See full context.
"The cat sat on the [MASK]" - what's missing?
BERT Architecture
Self-attention across entire sequence: $$\text{Attention} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
BERT Base:
- Vocabulary: 30K subwords (WordPiece)
- Hidden size: 768 (12 heads Ć 64 dims)
- Layers: 12
- Parameters: 110M
- Max sequence: 512 tokens
Compute note: Attention is O(n²) in sequence length ā limits max length.
Pre-training Objectives
Masked Language Modeling (MLM)
The task: Predict randomly masked tokens.
Input: The cat [MASK] on the mat
Target: sat
Masking strategy (for 15% of tokens):
- 80%: Replace with [MASK]
- 10%: Replace with random word
- 10%: Keep original
Why this mix?
- [MASK] never appears at fine-tuning ā train on real words too
- Random replacement adds noise, prevents overfitting
- Keeping some originals helps learn bidirectional context
Span Masking (SpanBERT)
Mask contiguous spans instead of random tokens.
Input: The [MASK] [MASK] [MASK] the mat
Target: cat sat on
Benefits:
- Better for tasks requiring span understanding (QA, NER)
- Span Boundary Objective: Predict span from boundary tokens
Next Sentence Prediction (NSP)
The task: Are these sentences adjacent in the original text?
[CLS] The cat sat [SEP] It was happy [SEP] ā IsNext
[CLS] The cat sat [SEP] I like pizza [SEP] ā NotNext
Training data: 50% real pairs, 50% random pairs.
Note: Later models (RoBERTa, ALBERT) found NSP less helpful than expected.
Pre-training Data
BERT was trained on:
- BooksCorpus: 800M words
- English Wikipedia: 2.5B words
Later models use more:
- CommonCrawl (filtered web text)
- News articles
- Code repositories
Compute requirement: Days to weeks on TPU/GPU clusters.
Fine-tuning
The Basic Recipe
- Load pre-trained model
- Add task-specific classification head (usually 1-2 layers)
- Train on labeled data with small learning rate
Fine-tuning Strategies
| Strategy | What's Updated | Best For |
|---|---|---|
| Full fine-tuning | All parameters | Lots of data, maximum performance |
| Feature extraction | Only head | Very small data |
| Adapter tuning | Small inserted modules | Efficient multi-task |
| Prompt tuning | Soft prompts only | Very large models |
Hyperparameters
Learning rate: 2e-5 to 5e-5 (much smaller than training from scratch!)
Epochs: 2-4 (often sufficient)
Batch size: 16-32
Task-Specific Architectures
Sequence Classification
Task: Classify entire input (sentiment, topic).
Input: [CLS] I loved this movie [SEP]
Output: Use [CLS] representation ā linear ā softmax ā class
Sentence Pair Classification
Task: Classify relationship between two texts (NLI, similarity).
Input: [CLS] Premise text [SEP] Hypothesis text [SEP]
Output: [CLS] representation ā linear ā entailment/contradiction/neutral
Token Classification (NER, POS)
Task: Label each token.
Input: [CLS] John lives in New York [SEP]
Labels: B-PER O O B-LOC I-LOC
Each token gets its own classification head output.
WordPiece handling:
- Training: Expand labels to all subword tokens
- Evaluation: Use label of first subword
Span Prediction (QA)
Task: Find answer span in context.
Input: [CLS] Where is Paris? [SEP] Paris is in France [SEP]
Output: Predict start position (index 5: "Paris")
Predict end position (index 8: "France")
Two classifiers: one for start, one for end position.
Modern Variants
RoBERTa (Robustly Optimized BERT)
Key changes:
- Remove NSP objective
- Larger batches, more data
- Dynamic masking (different masks each epoch)
ALBERT (A Lite BERT)
Parameter reduction:
- Factorized embedding parameters
- Cross-layer parameter sharing
- Sentence order prediction instead of NSP
DistilBERT
Knowledge distillation:
- 40% smaller, 60% faster
- 97% of BERT performance
Summary
| Concept | Key Point |
|---|---|
| Contextual embeddings | Same word ā different vectors in different contexts |
| Pre-training | Self-supervised on massive text |
| Fine-tuning | Task-specific with small labeled data |
| MLM | Predict masked tokens (BERT's main objective) |
| [CLS] token | Aggregate representation for classification |
| [SEP] token | Separate segments in input |
The Revolution
Transfer learning changed NLP:
| Before | After |
|---|---|
| Train from scratch per task | Pre-train once, fine-tune many times |
| Need lots of labeled data | Works with small labeled data |
| Shallow features | Deep contextual understanding |
| Task-specific architectures | One architecture, many tasks |
Practical Tips
- Start with pre-trained models ā rarely worth training from scratch
- Try multiple learning rates ā this is the most important hyperparameter
- Don't over-fine-tune ā 2-4 epochs is often enough
- Consider model size ā DistilBERT for production, BERT-large for best accuracy
- Domain matters ā SciBERT for science, BioBERT for biomedical, etc.