如何解决如何获得模型预测的最后一个索引?
我是 PyTorch 的新手。我有一个变量 pred,它有一个张量列表。
print(pred)
output: [tensor([[176.64380,193.86154,273.84702,306.30405,0.83492,2.00000]])]
所以我想访问最后一个元素,即类。我首先将列表转换为张量。
x = torch.stack(pred)
output: tensor([[[176.64380,2.00000]]])
现在,我如何访问最后一个元素,或者有什么更好/更有效的方法来做到这一点?
编辑:为了进一步参考这里是执行分类任务的代码。
def classify_face(image):
device = torch.device("cpu")
img = process_image(image)
print('Image processed')
# img = image.unsqueeze_(0)
# img = image.float()
pred = model(img)[0]
# Apply NMS
pred = non_max_suppression(pred,0.4,0.5,classes = [0,1,2],agnostic = None )
if classify:
pred = apply_classifier(pred,modelc,img,im0s)
#print(pred)
model.eval()
model.cpu()
print(pred)
# output = non_max_suppression(output,classes = class_names,agnostic = False)
#_,predicted = torch.max(output[0],1)
#print(predicted.data[0],"predicted")
classification = torch.cat(pred)[:,-1]
index = int(classification)
print(names[index])
return names[index]
在预测期间,pred 由 x1
、y1
、x2
、y2
、conf
和 class
组成。
例如pred = [tensor([[176.64380,2.00000]])]
如果模型没有做出任何预测,那么 pred
就是空的。
例如pred = [tensor([],size=(0,6))]
目前,如果我的程序收到一个空张量并抛出错误,它就会停止预测:
Traceback (most recent call last):
File "WEBCAM_DETECT.py",line 168,in <module>
label = classify_face(frame)
File "WEBCAM_DETECT.py",line 150,in classify_face
index = int(classification)
ValueError: only one element tensors can be converted to Python scalars
编辑 1: 当我检查 pred 的长度时它似乎有效,但是当张量中有两行或更多行时我收到此错误。
[tensor([[212.38568,117.47020,339.35773,266.00513,0.74144,2.00000],[214.60651,118.50694,339.90192,265.91696,0.44277,0.00000]])]
#################
#################
Traceback (most recent call last):
File "WEBCAM_DETECT.py",line 172,line 154,in classify_face
index = int(classification)
ValueError: only one element tensors can be converted to Python scalars
如果在某一帧没有进行预测并继续下一帧,我如何让我的程序忽略?
解决方法
您可以选择第 3 轴上带有索引符号的最后一个元素,然后广播到一维张量:
x[:,:,-1].view(-1)
但是,我宁愿在 pred
上使用 torch.cat
,这样可以避免创建新轴:
torch.cat(pred)[:,-1]
编辑 - 您可以预先检查张量是否为空:
if len(pred) == 0:
return None
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。