如何解决NotImplementedError:在__init__中带有参数的图层必须覆盖get_config
这不是一个错误,这是一个功能。
此错误使您知道TF无法保存模型,因为它无法加载模型。
具体来说,它将无法重新实例化您的自定义Layer
类:
和
。
层配置是包含层配置的Python字典(可序列化)。稍后可以从此配置中重新实例化同一层(没有经过训练的权重)。
例如,如果您的encoder
班级看起来像这样:
class encoder(tf.keras.layers.Layer):
def __init__(
self,
vocab_size, num_layers, units, d_model, num_heads, dropout,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.num_layers = num_layers
self.units = units
self.d_model = d_model
self.num_heads = num_heads
self.dropout = dropout
# Other methods etc.
那么您只需要重写此方法:
def get_config(self):
config = super().get_config().copy()
config.update({
'vocab_size': self.vocab_size,
'num_layers': self.num_layers,
'units': self.units,
'd_model': self.d_model,
'num_heads': self.num_heads,
'dropout': self.dropout,
})
return config
当TF看到这一点(针对两个类)时,您将能够保存模型。
因为现在加载模型时,TF将能够从config重新实例化同一层。
的源代码可以更好地了解其工作方式:
@classmethod
def from_config(cls, config):
return cls(**config)
解决方法
我正在尝试使用来保存我的TensorFlow模型model.save()
-我收到此错误。
变压器模型的代码:
def transformer(vocab_size,num_layers,units,d_model,num_heads,dropout,name="transformer"):
inputs = tf.keras.Input(shape=(None,),name="inputs")
dec_inputs = tf.keras.Input(shape=(None,name="dec_inputs")
enc_padding_mask = tf.keras.layers.Lambda(
create_padding_mask,output_shape=(1,1,None),name='enc_padding_mask')(inputs)
# mask the future tokens for decoder inputs at the 1st attention block
look_ahead_mask = tf.keras.layers.Lambda(
create_look_ahead_mask,None,name='look_ahead_mask')(dec_inputs)
# mask the encoder outputs for the 2nd attention block
dec_padding_mask = tf.keras.layers.Lambda(
create_padding_mask,name='dec_padding_mask')(inputs)
enc_outputs = encoder(
vocab_size=vocab_size,num_layers=num_layers,units=units,d_model=d_model,num_heads=num_heads,dropout=dropout,)(inputs=[inputs,enc_padding_mask])
dec_outputs = decoder(
vocab_size=vocab_size,)(inputs=[dec_inputs,enc_outputs,look_ahead_mask,dec_padding_mask])
outputs = tf.keras.layers.Dense(units=vocab_size,name="outputs")(dec_outputs)
return tf.keras.Model(inputs=[inputs,dec_inputs],outputs=outputs,name=name)
我不明白为什么会出现此错误,因为模型训练得很好。任何帮助,将不胜感激。
我的保存代码供参考:
print("Saving the model.")
saveloc = "C:/tmp/solar.h5"
model.save(saveloc)
print("Model saved to: " + saveloc + " succesfully.")
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。