Low-Rank Adaptation (LoRA) for BERT Classification

- 10 mins

Introduction

A Parameter-Efficient Alternative to Full Fine-Tuning

Modern deep learning workflows rely heavily on pretrained models. Models such as BERT are trained on massive corpora and subsequently adapted to downstream tasks such as classification, retrieval, and ranking. The dominant paradigm for adaptation has historically been full fine-tuning, where all parameters of the pretrained model are updated using task-specific data.

While effective, full fine-tuning is computationally inefficient, memory intensive, and fundamentally redundant. This inefficiency becomes more pronounced as model size increases. Low-Rank Adaptation (LoRA) addresses this inefficiency by constraining weight updates to a low-dimensional subspace, dramatically reducing the number of trainable parameters while preserving performance.

Although originally introduced for optimizing large language models (LLMs), LoRA is not specific to autoregressive transformers. It is a general method applicable to any neural network containing linear transformations. This includes encoder-only architectures such as BERT, as well as recommendation systems, retrieval towers, and vision transformers.

This article develops the mathematical foundation of LoRA and applies it to a BERT classification model, comparing it with standard fine-tuning in terms of compute, memory, and performance.


Fine-Tuning as Weight Perturbation

Consider a linear layer in a neural network:

\[y = Wx\]

where:

\[W \in \mathbb{R}^{d \times k}\]

During fine-tuning, the pretrained weights are modified:

\[W' = W + \Delta W\]

The fine-tuning process learns the update matrix:

\[\Delta W \in \mathbb{R}^{d \times k}\]

This matrix contains: $d\times k$ trainable parameters.

For example, in BERT-base:

\[d = k = 768\]

Thus:

\[|\Delta W| = 768 \times 768 = 589,824\]

per projection matrix.

Since transformers contain many such matrices, total trainable parameters during fine-tuning reach hundreds of millions.

We can reduce the number of trainable parameters by constraining $\Delta W$ to have low rank.

This leads to significant savings in both memory and computation.


Low-Rank Structure of Weight Updates

The key insight underlying LoRA is that weight updates lie in a low-dimensional subspace.

Formally, LoRA assumes:

\[\text{rank}(\Delta W) \le r\]

where:

\[r \ll \min(d,k)\]

Any rank-r matrix can be factorized as:

\[\Delta W = BA\]

where:

\[B \in \mathbb{R}^{d \times r}\] \[A \in \mathbb{R}^{r \times k}\]

Thus, instead of learning:

\[dk\]

parameters, we learn:

\[r(d + k)\]

parameters.

For example:

\[d = k = 768, \quad r = 8\]

Full fine-tuning:

\[589,824\]

LoRA:

\[8(768 + 768) = 12,288\]

This is a reduction of approximately 48×.


LoRA Forward Pass Formulation

The original linear transformation:

\[y = Wx\]

becomes:

\[y = Wx + BAx\]

The pretrained matrix:

\[W\]

is frozen.

Only:

\[A, B\]

are trainable.

The LoRA paper introduces a scaling factor:

\[y = Wx + \frac{\alpha}{r} BAx\]

where: $\alpha$ controls update magnitude.

This scaling stabilizes optimization.


Application to Transformer Architectures

Transformers rely heavily on linear projections.

In particular, attention layers contain projections:

\[Q = W_Q x\] \[K = W_K x\] \[V = W_V x\]

LoRA modifies these as:

\[Q = W_Q x + B_Q A_Q x\] \[K = W_K x + B_K A_K x\] \[V = W_V x + B_V A_V x\]

This enables adaptation without modifying pretrained weights.

The frozen weights retain general knowledge, while LoRA learns task-specific adjustments.


Applying LoRA to BERT Classification

The standard BERT classification architecture consists of:

\[\text{Input} \rightarrow \text{BERT Encoder} \rightarrow \text{CLS embedding} \rightarrow \text{Classifier}\]

During full fine-tuning, all encoder weights are updated.

Total parameters:

\[110M\]

With LoRA:

Let us quantify parameter reduction.

BERT-base attention layers:

\[12 \text{ layers}\]

Each layer contains 3 projections (Q,K,V):

\[3 \times 768 \times 768\]

Total attention parameters:

\[21M\]

With LoRA rank:

\[r = 8\]

Trainable parameters become:

\[3 \times 12 \times 8 \times (768 + 768)\] \[= 442,368\]

This represents a reduction of approximately:

\[50×\]

relative to attention parameters alone, and over 200× relative to full model fine-tuning.


Memory and Compute Efficiency

Training memory consists of:

For Adam optimizer:

\[\text{memory} \approx 3 \times \text{parameters}\]

Full fine-tuning:

\[110M \rightarrow 330M\]

LoRA fine-tuning:

\[0.44M \rightarrow 1.32M\]

This is a reduction of approximately:

\[250×\]

Backward pass compute is also proportional to trainable parameters.

Thus LoRA reduces training compute significantly.

Forward pass compute remains nearly identical.


Applying LoRA to BERT Classification

BERT consists of multiple attention layers containing projection matrices. LoRA inserts low-rank adapters into these projections.

We compare two training regimes:

The task used is binary sentiment classification on IMDb.


Script 1: Full Fine-Tuning

import torch
import numpy as np
import evaluate

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)

MODEL_NAME = "bert-base-uncased"

ds = load_dataset("imdb")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, max_length=256)

ds_tok = ds.map(tokenize, batched=True, remove_columns=["text"])

collator = DataCollatorWithPadding(tokenizer)

metric = evaluate.load("accuracy")

def compute_metrics(p):
    logits, labels = p
    preds = np.argmax(logits, axis=-1)
    return metric.compute(predictions=preds, references=labels)

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
)

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Full FT trainable params:", count_trainable_params(model))

args = TrainingArguments(
    output_dir="./bert_full",
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    num_train_epochs=2,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_tok["train"],
    eval_dataset=ds_tok["test"],
    data_collator=collator,
    compute_metrics=compute_metrics,
)

trainer.train()
full_metrics = trainer.evaluate()

print(full_metrics)

Script 2: LoRA Fine-Tuning

import torch
from peft import LoraConfig, get_peft_model, TaskType

base_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "value"],
)

model = get_peft_model(base_model, lora_config)

print("LoRA trainable params:", count_trainable_params(model))

args = TrainingArguments(
    output_dir="./bert_lora",
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=5e-4,
    num_train_epochs=2,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_tok["train"],
    eval_dataset=ds_tok["test"],
    data_collator=collator,
    compute_metrics=compute_metrics,
)

trainer.train()
lora_metrics = trainer.evaluate()

print(lora_metrics)

Comparison: Trainable Parameters and Accuracy

The key difference between the two approaches is the number of trainable parameters.

Method Trainable Parameters Trainable % Accuracy
Full Fine-Tuning 110M 100% 0.919720
LoRA Fine-Tuning 0.29M 0.26% 0.908040

With the parameter reduction of over 200×, LoRA achieves an accuracy of 90.8%, which is only slightly lower than the 91.9% achieved by full fine-tuning ( only 1.1% reduction ).

This demonstrates that downstream adaptation requires only a small structured update, not full weight modification.


Interpretation

Pretrained transformer weights already encode a rich representation space. Fine-tuning does not need to reconstruct this space. It only needs to adjust projections along task-relevant directions.

LoRA explicitly models this adjustment as a low-rank update.

This reduces optimization dimensionality while preserving expressive capacity.


Optimization Perspective

Standard fine-tuning solves:

\[\min_W L(W)\]

LoRA solves:

\[\min_{A,B} L(W + BA)\]

subject to rank constraint:

\[\text{rank}(BA) \le r\]

This acts as a structural regularizer.

It prevents overfitting while preserving expressive adaptation.


Storage and Deployment Benefits

Full fine-tuned BERT model:

~420MB

LoRA adapter:

~3MB

Adapters can be stored separately and merged at inference.

This enables efficient multi-task deployment.


Generalization Beyond LLMs

Although LoRA was introduced for large language models, it applies universally to neural networks containing linear layers.

This includes:

The underlying principle remains identical: weight updates are low rank.


Conclusion

Low-Rank Adaptation transforms fine-tuning from a high-dimensional optimization problem into a constrained low-dimensional one.

Instead of updating:

\[dk\]

parameters, LoRA updates:

\[r(d+k)\]

parameters, where:

\[r \ll d,k\]

This yields substantial reductions in memory, compute, and storage while preserving performance.

LoRA demonstrates that pretrained models already contain most of the necessary representational capacity. Downstream tasks require only small structured adjustments, not full parameter updates.

This insight has broad implications for scalable training and deployment of neural networks.


References

  1. Hu et al., 2021, LoRA: Low-Rank Adaptation of Large Language Models, https://arxiv.org/abs/2106.09685

  2. Devlin et al., 2018, BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, https://arxiv.org/abs/1810.04805

  3. Vaswani et al., 2017, Attention Is All You Need, https://arxiv.org/abs/1706.03762

  4. Dettmers et al., 2023, QLoRA: Efficient Finetuning of Quantized LLMs, https://arxiv.org/abs/2305.14314

  5. Microsoft LoRA implementation, https://github.com/microsoft/LoRA

  6. HuggingFace PEFT library, https://github.com/huggingface/peft