如何解决“带注释的变压器”中的 label_smoothing
在"The Annotated Transformer"中,标签平滑实现如下:
class LabelSmoothing(nn.Module):
"Implement label smoothing."
def __init__(self,size,padding_idx,smoothing=0.0):
super(LabelSmoothing,self).__init__()
self.criterion = nn.KLDivLoss(size_average=False)
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
def forward(self,x,target):
assert x.size(1) == self.size
true_dist = x.data.clone()
true_dist.fill_(self.smoothing / (self.size - 2))
true_dist.scatter_(1,target.data.unsqueeze(1),self.confidence)
true_dist[:,self.padding_idx] = 0
mask = torch.nonzero(target.data == self.padding_idx)
if mask.dim() > 0:
true_dist.index_fill_(0,mask.squeeze(),0.0)
self.true_dist = true_dist
return self.criterion(x,Variable(true_dist,requires_grad=False))
特别是为什么
true_dist.fill_(self.smoothing / (self.size - 2))
self.size - 2
中的 2 来自哪里?
true_dist[:,self.padding_idx] = 0
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。