如何解决一个人如何利用torch.topk提供的索引?
假设我有一个形状为x
的火炬张量[N,N_g,2]
。可以将其视为N * N_g
2d向量。具体来说,x[i,j,:]
是第j
批中第i
组的2d向量。
现在我正在尝试获取每个组中前5个长度矢量的坐标。所以我尝试了以下方法:
(i)首先,我使用x_len = (x**2).sum(dim=2).sqrt()
来计算它们的长度,得出x_len.shape==[N,N_g]
。
(ii)然后,我使用tk = x_len.topk(5)
来获取每个组中的前5个长度。
(iii)所需的输出将是形状为x_top5
的张量[N,5,2]
。自然,我想到了使用tk.indices
来索引x
以获得x_top5
的索引。但是我失败了,因为似乎不支持这种索引。
我该怎么做?
一个最小的例子:
x = torch.randn(10,10,2) # N=10 is the batchsize,N_g=10 is the group size
x_len = (x**2).sum(dim=2).sqrt()
tk = x_len.topk(5)
x_top5 = x[tk.indices]
print(x_top5.shape)
# torch.Size([10,2])
但是,这使x_top5
的形状为张[10,2]
,而不是所需的[10,2]
。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。