如何解决CSV MNIST 数据集:ValueError: Shapes (None, 10) 和 (None, 28, 10) 不兼容
我想用 keras 对 MINST 数据集 (csv) 进行分类。这是我的代码,但运行后出现此错误。你知道我该如何解决吗ValueError:Shapes (None,10) and (None,28,10) is incompatible
from keras import models
import numpy as np
from keras import layers
import tensorflow as tf
from tensorflow.keras.models import Sequential
from keras.utils import np_utils
from tensorflow.keras.layers import Dense,Dropout,LSTM,BatchNormalization
from keras.utils import to_categorical,plot_model
mnist = tf.keras.datasets.mnist
#Load dataset
(x_train,y_train),(x_test,y_test) = mnist.load_data()
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
model = Sequential()
model.add(Dense(units=32,activation='sigmoid',input_shape=(x_train.shape[1:])))
model.add(Dense(units=64,activation='sigmoid'))
model.add(Dense(units=10,activation='softmax'))
model.compile(optimizer="sgd",loss='categorical_crossentropy',metrics=['accuracy'])
history = model.fit(x_train,y_train,batch_size=32,epochs=100,validation_split=.3)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['training','validation'],loc='best')
plt.show()
这里我从代码中得到了错误。我知道这是因为输入形状的原因,但我不知道应该如何定义它。 x_train.shape 是 (60000,28) y_train.shape 是 (60000,10)
ValueError Traceback (most recent call last)
<ipython-input-112-7c9220a71c0e> in <module>
1 model.compile(optimizer="sgd",metrics=['accuracy'])
----> 2 history = model.fit(x_train,validation_split=.3)
3
4 plt.plot(history.history['accuracy'])
5 plt.plot(history.history['val_accuracy'])
~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self,*args,**kwargs)
64 def _method_wrapper(self,**kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self,**kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self,x,y,batch_size,epochs,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing)
846 batch_size=batch_size):
847 callbacks.on_train_batch_begin(step)
--> 848 tmp_logs = train_function(iterator)
849 # Catch OutOfRangeError for Datasets of unknown size.
850 # This blocks until the batch has finished executing.
~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self,**kwds)
578 xla_context.Exit()
579 else:
--> 580 result = self._call(*args,**kwds)
581
582 if tracing_count == self._get_tracing_count():
~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self,**kwds)
625 # This is the first call of __call__,so we have to initialize.
626 initializers = []
--> 627 self._initialize(args,kwds,add_initializers_to=initializers)
628 finally:
629 # At this point we know that the initialization is complete (or less
~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _initialize(self,args,add_initializers_to)
503 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
504 self._concrete_stateful_fn = (
--> 505 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
506 *args,**kwds))
507
~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _get_concrete_function_internal_garbage_collected(self,**kwargs)
2444 args,kwargs = None,None
2445 with self._lock:
-> 2446 graph_function,_,_ = self._maybe_define_function(args,kwargs)
2447 return graph_function
2448
~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _maybe_define_function(self,kwargs)
2775
2776 self._function_cache.missed.add(call_context_key)
-> 2777 graph_function = self._create_graph_function(args,kwargs)
2778 self._function_cache.primary[cache_key] = graph_function
2779 return graph_function,kwargs
~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _create_graph_function(self,kwargs,override_flat_arg_shapes)
2655 arg_names = base_arg_names + missing_arg_names
2656 graph_function = ConcreteFunction(
-> 2657 func_graph_module.func_graph_from_py_func(
2658 self._name,2659 self._python_function,~\anaconda3\lib\site-packages\tensorflow\python\framework\func_graph.py in func_graph_from_py_func(name,python_func,signature,func_graph,autograph,autograph_options,add_control_dependencies,arg_names,op_return_value,collections,capture_by_value,override_flat_arg_shapes)
979 _,original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args,**func_kwargs)
982
983 # invariant: `func_outputs` contains only Tensors,CompositeTensors,~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in wrapped_fn(*args,**kwds)
439 # __wrapped__ allows AutoGraph to swap in a converted function. We give
440 # the function a weak reference to itself to avoid a reference cycle.
--> 441 return weak_wrapped_fn().__wrapped__(*args,**kwds)
442 weak_wrapped_fn = weakref.ref(wrapped_fn)
443
~\anaconda3\lib\site-packages\tensorflow\python\framework\func_graph.py in wrapper(*args,**kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e,"ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
th
raise ValueError("Shapes %s and %s are incompatible" % (self,other))
ValueError: Shapes (None,10) are incompatible
解决方法
由于密集层无法处理图像等二维数据,您应该首先将输入展平为向量,然后将其传递给您的模型,否则,您将在输出中获得其他维度,然后是您的标签和logits(模型输出)不兼容,你会得到错误。
像这样向模型添加一个展平层:
model.add(Flatten(input_shape=(x_train.shape[1:]))) #add this
model.add(Dense(units=32,activation='sigmoid'))
model.add(Dense(units=64,activation='sigmoid'))
model.add(Dense(units=10,activation='softmax'))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。