当我尝试训练时,Keras CNN模型引发内存问题

如何解决当我尝试训练时,Keras CNN模型引发内存问题

我是CNN的新手,我正在此dataset上使用Keras制作基本的猫对狗CNN模型,该模型包含12500张猫狗图像,每张图像总共25000张。 我当前处理数据的方法如下:

将所有图像转换为128x128大小->将它们转换为numpy数组->将所有图像转换为黑白图像->将它们除以255以进行归一化->使用数据增强->使用以下方法训练CNN他们

(如果我们使用彩色图像,则会出现内存问题)

这是我要训练的模特:

model = Sequential()

model.add(Conv2D(filters = 64,kernel_size = (5,5),padding = 'Same',activation ='relu',input_shape = (128,128,1)))
model.add(Conv2D(filters = 64,activation ='relu'))
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.25))
          
model.add(Conv2D(filters = 128,kernel_size = (3,3),activation ='relu'))
model.add(Conv2D(filters = 128,2),strides=(2,2)))
model.add(Dropout(0.25))
          

model.add(Conv2D(filters = 32,kernel_size = (2,activation ='relu'))
model.add(Conv2D(filters = 32,2)))
model.add(Dropout(0.25))
          
model.add(Flatten())
model.add(Dense(512,activation = "relu"))
model.add(Dropout(0.5))
model.add(Dense(1,activation = "sigmoid"))
          
          
optimizer = RMSprop(lr=0.001,rho=0.9,epsilon=1e-08,decay=0.0)
model.compile(optimizer = optimizer,loss = "binary_crossentropy",metrics=["accuracy"])
          
learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc',patience=3,verbose=1,factor=0.5,min_lr=0.00001)

但是,每当我尝试开始训练时,即call model.fit_generator,它都会打印Epoch(1/30),然后抛出此错误:

ResourceExhaustedError: 2 root error(s) found.
  (0) Resource exhausted: OOM when allocating tensor with shape[86,64,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node conv2d_4/convolution}}]]
Hint: If you want to see a list of allocated tensors when OOM happens,add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

     [[metrics/accuracy/Identity/_117]]
Hint: If you want to see a list of allocated tensors when OOM happens,add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted: OOM when allocating tensor with shape[86,add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

0 successful operations.
0 derived errors ignored.

并停止训练。

我知道它与我的PC内存有关,因为我正在尝试在本地Windows系统上对其进行培训。 我的问题是,我应该怎么做才能解决这个问题。

我无法进一步降低图像质量,我很想使用黑白图像以减少内存消耗。

我的系统的内存: 8GB RAM, 2GB Nvidia GeForce 940MX显卡

如果有人需要完整的代码,这是我完整的python笔记本link

此外,当我执行from keras.models import Sequential

时,它还会引发以下警告
FutureWarning: Passing (type,1) or '1type' as a synonym of type is deprecated; in a future version of numpy,it will be understood as (type,(1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8",np.int8,1)])

解决方法

您正在将整个数据集加载到主存储器中。如果您的数据集很大,则不建议这样做,因为您几乎总是会用光内存。

对此的一种解决方案是使用TensorFlow的flow_from_directory方法,该方法允许您在需要批次时加载批次,而不是将整个数据集保存在内存中。
代码:

train_datagen = ImageDataGenerator(
        rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',target_size=(150,150),batch_size=32,class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',class_mode='binary')
model.fit(
        train_generator,steps_per_epoch=2000,epochs=50,validation_data=validation_generator,validation_steps=800)

您的代码将如下所示。

借助此功能,您可以进行图像增强以及加载数据而无需将其存储在主存储器中。

有关图像增强选项,请参见this

有关flow_from_directory选项,请参见this

在这里,标签是从目录名称中推断出来的。 您的目录结构应如下所示。

train
    - cat
        - img1
        - imgn
    - dog
        - img1
        - imgn

This是使用上述方法的完整的端到端示例的链接。

注意:您的steps_per_epoch = total_samples / batch_size

如果仍然出现OOM错误。

  1. 尝试减小批量大小
  2. 尝试缩小图像尺寸
  3. 尝试减少图像通道,即(RGB->灰度)
  4. 尝试减小模型尺寸。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?