import torch.nn as nn
class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim) # Cross Entropy 부분의 log softmax 미리 계산하기
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred) # 예측값과 동일한 크기의 영텐서 생성
true_dist.fill_(self.smoothing / (self.cls - 1)) # alpha/(K-1)을 만들어 줌(alpha/K로 할 수도 있음)
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) # (1-alpha)y + alpha/(K-1)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) # Cross Entropy Loss 계산