如何解决运行 Keras h5 模型
我正在尝试运行在 ALASKA2 图像隐写分析竞赛中找到的 here 这个 h5 模型。
我想使用以下代码预测 RGB 图像 c1.bmp
的标签:
import efficientnet.tfkeras as efn
import tensorflow as tf
from tensorflow import keras
import numpy as np
def decode_image(filename,image_size=(512,512)):
bits = tf.io.read_file(filename)
image = tf.image.decode_bmp(bits,channels=3)
image = tf.cast(image,tf.float32) / 255.0
image = tf.image.resize(image,image_size)
return image
img = decode_image('imgs/c1.bmp')
model = keras.models.load_model("model.h5")
print(model.predict(img,verbose=1))
但是,运行此代码会导致此错误:
File "alaska.py",line 20,in <module>
print(model.predict(img,verbose=1))
File "Python38\lib\site-packages\tensorflow\python\keras\engine\training.py",line 1629,in predict
tmp_batch_outputs = self.predict_function(iterator)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py",line 828,in __call__
result = self._call(*args,**kwds)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py",line 871,in _call
self._initialize(args,kwds,add_initializers_to=initializers)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py",line 725,in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "Python38\lib\site-packages\tensorflow\python\eager\function.py",line 2969,in _get_concrete_function_internal_garbage_collected
graph_function,_ = self._maybe_define_function(args,kwargs)
File "Python38\lib\site-packages\tensorflow\python\eager\function.py",line 3361,in _maybe_define_function
graph_function = self._create_graph_function(args,line 3196,in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "Python38\lib\site-packages\tensorflow\python\framework\func_graph.py",line 990,in func_graph_from_py_func
func_outputs = python_func(*func_args,**func_kwargs)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py",line 634,in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args,**kwds)
File "Python38\lib\site-packages\tensorflow\python\framework\func_graph.py",line 977,in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1478 predict_function *
return step_function(self,iterator)
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1468 step_function **
outputs = model.distribute_strategy.run(run_step,args=(data,))
Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn,args=args,kwargs=kwargs)
Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn,args,kwargs)
Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
return fn(*args,**kwargs)
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1461 run_step **
outputs = model.predict_step(data)
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1434 predict_step
return self(x,training=False)
Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:998 __call__
input_spec.assert_input_compatibility(self.input_spec,inputs,self.name)
Python38\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:271 assert_input_compatibility
raise ValueError('Input ' + str(input_index) +
ValueError: Input 0 is incompatible with layer sequential: expected shape=(None,512,3),found shape=(32,3)
我有 Python 3.8.7 和 tensorflow 2.4.1,并在 Windows 8 中使用 Pycharm。
这个错误是什么意思,我该如何解决?
解决方法
您忘记添加批次维度。只需将此跟随转换添加到 ddecode_image
函数:
image = tf.expand_dims(image,axis=0)