如何解决使用经过训练的对象检测API模型和TF 2进行批量预测
我在TPU上使用针对TF 2的对象检测API成功地训练了一个模型,该模型另存为.pb(SavedModel格式)。然后,我使用tf.saved_model.load
将其加载回去,当使用将单个图像转换为形状为(1,w,h,3)
的张量来预测框时,它可以很好地工作。
import tensorflow as tf
import numpy as np
# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')
image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image,channels=3).numpy()
input_tensor = np.expand_dims(image_np,0)
detections = detect_fn(input_tensor) # This works fine
问题是我需要进行批量预测才能将其缩放到100万张图像,但是此模型的输入签名似乎仅限于处理形状为(1,3)
的数据。
这也意味着我不能在Tensorflow Serving中使用批处理。
我怎么解决这个问题?我可以只更改Model Signature来处理大量数据吗?
所有工作(加载模型+预测)都是在由对象检测API(来自here)发布的官方容器中进行的
解决方法
我最近遇到了这个问题。当您使用exporter_main_v2.py
将检查点文件转换为.pb
文件时,它将调用exporter_lib_v2.py
。我发现在文件exporter_lib_v2.py
(here)中,TF2用形状[1,None,3]
固定了输入签名。我们必须将其更改为[None,3]
需要将该文件(138,162,170,185)中的行从1
修改为None
。然后重建TF2对象检测器API存储库(link),并使用新构建的版本再次导出.pb
。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。