微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在pytorch中使用索引张量索引中间维度?

如何解决如何在pytorch中使用索引张量索引中间维度?

如何索引具有n个维度的张量t和m个index张量,从而保留t的最后一个维度?对于尺寸m之前的所有尺寸,index张量的形状等于张量t。换句话说,我想索引张量的中间维度,同时保留选定索引的以下所有维度。

例如,假设我们有两个张量:

t = torch.randn([3,5,2]) * 10
index = torch.tensor([[1,3],[0,4],[3,2]]).long()

带有t:

tensor([[[ 15.2165,-7.9702],[  0.6646,5.2844],[-22.0657,-5.9876],[ -9.7319,11.7384],[  4.3985,-6.7058]],[[-15.6854,-11.9362],[ 11.3054,3.3068],[ -4.7756,-7.4524],[  5.0977,-17.3831],[  3.9152,-11.5047]],[[ -5.4265,-22.6456],[  1.6639,10.1483],[ 13.2129,3.7850],[  3.8543,-4.3496],[ -8.7577,-12.9722]]])

然后我想要的输出将具有(3,2,2)的形状,并且是:

tensor([[[  0.6646,11.7384]],[[  3.8543,3.7850]]])

一个示例是我有一个形状为t的张量(40,10,6,2)一个形状为(40,3)的索引张量。这应该查询张量t的维度3,并且预期的输出形状将为(40,3,2)

如何在不使用循环的情况下以通用方式实现这一目标?

解决方法

在这种情况下,您可以执行以下操作:

t[torch.arange(t.shape[0]).unsqueeze(1),index,...]

完整代码:

import torch

t = torch.tensor([[[ 15.2165,-7.9702],[  0.6646,5.2844],[-22.0657,-5.9876],[ -9.7319,11.7384],[  4.3985,-6.7058]],[[-15.6854,-11.9362],[ 11.3054,3.3068],[ -4.7756,-7.4524],[  5.0977,-17.3831],[  3.9152,-11.5047]],[[ -5.4265,-22.6456],[  1.6639,10.1483],[ 13.2129,3.7850],[  3.8543,-4.3496],[ -8.7577,-12.9722]]])

index = torch.tensor([[1,3],[0,4],[3,2]]).long()

output = t[torch.arange(t.shape[0]).unsqueeze(1),...]

# tensor([[[  0.6646,#          [ -9.7319,11.7384]],# 
#         [[-15.6854,#          [  3.9152,# 
#         [[  3.8543,#          [ 13.2129,3.7850]]])

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。