Tidymodels:使用具有多个预测变量的数据进行慢速超参数调整

如何解决Tidymodels:使用具有多个预测变量的数据进行慢速超参数调整

我目前正在尝试使用 tidymodels 框架在具有 101,064 行和 64 列的数据帧上拟合具有超参数调整的随机森林模型。我有分类和连续预测变量的混合,我的结果变量是一个有 3 个类别的分类变量,所以我有一个多类分类问题。

我遇到的问题是,即使使用并行处理,这个过程也需要大约 6 到 8 个小时才能在我的机器上完成。由于 101,064 不是大量数据,我怀疑我没有正确或有效地(或两者兼而有之!)。不幸的是,由于机密性,我无法共享确切的数据集,但我在下面共享的代码提供了原始数据集的非常接近的副本,从每个分类变量中的级别数到每列中存在的 NA 数。>

我对下面的代码有一些评论,可以让我深入了解我为什么这样做。首先,我根据组 ID 而不是行拆分训练和测试集。数据集嵌套在多个行对应于同一组 ID 的位置。理想情况下,我想要一个可以跨组 ID 学习模式的模型。因此,训练和测试折叠之间不应该有共同的组 ID,交叉验证折叠中的分析和评估折叠之间应该没有共同的组 ID。

其次,我包含了 step_unknown,因为随机森林不喜欢 NA 值。我已包含 step_novel 作为保护措施,以防未来数据具有当前数据未见过的分类级别。我不确定何时使用 step_unknownstep_novel 并且我不确定将它们一起使用是否明智,因此我们将不胜感激。我已经使用 step_otherstep_dummy 对分类预测变量进行了一次热编码。已包含 step_impute_median 以在数据中不包含 NA,以防止随机森林抱怨。 step_downsample 已被用于处理结果变量中的类不平衡,我使用下采样是为了减少模型构建步骤中的观察次数,但它似乎并没有减少训练时间。

我的问题是:

  1. 模型调整需要大约 6 个小时的原因是什么,我可以进一步优化吗?我愿意使用降维,并希望获得一些教程,将其作为使用 tidymodels 框架的受监督 ML 管道的一部分。

  2. 我是否正确指定和使用了食谱?这是我不太确定的事情。我已经在上面提到了我认为我正在做的事情,但这实际上是我正在做的事情吗?这是最好的方法吗?我愿意重新制定食谱步骤。

对此的任何帮助将不胜感激。这个数据集并不大,所以大幅缩短模型训练时间可以让我将其投入生产。

我在我的本地机器上运行此代码,这是一台配备 2.4 GHz、8 核处理器和 32GB 内存的 MacBook Pro。

library(tidyverse)
library(tidymodels)
library(themis)
library(finetune)
library(doParallel)
library(parallel)
library(ranger)
library(future)
library(doFuture)


# Create Synthetic data that closely mimics actual dataset ----
## Categorical predictors
categorical_predictor1 <- rep(c("cat1","cat2","cat3","cat4","cat5"),times = c(43281,29088,9881,8874,9940))
categorical_predictor2 <- rep(c("cat1",times = c(2522,21302,20955,36859,19426))
categorical_predictor3 <- rep(c("cat1","cat2"),times = c(15950,85114))
categorical_predictor4 <- rep(c("cat1","cat5","cat6","cat7"),times = c(52023,16666,13662,7045,2644,1798,7226))
categorical_predictor5 <- rep(c("cat1","cat3"),times = c(52613,14903,33548))
categorical_predictor6 <- rep(c("cat1","cat4"),times = c(13662,18713,52023))
categorical_predictor7 <- rep(c("cat1",NA),times = c(44210,11062,8846,4638,1778,4595,25935))
categorical_predictor8 <- rep(c("cat1",times = c(11062,11011,44210,25935))
categorical_predictor9 <- rep(c("cat1",times = c(11649,10215,9783,7580,5649,30253,25935))
categorical_predictor10 <- rep(c("cat1",times = c(12563,11649,23339,25935))
categorical_predictor11 <- rep(c("cat1",times = c(14037,61092,25935))
categorical_predictor12 <- rep(c("cat1",times = c(15042,35676,23861,26485))


# Outcome variable
outcome_variable <- rep(c("cat1",times = c(21375,49824,29865))

## Continuous Predictors: Values are not normalized
continuous_predictor1 <- runif(n = 101064,min = 0,max = 90)
continuous_predictor2 <- runif(n = 101064,max = 95.4)
continuous_predictor3 <- runif(n = 101064,max = 14.1515)
continuous_predictor4 <- runif(n = 101064,max = 85)
continuous_predictor5 <- runif(n = 101064,max = 71)
continuous_predictor6 <- runif(n = 101064,min = -236,max = 97)
continuous_predictor7 <- runif(n = 101064,min = -40,max = 84)
continuous_predictor8 <- runif(n = 101064,min = 2015,max = 2019)
continuous_predictor9 <- runif(n = 101064,max = 6)
continuous_predictor10 <- runif(n = 101064,min = 2,max = 26)
continuous_predictor11 <- runif(n = 101064,max = 26)
continuous_predictor12 <- runif(n = 101064,min = 0.1365,max = 0.4352)
continuous_predictor13 <- runif(n = 101064,min = 0.1282,max = 0.4860)
continuous_predictor14 <- runif(n = 101064,min = 0.1232,max = 0.4643)
continuous_predictor15 <- runif(n = 101064,max = 0.4885)
continuous_predictor16 <- runif(n = 101064,min = 107,max = 218.6)
continuous_predictor17 <- runif(n = 101064,min = 0.6667,max = 16.333)
continuous_predictor18 <- runif(n = 101064,min = 3.479,max = 7.177)
continuous_predictor19 <- runif(n = 101064,min = 0.8292,max = 3.3100)
continuous_predictor20 <- runif(n = 101064,min = 49.33,max = 101.70)
continuous_predictor21 <- runif(n = 101064,min = 0.07333,max = 0.42534)
continuous_predictor22 <- runif(n = 101064,min = 0.08727,max = 0.41762)
continuous_predictor23 <- runif(n = 101064,min = 0.1241,max = 0.4673)
continuous_predictor24 <- runif(n = 101064,min = 0.07483,max = 0.41192)
continuous_predictor25 <- runif(n = 101064,min = 446.1,max = 561.0)
continuous_predictor26 <- runif(n = 101064,min = 2.333,max = 24)
continuous_predictor27 <- runif(n = 101064,min = 14.52,max = 18.23)
continuous_predictor28 <- runif(n = 101064,min = 0.5463,max = 3.488)
continuous_predictor29 <- runif(n = 101064,min = 150.7,max = 251.9)
continuous_predictor30 <- runif(n = 101064,min = 0.1120,max = 0.4603)
continuous_predictor31 <- runif(n = 101064,min = 0.1231,max = 0.4766)
continuous_predictor32 <- runif(n = 101064,min = 0.1271,max = 0.4857)
continuous_predictor33 <- runif(n = 101064,min = 0.1152,max = 0.4613)
continuous_predictor34 <- runif(n = 101064,min = 238.6,max = 329.4)
continuous_predictor35 <- runif(n = 101064,min = 5.333,max = 19.667)
continuous_predictor36 <- runif(n = 101064,min = 7.815,max = 10.929)
continuous_predictor37 <- runif(n = 101064,min = 0.8323,max = 2.8035)
continuous_predictor38 <- runif(n = 101064,min = 140.9,max = 195.5)
continuous_predictor39 <- runif(n = 101064,min = 0.1098,max = 0.4581)
continuous_predictor40 <- runif(n = 101064,min = 0.08825,max = 0.41360)
continuous_predictor41 <- runif(n = 101064,min = 0.1209,max = 0.4510)
continuous_predictor42 <- runif(n = 101064,min = 0.1048,max = 0.4498)
continuous_predictor43 <- runif(n = 101064,min = 312.2,max = 382.2)
continuous_predictor44 <- runif(n = 101064,min = 2.667,max = 18)
continuous_predictor45 <- runif(n = 101064,min = 10.22,max = 12.49)
continuous_predictor46 <- runif(n = 101064,min = 1.077,max = 2.968)
continuous_predictor47 <- runif(n = 101064,min = 72.18,max = 155.71)

## Continuous Predictors: Values have NAs
continuous_predictor_withNA1 <- c(runif(n = 101064 - 26485,min = 1,max = 3),rep(NA,times = 26485))
continuous_predictor_withNA2 <- c(runif(n = 101064 - 26485,times = 26485))

## Group ID
set.seed(123)
group_id <- sample(c(1,2,3,4,5,6,7,9,10,11,13,14,16,17,18,19,20,21,22,24,25,26,27,28,29,30,31,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,107,109,110,111,112,125,126,161,162,163,164,165,178,179,180,184,185,186,187,188,189,197,198,199,209,210,211,212,213,214,231,232,233,234,239,240,250,251,252,255,256,257,258,259,260,261,508,509,510,602,721,730),size = 101064,replace = TRUE,prob = c(0.010300404,0.003661047,0.005758727,0.002849679,0.005976411,0.006738304,0.004957255,0.008727143,0.007757461,0.00530357,0.00867767,0.003839151,0.007836618,0.004531782,0.007678303,0.013150083,0.003364205,0.005194728,0.002750732,0.005778517,0.009825457,0.010488403,0.009399984,0.006105042,0.011101876,0.006490936,0.008459986,0.003918309,0.009083353,0.001583155,0.005382728,0.013832819,0.004828623,0.004670308,0.007213251,0.006570094,0.006035779,0.007322093,0.002077891,0.000979577,0.006926304,0.007124199,0.005521254,0.007618935,0.00335431,0.002968416,0.005442096,0.016069026,0.005174939,0.001820629,0.008578722,0.00213726,0.00142484,0.014644186,0.006688831,0.003799573,0.008430302,0.004581255,0.002552838,0.012833452,0.00620399,0.004729676,0.005639991,0.010824824,0.010735771,0.004343782,0.008934932,0.005679569,0.004096414,0.011141455,0.011853875,0.00354231,0.006312832,0.001553471,0.009162511,0.006550305,0.007688198,0.002354943,0.002730943,0.005085886,0.004808834,0.013634924,0.006233674,0.007915776,0.006431568,0.003957888,0.005422307,0.002394522,0.00865788,0.008093881,0.002592417,0.001157682,0.004897887,0.002364838,0.004749466,0.009795773,0.007054936,0.003601678,0.006362305,0.00848967,0.011448191,0.005224412,0.007282514,0.007242935,0.008074092,0.009686931,0.00670862,0.003571994,0.008717249,0.007806934,0.004135993,0.006253463,0.006302937,0.007846513,0.003680836,0.006095148,0.00264189,0.004838518,0.001454524,0.004571361,0.005926937,0.002236207,0.007361672,0.006332621,0.011952822,0.013852608,0.009775984,0.013733872,0.007143988,0.006827357,0.00425473,0.007094514,0.013308399,0.007480409,0.007737671,0.004551571,0.00744083,0.012576189,0.008796406,0.010884192,0.0063722,0.01006293))


## Join to make a dataframe
df <- tibble(group_id,categorical_predictor1,categorical_predictor2,categorical_predictor3,categorical_predictor4,categorical_predictor5,categorical_predictor6,categorical_predictor7,categorical_predictor8,categorical_predictor9,categorical_predictor10,categorical_predictor11,categorical_predictor12,continuous_predictor1,continuous_predictor2,continuous_predictor3,continuous_predictor4,continuous_predictor5,continuous_predictor6,continuous_predictor7,continuous_predictor8,continuous_predictor9,continuous_predictor10,continuous_predictor11,continuous_predictor12,continuous_predictor13,continuous_predictor14,continuous_predictor15,continuous_predictor16,continuous_predictor17,continuous_predictor18,continuous_predictor19,continuous_predictor20,continuous_predictor21,continuous_predictor22,continuous_predictor23,continuous_predictor24,continuous_predictor25,continuous_predictor26,continuous_predictor27,continuous_predictor28,continuous_predictor29,continuous_predictor30,continuous_predictor31,continuous_predictor32,continuous_predictor33,continuous_predictor34,continuous_predictor35,continuous_predictor36,continuous_predictor37,continuous_predictor38,continuous_predictor39,continuous_predictor40,continuous_predictor41,continuous_predictor42,continuous_predictor43,continuous_predictor44,continuous_predictor45,continuous_predictor46,continuous_predictor47,continuous_predictor_withNA1,continuous_predictor_withNA2,outcome_variable)

df <- df %>% 
  mutate_if(is.character,as.factor) %>% 
  mutate(.row = row_number())

# Split Data ----
## Split the data while keeping group ids separate,groups will not be split up across training and testing sets
set.seed(123)
holdout_group_id <- sample(unique(df$group_id),size = 5)

indices <- list(
  analysis = df %>% filter(!(group_id %in% holdout_group_id)) %>% pull(.row),assessment = df %>% filter(group_id %in% holdout_group_id) %>% pull(.row)
)

## Remove row column - no longer required
df <- df %>% 
  select(-.row)

split <- make_splits(indices,df)
df_train <- training(split)
df_test <- testing(split)

## Create Cross Validation Folds
set.seed(123)
folds <- group_vfold_cv(df_train,group = "group_id",v = 5)

# Create Recipe ----
## Define a recipe to be applied to the data
df_recipe <- recipe(outcome_variable ~ .,data = df_train) %>% 
  update_role(group_id,new_role = "ID") %>% 
  step_unknown(all_nominal_predictors()) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_other(all_nominal_predictors(),threshold = 0.1,other = "other_category") %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_impute_median(continuous_predictor_withNA1,continuous_predictor_withNA2) %>% 
  themis::step_downsample(all_outcomes(),skip = TRUE) 


# Define Model ----
## Initialise model with tuneable hyperparameters
rf_spec <- rand_forest(trees = tune(),mtry = tune()  ) %>% 
  set_engine("ranger",importance = "permutation") %>% 
  set_mode("classification")

# Define Workflow to connect Recipe and Model ----
rf_workflow <- workflow() %>% 
  add_recipe(df_recipe) %>% 
  add_model(rf_spec)

# Train and Tune Model ----
## Define a random grid for hyperparameters to vary over
set.seed(123)
rf_grid <- grid_latin_hypercube(
  trees(),mtry() %>% finalize(df_train %>% dplyr::select(-group_id,-outcome_variable)),size = 20)

## Tune Model using Parallel Processing
all_cores <- parallel::detectCores(logical=FALSE) - 1
registerDoFuture() # Register backend
cl <- makeCluster(all_cores,setup_strategy = "sequential")

set.seed(123)
rf_tuned <-rf_workflow %>% 
    tune_race_win_loss(resamples = folds,grid = rf_grid,control = control_race(save_pred = TRUE),metrics = metric_set(roc_auc,accuracy)) 

解决方法

我有一些想法可能会有所帮助。

  • 我建议开始时不进行调整,这样您就可以很好地了解事情需要多长时间以及使用未调整的随机森林获得的基线指标。您可能已经这样做了,但通常无论如何调整随机森林并没有得到太大的改进。使用 fit(rf_workflow,df_train) 以便您了解您正在使用的内容并进行调整。
  • 您实际上不需要将 step_dummy() 与随机森林结合使用。它可能不会让你的速度减慢太多,但没有理由添加它。
  • 您几乎肯定不想在重新采样或调整期间设置 importance = "permutation"。无论如何,您不会保留这些模型用于预测,并且计算重要性分数比单独拟合要更长的时间。

如果我去掉 step_dummy() 和重要性评分,我可以在不到一分钟的时间内将您的模型工作流程拟合到此示例数据中。您会将 folds 乘以 5,将您的网格乘以 20,大约 100 分钟左右,无需并行处理或赛车方法(当然,这会很有帮助)。我预计重要性评分是一个大问题,但您应该能够对此进行一些探索并找出答案。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-