如何解决如何执行2d索引,在Pytorch中收集?
用例:
我正在NLP程序中编写波束搜索子模块,该程序需要2d索引才能实现。 设置是:
-
k_next_words
是形状(batch_size,d1,d2) = (4,5,6)
-
ix
是形状(batch_size,k_num_to_select,ix_dim=2) = (4,7,2)
, 其中4
暗淡是批处理大小,7
暗淡是要查找的样本数,而2
暗淡是与d1
和{{1 }}中的d2
(这样索引中的第一个条目在0到4之间,第二个条目在0到5之间)。
问题:我们希望针对4个批次中的每个批次,在2d索引处收集与k_next_words
的元素对应的7个条目。
这可能是聚会的用例,但是到目前为止,我还没有开始使用它!
示例:
这是一个玩具示例,用于说明我的用例。现在暂时忽略批处理大小(= 4)。
k_next_words
是一个火炬。张量为(5,6)的张量,类似于:
k_next_words
[[1,2,3,4,6]
[11,12,13,14,15,16]
[21,22,23,24,25,26]
[31,32,33,34,35,36]
[41,42,43,44,45,46]]
是我们要在ix
中查找的(7,2)
坐标张量。如您所见,最后一个维度的第一个条目始终在0-4之间,最后一个维度的第二个条目始终在0-5之间
k_next_words
鉴于这两个输入,我想查找[[4,5],[1,0],[0,2],[2,1],[3,1]]
中由索引确定的7
的{{1}}项(注意:索引从0开始),即在玩具示例:
k_next_words
但是,我想执行上述操作,但在示例中也要针对给定的批处理大小(= 4)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。