如何解决Pytorch 1.6中GRU的retain_graph问题
我知道,在使用loss.backward()
时,我们需要指定retain_graph=True
是否存在多个网络和多个损失函数来分别优化每个网络。但是即使(或没有)指定此参数,我也会收到错误。以下是MWE重现此问题(在PyTorch 1.6上)。
import torch
from torch import nn
from torch import optim
torch.autograd.set_detect_anomaly(True)
class GRU1(nn.Module):
def __init__(self):
super(GRU1,self).__init__()
self.brnn = nn.GRU(input_size=2,bidirectional=True,num_layers=1,hidden_size=100)
def forward(self,x):
return self.brnn(x)
class GRU2(nn.Module):
def __init__(self):
super(GRU2,self).__init__()
self.brnn = nn.GRU(input_size=200,hidden_size=1)
def forward(self,x):
return self.brnn(x)
gru1 = GRU1()
gru2 = GRU2()
gru1_opt = optim.Adam(gru1.parameters())
gru2_opt = optim.Adam(gru2.parameters())
criterion = nn.MSELoss()
for i in range(100):
gru1_opt.zero_grad()
gru2_opt.zero_grad()
vector = torch.randn((15,100,2))
gru1_output,_ = gru1(vector) # (15,200)
loss_gru1 = criterion(gru1_output,torch.randn((15,200)))
loss_gru1.backward(retain_graph=True)
gru1_opt.step()
gru1_output,200)
gru2_output,_ = gru2(gru1_output) # (15,2)
loss_gru2 = criterion(gru2_output,2)))
loss_gru2.backward(retain_graph=True)
gru2_opt.step()
print(f"GRU1 loss: {loss_gru1.item()},GRU2 loss: {loss_gru2.item()}")
将retain_graph
设置为True
时出现错误
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100,300]],which is output 0 of TBackward,is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
没有参数的错误是
RuntimeError: Trying to backward through the graph a second time,but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
这是预期的。
请指出在上面的代码中需要进行哪些更改才能开始训练。任何帮助表示赞赏。
解决方法
在这种情况下,可以分离计算图以排除不需要优化的参数。在这种情况下,应在第二次向前通过gru1
后分离计算图,即
....
gru1_opt.step()
gru1_output,_ = gru1(vector)
gru1_output = gru1_output.detach()
....
这样,您就不会像所提到的错误那样“尝试第二次向后浏览图形”。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。