如何用pytorch实现图像的k折交叉验证,解决类别不平衡

在用深度学习做分类的时候,常常需要进行交叉验证,目前pytorch没有通用的一套代码来实现这个功能。可以借助 sklearn中的 StratifiedKFold,KFold来实现,其中StratifiedKFold可以根据类别的样本量,进行数据划分。以5折为例,它可以实现每个类别的样本都是4:1划分。

代码简单的示例如下:

from sklearn.model_selection import  StratifiedKFold
skf = StratifiedKFold(n_splits=5)
for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
    trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
    traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]

以上示例是将所有imgs列表与对应的labels列表进行split,得到train_idx代表训练集的下标,val_idx代表验证集的下标。后续代码只需要将split完成的trainset与valset输入dataset即可。

接下来用我自己数据集的实例来完整地实现整个过程,即从读取数据,到开始训练。如果你的数据集存储方式和我不同,改一下数据读取代码即可。关键是如何获取到imgs和对应的labels。

我的数据存储方式是这样的(类别为文件夹名,属于该类别的图像在该文件夹下):

"""A generic data loader where the images are arranged in this way: ::

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

 以下代码是获取imgs与labels的过程:

import os
import numpy as np

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png')

def is_image_file(filename):
    return filename.lower().endswith(IMG_EXTENSIONS)

def find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

if __name__ == "__main__":
    dir = 'your root path'
    classes, class_to_idx = find_classes(dir)
    imgs = []
    labels = []
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(dir, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_image_file(path):
                    imgs.append(path)
                    labels.append(class_index)

上述代码只需要把dir改为自己的root路径即可。接下来对所有数据进行5折split。其中我自己写了MyDataset类,可以直接照搬用。

from sklearn.model_selection import  StratifiedKFold
    skf = StratifiedKFold(n_splits=5) #5折
    for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
        trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
        traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
        train_dataset = MyDataset(trainset, traintag, data_transforms['train'] )
        val_dataset = MyDataset(valset, valtag, data_transforms['val'])
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):

    def __init__(self, imgs, labels, transform=None,target_transform=None):

        self.imgs = imgs
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        path = self.imgs[idx]
        target = self.labels[idx]

        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

有了数据集之后,就可以创建dataloader了,后面就是正常的训练代码:

from sklearn.model_selection import  StratifiedKFold
    skf = StratifiedKFold(n_splits=5) #5折
    for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
        trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
        traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
        train_dataset = MyDataset(trainset, traintag, data_transforms['train'] )
        val_dataset = MyDataset(valset, valtag, data_transforms['val'])
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                                  shuffle=True, num_workers=args.workers)
        test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                                  shuffle=True, num_workers=args.workers)

        for epoch in range(args.epoch):
            train_acc, train_loss = train(train_dataloader, model, criterion, args)
            test_acc, tect_acc_top5, test_loss = validate(test_dataloader, model, criterion, args)

为了保证每次跑的时候分的数据都是一致的,注意shuffle=False(默认)

StratifiedKFold(n_splits=5,shuffle=False)

以上就是实现的基本代码,之所以在代码层面实现k折而不是在数据层面做,比如预先把数据等分为5份。是因为这个代码可以支持数据样本的随意增减,不需要人为地再去分数据,十分方便。 

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


学习编程是顺着互联网的发展潮流,是一件好事。新手如何学习编程?其实不难,不过在学习编程之前你得先了解你的目的是什么?这个很重要,因为目的决定你的发展方向、决定你的发展速度。
IT行业是什么工作做什么?IT行业的工作有:产品策划类、页面设计类、前端与移动、开发与测试、营销推广类、数据运营类、运营维护类、游戏相关类等,根据不同的分类下面有细分了不同的岗位。
女生学Java好就业吗?女生适合学Java编程吗?目前有不少女生学习Java开发,但要结合自身的情况,先了解自己适不适合去学习Java,不要盲目的选择不适合自己的Java培训班进行学习。只要肯下功夫钻研,多看、多想、多练
Can’t connect to local MySQL server through socket \'/var/lib/mysql/mysql.sock问题 1.进入mysql路径
oracle基本命令 一、登录操作 1.管理员登录 # 管理员登录 sqlplus / as sysdba 2.普通用户登录
一、背景 因为项目中需要通北京网络,所以需要连vpn,但是服务器有时候会断掉,所以写个shell脚本每五分钟去判断是否连接,于是就有下面的shell脚本。
BETWEEN 操作符选取介于两个值之间的数据范围内的值。这些值可以是数值、文本或者日期。
假如你已经使用过苹果开发者中心上架app,你肯定知道在苹果开发者中心的web界面,无法直接提交ipa文件,而是需要使用第三方工具,将ipa文件上传到构建版本,开...
下面的 SQL 语句指定了两个别名,一个是 name 列的别名,一个是 country 列的别名。**提示:**如果列名称包含空格,要求使用双引号或方括号:
在使用H5混合开发的app打包后,需要将ipa文件上传到appstore进行发布,就需要去苹果开发者中心进行发布。​
+----+--------------+---------------------------+-------+---------+
数组的声明并不是声明一个个单独的变量,比如 number0、number1、...、number99,而是声明一个数组变量,比如 numbers,然后使用 nu...
第一步:到appuploader官网下载辅助工具和iCloud驱动,使用前面创建的AppID登录。
如需删除表中的列,请使用下面的语法(请注意,某些数据库系统不允许这种在数据库表中删除列的方式):
前不久在制作win11pe,制作了一版,1.26GB,太大了,不满意,想再裁剪下,发现这次dism mount正常,commit或discard巨慢,以前都很快...
赛门铁克各个版本概览:https://knowledge.broadcom.com/external/article?legacyId=tech163829
实测Python 3.6.6用pip 21.3.1,再高就报错了,Python 3.10.7用pip 22.3.1是可以的
Broadcom Corporation (博通公司,股票代号AVGO)是全球领先的有线和无线通信半导体公司。其产品实现向家庭、 办公室和移动环境以及在这些环境...
发现个问题,server2016上安装了c4d这些版本,低版本的正常显示窗格,但红色圈出的高版本c4d打开后不显示窗格,
TAT:https://cloud.tencent.com/document/product/1340