如何解决如何在不破坏结构的情况下重命名Keras模型的各层?
对于some library functionality,我正在尝试重命名给定模型的图层(包括输入图层)。
以下最小示例显示了我使用当前方法(使用TensorFlow 2.3)时遇到的错误:
from tensorflow.keras.models import load_model
model = load_model("model.h5")
for layer in model.layers:
layer._name = layer.name + "_renamed"
model.to_json()
ValueError: The target structure is of type `<class 'tensorflow.python.framework.ops.Tensor'>`
Tensor("input_1:0",shape=(None,4),dtype=float32)
However the input structure is a sequence (<class 'list'>) of length 0.
model.h5
文件可能是这样创建的,例如:
from tensorflow.keras.layers import Input,Dense
from tensorflow.keras.models import Model
inputs = Input(shape=(4,))
x = Dense(5,activation='relu',name='a')(inputs)
x = Dense(3,activation='softmax',name='b')(x)
model = Model(inputs=inputs,outputs=x)
model.compile(loss='categorical_crossentropy',optimizer='nadam')
model.save("model.h5")
关于如何解决此问题的任何想法?
解决方法
问题:Keras通过遍历layer._inbound_nodes
和comparing against model._network_nodes
来序列化网络;设置layer._name
时,后者会保留原始名称。
解决方案:相应地重命名_network_nodes
。工作功能位于底部,示例如下:
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input,Dense
from tensorflow.keras.models import Model
ipt = Input((16,))
out = Dense(16)(ipt)
model = Model(ipt,out)
model.compile('sgd','mse')
rename(model,model.layers[1],'new_name')
model.save('model.h5')
loaded = load_model('model.h5')
注意:layer.name
是没有.setter
的{{3}},这意味着(显然)它不是要设置的。此外,@property
被覆盖,除了设置属性外还执行其他步骤-可能是必要的,但不能确切确定它可能还具有什么其他效果。我提供了一个绕过这些的替代方法。最好将其视为临时解决方案;我建议在Github上打开一个Issue,因为API方面的更改已到。
功能:
并非万无一失-_get_node_suffix
的命名逻辑需要工作(例如dense_1
可能与dense_11
混淆)。
def rename(model,layer,new_name):
def _get_node_suffix(name):
for old_name in old_nodes:
if old_name.startswith(name):
return old_name[len(name):]
old_name = layer.name
old_nodes = list(model._network_nodes)
new_nodes = []
for l in model.layers:
if l.name == old_name:
l._name = new_name
# vars(l).__setitem__('_name',new) # bypasses .__setattr__
new_nodes.append(new_name + _get_node_suffix(old_name))
else:
new_nodes.append(l.name + _get_node_suffix(l.name))
model._network_nodes = set(new_nodes)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。