728x90 torch kd1 [Pytorch] torch kld-loss (KD-loss) 코드 일반적인 힌튼 KD loss 코드 import torch.nn as nn import torch.nn.functional as F KL_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1), F.softmax(targets/T, dim=1))*(alpha*T*T) outputs student model output (feature -> classifer 결과) shape : (batch size, class num) targets teacher model outputs (feature -> classifier 결과) shape : (batch size, class num) T (Temperature) 얼마나 지식을 부드럽게 .. 2023. 7. 14. 이전 1 다음 728x90