如何解决在 R 中绘制图形
我使用这个网站作为参考 https://www.r-bloggers.com/2021/02/how-to-build-a-handwritten-digit-classifier-with-r-and-random-forests/
使用 R 和随机森林编写手写数字分类器。
是否可以构建代码末尾获得的 colMeans 的图? MNIST 训练和测试数据集(您可以在上面的链接中找到)没有任何列标题。 我是 R 新手,仍在学习。任何形式的帮助将不胜感激。
代码如下:
library(readr)
#loading the train and test sets of MNIST dataset
train_set <- read_csv("mnist_train.csv",col_names = FALSE)
test_set <- read_csv("mnist_test.csv",col_names = FALSE)
#extracting the labels
#converting digits to factors
train_labels <- as.factor(train_set[,1]$X1)
test_labels <- as.factor(test_set[,1]$X1)
#printing the first 10 labels
head(train_labels,10)
#printing number of records for each digit (0 to 9)
summary(train_labels)
#importing random forest
library(randomForest)
#training the model
rf <- randomForest(x = train_set,y = train_labels,xtest = test_set,ntree = 50)
rf
#1- error rate
#represents the accuracy
1 - mean(rf$err.rate)
#importing dplyr
library(dplyr)
#error rate for every digit
err_df <- as.data.frame(rf$err.rate)
err_df %>%
select(-"OOB") %>%
colMeans()
colMeans 的输出1
解决方法
我通过对训练集和测试集进行相当多的子集化来稍微修改您的代码以加快分析速度。您可以自由评论/删除相关行。请看看下面的代码,告诉我这是否是您要找的。p>
library(readr)
#importing dplyr
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter,lag
#> The following objects are masked from 'package:base':
#>
#> intersect,setdiff,setequal,union
#importing random forest
library(randomForest)
#> randomForest 4.6-14
#> Type rfNews() to see new features/changes/bug fixes.
#>
#> Attaching package: 'randomForest'
#> The following object is masked from 'package:dplyr':
#>
#> combine
library(ggplot2)
#>
#> Attaching package: 'ggplot2'
#> The following object is masked from 'package:randomForest':
#>
#> margin
#loading the train and test sets of MNIST dataset
train_set <- read_csv("~/Downloads/mnist_train.csv",col_names = FALSE)
#>
#> ── Column specification ────────────────────────────────────────────────────────
#> cols(
#> .default = col_double()
#> )
#> ℹ Use `spec()` for the full column specifications.
test_set <- read_csv("~/Downloads/mnist_test.csv",col_names = FALSE)
#>
#> ── Column specification ────────────────────────────────────────────────────────
#> cols(
#> .default = col_double()
#> )
#> ℹ Use `spec()` for the full column specifications.
#extracting the labels
#converting digits to factors
train_labels <- as.factor(train_set[,1]$X1)
test_labels <- as.factor(test_set[,1]$X1)
#printing the first 10 labels
head(train_labels,10)
#> [1] 5 0 4 1 9 2 1 3 1 4
#> Levels: 0 1 2 3 4 5 6 7 8 9
#printing number of records for each digit (0 to 9)
summary(train_labels)
#> 0 1 2 3 4 5 6 7 8 9
#> 5923 6742 5958 6131 5842 5421 5918 6265 5851 5949
# reducing size
train_set <- train_set[ 1:1000,]
train_labels <- train_labels[ 1:1000 ]
test_set <- test_set[ 1:100,]
test_labels <- test_labels[ 1:100 ]
#training the model
rf <- randomForest(x = train_set,y = train_labels,xtest = test_set,ntree = 50)
rf
#>
#> Call:
#> randomForest(x = train_set,ntree = 50)
#> Type of random forest: classification
#> Number of trees: 50
#> No. of variables tried at each split: 28
#>
#> OOB estimate of error rate: 11.6%
#> Confusion matrix:
#> 0 1 2 3 4 5 6 7 8 9 class.error
#> 0 96 0 0 0 0 0 1 0 0 0 0.01030928
#> 1 0 112 1 1 0 1 0 0 0 1 0.03448276
#> 2 2 6 82 0 2 0 1 4 2 0 0.17171717
#> 3 0 1 2 78 2 5 1 1 2 1 0.16129032
#> 4 0 0 1 0 94 1 2 1 1 5 0.10476190
#> 5 0 0 1 8 3 77 1 0 0 2 0.16304348
#> 6 1 0 1 0 2 2 86 1 1 0 0.08510638
#> 7 0 3 3 2 4 0 0 102 0 3 0.12820513
#> 8 0 1 1 3 1 6 1 1 71 2 0.18390805
#> 9 1 0 0 1 4 1 1 5 1 86 0.14000000
#1- error rate
#represents the accuracy
1 - mean(rf$err.rate)
#> [1] 0.8012579
#error rate for every digit
err_df <- as.data.frame(rf$err.rate)
mymeans <- err_df %>%
select(-"OOB") %>%
colMeans()
# I build a data.frame containing the indexes and the means
toplot <- data.frame( index = seq_len( length( mymeans ) ) - 1,col_means = mymeans )
# this is to plot via ggplot2
ggplot( toplot,aes( x = index,y = col_means ) ) +
geom_line() +
geom_point() +
scale_x_continuous(breaks = seq_len( length( mymeans ) ) - 1 )
由 reprex package (v0.3.0) 于 2021 年 2 月 16 日创建
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。