如何解决在Tensorflow 2.0中使用tf.keras.utils.plot_model
我需要可视化深度学习模型的输入/输出维度以进行调试。我具有Keras功能API的经验,并在使用keras.utils.plot_model()
之前很有帮助。
现在,我正试图转向Tensorflow 2.0-主要是因为模型定义更加模块化等(您好pytorch!)。但不确定如何在此体系结构中使用tf.keras.utils.plot_model()
。下面的代码-
class Encoder(tf.keras.Model):
def __init__(self,vocab_size,embedding_dim,enc_units,batch_sz):
super(Encoder,self).__init__()
...
def call(self,x,hidden):
...
class Decoder(tf.keras.Model):
def __init__(self,dec_units,batch_sz):
super(Decoder,hidden,enc_output):
...
现在,在训练模型时,将保存检查点
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch,(inp,targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp,targ,enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,batch,batch_loss.numpy()))
# saving (checkpoint) the model
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,total_loss / steps_per_epoch))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
我知道此检查点具有模型信息。但是我不确定如何从该检查点获得tf.keras.utils.plot_model()
之类的similar visualization。
请提出建议。
编辑
这就是我定义检查点的方式
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,encoder=encoder,decoder=decoder)
然后将原始训练代码中显示的检查点另存为
checkpoint.save(file_prefix = checkpoint_prefix)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。