如何解决在TF 1.14和TF 2.1下恢复张量流模型问题
在TF 1.14和TF 2.1之间使用tf.saved_model.simple_save和tf.saved_model.load时遇到了一些麻烦
如您所见,我附上了代码,
我想看看权重(W),其值必须是节省时间时初始化的状态。
在TF 2.1下, 保存和加载tensorflow模型(pb文件)没有问题。 保存后加载时,我能够识别出相同的重量(W)值
但是,当我使用TF 1.14时, 保存模型还可以..但是,当我加载保存的模型时,结果不是我期望的。 看来tf.saved_model.load无法加载节省的重量,只能随机初始化。
我附上了下面的代码, 您可以通过切换TF_VERSION = 2.1和1.14,SAVE = True和False来运行
TF_VERSION = 2.1
# TF_VERSION = 1.14
SAVE = False
model_dir_path = "./pb"
if TF_VERSION == 1.14:
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
X = tf.placeholder(tf.float32,shape=[None,2],name='input')
# weight
weight_initer = tf.truncated_normal_initializer(mean=0.0,stddev=0.01)
W = tf.get_variable(name="Weight",dtype=tf.float32,shape=[2,1],initializer=weight_initer)
# bias
bias_initer = tf.constant(0.,shape=[1],dtype=tf.float32)
b = tf.get_variable(name="Bias",initializer=bias_initer)
x_w = tf.matmul(X,W,name="MatMul")
x_w_b = tf.add(x_w,b,name="Add")
#save
if SAVE:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_batch = [[2,[3,5]]
feed_dict = {X: x_batch}
output = sess.run(x_w_b,feed_dict=feed_dict)
tf.saved_model.simple_save(sess,model_dir_path,inputs={"inputs": X},outputs={"outputs": W})
# restore
else:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.load(sess,[tf.saved_model.tag_constants.SERVING],model_dir_path)
x_batch = [[2,5]]
feed_dict = {X: x_batch}
weight = sess.run(W,feed_dict=feed_dict)
print(weight)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。