Knowledge Distillation với PyTorch
2021-01-24 08:47:13
Knowledge Distillation (KD) là một kỹ thuật nén model (model compression) sao cho độ chính xác (hoặc một thước đo khác như mAP, và F1-score, ...) không thay đổi nhiều so với model gốc.

Random Forests meme

Bài toán thực tế

Nén model phân loại chữ viết tay sử dụng kỹ thuật KD, với bộ dữ liệu MNIST

Bước 1: Import các thư viện cần thiết

1
2
3
4
5
6
7
8
9
10
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from numpy import vstack
from sklearn.metrics import accuracy_score

Bước 2: Load dữ liệu

Ở bước này chúng ta sẽ load bộ dữ liệu MNIST,và biến đổi dữ liệu sao cho model có thể sử dụng được. Lưu ý rằng bộ dữ liệu này có sẵn trong thư viện torchvision rồi, nên chỉ cần load lên mà không cần vào trang chủ của bộ dữ liệu để tải nha.

1
2
3
4
5
6
7
8
9
10
11
12
13
# định nghĩa phép chuyển đổi: thay đổi kích thước ảnh thành 32x32, và chuyển đổi
# dữ liệu sang dạng tensor
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
# load tập dữ liệu train và test, và áp dụng phép chuyển đổi cho tập dữ liệu
train_set = MNIST(root='tmp/', train=True, download=True, transform=transform)
test_set = MNIST(root='tmp/', train=False, download=True, transform=transform)

# tạo loader để truyền vào model dữ liệu theo batch
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

Bước 3: Xây dựng các model

Chúng ta sẽ xây dựng 2 model ở bước này:

  • Model giáo viên (teacher): model gốc (LeNet5)
  • Model học sinh (student): model sau khi nén bằng cách giữ nguyên số lượng layer, nhưng giảm số lượng tham số (parameters)
Model giáo viên
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Teacher(nn.Module):
def __init__(self):
# Ref: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
super(Teacher, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
Model học sinh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Student(nn.Module):
def __init__(self):
super(Student, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 12, 3)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(12 * 6 * 6, 90) # 6*6 from image dimension
self.fc2 = nn.Linear(90, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features

Bước 3: Train các model

Model giáo viên

Đầu tiên chúng ta sẽ định nghĩa hàm loss và optimizer cho model như sau:

1
2
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.001, momentum=0.9)

Sau đó train model 100 epochs;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
for epoch in range(100):
running_loss = 0.0
for i, data in enumerate(train_loader, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data

inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = teacher(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
print('Finished Training')
Model học sinh

Ngoài định nghĩa hàm loss và optimizer như model giáo viên, chúng ta sẽ khai báo thêm 2 tham số cố định $alpha$ và $temperature$

1
2
alpha = 0.1
temperature = 10

Và train model 50 epochs:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(train_loader, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data

inputs = inputs.to(device)
labels = labels.to(device)

with torch.no_grad():
teacher_pred = teacher(inputs)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
student_pred = student(inputs)

student_loss = criterion(student_pred, labels)

distillation_loss = F.kl_div(
F.log_softmax(teacher_pred / temperature, dim=1),
F.softmax(student_pred / temperature, dim=1),
reduction='batchmean'
)
loss = alpha * student_loss + (1 - alpha) * distillation_loss

loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print('[{}, {}] loss: {}'.format(
epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
print('Finished Training')

Lưu ý một vài thay đổi so với model giáo viên:

Knowledge Distillation
Nguồn: https://intellabs.github.io/distiller/knowledge_distillation.html
  • model học sinh sử dụng thêm 1 loại loss nữa, đó là $distillation\ loss$, được tính bằng KL divergence giữa phân phối của model giáo viên và model học sinh.
  • Tham số $alpha$ dùng để đánh trọng số (weight) cho $student\ loss$ và $distillation\ loss$.
  • Tham số $temperature$ có tác dụng làm phân phối “soft” hơn. Tham số này càng tăng, thì phân phối càng “soft”, như hình dưới đây:

Smooth Distribution

Tại sao không train model học sinh từ đầu?

Chúng ta hoàn toàn có thể train model học sinh từ đầu, tuy nhiên sẽ khó đạt được kết quả như kỳ vọng. Kỹ thuật KD giúp model học sinh:

  • Generalize tốt hơn, do tập train của model giáo viên thông thường sẽ lớn hơn nhiều so với tập train của model học sinh
  • $distillation\ loss$ giảm thiểu ảnh hưởng của việc phân bố tập train của model học sinh khác nhiều so với phân bố thực tế

Tài liệu tham khảo

[1] https://towardsdatascience.com/knowledge-distillation-simplified-dd4973dbc764
[2] https://keras.io/examples/vision/knowledge_distillation/
[3] https://intellabs.github.io/distiller/knowledge_distillation.html
[4] https://arxiv.org/abs/1503.02531
[5] https://thenextweb.com/neural/2020/10/26/how-knowledge-distillation-compresses-neural-networks-syndication/