如何解决如何在Pytorch中使用张量板可视化所有预测?
我尝试通过预测使实际图像可视化,以弄清楚我的算法如何执行以及错误地预测了哪些标签。但是,当我设置步骤时,在张量板上可视化期间,它不会显示所有步骤。因此,它不会显示所有训练图像及其标签。相反,我只能从所有训练图像中看到几个例子。
writer = SummaryWriter(log_dir='graphs')
def matplotlib_imshow(img):
npimg = img.cpu().numpy()
npimg = np.transpose(npimg,(1,2,0))
plt.imshow((npimg * 255).astype(np.uint8))
def images_to_probs(net,images):
output = net(images)
_,preds_tensor = torch.max(output,1)
preds = np.squeeze(preds_tensor.cpu().numpy())
return preds,[F.softmax(el,dim=0)[i].item() for i,el in zip(preds,output)]
def plot_classes_preds(net,images,labels):
preds,probs = images_to_probs(net,images)
fig = plt.figure(figsize=(6,6))
for idx in np.arange(4):
ax = fig.add_subplot(1,4,idx+1,xticks=[],yticks=[])
matplotlib_imshow(images[idx])
ax.set_title("{0},{1:.1f}%\n(label: {2})".format(
classes[preds[idx]],probs[idx] * 100.0,classes[labels[idx]]),color=("green" if preds[idx]==labels[idx].item() else "red"))
return fig
以下是我的训练循环,其中我将全局步骤用作步骤。
for epoch in range(epochs):
epoch_start_time = time.time()
losses = []
total_batch_images = 0
batch_correct_pred = 0
step = 0
#save model
# if batch_accuracy>best_acc:
# best_acc = batch_accuracy
# checkpoint = {'state_dict': model.state_dict(),'acc' : batch_accuracy,'epoch' : epoch,'optimizer': optimizer.state_dict()}
# save_checkpoint(checkpoint)
model.train()
for batch_idx,(images,labels) in enumerate(train_loader):
# Get data to cuda if possible
images = images.to(device=device)
labels = labels.to(device=device)
# forward
scores = model(images)
loss = criterion(scores,labels)
losses.append(loss.item())
# backward
optimizer.zero_grad()
loss.backward()
# gradient descent or adam step
optimizer.step()
# visualizing Dataset images
# img_grid = torchvision.utils.make_grid(images)
# writer.add_image('Xray_images',img_grid,global_step = step)
# calculation running accuracy
model.eval()
_,predictions = scores.max(1)
num_correct = (predictions == labels).sum()
batch_correct_pred += float(num_correct)
total_batch_images += predictions.size(0)
writer.add_figure('predictions vs. actuals',plot_classes_preds(model,labels),global_step=step)
step += 1
我认为问题出在author.add_figure()的最后一行,其中定义了global_step。但是,在这方面的任何帮助将不胜感激。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。