Entity Extraction

Named-Entity Recognition with Fine-tuned RoBERTa Models

View on GitHub

Overview

This repository contains code for Named-Entity Recognition (NER) using finetuned RoBERTa models. The dataset has been borrowed from Kaggle. The objective is to accurately predict entity tags for words in sentences, distinguishing between different entity types such as people (per), organizations (org), locations (geo), and more.

Dataset Structure

The input dataset has the following structure:

Sentence # Word POS Tag
Sentence: 1 Thousands NNS O
of IN O
demonstrators NNS O
have VBP O
marched VBN O
through IN O
London NNP B-geo
to TO O

Where:

ML Approach

Model Architecture

We use the RoBERTa base model, which is an optimized version of BERT with improved training methodology. RoBERTa features:

For NER, we add a token classification head on top of the RoBERTa encoder, which consists of a dropout layer followed by a linear layer mapping to the number of entity classes.

Training Process

The model is trained using:

Data Preprocessing

Evaluation Metrics

We evaluate the model using:

We report metrics at multiple levels:

For NER specifically, we use a strict match evaluation - an entity prediction is considered correct only if both the entity type and its exact boundary (start and end positions) match the ground truth.

Code Structure

├── data                        # DATA FILES
│   ├── data.csv                    # Raw Dataset from Kaggle 
│   ├── train.csv                   # Training split from data.csv
│   └── valid.csv                   # Validation split from data.csv
│   ├── mapping.json                # Mapping of labels to index
│   ├── score.csv                   # Model predictions on valid.csv
│   ├── model                       # Finetuned Model
│   │   └── model.pt-v1.ckpt
│   ├── report.txt                  # Evaluation on valid.csv

├── src                         # SOURCE CODE
│   ├── data.py                     # Data Loaders
│   ├── model.py                    # HuggingFace Model
│   ├── score.py                    # Scoring Model
│   └── train.py                    # Training Model

├── main.py                     
├── conf.yaml

Instructions

Download the "data.csv" file from Kaggle and place it in data folder.

Command to split to preprocess data file and split it into training, validation:

python main.py --mode prepare-data

Command to finetune the model:

python main.py --mode train-model

Command to evaluate the model:

python main.py --mode score-model

Performance

The final model performance is saved in report.txt file. It looks like:

label precision recall f1-score support
art 0.38 0.06 0.10 161
eve 0.22 0.23 0.23 61
geo 0.86 0.88 0.87 12744
gpe 0.94 0.95 0.95 5020
nat 0.35 0.44 0.39 73
org 0.70 0.70 0.70 6655
per 0.76 0.81 0.78 5195
tim 0.83 0.82 0.82 3942
micro avg 0.82 0.83 0.82 33851
macro avg 0.63 0.61 0.60 33851
weighted avg 0.82 0.83 0.82 33851

Performance Analysis

The performance metrics reveal several important insights:

  1. Class imbalance impact:
    • High-support classes (geo: 12,744, org: 6,655, per: 5,195) achieve F1 scores of 0.87, 0.70, and 0.78 respectively
    • Low-support classes (art: 161, eve: 61, nat: 73) achieve much lower F1 scores of 0.10, 0.23, and 0.39
  2. Precision-recall balance:
    • For most entity types, precision and recall are fairly balanced
    • Exception is 'art' entities where precision (0.38) significantly exceeds recall (0.06), indicating the model is conservative in predicting this class
  3. Overall performance:
    • Micro-average F1 of 0.82 indicates strong general performance
    • Gap between micro-average (0.82) and macro-average (0.60) F1 scores confirms the model performs substantially better on common entity types
  4. Areas for improvement:
    • Targeted strategies for low-resource classes:
      • Data augmentation for minority classes
      • Class weighting in loss function
      • Few-shot or meta-learning approaches
    • Entity boundary detection could be improved for complex entities

The weighted average metrics closely match the micro-average, confirming that performance is dominated by high-frequency classes.

Training Progress

The wandb logs are below:

Training logs from Weights & Biases showing model performance