如何解决使用torch.max
我正在逐步采样大小为batch
的{{1}}。
我还有一个长度为n的列表torch.Size([n,8])
,其中包含对该批处理中的每个条目均有效的索引元组。
例如valid_indices
可能看起来像这样:valid_indices[0]
,这表明应该将索引2和6从(0,1,3,4,5,7)
的第一个条目中沿暗号1排除。
特别是在使用batch
时需要排除这些值。
要排除的指数(如果有的话)在批次中可能因条目而异。
有什么想法吗?预先感谢。
解决方法
我认为你已经老了。
IndexError: too many indices for tensor of dimension 1
直接在张量上使用元组索引时出现错误。 至少那是我执行以下行时能够重现的错误
t[0][valid_idx0]
其中t是大小为(10,8)的随机张量,而valid_idx0是具有4个元素的元组。
但是,当您将元组转换为以下列表时,同一行就可以正常工作
t[0][list(valid_idx0)]
>>> tensor([0.1847,0.1028,0.7130,0.5093])
但是在将这些索引应用于2D张量时,情况有所不同,因为我们需要保留张量的结构以进行批处理。
因此,将索引转换为掩码数组是合理的。
假设我们手头有一个元组valid_indices
的列表。首先是将其转换为列表列表。
valid_idx_list = [list(tup) for tup in valid_indices]
第二件事是将它们转换为掩码数组。
masks = np.zeros((t.size()))
for i,indices in enumerate(valid_idx_list):
masks[i][indices] = 1
完成。现在,我们可以应用蒙版并在蒙版张量上使用torch.max。
torch.max(t*masks)
请亲切地看一下我用来重现该问题的colab笔记本。
https://colab.research.google.com/drive/1BhKKgxk3gRwUjM8ilmiqgFvo0sfXMGiK?usp=sharing
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。