본문 바로가기
Deep Learning (AI)/Training Loss

[Pytorch] torch kld-loss (KD-loss) 코드

by 스프링섬머 2023. 7. 14.
728x90

일반적인 힌튼 KD loss 코드

import torch.nn as nn
import torch.nn.functional as F
KL_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1), 
                                              F.softmax(targets/T, dim=1))*(alpha*T*T)

outputs

  • student model output (feature -> classifer 결과)
  • shape : (batch size, class num)

targets 

  • teacher model outputs (feature -> classifier 결과)
  • shape : (batch size, class num)

T (Temperature)

  • 얼마나 지식을 부드럽게 증류(전달) 할 것인지 결정함.
  • T 값이 클수록 더 부드러워짐. (보통 T=5, alpha=0.2)
  • 간단히, kld loss의 계수를 결정한다고 보면 됨. 

Total loss = CEloss * 0.8 + Kld loss

 

torch 공식 홈페이지에 따라, log_softmax(outputs), softmax(targets)으로 설정,

reduction='batchmean' 은 batchsize만큼 나누어줌. 

 

KD (Knowledge Distillation)의 효과

  • classification에서 사용하는 one-hot label을 이용한 Cross entropy loss는 label noise가 발생함. 정답만 1로 설정되어 학습되기 때문. 따라서 비슷한 특징을 가지는 다른 카테고리라고 할지라도 그 둘을 완전히 분리하도록 학습됨. 다시말해, 학습 데이터에 오버피팅이 심함.
  • KD (Knowledge Distillation)는 미리 학습된 teacher의 예측분포를 T(Temperature)로 증류하기 때문에, label smoothing 효과를 줄 수 있다. 뿐만 아니라, 정답 외의 다른 카테고리의 예측에도 loss를 흘려주어 보다 강인한 학습효과가 있음. 
728x90