如何解决为什么我在Keras中的resnet50模型无法收敛?
我目前正在尝试对缺陷和非缺陷图像中的集成电路进行分类。我已经尝试过VGG16和InceptionV3,并且两者都获得了非常好的结果(95%的验证准确度和低val损失)。现在我想尝试resnet50,但是我的模型没有收敛。它的准确度也达到95%,但是当val acc停留在50%时,验证损失不断增加。
到目前为止,这是我的脚本:
from keras.applications.resnet50 import ResNet50
from keras.optimizers import Adam
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense,GlobalAveragePooling2D,Dropout
from keras import backend as K
from keras_preprocessing.image import ImageDataGenerator
import tensorflow as tf
class ResNet:
def __init__(self):
self.img_width,self.img_height = 224,224 # Dimensions of cropped image
self.classes_num = 2 # Number of classifications
# Training configurations
self.epochs = 32
self.batch_size = 16 # Play with this to determine number of images to train on per epoch
self.lr = 0.0001
def build_model(self,train_path):
train_data_path = train_path
train_datagen = ImageDataGenerator(rescale=1. / 255,validation_split=0.25)
train_generator = train_datagen.flow_from_directory(
train_data_path,target_size=(self.img_height,self.img_width),color_mode="rgb",batch_size=self.batch_size,class_mode='categorical',subset='training')
validation_generator = train_datagen.flow_from_directory(
train_data_path,subset='validation')
# create the base pre-trained model
base_model = ResNet50(weights='imagenet',include_top=False,input_shape= (self.img_height,self.img_width,3))
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024,activation='relu')(x)
#x = Dropout(0.3)(x)
# and a logistic layer -- let's say we have 200 classes
predictions = Dense(2,activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input,outputs=predictions)
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
layer.trainable = True
# compile the model (should be done *after* setting layers to non-trainable)
opt = Adam(self.lr) #,decay=self.INIT_LR / self.NUM_EPOCHS)
model.compile(opt,loss='binary_crossentropy',metrics=["accuracy"])
# train the model on the new data for a few epochs
from keras.callbacks import ModelCheckpoint,EarlyStopping
import matplotlib.pyplot as plt
checkpoint = ModelCheckpoint('resnetModel.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=False,mode='auto',period=1)
early = EarlyStopping(monitor='val_accuracy',min_delta=0,patience=16,mode='auto')
hist = model.fit_generator(steps_per_epoch=self.batch_size,generator=train_generator,validation_data=validation_generator,validation_steps=self.batch_size,epochs=self.epochs,callbacks=[checkpoint,early])
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title("model accuracy")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.legend(["Accuracy","Validation Accuracy","loss","Validation Loss"])
plt.show()
plt.figure(1)
import tensorflow as tf
if __name__ == '__main__':
x = ResNet()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
x.build_model("C:/Users/but/Desktop/dataScratch/Train")
这是模型的训练
resnet失败但vgg和inception起作用的原因可能是什么? 我的脚本中有任何错误吗?
解决方法
至少对于代码而言,我看不到任何可能影响培训过程的错误。
# and a logistic layer -- let's say we have 200 classes
predictions = Dense(2,activation='softmax')(x)
这些行有点可疑。但是看来错字在评论中,所以应该没事。
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
layer.trainable = True
这些也是可疑的。如果您想冻结ResNet-50的图层,您需要做的是
...
base_model = ResNet50(weights='imagenet',include_top=False,input_shape=(self.img_height,self.img_width,3))
for layer in base_model.layers:
layer.trainable = False
...
但是事实证明layer.trainable = True
实际上是您的意图,所以也没有关系。
首先,如果使用与培训VGG16和Inception V3相同的代码,则该代码不太可能是问题所在。
为什么不检查以下易受感染的原因?
- 该模型可能太小/太大,以至于无法满足/过度拟合。 (参数数量)
- 该模型可能需要更多时间才能收敛。 (培训更多时代)
- ResNet可能不适用于此分类。
- 您使用的预训练权重可能不适合此分类。
- 学习率可能太小/太大。
- 等...
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。