如何解决理解 PyTorch 的 RNN 实现
我想参考 Schmidt 的这篇论文,其中对 RNN 的一般描述:https://arxiv.org/pdf/1912.05911.pdf。所以,根据方程。论文(1)和(2)中,我们需要三个权重矩阵W_hh、W_xh和W_ho。但是,在打印简单RNN的参数数量时,我没有看到矩阵W_ho,我不明白(矩阵W_xh在打印输出中被称为W_ih):
将不胜感激!
解决方法
w_ho
将是从隐藏到输出的矩阵。在您的设置中,这很可能是 fc.weight
和 fc.bias
。您也可以通过检查参数数量或维度来验证这一点。您应该检查矩阵的维数而不是参数的数量来验证这一点。
更新:从 OP 的评论中,我了解到 OP 在理解 PyTorch 的 RNN 模块的输出方面存在问题。所以我在下面解释。
RNN 更新可以写成(没有偏差和非线性):
h(t,l) = h(t-1,l)Whh(l) + h(t,l-1)Wxh(l)
其中 t
表示时间,l
表示层。 h(.,0)
即在 l=0
处,h
与输入相同。
现在,RNN 模块实现这一点并输出最后一层的隐藏状态,即所有 t 的 h(t,L)
和每层的最后一个隐藏状态,即 h(N,1)
到 {{1} }(如h(N,L)
)。它没有实现上面链接的论文中提到的全连接输出层。
为什么只有这两个输出?
- 许多使用 RNN 的问题要么输出与序列相同的大小(例如 POS 标记),要么输出单个输出(例如分类)。对于前者,您可能每次都使用最后一个隐藏状态,而对于后者,您可能会使用最后一层(或可能所有层)的隐藏状态。
- 仅使用这些输出,就可以添加更多 RNN 层(使用输出)或继续处理序列(使用最后一个隐藏状态)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。