如何在PyTorch中找到混淆矩阵并将其绘制为图像分类器

如何解决如何在PyTorch中找到混淆矩阵并将其绘制为图像分类器

基本上这是VGG-16模型,我已经进行了转移学习并且对模型进行了微调,我在2周前对模型进行了训练,发现测试和训练的准确性都很高,但现在我也需要模型的逐级准确性,我想找出混淆矩阵,也想绘制矩阵。培训代码:

# Training the model again from the last CNN Block to The End of the Network
dataset = 'C:\\Users\\Sara Latif Khan\\OneDrive\\Desktop\\FYP_\\Scene15\\15-Scene'
model = model.to(device)
optimizer = Adam(filter(lambda p: p.requires_grad,model.parameters()))

#Training Fixed Feature Extractor for 15 epochs
num_epochs = 5
batch_loss = 0
cum_epoch_loss = 0 #cumulative loss for each batch

for e in range(num_epochs):
    cum_epoch_loss = 0
  
    for batch,(images,labels) in enumerate(trainloader,1):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logps = model(images)
        loss = criterion(logps,labels)
        loss.backward()
        optimizer.step()
    
        batch_loss += loss.item()
        print(f'Epoch({e}/{num_epochs} : Batch number({batch}/{len(trainloader)}) : Batch loss : {loss.item()}')
        torch.save(model,dataset+'_model_'+str(e)+'.pt')
    
print(f'Training loss : {batch_loss/len(trainloader)}')

这是我用来根据测试加载器中的数据检查模型准确性的代码。

model. to('cpu')

model.eval()
with torch.no_grad():
    num_correct = 0
    total = 0
    
    #set_trace ()
    for batch,labels) in enumerate(testloader,1):
        
        logps = model(images)
        output = torch.exp(logps)
        
        pred = torch.argmax(output,1)
        total += labels.size(0)
        
        num_correct += (pred==labels).sum().item()
        print(f'Batch ({batch} / {len(testloader)})')
        
        # to check the accuracy of model on 5 batches
        # if batch == 5:
            # break
            
    print(f'Accuracy of the model on {total} test images: {num_correct * 100 / total }% ')  

接下来,我需要找到模型的类精度。我正在使用Jupyter Notebook。我是否应该重新加载保存的模型并找到cm或执行该操作的适当方法。

解决方法

您必须保存测试集的所有预测和目标。

predictions,targets = [],[]
for images,labels in testloader:
    logps = model(images)
    output = torch.exp(logps)
    pred = torch.argmax(output,1)

    # convert to numpy arrays
    pred = pred.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    
    for i in range(len(pred)):
        predictions.append(pred[i])
        targets.append(labels[i])

现在,您已存储了测试集的所有预测和实际目标。 下一步是创建混淆矩阵。我想我可以给我我经常使用的功能:

def create_confusion_matrix(y_true,y_pred,classes):
    """ creates and plots a confusion matrix given two list (targets and predictions)
    :param list y_true: list of all targets (in this case integers bc. they are indices)
    :param list y_pred: list of all predictions (in this case one-hot encoded)
    :param dict classes: a dictionary of the countries with they index representation
    """

    amount_classes = len(classes)

    confusion_matrix = np.zeros((amount_classes,amount_classes))
    for idx in range(len(y_true)):
        target = y_true[idx][0]

        output = y_pred[idx]
        output = list(output).index(max(output))

        confusion_matrix[target][output] += 1

    fig,ax = plt.subplots(1)

    ax.matshow(confusion_matrix)
    ax.set_xticks(np.arange(len(list(classes.keys()))))
    ax.set_yticks(np.arange(len(list(classes.keys()))))

    ax.set_xticklabels(list(classes.keys()))
    ax.set_yticklabels(list(classes.keys()))

    plt.setp(ax.get_xticklabels(),rotation=45,ha="left",rotation_mode="anchor")
    plt.setp(ax.get_yticklabels(),ha="right",rotation_mode="anchor")

    plt.show()

所以y_true是所有目标,y_pred是所有预测,classes是一个字典,将标签映射到实际的类名,例如:

classes = {"dog": [1,0],"cat": [0,1]}

然后只需致电:

create_confusion_matrix(targets,predictions,classes)

可能您将不得不对其稍作修改以适应您的代码,但我希望这对您有用。 :)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-