Keras模型预测和sklearn混淆矩阵结果问题

如何解决Keras模型预测和sklearn混淆矩阵结果问题

我对Keras模型,预测和混淆矩阵有疑问。

我想将此Keras教程更改为多个课程。

https://www.tensorflow.org/tutorials/structured_data/feature_columns

我读入数据并将6个string目标类编码为int

dataframe = pd.read_csv("my_csv.csv")
target = 'some_target'
labelencoder = LabelEncoder()
dataframe[target] = labelencoder.fit_transform(dataframe[target])

然后我拆分数据,创建列,创建模型并拟合

train,test = train_test_split(dataframe,test_size=0.2)
train,val = train_test_split(train,test_size=0.2)


# A utility method to create a tf.data dataset from a Pandas Dataframe
def df_to_dataset(dataframe,target,shuffle=True,batch_size=32):
    dataframe = dataframe.copy()
    labels = dataframe.pop(target)
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe),labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(dataframe))
    ds = ds.batch(batch_size)
    return ds,labels

feature_columns = []
 
f = feature_column.categorical_column_with_vocabulary_list(
    field,unique_categories)
feature_columns.append(feature_column.embedding_column(f,dimension=8))

f2 = feature_column.categorical_column_with_vocabulary_list(
    field,unique_categories)
indicator_column = feature_column.indicator_column(f2)
feature_columns.append(indicator_column)

feature_columns.append(feature_column.numeric_column(field))

feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

batch_size = 32
train_ds,train_labels = df_to_dataset(train,batch_size=batch_size)
val_ds,val_labels = df_to_dataset(val,shuffle=False,batch_size=batch_size)
test_ds,test_labels = df_to_dataset(test,batch_size=batch_size)

model = tf.keras.Sequential([
    feature_layer,layers.Dense(128,activation='relu'),layers.Dropout(.1),layers.Dense(1,activation='softmax')
])


# get hps
optimizer = 'adam'
loss_function = tf.keras.losses.BinaryCrossentropy(from_logits=True)
metrics = ['accuracy']
epochs = 1

model.compile(optimizer=optimizer,loss=loss_function,metrics=metrics)

model.fit(train_ds,validation_data=val_ds,epochs=epochs)

loss,accuracy = model.evaluate(test_ds)
print("Accuracy",accuracy)

predicted = model.predict(test_ds)

cf = confusion_matrix(test_labels,predicted)

当我运行model.predict时,输出结果很奇怪

[1.]
[1.]
[1.]
[1.]
[1.]
[1.]
[1.]
[1.]

混乱矩阵也不正确

[ 0 33  0  0  0  0]
[  0 499   0   0   0   0]
[ 0 14  0  0  0  0]
[   0 1089    0    0    0    0]
[  0 360   0   0   0   0]
[0 4 0 0 0 0]

我为目标尝试了不同的编码,改变了损耗但无济于事

# mlb = MultiLabelBinarizer()
# dataframe[target] = mlb.fit_transform(dataframe[target])

loss='categorical_crossentropy'

我在这里做什么错了?

还尝试了6种输出神经元

model = tf.keras.Sequential([
    feature_layer,layers.Dense(6,activation='softmax')
])

但出现错误

ValueError: logits and labels must have the same shape ((None,6) vs (None,1))

编辑:

print(type(train_ds))
# <class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>

print(train_ds)
# <BatchDataset shapes: ({feature1: (None,),feature2: (None,feature3: (None,feature4: (None,) ...

print(type(train_labels))
# <class 'pandas.core.series.Series'>

编辑: 取得一些进展。原来损耗函数和目标暗淡是依赖的: Tensorflow : logits and labels must have the same first dimension

如果您具有一维整数编码目标,则可以使用sparse_categorical_crossentropy作为损失函数

因此将损失更改为: sparse_categorical_crossentropy

现在,当我运行model.predict输出看起来更好

[0.02313532 0.39231667 0.0117254  0.42083895 0.15037686 0.00160678]
[2.3085043e-02 3.3588389e-01 8.1730038e-03 4.8321337e-01 1.4923279e-01
 4.1199493e-04]
[8.1658429e-03 3.3901721e-01 2.3666199e-03 5.3861737e-01 1.1167890e-01
 1.5400720e-04]
[8.6198252e-04 1.2048376e-01 1.3487167e-02 4.1729528e-01 4.4759643e-01
 2.7547608e-04]
[0.06842247 0.31534496 0.02852604 0.40057638 0.17933881 0.0077913 ]
[0.05149424 0.34782204 0.02664029 0.34621894 0.22060096 0.00722347]

然后获得最高预测指数并传递到混淆矩阵

predictions_index = np.argmax(predicted,axis=1)
cf = confusion_matrix(test_labels,predictions_index)

混乱矩阵看起来更好

[ 0  3  0 27  2  0]
[  0  37   0 386  54   0]
[ 0  0  0 14  1  0]
[  0  13   0 968 124   0]
[  0   4   0 309  49   0]
[0 0 0 6 2 0]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-