pytorch 更改输入图像大小

如何解决pytorch 更改输入图像大小

我是 pytorch 的新手,我正在学习教程,但是当我尝试修改代码以使用 64x64x3 图像而不是 32x32x3 图像时,我遇到了很多错误。这是教程中的代码:

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

transform = transforms.Compose(
    [transforms.ToTensor(),transforms.Resize(32),transforms.RandomCrop(32),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5))])

batch_size = 4

trainset = ImageFolder("Train",transform=transform)
trainloader = DataLoader(trainset,shuffle=True,batch_size=batch_size,num_workers=0)

classes = ('Dog','Cat')

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images,labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16 * 5 * 5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

print("training started")

from tqdm import tqdm

for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i,data in tqdm(enumerate(trainloader,0),desc=f"epoch: {epoch + 1}"):
        # get the inputs; data is a list of [inputs,labels]
        inputs,labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d,%5d] loss: %.3f' %
                  (epoch + 1,i + 1,running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

PATH = './net.pth'
torch.save(net.state_dict(),PATH)

如果我将 'transforms.Resize(32)' 和 'transforms.RandomCrop(32)' 更改为 64(以获得 64x64x3 图像),我会收到此错误

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~\Documents\pyth\classifier\train_classifier.py in <module>
     86 
     87         # forward + backward + optimize
---> 88         outputs = net(inputs)
     89         loss = criterion(outputs,labels)
     90         loss.backward()

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self,*input,**kwargs)
    887             result = self._slow_forward(*input,**kwargs)
    888         else:
--> 889             result = self.forward(*input,**kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),~\Documents\pyth\classifier\train_classifier.py in forward(self,x)
     57         x = self.pool(F.relu(self.conv1(x)))
     58         x = self.pool(F.relu(self.conv2(x)))
---> 59         x = x.view(-1,10816+1)
     60         x = F.relu(self.fc1(x))
     61         x = F.relu(self.fc2(x))

RuntimeError: shape '[-1,10817]' is invalid for input of size 10816
´´´

and if i try to change the parameters of ´x.view(...)´ i get this error

´´´
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~\Documents\pyth\classifier\train_classifier.py in <module>
     86 
     87         # forward + backward + optimize
---> 88         outputs = net(inputs)
     89         loss = criterion(outputs,x)
     58         x = self.pool(F.relu(self.conv2(x)))
     59         x = x.view(-1,16 * 2 * 5 * 5)
---> 60         x = F.relu(self.fc1(x))
     61         x = F.relu(self.fc2(x))
     62         x = self.fc3(x)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self,~\Anaconda3\lib\site-packages\torch\nn\modules\linear.py in forward(self,input)
     92 
     93     def forward(self,input: Tensor) -> Tensor:
---> 94         return F.linear(input,self.weight,self.bias)
     95 
     96     def extra_repr(self) -> str:

~\Anaconda3\lib\site-packages\torch\nn\functional.py in linear(input,weight,bias)
   1751     if has_torch_function_variadic(input,weight):
   1752         return handle_torch_function(linear,(input,weight),input,bias=bias)
-> 1753     return torch._C._nn.linear(input,bias)
   1754 
   1755 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x800 and 400x120)
´´´

解决方法

我认为这应该可行,因为在执行第二次池化操作后,输出特征图是 N x C x 13 x 13

self.fc1 = nn.Linear(16 * 13 * 13,120)

x = x.view(-1,16 * 13 * 13)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16 * 13 * 13,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16 * 13 * 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-