Supervised Fine-Tuning (SFT) for LLMs
- 9 minsIntroduction
Supervised Fine-Tuning (SFT) is the most common way to adapt a pre-trained Large Language Model (LLM) to follow instructions, match a domain-specific style, or improve task performance. At a high level, SFT looks simple: collect prompt–response pairs and train the model with next-token prediction.
In practice, most SFT failures are not caused by model architecture or optimizers. They come from:
- incorrect label construction (learning on the wrong tokens)
- poor data formatting (template mismatches)
- inefficient batching (wasted compute)
- evaluation setups that hide regressions
This post focuses on SFT from an engineering perspective: how it works, what can go wrong, and how to implement it cleanly. We will first start off by explaining what SFT is, and then go into the details of implementation.
What is Supervised Fine-Tuning (SFT)?
Supervised Fine-Tuning (SFT) is the process of adapting a pre-trained language model using human-written (or curated) prompt–response pairs. The model is trained to produce the desired assistant response when given an instruction, question, or conversation context.
At a high level, SFT answers this question:
Given a prompt that represents what the user says, how do we train the model to respond like a helpful assistant?
This is usually the first alignment step after pre-training.
Pre-training teaches the model general language understanding and world knowledge.
SFT teaches the model how to behave: follow instructions, format answers correctly, maintain tone, and respond consistently.
SFT vs Pre-training (Key Differences)
Pre-training
- Objective: predict the next token on large-scale raw text
- Data: internet-scale corpora, books, code, etc.
- Outcome: the model learns general patterns and representations
SFT
- Objective: predict the assistant’s response conditioned on a prompt
- Data: curated instruction datasets (chat-like samples)
- Outcome: the model becomes instruction-following and conversational
In other words:
- Pre-training learns knowledge + language
- SFT learns interaction behavior
What does SFT Data Look Like?
SFT data typically looks like:
- Instruction / user message: what the user asks
- Assistant response: the desired output
For chat models, each example is often structured as:
system: global behavior guidelinesuser: instruction / queryassistant: target response
Example:
<System>
You are a helpful assistant.
<User>
Explain gradient clipping in deep learning.
<Assistant>
Gradient clipping limits the norm of gradients to prevent unstable updates...
During SFT, the model is trained to generate the assistant part, conditioned on everything before it.
Optimization Objective for SFT
A causal LLM models a sequence of tokens $x_{1:T}$ using:
\[p_\theta(x_{1:T}) = \prod_{t=1}^{T} p_\theta(x_t \mid x_{<t})\]In SFT, we build training samples that include:
- the prompt / instruction
- the assistant response
and optimize the negative log-likelihood over tokens that belong to the assistant response.
The loss for a single training example becomes:
\[\mathcal{L}(\theta) = -\sum_{t \in \mathcal{T}_{\text{assistant}}} \log p_\theta(x_t \mid x_{<t})\]where $\mathcal{T}_{\text{assistant}}$ is the set of token positions we actually want to learn from.
If you compute loss over the entire concatenated sequence (prompt + response), the model learns to predict:
- user instructions
- system prompts
- formatting delimiters
- assistant response
This dilutes the learning signal. Instead, we mask out tokens that do not belong to the assistant response during loss computation.
Data Formatting
A large portion of instruction-following capability comes from consistent formatting. For chat models, training examples should follow the same template used during inference.
A typical chat sample looks like:
system: high-level behavior policyuser: instruction / inputassistant: expected response
SFT should use the exact message structure, special tokens, and separators expected by the model tokenizer. Mismatched templates degrade quality because the model learns a distribution that does not match inference-time prompts.
An example of data formatting used for SFT
We assume the dataset consists of:
prompt: string (system + user messages)response: string (assistant message)
At training time, we concatenate:
\[\text{tokens} = \text{tokenize}(\text{prompt}) \; || \; \text{tokenize}(\text{response})\]and create labels such that:
- prompt tokens are ignored (
-100) - response tokens have true labels
Minimal PyTorch Implementation
1) Tokenization and Label Creation
import torch
from torch.utils.data import Dataset, DataLoader
IGNORE_INDEX = -100
class SFTDataset(Dataset):
def __init__(self, data, tokenizer, max_length=512):
"""
data: list of dicts:
{ "prompt": str, "response": str }
"""
self.data = data
self.tok = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
ex = self.data[idx]
prompt = ex["prompt"]
response = ex["response"]
prompt_ids = self.tok.encode(prompt, add_special_tokens=False)
response_ids = self.tok.encode(response, add_special_tokens=False)
# concatenate
input_ids = prompt_ids + response_ids
input_ids = input_ids[: self.max_length]
# labels: ignore prompt tokens, learn only response tokens
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
labels = labels[: self.max_length]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
2) Padding Collator
def pad_collate(batch, pad_id: int):
"""
Pads input_ids with pad_id.
Pads labels with IGNORE_INDEX so padding doesn't contribute to loss.
"""
max_len = max(x["input_ids"].size(0) for x in batch)
input_ids = []
labels = []
for x in batch:
ids = x["input_ids"]
lab = x["labels"]
pad_len = max_len - ids.size(0)
input_ids.append(torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)]))
labels.append(torch.cat([lab, torch.full((pad_len,), IGNORE_INDEX, dtype=torch.long)]))
return {
"input_ids": torch.stack(input_ids),
"labels": torch.stack(labels),
"attention_mask": (torch.stack(input_ids) != pad_id).long()
}
Training Objective and Forward Pass
Given input_ids, a causal model outputs logits:
- shape:
(batch, seq_len, vocab_size)
For next-token prediction, we shift by one:
predict token $t$ using logits at position $t-1$
Minimal training step
import torch.nn.functional as F
def sft_loss(logits, labels):
"""
logits: (B, T, V)
labels: (B, T) with IGNORE_INDEX for masked tokens
"""
# shift for causal LM loss
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=IGNORE_INDEX
)
return loss
Efficient Batching: Packing vs Padding
Padding wastes compute because all sequences are padded to the maximum length in the batch.
Packing reduces waste by concatenating multiple shorter samples into one long sequence until max_length is reached.
If the average sample length is 250 and max length is 2048:
- padding wastes ~88% tokens
- training becomes slower and more expensive
Packing improves throughput significantly, especially for long-context training.
Common Failures in SFT
-
Incorrect Label Masking: Failing to mask out prompt tokens leads to learning the wrong distribution.
-
Data Formatting Mismatches: Using templates that differ from inference prompts confuses the model.
-
Inefficient Batching: Excessive padding wastes compute and slows training.
-
Overfitting on Synthetic Data: Relying too much on synthetic or low-quality data degrades generalization.
-
Low-quality preference in data: If the responses in the dataset are not well-aligned with desired behavior, the model will learn to replicate those flaws and hallucinate more.
Evaluation Strategies
SFT success should be measured using:
- task-specific metrics (accuracy, F1, extraction exact match)
- instruction-following win-rate on curated prompts
- safety regression tests
- qualitative checks for refusal and tone control
A small evaluation set with strong coverage is more useful than a large generic benchmark.
Conclusion
Supervised Fine-Tuning (SFT) is a critical step in adapting Large Language Models to behave as helpful assistants. While the core idea is straightforward, successful SFT requires careful attention to data formatting, label masking, batching efficiency, and evaluation strategies.
References
- Fine-Tuning Language Models from Human Preferences: Stiennon et al. (2020)
- Scaling Laws for Neural Language Models: Kaplan et al. (2020)
- Language Models are Few-Shot Learners: Brown et al. (2020)
- Training Language Models to Follow Instructions with Human Feedback: Ouyang et al. (2022)