Bài 3: Fine-tuning với LoRA/QLoRA¶
Tổng quan¶
Fine-tuning điều chỉnh weights của model pre-trained để tối ưu cho task cụ thể. Bài này bao gồm framework quyết định khi nào nên fine-tune, các phương pháp PEFT, và thực hành fine-tune Llama với TRL.
1. Decision Framework: Prompt vs RAG vs Fine-tuning¶
Đây là câu hỏi quan trọng nhất trước khi đầu tư vào fine-tuning.
graph TD
START[Vấn đề của bạn?] --> Q1{Model cần kiến thức<br/>từ tài liệu cụ thể?}
Q1 -->|Có| RAG[→ RAG Pipeline<br/>Thêm tài liệu vào context]
Q1 -->|Không| Q2{Model cần thay đổi<br/>hành vi / style / format?}
Q2 -->|Có thể làm với prompting| PROMPT[→ Prompt Engineering<br/>System prompt, few-shot]
Q2 -->|Không đủ| Q3{Có đủ training data<br/>không? ≥ 1000 examples}
Q3 -->|Có| FINETUNE[→ Fine-tuning]
Q3 -->|Không| SYNTHETIC[→ Tạo synthetic data<br/>rồi fine-tune]
So sánh chi tiết¶
| Prompt Engineering | RAG | Fine-tuning | |
|---|---|---|---|
| Chi phí setup | Thấp | Trung bình | Cao |
| Cần training data | Không | Không | Có (≥ 100–1000 samples) |
| Kiến thức mới | Không (knowledge cutoff) | ✅ Có (real-time) | Tích hợp vào weights |
| Thay đổi hành vi | Hạn chế | Không | ✅ Sâu, nhất quán |
| Latency | Thấp | Trung bình (retrieval) | Thấp |
| Chi phí inference | Cao (long context) | Trung bình | Thấp |
Khi nào nên Fine-tune?¶
✅ Nên fine-tune khi:
- Cần format output rất cụ thể mà prompting không đạt được nhất quán
- Domain-specific style (giọng văn pháp lý, y tế, v.v.)
- Cần giảm latency và chi phí (context ngắn hơn)
- Cần model hoạt động tốt không có internet (embedded)
- Có task lặp lại hàng triệu lần → tiết kiệm chi phí
❌ Không nên fine-tune khi:
- Vấn đề có thể giải quyết bằng better prompting
- Cần thông tin cập nhật liên tục (dùng RAG)
- Không có đủ labeled data chất lượng cao
- Đang proof-of-concept (chi phí cao, iteration chậm)
2. Full Fine-tuning vs PEFT¶
Full Fine-tuning¶
Update toàn bộ parameters của model.
# Tất cả weights đều được cập nhật
for param in model.parameters():
param.requires_grad = True # Tất cả đều train
Vấn đề:
- Llama 3 8B: ~8B parameters × 2 bytes (FP16) = 16 GB chỉ cho weights
- Optimizer states (AdamW) chiếm thêm 2× weights = +32 GB
- Gradient checkpointing: +8 GB nữa
- Tổng: ~56 GB VRAM - cần nhiều A100 80GB
PEFT (Parameter-Efficient Fine-Tuning)¶
Chỉ train một phần nhỏ parameters mới được thêm vào, giữ nguyên weights gốc.
Full FT: [W_original] → [W_updated] (train 100% params)
PEFT/LoRA: [W_original] + [∆W_small] (train 0.1-1% params)
↑
Chỉ cần train cái này
| Phương pháp | % Params Trainable | Memory | Chất lượng |
|---|---|---|---|
| Full FT | 100% | ~56 GB (8B) | Tốt nhất |
| LoRA | ~0.5–2% | ~6–12 GB | Gần bằng Full FT |
| QLoRA | ~0.5–2% | ~4–6 GB | Tốt (nhỏ hơn LoRA) |
| Prompt Tuning | <0.01% | Rất thấp | Thấp hơn |
3. LoRA - Low-Rank Adaptation¶
Nguyên lý¶
LoRA giả định rằng "knowledge updates" trong fine-tuning có rank thấp - tức là có thể biểu diễn bằng tích hai ma trận nhỏ.
Weight matrix W (d × d): Được freeze, không thay đổi
LoRA matrices: A (d × r) và B (r × d) ← Chỉ train 2 ma trận này
Output = W·x + (B·A)·x · (α/r)
↑ ↑
frozen trainable
Với r = 8 và d = 4096 (Llama 7B):
W: 4096 × 4096 = 16,777,216 params (frozen)
A: 4096 × 8 = 32,768 params
B: 8 × 4096 = 32,768 params
Total trainable: 65,536 params ← chỉ 0.4% của W!
Hyperparameters LoRA¶
| Parameter | Ý nghĩa | Giá trị thường dùng |
|---|---|---|
r (rank) |
Rank của ma trận decomposition - cao hơn = nhiều params hơn | 8, 16, 32, 64 |
lora_alpha |
Scaling factor - ảnh hưởng learning rate hiệu quả | Thường = r hoặc 2×r |
lora_dropout |
Dropout để tránh overfitting | 0.05–0.1 |
target_modules |
Modules nào được áp dụng LoRA | q_proj, v_proj (tối thiểu) hoặc tất cả attention |
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32, # α/r = 2 → scaling = 2
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # Attention
"gate_proj", "up_proj", "down_proj", # FFN
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
Chọn r như thế nào:
r=8: Task đơn giản, ít datar=16: Balanced, phổ biến nhấtr=32–64: Task phức tạp, nhiều data; cần nhiều VRAM hơn
4. QLoRA - Quantized LoRA¶
QLoRA kết hợp quantization (4-bit) + LoRA:
- Load model ở 4-bit NF4 (giảm memory đáng kể)
- Thêm LoRA adapters ở FP16/BF16
- Chỉ train LoRA adapters
from transformers import BitsAndBytesConfig
import torch
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
5. Fine-tuning Llama 3.2 3B với TRL SFTTrainer¶
Chuẩn bị Data¶
Dataset format phổ biến: Alpaca format
# Mỗi example có 3 trường
{
"instruction": "Dịch câu sau sang tiếng Anh.",
"input": "Hôm nay trời đẹp quá.",
"output": "The weather is so beautiful today."
}
from datasets import load_dataset
dataset = load_dataset("5CD-AI/Vietnamese-alpaca-gpt4-gg-translated", split="train")
# Format thành prompt
def format_prompt(example):
if example["input"]:
return f"""### Instruction:
{example["instruction"]}
### Input:
{example["input"]}
### Response:
{example["output"]}"""
else:
return f"""### Instruction:
{example["instruction"]}
### Response:
{example["output"]}"""
dataset = dataset.map(lambda x: {"text": format_prompt(x)})
Full Training Script¶
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
OUTPUT_DIR = "./llama-3.2-3b-alpaca-vi"
# 1. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# 2. Quantization config (QLoRA)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# 3. Load model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
)
model.config.use_cache = False
# 4. LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# 5. Training config
sft_config = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4, # Effective batch = 2×4 = 8
learning_rate=2e-4,
fp16=False,
bf16=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=10,
save_strategy="epoch",
max_seq_length=1024,
dataset_text_field="text",
packing=True, # Đóng gói nhiều examples vào 1 sequence
)
# 6. Dataset
dataset = load_dataset("5CD-AI/Vietnamese-alpaca-gpt4-gg-translated", split="train")
# ... format dataset như trên
# 7. Trainer
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
peft_config=peft_config,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model()
6. Merge Adapter & Push lên HF Hub¶
Merge LoRA adapter vào base model¶
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model (full precision)
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.float16,
device_map="auto",
)
# Load và merge adapter
model = PeftModel.from_pretrained(base_model, "./llama-3.2-3b-alpaca-vi")
model = model.merge_and_unload() # ← Merge vào base weights
print("Merged! Model sẵn sàng để deploy.")
Push lên Hugging Face Hub¶
from huggingface_hub import login
login(token="hf_...") # HF token của bạn
repo_id = "your-username/llama-3.2-3b-vietnamese-assistant"
# Push model
model.push_to_hub(repo_id, use_auth_token=True)
tokenizer.push_to_hub(repo_id, use_auth_token=True)
print(f"Model available at: https://huggingface.co/{repo_id}")
Model Card tự động¶
from transformers import TrainingArguments
from trl import ModelCard
# Tạo model card mô tả training details
card = ModelCard.from_template(
card_data={
"base_model": "meta-llama/Llama-3.2-3B-Instruct",
"datasets": ["5CD-AI/Vietnamese-alpaca-gpt4-gg-translated"],
"tags": ["llama", "vietnamese", "qlora", "peft"],
}
)
card.push_to_hub(repo_id)
Tóm tắt¶
| Khái niệm | Key Points |
|---|---|
| Decision Framework | Thử Prompt → RAG → Fine-tune theo thứ tự chi phí tăng dần |
| LoRA | Thêm 2 ma trận nhỏ (A, B) cho mỗi weight; chỉ train A và B |
| QLoRA | LoRA + 4-bit quantization → train 8B model trên GPU 24GB |
| r (rank) | Cao hơn = nhiều params hơn = chất lượng cao hơn = chậm hơn |
| lora_alpha | Thường = 2×r; ảnh hưởng đến tốc độ học |
| Merge | merge_and_unload() → single model không cần PEFT library khi inference |
Pitfalls phổ biến
- Catastrophic forgetting: Fine-tune quá nhiều epoch → model quên kiến thức gốc → dùng
rnhỏ hơn hoặc epoch ít hơn - Data quality > quantity: 500 examples chất lượng cao tốt hơn 5000 examples kém
- Learning rate quá cao: 2e-4 là điểm khởi đầu tốt; quá cao sẽ phá vỡ model