본문 바로가기
Deep Learning (AI)/Training Loss

[Pytorch] Negative Pair Loss 코드

by 스프링섬머 2023. 7. 21.
728x90

두 feature가 멀어지도록 학습하는 loss

 

Negative Pair Loss = max(margin - distance_metric(feature_1, feature_2), 0)   -------- (1)

 

where,

max(f(x),0)은 f(x) > 0 이면 backpropagation 수행,

margin은 feature_1과 feature_2를 얼마나 멀게 할 것인지,

distance_metric은 두 feature의 거리를 구하는 function.

 

즉, 두 feature의 거리를 구하고 margin을 더했을 시, 0보다 크면 역전파하여 멀어지게 한다.

이 term은 triplet loss에서 positive pair과 같이 사용되기도 한다.

 

그렇다면 어떻게 두 feature를 멀어지게 한다는 것일까? 

 

예를 들어, feature_1 = 0.5, feature_2 = 0.9라고 해보자.

distance_metric = MSE()를 사용하여 구하면 (0.5-0.9)^2 = 0.16이 된다. 이 값은 forward에서의 출력값이고, 딥러닝 모델에 학습되는 gradient를 구하게 되면(feature_1로의 역전파를 생각할 때), 2(0.5-0.9) = -0.8이 나온다.

이러한 -0.8의 미분값을 모델의 학습 weight의 업데이트에 이용한다면, 

W = W -learning_rate * (-0.8)이 되고, 이는 W가 더 커지는 방향으로 학습하게 된다. 

하지만, 우리는 맨 위의 수식 (1)에서 마이너스를 붙여주었기 때문에, 반대 방향으로 학습이 되어 W가 작아지는 방향으로 업데이트 된다. 즉, feature_1이 feature_2와 멀어지도록 된다.

 

다시 forward 값이였던 0.16을 살펴보자. 이 값을 수식 (1)에 넣게 된다면,

max(margin -0.16, 0)이 되고, margin=0.3이라고 가정했을 때는, max(0.14,0)이 되어 양수이므로 역전파하게 된다. 

 

학습이 진행되어 두 feature가 margin=0.3보다 크게 멀어졌다면 역전파하지 않게 된다. 

 

 

 

# pytorh 코드

import torch.nn as nn
import torch

class Negative_pair_loss(nn.Module):
    def __init__(self, margin=0.3):
        super(Negative_pair_loss, self).__init__()
        self.m = margin
        self.mse = nn.MSELoss(reduction='none')

    def forward(self, feat_1, feat_2):
        mse_result = self.mse(feat_1, feat_2)
        mse_result = self.m + -1*torch.mean(mse_result, dim=1)
        loss = torch.clamp(mse_result, min=0.0)
        loss = torch.mean(loss)

        return loss

 

 

 

728x90