微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

一个人如何利用torch.topk提供的索引?

如何解决一个人如何利用torch.topk提供的索引?

假设我有一个形状为x的火炬张量[N,N_g,2]。可以将其视为N * N_g 2d向量。具体来说,x[i,j,:]是第j批中第i组的2d向量。

现在我正在尝试获取每个组中前5个长度矢量的坐标。所以我尝试了以下方法

(i)首先,我使用x_len = (x**2).sum(dim=2).sqrt()来计算它们的长度,得出x_len.shape==[N,N_g]

(ii)然后,我使用tk = x_len.topk(5)获取每个组中的前5个长度。

(iii)所需的输出将是形状为x_top5的张量[N,5,2]。自然,我想到了使用tk.indices来索引x以获得x_top5的索引。但是我失败了,因为似乎不支持这种索引。

我该怎么做?


一个最小的例子:

x = torch.randn(10,10,2) # N=10 is the batchsize,N_g=10 is the group size
x_len = (x**2).sum(dim=2).sqrt()
tk = x_len.topk(5)

x_top5 = x[tk.indices]

print(x_top5.shape)
# torch.Size([10,2])

但是,这使x_top5的形状为张[10,2],而不是所需的[10,2]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。