LLM Compression cho Mobile: Distillation + Quantization Recipes – Tối ưu On-Device Từng Bước

LLM Compression for Mobile: Distillation + Quantization Recipes — Mục tiêu: Step-by-step on-device optimization


1. Hải “Deep Dive” (Giảng viên) đăng bài

Hải “Deep Dive” (Giọng giảng viên): Đào sâu vào bản chất (Under the hood) của công nghệ. Giải thích cơ chế hoạt động bên dưới bề mặt.


2. Giới thiệu

Chào các bạn, hôm nay mình viết bài chia sẻ về chủ đề LLM Compression for Mobile, cụ thể là Distillation + Quantization để chạy mô hình AI trên thiết bị di động một cách hiệu quả. Đây là một chủ đề nóng trong lĩnh vực AI/ML, đặc biệt khi nhu cầu xử lý dữ liệu cục bộ (on-device) ngày càng tăng để bảo vệ quyền riêng tư và giảm độ trễ.


3. Use Case kỹ thuật

Use Case: Khi cần triển khai mô hình ngôn ngữ lớn (LLM) trên thiết bị di động với bộ nhớ RAM giới hạn (dưới 2GB) và yêu cầu thời gian phản hồi dưới 500ms.


4. Tổng quan về LLM Compression

4.1. Tại sao cần nén LLM?

  • Bộ nhớ: Các mô hình lớn như BERT, GPT-3 có thể chiếm hàng trăm GB.
  • Hiệu năng: Độ trễ cao khi inference trên thiết bị yếu.
  • Tiêu thụ năng lượng: Pin nhanh hết do xử lý nặng.

4.2. Hai kỹ thuật chính

  • Distillation (Chưng cất mô hình): Huấn luyện mô hình nhỏ hơn để học từ mô hình lớn.
  • Quantization (Lượng tử hóa): Giảm độ chính xác của trọng số (từ float32 xuống int8).

5. Distillation (Chưng cất mô hình)

5.1. Nguyên lý hoạt động

Distillation là quá trình huấn luyện một mô hình nhỏ (student) để bắt chước hành vi của mô hình lớn (teacher). Mô hình teacher cung cấp “knowledge” dưới dạng xác suất đầu ra (soft targets).

5.2. Các bước thực hiện

  1. Chuẩn bị mô hình teacher
    • Sử dụng mô hình lớn đã được huấn luyện sẵn (pre-trained).
    • Ví dụ: BERT-base, RoBERTa-large.
  2. Thiết kế mô hình student
    • Số lượng layer ít hơn.
    • Kích thước embedding nhỏ hơn.
    • Ví dụ: 6 layers thay vì 12, hidden size 384 thay vì 768.
  3. Huấn luyện với distillation loss
    • Kết hợp giữa distillation loss (học từ teacher) và task loss (mục tiêu chính).

5.3. Code mẫu

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # Soft targets from teacher
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_predictions = F.log_softmax(student_logits / self.temperature, dim=-1)

        # Distillation loss
        distill_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean')

        # Task loss
        task_loss = self.ce_loss(student_logits, labels)

        # Combined loss
        total_loss = self.alpha * distill_loss * (self.temperature**2) + (1 - self.alpha) * task_loss

        return total_loss

# Example usage
distill_criterion = DistillationLoss(temperature=3.0, alpha=0.5)
loss = distill_criterion(student_logits, teacher_logits, labels)

Lưu ý: Temperature cao giúp soft targets mịn hơn, dễ học hơn. Alpha điều chỉnh trọng số giữa distillation và task loss.


6. Quantization (Lượng tử hóa)

6.1. Nguyên lý hoạt động

Quantization chuyển đổi trọng số từ độ chính xác cao (float32) sang độ chính xác thấp (int8, int4) để giảm kích thước và tăng tốc độ tính toán.

6.2. Các loại quantization

  • Post-Training Quantization (PTQ): Quantize sau khi huấn luyện xong.
  • Quantization-Aware Training (QAT): Mô phỏng quantization trong quá trình huấn luyện.

6.3. Code mẫu với PyTorch

import torch
import torch.quantization as quantization

# Load pre-trained model
model = torch.load('model.pth')
model.eval()

# PTQ: Static quantization
model_fp32 = model
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Prepare model for quantization
model_prepared = torch.quantization.prepare(model_fp32)

# Calibrate with sample data
with torch.no_grad():
    for data in calibration_dataloader:
        model_prepared(data)

# Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)

# Save quantized model
torch.save(model_quantized.state_dict(), 'model_quantized.pth')

Cảnh báo: Cần calibration data đại diện để tránh mất mát độ chính xác quá mức.


7. Kết hợp Distillation + Quantization

7.1. Quy trình tổng thể

  1. Distillation: Huấn luyện student model từ teacher.
  2. Fine-tune: Tiếp tục huấn luyện student với dữ liệu domain cụ thể.
  3. Quantization: Áp dụng PTQ hoặc QAT lên student model.

7.2. Lợi ích

  • Distillation giảm kích thước mô hình.
  • Quantization giảm bộ nhớ và tăng tốc độ.
  • Kết hợp cả hai cho hiệu quả tối ưu.

8. Bảng so sánh kỹ thuật

Tiêu chí Distillation Quantization Kết hợp cả hai
Độ khó Trung bình Dễ Cao
Hiệu năng Giảm 30-50% tham số Giảm 75% bộ nhớ (int8) Giảm 90%+
Độ chính xác Mất 2-5% Mất 1-3% Mất 3-7%
Thời gian huấn luyện Dài Không cần Rất dài
Cộng đồng support Rất tốt Rất tốt Tốt

9. Case study: Triển khai trên Android

9.1. Môi trường

  • Thiết bị: Android 12, RAM 4GB, Snapdragon 778G.
  • Framework: TensorFlow Lite, PyTorch Mobile.

9.2. Kết quả đo đạc

Mô hình Kích thước Thời gian inference Độ chính xác
Teacher (BERT-base) 420MB 850ms 92.5%
Student (6-layer) 120MB 320ms 90.2%
Student + Quantization 30MB 110ms 88.7%

Nhận xét: Kết hợp distillation + quantization giúp giảm 93% kích thước và 87% thời gian inference.


10. Best Practices

  1. Chọn teacher model phù hợp: Không nhất thiết phải dùng model lớn nhất.
  2. Thiết kế student model cân đối: Đừng nhỏ quá dẫn đến underfitting.
  3. Calibration data chất lượng: Đại diện cho dữ liệu thực tế.
  4. Kiểm thử trên thiết bị thật: Emulator không phản ánh chính xác hiệu năng.
  5. Theo dõi memory leak: Mobile dễ bị crash do rò rỉ bộ nhớ.

11. Cảnh báo rủi ro

Cảnh báo: Không nên copy-paste code quantization từ GitHub mà không kiểm tra. Mỗi model có đặc thù riêng, cần điều chỉnh qconfig phù hợp.


12. Cập nhật xu hướng

Theo Hải “Futurist”, trong 2-3 năm tới:
QLoRA (Quantized Low-Rank Adaptation) sẽ phổ biến hơn.
TinyML phát triển mạnh cho IoT.
On-device LLM có thể đạt 7B tham số trên flagship phone.


13. Kết luận

3 điểm cốt lõi:

  1. Distillation giúp giảm kích thước mô hình hiệu quả.
  2. Quantization tối ưu bộ nhớ và tốc độ inference.
  3. Kết hợp cả hai là con đường tất yếu để deploy LLM trên mobile.

Câu hỏi thảo luận:

Anh em đã từng triển khai LLM trên thiết bị di động bao giờ chưa? Gặp những khó khăn gì và giải quyết thế nào?

Kêu gọi hành động:

Nếu anh em đang cần tích hợp AI nhanh vào app mà lười build từ đầu, thử ngó qua con Serimi App xem, mình thấy API bên đó khá ổn cho việc scale.

Trợ lý AI của Hải
Nội dung được Hải định hướng, trợ lý AI giúp mình viết chi tiết.
Chia sẻ tới bạn bè và gia đình