如何解决CNN训练的模型似乎不起作用
我已经训练了CNN模型,并且希望针对新数据运行训练后的模型。但是,似乎训练后的模型无法像训练中那样正确地预测计数。我感觉该模型未使用PTH文件。有人可以告诉我我在做什么错吗?
import argparse
import datetime
import glob
import os
import random
import shutil
import time
from os.path import join
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch.optim as optim
from convnet3_eval import Convnet
from dataset2_eval import CellsDataset
parser = argparse.ArgumentParser('Predicting hits from pixels')
parser.add_argument('name',type=str,help='Name of experiment')
parser.add_argument('data_dir',help='Path to data directory containing images and gt.csv')
parser.add_argument('--weight_decay',type=float,default=0.0,help='Weight decay coefficient (something like 10^-5)')
parser.add_argument('--lr',default=0.0001,help='Learning rate')
args = parser.parse_args()
metadata = pd.read_csv(join(args.data_dir,'gt.csv'))
metadata.set_index('filename',inplace=True)
dataset = CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
dataset = DataLoader(dataset,num_workers=4,pin_memory=True)
model_path = '/base_model.pth'
model = Convnet()
optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)
for images,paths in tqdm(dataset):
targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
targets = targets.float()
# code to print training data to a csv file
filename=CellsDataset(args.data_dir,return_filenames=True)
output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
print(preds)
paths_test = np.array([paths])
names_preds = np.hstack(paths)
print(names_preds)
df=pd.DataFrame({'Image_Name':names_preds,'Target':targets.detach(),'Prediction':preds.detach()})
print(df)
# save image name,targets,and predictions
df.to_csv(r'model.csv',index=False,mode='a')
model.load_state_dict(torch.load(model_path))
model.eval()
解决方法
将最后两行移动到加载权重的位置
model.load_state_dict(torch.load(model_path))
model.eval()
在下面初始化模型的for循环上方。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。