如何解决在pytorch中,自制数据集和测试数据集似乎耗尽了所有RAM
在pytorch中,自制数据集和测试数据集似乎耗尽了所有RAM
我是 pytorch 的新手,我在 MNIST 上的 pytorch 中编写了一个 ResNet 程序用于实验。
如果我使用如下的数据加载器,那就没问题了:
import torch as pt
from torch.utils.data import DataLoader,TensorDataset
import torchvision as ptv
mnist_train = ptv.datasets.MNIST(ROOT_DIR,train=True,transform=ptv.transforms.ToTensor(),download=False)
dl = pt.utils.data.DataLoader(dataset=mnist_train,batch_size=BATCH_SIZE,shuffle=True,drop_last=True)
如果我使用如下自制数据集在每次迭代时使用验证集,程序将耗尽我所有的 RAM。测试集不是在每次迭代中使用,而是在最后评估模型。
mnist_test = ptv.datasets.MNIST(ROOT_DIR,train=False,download=False)
M_TEST,PIC_H,PIC_W = mnist_test.data.shape
x_test = mnist_test.data.double() / 255.
y_test = mnist_test.targets
a = pt.randperm(M_TEST) # ATTENTION pt.randperm
x_test = x_test[a]
y_test = y_test[a]
VAL_RATE = 0.1
M_VAL = int(np.ceil(M_TEST * VAL_RATE))
M_TEST -= M_VAL
x_test,x_val = pt.split(x_test,(M_TEST,M_VAL))
y_test,y_val = pt.split(y_test,M_VAL))
x_test = x_test.view(-1,1,PIC_W).double()
x_val = x_val.view(-1,PIC_W).double()
dl_test = DataLoader(TensorDataset(x_test,y_test),batch_size=BATCH_SIZE)
def acc(ht,yt):
return (pt.argmax(ht,1) == yt.long()).double().mean()
# in iteration:
for epoch in range(N_EPOCHS):
for i,(bx,by) in enumerate(dl):
model.train(True)
optim.zero_grad()
bx = bx.view(-1,PIC_W).double()
ht = model(bx)
cost = criterion(ht,by)
cost.backward()
optim.step()
model.train(False)
accv = acc(ht,by)
ht_val = model(x_val)
val_cost = criterion(ht_val,y_val)
val_acc = acc(ht_val,y_val)
所以我怀疑只有 ptv.datasets.MNIST 和 pt.utils.data.DataLoader 可用,所以我在每次迭代时删除了我自制验证集的使用;并且移除后内存使用正常。但是即使我只使用 ptv.datasets.MNIST 和 pt.utils.data.DataLoader ,测试进度仍然耗尽了我所有的 RAM:
mnist_test = ptv.datasets.MNIST(ROOT_DIR,download=False)
dl_test = pt.utils.data.DataLoader(dataset=mnist_test,shuffle=False,drop_last=True)
test_cost_avg = 0.
test_acc_avg = 0.
GROUP = int(np.ceil(M_TEST / BATCH_SIZE / 10))
for i,by) in enumerate(dl_test):
bx = bx.view(-1,PIC_W).double()
ht = model(bx)
test_cost_avg += criterion(ht,by)
test_acc_avg += acc(ht,by)
if i % GROUP == 0:
print(f'Testing # {i + 1}')
if i % GROUP != 0:
print(f'Testing # {i + 1}')
test_cost_avg /= i + 1
test_acc_avg /= i + 1
print(f'Tested: cost = {test_cost_avg},acc = {test_acc_avg}')
print('Over')
请帮帮我。非常感谢!
更新:
我怀疑我的模型有问题,因为我在 pytorchvision 的 MNIST 自制数据集上有一个简单的 CNN 模型,没有这个 RAM 耗尽问题。所以我将我的模型粘贴到这个问题中,如下仅供参考:
def my_conv(in_side,in_ch,out_ch,kernel,stride,padding='same'):
if 'same' == padding:
ps = kernel - 1
padding = ps // 2
else:
padding = 0
print(padding) # tmp
return pt.nn.Conv2d(in_ch,kernel_size=kernel,stride=stride,padding=padding)
class MyResnetBlock(pt.nn.Module):
def __init__(self,residual,in_side,kernel=3,stride=1,**kwargs):
super().__init__(**kwargs)
self.residual = residual
self.in_side = in_side
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel = kernel
self.stride = stride
self.conv1 = my_conv(in_side,stride)
self.bn1 = pt.nn.BatchNorm2d(out_ch)
self.relu1 = pt.nn.ReLU()
self.conv2 = my_conv(np.ceil(in_side / stride),1)
self.bn2 = pt.nn.BatchNorm2d(out_ch)
self.relu2 = pt.nn.ReLU()
if residual:
self.conv_down = my_conv(in_side,stride)
def forward(self,input):
x = input
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.residual:
res = self.conv_down(input)
else:
res = input
x += res
x = self.relu2(x)
return x
class MyResnetByPt(pt.nn.Module):
def __init__(self,blocks_spec_list,init_in_ch,init_out_ch,**kwargs):
super().__init__(**kwargs)
self.conv1 = my_conv(in_side,3,1)
in_ch = out_ch = init_out_ch
blocks = []
for block_id,n_blocks in enumerate(blocks_spec_list):
for layer_id in range(n_blocks):
if layer_id == 0:
if block_id != 0:
out_ch *= 2
block = MyResnetBlock(True,2)
in_ch = out_ch
in_side = int(np.ceil(in_side / 2))
else:
block = MyResnetBlock(False,1)
blocks.append(block)
self.blocks = pt.nn.Sequential(*blocks)
self.final_ch = out_ch
self.avg_pool = pt.nn.AvgPool2d(kernel_size=(in_side,in_side),stride=(1,1),padding=(0,0))
self.fc = pt.nn.Linear(out_ch,N_CLS)
def forward(self,input):
x = input
x = self.conv1(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = x.view(-1,self.final_ch)
x = self.fc(x)
return x
model = MyResnetByPt([2,2,2],16)
model = model.double()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。