几乎恒定的训练和验证准确性

如何解决几乎恒定的训练和验证准确性

我是 pytorch 的新手,我的问题可能有点幼稚 我正在我的数据集上训练一个预训练的 VGG16 网络,它的大小接近 33000 张图像,分为 8 个类别,带有标签 [1,2,...,8],并且我的类别不平衡。我的问题是在训练过程中,验证和训练的准确率很低并且没有增加,我的代码有问题吗? 如果没有,你有什么建议来改进培训? '''

import torch
import time
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from torch.optim import Adam
import cv2
import torchvision.models as models
from classify_dataset import Classification_dataset
from torchvision import transforms

transform = transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=45),transforms.ToTensor(),transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
                                    ])
dataset = Classification_dataset(root_dir=r'//home/arisa/Desktop/Hamid/IQA/Hamid_Dataset',csv_file=r'/home/arisa/Desktop/Hamid/IQA/new_label.csv',transform=transform)


target = dataset.labels - 1

train_indices,test_indices = train_test_split(np.arange(target.shape[0]),stratify=target)
test_dataset = torch.utils.data.Subset(dataset,indices=test_indices)
train_dataset = torch.utils.data.Subset(dataset,indices=train_indices)

class_sample_count = np.array([len(np.where(target[train_indices] == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target[train_indices]])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
  
sampler = torch.utils.data.WeightedRandomSampler(samples_weight,len(samples_weight),replacement = True)


train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64,sampler=sampler)
test_loader = torch.utils.data.DataLoader(test_dataset,shuffle=False)
for param in model.parameters():
     param.requires_grad = False

num_ftrs = model.classifier[0].in_features
model.classifier = nn.Linear(num_ftrs,8)    


optimizer = Adam(model.parameters(),lr = 0.0001 )
criterion = nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=50,gamma=0.01)

path = '/home/arisa/Desktop/Hamid/IQA/'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
def train_model(model,train_loader,valid_loader,optimizer,criterion,scheduler=None,num_epochs=10 ):
        
    
        min_valid_loss = np.inf
        model.train()
        start = time.time()
        TrainLoss = []
        model = model.to(device)
        for epoch in range(num_epochs):
            total = 0
            correct = 0 
            train_loss = 0
            #lr_scheduler.step()
            print('Epoch {}/{}'.format(epoch+1,num_epochs))
            print('-' * 10)

            train_loss = 0.0
            for x,y in train_loader:
                x = x.to(device)
                #print(y.shape)
                y = y.view(y.shape[0],).to(device)
                y = y.to(device)
                y -= 1
                out = model(x)

                loss = criterion(out,y)
                optimizer.zero_grad()
                loss.backward()
                
                TrainLoss.append(loss.item()* y.shape[0])
                train_loss += loss.item() * y.shape[0]
                _,predicted = torch.max(out.data,1)
                total += y.size(0)
                correct += (predicted == y).sum().item()
                optimizer.step()
                lr_scheduler.step()
            accuracy = 100*correct/total
            valid_loss = 0.0
            val_loss = []
            model.eval()
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                
                for x_val,y_val in test_loader:
                    x_val = x_val.to(device)
                    y_val = y_val.view(y_val.shape[0],).to(device)
                    y_val -= 1
                    target = model(x_val)
                    loss = criterion(target,y_val)
                    valid_loss += loss.item() * y_val.shape[0]

                    _,predicted = torch.max(target.data,1)
                    val_total += y_val.size(0)
                    val_correct += (predicted == y_val).sum().item()



                    val_loss.append(loss.item()* y_val.shape[0])
                val_acc = 100*val_correct / val_total




                print(f'Epoch {epoch + 1} \t\t Training Loss: {train_loss / len(train_loader)} \t\t Validation Loss: {valid_loss / len(test_loader)} \t\t Train Acc:{accuracy} \t\t Validation Acc:{val_acc}')
                if min_valid_loss > (valid_loss / len(test_loader)):
                    print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss / len(test_loader):.6f}) \t Saving The Model')
                    min_valid_loss = valid_loss / len(test_loader)
                    state = {'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(),}
                    torch.save(state,'/home/arisa/Desktop/Hamid/IQA/checkpoint.t7')

        end = time.time()
        print('TRAIN TIME:')
        print('%.2gs'%(end-start))

    train_model(model=model,train_loader=train_loader,optimizer=optimizer,criterion=criterion,valid_loader= test_loader,num_epochs=500  )

提前致谢 这是 15 epoch 的结果

Epoch 1/500
----------
Epoch 1          Training Loss: 205.63448420514916       Validation Loss: 233.89266112356475         Train Acc:39.36360386127994         Validation Acc:24.142040038131555
Epoch 2/500
----------
Epoch 2          Training Loss: 199.05699240435197       Validation Loss: 235.08799531243065         Train Acc:41.90998291820601         Validation Acc:24.27311725452812
Epoch 3/500
----------
Epoch 3          Training Loss: 199.15626737127448       Validation Loss: 236.00033430619672         Train Acc:41.1035633416756          Validation Acc:23.677311725452814
Epoch 4/500
----------
Epoch 4          Training Loss: 199.02581041173886       Validation Loss: 233.60767459869385         Train Acc:41.86628530568466         Validation Acc:24.606768350810295
Epoch 5/500
----------
Epoch 5          Training Loss: 198.61493769454472       Validation Loss: 233.7503859202067          Train Acc:41.53656695665991         Validation Acc:25.0
Epoch 6/500
----------
Epoch 6          Training Loss: 198.71323942956585       Validation Loss: 234.17176149830675         Train Acc:41.639852222619474        Validation Acc:25.369399428026693
Epoch 7/500
----------
Epoch 7          Training Loss: 199.9395153770592        Validation Loss: 234.1744423635078          Train Acc:40.98041552456998         Validation Acc:24.84509056244042
Epoch 8/500
----------
Epoch 8          Training Loss: 199.3533399020355        Validation Loss: 235.4645173188412          Train Acc:41.26643626107337         Validation Acc:24.165872259294567
Epoch 9/500
----------
Epoch 9          Training Loss: 199.6451746921249        Validation Loss: 233.33387595956975         Train Acc:40.96452548365312         Validation Acc:24.59485224022879
Epoch 10/500
----------
Epoch 10         Training Loss: 197.9305159737011        Validation Loss: 233.76405122063377         Train Acc:41.8782028363723          Validation Acc:24.6186844613918
Epoch 11/500
----------
Epoch 11         Training Loss: 199.33247244055502       Validation Loss: 234.41085289463854         Train Acc:41.59218209986891         Validation Acc:25.119161105815063
Epoch 12/500
----------
Epoch 12         Training Loss: 199.87399289874256       Validation Loss: 234.23621463775635         Train Acc:41.028085647320545        Validation Acc:24.49952335557674
Epoch 13/500
----------
Epoch 13         Training Loss: 198.85540591944292       Validation Loss: 234.33149099349976         Train Acc:41.206848607635166        Validation Acc:24.857006673021925
Epoch 14/500
----------
Epoch 14         Training Loss: 199.92641723337513       Validation Loss: 233.37722391070741         Train Acc:41.15520597465539         Validation Acc:24.988083889418494
Epoch 15/500
----------
Epoch 15         Training Loss: 197.82172771698328       Validation Loss: 234.4943131533536          Train Acc:41.69943987605768         Validation Acc:24.380362249761678

解决方法

你冻结了你的模型

for param in model.parameters():
     param.requires_grad = False

基本上是说“不为任何权重计算任何梯度”,这相当于不更新权重 - 因此没有优化

,

我的问题出在 model.train()。这个短语应该在训练循环内。但在我的情况下,我把它放在训练循环之外,当涉及到 model.eval() 时,模型保持在这种模式

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-