我的Keras卷积模型预测了从不同路径导入的相同图像,但是预测结果不同

如何解决我的Keras卷积模型预测了从不同路径导入的相同图像,但是预测结果不同

我创建了一个CNN模型,用于使用 mnist时尚数据集来预测时尚。训练完模型后,我尝试预测了从Keras加载的测试图像之一和从我的PC导入到我的Google Colab笔记本中相同但又相同的另一图像,事实证明,预测结果并不相同。我该如何解决这个问题?

这是我导入数据集的方式:

import tensorflow as tf
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
(x_train,y_train),(x_test,y_test) = fashion_mnist.load_data()

数据操作:

from keras.utils import to_categorical
yTest = to_categorical(y_test)
yTrain = to_categorical(y_train)
xTrain = x_train.reshape((60000,28,1))
xTest = x_test.reshape(10000,1)

模型设置:

from keras.layers import Dense,Flatten,Conv2D,Dropout,MaxPool2D,BatchNormalization
from keras.callbacks import ModelCheckpoint

model = keras.Sequential()

#Adding the convolutional layer
model.add(Conv2D(50,kernel_size=3,activation='relu',padding = 'same',input_shape = (28,1)))
model.add(MaxPool2D(pool_size = (2,2),strides = 1,padding = 'valid'))
model.add(Dropout(0.5))
model.add(Conv2D(40,kernel_size = 3,activation = 'relu',padding = 'same'))
model.add(MaxPool2D(pool_size = (2,padding = 'valid'))
model.add(Dropout(0.5))
model.add(Conv2D(30,strides = 2,padding = 'valid'))
model.add(Dropout(0.5))
model.add(Conv2D(10,padding = 'same'))
model.add(Dropout(0.5))

#Connecting the CNN layers to the ANN
model.add(Flatten())
model.add(Dense(60,activation='relu'))
model.add(Dense(40,activation = 'relu'))
model.add(Dense(10,activation = 'softmax'))
model.load_weights('mnist_fashion.h5')

# Compiling the model
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(optimizer=opt,loss = 'categorical_crossentropy',metrics = ['accuracy']

培训模型:

model = keras.Sequential()

#Adding the convolutional layer
model.add(Conv2D(50,activation = 'softmax'))

模型的性能:

            precision    recall  f1-score   support

       0       0.89      0.88      0.88      1000
       1       0.99      0.99      0.99      1000
       2       0.88      0.89      0.89      1000
       3       0.93      0.93      0.93      1000
       4       0.87      0.89      0.88      1000
       5       0.99      0.98      0.99      1000
       6       0.79      0.78      0.78      1000
       7       0.97      0.98      0.97      1000
       8       0.99      0.98      0.99      1000
       9       0.97      0.97      0.97      1000



   accuracy                           0.93     10000
   macro avg       0.93      0.93      0.93     10000
   weighted avg    0.93      0.93      0.93     10000

来自数据集预测的图片

 #From the dataset
    import numpy as np
    image = xTrain[0].reshape(1,1)
    prd = model.predict(image)
    new_prd = np.argmax(prd,axis  = 1)
    print(f"Prediction = {new_prd}")
    print(f"Full Prediction = {prd}")
    print(f"Label = {y_train[0]}")

数据集结果

Prediction = [9]
Full Prediction = [[1.6268513e-07 2.3548612e-08 1.5456487e-07 8.6898848e-07 1.9692785e-09
  4.4544859e-04 6.6932116e-06 1.4004705e-02 4.1784686e-05 9.8550016e-01]]
Label = 9

导入的图片预测

imported_img = plt.imread("mnist fashion sample.png")
yolo = imported_img.reshape(1,1)
super_prd = model.predict(yolo)
prediction = np.argmax(super_prd,axis = 1)
print(f"Prediction = {prediction}")
print(f"Full Prediction = {super_prd}")
print(f"Label = {y_train[0]}")

导入的图片预测结果

Prediction = [8]
Full Prediction = [[2.49403762e-04 1.69450897e-04 4.47237398e-04 3.05729372e-05
  1.10463676e-04 4.34053177e-03 5.16198808e-04 8.16224664e-02
  8.73587310e-01 3.89263593e-02]]
Label = 9

解决方法

我解决了问题!

我做错了,因为我在训练之前没有对图片进行归一化。这可能会导致错误,因为数据像素范围可能太复杂,以至于relu激活函数无法计算或预测。

谢谢!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-