如何解决注意下Keras序列与序列模型中的维度不匹配
我正在尝试使用attention
构建神经机器翻译模型。我正在关注Keras blog上的教程,该教程显示了如何使用序列到序列方法(无需注意)来构建NMT模型。我通过以下方式扩展了模型以合并attention
-
latent_dim = 300
embedding_dim=100
batch_size = 128
# Encoder
encoder_inputs = keras.Input(shape=(None,num_encoder_tokens))
#encoder lstm 1
encoder_lstm = tf.keras.layers.LSTM(latent_dim,return_sequences=True,return_state=True,dropout=0.4,recurrent_dropout=0.4)
encoder_output,state_h,state_c = encoder_lstm(encoder_inputs)
print(encoder_output.shape)
# Set up the decoder,using `encoder_states` as initial state.
decoder_inputs = keras.Input(shape=(None,num_decoder_tokens))
decoder_lstm = tf.keras.layers.LSTM(latent_dim,recurrent_dropout=0.2)
decoder_output,decoder_fwd_state,decoder_back_state = decoder_lstm(decoder_inputs,initial_state=[state_h,state_c])
# Attention layer
attn_out = tf.keras.layers.Attention()([encoder_output,decoder_output])
# Concat attention input and decoder LSTM output
decoder_concat_input = tf.keras.layers.Concatenate(axis=-1,name='concat_layer')([decoder_output,attn_out])
#dense layer
decoder_dense = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(num_decoder_tokens,activation='softmax'))
decoder_outputs = decoder_dense(decoder_concat_input)
# Define the model
attn_model = tf.keras.Model([encoder_inputs,decoder_inputs],decoder_outputs)
attn_model.summary()
训练模型-
attn_model.compile(
optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"]
)
history = attn_model.fit(
[encoder_input_data,decoder_input_data],decoder_target_data,batch_size=batch_size,epochs=5,validation_split=0.2,)
我的身材矮小
encoder_input_data.shape
是(10000,16,71)
decoder_input_data.shape
是(10000,59,92)
decoder_target_data.shape
是(10000,92)
训练该模型时,出现以下错误:
InvalidArgumentError: Dimension 1 in both shapes must be equal,but are 59 and 16. Shapes are [?,59] and [?,16]. for 'model/concat_layer/concat' (op: 'ConcatV2') with input shapes: [?,300],[?,[] and with computed input tensors: input[2] = <2>.
我了解到它在抱怨encoder_input_data
和decoder_input_data
的尺寸,但是当我们运行{{3 }}。在这种情况下,由于Concatenation
层而引发错误。
有人可以建议如何解决此问题吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。