如何解决如何通过 pytorch-lightning 正确使用 Tensorboard 的 TSNE?
我在 MNIST 上运行以下代码
也就是说,我从每个验证时期返回
return {"val_loss": loss,"recon_batch": recon_batch,"label_batch": label_batch,"label_img": orig_batch.view(-1,1,28,28)}
然后使用
mat = torch.cat([o["recon_batch"] for o in outputs])
metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
tb.add_embedding(
mat=mat,metadata=metadata,label_img=label_img,global_step=self.current_epoch,)
并期望它能够工作,就像在 the doc 中一样。
似乎只显示了一个批次,在验证期间我得到的日志如下
验证:92%|█████████▏| 49/53 [00:01
如何为所有时代获得 recon_batch
的有效 TSNE?
完整代码供参考:
def validation_step(self,batch,batch_idx):
if self._config.dataset == "toy":
(orig_batch,noisy_batch),label_batch = batch
# TODO put in the noise here and not in the dataset?
elif self._config.dataset == "mnist":
orig_batch,label_batch = batch
orig_batch = orig_batch.reshape(-1,28 * 28)
noisy_batch = orig_batch
else:
raise ValueError("invalid dataset")
noisy_batch = noisy_batch.view(noisy_batch.size(0),-1)
recon_batch,mu,logvar = self.forward(noisy_batch)
loss = self._loss_function(
recon_batch,orig_batch,logvar,reconstruction_function=self._recon_function
)
tb = self.logger.experiment
tb.add_scalars("losses",{"val_loss": loss},global_step=self.current_epoch)
if batch_idx == len(self.val_dataloader()) - 2:
orig_batch -= orig_batch.min()
orig_batch /= orig_batch.max()
recon_batch -= recon_batch.min()
recon_batch /= recon_batch.max()
orig_grid = torchvision.utils.make_grid(orig_batch.view(-1,28))
val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1,28))
tb.add_image("original_val",orig_grid,global_step=self.current_epoch)
tb.add_image("reconstruction_val",val_recon_grid,global_step=self.current_epoch)
# f,axarr = plt.subplots(2,1)
# axarr[0].imshow(orig_grid.permute(1,2,0).cpu())
# axarr[1].imshow(val_recon_grid.permute(1,0).cpu())
# plt.show()
pass
return {"val_loss": loss,28)}
def validation_epoch_end(self,outputs: List[Any]) -> None:
first_batch_dict = outputs[-1]
self.log(name="val_epoch_end",value={"val_loss": first_batch_dict["val_loss"]})
tb = self.logger.experiment
# assert mat.shape[0] == label_img.shape[0],'#images should equal with #data points'
mat = torch.cat([o["recon_batch"] for o in outputs])
metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
tb.add_embedding(
mat=mat,)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。