如何解决无法将形状3,200,200,3的输入数组广播到形状3
我正在尝试将预训练模型用于我的项目;
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')
test_img = train_dataset[1]
img = image.img_to_array(test_img) # convert image to numpy arry
img = img.reshape((1,) + img.shape) # reshape image
i = 0
for batch in datagen.flow(img,save_prefix='test',save_format='jpeg'): # this loops runs forever until we break,saving images to current directory with specified prefix
plt.figure(i)
plot = plt.imshow(image.img_to_array(batch[0]))
i += 1
if i > 4: # show 4 images
break
plt.show()
我在上面运行代码时收到此错误 无法将形状(3,200,3)的输入数组广播到形状(3)
ValueError Traceback (most recent call last)
<ipython-input-12-546c359868af> in <module>
14 # pick an image to transform
15 test_img = train_dataset[1]
---> 16 img = image.img_to_array(test_img) # convert image to numpy arry
17 img = img.reshape((1,) + img.shape) # reshape image
18
~\miniconda3\envs\tensorflownew\lib\site-packages\keras\preprocessing\image.py in img_to_array(img,data_format,dtype)
73 if dtype is None:
74 dtype = backend.floatx()
---> 75 return image.img_to_array(img,data_format=data_format,dtype=dtype)
76
77
~\miniconda3\envs\tensorflownew\lib\site-packages\keras_preprocessing\image\utils.py in img_to_array(img,dtype)
297 # or (channel,height,width)
298 # but original PIL image has format (width,channel)
--> 299 x = np.asarray(img,dtype=dtype)
300 if len(x.shape) == 3:
301 if data_format == 'channels_first':
~\miniconda3\envs\tensorflownew\lib\site-packages\numpy\core\_asarray.py in asarray(a,dtype,order)
81
82 """
---> 83 return array(a,copy=False,order=order)
84
85
ValueError: could not broadcast input array from shape (3,3) into shape (3)
我的目标尺寸如下
train_dataset = train.flow_from_directory('MNIST_data/train',target_size = (200,200),batch_size = 3,class_mode = 'binary')
validation_dataset = train.flow_from_directory('MNIST_data/validation',class_mode = 'binary')
我应该添加一个变压器,还是代码有其他问题?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。