This repository contains code for Named-Entity Recognition (NER) using finetuned RoBERTa models. The dataset has been borrowed from Kaggle. The objective is to accurately predict entity tags for words in sentences, distinguishing between different entity types such as people (per), organizations (org), locations (geo), and more.
The input dataset has the following structure:
Sentence # | Word | POS | Tag |
---|---|---|---|
Sentence: 1 | Thousands | NNS | O |
of | IN | O | |
demonstrators | NNS | O | |
have | VBP | O | |
marched | VBN | O | |
through | IN | O | |
London | NNP | B-geo | |
to | TO | O |
Where:
We use the RoBERTa base model, which is an optimized version of BERT with improved training methodology. RoBERTa features:
For NER, we add a token classification head on top of the RoBERTa encoder, which consists of a dropout layer followed by a linear layer mapping to the number of entity classes.
The model is trained using:
We evaluate the model using:
We report metrics at multiple levels:
For NER specifically, we use a strict match evaluation - an entity prediction is considered correct only if both the entity type and its exact boundary (start and end positions) match the ground truth.
├── data # DATA FILES
│ ├── data.csv # Raw Dataset from Kaggle
│ ├── train.csv # Training split from data.csv
│ └── valid.csv # Validation split from data.csv
│ ├── mapping.json # Mapping of labels to index
│ ├── score.csv # Model predictions on valid.csv
│ ├── model # Finetuned Model
│ │ └── model.pt-v1.ckpt
│ ├── report.txt # Evaluation on valid.csv
├── src # SOURCE CODE
│ ├── data.py # Data Loaders
│ ├── model.py # HuggingFace Model
│ ├── score.py # Scoring Model
│ └── train.py # Training Model
├── main.py
├── conf.yaml
Download the "data.csv" file from Kaggle and place it in data folder.
Command to split to preprocess data file and split it into training, validation:
Command to finetune the model:
Command to evaluate the model:
The final model performance is saved in report.txt file. It looks like:
label | precision | recall | f1-score | support |
---|---|---|---|---|
art | 0.38 | 0.06 | 0.10 | 161 |
eve | 0.22 | 0.23 | 0.23 | 61 |
geo | 0.86 | 0.88 | 0.87 | 12744 |
gpe | 0.94 | 0.95 | 0.95 | 5020 |
nat | 0.35 | 0.44 | 0.39 | 73 |
org | 0.70 | 0.70 | 0.70 | 6655 |
per | 0.76 | 0.81 | 0.78 | 5195 |
tim | 0.83 | 0.82 | 0.82 | 3942 |
micro avg | 0.82 | 0.83 | 0.82 | 33851 |
macro avg | 0.63 | 0.61 | 0.60 | 33851 |
weighted avg | 0.82 | 0.83 | 0.82 | 33851 |
The performance metrics reveal several important insights:
The weighted average metrics closely match the micro-average, confirming that performance is dominated by high-frequency classes.
The wandb logs are below: