如何解决如何正确地从keras fit_generator迁移到fit?
我有2个数据集和一个权重数组。
(train_X,validation_X,train_Y,validation_Y
和sampleW
)
X集是3维的,而Y集是2维的numpy数组。
sampleW
是一维numpy
数组。
如何成功从fit_generator()
迁移到fit()
功能?
在以下方面:
- “ {{1}”是“
fit(x=None,y=None
”吗? - 如何分别传递验证数据? (
train_X,train_Y
) - 我可以像以前一样通过
validation_X,validation_Y
吗? - 如何在
sampleW
上训练分段数据? - 最重要的是:如何在没有生成器的情况下做到这一点?
这是一个最小的可重复性(我目前正在努力找出为什么除1以外的其他任何批处理大小都会产生错误,但> 1也应该可用)
fit()
解决方法
这是由于模型输出和提供的标签的形状不匹配。
模型架构:
如您所见,模型的输出形状为(batch_size,20,6)
,标签的形状为(batch_size,6)
,这是不兼容的。
为什么对batch_size = 1起作用?
这是因为TensorFlow使用了一种称为广播的技术。
例如:
x = np.ones(shape = (1,6))
array([[[1.,1.,1.],[1.,1.]]])
y = np.ones(shape = (1,6))
array([[1.,1.]])
y-x
array([[[0.,0.,0.],[0.,0.]]])
有关更多信息,请参见this。
但是当您使用batch_size = 10
时,广播不再可用。
代码:
x = np.ones(shape = (10,6))
y = np.ones(shape = (10,6))
y-x
输出:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-102-4a65323a80fa> in <module>
1 x = np.ones(shape = (10,6))
2 y = np.ones(shape = (10,6))
----> 3 y-x
ValueError: operands could not be broadcast together with shapes (10,6) (10,6)
可以通过在lstm层之后添加一个平坦层以将2d向量转换为1d向量来固定模型的形状。
代码:
model = Sequential()
model.add(LSTM(242,input_shape=Input_shape,return_sequences=True))
model.add(Dropout(0.3)); model.add(BatchNormalization())
model.add(LSTM(242,return_sequences=True))
model.add(Dropout(0.3)); model.add(BatchNormalization())
model.add(Flatten())
model.add(Dropout(0.3))
model.add(Dense(labels,activation='tanh'))
opt = tf.keras.optimizers.Adam(lr=0.001,decay=1e-6)
model.compile(loss='mean_absolute_error',optimizer=opt,metrics=['mse'])
tf.keras.utils.plot_model(model,'my_first_model.png',show_shapes=True)
模型架构:
最后使用model.fit()
:
model.fit(train_batch_gen,epochs=EPOCHS,validation_data = validation_batch_gen)
输出:
Epoch 1/3
2/2 [==============================] - 1s 708ms/step - loss: 0.2891 - mse: 0.5739 - val_loss: 0.4078 - val_mse: 0.2461
Epoch 2/3
2/2 [==============================] - 0s 46ms/step - loss: 0.2229 - mse: 0.3151 - val_loss: 0.3867 - val_mse: 0.2225
Epoch 3/3
2/2 [==============================] - 0s 49ms/step - loss: 0.2315 - mse: 0.3341 - val_loss: 0.3813 - val_mse: 0.2161
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。