如何解决TF-Keras-用于多输入功能API模型的Dataset.from_generator
我有一个生成器,生成三个变量。前两个变量是两输入Keras模型(功能性API)的两个输入。我正在使用TF-Dataset来填充我的模型。代码如下:
train_dataset = tf.data.Dataset.from_generator(generator=make_generator_train,args=[train_x_paths,train_y_int],output_types=(tf.tuple((tf.float16,tf.float16)),tf.int8),output_shapes=(tf.TensorShape([2]),tf.TensorShape([1]))).batch(batch_size=batch_size)
我得到一个TypeError
:
TypeError:如果浅层结构是一个序列,则输入也必须是一个序列。输入具有类型:
。
解决方法
尝试这样:
train_dataset = tf.data.Dataset.from_generator(
generator=make_generator_train,args=[train_x_paths,train_y_int],output_types=(tf.float16,tf.int8)
).batch(batch_size=batch_size)
大多数时候,您不需要指定output_shapes。它在运行时决定。此外,您只需要在output_types中指定输出张量的整体dtype。不是每个张量维的dtype。
,解决方案:生成器应为输入和输出按原样生成字典。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。