如何解决张量流CTC损失函数tf.nn.ctc_loss的确切输入参数是什么?
class CTCLoss(keras.losses.Loss):
def __init__(self,logits_time_major=False,blank_index=-1,reduction=keras.losses.Reduction.AUTO,name='ctc_loss'):
super().__init__(reduction=reduction,name=name)
self.logits_time_major = logits_time_major
self.blank_index = blank_index
def call(self,y_true,y_pred):
y_true = tf.cast(y_true,tf.int32)
y_true = tf.reshape(y_true,[batch_size,max_label_seq_length])
y_pred = tf.reshape(y_pred,[frames,batch_size,num_labels])
loss = tf.nn.ctc_loss(
labels=y_true,logits=y_pred,label_length=4480,logit_length=4480)
return tf.reduce_mean(loss)
model = Sequential()
model.add(Bidirectional(LSTM(35,input_shape=X_train.shape,return_sequences=True)))
# didn't add the hidden layers in this code snippet.
model.add(Flatten())
model.add(Dense((4480),activation='softmax'))
model.compile(optimizer='adam',loss=CTCLoss(),metrics=['accuracy'])
我正在尝试解决在线手写识别问题,并且尝试使用CTC丢失功能。我在上面的代码中尝试使用此类作为CTC损失函数。但是关于抛出的尺寸有一个错误。有人可以解释一下这些参数是什么吗?特别是[frames,batch_size,num_labels]中的“框架”是什么意思。请让我知道此特定代码在哪里出错。我的X_train的形状为(1311,919,3)。谢谢。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。