ValueError:Python输入与input_signature不兼容:

如何解决ValueError:Python输入与input_signature不兼容:

系统信息

  • OS平台和发行版:CentOS Linux版本7.7.1908 -TensorFlow版本:2.3.0

我正在关注以下示例:https://www.tensorflow.org/tutorials/text/image_captioning?hl=en

它可以正常工作并保存检查点,我现在想将其转换为TF Lite模型。

以下是完整转换代码的链接:https://colab.research.google.com/drive/1GJkGcwWvDAWMooTsECzuSRUSPbirADhb?usp=sharing

这里是完整火车代码的链接: https://colab.research.google.com/drive/1X2d9WW1EMEzN8Rgva3rtjevP0T_jFccj?usp=sharing

我也关注isssue#32999

这是我要保存的内容,它们将转换推理图:

@tf.function
def evaluate(image):
    hidden = decoder.reset_states(batch_size=1)

    temp_input = tf.expand_dims(load_image(image)[0],0)
    img_tensor_val = image_features_extract_model(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val,(img_tensor_val.shape[0],-1,img_tensor_val.shape[3]))

    features = encoder(img_tensor_val)

    dec_input = tf.expand_dims([tokenizer.word_index['<start>']],0)
    result = []

    for i in range(max_length):
        predictions,hidden,attention_weights = decoder(dec_input,features,hidden)

        predicted_id = tf.random.categorical(predictions,1)[0][0]
        # print(tokenizer.index_word)
        print(predicted_id,predicted_id.dtype)

        # for key,value in tokenizer.index_word.items():
        #     key = tf.convert_to_tensor(key)
        #     tf.dtypes.cast(key,tf.int64)
        #     print(key)

        # print(tokenizer.index_word)

        result.append(predicted_id)

        # if tokenizer.index_word[predicted_id] == '<end>':
        #     return result

        dec_input = tf.expand_dims([predicted_id],0)

    return result

export_dir = "./"
tflite_enc_input = ''
ckpt.f = evaluate
to_save = evaluate.get_concrete_function('')

converter = tf.lite.TFLiteConverter.from_concrete_functions([to_save])
tflite_model = converter.convert()

但是我得到这个错误

ValueError: in user code:

    convert2savedmodel.py:310 evaluate  *
        predictions,hidden)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:985 __call__  **
        outputs = call_fn(inputs,*args,**kwargs)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:780 __call__
        result = self._call(*args,**kwds)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:840 _call
        return self._stateless_fn(*args,**kwds)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py:2828 __call__
        graph_function,args,kwargs = self._maybe_define_function(args,kwargs)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py:3171 _maybe_define_function
        *args,**kwargs)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py:2622 canonicalize_function_inputs
        self._flat_input_signature)
    /share/nishome/19930072_0/miniconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py:2713 _convert_inputs_to_signature
        format_error_message(inputs,input_signature))

    ValueError: Python inputs incompatible with input_signature:
      inputs: (
        Tensor("ExpandDims_1:0",shape=(1,1),dtype=int32),Tensor("cnn__encoder/StatefulPartitionedCall:0",64,256),dtype=float32),Tensor("zeros:0",512),dtype=float32))
      input_signature: (
        TensorSpec(shape=(1,dtype=tf.int64,name=None),TensorSpec(shape=(1,dtype=tf.float32,name=None))

编码器型号:

class CNN_Encoder(tf.keras.Model):
    def __init__(self,embedding):
        super(CNN_Encoder,self).__init__()
        # shape after fc == (batch_size,embedding_dim)
        self.fc = tf.keras.layers.Dense(embedding_dim)

    @tf.function(input_signature=[tf.TensorSpec(shape=(1,features_shape),dtype=tf.dtypes.float32)])
    def call(self,x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x

解码器型号:

class RNN_Decoder(tf.keras.Model):
    def __init__(self,embedding_dim,units,vocab_size):
        super(RNN_Decoder,self).__init__()
        self.units = units

        self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_dim)
        self.gru = tf.keras.layers.GRU(self.units,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform',unroll = True)
        self.fc1 = tf.keras.layers.Dense(self.units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.units)


    @tf.function(input_signature=[tf.TensorSpec(shape=[1,1],dtype=tf.int64),tf.TensorSpec(shape=[1,256],dtype=tf.float32),512],dtype=tf.float32)])
    def call(self,x,hidden):

        context_vector,attention_weights = self.attention(features,hidden)

        #x shape after passing through embedding == (batch_size,1,embedding_dim)
        x = self.embedding(x)

        #x shape after concatenation == (batch_size,embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector,x],axis=-1)


        output,state = self.gru(x)

        #shape == (batch_size,max_length,hidden_size)
        x = self.fc1(output)

        #x shape == (batch_size,hidden_size)
        x = tf.reshape(x,(-1,x.shape[2]))

        # output shape == (batch_size * max_length,vocab)
        x = self.fc2(x)

        return x,state,attention_weights

    def reset_states(self,batch_size):
        return tf.zeros((batch_size,self.units))

我只是将tf.function更改为int32,如下所示:

@tf.function(input_signature=[tf.TensorSpec(shape=[1,dtype=tf.int32),dtype=tf.float32)])

但是出现另一个错误:

ValueError:Python输入与input_signature不兼容:

Tensor("ExpandDims_2:0",dtype=int64),Tensor("rnn__decoder/StatefulPartitionedCall:1",dtype=float32))
input_signature: (
TensorSpec(shape=(1,dtype=tf.int32,name=None))```
Why the dtypes of inputs change from int64 to int32?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-