如何解决Pytorch 的 nn.TransformerEncoder "src_key_padding_mask" 未按预期运行
我正在使用 Pytorch 的 nn.TransformerEncoder
模块。我得到了(正常)形状(batch-size,seq-len,emb-dim
)的输入样本。一批中的所有样本都已零填充到该批中最大样本的大小。因此,我希望忽略所有零值的注意力。
文档说,要向 src_key_padding_mask
模块的 forward
函数添加参数 nn.TransformerEncoder
。这个掩码应该是一个形状为 (batch-size,seq-len
) 的张量,并且对于每个索引,要么是填充零的 True
,要么是其他任何东西的 False
。
我做到了:
. . .
def forward(self,x):
# x.size -> i.e.: (200,28,200)
mask = (x == 0).cuda().reshape(x.shape[0],x.shape[1])
# mask.size -> i.e.: (200,20)
x = self.embed(x.type(torch.LongTensor).to(device=device))
x = self.pe(x)
x = self.transformer_encoder(x,src_key_padding_mask=mask)
. . .
当我不设置 src_key_padding_mask
时,一切正常。但是我这样做时得到的错误如下:
File "/home/me/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py",line 4282,in multi_head_attention_forward
assert key_padding_mask.size(0) == bsz
AssertionError
似乎是在比较掩码的第一个维度,即批量大小,与可能代表批量大小的 bsz
。但是为什么它会失败呢?非常感谢帮助!
解决方法
我遇到了同样的问题,这不是错误:pytorch's Transformer implementation 要求输入 FirstStep
为 x
,而您的输入似乎为 (seq-len x batch-size x emb-dim)
。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。