如何解决NotImplementedError:图层注意在__init__中具有参数,因此必须覆盖get_config
我已实现此链接中建议的自定义注意层: How to add attention layer to a Bi-LSTM
class attention(Layer):
def __init__(self,return_sequences=True):
self.return_sequences = return_sequences
super(attention,self).__init__()
def build(self,input_shape):
self.W=self.add_weight(name="att_weight",shape=(input_shape[-1],1),initializer="normal")
self.b=self.add_weight(name="att_bias",shape=(input_shape[1],initializer="zeros")
super(attention,self).build(input_shape)
def call(self,x):
e = K.tanh(K.dot(x,self.W)+self.b)
a = K.softmax(e,axis=1)
output = x*a
if self.return_sequences:
return output
return K.sum(output,axis=1)
代码运行了,但是当需要保存模型时出现了这个错误。
NotImplementedError:图层注意在__init__
中具有参数,因此必须覆盖get_config
。
一些评论建议覆盖get_config。
“此错误使您知道tensorflow无法保存模型,因为它无法加载模型。 具体来说,它将无法重新实例化自定义的Layer类。
要解决此问题,只需根据您添加的新参数覆盖其get_config方法即可。”
查看链接:NotImplementedError: Layers with arguments in `__init__` must override `get_config`
我的问题是,基于上面的自定义关注层,如何编写get_config来解决此错误?
解决方法
您需要这样的配置方法:
def get_config(self):
config = super().get_config().copy()
config.update({
'return_sequences': self.return_sequences
})
return config
所需的所有信息都在您链接的其他post中。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。