如何解决设计转换器类/nlp/pytorch 时出现 NotImplementedError
我在使用 nn.Transformer 执行机器翻译任务时遇到了一个令人困惑的错误。下面是我的代码显示变压器的定义:
class MyTrans(nn.Module):
def __init__(
self,d_model,vocab_size,padding_key,nhead,num_encoder_layers,num_decoder_layers,dim_feedforward,dropout,max_len,device,):
super(MyTrans,self).__init__()
self.device = device
self.positionalencoding = PositionalEncoding(d_model,dropout=dropout,max_len=max_len)
self.embedding = nn.Embedding(vocab_size,d_model)
self.padding_key = padding_key
self.transfomer = nn.Transformer(
d_model,)
self.fc = nn.Linear(d_model,vocab_size)
self.dropout = nn.Dropout(dropout)
def padding_masks(self,src_tgt_memory):
mask = src_tgt_memory.transpose(0,1) == self.padding_key
return mask.to(self.device)
def square_masks(self,src_tgt):
mask = self.transformer.generate_square_subsequent_mask(src_tgt.size()[1])
return mask.to(self.device)
def forward(self,src,tgt):
S,N = src.size()
T,N = tgt.size()
# N = 64,src_seq = S,tgt_seq = T
embed_src = self.embedding(src)
new_src = self.positionalencoding(embed_src)
embed_tgt = self.embedding(tgt)
new_tgt = self.positionalencoding(embed_tgt)
src_key_padding_mask = self.padding_masks(src)
tgt_mask = self.square_masks(tgt)
output = self.transfomer(
src = new_src,tgt = new_tgt,src_key_padding_mask = src_key_padding_mask,tgt_mask = tgt_mask,)
output = self.fc(output)
output = F.log_softmax(output,dim=-1)
return output
然后当我使用这个模型运行训练过程时,它给出了 NotImplementedError:
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self,*input,**kwargs)
725 result = self._slow_forward(*input,**kwargs)
726 else:
--> 727 result = self.forward(*input,**kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self,*input)
173 registered hooks while the latter silently ignores them.
174 """
--> 175 raise NotImplementedError
176
177
NotImplementedError:
我知道当有缩进或错字问题时会出现这个错误,但我详细检查了没有这样的问题。这个错误困扰了我一整天,我仍然无法修复它。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。