RuntimeError:向后Cudnn RNN只能在训练模式下调用

如何解决RuntimeError:向后Cudnn RNN只能在训练模式下调用

我第一次看到这个问题,在以前的Python项目中我从未遇到过这样的错误。这是我的训练代码:

def train(net,opt,criterion,ucf_train,batchsize,i):
    opt.zero_grad()
    total_loss = 0
    net=net.eval()
    net=net.train()
    for vid in range(i*batchsize,i*batchsize+batchsize,1):
    
        output=infer(net,ucf_train[vid])
        m=get_label_no(ucf_train[vid])
        m=m.cuda( )
        loss = criterion(output,m)
        loss.backward(retain_graph=True)
        total_loss += loss 
        opt.step()       #updates wghts and biases

    return total_loss/n_points

推断代码(网络,输入)

def infer(net,name):
    net.eval()
    hidden_0 = net.init_hidden()
    hidden_1 = net.init_hidden()
    hidden_2 = net.init_hidden()
    video_path = fetch_ucf_video(name)
    cap = cv2.VideoCapture(video_path)
    resize=(224,224)
    T=FrameCapture(video_path)
    print(T)
    lim=T-(T%20)-2
    i=0
    while(1):
      ret,frame2 = cap.read()
      frame2= cv2.resize(frame2,resize)
    #  print(type(frame2))
      if (i%20==0 and i<lim):
          input=normalize(frame2)     
          input=input.cuda()       
          output,hidden_0,hidden_1,hidden_2  = net(input,hidden_2)
      elif (i>=lim):
          break
      i=i+1 
    op=output  
    torch.cuda.empty_cache() 
    op=op.cuda() 
    return op 

我收到此错误,我在this之后尝试model.train(),其中net是我的模型:

 RuntimeError                              Traceback (most recent call last)
<ipython-input-62-42238f3f6877> in <module>()
----> 1 train(net1,1,0)

2 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors,grad_tensors,retain_graph,create_graph,grad_variables)
    125     Variable._execution_engine.run_backward(
    126         tensors,--> 127         allow_unreachable=True)  # allow_unreachable flag
    128 
    129 

RuntimeError: cudnn RNN backward can only be called in training mode

解决方法

您应该删除net.eval()之后的def infer(net,name):呼叫

需要删除它,因为您在训练代码中调用了此推断函数。在整个训练过程中,您的模型都必须处于训练模式。

在调用eval之后,您也永远不会将模型重新设置为可训练,因此这是所得到异常的根源。如果要在测试用例中使用此推断代码,则可以使用if覆盖该用例。

net.eval()赋值之后紧接的total_loss=0也没有用,因为在此之后立即调用net.train()。您也可以删除它,因为它会在下一行中和。

更新的代码

def train(net,opt,criterion,ucf_train,batchsize,i):
    opt.zero_grad()
    total_loss = 0
    net=net.train()
    for vid in range(i*batchsize,i*batchsize+batchsize,1):
        output=infer(net,ucf_train[vid])
        m=get_label_no(ucf_train[vid])
        m=m.cuda( )
        loss = criterion(output,m)
        loss.backward(retain_graph=True)
        total_loss += loss 
        opt.step()       #updates wghts and biases

    return total_loss/n_points

推断代码(净值,输入)

def infer(net,name,is_train=True):
    if not is_train:
        net.eval()
    hidden_0 = net.init_hidden()
    hidden_1 = net.init_hidden()
    hidden_2 = net.init_hidden()
    video_path = fetch_ucf_video(name)
    cap = cv2.VideoCapture(video_path)
    resize=(224,224)
    T=FrameCapture(video_path)
    print(T)
    lim=T-(T%20)-2
    i=0
    while(1):
      ret,frame2 = cap.read()
      frame2= cv2.resize(frame2,resize)
      #  print(type(frame2))
      if (i%20==0 and i<lim):
          input=normalize(frame2)     
          input=input.cuda()       
          output,hidden_0,hidden_1,hidden_2  = net(input,hidden_2)
      elif (i>=lim):
          break
      i=i+1 
    op=output  
    torch.cuda.empty_cache() 
    op=op.cuda() 
    return op 

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-