如何解决RuntimeError:对象没有属性nms:
所以我正在按照本教程构建对象检测,但遇到一个错误,我的朋友都没有得到(请注意,代码正在MacOS中运行)。我附上了错误消息的屏幕截图,如果遇到任何帮助,我会不断得到它。
import torchvision
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms as T
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__','person','bicycle','car','motorcycle','airplane','bus','train','truck','boat','traffic light','fire hydrant','N/A','stop sign','parking meter','bench','bird','cat','dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack','umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball','kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket','bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple','sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair','couch','potted plant','bed','dining table','toilet','tv','laptop','mouse','remote','keyboard','cell phone','microwave','oven','toaster','sink','refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush'
]
def get_prediction(img_path,threshold):
img = Image.open(img_path) # Load the image
transform = T.Compose([T.ToTensor()]) # Defing PyTorch Transform
img = transform(img) # Apply the transform to the image
pred = model([img]) # Pass the image to the model
pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())] # Get the Prediction Score
pred_boxes = [[(i[0],i[1]),(i[2],i[3])] for i in list(pred[0]['boxes'].detach().numpy())] # Bounding boxes
pred_score = list(pred[0]['scores'].detach().numpy())
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][
-1] # Get list of index with score greater than threshold.
pred_boxes = pred_boxes[:pred_t + 1]
pred_class = pred_class[:pred_t + 1]
return pred_boxes,pred_class
def object_detection_api(img_path,threshold=0.5,rect_th=3,text_size=3,text_th=3):
boxes,pred_cls = get_prediction(img_path,threshold) # Get predictions
img = cv2.imread(img_path) # Read image with cv2
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) # Convert to RGB
for i in range(len(boxes)):
cv2.rectangle(img,boxes[i][0],boxes[i][1],color=(0,255,0),thickness=rect_th) # Draw Rectangle with the coordinates
cv2.putText(img,pred_cls[i],cv2.FONT_HERSHEY_SIMPLEX,text_size,(0,thickness=text_th) # Write the prediction class
plt.figure(figsize=(20,30)) # display the output image
plt.imshow(img)
plt.xticks([])
plt.yticks([])
plt.show()
object_detection_api('./people.jpg',threshold=0.8)
解决方法
我不知道为什么它不起作用,但是在更新了我的python版本之后,它终于起作用了(3.7.6-> 3.8.5)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。