如何解决纯tensorflow2不使用keras中的学习循环应该看起来如何?
我在Tensorflow中有一个非常定制的模型,没有Keras组件。目前,我被困在编写学习循环中。我已经创建了模型(带有占位符),创建了数据集(带有批处理),优化器,图形等。我认为我不需要急切的执行,所以我用tf.compat.v1.disable_eager_execution()
禁用了它。当我尝试遍历数据集以馈送网络时,我遇到了错误,告诉我__iter__
仅在热切模式下可用。当我启用eager模式时,它们会出错,告诉我不能使用占位符。我正在使用张量流2.3
。
代码:
import tensorflow as tf
import tensorflow_datasets as tfds
from fcrn import ResNet50UpProj
tf.compat.v1.disable_eager_execution()
height = 480
width = 640
channels = 3
batch_size = 5
lr = 0.001
dataset = tfds.load('nyu_depth_v2',split='train',shuffle_files=True)
dataset = dataset.batch(batch_size)
input_node = tf.compat.v1.placeholder(tf.float32,shape=(None,height,width,channels))
labels_node = tf.compat.v1.placeholder(tf.float32,int(height/2),int(width/2),1))
net = ResNet50UpProj({'data': input_node},batch_size,1,False)
output_node = net.get_output()
graph = tf.Graph()
sess = tf.compat.v1.Session(graph=graph)
loss = tf.nn.l2_loss(output_node - labels_node)
opt = tf.compat.v1.train.AdamOptimizer(lr)
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
grads = opt.compute_gradients(
loss,var_list=tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))
train_op = opt.apply_gradients(grads,name='optimizer')
with graph.as_default():
sess.run(tf.compat.v1.global_variables_initializer())
for epoch in range(10):
print("Epoch: {}".format(epoch))
for step,(batch_input,batch_labels) in enumerate(dataset):
print(step)
# sess.run(train_op,feed_dict={input_node: batch.image,labels_node: batch.depth})
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。