如何解决在 R 中过滤火炬数据集
我正在努力学习《Deep Learning with PyTorch》这本书。我正在使用新的 R 包 torch
和 torchvision
。
在第 173 页的 7.2.1 节中,我只是不确定如何过滤此数据集以仅包含标签 1 和 3(对应于书中的 0 和 2)。
这是我的代码,我想知道如何按照书中的代码过滤 transformed_cifar10
。含义对其进行过滤,使 transformed_cifar10$y
标签仅包含 1 和 3。然后将 {1,3} 重新映射到 {1,2}。
library(dplyr)
library(torch)
library(torchvision)
data_path <- "./ch7/data" # need to change this?
train_transforms <- function (img) {
img %>%
transform_to_tensor() %>%
transform_normalize(mean = c(0.4915,0.4823,0.4468),std = c(0.2470,0.2435,0.2616))
}
transformed_cifar10 <- cifar10_dataset(data_path,train = TRUE,download = TRUE,transform = train_transforms)
这是书中的python代码:
# In[5]:
label_map = {0: 0,2: 1}
class_names = ['airplane','bird']
cifar2 = [(img,label_map[label])
for img,label in cifar10
if label in [0,2]]
起初我想尝试这样的事情,但显然它不起作用......
tensor_cifar10[tensor_cifar10$y == 1]
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。