如何解决Keras使用keras.predict预测单个示例吗?
我确信/希望这是一个非常简单的问题,但是我没有找到任何答案。
我在Keras中建立了一个顺序模型,其中包含366个输入神经元和一个输出神经元。似乎训练和评估都不错,但是无论何时我尝试预测一个示例,尽管(366,1)
是model.output_shape
,我都会得到一个形状为(None,1)
的小数组。
我知道here存在一个非常相似的问题,但不幸的是,所有提议的解决方案都无法解决我的问题。
根据我所读的内容,这是因为Keras将每个输入作为一个单独的示例进行预测,但是到目前为止,这对我没有帮助。
我尝试将输入作为大小为(366,1)
,(1,366)
,(366,)
的numpy数组以及包含每个变量的列表进行传递,但是没有任何效果((1,366)
抛出错误,其他所有输出的大小为(366,1)
)。
如果有人可以帮助您,将不胜感激。谢谢。
以下是代码:(很抱歉,如果它不是很整齐)
培训:
import example_generator as eg
import numpy as np
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
training_data = np.load("../data/training_data.npy",allow_pickle=True)
testing_data = np.load("../data/testing_data.npy",allow_pickle=True)
model = Sequential([
Dense(100,activation="relu"),Dense(30,Dense(1,activation="linear")
])
model.compile(
optimizer = "Adam",loss="mean_squared_error",metrics="mean_absolute_error"
)
model.fit(
eg.yield_training_example(training_data,366),epochs=1,steps_per_epoch = 14590,batch_size=50
)
model.evaluate(
eg.yield_training_example(testing_data,steps = 50
)
model.save("../models/model")
用于测试:
import example_generator as eg
import numpy as np
from tensorflow.keras.models import load_model
testing_data = np.load("../../data/testing_data.npy",allow_pickle=True)
model = load_model("../../data/models/feedforward")
testing_example = next(eg.yield_training_example(testing_data,366))
X = testing_example[0]
prediction = model.predict(
X
)
print(f"Prediction: {prediction}\nAnswer: {testing_example[1]}\n\n")
对于我用来返回示例的生成器:
import numpy as np
# Acts as an iterable of training examples
def yield_training_example(data,num_nn_inputs = 366):
for eg in data[:,1]:
inc = 0
while (inc <= (len(eg) - num_nn_inputs - 1)):
yield (
np.array(eg[inc:(inc + num_nn_inputs),:]).astype(np.float32),np.array(eg[(inc + num_nn_inputs),:]).astype(np.float32)
)
inc += 1
解决方法
好的,我最终解决了这个问题。在制作模型时,我需要指定输入尺寸,例如:
model = Sequential([
Dense(100,activation="relu",input_dim=366),Dense(30,activation="relu"),Dense(1,activation="linear")
])
从那里,我只需要确保对网络的所有输入都以形状为(1,366)
的数组形式传递。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。