如何解决如何扩展该对象检测pytorch程序以检测和分类多个类?
我遵循了this教程来构建对象检测器,现在我正尝试对其进行扩展以识别更多的类。
此对象检测算法仅适用于一类(浣熊),但是可以用更大的数据集训练神经网络以识别更多的类。
它使用了预先训练的fast_rcnn_resnet50网络。
有人可以帮助我修改代码,以便区分多个类吗?
谢谢!
class RaccoonDataset(torch.utils.data.Dataset):
def __init__(self,root,data_file,transforms=None):
self.root = root
self.transforms = transforms
self.imgs = sorted(os.listdir(os.path.join(root,"images")))
self.path_to_data_file = data_file
def __getitem__(self,idx):
# load images and bounding boxes
img_path = os.path.join(self.root,"images",self.imgs[idx])
img = Image.open(img_path).convert("RGB")
box_list = parse_one_annot(self.path_to_data_file,self.imgs[idx])
boxes = torch.as_tensor(box_list,dtype=torch.float32)
num_objs = len(box_list)
# there is only one class
labels = torch.ones((num_objs,),dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:,3] - boxes[:,1]) * (boxes[:,2] - boxes[:,0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
img,target = self.transforms(img,target)
return img,target
def __len__(self):
return len(self.imgs)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。