728x90
일반적인 힌튼 KD loss 코드
import torch.nn as nn
import torch.nn.functional as F
KL_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),
F.softmax(targets/T, dim=1))*(alpha*T*T)
outputs
- student model output (feature -> classifer 결과)
- shape : (batch size, class num)
targets
- teacher model outputs (feature -> classifier 결과)
- shape : (batch size, class num)
T (Temperature)
- 얼마나 지식을 부드럽게 증류(전달) 할 것인지 결정함.
- T 값이 클수록 더 부드러워짐. (보통 T=5, alpha=0.2)
- 간단히, kld loss의 계수를 결정한다고 보면 됨.
Total loss = CEloss * 0.8 + Kld loss
torch 공식 홈페이지에 따라, log_softmax(outputs), softmax(targets)으로 설정,
reduction='batchmean' 은 batchsize만큼 나누어줌.
KD (Knowledge Distillation)의 효과
- classification에서 사용하는 one-hot label을 이용한 Cross entropy loss는 label noise가 발생함. 정답만 1로 설정되어 학습되기 때문. 따라서 비슷한 특징을 가지는 다른 카테고리라고 할지라도 그 둘을 완전히 분리하도록 학습됨. 다시말해, 학습 데이터에 오버피팅이 심함.
- KD (Knowledge Distillation)는 미리 학습된 teacher의 예측분포를 T(Temperature)로 증류하기 때문에, label smoothing 효과를 줄 수 있다. 뿐만 아니라, 정답 외의 다른 카테고리의 예측에도 loss를 흘려주어 보다 강인한 학습효과가 있음.
728x90
'Deep Learning (AI) > Training Loss' 카테고리의 다른 글
[Pytorch] Cosin Similarity Loss +역전파 (4) | 2023.07.30 |
---|---|
[Pytorch] Negative Pair Loss 코드 (4) | 2023.07.21 |
[Pytorch] L1-loss, L2-loss, KLD-loss 역전파 (13) | 2023.07.20 |
[Pytorch] Softmax with CrossEntropyLoss 역전파 (3) | 2023.07.20 |