使用Flux运行功能时未更新模型

如何解决使用Flux运行功能时未更新模型

所以,我正在用一些数据训练神经网络。我在REPL上运行了所有程序,并且运行的方向正确,但是当我尝试将代码包装在function()中时,我的model(x)停止更新:

所以,我的代码如下:

using Flux;
using Flux.Optimise: update!;
using Flux: normalise;
using Flux: onecold;
using Flux: onehotbatch;
using Flux: @epochs;
using Flux: throttle;

Random.seed!(125);

ep_max  = 2;  # number of epochs
batch   = 100;   # batch size for training
lr      = 0.001  # learning rate
spt     = 0.01;   # Split ratio: define % to be used as Test data
opt     = ADAM(lr,(0.9,0.8)); # Optimizer
time_show = 5;
dat_groups = 1:10;
dat_num    = 100;
creating   = false;
reading    = true;

if creating
  data_creator(dat_groups,dat_num); # I create data and store it
end
if reading
  xtrain,ytrain = data_reader(dat_groups); #reads data
end

# batching data
datatrain,datatest = getdata(xtrain',ytrain',spt,batch); # DataLoader function in here                                          

xtr,ytr = recoverdata(datatrain); # recovering training data,to be used if needed
xts,yts = recoverdata(datatest);  # recovering test data,to be used if needed
m = layers(size(xtr,1),size(ytr,1)); # creates layers (6 layers,tanh)
ps = Flux.params(m);                 # initialize parameters

trainmode!(m,true)
evalcb = () -> @show(loss_all(datatrain,m))

for i = 1:ep_max # run ep_max times for a single batch
  println()
  println("**************")
  println(i)
  println()
  Flux.train!(loss,ps,datatrain,opt,cb = throttle(evalcb,time_show));
  println()
  @show accuracy(datatest,m)
end
function accuracy(dataloader,model)
    acc = 0
    for (x,y) in dataloader
        println()
        mod_x = model(x); # model evaluation

        cpu_mod_x = cpu(mod_x);
        cpu_y     = cpu(y);
        one_cpu_mod_x = onecold(cpu_mod_x);
        one_cpu_y     = onecold(cpu_y)

        @show mod_x # HERE IS WHERE THINGS GO WRONG 
        acc += sum(one_cpu_mod_x .== one_cpu_y)*1 / size(x,2)
    end
    acc/length(dataloader);
end
loss(x,y)   = Flux.mse(m(x),y);

例如,当我运行这段代码进行2个时期的迭代时,它给出了:

julia> include("main.jl")
size of X data is :(1000,63)
size of Y data is :(1000,6)


[ Info: Batching data...
[ Info: splitting into 990.0,10.0

[ Info: Batching train data...
[ Info: Batching test data...
┌ Warning: Number of data points less than batchsize,decreasing the batchsize to 10
└ @ Flux.Data ~/.julia/packages/Flux/Fj3bt/src/data/dataloader.jl:64
[ Info: layers created....


**************
1

loss_all(datatrain,m) = 0.3405524244181338


mod_x = Float32[0.21500134 0.25191692 0.28280517 0.14269947 0.12386108 0.22535957 0.38209096 0.21966429 0.061912186 0.32045293; 0.30720016 0.36394575 0.23585278 0.019663436 0.033996515 0.37338153 0.22447488 0.21927631 0.22822481 0.124495685; 0.039474137 -0.13775912 -0.0623653 0.021980956 -0.028107032 -0.027529262 -0.06072978 -0.13554919 -0.04740917 -0.020533875; 0.05143341 0.13719048 0.08347133 0.008867923 0.09923494 0.058163155 0.13347353 0.14189252 0.001730077 0.14392109; 0.119510576 0.07049953 0.05730217 0.5498258 -0.33574563 0.32612923 0.3832937 -0.06748764 0.2360552 0.15549593; 0.33197474 0.16447222 0.27249426 -0.15527818 0.2785189 0.34654236 0.124443345 0.18982176 0.26248497 0.16329157]
acc += (sum(one_cpu_mod_x .== one_cpu_y) * 1) / size(x,2) = 0.2
accuracy(datatest,m) = 0.2

**************
2

loss_all(datatrain,m) = 0.22800623235686412


mod_x = Float32[0.45594802 0.4247107 0.42235023 0.27602637 0.38965002 0.37095627 0.46256495 0.4262407 0.25281948 0.53176546; 0.31144667 0.32339665 0.24235432 0.16192524 0.18050455 0.41499415 0.21660031 0.38733715 0.43207392 0.23064193; 0.16366918 0.008529371 0.036853492 0.018185081 0.057695292 0.12094624 0.07630184 -0.011614937 0.012737181 0.173724; 0.14474952 0.19103736 0.1090886 0.08852501 0.14772236 0.10033486 0.12594518 0.16158527 0.08090371 0.16053662; 0.32059592 0.13817215 0.25556487 0.3619385 0.1361927 0.34184596 0.42664242 0.20382118 0.15213369 0.428005; 0.38914698 0.34429085 0.43361163 0.1414494 0.38538712 0.49637955 0.32894653 0.38855922 0.5757681 0.2794177]
acc += (sum(one_cpu_mod_x .== one_cpu_y) * 1) / size(x,2) = 0.3
accuracy(datatest,m) = 0.3

模型(x)似乎发生了变化(更新),实际上成本函数收敛了。同时,如果我将上面的所有代码(main.jl)放在这样的函数中:

function all_the_code()
    ....
    ...
    for i = 1:ep_max # run ep_max times for a single batch
      println()
      println("**************")
      println(i)
      println()
      Flux.train!(loss,time_show));
      println()
      @show accuracy(datatest,m)
    end
return
end

我明白了

julia> all_the_code()
size of X data is :(1000,m) = 0.36721910246630546


mod_x = Float32[-0.23175366 -0.0057259724 0.082216755 -0.028256565 0.0046515726 5.131215f-5 0.24094917 0.2069467 -0.12277043 0.25271556; 0.3987115 0.698753 0.41566908 0.04473591 -0.25956866 0.115343675 0.5015237 -0.32195306 0.14039147 -0.6866529; -0.061519355 -0.40676624 -0.23660009 -0.04070711 0.07426359 -0.058668774 -0.32793295 0.096719205 -0.103397384 0.3026058; -0.19165385 0.047038484 0.08517692 -0.10537018 0.107321024 -0.033743735 0.14160846 0.16544174 -0.061927572 0.08228015; 0.080324054 0.095603675 -0.37848675 0.8972529 -0.84439826 0.3704224 0.25929335 -0.44578144 0.2532668 0.09203011; 0.2906073 0.16897422 0.2946054 -0.3970689 0.16302669 0.034337603 0.072528295 -0.09895017 -0.008754879 -0.30081]
acc += (sum(one_cpu_mod_x .== one_cpu_y) * 1) / size(x,m) = 0.3670469086875309


mod_x = Float32[-0.23175366 -0.0057259724 0.082216755 -0.028256565 0.0046515726 5.131215f-5 0.24094917 0.2069467 -0.12277043 0.25271556; 0.3987115 0.698753 0.41566908 0.04473591 -0.25956866 0.115343675 0.5015237 -0.32195306 0.14039147 -0.6866529; -0.061519355 -0.40676624 -0.23660009 -0.04070711 0.07426359 -0.058668774 -0.32793295 0.096719205 -0.103397384 0.3026058; -0.19165385 0.047038484 0.08517692 -0.10537018 0.107321024 -0.033743735 0.14160846 0.16544174 -0.061927572 0.08228015; 0.080324054 0.095603675 -0.37848675 0.8972529 -0.84439826 0.3704224 0.25929335 -0.44578144 0.2532668 0.09203011; 0.2906073 0.16897422 0.2946054 -0.3970689 0.16302669 0.034337603 0.072528295 -0.09895017 -0.008754879 -0.30081]
acc += (sum(one_cpu_mod_x .== one_cpu_y) * 1) / size(x,m) = 0.2

因此,您看到了,该模型没有更新(成本函数始终保持在0.36附近)。无论迭代多少次。发生了什么事?

我不希望继续使用REPL。我总是尽量避免使用全局变量,因此我需要进入function()内,但是我不知道为什么会这样而盲目。

PD:两个实验的数据相同。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-