如何解决为什么在 MNIST 分类器代码中使用 X[0] 会给我一个错误?
我正在学习使用 MNIST 数据集进行分类。我遇到了一个错误,我无法弄清楚,我已经做了很多谷歌搜索,但我什么也做不了,也许你是专家,可以帮助我。这是代码--
>>> from sklearn.datasets import fetch_openml
>>> mnist = fetch_openml('mnist_784',version=1)
>>> mnist.keys()
输出: dict_keys(['data','target','frame','categories','feature_names','target_names','DESCR','details','url'])
>>> X,y = mnist["data"],mnist["target"]
>>> X.shape
输出:(70000,784)
>>> y.shape
输出:(70000)
>>> X[0]
output:KeyError Traceback (most recent call last)
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self,key,method,tolerance)
2897 try:
-> 2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 0
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
<ipython-input-10-19c40ecbd036> in <module>
----> 1 X[0]
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\frame.py in __getitem__(self,key)
2904 if self.columns.nlevels > 1:
2905 return self._getitem_multilevel(key)
-> 2906 indexer = self.columns.get_loc(key)
2907 if is_integer(indexer):
2908 indexer = [indexer]
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self,tolerance)
2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
-> 2900 raise KeyError(key) from err
2901
2902 if tolerance is not None:
KeyError: 0
请回答,可能有一个愚蠢的错误,因为我是 ML 的初学者。如果你也给我一些提示,那真的很有帮助。
解决方法
我也遇到了同样的问题。
- scikit-learn:0.24.0
- matplotlib:3.3.3
- Python:3.9.1
我曾经用下面的代码来解决这个问题。
import matplotlib as mpl
import matplotlib.pyplot as plt
# instead of some_digit = X[0]
some_digit = X.to_numpy()[0]
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
,
fetch_openml
的 API 在不同版本之间发生了变化。最初,它返回一个 pandas.DataFrame
,这正是您所拥有的。自 0.24.0
(2020 年 12 月)起,as_frame
的 fetch_openml
参数设置为 False,从而为您提供 numpy.ndarray
。您应该将您的 DataFrame 转换为 numpy.ndarray
,请参阅 pandas
的 indexing method 或升级 sklearn
。
如果您遵循以下代码,则无需降级 scikit-learn 库:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784',version= 1,as_frame= False)
mnist.keys()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。